Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,6 @@
//! Application layer for embedding bounded context.
//!
//! Contains application services that orchestrate domain logic
//! and coordinate between domain entities and infrastructure.
pub mod services;

View File

@@ -0,0 +1,567 @@
//! Application services for embedding generation.
//!
//! Provides high-level services for generating embeddings from audio
//! spectrograms using the Perch 2.0 ONNX model.
use std::sync::Arc;
use std::time::Instant;
use ndarray::Array3;
use rayon::prelude::*;
use tracing::{debug, info, instrument, warn};
use crate::domain::entities::{
Embedding, EmbeddingBatch, EmbeddingMetadata, SegmentId, StorageTier,
};
use crate::infrastructure::model_manager::ModelManager;
use crate::normalization;
use crate::{EmbeddingError, EMBEDDING_DIM, MEL_BINS, MEL_FRAMES};
/// Input spectrogram for embedding generation.
///
/// Represents a mel spectrogram with shape [1, MEL_FRAMES, MEL_BINS] = [1, 500, 128].
#[derive(Debug, Clone)]
pub struct Spectrogram {
/// The spectrogram data as a 3D array [batch, frames, bins]
pub data: Array3<f32>,
/// Associated segment ID
pub segment_id: SegmentId,
/// Additional metadata
pub metadata: SpectrogramMetadata,
}
/// Metadata about the spectrogram
#[derive(Debug, Clone, Default)]
pub struct SpectrogramMetadata {
/// Sample rate of the original audio
pub sample_rate: Option<u32>,
/// Duration of the audio segment in seconds
pub duration_secs: Option<f32>,
/// SNR of the audio segment
pub snr: Option<f32>,
}
impl Spectrogram {
/// Create a new spectrogram from raw data.
///
/// # Arguments
///
/// * `data` - 2D array of shape [MEL_FRAMES, MEL_BINS] (will be expanded to 3D)
/// * `segment_id` - ID of the source audio segment
///
/// # Errors
///
/// Returns an error if the data dimensions are incorrect.
pub fn new(
data: ndarray::Array2<f32>,
segment_id: SegmentId,
) -> Result<Self, EmbeddingError> {
let shape = data.shape();
if shape[0] != MEL_FRAMES || shape[1] != MEL_BINS {
return Err(EmbeddingError::InvalidDimensions {
expected: MEL_FRAMES * MEL_BINS,
actual: shape[0] * shape[1],
});
}
// Expand to 3D: [1, frames, bins]
let data = data.insert_axis(ndarray::Axis(0));
Ok(Self {
data,
segment_id,
metadata: SpectrogramMetadata::default(),
})
}
/// Create from a 3D array directly
pub fn from_array3(data: Array3<f32>, segment_id: SegmentId) -> Result<Self, EmbeddingError> {
let shape = data.shape();
if shape[1] != MEL_FRAMES || shape[2] != MEL_BINS {
return Err(EmbeddingError::InvalidDimensions {
expected: MEL_FRAMES * MEL_BINS,
actual: shape[1] * shape[2],
});
}
Ok(Self {
data,
segment_id,
metadata: SpectrogramMetadata::default(),
})
}
/// Set metadata for the spectrogram
pub fn with_metadata(mut self, metadata: SpectrogramMetadata) -> Self {
self.metadata = metadata;
self
}
}
/// Output from the embedding service
#[derive(Debug, Clone)]
pub struct EmbeddingOutput {
/// The generated embedding
pub embedding: Embedding,
/// Whether GPU was used for inference
pub gpu_used: bool,
/// Inference latency in milliseconds
pub latency_ms: f32,
}
/// Configuration for the embedding service
#[derive(Debug, Clone)]
pub struct EmbeddingServiceConfig {
/// Maximum batch size for inference
pub batch_size: usize,
/// Whether to L2 normalize embeddings
pub normalize: bool,
/// Default storage tier for new embeddings
pub default_tier: StorageTier,
/// Whether to validate embeddings after generation
pub validate_embeddings: bool,
/// Maximum allowed sparsity (fraction of near-zero values)
pub max_sparsity: f32,
}
impl Default for EmbeddingServiceConfig {
fn default() -> Self {
Self {
batch_size: 8,
normalize: true,
default_tier: StorageTier::Hot,
validate_embeddings: true,
max_sparsity: 0.9,
}
}
}
/// Service for generating embeddings from spectrograms.
///
/// This is the main application service for the embedding bounded context.
/// It coordinates between the model manager, ONNX inference, and domain entities.
pub struct EmbeddingService {
/// Model manager for loading and caching ONNX models
model_manager: Arc<ModelManager>,
/// Configuration for the service
config: EmbeddingServiceConfig,
}
impl EmbeddingService {
/// Create a new embedding service.
///
/// # Arguments
///
/// * `model_manager` - The model manager for ONNX model access
/// * `batch_size` - Maximum batch size for inference
#[must_use]
pub fn new(model_manager: Arc<ModelManager>, batch_size: usize) -> Self {
Self {
model_manager,
config: EmbeddingServiceConfig {
batch_size,
..Default::default()
},
}
}
/// Create with custom configuration
#[must_use]
pub fn with_config(model_manager: Arc<ModelManager>, config: EmbeddingServiceConfig) -> Self {
Self {
model_manager,
config,
}
}
/// Generate an embedding from a single spectrogram.
///
/// # Arguments
///
/// * `spectrogram` - The input spectrogram
///
/// # Errors
///
/// Returns an error if inference fails or the embedding is invalid.
#[instrument(skip(self, spectrogram), fields(segment_id = %spectrogram.segment_id))]
pub async fn embed_segment(
&self,
spectrogram: &Spectrogram,
) -> Result<EmbeddingOutput, EmbeddingError> {
let start = Instant::now();
// Get the inference session
let inference = self.model_manager.get_inference().await?;
let model_version = self.model_manager.current_version();
// Run inference
let raw_embedding = inference.run(&spectrogram.data)?;
// Convert to vector
let mut vector: Vec<f32> = raw_embedding.iter().copied().collect();
// Calculate original norm before normalization
let original_norm = normalization::compute_norm(&vector);
// L2 normalize if configured
if self.config.normalize {
normalization::l2_normalize(&mut vector);
}
// Validate embedding
if self.config.validate_embeddings {
self.validate_embedding(&vector)?;
}
// Calculate sparsity
let sparsity = normalization::compute_sparsity(&vector);
// Create embedding entity
let mut embedding = Embedding::new(
spectrogram.segment_id,
vector,
model_version.full_version(),
)?;
// Set metadata
let latency_ms = start.elapsed().as_secs_f32() * 1000.0;
embedding.metadata = EmbeddingMetadata {
inference_latency_ms: Some(latency_ms),
batch_id: None,
gpu_used: inference.is_gpu(),
original_norm: Some(original_norm),
sparsity: Some(sparsity),
quality_score: Some(self.compute_quality_score(&embedding)),
};
embedding.tier = self.config.default_tier;
debug!(
latency_ms = latency_ms,
norm = embedding.norm(),
sparsity = sparsity,
"Generated embedding"
);
Ok(EmbeddingOutput {
embedding,
gpu_used: inference.is_gpu(),
latency_ms,
})
}
/// Generate embeddings for multiple spectrograms in batches.
///
/// This is more efficient than calling `embed_segment` multiple times
/// as it uses batched inference.
///
/// # Arguments
///
/// * `spectrograms` - Slice of input spectrograms
///
/// # Errors
///
/// Returns an error if any inference fails. Partial results are not returned.
#[instrument(skip(self, spectrograms), fields(count = spectrograms.len()))]
pub async fn embed_batch(
&self,
spectrograms: &[Spectrogram],
) -> Result<Vec<EmbeddingOutput>, EmbeddingError> {
if spectrograms.is_empty() {
return Ok(Vec::new());
}
let total_start = Instant::now();
let batch_id = uuid::Uuid::new_v4().to_string();
info!(
batch_id = %batch_id,
total_segments = spectrograms.len(),
batch_size = self.config.batch_size,
"Starting batch embedding"
);
// Get the inference session
let inference = self.model_manager.get_inference().await?;
let model_version = self.model_manager.current_version();
// Process in batches
let mut all_outputs = Vec::with_capacity(spectrograms.len());
for (batch_idx, chunk) in spectrograms.chunks(self.config.batch_size).enumerate() {
let batch_start = Instant::now();
// Prepare batch input
let inputs: Vec<&Array3<f32>> = chunk.iter().map(|s| &s.data).collect();
// Run batched inference
let raw_embeddings = inference.run_batch(&inputs)?;
let batch_latency_ms = batch_start.elapsed().as_secs_f32() * 1000.0;
let per_item_latency = batch_latency_ms / chunk.len() as f32;
// Process each embedding in the batch (parallelize normalization)
let outputs: Vec<Result<EmbeddingOutput, EmbeddingError>> = chunk
.par_iter()
.zip(raw_embeddings.par_iter())
.map(|(spectrogram, raw_emb)| {
let mut vector: Vec<f32> = raw_emb.iter().copied().collect();
let original_norm = normalization::compute_norm(&vector);
if self.config.normalize {
normalization::l2_normalize(&mut vector);
}
if self.config.validate_embeddings {
self.validate_embedding(&vector)?;
}
let sparsity = normalization::compute_sparsity(&vector);
let mut embedding = Embedding::new(
spectrogram.segment_id,
vector,
model_version.full_version(),
)?;
embedding.metadata = EmbeddingMetadata {
inference_latency_ms: Some(per_item_latency),
batch_id: Some(batch_id.clone()),
gpu_used: inference.is_gpu(),
original_norm: Some(original_norm),
sparsity: Some(sparsity),
quality_score: Some(self.compute_quality_score(&embedding)),
};
embedding.tier = self.config.default_tier;
Ok(EmbeddingOutput {
embedding,
gpu_used: inference.is_gpu(),
latency_ms: per_item_latency,
})
})
.collect();
// Check for errors
let batch_outputs: Result<Vec<_>, _> = outputs.into_iter().collect();
all_outputs.extend(batch_outputs?);
debug!(
batch_idx = batch_idx,
batch_size = chunk.len(),
latency_ms = batch_latency_ms,
"Completed batch"
);
}
let total_latency_ms = total_start.elapsed().as_secs_f32() * 1000.0;
let throughput = spectrograms.len() as f32 / (total_latency_ms / 1000.0);
info!(
batch_id = %batch_id,
total_segments = spectrograms.len(),
total_latency_ms = total_latency_ms,
throughput_per_sec = throughput,
"Completed batch embedding"
);
Ok(all_outputs)
}
/// Create a batch tracking object for monitoring progress.
#[must_use]
pub fn create_batch(&self, segment_ids: Vec<SegmentId>) -> EmbeddingBatch {
EmbeddingBatch::new(segment_ids)
}
/// Validate an embedding vector.
fn validate_embedding(&self, vector: &[f32]) -> Result<(), EmbeddingError> {
// Check dimensions
if vector.len() != EMBEDDING_DIM {
return Err(EmbeddingError::InvalidDimensions {
expected: EMBEDDING_DIM,
actual: vector.len(),
});
}
// Check for NaN values
if vector.iter().any(|x| x.is_nan()) {
return Err(EmbeddingError::Validation(
"Embedding contains NaN values".to_string(),
));
}
// Check for infinite values
if vector.iter().any(|x| x.is_infinite()) {
return Err(EmbeddingError::Validation(
"Embedding contains infinite values".to_string(),
));
}
// Check sparsity
let sparsity = normalization::compute_sparsity(vector);
if sparsity > self.config.max_sparsity {
warn!(
sparsity = sparsity,
max_sparsity = self.config.max_sparsity,
"Embedding has high sparsity"
);
}
Ok(())
}
/// Compute a quality score for an embedding.
fn compute_quality_score(&self, embedding: &Embedding) -> f32 {
let mut score = 1.0_f32;
// Penalize deviation from unit norm
let norm = embedding.norm();
let norm_deviation = (norm - 1.0).abs();
score -= norm_deviation * 0.5;
// Penalize high sparsity
if let Some(sparsity) = embedding.metadata.sparsity {
score -= sparsity * 0.3;
}
score.clamp(0.0, 1.0)
}
/// Get the current model version being used.
#[must_use]
pub fn model_version(&self) -> String {
self.model_manager.current_version().full_version()
}
/// Check if the service is ready for inference.
pub async fn is_ready(&self) -> bool {
self.model_manager.is_ready().await
}
}
/// Builder for creating embedding service instances
#[derive(Debug)]
pub struct EmbeddingServiceBuilder {
model_manager: Option<Arc<ModelManager>>,
config: EmbeddingServiceConfig,
}
impl EmbeddingServiceBuilder {
/// Create a new builder
#[must_use]
pub fn new() -> Self {
Self {
model_manager: None,
config: EmbeddingServiceConfig::default(),
}
}
/// Set the model manager
#[must_use]
pub fn model_manager(mut self, manager: Arc<ModelManager>) -> Self {
self.model_manager = Some(manager);
self
}
/// Set the batch size
#[must_use]
pub fn batch_size(mut self, size: usize) -> Self {
self.config.batch_size = size;
self
}
/// Set whether to normalize embeddings
#[must_use]
pub fn normalize(mut self, normalize: bool) -> Self {
self.config.normalize = normalize;
self
}
/// Set the default storage tier
#[must_use]
pub fn default_tier(mut self, tier: StorageTier) -> Self {
self.config.default_tier = tier;
self
}
/// Set whether to validate embeddings
#[must_use]
pub fn validate_embeddings(mut self, validate: bool) -> Self {
self.config.validate_embeddings = validate;
self
}
/// Build the embedding service
///
/// # Errors
///
/// Returns an error if the model manager is not set.
pub fn build(self) -> Result<EmbeddingService, EmbeddingError> {
let model_manager = self.model_manager.ok_or_else(|| {
EmbeddingError::Validation("Model manager is required".to_string())
})?;
Ok(EmbeddingService::with_config(model_manager, self.config))
}
}
impl Default for EmbeddingServiceBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
#[test]
fn test_spectrogram_creation() {
let data = Array2::zeros((MEL_FRAMES, MEL_BINS));
let segment_id = SegmentId::new();
let spec = Spectrogram::new(data, segment_id);
assert!(spec.is_ok());
}
#[test]
fn test_spectrogram_invalid_dimensions() {
let data = Array2::zeros((100, 100)); // Wrong dimensions
let segment_id = SegmentId::new();
let spec = Spectrogram::new(data, segment_id);
assert!(spec.is_err());
}
#[test]
fn test_service_config_default() {
let config = EmbeddingServiceConfig::default();
assert_eq!(config.batch_size, 8);
assert!(config.normalize);
assert!(config.validate_embeddings);
}
#[test]
fn test_service_builder() {
let builder = EmbeddingServiceBuilder::new()
.batch_size(16)
.normalize(false)
.default_tier(StorageTier::Warm);
assert_eq!(builder.config.batch_size, 16);
assert!(!builder.config.normalize);
assert_eq!(builder.config.default_tier, StorageTier::Warm);
}
}

