Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
6
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-embedding/src/application/mod.rs
vendored
Normal file
6
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-embedding/src/application/mod.rs
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
//! Application layer for embedding bounded context.
|
||||
//!
|
||||
//! Contains application services that orchestrate domain logic
|
||||
//! and coordinate between domain entities and infrastructure.
|
||||
|
||||
pub mod services;
|
||||
567
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-embedding/src/application/services.rs
vendored
Normal file
567
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-embedding/src/application/services.rs
vendored
Normal file
@@ -0,0 +1,567 @@
|
||||
//! Application services for embedding generation.
|
||||
//!
|
||||
//! Provides high-level services for generating embeddings from audio
|
||||
//! spectrograms using the Perch 2.0 ONNX model.
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use ndarray::Array3;
|
||||
use rayon::prelude::*;
|
||||
use tracing::{debug, info, instrument, warn};
|
||||
|
||||
use crate::domain::entities::{
|
||||
Embedding, EmbeddingBatch, EmbeddingMetadata, SegmentId, StorageTier,
|
||||
};
|
||||
use crate::infrastructure::model_manager::ModelManager;
|
||||
use crate::normalization;
|
||||
use crate::{EmbeddingError, EMBEDDING_DIM, MEL_BINS, MEL_FRAMES};
|
||||
|
||||
/// Input spectrogram for embedding generation.
|
||||
///
|
||||
/// Represents a mel spectrogram with shape [1, MEL_FRAMES, MEL_BINS] = [1, 500, 128].
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Spectrogram {
|
||||
/// The spectrogram data as a 3D array [batch, frames, bins]
|
||||
pub data: Array3<f32>,
|
||||
|
||||
/// Associated segment ID
|
||||
pub segment_id: SegmentId,
|
||||
|
||||
/// Additional metadata
|
||||
pub metadata: SpectrogramMetadata,
|
||||
}
|
||||
|
||||
/// Metadata about the spectrogram
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct SpectrogramMetadata {
|
||||
/// Sample rate of the original audio
|
||||
pub sample_rate: Option<u32>,
|
||||
|
||||
/// Duration of the audio segment in seconds
|
||||
pub duration_secs: Option<f32>,
|
||||
|
||||
/// SNR of the audio segment
|
||||
pub snr: Option<f32>,
|
||||
}
|
||||
|
||||
impl Spectrogram {
|
||||
/// Create a new spectrogram from raw data.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `data` - 2D array of shape [MEL_FRAMES, MEL_BINS] (will be expanded to 3D)
|
||||
/// * `segment_id` - ID of the source audio segment
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the data dimensions are incorrect.
|
||||
pub fn new(
|
||||
data: ndarray::Array2<f32>,
|
||||
segment_id: SegmentId,
|
||||
) -> Result<Self, EmbeddingError> {
|
||||
let shape = data.shape();
|
||||
if shape[0] != MEL_FRAMES || shape[1] != MEL_BINS {
|
||||
return Err(EmbeddingError::InvalidDimensions {
|
||||
expected: MEL_FRAMES * MEL_BINS,
|
||||
actual: shape[0] * shape[1],
|
||||
});
|
||||
}
|
||||
|
||||
// Expand to 3D: [1, frames, bins]
|
||||
let data = data.insert_axis(ndarray::Axis(0));
|
||||
|
||||
Ok(Self {
|
||||
data,
|
||||
segment_id,
|
||||
metadata: SpectrogramMetadata::default(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create from a 3D array directly
|
||||
pub fn from_array3(data: Array3<f32>, segment_id: SegmentId) -> Result<Self, EmbeddingError> {
|
||||
let shape = data.shape();
|
||||
if shape[1] != MEL_FRAMES || shape[2] != MEL_BINS {
|
||||
return Err(EmbeddingError::InvalidDimensions {
|
||||
expected: MEL_FRAMES * MEL_BINS,
|
||||
actual: shape[1] * shape[2],
|
||||
});
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
data,
|
||||
segment_id,
|
||||
metadata: SpectrogramMetadata::default(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Set metadata for the spectrogram
|
||||
pub fn with_metadata(mut self, metadata: SpectrogramMetadata) -> Self {
|
||||
self.metadata = metadata;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Output from the embedding service
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EmbeddingOutput {
|
||||
/// The generated embedding
|
||||
pub embedding: Embedding,
|
||||
|
||||
/// Whether GPU was used for inference
|
||||
pub gpu_used: bool,
|
||||
|
||||
/// Inference latency in milliseconds
|
||||
pub latency_ms: f32,
|
||||
}
|
||||
|
||||
/// Configuration for the embedding service
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EmbeddingServiceConfig {
|
||||
/// Maximum batch size for inference
|
||||
pub batch_size: usize,
|
||||
|
||||
/// Whether to L2 normalize embeddings
|
||||
pub normalize: bool,
|
||||
|
||||
/// Default storage tier for new embeddings
|
||||
pub default_tier: StorageTier,
|
||||
|
||||
/// Whether to validate embeddings after generation
|
||||
pub validate_embeddings: bool,
|
||||
|
||||
/// Maximum allowed sparsity (fraction of near-zero values)
|
||||
pub max_sparsity: f32,
|
||||
}
|
||||
|
||||
impl Default for EmbeddingServiceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
batch_size: 8,
|
||||
normalize: true,
|
||||
default_tier: StorageTier::Hot,
|
||||
validate_embeddings: true,
|
||||
max_sparsity: 0.9,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Service for generating embeddings from spectrograms.
|
||||
///
|
||||
/// This is the main application service for the embedding bounded context.
|
||||
/// It coordinates between the model manager, ONNX inference, and domain entities.
|
||||
pub struct EmbeddingService {
|
||||
/// Model manager for loading and caching ONNX models
|
||||
model_manager: Arc<ModelManager>,
|
||||
|
||||
/// Configuration for the service
|
||||
config: EmbeddingServiceConfig,
|
||||
}
|
||||
|
||||
impl EmbeddingService {
|
||||
/// Create a new embedding service.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model_manager` - The model manager for ONNX model access
|
||||
/// * `batch_size` - Maximum batch size for inference
|
||||
#[must_use]
|
||||
pub fn new(model_manager: Arc<ModelManager>, batch_size: usize) -> Self {
|
||||
Self {
|
||||
model_manager,
|
||||
config: EmbeddingServiceConfig {
|
||||
batch_size,
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with custom configuration
|
||||
#[must_use]
|
||||
pub fn with_config(model_manager: Arc<ModelManager>, config: EmbeddingServiceConfig) -> Self {
|
||||
Self {
|
||||
model_manager,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate an embedding from a single spectrogram.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `spectrogram` - The input spectrogram
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if inference fails or the embedding is invalid.
|
||||
#[instrument(skip(self, spectrogram), fields(segment_id = %spectrogram.segment_id))]
|
||||
pub async fn embed_segment(
|
||||
&self,
|
||||
spectrogram: &Spectrogram,
|
||||
) -> Result<EmbeddingOutput, EmbeddingError> {
|
||||
let start = Instant::now();
|
||||
|
||||
// Get the inference session
|
||||
let inference = self.model_manager.get_inference().await?;
|
||||
let model_version = self.model_manager.current_version();
|
||||
|
||||
// Run inference
|
||||
let raw_embedding = inference.run(&spectrogram.data)?;
|
||||
|
||||
// Convert to vector
|
||||
let mut vector: Vec<f32> = raw_embedding.iter().copied().collect();
|
||||
|
||||
// Calculate original norm before normalization
|
||||
let original_norm = normalization::compute_norm(&vector);
|
||||
|
||||
// L2 normalize if configured
|
||||
if self.config.normalize {
|
||||
normalization::l2_normalize(&mut vector);
|
||||
}
|
||||
|
||||
// Validate embedding
|
||||
if self.config.validate_embeddings {
|
||||
self.validate_embedding(&vector)?;
|
||||
}
|
||||
|
||||
// Calculate sparsity
|
||||
let sparsity = normalization::compute_sparsity(&vector);
|
||||
|
||||
// Create embedding entity
|
||||
let mut embedding = Embedding::new(
|
||||
spectrogram.segment_id,
|
||||
vector,
|
||||
model_version.full_version(),
|
||||
)?;
|
||||
|
||||
// Set metadata
|
||||
let latency_ms = start.elapsed().as_secs_f32() * 1000.0;
|
||||
embedding.metadata = EmbeddingMetadata {
|
||||
inference_latency_ms: Some(latency_ms),
|
||||
batch_id: None,
|
||||
gpu_used: inference.is_gpu(),
|
||||
original_norm: Some(original_norm),
|
||||
sparsity: Some(sparsity),
|
||||
quality_score: Some(self.compute_quality_score(&embedding)),
|
||||
};
|
||||
|
||||
embedding.tier = self.config.default_tier;
|
||||
|
||||
debug!(
|
||||
latency_ms = latency_ms,
|
||||
norm = embedding.norm(),
|
||||
sparsity = sparsity,
|
||||
"Generated embedding"
|
||||
);
|
||||
|
||||
Ok(EmbeddingOutput {
|
||||
embedding,
|
||||
gpu_used: inference.is_gpu(),
|
||||
latency_ms,
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate embeddings for multiple spectrograms in batches.
|
||||
///
|
||||
/// This is more efficient than calling `embed_segment` multiple times
|
||||
/// as it uses batched inference.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `spectrograms` - Slice of input spectrograms
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if any inference fails. Partial results are not returned.
|
||||
#[instrument(skip(self, spectrograms), fields(count = spectrograms.len()))]
|
||||
pub async fn embed_batch(
|
||||
&self,
|
||||
spectrograms: &[Spectrogram],
|
||||
) -> Result<Vec<EmbeddingOutput>, EmbeddingError> {
|
||||
if spectrograms.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let total_start = Instant::now();
|
||||
let batch_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
info!(
|
||||
batch_id = %batch_id,
|
||||
total_segments = spectrograms.len(),
|
||||
batch_size = self.config.batch_size,
|
||||
"Starting batch embedding"
|
||||
);
|
||||
|
||||
// Get the inference session
|
||||
let inference = self.model_manager.get_inference().await?;
|
||||
let model_version = self.model_manager.current_version();
|
||||
|
||||
// Process in batches
|
||||
let mut all_outputs = Vec::with_capacity(spectrograms.len());
|
||||
|
||||
for (batch_idx, chunk) in spectrograms.chunks(self.config.batch_size).enumerate() {
|
||||
let batch_start = Instant::now();
|
||||
|
||||
// Prepare batch input
|
||||
let inputs: Vec<&Array3<f32>> = chunk.iter().map(|s| &s.data).collect();
|
||||
|
||||
// Run batched inference
|
||||
let raw_embeddings = inference.run_batch(&inputs)?;
|
||||
|
||||
let batch_latency_ms = batch_start.elapsed().as_secs_f32() * 1000.0;
|
||||
let per_item_latency = batch_latency_ms / chunk.len() as f32;
|
||||
|
||||
// Process each embedding in the batch (parallelize normalization)
|
||||
let outputs: Vec<Result<EmbeddingOutput, EmbeddingError>> = chunk
|
||||
.par_iter()
|
||||
.zip(raw_embeddings.par_iter())
|
||||
.map(|(spectrogram, raw_emb)| {
|
||||
let mut vector: Vec<f32> = raw_emb.iter().copied().collect();
|
||||
let original_norm = normalization::compute_norm(&vector);
|
||||
|
||||
if self.config.normalize {
|
||||
normalization::l2_normalize(&mut vector);
|
||||
}
|
||||
|
||||
if self.config.validate_embeddings {
|
||||
self.validate_embedding(&vector)?;
|
||||
}
|
||||
|
||||
let sparsity = normalization::compute_sparsity(&vector);
|
||||
|
||||
let mut embedding = Embedding::new(
|
||||
spectrogram.segment_id,
|
||||
vector,
|
||||
model_version.full_version(),
|
||||
)?;
|
||||
|
||||
embedding.metadata = EmbeddingMetadata {
|
||||
inference_latency_ms: Some(per_item_latency),
|
||||
batch_id: Some(batch_id.clone()),
|
||||
gpu_used: inference.is_gpu(),
|
||||
original_norm: Some(original_norm),
|
||||
sparsity: Some(sparsity),
|
||||
quality_score: Some(self.compute_quality_score(&embedding)),
|
||||
};
|
||||
|
||||
embedding.tier = self.config.default_tier;
|
||||
|
||||
Ok(EmbeddingOutput {
|
||||
embedding,
|
||||
gpu_used: inference.is_gpu(),
|
||||
latency_ms: per_item_latency,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Check for errors
|
||||
let batch_outputs: Result<Vec<_>, _> = outputs.into_iter().collect();
|
||||
all_outputs.extend(batch_outputs?);
|
||||
|
||||
debug!(
|
||||
batch_idx = batch_idx,
|
||||
batch_size = chunk.len(),
|
||||
latency_ms = batch_latency_ms,
|
||||
"Completed batch"
|
||||
);
|
||||
}
|
||||
|
||||
let total_latency_ms = total_start.elapsed().as_secs_f32() * 1000.0;
|
||||
let throughput = spectrograms.len() as f32 / (total_latency_ms / 1000.0);
|
||||
|
||||
info!(
|
||||
batch_id = %batch_id,
|
||||
total_segments = spectrograms.len(),
|
||||
total_latency_ms = total_latency_ms,
|
||||
throughput_per_sec = throughput,
|
||||
"Completed batch embedding"
|
||||
);
|
||||
|
||||
Ok(all_outputs)
|
||||
}
|
||||
|
||||
/// Create a batch tracking object for monitoring progress.
|
||||
#[must_use]
|
||||
pub fn create_batch(&self, segment_ids: Vec<SegmentId>) -> EmbeddingBatch {
|
||||
EmbeddingBatch::new(segment_ids)
|
||||
}
|
||||
|
||||
/// Validate an embedding vector.
|
||||
fn validate_embedding(&self, vector: &[f32]) -> Result<(), EmbeddingError> {
|
||||
// Check dimensions
|
||||
if vector.len() != EMBEDDING_DIM {
|
||||
return Err(EmbeddingError::InvalidDimensions {
|
||||
expected: EMBEDDING_DIM,
|
||||
actual: vector.len(),
|
||||
});
|
||||
}
|
||||
|
||||
// Check for NaN values
|
||||
if vector.iter().any(|x| x.is_nan()) {
|
||||
return Err(EmbeddingError::Validation(
|
||||
"Embedding contains NaN values".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Check for infinite values
|
||||
if vector.iter().any(|x| x.is_infinite()) {
|
||||
return Err(EmbeddingError::Validation(
|
||||
"Embedding contains infinite values".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Check sparsity
|
||||
let sparsity = normalization::compute_sparsity(vector);
|
||||
if sparsity > self.config.max_sparsity {
|
||||
warn!(
|
||||
sparsity = sparsity,
|
||||
max_sparsity = self.config.max_sparsity,
|
||||
"Embedding has high sparsity"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Compute a quality score for an embedding.
|
||||
fn compute_quality_score(&self, embedding: &Embedding) -> f32 {
|
||||
let mut score = 1.0_f32;
|
||||
|
||||
// Penalize deviation from unit norm
|
||||
let norm = embedding.norm();
|
||||
let norm_deviation = (norm - 1.0).abs();
|
||||
score -= norm_deviation * 0.5;
|
||||
|
||||
// Penalize high sparsity
|
||||
if let Some(sparsity) = embedding.metadata.sparsity {
|
||||
score -= sparsity * 0.3;
|
||||
}
|
||||
|
||||
score.clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
/// Get the current model version being used.
|
||||
#[must_use]
|
||||
pub fn model_version(&self) -> String {
|
||||
self.model_manager.current_version().full_version()
|
||||
}
|
||||
|
||||
/// Check if the service is ready for inference.
|
||||
pub async fn is_ready(&self) -> bool {
|
||||
self.model_manager.is_ready().await
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for creating embedding service instances
|
||||
#[derive(Debug)]
|
||||
pub struct EmbeddingServiceBuilder {
|
||||
model_manager: Option<Arc<ModelManager>>,
|
||||
config: EmbeddingServiceConfig,
|
||||
}
|
||||
|
||||
impl EmbeddingServiceBuilder {
|
||||
/// Create a new builder
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
model_manager: None,
|
||||
config: EmbeddingServiceConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the model manager
|
||||
#[must_use]
|
||||
pub fn model_manager(mut self, manager: Arc<ModelManager>) -> Self {
|
||||
self.model_manager = Some(manager);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the batch size
|
||||
#[must_use]
|
||||
pub fn batch_size(mut self, size: usize) -> Self {
|
||||
self.config.batch_size = size;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set whether to normalize embeddings
|
||||
#[must_use]
|
||||
pub fn normalize(mut self, normalize: bool) -> Self {
|
||||
self.config.normalize = normalize;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the default storage tier
|
||||
#[must_use]
|
||||
pub fn default_tier(mut self, tier: StorageTier) -> Self {
|
||||
self.config.default_tier = tier;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set whether to validate embeddings
|
||||
#[must_use]
|
||||
pub fn validate_embeddings(mut self, validate: bool) -> Self {
|
||||
self.config.validate_embeddings = validate;
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the embedding service
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the model manager is not set.
|
||||
pub fn build(self) -> Result<EmbeddingService, EmbeddingError> {
|
||||
let model_manager = self.model_manager.ok_or_else(|| {
|
||||
EmbeddingError::Validation("Model manager is required".to_string())
|
||||
})?;
|
||||
|
||||
Ok(EmbeddingService::with_config(model_manager, self.config))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EmbeddingServiceBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use ndarray::Array2;
|
||||
|
||||
#[test]
|
||||
fn test_spectrogram_creation() {
|
||||
let data = Array2::zeros((MEL_FRAMES, MEL_BINS));
|
||||
let segment_id = SegmentId::new();
|
||||
let spec = Spectrogram::new(data, segment_id);
|
||||
assert!(spec.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_spectrogram_invalid_dimensions() {
|
||||
let data = Array2::zeros((100, 100)); // Wrong dimensions
|
||||
let segment_id = SegmentId::new();
|
||||
let spec = Spectrogram::new(data, segment_id);
|
||||
assert!(spec.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_service_config_default() {
|
||||
let config = EmbeddingServiceConfig::default();
|
||||
assert_eq!(config.batch_size, 8);
|
||||
assert!(config.normalize);
|
||||
assert!(config.validate_embeddings);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_service_builder() {
|
||||
let builder = EmbeddingServiceBuilder::new()
|
||||
.batch_size(16)
|
||||
.normalize(false)
|
||||
.default_tier(StorageTier::Warm);
|
||||
|
||||
assert_eq!(builder.config.batch_size, 16);
|
||||
assert!(!builder.config.normalize);
|
||||
assert_eq!(builder.config.default_tier, StorageTier::Warm);
|
||||
}
|
||||
}
|
||||
627
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-embedding/src/domain/entities.rs
vendored
Normal file
627
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-embedding/src/domain/entities.rs
vendored
Normal file
@@ -0,0 +1,627 @@
|
||||
//! Domain entities for the embedding bounded context.
|
||||
//!
|
||||
//! This module defines the core domain entities:
|
||||
//! - `Embedding`: A 1536-dimensional vector representation of an audio segment
|
||||
//! - `EmbeddingModel`: Configuration and metadata for Perch 2.0 ONNX model
|
||||
//! - `EmbeddingBatch`: Collection of embeddings processed together
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::EMBEDDING_DIM;
|
||||
|
||||
/// Unique identifier for an embedding
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct EmbeddingId(Uuid);
|
||||
|
||||
impl EmbeddingId {
|
||||
/// Create a new unique embedding ID
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self(Uuid::new_v4())
|
||||
}
|
||||
|
||||
/// Create from an existing UUID
|
||||
#[must_use]
|
||||
pub const fn from_uuid(uuid: Uuid) -> Self {
|
||||
Self(uuid)
|
||||
}
|
||||
|
||||
/// Get the inner UUID
|
||||
#[must_use]
|
||||
pub const fn as_uuid(&self) -> &Uuid {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EmbeddingId {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for EmbeddingId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Unique identifier for an audio segment
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct SegmentId(Uuid);
|
||||
|
||||
impl SegmentId {
|
||||
/// Create a new unique segment ID
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self(Uuid::new_v4())
|
||||
}
|
||||
|
||||
/// Create from an existing UUID
|
||||
#[must_use]
|
||||
pub const fn from_uuid(uuid: Uuid) -> Self {
|
||||
Self(uuid)
|
||||
}
|
||||
|
||||
/// Get the inner UUID
|
||||
#[must_use]
|
||||
pub const fn as_uuid(&self) -> &Uuid {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SegmentId {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for SegmentId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Storage tier for embeddings based on access patterns
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum StorageTier {
|
||||
/// Hot storage: frequently accessed, lowest latency
|
||||
/// Full f32 precision, in-memory or SSD
|
||||
Hot,
|
||||
|
||||
/// Warm storage: occasional access, moderate latency
|
||||
/// F16 quantized, SSD storage
|
||||
Warm,
|
||||
|
||||
/// Cold storage: rare access, higher latency acceptable
|
||||
/// INT8 quantized, archive storage
|
||||
Cold,
|
||||
}
|
||||
|
||||
impl Default for StorageTier {
|
||||
fn default() -> Self {
|
||||
Self::Hot
|
||||
}
|
||||
}
|
||||
|
||||
impl StorageTier {
|
||||
/// Get the bytes per dimension for this tier
|
||||
#[must_use]
|
||||
pub const fn bytes_per_dim(&self) -> usize {
|
||||
match self {
|
||||
Self::Hot => 4, // f32
|
||||
Self::Warm => 2, // f16
|
||||
Self::Cold => 1, // i8
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the total bytes for an embedding in this tier
|
||||
#[must_use]
|
||||
pub const fn embedding_bytes(&self) -> usize {
|
||||
self.bytes_per_dim() * EMBEDDING_DIM
|
||||
}
|
||||
}
|
||||
|
||||
/// Timestamp type alias for consistency
|
||||
pub type Timestamp = DateTime<Utc>;
|
||||
|
||||
/// A 1536-dimensional embedding vector representing an audio segment.
|
||||
///
|
||||
/// This is the aggregate root for the embedding context. Each embedding
|
||||
/// is generated by the Perch 2.0 model from a preprocessed audio segment.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Embedding {
|
||||
/// Unique identifier for this embedding
|
||||
pub id: EmbeddingId,
|
||||
|
||||
/// Reference to the source audio segment
|
||||
pub segment_id: SegmentId,
|
||||
|
||||
/// The 1536-dimensional embedding vector (L2 normalized)
|
||||
pub vector: Vec<f32>,
|
||||
|
||||
/// Version of the model used to generate this embedding
|
||||
pub model_version: String,
|
||||
|
||||
/// When this embedding was created
|
||||
pub created_at: Timestamp,
|
||||
|
||||
/// Storage tier for this embedding
|
||||
pub tier: StorageTier,
|
||||
|
||||
/// Additional metadata about the embedding
|
||||
pub metadata: EmbeddingMetadata,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
/// Create a new embedding with the given parameters
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the vector dimension is not 1536
|
||||
pub fn new(
|
||||
segment_id: SegmentId,
|
||||
vector: Vec<f32>,
|
||||
model_version: String,
|
||||
) -> Result<Self, crate::EmbeddingError> {
|
||||
if vector.len() != EMBEDDING_DIM {
|
||||
return Err(crate::EmbeddingError::InvalidDimensions {
|
||||
expected: EMBEDDING_DIM,
|
||||
actual: vector.len(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
id: EmbeddingId::new(),
|
||||
segment_id,
|
||||
vector,
|
||||
model_version,
|
||||
created_at: Utc::now(),
|
||||
tier: StorageTier::default(),
|
||||
metadata: EmbeddingMetadata::default(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the L2 norm of the embedding vector
|
||||
#[must_use]
|
||||
pub fn norm(&self) -> f32 {
|
||||
self.vector
|
||||
.iter()
|
||||
.map(|x| x * x)
|
||||
.sum::<f32>()
|
||||
.sqrt()
|
||||
}
|
||||
|
||||
/// Check if the embedding is properly L2 normalized (norm close to 1.0)
|
||||
#[must_use]
|
||||
pub fn is_normalized(&self) -> bool {
|
||||
let norm = self.norm();
|
||||
(0.99..=1.01).contains(&norm)
|
||||
}
|
||||
|
||||
/// Check if the embedding contains any invalid values (NaN or Inf)
|
||||
#[must_use]
|
||||
pub fn is_valid(&self) -> bool {
|
||||
!self.vector.iter().any(|x| x.is_nan() || x.is_infinite())
|
||||
}
|
||||
|
||||
/// Compute cosine similarity with another embedding
|
||||
///
|
||||
/// For L2-normalized embeddings, this is equivalent to dot product
|
||||
#[must_use]
|
||||
pub fn cosine_similarity(&self, other: &Self) -> f32 {
|
||||
self.vector
|
||||
.iter()
|
||||
.zip(other.vector.iter())
|
||||
.map(|(a, b)| a * b)
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Change the storage tier for this embedding
|
||||
pub fn set_tier(&mut self, tier: StorageTier) {
|
||||
self.tier = tier;
|
||||
}
|
||||
}
|
||||
|
||||
/// Metadata about an embedding's generation
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct EmbeddingMetadata {
|
||||
/// Time taken for inference in milliseconds
|
||||
pub inference_latency_ms: Option<f32>,
|
||||
|
||||
/// Batch ID if processed in a batch
|
||||
pub batch_id: Option<String>,
|
||||
|
||||
/// Whether GPU was used for inference
|
||||
pub gpu_used: bool,
|
||||
|
||||
/// Original norm before L2 normalization
|
||||
pub original_norm: Option<f32>,
|
||||
|
||||
/// Sparsity of the embedding (fraction of near-zero values)
|
||||
pub sparsity: Option<f32>,
|
||||
|
||||
/// Quality score based on various metrics
|
||||
pub quality_score: Option<f32>,
|
||||
}
|
||||
|
||||
/// Model version identifier
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct ModelVersion {
|
||||
/// Model name (e.g., "perch-v2")
|
||||
pub name: String,
|
||||
|
||||
/// Semantic version string (e.g., "2.0.0")
|
||||
pub version: String,
|
||||
|
||||
/// Model variant (e.g., "base", "quantized", "pruned")
|
||||
pub variant: String,
|
||||
}
|
||||
|
||||
impl ModelVersion {
|
||||
/// Create a new model version
|
||||
#[must_use]
|
||||
pub fn new(name: impl Into<String>, version: impl Into<String>, variant: impl Into<String>) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
version: version.into(),
|
||||
variant: variant.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the default Perch 2.0 base model version
|
||||
#[must_use]
|
||||
pub fn perch_v2_base() -> Self {
|
||||
Self::new("perch-v2", "2.0.0", "base")
|
||||
}
|
||||
|
||||
/// Get the Perch 2.0 quantized model version
|
||||
#[must_use]
|
||||
pub fn perch_v2_quantized() -> Self {
|
||||
Self::new("perch-v2", "2.0.0", "quantized")
|
||||
}
|
||||
|
||||
/// Get the full version string
|
||||
#[must_use]
|
||||
pub fn full_version(&self) -> String {
|
||||
format!("{}-{}-{}", self.name, self.version, self.variant)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ModelVersion {
|
||||
fn default() -> Self {
|
||||
Self::perch_v2_base()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ModelVersion {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.full_version())
|
||||
}
|
||||
}
|
||||
|
||||
/// Input specification for the embedding model
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct InputSpecification {
|
||||
/// Expected sample rate in Hz
|
||||
pub sample_rate: u32,
|
||||
|
||||
/// Window duration in seconds
|
||||
pub window_duration: f32,
|
||||
|
||||
/// Number of samples per window
|
||||
pub window_samples: usize,
|
||||
|
||||
/// Number of mel frequency bins
|
||||
pub mel_bins: usize,
|
||||
|
||||
/// Number of time frames
|
||||
pub mel_frames: usize,
|
||||
|
||||
/// Frequency range (low, high) in Hz
|
||||
pub frequency_range: (f32, f32),
|
||||
}
|
||||
|
||||
impl Default for InputSpecification {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
sample_rate: crate::TARGET_SAMPLE_RATE,
|
||||
window_duration: crate::TARGET_WINDOW_SECONDS,
|
||||
window_samples: crate::TARGET_WINDOW_SAMPLES,
|
||||
mel_bins: crate::MEL_BINS,
|
||||
mel_frames: crate::MEL_FRAMES,
|
||||
frequency_range: (60.0, 16000.0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Status of an embedding model
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum ModelStatus {
|
||||
/// Model is available and ready for inference
|
||||
Active,
|
||||
|
||||
/// Model is being loaded
|
||||
Loading,
|
||||
|
||||
/// Model failed to load
|
||||
Failed,
|
||||
|
||||
/// Model is deprecated and should not be used
|
||||
Deprecated,
|
||||
}
|
||||
|
||||
/// Configuration and metadata for an embedding model.
|
||||
///
|
||||
/// Represents the Perch 2.0 ONNX model used for generating embeddings.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbeddingModel {
|
||||
/// Model name
|
||||
pub name: String,
|
||||
|
||||
/// Model version
|
||||
pub version: ModelVersion,
|
||||
|
||||
/// Output embedding dimensions
|
||||
pub dimensions: usize,
|
||||
|
||||
/// SHA-256 checksum of the model file
|
||||
pub checksum: String,
|
||||
|
||||
/// Input specification
|
||||
pub input_spec: InputSpecification,
|
||||
|
||||
/// Current model status
|
||||
pub status: ModelStatus,
|
||||
|
||||
/// When the model was last loaded
|
||||
pub loaded_at: Option<Timestamp>,
|
||||
|
||||
/// Path to the model file
|
||||
pub model_path: Option<String>,
|
||||
}
|
||||
|
||||
impl EmbeddingModel {
|
||||
/// Create a new embedding model configuration
|
||||
#[must_use]
|
||||
pub fn new(name: String, version: ModelVersion, checksum: String) -> Self {
|
||||
Self {
|
||||
name,
|
||||
version,
|
||||
dimensions: EMBEDDING_DIM,
|
||||
checksum,
|
||||
input_spec: InputSpecification::default(),
|
||||
status: ModelStatus::Loading,
|
||||
loaded_at: None,
|
||||
model_path: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a default Perch 2.0 model configuration
|
||||
#[must_use]
|
||||
pub fn perch_v2_default() -> Self {
|
||||
Self::new(
|
||||
"perch-v2".to_string(),
|
||||
ModelVersion::perch_v2_base(),
|
||||
String::new(), // Checksum will be computed on load
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if the model is ready for inference
|
||||
#[must_use]
|
||||
pub const fn is_ready(&self) -> bool {
|
||||
matches!(self.status, ModelStatus::Active)
|
||||
}
|
||||
|
||||
/// Mark the model as active
|
||||
pub fn mark_active(&mut self) {
|
||||
self.status = ModelStatus::Active;
|
||||
self.loaded_at = Some(Utc::now());
|
||||
}
|
||||
|
||||
/// Mark the model as failed
|
||||
pub fn mark_failed(&mut self) {
|
||||
self.status = ModelStatus::Failed;
|
||||
}
|
||||
}
|
||||
|
||||
/// A batch of embeddings processed together
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbeddingBatch {
|
||||
/// Unique batch identifier
|
||||
pub id: String,
|
||||
|
||||
/// Segment IDs in this batch
|
||||
pub segment_ids: Vec<SegmentId>,
|
||||
|
||||
/// Batch processing status
|
||||
pub status: BatchStatus,
|
||||
|
||||
/// When batch processing started
|
||||
pub started_at: Timestamp,
|
||||
|
||||
/// When batch processing completed
|
||||
pub completed_at: Option<Timestamp>,
|
||||
|
||||
/// Batch processing metrics
|
||||
pub metrics: BatchMetrics,
|
||||
}
|
||||
|
||||
/// Status of a batch operation
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum BatchStatus {
|
||||
/// Batch is queued for processing
|
||||
Pending,
|
||||
|
||||
/// Batch is currently being processed
|
||||
Processing,
|
||||
|
||||
/// Batch processing completed successfully
|
||||
Completed,
|
||||
|
||||
/// Batch processing failed
|
||||
Failed,
|
||||
}
|
||||
|
||||
/// Metrics for batch processing
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct BatchMetrics {
|
||||
/// Total segments in batch
|
||||
pub total_segments: u32,
|
||||
|
||||
/// Successfully processed segments
|
||||
pub success_count: u32,
|
||||
|
||||
/// Failed segments
|
||||
pub failure_count: u32,
|
||||
|
||||
/// Average inference latency in milliseconds
|
||||
pub avg_latency_ms: f32,
|
||||
|
||||
/// Throughput in segments per second
|
||||
pub throughput: f32,
|
||||
}
|
||||
|
||||
impl EmbeddingBatch {
|
||||
/// Create a new embedding batch
|
||||
#[must_use]
|
||||
pub fn new(segment_ids: Vec<SegmentId>) -> Self {
|
||||
let total = segment_ids.len() as u32;
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
segment_ids,
|
||||
status: BatchStatus::Pending,
|
||||
started_at: Utc::now(),
|
||||
completed_at: None,
|
||||
metrics: BatchMetrics {
|
||||
total_segments: total,
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark batch as processing
|
||||
pub fn mark_processing(&mut self) {
|
||||
self.status = BatchStatus::Processing;
|
||||
}
|
||||
|
||||
/// Mark batch as completed with metrics
|
||||
pub fn mark_completed(&mut self, success_count: u32, failure_count: u32, avg_latency_ms: f32) {
|
||||
self.status = BatchStatus::Completed;
|
||||
self.completed_at = Some(Utc::now());
|
||||
self.metrics.success_count = success_count;
|
||||
self.metrics.failure_count = failure_count;
|
||||
self.metrics.avg_latency_ms = avg_latency_ms;
|
||||
|
||||
if let Some(completed) = self.completed_at {
|
||||
let duration = (completed - self.started_at).num_milliseconds() as f32 / 1000.0;
|
||||
if duration > 0.0 {
|
||||
self.metrics.throughput = self.metrics.total_segments as f32 / duration;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark batch as failed
|
||||
pub fn mark_failed(&mut self) {
|
||||
self.status = BatchStatus::Failed;
|
||||
self.completed_at = Some(Utc::now());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_embedding_id_generation() {
|
||||
let id1 = EmbeddingId::new();
|
||||
let id2 = EmbeddingId::new();
|
||||
assert_ne!(id1, id2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_creation() {
|
||||
let segment_id = SegmentId::new();
|
||||
let vector = vec![0.0; EMBEDDING_DIM];
|
||||
let embedding = Embedding::new(segment_id, vector, "perch-v2-2.0.0-base".to_string());
|
||||
assert!(embedding.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_invalid_dimensions() {
|
||||
let segment_id = SegmentId::new();
|
||||
let vector = vec![0.0; 100]; // Wrong dimension
|
||||
let result = Embedding::new(segment_id, vector, "perch-v2-2.0.0-base".to_string());
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_norm() {
|
||||
let segment_id = SegmentId::new();
|
||||
let mut vector = vec![0.0; EMBEDDING_DIM];
|
||||
vector[0] = 1.0; // Unit vector
|
||||
let embedding = Embedding::new(segment_id, vector, "perch-v2-2.0.0-base".to_string()).unwrap();
|
||||
assert!((embedding.norm() - 1.0).abs() < 1e-6);
|
||||
assert!(embedding.is_normalized());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity() {
|
||||
let segment_id = SegmentId::new();
|
||||
|
||||
// Create two identical normalized vectors
|
||||
let mut vector = vec![0.0; EMBEDDING_DIM];
|
||||
vector[0] = 1.0;
|
||||
|
||||
let emb1 = Embedding::new(segment_id, vector.clone(), "test".to_string()).unwrap();
|
||||
let emb2 = Embedding::new(segment_id, vector, "test".to_string()).unwrap();
|
||||
|
||||
let similarity = emb1.cosine_similarity(&emb2);
|
||||
assert!((similarity - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_storage_tier_bytes() {
|
||||
assert_eq!(StorageTier::Hot.bytes_per_dim(), 4);
|
||||
assert_eq!(StorageTier::Warm.bytes_per_dim(), 2);
|
||||
assert_eq!(StorageTier::Cold.bytes_per_dim(), 1);
|
||||
|
||||
assert_eq!(StorageTier::Hot.embedding_bytes(), EMBEDDING_DIM * 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_version() {
|
||||
let version = ModelVersion::perch_v2_base();
|
||||
assert_eq!(version.name, "perch-v2");
|
||||
assert_eq!(version.version, "2.0.0");
|
||||
assert_eq!(version.variant, "base");
|
||||
assert_eq!(version.full_version(), "perch-v2-2.0.0-base");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_input_specification_defaults() {
|
||||
let spec = InputSpecification::default();
|
||||
assert_eq!(spec.sample_rate, 32000);
|
||||
assert_eq!(spec.window_samples, 160_000);
|
||||
assert_eq!(spec.mel_bins, 128);
|
||||
assert_eq!(spec.mel_frames, 500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_batch_lifecycle() {
|
||||
let segment_ids = vec![SegmentId::new(), SegmentId::new()];
|
||||
let mut batch = EmbeddingBatch::new(segment_ids);
|
||||
|
||||
assert_eq!(batch.status, BatchStatus::Pending);
|
||||
assert_eq!(batch.metrics.total_segments, 2);
|
||||
|
||||
batch.mark_processing();
|
||||
assert_eq!(batch.status, BatchStatus::Processing);
|
||||
|
||||
batch.mark_completed(2, 0, 50.0);
|
||||
assert_eq!(batch.status, BatchStatus::Completed);
|
||||
assert_eq!(batch.metrics.success_count, 2);
|
||||
assert!(batch.completed_at.is_some());
|
||||
}
|
||||
}
|
||||
7
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-embedding/src/domain/mod.rs
vendored
Normal file
7
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-embedding/src/domain/mod.rs
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
//! Domain layer for embedding bounded context.
|
||||
//!
|
||||
//! Contains core entities, value objects, and repository traits that define
|
||||
//! the embedding domain model.
|
||||
|
||||
pub mod entities;
|
||||
pub mod repository;
|
||||
376
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-embedding/src/domain/repository.rs
vendored
Normal file
376
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-embedding/src/domain/repository.rs
vendored
Normal file
@@ -0,0 +1,376 @@
|
||||
//! Repository traits for the embedding bounded context.
|
||||
//!
|
||||
//! Defines the interfaces for persisting and retrieving embeddings,
|
||||
//! following the repository pattern from Domain-Driven Design.
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use super::entities::{Embedding, EmbeddingId, EmbeddingModel, SegmentId, StorageTier};
|
||||
use crate::EmbeddingError;
|
||||
|
||||
/// Repository trait for embedding persistence.
|
||||
///
|
||||
/// Implementations may use various storage backends:
|
||||
/// - In-memory (for testing)
|
||||
/// - Vector databases (Qdrant, Milvus)
|
||||
/// - Relational databases (PostgreSQL with pgvector)
|
||||
/// - File-based storage
|
||||
#[async_trait]
|
||||
pub trait EmbeddingRepository: Send + Sync {
|
||||
/// Save a single embedding to the repository.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the embedding cannot be persisted.
|
||||
async fn save(&self, embedding: &Embedding) -> Result<(), EmbeddingError>;
|
||||
|
||||
/// Find an embedding by its unique identifier.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// - `Ok(Some(embedding))` if found
|
||||
/// - `Ok(None)` if not found
|
||||
/// - `Err(...)` on storage errors
|
||||
async fn find_by_id(&self, id: &EmbeddingId) -> Result<Option<Embedding>, EmbeddingError>;
|
||||
|
||||
/// Find an embedding by its source segment ID.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// - `Ok(Some(embedding))` if found
|
||||
/// - `Ok(None)` if not found
|
||||
/// - `Err(...)` on storage errors
|
||||
async fn find_by_segment(&self, segment_id: &SegmentId) -> Result<Option<Embedding>, EmbeddingError>;
|
||||
|
||||
/// Save multiple embeddings in a batch operation.
|
||||
///
|
||||
/// This is more efficient than calling `save` multiple times
|
||||
/// as it can use bulk insert operations.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if any embedding cannot be persisted.
|
||||
/// Implementations should document their atomicity guarantees.
|
||||
async fn batch_save(&self, embeddings: &[Embedding]) -> Result<(), EmbeddingError>;
|
||||
|
||||
/// Delete an embedding by its ID.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// - `Ok(true)` if the embedding was deleted
|
||||
/// - `Ok(false)` if the embedding was not found
|
||||
/// - `Err(...)` on storage errors
|
||||
async fn delete(&self, id: &EmbeddingId) -> Result<bool, EmbeddingError>;
|
||||
|
||||
/// Delete embeddings by segment ID.
|
||||
///
|
||||
/// Useful when a segment is reprocessed with a new model version.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Number of embeddings deleted.
|
||||
async fn delete_by_segment(&self, segment_id: &SegmentId) -> Result<usize, EmbeddingError>;
|
||||
|
||||
/// Count total embeddings in the repository.
|
||||
async fn count(&self) -> Result<u64, EmbeddingError>;
|
||||
|
||||
/// Count embeddings by storage tier.
|
||||
async fn count_by_tier(&self, tier: StorageTier) -> Result<u64, EmbeddingError>;
|
||||
|
||||
/// Find embeddings by model version.
|
||||
///
|
||||
/// Useful for identifying embeddings that need re-generation
|
||||
/// after a model update.
|
||||
async fn find_by_model_version(
|
||||
&self,
|
||||
model_version: &str,
|
||||
limit: usize,
|
||||
offset: usize,
|
||||
) -> Result<Vec<Embedding>, EmbeddingError>;
|
||||
|
||||
/// Update the storage tier for an embedding.
|
||||
///
|
||||
/// Used for tiered storage management (hot -> warm -> cold).
|
||||
async fn update_tier(
|
||||
&self,
|
||||
id: &EmbeddingId,
|
||||
tier: StorageTier,
|
||||
) -> Result<bool, EmbeddingError>;
|
||||
|
||||
/// Check if an embedding exists.
|
||||
async fn exists(&self, id: &EmbeddingId) -> Result<bool, EmbeddingError> {
|
||||
Ok(self.find_by_id(id).await?.is_some())
|
||||
}
|
||||
|
||||
/// Check if an embedding exists for a segment.
|
||||
async fn exists_for_segment(&self, segment_id: &SegmentId) -> Result<bool, EmbeddingError> {
|
||||
Ok(self.find_by_segment(segment_id).await?.is_some())
|
||||
}
|
||||
}
|
||||
|
||||
/// Repository trait for embedding model management.
|
||||
///
|
||||
/// Manages the lifecycle of ONNX models used for embedding generation.
|
||||
#[async_trait]
|
||||
pub trait ModelRepository: Send + Sync {
|
||||
/// Save or update a model configuration.
|
||||
async fn save_model(&self, model: &EmbeddingModel) -> Result<(), EmbeddingError>;
|
||||
|
||||
/// Find a model by name and version.
|
||||
async fn find_model(
|
||||
&self,
|
||||
name: &str,
|
||||
version: &str,
|
||||
) -> Result<Option<EmbeddingModel>, EmbeddingError>;
|
||||
|
||||
/// Get the currently active model for a given name.
|
||||
async fn get_active_model(&self, name: &str) -> Result<Option<EmbeddingModel>, EmbeddingError>;
|
||||
|
||||
/// List all available models.
|
||||
async fn list_models(&self) -> Result<Vec<EmbeddingModel>, EmbeddingError>;
|
||||
|
||||
/// Delete a model configuration.
|
||||
async fn delete_model(&self, name: &str, version: &str) -> Result<bool, EmbeddingError>;
|
||||
}
|
||||
|
||||
/// Query parameters for embedding search operations.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct EmbeddingQuery {
|
||||
/// Filter by segment IDs
|
||||
pub segment_ids: Option<Vec<SegmentId>>,
|
||||
|
||||
/// Filter by model version
|
||||
pub model_version: Option<String>,
|
||||
|
||||
/// Filter by storage tier
|
||||
pub tier: Option<StorageTier>,
|
||||
|
||||
/// Filter by creation date (after)
|
||||
pub created_after: Option<chrono::DateTime<chrono::Utc>>,
|
||||
|
||||
/// Filter by creation date (before)
|
||||
pub created_before: Option<chrono::DateTime<chrono::Utc>>,
|
||||
|
||||
/// Maximum results to return
|
||||
pub limit: Option<usize>,
|
||||
|
||||
/// Offset for pagination
|
||||
pub offset: Option<usize>,
|
||||
}
|
||||
|
||||
impl EmbeddingQuery {
|
||||
/// Create a new query builder
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Filter by segment IDs
|
||||
#[must_use]
|
||||
pub fn with_segment_ids(mut self, ids: Vec<SegmentId>) -> Self {
|
||||
self.segment_ids = Some(ids);
|
||||
self
|
||||
}
|
||||
|
||||
/// Filter by model version
|
||||
#[must_use]
|
||||
pub fn with_model_version(mut self, version: impl Into<String>) -> Self {
|
||||
self.model_version = Some(version.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Filter by storage tier
|
||||
#[must_use]
|
||||
pub const fn with_tier(mut self, tier: StorageTier) -> Self {
|
||||
self.tier = Some(tier);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set pagination limit
|
||||
#[must_use]
|
||||
pub const fn with_limit(mut self, limit: usize) -> Self {
|
||||
self.limit = Some(limit);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set pagination offset
|
||||
#[must_use]
|
||||
pub const fn with_offset(mut self, offset: usize) -> Self {
|
||||
self.offset = Some(offset);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Extended repository trait with query support.
|
||||
#[async_trait]
|
||||
pub trait QueryableEmbeddingRepository: EmbeddingRepository {
|
||||
/// Query embeddings with filters.
|
||||
async fn query(&self, query: &EmbeddingQuery) -> Result<Vec<Embedding>, EmbeddingError>;
|
||||
|
||||
/// Find k nearest neighbors to a query vector.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `query_vector` - The query embedding vector
|
||||
/// * `k` - Number of neighbors to return
|
||||
/// * `ef_search` - HNSW search parameter (larger = more accurate, slower)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Vector of (embedding, distance) pairs, sorted by distance ascending.
|
||||
async fn find_nearest(
|
||||
&self,
|
||||
query_vector: &[f32],
|
||||
k: usize,
|
||||
ef_search: Option<usize>,
|
||||
) -> Result<Vec<(Embedding, f32)>, EmbeddingError>;
|
||||
|
||||
/// Find embeddings within a distance threshold.
|
||||
async fn find_within_distance(
|
||||
&self,
|
||||
query_vector: &[f32],
|
||||
max_distance: f32,
|
||||
) -> Result<Vec<(Embedding, f32)>, EmbeddingError>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// In-memory implementation for testing
|
||||
struct InMemoryEmbeddingRepository {
|
||||
embeddings: Arc<RwLock<HashMap<EmbeddingId, Embedding>>>,
|
||||
}
|
||||
|
||||
impl InMemoryEmbeddingRepository {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
embeddings: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingRepository for InMemoryEmbeddingRepository {
|
||||
async fn save(&self, embedding: &Embedding) -> Result<(), EmbeddingError> {
|
||||
self.embeddings.write().await.insert(embedding.id, embedding.clone());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn find_by_id(&self, id: &EmbeddingId) -> Result<Option<Embedding>, EmbeddingError> {
|
||||
Ok(self.embeddings.read().await.get(id).cloned())
|
||||
}
|
||||
|
||||
async fn find_by_segment(&self, segment_id: &SegmentId) -> Result<Option<Embedding>, EmbeddingError> {
|
||||
Ok(self.embeddings.read().await.values()
|
||||
.find(|e| e.segment_id == *segment_id)
|
||||
.cloned())
|
||||
}
|
||||
|
||||
async fn batch_save(&self, embeddings: &[Embedding]) -> Result<(), EmbeddingError> {
|
||||
let mut store = self.embeddings.write().await;
|
||||
for embedding in embeddings {
|
||||
store.insert(embedding.id, embedding.clone());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete(&self, id: &EmbeddingId) -> Result<bool, EmbeddingError> {
|
||||
Ok(self.embeddings.write().await.remove(id).is_some())
|
||||
}
|
||||
|
||||
async fn delete_by_segment(&self, segment_id: &SegmentId) -> Result<usize, EmbeddingError> {
|
||||
let mut store = self.embeddings.write().await;
|
||||
let to_remove: Vec<_> = store.iter()
|
||||
.filter(|(_, e)| e.segment_id == *segment_id)
|
||||
.map(|(id, _)| *id)
|
||||
.collect();
|
||||
let count = to_remove.len();
|
||||
for id in to_remove {
|
||||
store.remove(&id);
|
||||
}
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
async fn count(&self) -> Result<u64, EmbeddingError> {
|
||||
Ok(self.embeddings.read().await.len() as u64)
|
||||
}
|
||||
|
||||
async fn count_by_tier(&self, tier: StorageTier) -> Result<u64, EmbeddingError> {
|
||||
Ok(self.embeddings.read().await.values()
|
||||
.filter(|e| e.tier == tier)
|
||||
.count() as u64)
|
||||
}
|
||||
|
||||
async fn find_by_model_version(
|
||||
&self,
|
||||
model_version: &str,
|
||||
limit: usize,
|
||||
offset: usize,
|
||||
) -> Result<Vec<Embedding>, EmbeddingError> {
|
||||
Ok(self.embeddings.read().await.values()
|
||||
.filter(|e| e.model_version == model_version)
|
||||
.skip(offset)
|
||||
.take(limit)
|
||||
.cloned()
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn update_tier(
|
||||
&self,
|
||||
id: &EmbeddingId,
|
||||
tier: StorageTier,
|
||||
) -> Result<bool, EmbeddingError> {
|
||||
let mut store = self.embeddings.write().await;
|
||||
if let Some(embedding) = store.get_mut(id) {
|
||||
embedding.tier = tier;
|
||||
Ok(true)
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_in_memory_repository() {
|
||||
let repo = InMemoryEmbeddingRepository::new();
|
||||
let segment_id = SegmentId::new();
|
||||
let vector = vec![0.0; crate::EMBEDDING_DIM];
|
||||
let embedding = Embedding::new(segment_id, vector, "test".to_string()).unwrap();
|
||||
|
||||
// Save
|
||||
repo.save(&embedding).await.unwrap();
|
||||
|
||||
// Find by ID
|
||||
let found = repo.find_by_id(&embedding.id).await.unwrap();
|
||||
assert!(found.is_some());
|
||||
|
||||
// Find by segment
|
||||
let found = repo.find_by_segment(&segment_id).await.unwrap();
|
||||
assert!(found.is_some());
|
||||
|
||||
// Count
|
||||
assert_eq!(repo.count().await.unwrap(), 1);
|
||||
|
||||
// Delete
|
||||
assert!(repo.delete(&embedding.id).await.unwrap());
|
||||
assert_eq!(repo.count().await.unwrap(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_query_builder() {
|
||||
let query = EmbeddingQuery::new()
|
||||
.with_model_version("perch-v2-2.0.0-base")
|
||||
.with_tier(StorageTier::Hot)
|
||||
.with_limit(100)
|
||||
.with_offset(0);
|
||||
|
||||
assert_eq!(query.model_version.as_deref(), Some("perch-v2-2.0.0-base"));
|
||||
assert_eq!(query.tier, Some(StorageTier::Hot));
|
||||
assert_eq!(query.limit, Some(100));
|
||||
assert_eq!(query.offset, Some(0));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
//! Infrastructure layer for embedding bounded context.
|
||||
//!
|
||||
//! Contains implementations for ONNX model loading, inference,
|
||||
//! and integration with external systems.
|
||||
|
||||
pub mod model_manager;
|
||||
pub mod onnx_inference;
|
||||
@@ -0,0 +1,511 @@
|
||||
//! Model management for ONNX embedding models.
|
||||
//!
|
||||
//! Provides thread-safe loading, caching, and hot-swapping of
|
||||
//! Perch 2.0 ONNX models for embedding generation.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
use parking_lot::RwLock;
|
||||
use sha2::{Digest, Sha256};
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info, instrument, warn};
|
||||
|
||||
use super::onnx_inference::OnnxInference;
|
||||
use crate::domain::entities::{EmbeddingModel, ModelVersion};
|
||||
|
||||
/// Errors that can occur during model management
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ModelError {
|
||||
/// Model file not found
|
||||
#[error("Model not found: {0}")]
|
||||
NotFound(String),
|
||||
|
||||
/// Failed to load model
|
||||
#[error("Failed to load model: {0}")]
|
||||
LoadFailed(String),
|
||||
|
||||
/// Checksum verification failed
|
||||
#[error("Checksum mismatch for model {model}: expected {expected}, got {actual}")]
|
||||
ChecksumMismatch {
|
||||
/// Model name
|
||||
model: String,
|
||||
/// Expected checksum
|
||||
expected: String,
|
||||
/// Actual checksum
|
||||
actual: String,
|
||||
},
|
||||
|
||||
/// Model initialization failed
|
||||
#[error("Model initialization failed: {0}")]
|
||||
InitializationFailed(String),
|
||||
|
||||
/// IO error
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
/// ONNX Runtime error
|
||||
#[error("ONNX Runtime error: {0}")]
|
||||
OnnxRuntime(String),
|
||||
|
||||
/// Model not ready
|
||||
#[error("Model not ready: {0}")]
|
||||
NotReady(String),
|
||||
}
|
||||
|
||||
/// Configuration for the model manager
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelConfig {
|
||||
/// Directory containing model files
|
||||
pub model_dir: PathBuf,
|
||||
|
||||
/// Number of threads for intra-op parallelism
|
||||
pub intra_op_threads: usize,
|
||||
|
||||
/// Number of threads for inter-op parallelism
|
||||
pub inter_op_threads: usize,
|
||||
|
||||
/// Whether to verify model checksums on load
|
||||
pub verify_checksums: bool,
|
||||
|
||||
/// Execution providers in priority order
|
||||
pub execution_providers: Vec<ExecutionProvider>,
|
||||
|
||||
/// Maximum number of cached sessions
|
||||
pub max_cached_sessions: usize,
|
||||
}
|
||||
|
||||
impl Default for ModelConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model_dir: PathBuf::from("models"),
|
||||
intra_op_threads: num_cpus::get().min(4),
|
||||
inter_op_threads: 1,
|
||||
verify_checksums: true,
|
||||
execution_providers: vec![
|
||||
ExecutionProvider::Cuda { device_id: 0 },
|
||||
ExecutionProvider::CoreML,
|
||||
ExecutionProvider::Cpu,
|
||||
],
|
||||
max_cached_sessions: 4,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Execution provider for ONNX Runtime
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ExecutionProvider {
|
||||
/// CPU execution
|
||||
Cpu,
|
||||
|
||||
/// NVIDIA CUDA execution
|
||||
Cuda {
|
||||
/// GPU device ID
|
||||
device_id: i32,
|
||||
},
|
||||
|
||||
/// Apple CoreML execution
|
||||
CoreML,
|
||||
|
||||
/// DirectML execution (Windows)
|
||||
DirectML {
|
||||
/// Device ID
|
||||
device_id: i32,
|
||||
},
|
||||
}
|
||||
|
||||
/// Thread-safe model session manager with caching and hot-swap support.
|
||||
///
|
||||
/// Manages the lifecycle of ONNX models used for embedding generation,
|
||||
/// including loading, caching, and version management.
|
||||
pub struct ModelManager {
|
||||
/// Cached model sessions by version
|
||||
sessions: RwLock<HashMap<String, Arc<OnnxInference>>>,
|
||||
|
||||
/// Model metadata by version
|
||||
models: RwLock<HashMap<String, EmbeddingModel>>,
|
||||
|
||||
/// Currently active model version
|
||||
active_version: RwLock<ModelVersion>,
|
||||
|
||||
/// Configuration
|
||||
config: ModelConfig,
|
||||
}
|
||||
|
||||
impl ModelManager {
|
||||
/// Create a new model manager with the given configuration.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the model directory doesn't exist and can't be created.
|
||||
pub fn new(config: ModelConfig) -> Result<Self, ModelError> {
|
||||
// Ensure model directory exists
|
||||
if !config.model_dir.exists() {
|
||||
std::fs::create_dir_all(&config.model_dir)?;
|
||||
debug!(path = ?config.model_dir, "Created model directory");
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
sessions: RwLock::new(HashMap::new()),
|
||||
models: RwLock::new(HashMap::new()),
|
||||
active_version: RwLock::new(ModelVersion::perch_v2_base()),
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with default configuration
|
||||
pub fn with_defaults() -> Result<Self, ModelError> {
|
||||
Self::new(ModelConfig::default())
|
||||
}
|
||||
|
||||
/// Load a model from a file.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `name` - Model name (e.g., "perch-v2")
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the model file doesn't exist or fails to load.
|
||||
#[instrument(skip(self), fields(model = %name))]
|
||||
pub fn load_model(&self, name: &str) -> Result<Arc<OnnxInference>, ModelError> {
|
||||
let version = self.active_version.read().clone();
|
||||
let version_key = version.full_version();
|
||||
|
||||
// Check cache first
|
||||
{
|
||||
let sessions = self.sessions.read();
|
||||
if let Some(session) = sessions.get(&version_key) {
|
||||
debug!("Using cached session for {}", version_key);
|
||||
return Ok(Arc::clone(session));
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve model path
|
||||
let model_path = self.resolve_model_path(name, &version)?;
|
||||
|
||||
// Verify checksum if configured
|
||||
if self.config.verify_checksums {
|
||||
if let Some(model) = self.models.read().get(&version_key) {
|
||||
if !model.checksum.is_empty() {
|
||||
self.verify_checksum(&model_path, &model.checksum)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create new session
|
||||
info!(path = ?model_path, "Loading model");
|
||||
let session = self.create_session(&model_path)?;
|
||||
let session = Arc::new(session);
|
||||
|
||||
// Cache the session
|
||||
{
|
||||
let mut sessions = self.sessions.write();
|
||||
|
||||
// Evict old sessions if at capacity
|
||||
while sessions.len() >= self.config.max_cached_sessions {
|
||||
if let Some(key) = sessions.keys().next().cloned() {
|
||||
sessions.remove(&key);
|
||||
debug!("Evicted cached session: {}", key);
|
||||
}
|
||||
}
|
||||
|
||||
sessions.insert(version_key.clone(), Arc::clone(&session));
|
||||
}
|
||||
|
||||
// Update model metadata
|
||||
{
|
||||
let mut models = self.models.write();
|
||||
if let Some(model) = models.get_mut(&version_key) {
|
||||
model.mark_active();
|
||||
}
|
||||
}
|
||||
|
||||
info!(version = %version_key, "Model loaded successfully");
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
/// Verify the checksum of a model file.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the checksum doesn't match.
|
||||
pub fn verify_checksum(&self, path: &Path, expected: &str) -> Result<bool, ModelError> {
|
||||
let actual = self.compute_checksum(path)?;
|
||||
|
||||
if actual != expected {
|
||||
return Err(ModelError::ChecksumMismatch {
|
||||
model: path.display().to_string(),
|
||||
expected: expected.to_string(),
|
||||
actual,
|
||||
});
|
||||
}
|
||||
|
||||
debug!(path = ?path, "Checksum verified");
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Compute the SHA-256 checksum of a file.
|
||||
fn compute_checksum(&self, path: &Path) -> Result<String, ModelError> {
|
||||
let mut file = std::fs::File::open(path)?;
|
||||
let mut hasher = Sha256::new();
|
||||
std::io::copy(&mut file, &mut hasher)?;
|
||||
let hash = hasher.finalize();
|
||||
Ok(hex::encode(hash))
|
||||
}
|
||||
|
||||
/// Hot-swap to a new model version without restart.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `name` - Model name
|
||||
/// * `new_path` - Path to the new model file
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the new model fails to load.
|
||||
#[instrument(skip(self, new_path), fields(model = %name, path = ?new_path))]
|
||||
pub fn hot_swap(&self, name: &str, new_path: &Path) -> Result<(), ModelError> {
|
||||
// Validate the new model can be loaded
|
||||
info!("Attempting hot-swap to new model");
|
||||
let new_session = self.create_session(new_path)?;
|
||||
|
||||
// Compute checksum for the new model
|
||||
let checksum = self.compute_checksum(new_path)?;
|
||||
|
||||
// Create new version
|
||||
let old_version = self.active_version.read().clone();
|
||||
let new_version = ModelVersion::new(
|
||||
name,
|
||||
&old_version.version, // Keep same semantic version
|
||||
"hot-swap",
|
||||
);
|
||||
let version_key = new_version.full_version();
|
||||
|
||||
// Update sessions cache
|
||||
{
|
||||
let mut sessions = self.sessions.write();
|
||||
sessions.insert(version_key.clone(), Arc::new(new_session));
|
||||
}
|
||||
|
||||
// Update model metadata
|
||||
{
|
||||
let mut models = self.models.write();
|
||||
let mut model = EmbeddingModel::new(
|
||||
name.to_string(),
|
||||
new_version.clone(),
|
||||
checksum,
|
||||
);
|
||||
model.model_path = Some(new_path.to_string_lossy().to_string());
|
||||
model.mark_active();
|
||||
models.insert(version_key, model);
|
||||
}
|
||||
|
||||
// Update active version
|
||||
*self.active_version.write() = new_version.clone();
|
||||
|
||||
info!(
|
||||
old_version = %old_version,
|
||||
new_version = %new_version,
|
||||
"Hot-swap completed successfully"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the ONNX inference engine for the current model.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if no model is loaded.
|
||||
pub async fn get_inference(&self) -> Result<Arc<OnnxInference>, ModelError> {
|
||||
let version = self.active_version.read().clone();
|
||||
self.load_model(&version.name)
|
||||
}
|
||||
|
||||
/// Get the currently active model version.
|
||||
#[must_use]
|
||||
pub fn current_version(&self) -> ModelVersion {
|
||||
self.active_version.read().clone()
|
||||
}
|
||||
|
||||
/// Set the active model version.
|
||||
pub fn set_active_version(&self, version: ModelVersion) {
|
||||
*self.active_version.write() = version;
|
||||
}
|
||||
|
||||
/// Check if a model is loaded and ready.
|
||||
pub async fn is_ready(&self) -> bool {
|
||||
let version_key = self.active_version.read().full_version();
|
||||
self.sessions.read().contains_key(&version_key)
|
||||
}
|
||||
|
||||
/// Get model metadata for a version.
|
||||
#[must_use]
|
||||
pub fn get_model(&self, version_key: &str) -> Option<EmbeddingModel> {
|
||||
self.models.read().get(version_key).cloned()
|
||||
}
|
||||
|
||||
/// List all loaded models.
|
||||
#[must_use]
|
||||
pub fn list_models(&self) -> Vec<EmbeddingModel> {
|
||||
self.models.read().values().cloned().collect()
|
||||
}
|
||||
|
||||
/// Clear all cached sessions.
|
||||
pub fn clear_cache(&self) {
|
||||
self.sessions.write().clear();
|
||||
info!("Cleared model session cache");
|
||||
}
|
||||
|
||||
/// Resolve the path to a model file.
|
||||
fn resolve_model_path(&self, name: &str, version: &ModelVersion) -> Result<PathBuf, ModelError> {
|
||||
// Try various naming conventions
|
||||
let candidates = vec![
|
||||
self.config.model_dir.join(format!("{}.onnx", version.full_version())),
|
||||
self.config.model_dir.join(format!("{}_{}.onnx", name, version.version)),
|
||||
self.config.model_dir.join(format!("{}.onnx", name)),
|
||||
self.config.model_dir.join(format!("{}/{}.onnx", name, version.version)),
|
||||
];
|
||||
|
||||
for path in &candidates {
|
||||
if path.exists() {
|
||||
return Ok(path.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Also check if there's a model metadata entry with a path
|
||||
let version_key = version.full_version();
|
||||
if let Some(model) = self.models.read().get(&version_key) {
|
||||
if let Some(ref path_str) = model.model_path {
|
||||
let path = PathBuf::from(path_str);
|
||||
if path.exists() {
|
||||
return Ok(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(ModelError::NotFound(format!(
|
||||
"Model {} not found in {:?}. Tried: {:?}",
|
||||
name, self.config.model_dir, candidates
|
||||
)))
|
||||
}
|
||||
|
||||
/// Create an ONNX inference session from a model file.
|
||||
fn create_session(&self, path: &Path) -> Result<OnnxInference, ModelError> {
|
||||
OnnxInference::new(
|
||||
path,
|
||||
self.config.intra_op_threads,
|
||||
self.config.inter_op_threads,
|
||||
&self.config.execution_providers,
|
||||
)
|
||||
.map_err(|e| ModelError::LoadFailed(e.to_string()))
|
||||
}
|
||||
|
||||
/// Register a model without loading it.
|
||||
pub fn register_model(&self, model: EmbeddingModel) {
|
||||
let version_key = model.version.full_version();
|
||||
self.models.write().insert(version_key, model);
|
||||
}
|
||||
|
||||
/// Unload a specific model version from cache.
|
||||
pub fn unload_model(&self, version_key: &str) -> bool {
|
||||
let removed = self.sessions.write().remove(version_key).is_some();
|
||||
if removed {
|
||||
info!(version = %version_key, "Unloaded model from cache");
|
||||
}
|
||||
removed
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ModelManager {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ModelManager")
|
||||
.field("model_dir", &self.config.model_dir)
|
||||
.field("active_version", &*self.active_version.read())
|
||||
.field("cached_sessions", &self.sessions.read().len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Write;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_model_config_default() {
|
||||
let config = ModelConfig::default();
|
||||
assert!(config.intra_op_threads > 0);
|
||||
assert!(config.verify_checksums);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_manager_creation() {
|
||||
let dir = tempdir().unwrap();
|
||||
let config = ModelConfig {
|
||||
model_dir: dir.path().to_path_buf(),
|
||||
..Default::default()
|
||||
};
|
||||
let manager = ModelManager::new(config);
|
||||
assert!(manager.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_checksum_computation() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("test.bin");
|
||||
|
||||
let mut file = std::fs::File::create(&file_path).unwrap();
|
||||
file.write_all(b"test content").unwrap();
|
||||
|
||||
let config = ModelConfig {
|
||||
model_dir: dir.path().to_path_buf(),
|
||||
..Default::default()
|
||||
};
|
||||
let manager = ModelManager::new(config).unwrap();
|
||||
|
||||
let checksum = manager.compute_checksum(&file_path).unwrap();
|
||||
assert!(!checksum.is_empty());
|
||||
assert_eq!(checksum.len(), 64); // SHA-256 hex length
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_version_key() {
|
||||
let version = ModelVersion::perch_v2_base();
|
||||
assert_eq!(version.full_version(), "perch-v2-2.0.0-base");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_register_model() {
|
||||
let dir = tempdir().unwrap();
|
||||
let config = ModelConfig {
|
||||
model_dir: dir.path().to_path_buf(),
|
||||
..Default::default()
|
||||
};
|
||||
let manager = ModelManager::new(config).unwrap();
|
||||
|
||||
let model = EmbeddingModel::perch_v2_default();
|
||||
let version_key = model.version.full_version();
|
||||
|
||||
manager.register_model(model);
|
||||
|
||||
let retrieved = manager.get_model(&version_key);
|
||||
assert!(retrieved.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear_cache() {
|
||||
let dir = tempdir().unwrap();
|
||||
let config = ModelConfig {
|
||||
model_dir: dir.path().to_path_buf(),
|
||||
..Default::default()
|
||||
};
|
||||
let manager = ModelManager::new(config).unwrap();
|
||||
|
||||
manager.clear_cache();
|
||||
// Should not panic
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,426 @@
|
||||
//! ONNX Runtime inference for Perch 2.0 embeddings.
|
||||
//!
|
||||
//! Provides efficient inference using the `ort` crate for
|
||||
//! ONNX Runtime integration in Rust.
|
||||
|
||||
use std::path::Path;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Mutex;
|
||||
|
||||
use ndarray::{Array1, Array3};
|
||||
use ort::session::{Session, builder::GraphOptimizationLevel};
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, instrument, warn};
|
||||
|
||||
use super::model_manager::ExecutionProvider;
|
||||
use crate::{EMBEDDING_DIM, MEL_BINS, MEL_FRAMES};
|
||||
|
||||
/// Errors during ONNX inference
|
||||
#[derive(Debug, Error)]
|
||||
pub enum InferenceError {
|
||||
/// Session creation failed
|
||||
#[error("Failed to create session: {0}")]
|
||||
SessionCreation(String),
|
||||
|
||||
/// Input tensor creation failed
|
||||
#[error("Failed to create input tensor: {0}")]
|
||||
InputTensor(String),
|
||||
|
||||
/// Inference execution failed
|
||||
#[error("Inference failed: {0}")]
|
||||
Execution(String),
|
||||
|
||||
/// Output extraction failed
|
||||
#[error("Failed to extract output: {0}")]
|
||||
OutputExtraction(String),
|
||||
|
||||
/// Invalid input dimensions
|
||||
#[error("Invalid input dimensions: expected {expected:?}, got {actual:?}")]
|
||||
InvalidDimensions {
|
||||
/// Expected shape
|
||||
expected: Vec<usize>,
|
||||
/// Actual shape
|
||||
actual: Vec<usize>,
|
||||
},
|
||||
|
||||
/// Model not initialized
|
||||
#[error("Model not initialized")]
|
||||
NotInitialized,
|
||||
}
|
||||
|
||||
/// ONNX inference engine for embedding generation.
|
||||
pub struct OnnxInference {
|
||||
/// The ONNX Runtime session (wrapped in Mutex for interior mutability)
|
||||
session: Mutex<Session>,
|
||||
|
||||
/// Whether GPU is being used
|
||||
gpu_enabled: AtomicBool,
|
||||
|
||||
/// Input name for the model
|
||||
input_name: String,
|
||||
|
||||
/// Output name for embeddings
|
||||
output_name: String,
|
||||
}
|
||||
|
||||
impl OnnxInference {
|
||||
/// Create a new ONNX inference engine from a model file.
|
||||
#[instrument(skip(providers), fields(path = ?model_path))]
|
||||
pub fn new(
|
||||
model_path: &Path,
|
||||
intra_op_threads: usize,
|
||||
inter_op_threads: usize,
|
||||
providers: &[ExecutionProvider],
|
||||
) -> Result<Self, InferenceError> {
|
||||
let builder = Session::builder()
|
||||
.map_err(|e| InferenceError::SessionCreation(e.to_string()))?
|
||||
.with_intra_threads(intra_op_threads)
|
||||
.map_err(|e| InferenceError::SessionCreation(e.to_string()))?
|
||||
.with_inter_threads(inter_op_threads)
|
||||
.map_err(|e| InferenceError::SessionCreation(e.to_string()))?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||
.map_err(|e| InferenceError::SessionCreation(e.to_string()))?;
|
||||
|
||||
let gpu_enabled = false;
|
||||
|
||||
for provider in providers {
|
||||
match provider {
|
||||
ExecutionProvider::Cuda { device_id } => {
|
||||
warn!("CUDA device {} requested, using CPU fallback", device_id);
|
||||
}
|
||||
ExecutionProvider::CoreML => {
|
||||
warn!("CoreML requested, using CPU fallback");
|
||||
}
|
||||
ExecutionProvider::DirectML { device_id } => {
|
||||
warn!("DirectML device {} requested, using CPU fallback", device_id);
|
||||
}
|
||||
ExecutionProvider::Cpu => {
|
||||
debug!("Using CPU execution provider");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let session = builder
|
||||
.commit_from_file(model_path)
|
||||
.map_err(|e| InferenceError::SessionCreation(e.to_string()))?;
|
||||
|
||||
let inputs = session.inputs();
|
||||
let outputs = session.outputs();
|
||||
|
||||
let input_name = inputs
|
||||
.first()
|
||||
.map(|i| i.name().to_string())
|
||||
.unwrap_or_else(|| "input".to_string());
|
||||
|
||||
let output_name = outputs
|
||||
.first()
|
||||
.map(|o| o.name().to_string())
|
||||
.unwrap_or_else(|| "embedding".to_string());
|
||||
|
||||
debug!(
|
||||
input = %input_name,
|
||||
output = %output_name,
|
||||
gpu = gpu_enabled,
|
||||
"ONNX session created"
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
session: Mutex::new(session),
|
||||
gpu_enabled: AtomicBool::new(gpu_enabled),
|
||||
input_name,
|
||||
output_name,
|
||||
})
|
||||
}
|
||||
|
||||
/// Run inference on a single spectrogram.
|
||||
#[instrument(skip(self, input))]
|
||||
pub fn run(&self, input: &Array3<f32>) -> Result<Array1<f32>, InferenceError> {
|
||||
let shape = input.shape();
|
||||
if shape[1] != MEL_FRAMES || shape[2] != MEL_BINS {
|
||||
return Err(InferenceError::InvalidDimensions {
|
||||
expected: vec![1, MEL_FRAMES, MEL_BINS],
|
||||
actual: shape.to_vec(),
|
||||
});
|
||||
}
|
||||
|
||||
// Create input tensor using ort 2.0 API with shape tuple
|
||||
let input_vec: Vec<f32> = input.iter().cloned().collect();
|
||||
let tensor_shape = vec![1i64, MEL_FRAMES as i64, MEL_BINS as i64];
|
||||
let input_tensor = ort::value::Tensor::from_array((tensor_shape, input_vec))
|
||||
.map_err(|e| InferenceError::InputTensor(e.to_string()))?;
|
||||
|
||||
// Run inference (lock session for mutable access required by ort 2.0)
|
||||
let inputs = ort::inputs![&self.input_name => input_tensor];
|
||||
let mut session = self.session.lock()
|
||||
.map_err(|e| InferenceError::Execution(format!("Lock error: {}", e)))?;
|
||||
let outputs = session
|
||||
.run(inputs)
|
||||
.map_err(|e| InferenceError::Execution(e.to_string()))?;
|
||||
|
||||
// Extract embedding output
|
||||
let output = outputs
|
||||
.get(&self.output_name)
|
||||
.ok_or_else(|| InferenceError::OutputExtraction("No output found".to_string()))?;
|
||||
|
||||
// Extract tensor data using ort 2.0 API - returns (&Shape, &[f32]) tuple
|
||||
let (_shape, flat_slice) = output
|
||||
.try_extract_tensor::<f32>()
|
||||
.map_err(|e| InferenceError::OutputExtraction(e.to_string()))?;
|
||||
|
||||
// Handle different output shapes
|
||||
let embedding_data: Vec<f32> = if flat_slice.len() == EMBEDDING_DIM {
|
||||
flat_slice.to_vec()
|
||||
} else if flat_slice.len() > EMBEDDING_DIM {
|
||||
flat_slice[..EMBEDDING_DIM].to_vec()
|
||||
} else {
|
||||
return Err(InferenceError::OutputExtraction(format!(
|
||||
"Unexpected embedding size: {} (expected {})",
|
||||
flat_slice.len(),
|
||||
EMBEDDING_DIM
|
||||
)));
|
||||
};
|
||||
|
||||
debug!("Inference completed");
|
||||
Ok(Array1::from_vec(embedding_data))
|
||||
}
|
||||
|
||||
/// Run inference on a batch of spectrograms.
|
||||
#[instrument(skip(self, inputs), fields(batch_size = inputs.len()))]
|
||||
pub fn run_batch(&self, inputs: &[&Array3<f32>]) -> Result<Vec<Array1<f32>>, InferenceError> {
|
||||
if inputs.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let batch_size = inputs.len();
|
||||
|
||||
for input in inputs.iter() {
|
||||
let shape = input.shape();
|
||||
if shape[1] != MEL_FRAMES || shape[2] != MEL_BINS {
|
||||
return Err(InferenceError::InvalidDimensions {
|
||||
expected: vec![1, MEL_FRAMES, MEL_BINS],
|
||||
actual: shape.to_vec(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Stack inputs into a batch tensor
|
||||
let mut batch_data = Vec::with_capacity(batch_size * MEL_FRAMES * MEL_BINS);
|
||||
for input in inputs {
|
||||
let view = input.view();
|
||||
for frame in 0..MEL_FRAMES {
|
||||
for bin in 0..MEL_BINS {
|
||||
batch_data.push(view[[0, frame, bin]]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create batch tensor using shape tuple API
|
||||
let tensor_shape = vec![batch_size as i64, MEL_FRAMES as i64, MEL_BINS as i64];
|
||||
let input_tensor = ort::value::Tensor::from_array((tensor_shape, batch_data))
|
||||
.map_err(|e| InferenceError::InputTensor(e.to_string()))?;
|
||||
|
||||
// Run inference (lock session for mutable access required by ort 2.0)
|
||||
let ort_inputs = ort::inputs![&self.input_name => input_tensor];
|
||||
let mut session = self.session.lock()
|
||||
.map_err(|e| InferenceError::Execution(format!("Lock error: {}", e)))?;
|
||||
let outputs = session
|
||||
.run(ort_inputs)
|
||||
.map_err(|e| InferenceError::Execution(e.to_string()))?;
|
||||
|
||||
// Extract embeddings using ort 2.0 API - returns (&Shape, &[f32]) tuple
|
||||
let output = outputs
|
||||
.get(&self.output_name)
|
||||
.ok_or_else(|| InferenceError::OutputExtraction("No output found".to_string()))?;
|
||||
|
||||
let (_shape, flat_slice) = output
|
||||
.try_extract_tensor::<f32>()
|
||||
.map_err(|e| InferenceError::OutputExtraction(e.to_string()))?;
|
||||
|
||||
// Split into individual embeddings
|
||||
let total_expected = batch_size * EMBEDDING_DIM;
|
||||
if flat_slice.len() < total_expected {
|
||||
return Err(InferenceError::OutputExtraction(format!(
|
||||
"Unexpected output size: {} (expected at least {})",
|
||||
flat_slice.len(),
|
||||
total_expected
|
||||
)));
|
||||
}
|
||||
|
||||
let result: Vec<Array1<f32>> = (0..batch_size)
|
||||
.map(|i| {
|
||||
let start = i * EMBEDDING_DIM;
|
||||
let end = start + EMBEDDING_DIM;
|
||||
Array1::from_vec(flat_slice[start..end].to_vec())
|
||||
})
|
||||
.collect();
|
||||
|
||||
debug!(batch_size = batch_size, "Batch inference completed");
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Check if GPU is being used for inference.
|
||||
#[must_use]
|
||||
pub fn is_gpu(&self) -> bool {
|
||||
self.gpu_enabled.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Get the input name expected by the model.
|
||||
#[must_use]
|
||||
pub fn input_name(&self) -> &str {
|
||||
&self.input_name
|
||||
}
|
||||
|
||||
/// Get the output name for embeddings.
|
||||
#[must_use]
|
||||
pub fn output_name(&self) -> &str {
|
||||
&self.output_name
|
||||
}
|
||||
|
||||
/// Get information about the model's expected input shape.
|
||||
#[must_use]
|
||||
pub fn input_info(&self) -> Option<InputInfo> {
|
||||
self.session.lock().ok().and_then(|session| {
|
||||
session.inputs().first().map(|input| InputInfo {
|
||||
name: input.name().to_string(),
|
||||
dimensions: Vec::new(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// Get information about the model's output shape.
|
||||
#[must_use]
|
||||
pub fn output_info(&self) -> Option<OutputInfo> {
|
||||
self.session.lock().ok().and_then(|session| {
|
||||
session.outputs().first().map(|output| OutputInfo {
|
||||
name: output.name().to_string(),
|
||||
dimensions: Vec::new(),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Information about model input
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InputInfo {
|
||||
/// Input tensor name
|
||||
pub name: String,
|
||||
/// Expected dimensions
|
||||
pub dimensions: Vec<usize>,
|
||||
}
|
||||
|
||||
/// Information about model output
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OutputInfo {
|
||||
/// Output tensor name
|
||||
pub name: String,
|
||||
/// Output dimensions
|
||||
pub dimensions: Vec<usize>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OnnxInference {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OnnxInference")
|
||||
.field("gpu_enabled", &self.is_gpu())
|
||||
.field("input_name", &self.input_name)
|
||||
.field("output_name", &self.output_name)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for ONNX inference
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InferenceConfig {
|
||||
/// Number of threads for intra-op parallelism
|
||||
pub intra_op_threads: usize,
|
||||
/// Number of threads for inter-op parallelism
|
||||
pub inter_op_threads: usize,
|
||||
/// Execution providers in priority order
|
||||
pub providers: Vec<ExecutionProvider>,
|
||||
/// Whether to enable memory optimization
|
||||
pub optimize_memory: bool,
|
||||
/// Maximum batch size for inference
|
||||
pub max_batch_size: usize,
|
||||
}
|
||||
|
||||
impl Default for InferenceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
intra_op_threads: num_cpus::get().min(4),
|
||||
inter_op_threads: 1,
|
||||
providers: vec![
|
||||
ExecutionProvider::Cuda { device_id: 0 },
|
||||
ExecutionProvider::CoreML,
|
||||
ExecutionProvider::Cpu,
|
||||
],
|
||||
optimize_memory: true,
|
||||
max_batch_size: 32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl InferenceConfig {
|
||||
/// Configuration optimized for field devices
|
||||
#[must_use]
|
||||
pub fn field_device() -> Self {
|
||||
Self {
|
||||
intra_op_threads: 2,
|
||||
inter_op_threads: 1,
|
||||
providers: vec![ExecutionProvider::Cpu],
|
||||
optimize_memory: true,
|
||||
max_batch_size: 1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration optimized for server deployment
|
||||
#[must_use]
|
||||
pub fn server() -> Self {
|
||||
Self {
|
||||
intra_op_threads: 4,
|
||||
inter_op_threads: 2,
|
||||
providers: vec![
|
||||
ExecutionProvider::Cuda { device_id: 0 },
|
||||
ExecutionProvider::Cpu,
|
||||
],
|
||||
optimize_memory: false,
|
||||
max_batch_size: 64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_inference_config_default() {
|
||||
let config = InferenceConfig::default();
|
||||
assert!(config.intra_op_threads > 0);
|
||||
assert!(!config.providers.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inference_config_field_device() {
|
||||
let config = InferenceConfig::field_device();
|
||||
assert_eq!(config.intra_op_threads, 2);
|
||||
assert_eq!(config.max_batch_size, 1);
|
||||
assert!(config.optimize_memory);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inference_config_server() {
|
||||
let config = InferenceConfig::server();
|
||||
assert_eq!(config.max_batch_size, 64);
|
||||
assert!(!config.optimize_memory);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_input_validation() {
|
||||
let valid_shape = vec![1, MEL_FRAMES, MEL_BINS];
|
||||
let invalid_shape = vec![1, 100, 100];
|
||||
|
||||
assert_eq!(valid_shape[1], MEL_FRAMES);
|
||||
assert_eq!(valid_shape[2], MEL_BINS);
|
||||
assert_ne!(invalid_shape[1], MEL_FRAMES);
|
||||
}
|
||||
}
|
||||
143
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-embedding/src/lib.rs
vendored
Normal file
143
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-embedding/src/lib.rs
vendored
Normal file
@@ -0,0 +1,143 @@
|
||||
//! # sevensense-embedding
|
||||
//!
|
||||
//! Embedding bounded context for 7sense bioacoustics platform.
|
||||
//!
|
||||
//! This crate provides Perch 2.0 ONNX integration for generating 1536-dimensional
|
||||
//! embeddings from preprocessed audio segments. It handles model loading, inference,
|
||||
//! normalization, and quantization for efficient storage and retrieval.
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! The crate follows Domain-Driven Design (DDD) principles:
|
||||
//!
|
||||
//! - **Domain Layer**: Core entities (`Embedding`, `EmbeddingModel`) and repository traits
|
||||
//! - **Application Layer**: Services for embedding generation and batch processing
|
||||
//! - **Infrastructure Layer**: ONNX Runtime integration and model management
|
||||
//!
|
||||
//! ## Usage
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use sevensense_embedding::{
|
||||
//! EmbeddingService, ModelManager, ModelConfig,
|
||||
//! domain::Embedding,
|
||||
//! };
|
||||
//!
|
||||
//! // Initialize model manager
|
||||
//! let config = ModelConfig::default();
|
||||
//! let model_manager = ModelManager::new(config)?;
|
||||
//!
|
||||
//! // Create embedding service
|
||||
//! let service = EmbeddingService::new(model_manager, 8);
|
||||
//!
|
||||
//! // Generate embedding from spectrogram
|
||||
//! let embedding = service.embed_segment(&spectrogram).await?;
|
||||
//! ```
|
||||
//!
|
||||
//! ## Features
|
||||
//!
|
||||
//! - **Perch 2.0 Integration**: Full support for EfficientNet-B3 bioacoustic embeddings
|
||||
//! - **Batch Processing**: Efficient batch inference with configurable batch sizes
|
||||
//! - **Model Hot-Swap**: Update models without service restart
|
||||
//! - **Quantization**: F16 and INT8 quantization for reduced storage
|
||||
//! - **Validation**: Comprehensive embedding validation (NaN detection, dimension checks)
|
||||
|
||||
#![warn(missing_docs)]
|
||||
#![warn(clippy::all)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
|
||||
pub mod domain;
|
||||
pub mod application;
|
||||
pub mod infrastructure;
|
||||
pub mod normalization;
|
||||
pub mod quantization;
|
||||
|
||||
// Re-export main types for convenience
|
||||
pub use domain::entities::{
|
||||
Embedding, EmbeddingId, EmbeddingModel, EmbeddingMetadata,
|
||||
StorageTier, ModelVersion, InputSpecification,
|
||||
};
|
||||
pub use domain::repository::EmbeddingRepository;
|
||||
pub use application::services::EmbeddingService;
|
||||
pub use infrastructure::model_manager::{ModelManager, ModelConfig};
|
||||
pub use infrastructure::onnx_inference::OnnxInference;
|
||||
|
||||
/// Embedding dimension for Perch 2.0 model
|
||||
pub const EMBEDDING_DIM: usize = 1536;
|
||||
|
||||
/// Target sample rate for Perch 2.0 (32kHz)
|
||||
pub const TARGET_SAMPLE_RATE: u32 = 32000;
|
||||
|
||||
/// Target window duration in seconds (5s)
|
||||
pub const TARGET_WINDOW_SECONDS: f32 = 5.0;
|
||||
|
||||
/// Target window samples (160,000 = 5s at 32kHz)
|
||||
pub const TARGET_WINDOW_SAMPLES: usize = 160_000;
|
||||
|
||||
/// Mel spectrogram bins for Perch 2.0
|
||||
pub const MEL_BINS: usize = 128;
|
||||
|
||||
/// Mel spectrogram frames for Perch 2.0
|
||||
pub const MEL_FRAMES: usize = 500;
|
||||
|
||||
/// Crate version information
|
||||
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
/// Common result type for embedding operations
|
||||
pub type Result<T> = std::result::Result<T, EmbeddingError>;
|
||||
|
||||
/// Unified error type for embedding operations
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum EmbeddingError {
|
||||
/// Model loading or initialization error
|
||||
#[error("Model error: {0}")]
|
||||
Model(#[from] infrastructure::model_manager::ModelError),
|
||||
|
||||
/// ONNX inference error
|
||||
#[error("Inference error: {0}")]
|
||||
Inference(#[from] infrastructure::onnx_inference::InferenceError),
|
||||
|
||||
/// Embedding validation error
|
||||
#[error("Validation error: {0}")]
|
||||
Validation(String),
|
||||
|
||||
/// Invalid input dimensions
|
||||
#[error("Invalid dimensions: expected {expected}, got {actual}")]
|
||||
InvalidDimensions {
|
||||
/// Expected dimension
|
||||
expected: usize,
|
||||
/// Actual dimension
|
||||
actual: usize,
|
||||
},
|
||||
|
||||
/// Repository error
|
||||
#[error("Repository error: {0}")]
|
||||
Repository(String),
|
||||
|
||||
/// IO error
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
/// Checksum verification failed
|
||||
#[error("Checksum mismatch: expected {expected}, got {actual}")]
|
||||
ChecksumMismatch {
|
||||
/// Expected checksum
|
||||
expected: String,
|
||||
/// Actual checksum
|
||||
actual: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_constants() {
|
||||
assert_eq!(EMBEDDING_DIM, 1536);
|
||||
assert_eq!(TARGET_SAMPLE_RATE, 32000);
|
||||
assert_eq!(TARGET_WINDOW_SAMPLES, 160_000);
|
||||
assert_eq!(MEL_BINS, 128);
|
||||
assert_eq!(MEL_FRAMES, 500);
|
||||
}
|
||||
}
|
||||
460
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-embedding/src/normalization.rs
vendored
Normal file
460
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-embedding/src/normalization.rs
vendored
Normal file
@@ -0,0 +1,460 @@
|
||||
//! Normalization utilities for embedding vectors.
|
||||
//!
|
||||
//! Provides L2 normalization and validation functions to ensure
|
||||
//! embedding vectors are properly normalized for cosine similarity
|
||||
//! operations in vector databases.
|
||||
|
||||
use crate::EMBEDDING_DIM;
|
||||
|
||||
/// L2 normalize an embedding vector in-place.
|
||||
///
|
||||
/// After normalization, the vector will have unit length (norm = 1.0),
|
||||
/// enabling cosine similarity to be computed as a simple dot product.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `embedding` - The embedding vector to normalize in-place
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use sevensense_embedding::normalization::l2_normalize;
|
||||
///
|
||||
/// let mut vector = vec![3.0, 4.0];
|
||||
/// l2_normalize(&mut vector);
|
||||
/// assert!((vector[0] - 0.6).abs() < 1e-6);
|
||||
/// assert!((vector[1] - 0.8).abs() < 1e-6);
|
||||
/// ```
|
||||
pub fn l2_normalize(embedding: &mut [f32]) {
|
||||
let norm = compute_norm(embedding);
|
||||
|
||||
if norm > 1e-12 {
|
||||
for x in embedding.iter_mut() {
|
||||
*x /= norm;
|
||||
}
|
||||
} else {
|
||||
// Handle near-zero embeddings (likely silent input)
|
||||
// Set to unit vector in first dimension
|
||||
embedding.iter_mut().for_each(|x| *x = 0.0);
|
||||
if !embedding.is_empty() {
|
||||
embedding[0] = 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the L2 norm of a vector.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vector` - The vector to compute the norm for
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The L2 norm (Euclidean length) of the vector.
|
||||
#[must_use]
|
||||
pub fn compute_norm(vector: &[f32]) -> f32 {
|
||||
vector.iter().map(|x| x * x).sum::<f32>().sqrt()
|
||||
}
|
||||
|
||||
/// Compute the sparsity of a vector.
|
||||
///
|
||||
/// Sparsity is the fraction of near-zero values in the vector.
|
||||
/// High sparsity may indicate issues with the embedding model.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vector` - The vector to analyze
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Sparsity as a value between 0.0 (no zeros) and 1.0 (all zeros).
|
||||
#[must_use]
|
||||
pub fn compute_sparsity(vector: &[f32]) -> f32 {
|
||||
if vector.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let near_zero_count = vector.iter().filter(|&&x| x.abs() < 1e-6).count();
|
||||
near_zero_count as f32 / vector.len() as f32
|
||||
}
|
||||
|
||||
/// Validate an embedding vector for common issues.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `embedding` - The embedding vector to validate
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A `ValidationResult` containing detailed information about the vector.
|
||||
#[must_use]
|
||||
pub fn validate_embedding(embedding: &[f32]) -> ValidationResult {
|
||||
let dimension_valid = embedding.len() == EMBEDDING_DIM;
|
||||
let has_nan = embedding.iter().any(|x| x.is_nan());
|
||||
let has_inf = embedding.iter().any(|x| x.is_infinite());
|
||||
let norm = compute_norm(embedding);
|
||||
let is_normalized = (0.99..=1.01).contains(&norm);
|
||||
let sparsity = compute_sparsity(embedding);
|
||||
|
||||
let issues = collect_issues(
|
||||
dimension_valid,
|
||||
embedding.len(),
|
||||
has_nan,
|
||||
has_inf,
|
||||
is_normalized,
|
||||
norm,
|
||||
sparsity,
|
||||
);
|
||||
|
||||
ValidationResult {
|
||||
dimension: embedding.len(),
|
||||
dimension_valid,
|
||||
norm,
|
||||
is_normalized,
|
||||
has_nan,
|
||||
has_inf,
|
||||
sparsity,
|
||||
is_valid: dimension_valid && !has_nan && !has_inf,
|
||||
issues,
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_issues(
|
||||
dimension_valid: bool,
|
||||
actual_dim: usize,
|
||||
has_nan: bool,
|
||||
has_inf: bool,
|
||||
is_normalized: bool,
|
||||
norm: f32,
|
||||
sparsity: f32,
|
||||
) -> Vec<ValidationIssue> {
|
||||
let mut issues = Vec::new();
|
||||
|
||||
if !dimension_valid {
|
||||
issues.push(ValidationIssue::InvalidDimension {
|
||||
expected: EMBEDDING_DIM,
|
||||
actual: actual_dim,
|
||||
});
|
||||
}
|
||||
|
||||
if has_nan {
|
||||
issues.push(ValidationIssue::ContainsNaN);
|
||||
}
|
||||
|
||||
if has_inf {
|
||||
issues.push(ValidationIssue::ContainsInfinite);
|
||||
}
|
||||
|
||||
if !is_normalized && !has_nan && !has_inf {
|
||||
issues.push(ValidationIssue::NotNormalized { norm });
|
||||
}
|
||||
|
||||
if sparsity > 0.9 {
|
||||
issues.push(ValidationIssue::HighSparsity { sparsity });
|
||||
}
|
||||
|
||||
issues
|
||||
}
|
||||
|
||||
/// Result of embedding validation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ValidationResult {
|
||||
/// Actual dimension of the embedding
|
||||
pub dimension: usize,
|
||||
|
||||
/// Whether the dimension matches expected (1536)
|
||||
pub dimension_valid: bool,
|
||||
|
||||
/// L2 norm of the embedding
|
||||
pub norm: f32,
|
||||
|
||||
/// Whether the embedding is L2 normalized (norm close to 1.0)
|
||||
pub is_normalized: bool,
|
||||
|
||||
/// Whether the embedding contains NaN values
|
||||
pub has_nan: bool,
|
||||
|
||||
/// Whether the embedding contains infinite values
|
||||
pub has_inf: bool,
|
||||
|
||||
/// Fraction of near-zero values
|
||||
pub sparsity: f32,
|
||||
|
||||
/// Overall validity (no NaN, no Inf, correct dimension)
|
||||
pub is_valid: bool,
|
||||
|
||||
/// List of specific issues found
|
||||
pub issues: Vec<ValidationIssue>,
|
||||
}
|
||||
|
||||
impl ValidationResult {
|
||||
/// Check if the embedding passes all validation checks
|
||||
#[must_use]
|
||||
pub fn is_ok(&self) -> bool {
|
||||
self.issues.is_empty()
|
||||
}
|
||||
|
||||
/// Get a human-readable summary of the validation result
|
||||
#[must_use]
|
||||
pub fn summary(&self) -> String {
|
||||
if self.issues.is_empty() {
|
||||
return "Embedding is valid".to_string();
|
||||
}
|
||||
|
||||
let issue_strings: Vec<String> = self.issues.iter().map(|i| i.to_string()).collect();
|
||||
format!("Embedding has issues: {}", issue_strings.join(", "))
|
||||
}
|
||||
}
|
||||
|
||||
/// Specific validation issues that can be detected
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum ValidationIssue {
|
||||
/// Embedding dimension doesn't match expected
|
||||
InvalidDimension {
|
||||
/// Expected dimension
|
||||
expected: usize,
|
||||
/// Actual dimension
|
||||
actual: usize,
|
||||
},
|
||||
|
||||
/// Embedding contains NaN values
|
||||
ContainsNaN,
|
||||
|
||||
/// Embedding contains infinite values
|
||||
ContainsInfinite,
|
||||
|
||||
/// Embedding is not L2 normalized
|
||||
NotNormalized {
|
||||
/// Actual norm
|
||||
norm: f32,
|
||||
},
|
||||
|
||||
/// Embedding has high sparsity (many near-zero values)
|
||||
HighSparsity {
|
||||
/// Sparsity value
|
||||
sparsity: f32,
|
||||
},
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ValidationIssue {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::InvalidDimension { expected, actual } => {
|
||||
write!(f, "invalid dimension (expected {expected}, got {actual})")
|
||||
}
|
||||
Self::ContainsNaN => write!(f, "contains NaN values"),
|
||||
Self::ContainsInfinite => write!(f, "contains infinite values"),
|
||||
Self::NotNormalized { norm } => {
|
||||
write!(f, "not normalized (norm = {norm:.4})")
|
||||
}
|
||||
Self::HighSparsity { sparsity } => {
|
||||
write!(f, "high sparsity ({:.1}%)", sparsity * 100.0)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// L1 normalize a vector (sum of absolute values = 1)
|
||||
pub fn l1_normalize(embedding: &mut [f32]) {
|
||||
let sum: f32 = embedding.iter().map(|x| x.abs()).sum();
|
||||
|
||||
if sum > 1e-12 {
|
||||
for x in embedding.iter_mut() {
|
||||
*x /= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Min-max normalize a vector to [0, 1] range
|
||||
pub fn minmax_normalize(embedding: &mut [f32]) {
|
||||
if embedding.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let min = embedding.iter().fold(f32::INFINITY, |a, &b| a.min(b));
|
||||
let max = embedding.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
|
||||
let range = max - min;
|
||||
|
||||
if range > 1e-12 {
|
||||
for x in embedding.iter_mut() {
|
||||
*x = (*x - min) / range;
|
||||
}
|
||||
} else {
|
||||
// All values are the same
|
||||
embedding.iter_mut().for_each(|x| *x = 0.5);
|
||||
}
|
||||
}
|
||||
|
||||
/// Z-score normalize a vector (mean = 0, std = 1)
|
||||
pub fn zscore_normalize(embedding: &mut [f32]) {
|
||||
if embedding.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let n = embedding.len() as f32;
|
||||
let mean: f32 = embedding.iter().sum::<f32>() / n;
|
||||
let variance: f32 = embedding.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / n;
|
||||
let std = variance.sqrt();
|
||||
|
||||
if std > 1e-12 {
|
||||
for x in embedding.iter_mut() {
|
||||
*x = (*x - mean) / std;
|
||||
}
|
||||
} else {
|
||||
// Zero variance - all values are the same
|
||||
embedding.iter_mut().for_each(|x| *x = 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
/// Clamp values to a specified range
|
||||
pub fn clamp(embedding: &mut [f32], min: f32, max: f32) {
|
||||
for x in embedding.iter_mut() {
|
||||
*x = x.clamp(min, max);
|
||||
}
|
||||
}
|
||||
|
||||
/// Soft clipping using tanh
|
||||
pub fn soft_clip(embedding: &mut [f32], scale: f32) {
|
||||
for x in embedding.iter_mut() {
|
||||
*x = (*x / scale).tanh() * scale;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_l2_normalize() {
|
||||
let mut vector = vec![3.0, 4.0];
|
||||
l2_normalize(&mut vector);
|
||||
assert!((vector[0] - 0.6).abs() < 1e-6);
|
||||
assert!((vector[1] - 0.8).abs() < 1e-6);
|
||||
|
||||
let norm = compute_norm(&vector);
|
||||
assert!((norm - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_l2_normalize_zero_vector() {
|
||||
let mut vector = vec![0.0, 0.0, 0.0];
|
||||
l2_normalize(&mut vector);
|
||||
assert_eq!(vector[0], 1.0);
|
||||
assert_eq!(vector[1], 0.0);
|
||||
assert_eq!(vector[2], 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_norm() {
|
||||
let vector = vec![3.0, 4.0];
|
||||
let norm = compute_norm(&vector);
|
||||
assert!((norm - 5.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_sparsity() {
|
||||
let vector = vec![0.0, 1.0, 0.0, 2.0, 0.0];
|
||||
let sparsity = compute_sparsity(&vector);
|
||||
assert!((sparsity - 0.6).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_sparsity_empty() {
|
||||
let vector: Vec<f32> = vec![];
|
||||
let sparsity = compute_sparsity(&vector);
|
||||
assert_eq!(sparsity, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_embedding_valid() {
|
||||
let mut vector = vec![0.0; EMBEDDING_DIM];
|
||||
vector[0] = 1.0;
|
||||
let result = validate_embedding(&vector);
|
||||
assert!(result.is_valid);
|
||||
assert!(result.is_normalized);
|
||||
assert!(!result.has_nan);
|
||||
assert!(!result.has_inf);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_embedding_wrong_dimension() {
|
||||
let vector = vec![1.0; 100];
|
||||
let result = validate_embedding(&vector);
|
||||
assert!(!result.dimension_valid);
|
||||
assert!(result.issues.iter().any(|i| matches!(i, ValidationIssue::InvalidDimension { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_embedding_nan() {
|
||||
let mut vector = vec![0.0; EMBEDDING_DIM];
|
||||
vector[0] = f32::NAN;
|
||||
let result = validate_embedding(&vector);
|
||||
assert!(result.has_nan);
|
||||
assert!(!result.is_valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_embedding_infinite() {
|
||||
let mut vector = vec![0.0; EMBEDDING_DIM];
|
||||
vector[0] = f32::INFINITY;
|
||||
let result = validate_embedding(&vector);
|
||||
assert!(result.has_inf);
|
||||
assert!(!result.is_valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_l1_normalize() {
|
||||
let mut vector = vec![1.0, 2.0, 3.0];
|
||||
l1_normalize(&mut vector);
|
||||
let sum: f32 = vector.iter().map(|x| x.abs()).sum();
|
||||
assert!((sum - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_minmax_normalize() {
|
||||
let mut vector = vec![0.0, 5.0, 10.0];
|
||||
minmax_normalize(&mut vector);
|
||||
assert!((vector[0] - 0.0).abs() < 1e-6);
|
||||
assert!((vector[1] - 0.5).abs() < 1e-6);
|
||||
assert!((vector[2] - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zscore_normalize() {
|
||||
let mut vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
zscore_normalize(&mut vector);
|
||||
let mean: f32 = vector.iter().sum::<f32>() / vector.len() as f32;
|
||||
assert!(mean.abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clamp() {
|
||||
let mut vector = vec![-2.0, 0.5, 2.0];
|
||||
clamp(&mut vector, -1.0, 1.0);
|
||||
assert_eq!(vector, vec![-1.0, 0.5, 1.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soft_clip() {
|
||||
let mut vector = vec![0.0, 1.0, 2.0];
|
||||
soft_clip(&mut vector, 1.0);
|
||||
assert!((vector[0] - 0.0).abs() < 1e-6);
|
||||
// tanh(1) ≈ 0.7616
|
||||
assert!(vector[1] > 0.5 && vector[1] < 0.8);
|
||||
// tanh(2) ≈ 0.964
|
||||
assert!(vector[2] > 0.9 && vector[2] < 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validation_result_summary() {
|
||||
// Create a reasonably distributed embedding (not too sparse)
|
||||
let mut vector = vec![0.0; EMBEDDING_DIM];
|
||||
// Fill first half with small values that sum to norm 1.0
|
||||
let val = 1.0 / (EMBEDDING_DIM as f32 / 2.0).sqrt();
|
||||
for i in 0..EMBEDDING_DIM / 2 {
|
||||
vector[i] = val;
|
||||
}
|
||||
let result = validate_embedding(&vector);
|
||||
assert!(result.summary().contains("valid"), "Summary: {}", result.summary());
|
||||
}
|
||||
}
|
||||
562
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-embedding/src/quantization.rs
vendored
Normal file
562
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-embedding/src/quantization.rs
vendored
Normal file
@@ -0,0 +1,562 @@
|
||||
//! Quantization utilities for embedding storage optimization.
|
||||
//!
|
||||
//! Provides F16 and INT8 quantization for reduced storage footprint
|
||||
//! while maintaining acceptable precision for similarity search.
|
||||
//!
|
||||
//! ## Storage Comparison
|
||||
//!
|
||||
//! | Format | Bytes/Dim | Total (1536-D) | Precision Loss |
|
||||
//! |--------|-----------|----------------|----------------|
|
||||
//! | f32 | 4 | 6,144 bytes | None (baseline)|
|
||||
//! | f16 | 2 | 3,072 bytes | ~0.1% typical |
|
||||
//! | i8 | 1 | 1,536 bytes | ~1-2% typical |
|
||||
|
||||
use half::f16;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Quantize f32 embedding to f16 (half precision).
|
||||
///
|
||||
/// F16 provides 50% storage reduction with minimal precision loss.
|
||||
/// Suitable for warm storage tier.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `embedding` - The f32 embedding vector to quantize
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Vector of f16 values
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use sevensense_embedding::quantization::quantize_to_f16;
|
||||
///
|
||||
/// let embedding = vec![0.5, -0.3, 0.8];
|
||||
/// let quantized = quantize_to_f16(&embedding);
|
||||
/// assert_eq!(quantized.len(), embedding.len());
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub fn quantize_to_f16(embedding: &[f32]) -> Vec<f16> {
|
||||
embedding.iter().map(|&x| f16::from_f32(x)).collect()
|
||||
}
|
||||
|
||||
/// Dequantize f16 embedding back to f32.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `quantized` - The f16 quantized embedding
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Vector of f32 values
|
||||
#[must_use]
|
||||
pub fn dequantize_f16(quantized: &[f16]) -> Vec<f32> {
|
||||
quantized.iter().map(|&x| x.to_f32()).collect()
|
||||
}
|
||||
|
||||
/// Quantization parameters for INT8 quantization
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct QuantizationParams {
|
||||
/// Scale factor for quantization
|
||||
pub scale: f32,
|
||||
/// Zero point for asymmetric quantization
|
||||
pub zero_point: f32,
|
||||
/// Minimum value in the original data
|
||||
pub min_val: f32,
|
||||
/// Maximum value in the original data
|
||||
pub max_val: f32,
|
||||
}
|
||||
|
||||
impl QuantizationParams {
|
||||
/// Compute quantization parameters from data
|
||||
#[must_use]
|
||||
pub fn from_data(data: &[f32]) -> Self {
|
||||
if data.is_empty() {
|
||||
return Self {
|
||||
scale: 1.0,
|
||||
zero_point: 0.0,
|
||||
min_val: 0.0,
|
||||
max_val: 0.0,
|
||||
};
|
||||
}
|
||||
|
||||
let min_val = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
|
||||
let max_val = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
|
||||
|
||||
// For symmetric quantization around zero (better for L2-normalized embeddings)
|
||||
let abs_max = min_val.abs().max(max_val.abs());
|
||||
let scale = if abs_max > 1e-12 {
|
||||
abs_max / 127.0
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
Self {
|
||||
scale,
|
||||
zero_point: 0.0, // Symmetric quantization
|
||||
min_val,
|
||||
max_val,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute quantization parameters for asymmetric quantization
|
||||
#[must_use]
|
||||
pub fn from_data_asymmetric(data: &[f32]) -> Self {
|
||||
if data.is_empty() {
|
||||
return Self {
|
||||
scale: 1.0,
|
||||
zero_point: 0.0,
|
||||
min_val: 0.0,
|
||||
max_val: 0.0,
|
||||
};
|
||||
}
|
||||
|
||||
let min_val = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
|
||||
let max_val = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
|
||||
|
||||
let range = max_val - min_val;
|
||||
let scale = if range > 1e-12 {
|
||||
range / 255.0
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
// Zero point maps min_val to 0 in quantized space
|
||||
let zero_point = -min_val / scale;
|
||||
|
||||
Self {
|
||||
scale,
|
||||
zero_point,
|
||||
min_val,
|
||||
max_val,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of INT8 quantization including the quantized values and parameters
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct QuantizedEmbedding {
|
||||
/// Quantized INT8 values
|
||||
pub values: Vec<i8>,
|
||||
/// Quantization parameters for dequantization
|
||||
pub params: QuantizationParams,
|
||||
}
|
||||
|
||||
impl QuantizedEmbedding {
|
||||
/// Get the storage size in bytes
|
||||
#[must_use]
|
||||
pub fn size_bytes(&self) -> usize {
|
||||
// i8 values + params overhead
|
||||
self.values.len() + std::mem::size_of::<QuantizationParams>()
|
||||
}
|
||||
|
||||
/// Dequantize back to f32
|
||||
#[must_use]
|
||||
pub fn dequantize(&self) -> Vec<f32> {
|
||||
dequantize_i8(&self.values, self.params.scale, self.params.zero_point)
|
||||
}
|
||||
}
|
||||
|
||||
/// Quantize f32 embedding to INT8 with scale and zero point.
|
||||
///
|
||||
/// INT8 provides 75% storage reduction but with some precision loss.
|
||||
/// Suitable for cold storage tier or large-scale deployments.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `embedding` - The f32 embedding vector to quantize
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Tuple of (quantized values, scale, zero_point)
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use sevensense_embedding::quantization::quantize_to_i8;
|
||||
///
|
||||
/// let embedding = vec![0.5, -0.3, 0.8, -0.1];
|
||||
/// let (quantized, scale, zero_point) = quantize_to_i8(&embedding);
|
||||
/// assert_eq!(quantized.len(), embedding.len());
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub fn quantize_to_i8(embedding: &[f32]) -> (Vec<i8>, f32, f32) {
|
||||
let params = QuantizationParams::from_data(embedding);
|
||||
|
||||
let quantized: Vec<i8> = embedding
|
||||
.iter()
|
||||
.map(|&x| {
|
||||
let q = (x / params.scale).round();
|
||||
q.clamp(-128.0, 127.0) as i8
|
||||
})
|
||||
.collect();
|
||||
|
||||
(quantized, params.scale, params.zero_point)
|
||||
}
|
||||
|
||||
/// Quantize f32 embedding to INT8 with full quantization info.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `embedding` - The f32 embedding vector to quantize
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// QuantizedEmbedding containing values and parameters
|
||||
#[must_use]
|
||||
pub fn quantize_to_i8_full(embedding: &[f32]) -> QuantizedEmbedding {
|
||||
let params = QuantizationParams::from_data(embedding);
|
||||
|
||||
let values: Vec<i8> = embedding
|
||||
.iter()
|
||||
.map(|&x| {
|
||||
let q = (x / params.scale).round();
|
||||
q.clamp(-128.0, 127.0) as i8
|
||||
})
|
||||
.collect();
|
||||
|
||||
QuantizedEmbedding { values, params }
|
||||
}
|
||||
|
||||
/// Dequantize INT8 embedding back to f32.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `quantized` - The INT8 quantized values
|
||||
/// * `scale` - Scale factor used during quantization
|
||||
/// * `zero_point` - Zero point used during quantization
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Vector of f32 values
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use sevensense_embedding::quantization::{quantize_to_i8, dequantize_i8};
|
||||
///
|
||||
/// let embedding = vec![0.5, -0.3, 0.8, -0.1];
|
||||
/// let (quantized, scale, zero_point) = quantize_to_i8(&embedding);
|
||||
/// let restored = dequantize_i8(&quantized, scale, zero_point);
|
||||
///
|
||||
/// // Check that values are close (within quantization error)
|
||||
/// for (orig, rest) in embedding.iter().zip(restored.iter()) {
|
||||
/// assert!((orig - rest).abs() < 0.05);
|
||||
/// }
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub fn dequantize_i8(quantized: &[i8], scale: f32, zero_point: f32) -> Vec<f32> {
|
||||
quantized
|
||||
.iter()
|
||||
.map(|&q| (q as f32 - zero_point) * scale)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Quantize to unsigned INT8 (0-255 range) for asymmetric quantization
|
||||
#[must_use]
|
||||
pub fn quantize_to_u8(embedding: &[f32]) -> (Vec<u8>, f32, f32) {
|
||||
let params = QuantizationParams::from_data_asymmetric(embedding);
|
||||
|
||||
let quantized: Vec<u8> = embedding
|
||||
.iter()
|
||||
.map(|&x| {
|
||||
let q = (x / params.scale + params.zero_point).round();
|
||||
q.clamp(0.0, 255.0) as u8
|
||||
})
|
||||
.collect();
|
||||
|
||||
(quantized, params.scale, params.zero_point)
|
||||
}
|
||||
|
||||
/// Dequantize unsigned INT8 back to f32
|
||||
#[must_use]
|
||||
pub fn dequantize_u8(quantized: &[u8], scale: f32, zero_point: f32) -> Vec<f32> {
|
||||
quantized
|
||||
.iter()
|
||||
.map(|&q| (q as f32 - zero_point) * scale)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute quantization error (MSE) between original and dequantized values
|
||||
#[must_use]
|
||||
pub fn compute_quantization_error(original: &[f32], dequantized: &[f32]) -> f32 {
|
||||
if original.len() != dequantized.len() || original.is_empty() {
|
||||
return f32::NAN;
|
||||
}
|
||||
|
||||
let mse: f32 = original
|
||||
.iter()
|
||||
.zip(dequantized.iter())
|
||||
.map(|(a, b)| (a - b).powi(2))
|
||||
.sum::<f32>()
|
||||
/ original.len() as f32;
|
||||
|
||||
mse
|
||||
}
|
||||
|
||||
/// Compute cosine similarity preservation after quantization
|
||||
///
|
||||
/// Returns the ratio of cosine similarities (quantized / original)
|
||||
#[must_use]
|
||||
pub fn compute_cosine_preservation(
|
||||
original_a: &[f32],
|
||||
original_b: &[f32],
|
||||
dequant_a: &[f32],
|
||||
dequant_b: &[f32],
|
||||
) -> f32 {
|
||||
let original_sim = cosine_similarity(original_a, original_b);
|
||||
let quant_sim = cosine_similarity(dequant_a, dequant_b);
|
||||
|
||||
if original_sim.abs() < 1e-12 {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
quant_sim / original_sim
|
||||
}
|
||||
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a * norm_b < 1e-12 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
dot / (norm_a * norm_b)
|
||||
}
|
||||
|
||||
/// Statistics about quantization quality
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QuantizationStats {
|
||||
/// Mean squared error
|
||||
pub mse: f32,
|
||||
/// Root mean squared error
|
||||
pub rmse: f32,
|
||||
/// Maximum absolute error
|
||||
pub max_error: f32,
|
||||
/// Mean absolute error
|
||||
pub mean_error: f32,
|
||||
/// Compression ratio (original size / quantized size)
|
||||
pub compression_ratio: f32,
|
||||
}
|
||||
|
||||
impl QuantizationStats {
|
||||
/// Compute statistics comparing original and dequantized embeddings
|
||||
#[must_use]
|
||||
pub fn compute(original: &[f32], dequantized: &[f32], quantized_bytes: usize) -> Self {
|
||||
let mse = compute_quantization_error(original, dequantized);
|
||||
let rmse = mse.sqrt();
|
||||
|
||||
let errors: Vec<f32> = original
|
||||
.iter()
|
||||
.zip(dequantized.iter())
|
||||
.map(|(a, b)| (a - b).abs())
|
||||
.collect();
|
||||
|
||||
let max_error = errors.iter().fold(0.0f32, |a, &b| a.max(b));
|
||||
let mean_error = errors.iter().sum::<f32>() / errors.len().max(1) as f32;
|
||||
|
||||
let original_bytes = original.len() * std::mem::size_of::<f32>();
|
||||
let compression_ratio = original_bytes as f32 / quantized_bytes.max(1) as f32;
|
||||
|
||||
Self {
|
||||
mse,
|
||||
rmse,
|
||||
max_error,
|
||||
mean_error,
|
||||
compression_ratio,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Batch quantization for multiple embeddings
|
||||
pub struct BatchQuantizer {
|
||||
/// Whether to use symmetric quantization
|
||||
pub symmetric: bool,
|
||||
/// Target precision
|
||||
pub precision: QuantizationPrecision,
|
||||
}
|
||||
|
||||
/// Supported quantization precisions
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum QuantizationPrecision {
|
||||
/// 16-bit floating point
|
||||
F16,
|
||||
/// 8-bit signed integer (symmetric)
|
||||
Int8,
|
||||
/// 8-bit unsigned integer (asymmetric)
|
||||
UInt8,
|
||||
}
|
||||
|
||||
impl BatchQuantizer {
|
||||
/// Create a new batch quantizer
|
||||
#[must_use]
|
||||
pub fn new(precision: QuantizationPrecision) -> Self {
|
||||
Self {
|
||||
symmetric: matches!(precision, QuantizationPrecision::Int8),
|
||||
precision,
|
||||
}
|
||||
}
|
||||
|
||||
/// Quantize a batch of embeddings
|
||||
pub fn quantize_batch(&self, embeddings: &[Vec<f32>]) -> Vec<QuantizedEmbedding> {
|
||||
embeddings
|
||||
.iter()
|
||||
.map(|emb| match self.precision {
|
||||
QuantizationPrecision::F16 => {
|
||||
let f16_vals = quantize_to_f16(emb);
|
||||
// Store f16 as i8 pairs for uniform interface
|
||||
let bytes: Vec<i8> = f16_vals
|
||||
.iter()
|
||||
.flat_map(|v| {
|
||||
let bits = v.to_bits();
|
||||
[(bits & 0xFF) as i8, ((bits >> 8) & 0xFF) as i8]
|
||||
})
|
||||
.collect();
|
||||
QuantizedEmbedding {
|
||||
values: bytes,
|
||||
params: QuantizationParams {
|
||||
scale: 1.0,
|
||||
zero_point: 0.0,
|
||||
min_val: 0.0,
|
||||
max_val: 0.0,
|
||||
},
|
||||
}
|
||||
}
|
||||
QuantizationPrecision::Int8 | QuantizationPrecision::UInt8 => {
|
||||
quantize_to_i8_full(emb)
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_f16_roundtrip() {
|
||||
let original = vec![0.5, -0.3, 0.8, -0.1, 0.0, 1.0, -1.0];
|
||||
let quantized = quantize_to_f16(&original);
|
||||
let restored = dequantize_f16(&quantized);
|
||||
|
||||
for (orig, rest) in original.iter().zip(restored.iter()) {
|
||||
assert!((orig - rest).abs() < 0.01, "F16 roundtrip error too large");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_i8_roundtrip() {
|
||||
let original = vec![0.5, -0.3, 0.8, -0.1, 0.0, 0.9, -0.9];
|
||||
let (quantized, scale, zero_point) = quantize_to_i8(&original);
|
||||
let restored = dequantize_i8(&quantized, scale, zero_point);
|
||||
|
||||
for (orig, rest) in original.iter().zip(restored.iter()) {
|
||||
// INT8 has larger quantization error
|
||||
assert!((orig - rest).abs() < 0.02, "I8 roundtrip error too large");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_u8_roundtrip() {
|
||||
let original = vec![0.1, 0.3, 0.5, 0.7, 0.9];
|
||||
let (quantized, scale, zero_point) = quantize_to_u8(&original);
|
||||
let restored = dequantize_u8(&quantized, scale, zero_point);
|
||||
|
||||
for (orig, rest) in original.iter().zip(restored.iter()) {
|
||||
assert!((orig - rest).abs() < 0.02, "U8 roundtrip error too large");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantization_params() {
|
||||
let data = vec![-0.5, 0.0, 0.5, 1.0];
|
||||
let params = QuantizationParams::from_data(&data);
|
||||
|
||||
assert!(params.scale > 0.0);
|
||||
assert_eq!(params.min_val, -0.5);
|
||||
assert_eq!(params.max_val, 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantization_error() {
|
||||
let original = vec![0.5, -0.3, 0.8];
|
||||
let modified = vec![0.51, -0.29, 0.79];
|
||||
let error = compute_quantization_error(&original, &modified);
|
||||
assert!(error < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_preservation() {
|
||||
let a = vec![0.6, 0.8, 0.0];
|
||||
let b = vec![0.0, 0.6, 0.8];
|
||||
|
||||
// Slightly perturbed versions
|
||||
let a_quant = vec![0.61, 0.79, 0.01];
|
||||
let b_quant = vec![0.01, 0.59, 0.81];
|
||||
|
||||
let preservation = compute_cosine_preservation(&a, &b, &a_quant, &b_quant);
|
||||
// Should be close to 1.0 if quantization preserves cosine similarity
|
||||
assert!(preservation > 0.95 && preservation < 1.05);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantization_stats() {
|
||||
let original = vec![0.5, -0.3, 0.8, -0.1];
|
||||
let (quantized, scale, zero_point) = quantize_to_i8(&original);
|
||||
let restored = dequantize_i8(&quantized, scale, zero_point);
|
||||
|
||||
let stats = QuantizationStats::compute(&original, &restored, quantized.len());
|
||||
|
||||
assert!(stats.mse >= 0.0);
|
||||
assert!(stats.rmse >= 0.0);
|
||||
assert!(stats.compression_ratio > 1.0); // Should compress
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_quantizer() {
|
||||
let embeddings = vec![
|
||||
vec![0.5, -0.3, 0.8],
|
||||
vec![-0.1, 0.2, 0.9],
|
||||
];
|
||||
|
||||
let quantizer = BatchQuantizer::new(QuantizationPrecision::Int8);
|
||||
let quantized = quantizer.quantize_batch(&embeddings);
|
||||
|
||||
assert_eq!(quantized.len(), 2);
|
||||
assert_eq!(quantized[0].values.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantized_embedding_dequantize() {
|
||||
let original = vec![0.5, -0.3, 0.8, -0.1];
|
||||
let quantized = quantize_to_i8_full(&original);
|
||||
let restored = quantized.dequantize();
|
||||
|
||||
assert_eq!(restored.len(), original.len());
|
||||
for (orig, rest) in original.iter().zip(restored.iter()) {
|
||||
assert!((orig - rest).abs() < 0.02);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_input() {
|
||||
let empty: Vec<f32> = vec![];
|
||||
|
||||
let f16_result = quantize_to_f16(&empty);
|
||||
assert!(f16_result.is_empty());
|
||||
|
||||
let (i8_result, _, _) = quantize_to_i8(&empty);
|
||||
assert!(i8_result.is_empty());
|
||||
|
||||
let params = QuantizationParams::from_data(&empty);
|
||||
assert_eq!(params.scale, 1.0);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user