View File

@@ -0,0 +1,627 @@
//! Domain entities for the embedding bounded context.
//!
//! This module defines the core domain entities:
//! - `Embedding`: A 1536-dimensional vector representation of an audio segment
//! - `EmbeddingModel`: Configuration and metadata for Perch 2.0 ONNX model
//! - `EmbeddingBatch`: Collection of embeddings processed together
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::EMBEDDING_DIM;
/// Unique identifier for an embedding
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct EmbeddingId(Uuid);
impl EmbeddingId {
/// Create a new unique embedding ID
#[must_use]
pub fn new() -> Self {
Self(Uuid::new_v4())
}
/// Create from an existing UUID
#[must_use]
pub const fn from_uuid(uuid: Uuid) -> Self {
Self(uuid)
}
/// Get the inner UUID
#[must_use]
pub const fn as_uuid(&self) -> &Uuid {
&self.0
}
}
impl Default for EmbeddingId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for EmbeddingId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
/// Unique identifier for an audio segment
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct SegmentId(Uuid);
impl SegmentId {
/// Create a new unique segment ID
#[must_use]
pub fn new() -> Self {
Self(Uuid::new_v4())
}
/// Create from an existing UUID
#[must_use]
pub const fn from_uuid(uuid: Uuid) -> Self {
Self(uuid)
}
/// Get the inner UUID
#[must_use]
pub const fn as_uuid(&self) -> &Uuid {
&self.0
}
}
impl Default for SegmentId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for SegmentId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
/// Storage tier for embeddings based on access patterns
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum StorageTier {
/// Hot storage: frequently accessed, lowest latency
/// Full f32 precision, in-memory or SSD
Hot,
/// Warm storage: occasional access, moderate latency
/// F16 quantized, SSD storage
Warm,
/// Cold storage: rare access, higher latency acceptable
/// INT8 quantized, archive storage
Cold,
}
impl Default for StorageTier {
fn default() -> Self {
Self::Hot
}
}
impl StorageTier {
/// Get the bytes per dimension for this tier
#[must_use]
pub const fn bytes_per_dim(&self) -> usize {
match self {
Self::Hot => 4, // f32
Self::Warm => 2, // f16
Self::Cold => 1, // i8
}
}
/// Get the total bytes for an embedding in this tier
#[must_use]
pub const fn embedding_bytes(&self) -> usize {
self.bytes_per_dim() * EMBEDDING_DIM
}
}
/// Timestamp type alias for consistency
pub type Timestamp = DateTime<Utc>;
/// A 1536-dimensional embedding vector representing an audio segment.
///
/// This is the aggregate root for the embedding context. Each embedding
/// is generated by the Perch 2.0 model from a preprocessed audio segment.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Embedding {
/// Unique identifier for this embedding
pub id: EmbeddingId,
/// Reference to the source audio segment
pub segment_id: SegmentId,
/// The 1536-dimensional embedding vector (L2 normalized)
pub vector: Vec<f32>,
/// Version of the model used to generate this embedding
pub model_version: String,
/// When this embedding was created
pub created_at: Timestamp,
/// Storage tier for this embedding
pub tier: StorageTier,
/// Additional metadata about the embedding
pub metadata: EmbeddingMetadata,
}
impl Embedding {
/// Create a new embedding with the given parameters
///
/// # Errors
///
/// Returns an error if the vector dimension is not 1536
pub fn new(
segment_id: SegmentId,
vector: Vec<f32>,
model_version: String,
) -> Result<Self, crate::EmbeddingError> {
if vector.len() != EMBEDDING_DIM {
return Err(crate::EmbeddingError::InvalidDimensions {
expected: EMBEDDING_DIM,
actual: vector.len(),
});
}
Ok(Self {
id: EmbeddingId::new(),
segment_id,
vector,
model_version,
created_at: Utc::now(),
tier: StorageTier::default(),
metadata: EmbeddingMetadata::default(),
})
}
/// Get the L2 norm of the embedding vector
#[must_use]
pub fn norm(&self) -> f32 {
self.vector
.iter()
.map(|x| x * x)
.sum::<f32>()
.sqrt()
}
/// Check if the embedding is properly L2 normalized (norm close to 1.0)
#[must_use]
pub fn is_normalized(&self) -> bool {
let norm = self.norm();
(0.99..=1.01).contains(&norm)
}
/// Check if the embedding contains any invalid values (NaN or Inf)
#[must_use]
pub fn is_valid(&self) -> bool {
!self.vector.iter().any(|x| x.is_nan() || x.is_infinite())
}
/// Compute cosine similarity with another embedding
///
/// For L2-normalized embeddings, this is equivalent to dot product
#[must_use]
pub fn cosine_similarity(&self, other: &Self) -> f32 {
self.vector
.iter()
.zip(other.vector.iter())
.map(|(a, b)| a * b)
.sum()
}
/// Change the storage tier for this embedding
pub fn set_tier(&mut self, tier: StorageTier) {
self.tier = tier;
}
}
/// Metadata about an embedding's generation
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EmbeddingMetadata {
/// Time taken for inference in milliseconds
pub inference_latency_ms: Option<f32>,
/// Batch ID if processed in a batch
pub batch_id: Option<String>,
/// Whether GPU was used for inference
pub gpu_used: bool,
/// Original norm before L2 normalization
pub original_norm: Option<f32>,
/// Sparsity of the embedding (fraction of near-zero values)
pub sparsity: Option<f32>,
/// Quality score based on various metrics
pub quality_score: Option<f32>,
}
/// Model version identifier
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModelVersion {
/// Model name (e.g., "perch-v2")
pub name: String,
/// Semantic version string (e.g., "2.0.0")
pub version: String,
/// Model variant (e.g., "base", "quantized", "pruned")
pub variant: String,
}
impl ModelVersion {
/// Create a new model version
#[must_use]
pub fn new(name: impl Into<String>, version: impl Into<String>, variant: impl Into<String>) -> Self {
Self {
name: name.into(),
version: version.into(),
variant: variant.into(),
}
}
/// Get the default Perch 2.0 base model version
#[must_use]
pub fn perch_v2_base() -> Self {
Self::new("perch-v2", "2.0.0", "base")
}
/// Get the Perch 2.0 quantized model version
#[must_use]
pub fn perch_v2_quantized() -> Self {
Self::new("perch-v2", "2.0.0", "quantized")
}
/// Get the full version string
#[must_use]
pub fn full_version(&self) -> String {
format!("{}-{}-{}", self.name, self.version, self.variant)
}
}
impl Default for ModelVersion {
fn default() -> Self {
Self::perch_v2_base()
}
}
impl std::fmt::Display for ModelVersion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.full_version())
}
}
/// Input specification for the embedding model
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InputSpecification {
/// Expected sample rate in Hz
pub sample_rate: u32,
/// Window duration in seconds
pub window_duration: f32,
/// Number of samples per window
pub window_samples: usize,
/// Number of mel frequency bins
pub mel_bins: usize,
/// Number of time frames
pub mel_frames: usize,
/// Frequency range (low, high) in Hz
pub frequency_range: (f32, f32),
}
impl Default for InputSpecification {
fn default() -> Self {
Self {
sample_rate: crate::TARGET_SAMPLE_RATE,
window_duration: crate::TARGET_WINDOW_SECONDS,
window_samples: crate::TARGET_WINDOW_SAMPLES,
mel_bins: crate::MEL_BINS,
mel_frames: crate::MEL_FRAMES,
frequency_range: (60.0, 16000.0),
}
}
}
/// Status of an embedding model
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ModelStatus {
/// Model is available and ready for inference
Active,
/// Model is being loaded
Loading,
/// Model failed to load
Failed,
/// Model is deprecated and should not be used
Deprecated,
}
/// Configuration and metadata for an embedding model.
///
/// Represents the Perch 2.0 ONNX model used for generating embeddings.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingModel {
/// Model name
pub name: String,
/// Model version
pub version: ModelVersion,
/// Output embedding dimensions
pub dimensions: usize,
/// SHA-256 checksum of the model file
pub checksum: String,
/// Input specification
pub input_spec: InputSpecification,
/// Current model status
pub status: ModelStatus,
/// When the model was last loaded
pub loaded_at: Option<Timestamp>,
/// Path to the model file
pub model_path: Option<String>,
}
impl EmbeddingModel {
/// Create a new embedding model configuration
#[must_use]
pub fn new(name: String, version: ModelVersion, checksum: String) -> Self {
Self {
name,
version,
dimensions: EMBEDDING_DIM,
checksum,
input_spec: InputSpecification::default(),
status: ModelStatus::Loading,
loaded_at: None,
model_path: None,
}
}
/// Create a default Perch 2.0 model configuration
#[must_use]
pub fn perch_v2_default() -> Self {
Self::new(
"perch-v2".to_string(),
ModelVersion::perch_v2_base(),
String::new(), // Checksum will be computed on load
)
}
/// Check if the model is ready for inference
#[must_use]
pub const fn is_ready(&self) -> bool {
matches!(self.status, ModelStatus::Active)
}
/// Mark the model as active
pub fn mark_active(&mut self) {
self.status = ModelStatus::Active;
self.loaded_at = Some(Utc::now());
}
/// Mark the model as failed
pub fn mark_failed(&mut self) {
self.status = ModelStatus::Failed;
}
}
/// A batch of embeddings processed together
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingBatch {
/// Unique batch identifier
pub id: String,
/// Segment IDs in this batch
pub segment_ids: Vec<SegmentId>,
/// Batch processing status
pub status: BatchStatus,
/// When batch processing started
pub started_at: Timestamp,
/// When batch processing completed
pub completed_at: Option<Timestamp>,
/// Batch processing metrics
pub metrics: BatchMetrics,
}
/// Status of a batch operation
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BatchStatus {
/// Batch is queued for processing
Pending,
/// Batch is currently being processed
Processing,
/// Batch processing completed successfully
Completed,
/// Batch processing failed
Failed,
}
/// Metrics for batch processing
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct BatchMetrics {
/// Total segments in batch
pub total_segments: u32,
/// Successfully processed segments
pub success_count: u32,
/// Failed segments
pub failure_count: u32,
/// Average inference latency in milliseconds
pub avg_latency_ms: f32,
/// Throughput in segments per second
pub throughput: f32,
}
impl EmbeddingBatch {
/// Create a new embedding batch
#[must_use]
pub fn new(segment_ids: Vec<SegmentId>) -> Self {
let total = segment_ids.len() as u32;
Self {
id: Uuid::new_v4().to_string(),
segment_ids,
status: BatchStatus::Pending,
started_at: Utc::now(),
completed_at: None,
metrics: BatchMetrics {
total_segments: total,
..Default::default()
},
}
}
/// Mark batch as processing
pub fn mark_processing(&mut self) {
self.status = BatchStatus::Processing;
}
/// Mark batch as completed with metrics
pub fn mark_completed(&mut self, success_count: u32, failure_count: u32, avg_latency_ms: f32) {
self.status = BatchStatus::Completed;
self.completed_at = Some(Utc::now());
self.metrics.success_count = success_count;
self.metrics.failure_count = failure_count;
self.metrics.avg_latency_ms = avg_latency_ms;
if let Some(completed) = self.completed_at {
let duration = (completed - self.started_at).num_milliseconds() as f32 / 1000.0;
if duration > 0.0 {
self.metrics.throughput = self.metrics.total_segments as f32 / duration;
}
}
}
/// Mark batch as failed
pub fn mark_failed(&mut self) {
self.status = BatchStatus::Failed;
self.completed_at = Some(Utc::now());
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_id_generation() {
let id1 = EmbeddingId::new();
let id2 = EmbeddingId::new();
assert_ne!(id1, id2);
}
#[test]
fn test_embedding_creation() {
let segment_id = SegmentId::new();
let vector = vec![0.0; EMBEDDING_DIM];
let embedding = Embedding::new(segment_id, vector, "perch-v2-2.0.0-base".to_string());
assert!(embedding.is_ok());
}
#[test]
fn test_embedding_invalid_dimensions() {
let segment_id = SegmentId::new();
let vector = vec![0.0; 100]; // Wrong dimension
let result = Embedding::new(segment_id, vector, "perch-v2-2.0.0-base".to_string());
assert!(result.is_err());
}
#[test]
fn test_embedding_norm() {
let segment_id = SegmentId::new();
let mut vector = vec![0.0; EMBEDDING_DIM];
vector[0] = 1.0; // Unit vector
let embedding = Embedding::new(segment_id, vector, "perch-v2-2.0.0-base".to_string()).unwrap();
assert!((embedding.norm() - 1.0).abs() < 1e-6);
assert!(embedding.is_normalized());
}
#[test]
fn test_cosine_similarity() {
let segment_id = SegmentId::new();
// Create two identical normalized vectors
let mut vector = vec![0.0; EMBEDDING_DIM];
vector[0] = 1.0;
let emb1 = Embedding::new(segment_id, vector.clone(), "test".to_string()).unwrap();
let emb2 = Embedding::new(segment_id, vector, "test".to_string()).unwrap();
let similarity = emb1.cosine_similarity(&emb2);
assert!((similarity - 1.0).abs() < 1e-6);
}
#[test]
fn test_storage_tier_bytes() {
assert_eq!(StorageTier::Hot.bytes_per_dim(), 4);
assert_eq!(StorageTier::Warm.bytes_per_dim(), 2);
assert_eq!(StorageTier::Cold.bytes_per_dim(), 1);
assert_eq!(StorageTier::Hot.embedding_bytes(), EMBEDDING_DIM * 4);
}
#[test]
fn test_model_version() {
let version = ModelVersion::perch_v2_base();
assert_eq!(version.name, "perch-v2");
assert_eq!(version.version, "2.0.0");
assert_eq!(version.variant, "base");
assert_eq!(version.full_version(), "perch-v2-2.0.0-base");
}
#[test]
fn test_input_specification_defaults() {
let spec = InputSpecification::default();
assert_eq!(spec.sample_rate, 32000);
assert_eq!(spec.window_samples, 160_000);
assert_eq!(spec.mel_bins, 128);
assert_eq!(spec.mel_frames, 500);
}
#[test]
fn test_embedding_batch_lifecycle() {
let segment_ids = vec![SegmentId::new(), SegmentId::new()];
let mut batch = EmbeddingBatch::new(segment_ids);
assert_eq!(batch.status, BatchStatus::Pending);
assert_eq!(batch.metrics.total_segments, 2);
batch.mark_processing();
assert_eq!(batch.status, BatchStatus::Processing);
batch.mark_completed(2, 0, 50.0);
assert_eq!(batch.status, BatchStatus::Completed);
assert_eq!(batch.metrics.success_count, 2);
assert!(batch.completed_at.is_some());
}
}

View File

@@ -0,0 +1,7 @@
//! Domain layer for embedding bounded context.
//!
//! Contains core entities, value objects, and repository traits that define
//! the embedding domain model.
pub mod entities;
pub mod repository;

View File

@@ -0,0 +1,376 @@
//! Repository traits for the embedding bounded context.
//!
//! Defines the interfaces for persisting and retrieving embeddings,
//! following the repository pattern from Domain-Driven Design.
use async_trait::async_trait;
use super::entities::{Embedding, EmbeddingId, EmbeddingModel, SegmentId, StorageTier};
use crate::EmbeddingError;
/// Repository trait for embedding persistence.
///
/// Implementations may use various storage backends:
/// - In-memory (for testing)
/// - Vector databases (Qdrant, Milvus)
/// - Relational databases (PostgreSQL with pgvector)
/// - File-based storage
#[async_trait]
pub trait EmbeddingRepository: Send + Sync {
/// Save a single embedding to the repository.
///
/// # Errors
///
/// Returns an error if the embedding cannot be persisted.
async fn save(&self, embedding: &Embedding) -> Result<(), EmbeddingError>;
/// Find an embedding by its unique identifier.
///
/// # Returns
///
/// - `Ok(Some(embedding))` if found
/// - `Ok(None)` if not found
/// - `Err(...)` on storage errors
async fn find_by_id(&self, id: &EmbeddingId) -> Result<Option<Embedding>, EmbeddingError>;
/// Find an embedding by its source segment ID.
///
/// # Returns
///
/// - `Ok(Some(embedding))` if found
/// - `Ok(None)` if not found
/// - `Err(...)` on storage errors
async fn find_by_segment(&self, segment_id: &SegmentId) -> Result<Option<Embedding>, EmbeddingError>;
/// Save multiple embeddings in a batch operation.
///
/// This is more efficient than calling `save` multiple times
/// as it can use bulk insert operations.
///
/// # Errors
///
/// Returns an error if any embedding cannot be persisted.
/// Implementations should document their atomicity guarantees.
async fn batch_save(&self, embeddings: &[Embedding]) -> Result<(), EmbeddingError>;
/// Delete an embedding by its ID.
///
/// # Returns
///
/// - `Ok(true)` if the embedding was deleted
/// - `Ok(false)` if the embedding was not found
/// - `Err(...)` on storage errors
async fn delete(&self, id: &EmbeddingId) -> Result<bool, EmbeddingError>;
/// Delete embeddings by segment ID.
///
/// Useful when a segment is reprocessed with a new model version.
///
/// # Returns
///
/// Number of embeddings deleted.
async fn delete_by_segment(&self, segment_id: &SegmentId) -> Result<usize, EmbeddingError>;
/// Count total embeddings in the repository.
async fn count(&self) -> Result<u64, EmbeddingError>;
/// Count embeddings by storage tier.
async fn count_by_tier(&self, tier: StorageTier) -> Result<u64, EmbeddingError>;
/// Find embeddings by model version.
///
/// Useful for identifying embeddings that need re-generation
/// after a model update.
async fn find_by_model_version(
&self,
model_version: &str,
limit: usize,
offset: usize,
) -> Result<Vec<Embedding>, EmbeddingError>;
/// Update the storage tier for an embedding.
///
/// Used for tiered storage management (hot -> warm -> cold).
async fn update_tier(
&self,
id: &EmbeddingId,
tier: StorageTier,
) -> Result<bool, EmbeddingError>;
/// Check if an embedding exists.
async fn exists(&self, id: &EmbeddingId) -> Result<bool, EmbeddingError> {
Ok(self.find_by_id(id).await?.is_some())
}
/// Check if an embedding exists for a segment.
async fn exists_for_segment(&self, segment_id: &SegmentId) -> Result<bool, EmbeddingError> {
Ok(self.find_by_segment(segment_id).await?.is_some())
}
}
/// Repository trait for embedding model management.
///
/// Manages the lifecycle of ONNX models used for embedding generation.
#[async_trait]
pub trait ModelRepository: Send + Sync {
/// Save or update a model configuration.
async fn save_model(&self, model: &EmbeddingModel) -> Result<(), EmbeddingError>;
/// Find a model by name and version.
async fn find_model(
&self,
name: &str,
version: &str,
) -> Result<Option<EmbeddingModel>, EmbeddingError>;
/// Get the currently active model for a given name.
async fn get_active_model(&self, name: &str) -> Result<Option<EmbeddingModel>, EmbeddingError>;
/// List all available models.
async fn list_models(&self) -> Result<Vec<EmbeddingModel>, EmbeddingError>;
/// Delete a model configuration.
async fn delete_model(&self, name: &str, version: &str) -> Result<bool, EmbeddingError>;
}
/// Query parameters for embedding search operations.
#[derive(Debug, Clone, Default)]
pub struct EmbeddingQuery {
/// Filter by segment IDs
pub segment_ids: Option<Vec<SegmentId>>,
/// Filter by model version
pub model_version: Option<String>,
/// Filter by storage tier
pub tier: Option<StorageTier>,
/// Filter by creation date (after)
pub created_after: Option<chrono::DateTime<chrono::Utc>>,
/// Filter by creation date (before)
pub created_before: Option<chrono::DateTime<chrono::Utc>>,
/// Maximum results to return
pub limit: Option<usize>,
/// Offset for pagination
pub offset: Option<usize>,
}
impl EmbeddingQuery {
/// Create a new query builder
#[must_use]
pub fn new() -> Self {
Self::default()
}
/// Filter by segment IDs
#[must_use]
pub fn with_segment_ids(mut self, ids: Vec<SegmentId>) -> Self {
self.segment_ids = Some(ids);
self
}
/// Filter by model version
#[must_use]
pub fn with_model_version(mut self, version: impl Into<String>) -> Self {
self.model_version = Some(version.into());
self
}
/// Filter by storage tier
#[must_use]
pub const fn with_tier(mut self, tier: StorageTier) -> Self {
self.tier = Some(tier);
self
}
/// Set pagination limit
#[must_use]
pub const fn with_limit(mut self, limit: usize) -> Self {
self.limit = Some(limit);
self
}
/// Set pagination offset
#[must_use]
pub const fn with_offset(mut self, offset: usize) -> Self {
self.offset = Some(offset);
self
}
}
/// Extended repository trait with query support.
#[async_trait]
pub trait QueryableEmbeddingRepository: EmbeddingRepository {
/// Query embeddings with filters.
async fn query(&self, query: &EmbeddingQuery) -> Result<Vec<Embedding>, EmbeddingError>;
/// Find k nearest neighbors to a query vector.
///
/// # Arguments
///
/// * `query_vector` - The query embedding vector
/// * `k` - Number of neighbors to return
/// * `ef_search` - HNSW search parameter (larger = more accurate, slower)
///
/// # Returns
///
/// Vector of (embedding, distance) pairs, sorted by distance ascending.
async fn find_nearest(
&self,
query_vector: &[f32],
k: usize,
ef_search: Option<usize>,
) -> Result<Vec<(Embedding, f32)>, EmbeddingError>;
/// Find embeddings within a distance threshold.
async fn find_within_distance(
&self,
query_vector: &[f32],
max_distance: f32,
) -> Result<Vec<(Embedding, f32)>, EmbeddingError>;
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
/// In-memory implementation for testing
struct InMemoryEmbeddingRepository {
embeddings: Arc<RwLock<HashMap<EmbeddingId, Embedding>>>,
}
impl InMemoryEmbeddingRepository {
fn new() -> Self {
Self {
embeddings: Arc::new(RwLock::new(HashMap::new())),
}
}
}
#[async_trait]
impl EmbeddingRepository for InMemoryEmbeddingRepository {
async fn save(&self, embedding: &Embedding) -> Result<(), EmbeddingError> {
self.embeddings.write().await.insert(embedding.id, embedding.clone());
Ok(())
}
async fn find_by_id(&self, id: &EmbeddingId) -> Result<Option<Embedding>, EmbeddingError> {
Ok(self.embeddings.read().await.get(id).cloned())
}
async fn find_by_segment(&self, segment_id: &SegmentId) -> Result<Option<Embedding>, EmbeddingError> {
Ok(self.embeddings.read().await.values()
.find(|e| e.segment_id == *segment_id)
.cloned())
}
async fn batch_save(&self, embeddings: &[Embedding]) -> Result<(), EmbeddingError> {
let mut store = self.embeddings.write().await;
for embedding in embeddings {
store.insert(embedding.id, embedding.clone());
}
Ok(())
}
async fn delete(&self, id: &EmbeddingId) -> Result<bool, EmbeddingError> {
Ok(self.embeddings.write().await.remove(id).is_some())
}
async fn delete_by_segment(&self, segment_id: &SegmentId) -> Result<usize, EmbeddingError> {
let mut store = self.embeddings.write().await;
let to_remove: Vec<_> = store.iter()
.filter(|(_, e)| e.segment_id == *segment_id)
.map(|(id, _)| *id)
.collect();
let count = to_remove.len();
for id in to_remove {
store.remove(&id);
}
Ok(count)
}
async fn count(&self) -> Result<u64, EmbeddingError> {
Ok(self.embeddings.read().await.len() as u64)
}
async fn count_by_tier(&self, tier: StorageTier) -> Result<u64, EmbeddingError> {
Ok(self.embeddings.read().await.values()
.filter(|e| e.tier == tier)
.count() as u64)
}
async fn find_by_model_version(
&self,
model_version: &str,
limit: usize,
offset: usize,
) -> Result<Vec<Embedding>, EmbeddingError> {
Ok(self.embeddings.read().await.values()
.filter(|e| e.model_version == model_version)
.skip(offset)
.take(limit)
.cloned()
.collect())
}
async fn update_tier(
&self,
id: &EmbeddingId,
tier: StorageTier,
) -> Result<bool, EmbeddingError> {
let mut store = self.embeddings.write().await;
if let Some(embedding) = store.get_mut(id) {
embedding.tier = tier;
Ok(true)
} else {
Ok(false)
}
}
}
#[tokio::test]
async fn test_in_memory_repository() {
let repo = InMemoryEmbeddingRepository::new();
let segment_id = SegmentId::new();
let vector = vec![0.0; crate::EMBEDDING_DIM];
let embedding = Embedding::new(segment_id, vector, "test".to_string()).unwrap();
// Save
repo.save(&embedding).await.unwrap();
// Find by ID
let found = repo.find_by_id(&embedding.id).await.unwrap();
assert!(found.is_some());
// Find by segment
let found = repo.find_by_segment(&segment_id).await.unwrap();
assert!(found.is_some());
// Count
assert_eq!(repo.count().await.unwrap(), 1);
// Delete
assert!(repo.delete(&embedding.id).await.unwrap());
assert_eq!(repo.count().await.unwrap(), 0);
}
#[test]
fn test_embedding_query_builder() {
let query = EmbeddingQuery::new()
.with_model_version("perch-v2-2.0.0-base")
.with_tier(StorageTier::Hot)
.with_limit(100)
.with_offset(0);
assert_eq!(query.model_version.as_deref(), Some("perch-v2-2.0.0-base"));
assert_eq!(query.tier, Some(StorageTier::Hot));
assert_eq!(query.limit, Some(100));
assert_eq!(query.offset, Some(0));
}
}

View File

@@ -0,0 +1,7 @@
//! Infrastructure layer for embedding bounded context.
//!
//! Contains implementations for ONNX model loading, inference,
//! and integration with external systems.
pub mod model_manager;
pub mod onnx_inference;

View File

@@ -0,0 +1,511 @@
//! Model management for ONNX embedding models.
//!
//! Provides thread-safe loading, caching, and hot-swapping of
//! Perch 2.0 ONNX models for embedding generation.
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use parking_lot::RwLock;
use sha2::{Digest, Sha256};
use thiserror::Error;
use tracing::{debug, info, instrument, warn};
use super::onnx_inference::OnnxInference;
use crate::domain::entities::{EmbeddingModel, ModelVersion};
/// Errors that can occur during model management
#[derive(Debug, Error)]
pub enum ModelError {
/// Model file not found
#[error("Model not found: {0}")]
NotFound(String),
/// Failed to load model
#[error("Failed to load model: {0}")]
LoadFailed(String),
/// Checksum verification failed
#[error("Checksum mismatch for model {model}: expected {expected}, got {actual}")]
ChecksumMismatch {
/// Model name
model: String,
/// Expected checksum
expected: String,
/// Actual checksum
actual: String,
},
/// Model initialization failed
#[error("Model initialization failed: {0}")]
InitializationFailed(String),
/// IO error
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
/// ONNX Runtime error
#[error("ONNX Runtime error: {0}")]
OnnxRuntime(String),
/// Model not ready
#[error("Model not ready: {0}")]
NotReady(String),
}
/// Configuration for the model manager
#[derive(Debug, Clone)]
pub struct ModelConfig {
/// Directory containing model files
pub model_dir: PathBuf,
/// Number of threads for intra-op parallelism
pub intra_op_threads: usize,
/// Number of threads for inter-op parallelism
pub inter_op_threads: usize,
/// Whether to verify model checksums on load
pub verify_checksums: bool,
/// Execution providers in priority order
pub execution_providers: Vec<ExecutionProvider>,
/// Maximum number of cached sessions
pub max_cached_sessions: usize,
}
impl Default for ModelConfig {
fn default() -> Self {
Self {
model_dir: PathBuf::from("models"),
intra_op_threads: num_cpus::get().min(4),
inter_op_threads: 1,
verify_checksums: true,
execution_providers: vec![
ExecutionProvider::Cuda { device_id: 0 },
ExecutionProvider::CoreML,
ExecutionProvider::Cpu,
],
max_cached_sessions: 4,
}
}
}
/// Execution provider for ONNX Runtime
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExecutionProvider {
/// CPU execution
Cpu,
/// NVIDIA CUDA execution
Cuda {
/// GPU device ID
device_id: i32,
},
/// Apple CoreML execution
CoreML,
/// DirectML execution (Windows)
DirectML {
/// Device ID
device_id: i32,
},
}
/// Thread-safe model session manager with caching and hot-swap support.
///
/// Manages the lifecycle of ONNX models used for embedding generation,
/// including loading, caching, and version management.
pub struct ModelManager {
/// Cached model sessions by version
sessions: RwLock<HashMap<String, Arc<OnnxInference>>>,
/// Model metadata by version
models: RwLock<HashMap<String, EmbeddingModel>>,
/// Currently active model version
active_version: RwLock<ModelVersion>,
/// Configuration
config: ModelConfig,
}
impl ModelManager {
/// Create a new model manager with the given configuration.
///
/// # Errors
///
/// Returns an error if the model directory doesn't exist and can't be created.
pub fn new(config: ModelConfig) -> Result<Self, ModelError> {
// Ensure model directory exists
if !config.model_dir.exists() {
std::fs::create_dir_all(&config.model_dir)?;
debug!(path = ?config.model_dir, "Created model directory");
}
Ok(Self {
sessions: RwLock::new(HashMap::new()),
models: RwLock::new(HashMap::new()),
active_version: RwLock::new(ModelVersion::perch_v2_base()),
config,
})
}
/// Create with default configuration
pub fn with_defaults() -> Result<Self, ModelError> {
Self::new(ModelConfig::default())
}
/// Load a model from a file.
///
/// # Arguments
///
/// * `name` - Model name (e.g., "perch-v2")
///
/// # Errors
///
/// Returns an error if the model file doesn't exist or fails to load.
#[instrument(skip(self), fields(model = %name))]
pub fn load_model(&self, name: &str) -> Result<Arc<OnnxInference>, ModelError> {
let version = self.active_version.read().clone();
let version_key = version.full_version();
// Check cache first
{
let sessions = self.sessions.read();
if let Some(session) = sessions.get(&version_key) {
debug!("Using cached session for {}", version_key);
return Ok(Arc::clone(session));
}
}
// Resolve model path
let model_path = self.resolve_model_path(name, &version)?;
// Verify checksum if configured
if self.config.verify_checksums {
if let Some(model) = self.models.read().get(&version_key) {
if !model.checksum.is_empty() {
self.verify_checksum(&model_path, &model.checksum)?;
}
}
}
// Create new session
info!(path = ?model_path, "Loading model");
let session = self.create_session(&model_path)?;
let session = Arc::new(session);
// Cache the session
{
let mut sessions = self.sessions.write();
// Evict old sessions if at capacity
while sessions.len() >= self.config.max_cached_sessions {
if let Some(key) = sessions.keys().next().cloned() {
sessions.remove(&key);
debug!("Evicted cached session: {}", key);
}
}
sessions.insert(version_key.clone(), Arc::clone(&session));
}
// Update model metadata
{
let mut models = self.models.write();
if let Some(model) = models.get_mut(&version_key) {
model.mark_active();
}
}
info!(version = %version_key, "Model loaded successfully");
Ok(session)
}
/// Verify the checksum of a model file.
///
/// # Errors
///
/// Returns an error if the checksum doesn't match.
pub fn verify_checksum(&self, path: &Path, expected: &str) -> Result<bool, ModelError> {
let actual = self.compute_checksum(path)?;
if actual != expected {
return Err(ModelError::ChecksumMismatch {
model: path.display().to_string(),
expected: expected.to_string(),
actual,
});
}
debug!(path = ?path, "Checksum verified");
Ok(true)
}
/// Compute the SHA-256 checksum of a file.
fn compute_checksum(&self, path: &Path) -> Result<String, ModelError> {
let mut file = std::fs::File::open(path)?;
let mut hasher = Sha256::new();
std::io::copy(&mut file, &mut hasher)?;
let hash = hasher.finalize();
Ok(hex::encode(hash))
}
/// Hot-swap to a new model version without restart.
///
/// # Arguments
///
/// * `name` - Model name
/// * `new_path` - Path to the new model file
///
/// # Errors
///
/// Returns an error if the new model fails to load.
#[instrument(skip(self, new_path), fields(model = %name, path = ?new_path))]
pub fn hot_swap(&self, name: &str, new_path: &Path) -> Result<(), ModelError> {
// Validate the new model can be loaded
info!("Attempting hot-swap to new model");
let new_session = self.create_session(new_path)?;
// Compute checksum for the new model
let checksum = self.compute_checksum(new_path)?;
// Create new version
let old_version = self.active_version.read().clone();
let new_version = ModelVersion::new(
name,
&old_version.version, // Keep same semantic version
"hot-swap",
);
let version_key = new_version.full_version();
// Update sessions cache
{
let mut sessions = self.sessions.write();
sessions.insert(version_key.clone(), Arc::new(new_session));
}
// Update model metadata
{
let mut models = self.models.write();
let mut model = EmbeddingModel::new(
name.to_string(),
new_version.clone(),
checksum,
);
model.model_path = Some(new_path.to_string_lossy().to_string());
model.mark_active();
models.insert(version_key, model);
}
// Update active version
*self.active_version.write() = new_version.clone();
info!(
old_version = %old_version,
new_version = %new_version,
"Hot-swap completed successfully"
);
Ok(())
}
/// Get the ONNX inference engine for the current model.
///
/// # Errors
///
/// Returns an error if no model is loaded.
pub async fn get_inference(&self) -> Result<Arc<OnnxInference>, ModelError> {
let version = self.active_version.read().clone();
self.load_model(&version.name)
}
/// Get the currently active model version.
#[must_use]
pub fn current_version(&self) -> ModelVersion {
self.active_version.read().clone()
}
/// Set the active model version.
pub fn set_active_version(&self, version: ModelVersion) {
*self.active_version.write() = version;
}
/// Check if a model is loaded and ready.
pub async fn is_ready(&self) -> bool {
let version_key = self.active_version.read().full_version();
self.sessions.read().contains_key(&version_key)
}
/// Get model metadata for a version.
#[must_use]
pub fn get_model(&self, version_key: &str) -> Option<EmbeddingModel> {
self.models.read().get(version_key).cloned()
}
/// List all loaded models.
#[must_use]
pub fn list_models(&self) -> Vec<EmbeddingModel> {
self.models.read().values().cloned().collect()
}
/// Clear all cached sessions.
pub fn clear_cache(&self) {
self.sessions.write().clear();
info!("Cleared model session cache");
}
/// Resolve the path to a model file.
fn resolve_model_path(&self, name: &str, version: &ModelVersion) -> Result<PathBuf, ModelError> {
// Try various naming conventions
let candidates = vec![
self.config.model_dir.join(format!("{}.onnx", version.full_version())),
self.config.model_dir.join(format!("{}_{}.onnx", name, version.version)),
self.config.model_dir.join(format!("{}.onnx", name)),
self.config.model_dir.join(format!("{}/{}.onnx", name, version.version)),
];
for path in &candidates {
if path.exists() {
return Ok(path.clone());
}
}
// Also check if there's a model metadata entry with a path
let version_key = version.full_version();
if let Some(model) = self.models.read().get(&version_key) {
if let Some(ref path_str) = model.model_path {
let path = PathBuf::from(path_str);
if path.exists() {
return Ok(path);
}
}
}
Err(ModelError::NotFound(format!(
"Model {} not found in {:?}. Tried: {:?}",
name, self.config.model_dir, candidates
)))
}
/// Create an ONNX inference session from a model file.
fn create_session(&self, path: &Path) -> Result<OnnxInference, ModelError> {
OnnxInference::new(
path,
self.config.intra_op_threads,
self.config.inter_op_threads,
&self.config.execution_providers,
)
.map_err(|e| ModelError::LoadFailed(e.to_string()))
}
/// Register a model without loading it.
pub fn register_model(&self, model: EmbeddingModel) {
let version_key = model.version.full_version();
self.models.write().insert(version_key, model);
}
/// Unload a specific model version from cache.
pub fn unload_model(&self, version_key: &str) -> bool {
let removed = self.sessions.write().remove(version_key).is_some();
if removed {
info!(version = %version_key, "Unloaded model from cache");
}
removed
}
}
impl std::fmt::Debug for ModelManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ModelManager")
.field("model_dir", &self.config.model_dir)
.field("active_version", &*self.active_version.read())
.field("cached_sessions", &self.sessions.read().len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::tempdir;
#[test]
fn test_model_config_default() {
let config = ModelConfig::default();
assert!(config.intra_op_threads > 0);
assert!(config.verify_checksums);
}
#[test]
fn test_model_manager_creation() {
let dir = tempdir().unwrap();
let config = ModelConfig {
model_dir: dir.path().to_path_buf(),
..Default::default()
};
let manager = ModelManager::new(config);
assert!(manager.is_ok());
}
#[test]
fn test_checksum_computation() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.bin");
let mut file = std::fs::File::create(&file_path).unwrap();
file.write_all(b"test content").unwrap();
let config = ModelConfig {
model_dir: dir.path().to_path_buf(),
..Default::default()
};
let manager = ModelManager::new(config).unwrap();
let checksum = manager.compute_checksum(&file_path).unwrap();
assert!(!checksum.is_empty());
assert_eq!(checksum.len(), 64); // SHA-256 hex length
}
#[test]
fn test_model_version_key() {
let version = ModelVersion::perch_v2_base();
assert_eq!(version.full_version(), "perch-v2-2.0.0-base");
}
#[test]
fn test_register_model() {
let dir = tempdir().unwrap();
let config = ModelConfig {
model_dir: dir.path().to_path_buf(),
..Default::default()
};
let manager = ModelManager::new(config).unwrap();
let model = EmbeddingModel::perch_v2_default();
let version_key = model.version.full_version();
manager.register_model(model);
let retrieved = manager.get_model(&version_key);
assert!(retrieved.is_some());
}
#[test]
fn test_clear_cache() {
let dir = tempdir().unwrap();
let config = ModelConfig {
model_dir: dir.path().to_path_buf(),
..Default::default()
};
let manager = ModelManager::new(config).unwrap();
manager.clear_cache();
// Should not panic
}
}

View File

@@ -0,0 +1,426 @@
//! ONNX Runtime inference for Perch 2.0 embeddings.
//!
//! Provides efficient inference using the `ort` crate for
//! ONNX Runtime integration in Rust.
use std::path::Path;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex;
use ndarray::{Array1, Array3};
use ort::session::{Session, builder::GraphOptimizationLevel};
use thiserror::Error;
use tracing::{debug, instrument, warn};
use super::model_manager::ExecutionProvider;
use crate::{EMBEDDING_DIM, MEL_BINS, MEL_FRAMES};
/// Errors during ONNX inference
#[derive(Debug, Error)]
pub enum InferenceError {
/// Session creation failed
#[error("Failed to create session: {0}")]
SessionCreation(String),
/// Input tensor creation failed
#[error("Failed to create input tensor: {0}")]
InputTensor(String),
/// Inference execution failed
#[error("Inference failed: {0}")]
Execution(String),
/// Output extraction failed
#[error("Failed to extract output: {0}")]
OutputExtraction(String),
/// Invalid input dimensions
#[error("Invalid input dimensions: expected {expected:?}, got {actual:?}")]
InvalidDimensions {
/// Expected shape
expected: Vec<usize>,
/// Actual shape
actual: Vec<usize>,
},
/// Model not initialized
#[error("Model not initialized")]
NotInitialized,
}
/// ONNX inference engine for embedding generation.
pub struct OnnxInference {
/// The ONNX Runtime session (wrapped in Mutex for interior mutability)
session: Mutex<Session>,
/// Whether GPU is being used
gpu_enabled: AtomicBool,
/// Input name for the model
input_name: String,
/// Output name for embeddings
output_name: String,
}
impl OnnxInference {
/// Create a new ONNX inference engine from a model file.
#[instrument(skip(providers), fields(path = ?model_path))]
pub fn new(
model_path: &Path,
intra_op_threads: usize,
inter_op_threads: usize,
providers: &[ExecutionProvider],
) -> Result<Self, InferenceError> {
let builder = Session::builder()
.map_err(|e| InferenceError::SessionCreation(e.to_string()))?
.with_intra_threads(intra_op_threads)
.map_err(|e| InferenceError::SessionCreation(e.to_string()))?
.with_inter_threads(inter_op_threads)
.map_err(|e| InferenceError::SessionCreation(e.to_string()))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| InferenceError::SessionCreation(e.to_string()))?;
let gpu_enabled = false;
for provider in providers {
match provider {
ExecutionProvider::Cuda { device_id } => {
warn!("CUDA device {} requested, using CPU fallback", device_id);
}
ExecutionProvider::CoreML => {
warn!("CoreML requested, using CPU fallback");
}
ExecutionProvider::DirectML { device_id } => {
warn!("DirectML device {} requested, using CPU fallback", device_id);
}
ExecutionProvider::Cpu => {
debug!("Using CPU execution provider");
break;
}
}
}
let session = builder
.commit_from_file(model_path)
.map_err(|e| InferenceError::SessionCreation(e.to_string()))?;
let inputs = session.inputs();
let outputs = session.outputs();
let input_name = inputs
.first()
.map(|i| i.name().to_string())
.unwrap_or_else(|| "input".to_string());
let output_name = outputs
.first()
.map(|o| o.name().to_string())
.unwrap_or_else(|| "embedding".to_string());
debug!(
input = %input_name,
output = %output_name,
gpu = gpu_enabled,
"ONNX session created"
);
Ok(Self {
session: Mutex::new(session),
gpu_enabled: AtomicBool::new(gpu_enabled),
input_name,
output_name,
})
}
/// Run inference on a single spectrogram.
#[instrument(skip(self, input))]
pub fn run(&self, input: &Array3<f32>) -> Result<Array1<f32>, InferenceError> {
let shape = input.shape();
if shape[1] != MEL_FRAMES || shape[2] != MEL_BINS {
return Err(InferenceError::InvalidDimensions {
expected: vec![1, MEL_FRAMES, MEL_BINS],
actual: shape.to_vec(),
});
}
// Create input tensor using ort 2.0 API with shape tuple
let input_vec: Vec<f32> = input.iter().cloned().collect();
let tensor_shape = vec![1i64, MEL_FRAMES as i64, MEL_BINS as i64];
let input_tensor = ort::value::Tensor::from_array((tensor_shape, input_vec))
.map_err(|e| InferenceError::InputTensor(e.to_string()))?;
// Run inference (lock session for mutable access required by ort 2.0)
let inputs = ort::inputs![&self.input_name => input_tensor];
let mut session = self.session.lock()
.map_err(|e| InferenceError::Execution(format!("Lock error: {}", e)))?;
let outputs = session
.run(inputs)
.map_err(|e| InferenceError::Execution(e.to_string()))?;
// Extract embedding output
let output = outputs
.get(&self.output_name)
.ok_or_else(|| InferenceError::OutputExtraction("No output found".to_string()))?;
// Extract tensor data using ort 2.0 API - returns (&Shape, &[f32]) tuple
let (_shape, flat_slice) = output
.try_extract_tensor::<f32>()
.map_err(|e| InferenceError::OutputExtraction(e.to_string()))?;
// Handle different output shapes
let embedding_data: Vec<f32> = if flat_slice.len() == EMBEDDING_DIM {
flat_slice.to_vec()
} else if flat_slice.len() > EMBEDDING_DIM {
flat_slice[..EMBEDDING_DIM].to_vec()
} else {
return Err(InferenceError::OutputExtraction(format!(
"Unexpected embedding size: {} (expected {})",
flat_slice.len(),
EMBEDDING_DIM
)));
};
debug!("Inference completed");
Ok(Array1::from_vec(embedding_data))
}
/// Run inference on a batch of spectrograms.
#[instrument(skip(self, inputs), fields(batch_size = inputs.len()))]
pub fn run_batch(&self, inputs: &[&Array3<f32>]) -> Result<Vec<Array1<f32>>, InferenceError> {
if inputs.is_empty() {
return Ok(Vec::new());
}
let batch_size = inputs.len();
for input in inputs.iter() {
let shape = input.shape();
if shape[1] != MEL_FRAMES || shape[2] != MEL_BINS {
return Err(InferenceError::InvalidDimensions {
expected: vec![1, MEL_FRAMES, MEL_BINS],
actual: shape.to_vec(),
});
}
}
// Stack inputs into a batch tensor
let mut batch_data = Vec::with_capacity(batch_size * MEL_FRAMES * MEL_BINS);
for input in inputs {
let view = input.view();
for frame in 0..MEL_FRAMES {
for bin in 0..MEL_BINS {
batch_data.push(view[[0, frame, bin]]);
}
}
}
// Create batch tensor using shape tuple API
let tensor_shape = vec![batch_size as i64, MEL_FRAMES as i64, MEL_BINS as i64];
let input_tensor = ort::value::Tensor::from_array((tensor_shape, batch_data))
.map_err(|e| InferenceError::InputTensor(e.to_string()))?;
// Run inference (lock session for mutable access required by ort 2.0)
let ort_inputs = ort::inputs![&self.input_name => input_tensor];
let mut session = self.session.lock()
.map_err(|e| InferenceError::Execution(format!("Lock error: {}", e)))?;
let outputs = session
.run(ort_inputs)
.map_err(|e| InferenceError::Execution(e.to_string()))?;
// Extract embeddings using ort 2.0 API - returns (&Shape, &[f32]) tuple
let output = outputs
.get(&self.output_name)
.ok_or_else(|| InferenceError::OutputExtraction("No output found".to_string()))?;
let (_shape, flat_slice) = output
.try_extract_tensor::<f32>()
.map_err(|e| InferenceError::OutputExtraction(e.to_string()))?;
// Split into individual embeddings
let total_expected = batch_size * EMBEDDING_DIM;
if flat_slice.len() < total_expected {
return Err(InferenceError::OutputExtraction(format!(
"Unexpected output size: {} (expected at least {})",
flat_slice.len(),
total_expected
)));
}
let result: Vec<Array1<f32>> = (0..batch_size)
.map(|i| {
let start = i * EMBEDDING_DIM;
let end = start + EMBEDDING_DIM;
Array1::from_vec(flat_slice[start..end].to_vec())
})
.collect();
debug!(batch_size = batch_size, "Batch inference completed");
Ok(result)
}
/// Check if GPU is being used for inference.
#[must_use]
pub fn is_gpu(&self) -> bool {
self.gpu_enabled.load(Ordering::Relaxed)
}
/// Get the input name expected by the model.
#[must_use]
pub fn input_name(&self) -> &str {
&self.input_name
}
/// Get the output name for embeddings.
#[must_use]
pub fn output_name(&self) -> &str {
&self.output_name
}
/// Get information about the model's expected input shape.
#[must_use]
pub fn input_info(&self) -> Option<InputInfo> {
self.session.lock().ok().and_then(|session| {
session.inputs().first().map(|input| InputInfo {
name: input.name().to_string(),
dimensions: Vec::new(),
})
})
}
/// Get information about the model's output shape.
#[must_use]
pub fn output_info(&self) -> Option<OutputInfo> {
self.session.lock().ok().and_then(|session| {
session.outputs().first().map(|output| OutputInfo {
name: output.name().to_string(),
dimensions: Vec::new(),
})
})
}
}
/// Information about model input
#[derive(Debug, Clone)]
pub struct InputInfo {
/// Input tensor name
pub name: String,
/// Expected dimensions
pub dimensions: Vec<usize>,
}
/// Information about model output
#[derive(Debug, Clone)]
pub struct OutputInfo {
/// Output tensor name
pub name: String,
/// Output dimensions
pub dimensions: Vec<usize>,
}
impl std::fmt::Debug for OnnxInference {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OnnxInference")
.field("gpu_enabled", &self.is_gpu())
.field("input_name", &self.input_name)
.field("output_name", &self.output_name)
.finish()
}
}
/// Configuration for ONNX inference
#[derive(Debug, Clone)]
pub struct InferenceConfig {
/// Number of threads for intra-op parallelism
pub intra_op_threads: usize,
/// Number of threads for inter-op parallelism
pub inter_op_threads: usize,
/// Execution providers in priority order
pub providers: Vec<ExecutionProvider>,
/// Whether to enable memory optimization
pub optimize_memory: bool,
/// Maximum batch size for inference
pub max_batch_size: usize,
}
impl Default for InferenceConfig {
fn default() -> Self {
Self {
intra_op_threads: num_cpus::get().min(4),
inter_op_threads: 1,
providers: vec![
ExecutionProvider::Cuda { device_id: 0 },
ExecutionProvider::CoreML,
ExecutionProvider::Cpu,
],
optimize_memory: true,
max_batch_size: 32,
}
}
}
impl InferenceConfig {
/// Configuration optimized for field devices
#[must_use]
pub fn field_device() -> Self {
Self {
intra_op_threads: 2,
inter_op_threads: 1,
providers: vec![ExecutionProvider::Cpu],
optimize_memory: true,
max_batch_size: 1,
}
}
/// Configuration optimized for server deployment
#[must_use]
pub fn server() -> Self {
Self {
intra_op_threads: 4,
inter_op_threads: 2,
providers: vec![
ExecutionProvider::Cuda { device_id: 0 },
ExecutionProvider::Cpu,
],
optimize_memory: false,
max_batch_size: 64,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_inference_config_default() {
let config = InferenceConfig::default();
assert!(config.intra_op_threads > 0);
assert!(!config.providers.is_empty());
}
#[test]
fn test_inference_config_field_device() {
let config = InferenceConfig::field_device();
assert_eq!(config.intra_op_threads, 2);
assert_eq!(config.max_batch_size, 1);
assert!(config.optimize_memory);
}
#[test]
fn test_inference_config_server() {
let config = InferenceConfig::server();
assert_eq!(config.max_batch_size, 64);
assert!(!config.optimize_memory);
}
#[test]
fn test_input_validation() {
let valid_shape = vec![1, MEL_FRAMES, MEL_BINS];
let invalid_shape = vec![1, 100, 100];
assert_eq!(valid_shape[1], MEL_FRAMES);
assert_eq!(valid_shape[2], MEL_BINS);
assert_ne!(invalid_shape[1], MEL_FRAMES);
}
}

View File

@@ -0,0 +1,143 @@
//! # sevensense-embedding
//!
//! Embedding bounded context for 7sense bioacoustics platform.
//!
//! This crate provides Perch 2.0 ONNX integration for generating 1536-dimensional
//! embeddings from preprocessed audio segments. It handles model loading, inference,
//! normalization, and quantization for efficient storage and retrieval.
//!
//! ## Architecture
//!
//! The crate follows Domain-Driven Design (DDD) principles:
//!
//! - **Domain Layer**: Core entities (`Embedding`, `EmbeddingModel`) and repository traits
//! - **Application Layer**: Services for embedding generation and batch processing
//! - **Infrastructure Layer**: ONNX Runtime integration and model management
//!
//! ## Usage
//!
//! ```rust,ignore
//! use sevensense_embedding::{
//! EmbeddingService, ModelManager, ModelConfig,
//! domain::Embedding,
//! };
//!
//! // Initialize model manager
//! let config = ModelConfig::default();
//! let model_manager = ModelManager::new(config)?;
//!
//! // Create embedding service
//! let service = EmbeddingService::new(model_manager, 8);
//!
//! // Generate embedding from spectrogram
//! let embedding = service.embed_segment(&spectrogram).await?;
//! ```
//!
//! ## Features
//!
//! - **Perch 2.0 Integration**: Full support for EfficientNet-B3 bioacoustic embeddings
//! - **Batch Processing**: Efficient batch inference with configurable batch sizes
//! - **Model Hot-Swap**: Update models without service restart
//! - **Quantization**: F16 and INT8 quantization for reduced storage
//! - **Validation**: Comprehensive embedding validation (NaN detection, dimension checks)
#![warn(missing_docs)]
#![warn(clippy::all)]
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)]
pub mod domain;
pub mod application;
pub mod infrastructure;
pub mod normalization;
pub mod quantization;
// Re-export main types for convenience
pub use domain::entities::{
Embedding, EmbeddingId, EmbeddingModel, EmbeddingMetadata,
StorageTier, ModelVersion, InputSpecification,
};
pub use domain::repository::EmbeddingRepository;
pub use application::services::EmbeddingService;
pub use infrastructure::model_manager::{ModelManager, ModelConfig};
pub use infrastructure::onnx_inference::OnnxInference;
/// Embedding dimension for Perch 2.0 model
pub const EMBEDDING_DIM: usize = 1536;
/// Target sample rate for Perch 2.0 (32kHz)
pub const TARGET_SAMPLE_RATE: u32 = 32000;
/// Target window duration in seconds (5s)
pub const TARGET_WINDOW_SECONDS: f32 = 5.0;
/// Target window samples (160,000 = 5s at 32kHz)
pub const TARGET_WINDOW_SAMPLES: usize = 160_000;
/// Mel spectrogram bins for Perch 2.0
pub const MEL_BINS: usize = 128;
/// Mel spectrogram frames for Perch 2.0
pub const MEL_FRAMES: usize = 500;
/// Crate version information
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
/// Common result type for embedding operations
pub type Result<T> = std::result::Result<T, EmbeddingError>;
/// Unified error type for embedding operations
#[derive(Debug, thiserror::Error)]
pub enum EmbeddingError {
/// Model loading or initialization error
#[error("Model error: {0}")]
Model(#[from] infrastructure::model_manager::ModelError),
/// ONNX inference error
#[error("Inference error: {0}")]
Inference(#[from] infrastructure::onnx_inference::InferenceError),
/// Embedding validation error
#[error("Validation error: {0}")]
Validation(String),
/// Invalid input dimensions
#[error("Invalid dimensions: expected {expected}, got {actual}")]
InvalidDimensions {
/// Expected dimension
expected: usize,
/// Actual dimension
actual: usize,
},
/// Repository error
#[error("Repository error: {0}")]
Repository(String),
/// IO error
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
/// Checksum verification failed
#[error("Checksum mismatch: expected {expected}, got {actual}")]
ChecksumMismatch {
/// Expected checksum
expected: String,
/// Actual checksum
actual: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constants() {
assert_eq!(EMBEDDING_DIM, 1536);
assert_eq!(TARGET_SAMPLE_RATE, 32000);
assert_eq!(TARGET_WINDOW_SAMPLES, 160_000);
assert_eq!(MEL_BINS, 128);
assert_eq!(MEL_FRAMES, 500);
}
}

View File

@@ -0,0 +1,460 @@
//! Normalization utilities for embedding vectors.
//!
//! Provides L2 normalization and validation functions to ensure
//! embedding vectors are properly normalized for cosine similarity
//! operations in vector databases.
use crate::EMBEDDING_DIM;
/// L2 normalize an embedding vector in-place.
///
/// After normalization, the vector will have unit length (norm = 1.0),
/// enabling cosine similarity to be computed as a simple dot product.
///
/// # Arguments
///
/// * `embedding` - The embedding vector to normalize in-place
///
/// # Example
///
/// ```rust
/// use sevensense_embedding::normalization::l2_normalize;
///
/// let mut vector = vec![3.0, 4.0];
/// l2_normalize(&mut vector);
/// assert!((vector[0] - 0.6).abs() < 1e-6);
/// assert!((vector[1] - 0.8).abs() < 1e-6);
/// ```
pub fn l2_normalize(embedding: &mut [f32]) {
let norm = compute_norm(embedding);
if norm > 1e-12 {
for x in embedding.iter_mut() {
*x /= norm;
}
} else {
// Handle near-zero embeddings (likely silent input)
// Set to unit vector in first dimension
embedding.iter_mut().for_each(|x| *x = 0.0);
if !embedding.is_empty() {
embedding[0] = 1.0;
}
}
}
/// Compute the L2 norm of a vector.
///
/// # Arguments
///
/// * `vector` - The vector to compute the norm for
///
/// # Returns
///
/// The L2 norm (Euclidean length) of the vector.
#[must_use]
pub fn compute_norm(vector: &[f32]) -> f32 {
vector.iter().map(|x| x * x).sum::<f32>().sqrt()
}
/// Compute the sparsity of a vector.
///
/// Sparsity is the fraction of near-zero values in the vector.
/// High sparsity may indicate issues with the embedding model.
///
/// # Arguments
///
/// * `vector` - The vector to analyze
///
/// # Returns
///
/// Sparsity as a value between 0.0 (no zeros) and 1.0 (all zeros).
#[must_use]
pub fn compute_sparsity(vector: &[f32]) -> f32 {
if vector.is_empty() {
return 0.0;
}
let near_zero_count = vector.iter().filter(|&&x| x.abs() < 1e-6).count();
near_zero_count as f32 / vector.len() as f32
}
/// Validate an embedding vector for common issues.
///
/// # Arguments
///
/// * `embedding` - The embedding vector to validate
///
/// # Returns
///
/// A `ValidationResult` containing detailed information about the vector.
#[must_use]
pub fn validate_embedding(embedding: &[f32]) -> ValidationResult {
let dimension_valid = embedding.len() == EMBEDDING_DIM;
let has_nan = embedding.iter().any(|x| x.is_nan());
let has_inf = embedding.iter().any(|x| x.is_infinite());
let norm = compute_norm(embedding);
let is_normalized = (0.99..=1.01).contains(&norm);
let sparsity = compute_sparsity(embedding);
let issues = collect_issues(
dimension_valid,
embedding.len(),
has_nan,
has_inf,
is_normalized,
norm,
sparsity,
);
ValidationResult {
dimension: embedding.len(),
dimension_valid,
norm,
is_normalized,
has_nan,
has_inf,
sparsity,
is_valid: dimension_valid && !has_nan && !has_inf,
issues,
}
}
fn collect_issues(
dimension_valid: bool,
actual_dim: usize,
has_nan: bool,
has_inf: bool,
is_normalized: bool,
norm: f32,
sparsity: f32,
) -> Vec<ValidationIssue> {
let mut issues = Vec::new();
if !dimension_valid {
issues.push(ValidationIssue::InvalidDimension {
expected: EMBEDDING_DIM,
actual: actual_dim,
});
}
if has_nan {
issues.push(ValidationIssue::ContainsNaN);
}
if has_inf {
issues.push(ValidationIssue::ContainsInfinite);
}
if !is_normalized && !has_nan && !has_inf {
issues.push(ValidationIssue::NotNormalized { norm });
}
if sparsity > 0.9 {
issues.push(ValidationIssue::HighSparsity { sparsity });
}
issues
}
/// Result of embedding validation
#[derive(Debug, Clone)]
pub struct ValidationResult {
/// Actual dimension of the embedding
pub dimension: usize,
/// Whether the dimension matches expected (1536)
pub dimension_valid: bool,
/// L2 norm of the embedding
pub norm: f32,
/// Whether the embedding is L2 normalized (norm close to 1.0)
pub is_normalized: bool,
/// Whether the embedding contains NaN values
pub has_nan: bool,
/// Whether the embedding contains infinite values
pub has_inf: bool,
/// Fraction of near-zero values
pub sparsity: f32,
/// Overall validity (no NaN, no Inf, correct dimension)
pub is_valid: bool,
/// List of specific issues found
pub issues: Vec<ValidationIssue>,
}
impl ValidationResult {
/// Check if the embedding passes all validation checks
#[must_use]
pub fn is_ok(&self) -> bool {
self.issues.is_empty()
}
/// Get a human-readable summary of the validation result
#[must_use]
pub fn summary(&self) -> String {
if self.issues.is_empty() {
return "Embedding is valid".to_string();
}
let issue_strings: Vec<String> = self.issues.iter().map(|i| i.to_string()).collect();
format!("Embedding has issues: {}", issue_strings.join(", "))
}
}
/// Specific validation issues that can be detected
#[derive(Debug, Clone, PartialEq)]
pub enum ValidationIssue {
/// Embedding dimension doesn't match expected
InvalidDimension {
/// Expected dimension
expected: usize,
/// Actual dimension
actual: usize,
},
/// Embedding contains NaN values
ContainsNaN,
/// Embedding contains infinite values
ContainsInfinite,
/// Embedding is not L2 normalized
NotNormalized {
/// Actual norm
norm: f32,
},
/// Embedding has high sparsity (many near-zero values)
HighSparsity {
/// Sparsity value
sparsity: f32,
},
}
impl std::fmt::Display for ValidationIssue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidDimension { expected, actual } => {
write!(f, "invalid dimension (expected {expected}, got {actual})")
}
Self::ContainsNaN => write!(f, "contains NaN values"),
Self::ContainsInfinite => write!(f, "contains infinite values"),
Self::NotNormalized { norm } => {
write!(f, "not normalized (norm = {norm:.4})")
}
Self::HighSparsity { sparsity } => {
write!(f, "high sparsity ({:.1}%)", sparsity * 100.0)
}
}
}
}
/// L1 normalize a vector (sum of absolute values = 1)
pub fn l1_normalize(embedding: &mut [f32]) {
let sum: f32 = embedding.iter().map(|x| x.abs()).sum();
if sum > 1e-12 {
for x in embedding.iter_mut() {
*x /= sum;
}
}
}
/// Min-max normalize a vector to [0, 1] range
pub fn minmax_normalize(embedding: &mut [f32]) {
if embedding.is_empty() {
return;
}
let min = embedding.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max = embedding.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let range = max - min;
if range > 1e-12 {
for x in embedding.iter_mut() {
*x = (*x - min) / range;
}
} else {
// All values are the same
embedding.iter_mut().for_each(|x| *x = 0.5);
}
}
/// Z-score normalize a vector (mean = 0, std = 1)
pub fn zscore_normalize(embedding: &mut [f32]) {
if embedding.is_empty() {
return;
}
let n = embedding.len() as f32;
let mean: f32 = embedding.iter().sum::<f32>() / n;
let variance: f32 = embedding.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / n;
let std = variance.sqrt();
if std > 1e-12 {
for x in embedding.iter_mut() {
*x = (*x - mean) / std;
}
} else {
// Zero variance - all values are the same
embedding.iter_mut().for_each(|x| *x = 0.0);
}
}
/// Clamp values to a specified range
pub fn clamp(embedding: &mut [f32], min: f32, max: f32) {
for x in embedding.iter_mut() {
*x = x.clamp(min, max);
}
}
/// Soft clipping using tanh
pub fn soft_clip(embedding: &mut [f32], scale: f32) {
for x in embedding.iter_mut() {
*x = (*x / scale).tanh() * scale;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_l2_normalize() {
let mut vector = vec![3.0, 4.0];
l2_normalize(&mut vector);
assert!((vector[0] - 0.6).abs() < 1e-6);
assert!((vector[1] - 0.8).abs() < 1e-6);
let norm = compute_norm(&vector);
assert!((norm - 1.0).abs() < 1e-6);
}
#[test]
fn test_l2_normalize_zero_vector() {
let mut vector = vec![0.0, 0.0, 0.0];
l2_normalize(&mut vector);
assert_eq!(vector[0], 1.0);
assert_eq!(vector[1], 0.0);
assert_eq!(vector[2], 0.0);
}
#[test]
fn test_compute_norm() {
let vector = vec![3.0, 4.0];
let norm = compute_norm(&vector);
assert!((norm - 5.0).abs() < 1e-6);
}
#[test]
fn test_compute_sparsity() {
let vector = vec![0.0, 1.0, 0.0, 2.0, 0.0];
let sparsity = compute_sparsity(&vector);
assert!((sparsity - 0.6).abs() < 1e-6);
}
#[test]
fn test_compute_sparsity_empty() {
let vector: Vec<f32> = vec![];
let sparsity = compute_sparsity(&vector);
assert_eq!(sparsity, 0.0);
}
#[test]
fn test_validate_embedding_valid() {
let mut vector = vec![0.0; EMBEDDING_DIM];
vector[0] = 1.0;
let result = validate_embedding(&vector);
assert!(result.is_valid);
assert!(result.is_normalized);
assert!(!result.has_nan);
assert!(!result.has_inf);
}
#[test]
fn test_validate_embedding_wrong_dimension() {
let vector = vec![1.0; 100];
let result = validate_embedding(&vector);
assert!(!result.dimension_valid);
assert!(result.issues.iter().any(|i| matches!(i, ValidationIssue::InvalidDimension { .. })));
}
#[test]
fn test_validate_embedding_nan() {
let mut vector = vec![0.0; EMBEDDING_DIM];
vector[0] = f32::NAN;
let result = validate_embedding(&vector);
assert!(result.has_nan);
assert!(!result.is_valid);
}
#[test]
fn test_validate_embedding_infinite() {
let mut vector = vec![0.0; EMBEDDING_DIM];
vector[0] = f32::INFINITY;
let result = validate_embedding(&vector);
assert!(result.has_inf);
assert!(!result.is_valid);
}
#[test]
fn test_l1_normalize() {
let mut vector = vec![1.0, 2.0, 3.0];
l1_normalize(&mut vector);
let sum: f32 = vector.iter().map(|x| x.abs()).sum();
assert!((sum - 1.0).abs() < 1e-6);
}
#[test]
fn test_minmax_normalize() {
let mut vector = vec![0.0, 5.0, 10.0];
minmax_normalize(&mut vector);
assert!((vector[0] - 0.0).abs() < 1e-6);
assert!((vector[1] - 0.5).abs() < 1e-6);
assert!((vector[2] - 1.0).abs() < 1e-6);
}
#[test]
fn test_zscore_normalize() {
let mut vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
zscore_normalize(&mut vector);
let mean: f32 = vector.iter().sum::<f32>() / vector.len() as f32;
assert!(mean.abs() < 1e-6);
}
#[test]
fn test_clamp() {
let mut vector = vec![-2.0, 0.5, 2.0];
clamp(&mut vector, -1.0, 1.0);
assert_eq!(vector, vec![-1.0, 0.5, 1.0]);
}
#[test]
fn test_soft_clip() {
let mut vector = vec![0.0, 1.0, 2.0];
soft_clip(&mut vector, 1.0);
assert!((vector[0] - 0.0).abs() < 1e-6);
// tanh(1) ≈ 0.7616
assert!(vector[1] > 0.5 && vector[1] < 0.8);
// tanh(2) ≈ 0.964
assert!(vector[2] > 0.9 && vector[2] < 1.0);
}
#[test]
fn test_validation_result_summary() {
// Create a reasonably distributed embedding (not too sparse)
let mut vector = vec![0.0; EMBEDDING_DIM];
// Fill first half with small values that sum to norm 1.0
let val = 1.0 / (EMBEDDING_DIM as f32 / 2.0).sqrt();
for i in 0..EMBEDDING_DIM / 2 {
vector[i] = val;
}
let result = validate_embedding(&vector);
assert!(result.summary().contains("valid"), "Summary: {}", result.summary());
}
}

View File

@@ -0,0 +1,562 @@
//! Quantization utilities for embedding storage optimization.
//!
//! Provides F16 and INT8 quantization for reduced storage footprint
//! while maintaining acceptable precision for similarity search.
//!
//! ## Storage Comparison
//!
//! | Format | Bytes/Dim | Total (1536-D) | Precision Loss |
//! |--------|-----------|----------------|----------------|
//! | f32 | 4 | 6,144 bytes | None (baseline)|
//! | f16 | 2 | 3,072 bytes | ~0.1% typical |
//! | i8 | 1 | 1,536 bytes | ~1-2% typical |
use half::f16;
use serde::{Deserialize, Serialize};
/// Quantize f32 embedding to f16 (half precision).
///
/// F16 provides 50% storage reduction with minimal precision loss.
/// Suitable for warm storage tier.
///
/// # Arguments
///
/// * `embedding` - The f32 embedding vector to quantize
///
/// # Returns
///
/// Vector of f16 values
///
/// # Example
///
/// ```rust
/// use sevensense_embedding::quantization::quantize_to_f16;
///
/// let embedding = vec![0.5, -0.3, 0.8];
/// let quantized = quantize_to_f16(&embedding);
/// assert_eq!(quantized.len(), embedding.len());
/// ```
#[must_use]
pub fn quantize_to_f16(embedding: &[f32]) -> Vec<f16> {
embedding.iter().map(|&x| f16::from_f32(x)).collect()
}
/// Dequantize f16 embedding back to f32.
///
/// # Arguments
///
/// * `quantized` - The f16 quantized embedding
///
/// # Returns
///
/// Vector of f32 values
#[must_use]
pub fn dequantize_f16(quantized: &[f16]) -> Vec<f32> {
quantized.iter().map(|&x| x.to_f32()).collect()
}
/// Quantization parameters for INT8 quantization
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct QuantizationParams {
/// Scale factor for quantization
pub scale: f32,
/// Zero point for asymmetric quantization
pub zero_point: f32,
/// Minimum value in the original data
pub min_val: f32,
/// Maximum value in the original data
pub max_val: f32,
}
impl QuantizationParams {
/// Compute quantization parameters from data
#[must_use]
pub fn from_data(data: &[f32]) -> Self {
if data.is_empty() {
return Self {
scale: 1.0,
zero_point: 0.0,
min_val: 0.0,
max_val: 0.0,
};
}
let min_val = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_val = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
// For symmetric quantization around zero (better for L2-normalized embeddings)
let abs_max = min_val.abs().max(max_val.abs());
let scale = if abs_max > 1e-12 {
abs_max / 127.0
} else {
1.0
};
Self {
scale,
zero_point: 0.0, // Symmetric quantization
min_val,
max_val,
}
}
/// Compute quantization parameters for asymmetric quantization
#[must_use]
pub fn from_data_asymmetric(data: &[f32]) -> Self {
if data.is_empty() {
return Self {
scale: 1.0,
zero_point: 0.0,
min_val: 0.0,
max_val: 0.0,
};
}
let min_val = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_val = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let range = max_val - min_val;
let scale = if range > 1e-12 {
range / 255.0
} else {
1.0
};
// Zero point maps min_val to 0 in quantized space
let zero_point = -min_val / scale;
Self {
scale,
zero_point,
min_val,
max_val,
}
}
}
/// Result of INT8 quantization including the quantized values and parameters
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizedEmbedding {
/// Quantized INT8 values
pub values: Vec<i8>,
/// Quantization parameters for dequantization
pub params: QuantizationParams,
}
impl QuantizedEmbedding {
/// Get the storage size in bytes
#[must_use]
pub fn size_bytes(&self) -> usize {
// i8 values + params overhead
self.values.len() + std::mem::size_of::<QuantizationParams>()
}
/// Dequantize back to f32
#[must_use]
pub fn dequantize(&self) -> Vec<f32> {
dequantize_i8(&self.values, self.params.scale, self.params.zero_point)
}
}
/// Quantize f32 embedding to INT8 with scale and zero point.
///
/// INT8 provides 75% storage reduction but with some precision loss.
/// Suitable for cold storage tier or large-scale deployments.
///
/// # Arguments
///
/// * `embedding` - The f32 embedding vector to quantize
///
/// # Returns
///
/// Tuple of (quantized values, scale, zero_point)
///
/// # Example
///
/// ```rust
/// use sevensense_embedding::quantization::quantize_to_i8;
///
/// let embedding = vec![0.5, -0.3, 0.8, -0.1];
/// let (quantized, scale, zero_point) = quantize_to_i8(&embedding);
/// assert_eq!(quantized.len(), embedding.len());
/// ```
#[must_use]
pub fn quantize_to_i8(embedding: &[f32]) -> (Vec<i8>, f32, f32) {
let params = QuantizationParams::from_data(embedding);
let quantized: Vec<i8> = embedding
.iter()
.map(|&x| {
let q = (x / params.scale).round();
q.clamp(-128.0, 127.0) as i8
})
.collect();
(quantized, params.scale, params.zero_point)
}
/// Quantize f32 embedding to INT8 with full quantization info.
///
/// # Arguments
///
/// * `embedding` - The f32 embedding vector to quantize
///
/// # Returns
///
/// QuantizedEmbedding containing values and parameters
#[must_use]
pub fn quantize_to_i8_full(embedding: &[f32]) -> QuantizedEmbedding {
let params = QuantizationParams::from_data(embedding);
let values: Vec<i8> = embedding
.iter()
.map(|&x| {
let q = (x / params.scale).round();
q.clamp(-128.0, 127.0) as i8
})
.collect();
QuantizedEmbedding { values, params }
}
/// Dequantize INT8 embedding back to f32.
///
/// # Arguments
///
/// * `quantized` - The INT8 quantized values
/// * `scale` - Scale factor used during quantization
/// * `zero_point` - Zero point used during quantization
///
/// # Returns
///
/// Vector of f32 values
///
/// # Example
///
/// ```rust
/// use sevensense_embedding::quantization::{quantize_to_i8, dequantize_i8};
///
/// let embedding = vec![0.5, -0.3, 0.8, -0.1];
/// let (quantized, scale, zero_point) = quantize_to_i8(&embedding);
/// let restored = dequantize_i8(&quantized, scale, zero_point);
///
/// // Check that values are close (within quantization error)
/// for (orig, rest) in embedding.iter().zip(restored.iter()) {
/// assert!((orig - rest).abs() < 0.05);
/// }
/// ```
#[must_use]
pub fn dequantize_i8(quantized: &[i8], scale: f32, zero_point: f32) -> Vec<f32> {
quantized
.iter()
.map(|&q| (q as f32 - zero_point) * scale)
.collect()
}
/// Quantize to unsigned INT8 (0-255 range) for asymmetric quantization
#[must_use]
pub fn quantize_to_u8(embedding: &[f32]) -> (Vec<u8>, f32, f32) {
let params = QuantizationParams::from_data_asymmetric(embedding);
let quantized: Vec<u8> = embedding
.iter()
.map(|&x| {
let q = (x / params.scale + params.zero_point).round();
q.clamp(0.0, 255.0) as u8
})
.collect();
(quantized, params.scale, params.zero_point)
}
/// Dequantize unsigned INT8 back to f32
#[must_use]
pub fn dequantize_u8(quantized: &[u8], scale: f32, zero_point: f32) -> Vec<f32> {
quantized
.iter()
.map(|&q| (q as f32 - zero_point) * scale)
.collect()
}
/// Compute quantization error (MSE) between original and dequantized values
#[must_use]
pub fn compute_quantization_error(original: &[f32], dequantized: &[f32]) -> f32 {
if original.len() != dequantized.len() || original.is_empty() {
return f32::NAN;
}
let mse: f32 = original
.iter()
.zip(dequantized.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
/ original.len() as f32;
mse
}
/// Compute cosine similarity preservation after quantization
///
/// Returns the ratio of cosine similarities (quantized / original)
#[must_use]
pub fn compute_cosine_preservation(
original_a: &[f32],
original_b: &[f32],
dequant_a: &[f32],
dequant_b: &[f32],
) -> f32 {
let original_sim = cosine_similarity(original_a, original_b);
let quant_sim = cosine_similarity(dequant_a, dequant_b);
if original_sim.abs() < 1e-12 {
return 1.0;
}
quant_sim / original_sim
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a * norm_b < 1e-12 {
return 0.0;
}
dot / (norm_a * norm_b)
}
/// Statistics about quantization quality
#[derive(Debug, Clone)]
pub struct QuantizationStats {
/// Mean squared error
pub mse: f32,
/// Root mean squared error
pub rmse: f32,
/// Maximum absolute error
pub max_error: f32,
/// Mean absolute error
pub mean_error: f32,
/// Compression ratio (original size / quantized size)
pub compression_ratio: f32,
}
impl QuantizationStats {
/// Compute statistics comparing original and dequantized embeddings
#[must_use]
pub fn compute(original: &[f32], dequantized: &[f32], quantized_bytes: usize) -> Self {
let mse = compute_quantization_error(original, dequantized);
let rmse = mse.sqrt();
let errors: Vec<f32> = original
.iter()
.zip(dequantized.iter())
.map(|(a, b)| (a - b).abs())
.collect();
let max_error = errors.iter().fold(0.0f32, |a, &b| a.max(b));
let mean_error = errors.iter().sum::<f32>() / errors.len().max(1) as f32;
let original_bytes = original.len() * std::mem::size_of::<f32>();
let compression_ratio = original_bytes as f32 / quantized_bytes.max(1) as f32;
Self {
mse,
rmse,
max_error,
mean_error,
compression_ratio,
}
}
}
/// Batch quantization for multiple embeddings
pub struct BatchQuantizer {
/// Whether to use symmetric quantization
pub symmetric: bool,
/// Target precision
pub precision: QuantizationPrecision,
}
/// Supported quantization precisions
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QuantizationPrecision {
/// 16-bit floating point
F16,
/// 8-bit signed integer (symmetric)
Int8,
/// 8-bit unsigned integer (asymmetric)
UInt8,
}
impl BatchQuantizer {
/// Create a new batch quantizer
#[must_use]
pub fn new(precision: QuantizationPrecision) -> Self {
Self {
symmetric: matches!(precision, QuantizationPrecision::Int8),
precision,
}
}
/// Quantize a batch of embeddings
pub fn quantize_batch(&self, embeddings: &[Vec<f32>]) -> Vec<QuantizedEmbedding> {
embeddings
.iter()
.map(|emb| match self.precision {
QuantizationPrecision::F16 => {
let f16_vals = quantize_to_f16(emb);
// Store f16 as i8 pairs for uniform interface
let bytes: Vec<i8> = f16_vals
.iter()
.flat_map(|v| {
let bits = v.to_bits();
[(bits & 0xFF) as i8, ((bits >> 8) & 0xFF) as i8]
})
.collect();
QuantizedEmbedding {
values: bytes,
params: QuantizationParams {
scale: 1.0,
zero_point: 0.0,
min_val: 0.0,
max_val: 0.0,
},
}
}
QuantizationPrecision::Int8 | QuantizationPrecision::UInt8 => {
quantize_to_i8_full(emb)
}
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_f16_roundtrip() {
let original = vec![0.5, -0.3, 0.8, -0.1, 0.0, 1.0, -1.0];
let quantized = quantize_to_f16(&original);
let restored = dequantize_f16(&quantized);
for (orig, rest) in original.iter().zip(restored.iter()) {
assert!((orig - rest).abs() < 0.01, "F16 roundtrip error too large");
}
}
#[test]
fn test_i8_roundtrip() {
let original = vec![0.5, -0.3, 0.8, -0.1, 0.0, 0.9, -0.9];
let (quantized, scale, zero_point) = quantize_to_i8(&original);
let restored = dequantize_i8(&quantized, scale, zero_point);
for (orig, rest) in original.iter().zip(restored.iter()) {
// INT8 has larger quantization error
assert!((orig - rest).abs() < 0.02, "I8 roundtrip error too large");
}
}
#[test]
fn test_u8_roundtrip() {
let original = vec![0.1, 0.3, 0.5, 0.7, 0.9];
let (quantized, scale, zero_point) = quantize_to_u8(&original);
let restored = dequantize_u8(&quantized, scale, zero_point);
for (orig, rest) in original.iter().zip(restored.iter()) {
assert!((orig - rest).abs() < 0.02, "U8 roundtrip error too large");
}
}
#[test]
fn test_quantization_params() {
let data = vec![-0.5, 0.0, 0.5, 1.0];
let params = QuantizationParams::from_data(&data);
assert!(params.scale > 0.0);
assert_eq!(params.min_val, -0.5);
assert_eq!(params.max_val, 1.0);
}
#[test]
fn test_quantization_error() {
let original = vec![0.5, -0.3, 0.8];
let modified = vec![0.51, -0.29, 0.79];
let error = compute_quantization_error(&original, &modified);
assert!(error < 0.001);
}
#[test]
fn test_cosine_preservation() {
let a = vec![0.6, 0.8, 0.0];
let b = vec![0.0, 0.6, 0.8];
// Slightly perturbed versions
let a_quant = vec![0.61, 0.79, 0.01];
let b_quant = vec![0.01, 0.59, 0.81];
let preservation = compute_cosine_preservation(&a, &b, &a_quant, &b_quant);
// Should be close to 1.0 if quantization preserves cosine similarity
assert!(preservation > 0.95 && preservation < 1.05);
}
#[test]
fn test_quantization_stats() {
let original = vec![0.5, -0.3, 0.8, -0.1];
let (quantized, scale, zero_point) = quantize_to_i8(&original);
let restored = dequantize_i8(&quantized, scale, zero_point);
let stats = QuantizationStats::compute(&original, &restored, quantized.len());
assert!(stats.mse >= 0.0);
assert!(stats.rmse >= 0.0);
assert!(stats.compression_ratio > 1.0); // Should compress
}
#[test]
fn test_batch_quantizer() {
let embeddings = vec![
vec![0.5, -0.3, 0.8],
vec![-0.1, 0.2, 0.9],
];
let quantizer = BatchQuantizer::new(QuantizationPrecision::Int8);
let quantized = quantizer.quantize_batch(&embeddings);
assert_eq!(quantized.len(), 2);
assert_eq!(quantized[0].values.len(), 3);
}
#[test]
fn test_quantized_embedding_dequantize() {
let original = vec![0.5, -0.3, 0.8, -0.1];
let quantized = quantize_to_i8_full(&original);
let restored = quantized.dequantize();
assert_eq!(restored.len(), original.len());
for (orig, rest) in original.iter().zip(restored.iter()) {
assert!((orig - rest).abs() < 0.02);
}
}
#[test]
fn test_empty_input() {
let empty: Vec<f32> = vec![];
let f16_result = quantize_to_f16(&empty);
assert!(f16_result.is_empty());
let (i8_result, _, _) = quantize_to_i8(&empty);
assert!(i8_result.is_empty());
let params = QuantizationParams::from_data(&empty);
assert_eq!(params.scale, 1.0);
}
}