Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,94 @@
//! Error types for the audio application layer.
use std::path::PathBuf;
use thiserror::Error;
/// Errors that can occur during audio processing.
#[derive(Debug, Error)]
pub enum AudioError {
/// Failed to read audio file.
#[error("Failed to read audio file '{path}': {message}")]
FileRead {
path: PathBuf,
message: String,
},
/// Unsupported audio format.
#[error("Unsupported audio format: {format}")]
UnsupportedFormat {
format: String,
},
/// Resampling error.
#[error("Resampling failed: {0}")]
Resampling(String),
/// Segmentation error.
#[error("Segmentation failed: {0}")]
Segmentation(String),
/// Spectrogram computation error.
#[error("Spectrogram computation failed: {0}")]
Spectrogram(String),
/// Invalid audio data.
#[error("Invalid audio data: {0}")]
InvalidData(String),
/// I/O error.
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
/// Repository error.
#[error("Repository error: {0}")]
Repository(String),
/// Configuration error.
#[error("Configuration error: {0}")]
Config(String),
}
impl AudioError {
/// Creates a FileRead error.
pub fn file_read(path: impl Into<PathBuf>, message: impl Into<String>) -> Self {
Self::FileRead {
path: path.into(),
message: message.into(),
}
}
/// Creates an UnsupportedFormat error.
pub fn unsupported_format(format: impl Into<String>) -> Self {
Self::UnsupportedFormat {
format: format.into(),
}
}
/// Creates a Resampling error.
pub fn resampling(message: impl Into<String>) -> Self {
Self::Resampling(message.into())
}
/// Creates a Segmentation error.
pub fn segmentation(message: impl Into<String>) -> Self {
Self::Segmentation(message.into())
}
/// Creates a Spectrogram error.
pub fn spectrogram(message: impl Into<String>) -> Self {
Self::Spectrogram(message.into())
}
/// Creates an InvalidData error.
pub fn invalid_data(message: impl Into<String>) -> Self {
Self::InvalidData(message.into())
}
/// Creates a Repository error.
pub fn repository(message: impl Into<String>) -> Self {
Self::Repository(message.into())
}
}
/// Result type for audio operations.
pub type AudioResult<T> = Result<T, AudioError>;

View File

@@ -0,0 +1,10 @@
//! Application layer for the audio ingestion bounded context.
//!
//! This module contains application services that orchestrate
//! domain operations and infrastructure components.
pub mod services;
pub mod error;
pub use services::*;
pub use error::*;

View File

@@ -0,0 +1,268 @@
//! Application services for audio processing.
//!
//! These services coordinate domain operations with infrastructure components
//! to implement the audio ingestion use cases.
use std::path::Path;
use std::sync::Arc;
use tracing::{debug, info, instrument, warn};
use crate::domain::entities::{CallSegment, Recording, RecordingStatus};
use crate::infrastructure::{AudioFileReader, AudioResampler, AudioSegmenter};
use crate::AudioError;
use sevensense_core::{AudioMetadata, Timestamp};
/// Service for ingesting and processing audio files.
///
/// This service orchestrates the audio ingestion pipeline:
/// 1. Read audio from various file formats
/// 2. Resample to standard rate (32kHz)
/// 3. Segment into individual calls
pub struct AudioIngestionService {
reader: Arc<dyn AudioFileReader>,
resampler: Arc<dyn AudioResampler>,
segmenter: Arc<dyn AudioSegmenter>,
}
impl AudioIngestionService {
/// Creates a new AudioIngestionService with the given components.
#[must_use]
pub fn new(
reader: Arc<dyn AudioFileReader>,
resampler: Arc<dyn AudioResampler>,
segmenter: Arc<dyn AudioSegmenter>,
) -> Self {
Self {
reader,
resampler,
segmenter,
}
}
/// Ingests an audio file and creates a Recording entity.
///
/// This performs the following steps:
/// 1. Read the audio file and extract metadata
/// 2. Convert to mono if stereo
/// 3. Resample to 32kHz if needed
///
/// # Arguments
/// * `path` - Path to the audio file
///
/// # Returns
/// A Recording with samples loaded and ready for segmentation.
#[instrument(skip(self), fields(path = %path.display()))]
pub async fn ingest_file(&self, path: &Path) -> Result<Recording, AudioError> {
info!("Starting audio ingestion");
// Read the audio file
let (samples, metadata) = self.reader.read(path).await?;
debug!(
sample_rate = metadata.sample_rate,
channels = metadata.channels,
duration_ms = metadata.duration_ms,
"Read audio file"
);
// Convert to mono if needed
let mono_samples = if metadata.channels > 1 {
debug!("Converting {} channels to mono", metadata.channels);
Self::to_mono(&samples, metadata.channels)
} else {
samples
};
// Resample if needed
let (resampled, final_rate) = if metadata.sample_rate != crate::TARGET_SAMPLE_RATE {
debug!(
"Resampling from {} Hz to {} Hz",
metadata.sample_rate,
crate::TARGET_SAMPLE_RATE
);
let resampled = self
.resampler
.resample(&mono_samples, metadata.sample_rate)?;
(resampled, crate::TARGET_SAMPLE_RATE)
} else {
(mono_samples, metadata.sample_rate)
};
// Calculate new duration after resampling
let duration_ms = (resampled.len() as u64 * 1000) / u64::from(final_rate);
// Create updated metadata
let final_metadata = AudioMetadata::new(
final_rate,
1, // Now mono
metadata.bits_per_sample,
duration_ms,
metadata.format.clone(),
metadata.file_size_bytes,
);
// Create the recording entity
let mut recording = Recording::new(
path.to_path_buf(),
final_metadata,
None, // Location to be set separately
Timestamp::now(),
);
recording.set_samples(resampled);
recording.set_status(RecordingStatus::Processing);
info!(
recording_id = %recording.id,
duration_ms = recording.duration_ms(),
"Audio ingestion complete"
);
Ok(recording)
}
/// Segments a recording into individual call segments.
///
/// This analyzes the audio to find regions of interest (potential
/// bird calls) based on energy levels and signal characteristics.
///
/// # Arguments
/// * `recording` - A Recording with samples loaded
///
/// # Returns
/// A vector of detected CallSegments, also added to the recording.
#[instrument(skip(self, recording), fields(recording_id = %recording.id))]
pub async fn segment_recording(
&self,
recording: &mut Recording,
) -> Result<Vec<CallSegment>, AudioError> {
let samples = recording
.samples
.as_ref()
.ok_or_else(|| AudioError::invalid_data("Recording has no samples loaded"))?;
info!("Starting segmentation");
let segments = self.segmenter.segment(
samples,
recording.metadata.sample_rate,
recording.id,
)?;
let viable_count = segments.iter().filter(|s| s.is_viable()).count();
info!(
total_segments = segments.len(),
viable_segments = viable_count,
"Segmentation complete"
);
// Add segments to recording
for segment in &segments {
recording.add_segment(segment.clone());
}
recording.set_status(RecordingStatus::Processed);
Ok(segments)
}
/// Converts multi-channel audio to mono by averaging channels.
fn to_mono(samples: &[f32], channels: u16) -> Vec<f32> {
let channels = channels as usize;
let frame_count = samples.len() / channels;
let mut mono = Vec::with_capacity(frame_count);
for frame in 0..frame_count {
let mut sum = 0.0f32;
for ch in 0..channels {
sum += samples[frame * channels + ch];
}
mono.push(sum / channels as f32);
}
mono
}
/// Extracts a segment's samples from the recording.
///
/// # Arguments
/// * `recording` - The source recording
/// * `segment` - The segment to extract
///
/// # Returns
/// The audio samples for just this segment.
pub fn extract_segment_samples(
&self,
recording: &Recording,
segment: &CallSegment,
) -> Result<Vec<f32>, AudioError> {
let samples = recording
.samples
.as_ref()
.ok_or_else(|| AudioError::invalid_data("Recording has no samples loaded"))?;
let sample_rate = recording.metadata.sample_rate;
let start_sample = (segment.start_ms as usize * sample_rate as usize) / 1000;
let end_sample = (segment.end_ms as usize * sample_rate as usize) / 1000;
if end_sample > samples.len() {
warn!(
segment_end = end_sample,
samples_len = samples.len(),
"Segment extends beyond recording"
);
}
let end_sample = end_sample.min(samples.len());
let start_sample = start_sample.min(end_sample);
Ok(samples[start_sample..end_sample].to_vec())
}
}
/// Configuration for the audio ingestion service.
#[derive(Debug, Clone)]
pub struct AudioIngestionConfig {
/// Target sample rate for all processing.
pub target_sample_rate: u32,
/// Minimum segment duration in milliseconds.
pub min_segment_duration_ms: u64,
/// Maximum segment duration in milliseconds.
pub max_segment_duration_ms: u64,
/// Energy threshold for segment detection.
pub energy_threshold: f32,
}
impl Default for AudioIngestionConfig {
fn default() -> Self {
Self {
target_sample_rate: crate::TARGET_SAMPLE_RATE,
min_segment_duration_ms: 100,
max_segment_duration_ms: 10_000,
energy_threshold: 0.01,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_to_mono_stereo() {
// Stereo samples: [L, R, L, R, ...]
let stereo = vec![0.5, 0.3, 0.8, 0.6, 0.2, 0.4];
let mono = AudioIngestionService::to_mono(&stereo, 2);
assert_eq!(mono.len(), 3);
assert!((mono[0] - 0.4).abs() < 0.001); // (0.5 + 0.3) / 2
assert!((mono[1] - 0.7).abs() < 0.001); // (0.8 + 0.6) / 2
assert!((mono[2] - 0.3).abs() < 0.001); // (0.2 + 0.4) / 2
}
#[test]
fn test_config_defaults() {
let config = AudioIngestionConfig::default();
assert_eq!(config.target_sample_rate, 32_000);
assert_eq!(config.min_segment_duration_ms, 100);
}
}

View File

@@ -0,0 +1,386 @@
//! Domain entities for audio processing.
//!
//! These are the core aggregates of the audio bounded context.
use serde::{Deserialize, Serialize};
use sevensense_core::{
AudioMetadata, GeoLocation, RecordingId, SegmentId, Timestamp,
};
use std::path::PathBuf;
/// Represents an audio recording from the field.
///
/// A Recording is the aggregate root for the audio context. It contains
/// metadata about the source file and a collection of identified call
/// segments extracted during analysis.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Recording {
/// Unique identifier for this recording.
pub id: RecordingId,
/// Path to the original source file.
pub source_path: PathBuf,
/// Audio metadata (sample rate, channels, duration, etc.).
pub metadata: AudioMetadata,
/// Geographic location where the recording was made.
pub location: Option<GeoLocation>,
/// Timestamp when the recording was captured.
pub recorded_at: Timestamp,
/// Call segments identified in this recording.
pub segments: Vec<CallSegment>,
/// Raw audio samples (mono, resampled to target rate).
#[serde(skip)]
pub samples: Option<Vec<f32>>,
/// Processing status.
pub status: RecordingStatus,
/// When this recording was ingested into the system.
pub ingested_at: Timestamp,
}
impl Recording {
/// Creates a new Recording with the given parameters.
#[must_use]
pub fn new(
source_path: PathBuf,
metadata: AudioMetadata,
location: Option<GeoLocation>,
recorded_at: Timestamp,
) -> Self {
Self {
id: RecordingId::new(),
source_path,
metadata,
location,
recorded_at,
segments: Vec::new(),
samples: None,
status: RecordingStatus::Pending,
ingested_at: Timestamp::now(),
}
}
/// Adds a call segment to this recording.
pub fn add_segment(&mut self, segment: CallSegment) {
self.segments.push(segment);
}
/// Returns the number of segments.
#[must_use]
pub fn segment_count(&self) -> usize {
self.segments.len()
}
/// Returns the total duration in milliseconds.
#[must_use]
pub fn duration_ms(&self) -> u64 {
self.metadata.duration_ms
}
/// Updates the recording status.
pub fn set_status(&mut self, status: RecordingStatus) {
self.status = status;
}
/// Checks if the recording has been processed.
#[must_use]
pub fn is_processed(&self) -> bool {
matches!(self.status, RecordingStatus::Processed)
}
/// Gets high-quality segments only.
#[must_use]
pub fn high_quality_segments(&self) -> Vec<&CallSegment> {
self.segments
.iter()
.filter(|s| matches!(s.signal_quality, SignalQuality::High))
.collect()
}
/// Sets the raw audio samples.
pub fn set_samples(&mut self, samples: Vec<f32>) {
self.samples = Some(samples);
}
/// Takes ownership of the samples, leaving None in their place.
pub fn take_samples(&mut self) -> Option<Vec<f32>> {
self.samples.take()
}
}
/// Status of recording processing.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RecordingStatus {
/// Recording is pending processing.
Pending,
/// Recording is currently being processed.
Processing,
/// Recording has been fully processed.
Processed,
/// Processing failed.
Failed,
}
impl Default for RecordingStatus {
fn default() -> Self {
Self::Pending
}
}
/// Represents a segment of audio containing a potential vocalization.
///
/// Call segments are extracted from recordings using energy-based
/// segmentation and represent isolated vocalizations suitable for
/// species classification.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CallSegment {
/// Unique identifier for this segment.
pub id: SegmentId,
/// The recording this segment belongs to.
pub recording_id: RecordingId,
/// Start time of the segment in milliseconds from recording start.
pub start_ms: u64,
/// End time of the segment in milliseconds from recording start.
pub end_ms: u64,
/// Peak amplitude in the segment (0.0 to 1.0).
pub peak_amplitude: f32,
/// Root mean square energy of the segment.
pub rms_energy: f32,
/// Assessed quality of the signal in this segment.
pub signal_quality: SignalQuality,
/// Zero-crossing rate (useful for distinguishing noise from calls).
pub zero_crossing_rate: f32,
/// Spectral centroid in Hz (indicates "brightness" of sound).
pub spectral_centroid: Option<f32>,
/// Dominant frequency in Hz.
pub dominant_frequency: Option<f32>,
}
impl CallSegment {
/// Creates a new CallSegment with the given parameters.
#[must_use]
pub fn new(
recording_id: RecordingId,
start_ms: u64,
end_ms: u64,
peak_amplitude: f32,
rms_energy: f32,
signal_quality: SignalQuality,
) -> Self {
Self {
id: SegmentId::new(),
recording_id,
start_ms,
end_ms,
peak_amplitude,
rms_energy,
signal_quality,
zero_crossing_rate: 0.0,
spectral_centroid: None,
dominant_frequency: None,
}
}
/// Returns the duration of the segment in milliseconds.
#[must_use]
pub fn duration_ms(&self) -> u64 {
self.end_ms.saturating_sub(self.start_ms)
}
/// Sets the zero-crossing rate.
pub fn with_zero_crossing_rate(mut self, rate: f32) -> Self {
self.zero_crossing_rate = rate;
self
}
/// Sets the spectral centroid.
pub fn with_spectral_centroid(mut self, centroid: f32) -> Self {
self.spectral_centroid = Some(centroid);
self
}
/// Sets the dominant frequency.
pub fn with_dominant_frequency(mut self, freq: f32) -> Self {
self.dominant_frequency = Some(freq);
self
}
/// Checks if this segment meets minimum quality standards.
#[must_use]
pub fn is_viable(&self) -> bool {
!matches!(self.signal_quality, SignalQuality::Noise)
&& self.duration_ms() >= 100
&& self.rms_energy > 0.001
}
}
/// Quality assessment of the signal in a segment.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SignalQuality {
/// High quality signal with clear vocalization.
High,
/// Medium quality signal, usable but may have some noise.
Medium,
/// Low quality signal, may be difficult to classify.
Low,
/// Primarily noise, likely not a vocalization.
Noise,
}
impl SignalQuality {
/// Assesses signal quality based on SNR and energy metrics.
#[must_use]
pub fn from_metrics(snr_db: f32, rms_energy: f32, zero_crossing_rate: f32) -> Self {
// High SNR and moderate energy suggests good signal
if snr_db > 20.0 && rms_energy > 0.05 && zero_crossing_rate < 0.3 {
return Self::High;
}
// Moderate SNR
if snr_db > 10.0 && rms_energy > 0.02 {
return Self::Medium;
}
// Low SNR but some signal present
if snr_db > 3.0 && rms_energy > 0.01 {
return Self::Low;
}
// Too noisy or no clear signal
Self::Noise
}
/// Returns a numeric score (0.0 to 1.0) for the quality level.
#[must_use]
pub fn score(&self) -> f32 {
match self {
Self::High => 1.0,
Self::Medium => 0.7,
Self::Low => 0.4,
Self::Noise => 0.1,
}
}
}
impl Default for SignalQuality {
fn default() -> Self {
Self::Medium
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_metadata() -> AudioMetadata {
AudioMetadata::new(32000, 1, 16, 5000, "wav".to_string(), 320000)
}
#[test]
fn test_recording_creation() {
let recording = Recording::new(
PathBuf::from("/test/recording.wav"),
create_test_metadata(),
None,
Timestamp::now(),
);
assert_eq!(recording.segment_count(), 0);
assert_eq!(recording.duration_ms(), 5000);
assert!(!recording.is_processed());
}
#[test]
fn test_recording_add_segment() {
let mut recording = Recording::new(
PathBuf::from("/test/recording.wav"),
create_test_metadata(),
None,
Timestamp::now(),
);
let segment = CallSegment::new(
recording.id,
1000,
2000,
0.8,
0.3,
SignalQuality::High,
);
recording.add_segment(segment);
assert_eq!(recording.segment_count(), 1);
}
#[test]
fn test_segment_duration() {
let segment = CallSegment::new(
RecordingId::new(),
1000,
2500,
0.8,
0.3,
SignalQuality::High,
);
assert_eq!(segment.duration_ms(), 1500);
}
#[test]
fn test_segment_viability() {
let viable = CallSegment::new(
RecordingId::new(),
0,
500,
0.5,
0.1,
SignalQuality::Medium,
);
assert!(viable.is_viable());
let noise = CallSegment::new(
RecordingId::new(),
0,
500,
0.1,
0.001,
SignalQuality::Noise,
);
assert!(!noise.is_viable());
}
#[test]
fn test_signal_quality_from_metrics() {
assert_eq!(
SignalQuality::from_metrics(25.0, 0.1, 0.2),
SignalQuality::High
);
assert_eq!(
SignalQuality::from_metrics(15.0, 0.05, 0.3),
SignalQuality::Medium
);
assert_eq!(
SignalQuality::from_metrics(5.0, 0.02, 0.4),
SignalQuality::Low
);
assert_eq!(
SignalQuality::from_metrics(1.0, 0.005, 0.5),
SignalQuality::Noise
);
}
}

View File

@@ -0,0 +1,12 @@
//! Domain layer for the audio ingestion bounded context.
//!
//! This module contains the core domain model:
//! - Entities: Recording, CallSegment
//! - Value objects: SignalQuality
//! - Repository traits: RecordingRepository
pub mod entities;
pub mod repository;
pub use entities::*;
pub use repository::*;

View File

@@ -0,0 +1,187 @@
//! Repository traits for the audio domain.
//!
//! These traits define the persistence interface for domain entities,
//! following the repository pattern from DDD.
use async_trait::async_trait;
use sevensense_core::{GeoLocation, RecordingId, SegmentId};
use super::entities::{CallSegment, Recording, SignalQuality};
use crate::AudioError;
/// Repository trait for Recording entities.
///
/// Implementations handle the persistence of recordings and their
/// associated segments. This trait enables the domain layer to
/// remain independent of the specific storage mechanism.
#[async_trait]
pub trait RecordingRepository: Send + Sync {
/// Saves a recording to the repository.
///
/// If the recording already exists, it will be updated.
async fn save(&self, recording: &Recording) -> Result<(), AudioError>;
/// Finds a recording by its unique identifier.
async fn find_by_id(&self, id: &RecordingId) -> Result<Option<Recording>, AudioError>;
/// Finds all recordings within a radius of a geographic location.
///
/// # Arguments
/// * `loc` - The center point of the search
/// * `radius_km` - Search radius in kilometers
async fn find_by_location(
&self,
loc: &GeoLocation,
radius_km: f64,
) -> Result<Vec<Recording>, AudioError>;
/// Finds recordings by source file path pattern.
async fn find_by_path_pattern(&self, pattern: &str) -> Result<Vec<Recording>, AudioError>;
/// Deletes a recording and all its segments.
async fn delete(&self, id: &RecordingId) -> Result<bool, AudioError>;
/// Returns the total count of recordings.
async fn count(&self) -> Result<u64, AudioError>;
/// Lists recordings with pagination.
async fn list(&self, offset: u64, limit: u64) -> Result<Vec<Recording>, AudioError>;
}
/// Repository trait for CallSegment entities.
///
/// While segments are part of the Recording aggregate, this repository
/// provides direct access for querying and analysis purposes.
#[async_trait]
pub trait SegmentRepository: Send + Sync {
/// Saves a segment to the repository.
async fn save(&self, segment: &CallSegment) -> Result<(), AudioError>;
/// Saves multiple segments in batch.
async fn save_batch(&self, segments: &[CallSegment]) -> Result<(), AudioError>;
/// Finds a segment by its unique identifier.
async fn find_by_id(&self, id: &SegmentId) -> Result<Option<CallSegment>, AudioError>;
/// Finds all segments for a recording.
async fn find_by_recording(&self, recording_id: &RecordingId) -> Result<Vec<CallSegment>, AudioError>;
/// Finds segments by quality level.
async fn find_by_quality(&self, quality: SignalQuality) -> Result<Vec<CallSegment>, AudioError>;
/// Finds segments within a time range of a recording.
async fn find_in_time_range(
&self,
recording_id: &RecordingId,
start_ms: u64,
end_ms: u64,
) -> Result<Vec<CallSegment>, AudioError>;
/// Deletes a segment.
async fn delete(&self, id: &SegmentId) -> Result<bool, AudioError>;
/// Deletes all segments for a recording.
async fn delete_by_recording(&self, recording_id: &RecordingId) -> Result<u64, AudioError>;
}
/// Query specification for finding recordings.
#[derive(Debug, Clone, Default)]
pub struct RecordingQuery {
/// Filter by location and radius.
pub location: Option<(GeoLocation, f64)>,
/// Filter by minimum duration in milliseconds.
pub min_duration_ms: Option<u64>,
/// Filter by maximum duration in milliseconds.
pub max_duration_ms: Option<u64>,
/// Filter by minimum number of segments.
pub min_segments: Option<usize>,
/// Filter by source path pattern (glob-style).
pub path_pattern: Option<String>,
/// Pagination offset.
pub offset: u64,
/// Pagination limit.
pub limit: u64,
}
impl RecordingQuery {
/// Creates a new query with default settings.
#[must_use]
pub fn new() -> Self {
Self {
limit: 100,
..Default::default()
}
}
/// Sets the location filter.
#[must_use]
pub fn with_location(mut self, loc: GeoLocation, radius_km: f64) -> Self {
self.location = Some((loc, radius_km));
self
}
/// Sets the minimum duration filter.
#[must_use]
pub fn with_min_duration(mut self, ms: u64) -> Self {
self.min_duration_ms = Some(ms);
self
}
/// Sets the maximum duration filter.
#[must_use]
pub fn with_max_duration(mut self, ms: u64) -> Self {
self.max_duration_ms = Some(ms);
self
}
/// Sets the minimum segments filter.
#[must_use]
pub fn with_min_segments(mut self, count: usize) -> Self {
self.min_segments = Some(count);
self
}
/// Sets the path pattern filter.
#[must_use]
pub fn with_path_pattern(mut self, pattern: impl Into<String>) -> Self {
self.path_pattern = Some(pattern.into());
self
}
/// Sets pagination parameters.
#[must_use]
pub fn with_pagination(mut self, offset: u64, limit: u64) -> Self {
self.offset = offset;
self.limit = limit;
self
}
}
/// Extended repository trait with query support.
#[async_trait]
pub trait RecordingQueryRepository: RecordingRepository {
/// Executes a query and returns matching recordings.
async fn query(&self, query: &RecordingQuery) -> Result<Vec<Recording>, AudioError>;
/// Counts recordings matching a query.
async fn query_count(&self, query: &RecordingQuery) -> Result<u64, AudioError>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_builder() {
let query = RecordingQuery::new()
.with_min_duration(1000)
.with_max_duration(60000)
.with_min_segments(5)
.with_pagination(0, 50);
assert_eq!(query.min_duration_ms, Some(1000));
assert_eq!(query.max_duration_ms, Some(60000));
assert_eq!(query.min_segments, Some(5));
assert_eq!(query.limit, 50);
}
}

View File

@@ -0,0 +1,397 @@
//! Audio file reading implementation using Symphonia.
//!
//! Symphonia provides support for multiple audio formats including
//! WAV, FLAC, MP3, OGG Vorbis, and more.
use async_trait::async_trait;
use std::fs::File;
use std::path::Path;
use symphonia::core::audio::{AudioBufferRef, Signal};
use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
use symphonia::core::formats::FormatOptions;
use symphonia::core::io::MediaSourceStream;
use symphonia::core::meta::MetadataOptions;
use symphonia::core::probe::Hint;
use tracing::{debug, instrument};
use super::AudioFileReader;
use crate::AudioError;
use sevensense_core::AudioMetadata;
/// Audio file reader using Symphonia for multi-format support.
pub struct SymphoniaFileReader {
/// Supported file extensions.
supported_extensions: Vec<&'static str>,
}
impl SymphoniaFileReader {
/// Creates a new SymphoniaFileReader.
#[must_use]
pub fn new() -> Self {
Self {
supported_extensions: vec![
"wav", "wave", "flac", "mp3", "ogg", "oga", "opus", "m4a", "aac", "aiff", "aif",
],
}
}
/// Converts an audio buffer to f32 samples.
fn buffer_to_samples(buf: AudioBufferRef<'_>) -> Vec<f32> {
match buf {
AudioBufferRef::F32(buf) => {
let channels = buf.spec().channels.count();
let frames = buf.frames();
let mut samples = Vec::with_capacity(frames * channels);
for frame in 0..frames {
for ch in 0..channels {
samples.push(buf.chan(ch)[frame]);
}
}
samples
}
AudioBufferRef::S16(buf) => {
let channels = buf.spec().channels.count();
let frames = buf.frames();
let mut samples = Vec::with_capacity(frames * channels);
for frame in 0..frames {
for ch in 0..channels {
// Convert i16 to f32 (-1.0 to 1.0)
let sample = f32::from(buf.chan(ch)[frame]) / f32::from(i16::MAX);
samples.push(sample);
}
}
samples
}
AudioBufferRef::S24(buf) => {
let channels = buf.spec().channels.count();
let frames = buf.frames();
let mut samples = Vec::with_capacity(frames * channels);
const MAX_24: f32 = 8_388_607.0; // 2^23 - 1
for frame in 0..frames {
for ch in 0..channels {
let sample = buf.chan(ch)[frame].inner() as f32 / MAX_24;
samples.push(sample);
}
}
samples
}
AudioBufferRef::S32(buf) => {
let channels = buf.spec().channels.count();
let frames = buf.frames();
let mut samples = Vec::with_capacity(frames * channels);
for frame in 0..frames {
for ch in 0..channels {
let sample = buf.chan(ch)[frame] as f32 / i32::MAX as f32;
samples.push(sample);
}
}
samples
}
AudioBufferRef::F64(buf) => {
let channels = buf.spec().channels.count();
let frames = buf.frames();
let mut samples = Vec::with_capacity(frames * channels);
for frame in 0..frames {
for ch in 0..channels {
samples.push(buf.chan(ch)[frame] as f32);
}
}
samples
}
AudioBufferRef::U8(buf) => {
let channels = buf.spec().channels.count();
let frames = buf.frames();
let mut samples = Vec::with_capacity(frames * channels);
for frame in 0..frames {
for ch in 0..channels {
// Convert u8 (0-255) to f32 (-1.0 to 1.0)
let sample = (f32::from(buf.chan(ch)[frame]) - 128.0) / 128.0;
samples.push(sample);
}
}
samples
}
AudioBufferRef::U16(buf) => {
let channels = buf.spec().channels.count();
let frames = buf.frames();
let mut samples = Vec::with_capacity(frames * channels);
for frame in 0..frames {
for ch in 0..channels {
// Convert u16 to f32 (-1.0 to 1.0)
let sample = (buf.chan(ch)[frame] as f32 / u16::MAX as f32) * 2.0 - 1.0;
samples.push(sample);
}
}
samples
}
AudioBufferRef::U24(buf) => {
let channels = buf.spec().channels.count();
let frames = buf.frames();
let mut samples = Vec::with_capacity(frames * channels);
const MAX_24: f32 = 16_777_215.0; // 2^24 - 1
for frame in 0..frames {
for ch in 0..channels {
let sample = (buf.chan(ch)[frame].inner() as f32 / MAX_24) * 2.0 - 1.0;
samples.push(sample);
}
}
samples
}
AudioBufferRef::U32(buf) => {
let channels = buf.spec().channels.count();
let frames = buf.frames();
let mut samples = Vec::with_capacity(frames * channels);
for frame in 0..frames {
for ch in 0..channels {
let sample = (buf.chan(ch)[frame] as f32 / u32::MAX as f32) * 2.0 - 1.0;
samples.push(sample);
}
}
samples
}
AudioBufferRef::S8(buf) => {
let channels = buf.spec().channels.count();
let frames = buf.frames();
let mut samples = Vec::with_capacity(frames * channels);
for frame in 0..frames {
for ch in 0..channels {
let sample = f32::from(buf.chan(ch)[frame]) / f32::from(i8::MAX);
samples.push(sample);
}
}
samples
}
}
}
}
impl Default for SymphoniaFileReader {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl AudioFileReader for SymphoniaFileReader {
#[instrument(skip(self), fields(path = %path.display()))]
async fn read(&self, path: &Path) -> Result<(Vec<f32>, AudioMetadata), AudioError> {
// Get file metadata for size
let file_metadata = std::fs::metadata(path)
.map_err(|e| AudioError::file_read(path, e.to_string()))?;
let file_size = file_metadata.len();
// Open the file
let file = File::open(path)
.map_err(|e| AudioError::file_read(path, e.to_string()))?;
// Create media source stream
let mss = MediaSourceStream::new(Box::new(file), Default::default());
// Create a hint for the format
let mut hint = Hint::new();
if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
hint.with_extension(ext);
}
// Probe the format
let probed = symphonia::default::get_probe()
.format(&hint, mss, &FormatOptions::default(), &MetadataOptions::default())
.map_err(|e| AudioError::file_read(path, format!("Failed to probe format: {e}")))?;
let mut format = probed.format;
// Find the first audio track
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
.ok_or_else(|| AudioError::file_read(path, "No audio track found"))?;
let track_id = track.id;
// Get codec parameters
let codec_params = track.codec_params.clone();
let sample_rate = codec_params
.sample_rate
.ok_or_else(|| AudioError::file_read(path, "Unknown sample rate"))?;
let channels = codec_params
.channels
.map(|c| c.count() as u16)
.unwrap_or(1);
let bits_per_sample = codec_params
.bits_per_sample
.unwrap_or(16) as u16;
debug!(
sample_rate = sample_rate,
channels = channels,
bits = bits_per_sample,
"Decoded audio parameters"
);
// Create decoder
let mut decoder = symphonia::default::get_codecs()
.make(&codec_params, &DecoderOptions::default())
.map_err(|e| AudioError::file_read(path, format!("Failed to create decoder: {e}")))?;
// Decode all packets
let mut all_samples = Vec::new();
loop {
let packet = match format.next_packet() {
Ok(packet) => packet,
Err(symphonia::core::errors::Error::IoError(ref e))
if e.kind() == std::io::ErrorKind::UnexpectedEof =>
{
break;
}
Err(e) => {
return Err(AudioError::file_read(path, format!("Decode error: {e}")));
}
};
// Skip packets from other tracks
if packet.track_id() != track_id {
continue;
}
// Decode the packet
let decoded = match decoder.decode(&packet) {
Ok(decoded) => decoded,
Err(symphonia::core::errors::Error::DecodeError(e)) => {
debug!("Decode error (skipping packet): {}", e);
continue;
}
Err(e) => {
return Err(AudioError::file_read(path, format!("Decode error: {e}")));
}
};
// Convert to f32 samples
let samples = Self::buffer_to_samples(decoded);
all_samples.extend(samples);
}
// Calculate duration
let frame_count = all_samples.len() / channels as usize;
let duration_ms = (frame_count as u64 * 1000) / u64::from(sample_rate);
// Get format string from extension
let format_str = path
.extension()
.and_then(|e| e.to_str())
.unwrap_or("unknown")
.to_lowercase();
let metadata = AudioMetadata::new(
sample_rate,
channels,
bits_per_sample,
duration_ms,
format_str,
file_size,
);
debug!(
total_samples = all_samples.len(),
duration_ms = duration_ms,
"Audio decoding complete"
);
Ok((all_samples, metadata))
}
fn supports_extension(&self, ext: &str) -> bool {
self.supported_extensions
.contains(&ext.to_lowercase().as_str())
}
}
/// Simple WAV file reader using hound (for simple cases).
pub struct HoundWavReader;
impl HoundWavReader {
/// Creates a new HoundWavReader.
#[must_use]
pub const fn new() -> Self {
Self
}
/// Reads a WAV file synchronously.
pub fn read_wav(&self, path: &Path) -> Result<(Vec<f32>, AudioMetadata), AudioError> {
let reader = hound::WavReader::open(path)
.map_err(|e| AudioError::file_read(path, e.to_string()))?;
let spec = reader.spec();
let sample_rate = spec.sample_rate;
let channels = spec.channels;
let bits_per_sample = spec.bits_per_sample;
let file_size = std::fs::metadata(path)
.map(|m| m.len())
.unwrap_or(0);
let samples: Vec<f32> = match spec.sample_format {
hound::SampleFormat::Float => {
reader.into_samples::<f32>()
.filter_map(Result::ok)
.collect()
}
hound::SampleFormat::Int => {
let max_val = (1i32 << (bits_per_sample - 1)) as f32;
reader.into_samples::<i32>()
.filter_map(Result::ok)
.map(|s| s as f32 / max_val)
.collect()
}
};
let frame_count = samples.len() / channels as usize;
let duration_ms = (frame_count as u64 * 1000) / u64::from(sample_rate);
let metadata = AudioMetadata::new(
sample_rate,
channels,
bits_per_sample,
duration_ms,
"wav".to_string(),
file_size,
);
Ok((samples, metadata))
}
}
impl Default for HoundWavReader {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_symphonia_supported_extensions() {
let reader = SymphoniaFileReader::new();
assert!(reader.supports_extension("wav"));
assert!(reader.supports_extension("WAV"));
assert!(reader.supports_extension("flac"));
assert!(reader.supports_extension("mp3"));
assert!(!reader.supports_extension("txt"));
}
}

View File

@@ -0,0 +1,72 @@
//! Infrastructure layer for the audio bounded context.
//!
//! This module contains technical implementations for:
//! - Audio file reading (multiple formats via Symphonia)
//! - Resampling (via Rubato)
//! - Signal segmentation (energy-based algorithm)
pub mod file_reader;
pub mod resampler;
pub mod segmenter;
pub use file_reader::*;
pub use resampler::*;
pub use segmenter::*;
use async_trait::async_trait;
use sevensense_core::{AudioMetadata, RecordingId};
use std::path::Path;
use crate::domain::entities::CallSegment;
use crate::AudioError;
/// Trait for reading audio files.
#[async_trait]
pub trait AudioFileReader: Send + Sync {
/// Reads an audio file and returns samples with metadata.
///
/// # Arguments
/// * `path` - Path to the audio file
///
/// # Returns
/// A tuple of (samples, metadata). Samples are interleaved if multi-channel.
async fn read(&self, path: &Path) -> Result<(Vec<f32>, AudioMetadata), AudioError>;
/// Checks if this reader supports the given file extension.
fn supports_extension(&self, ext: &str) -> bool;
}
/// Trait for audio resampling.
pub trait AudioResampler: Send + Sync {
/// Resamples audio to the target sample rate.
///
/// # Arguments
/// * `samples` - Input samples (mono)
/// * `source_rate` - Source sample rate in Hz
///
/// # Returns
/// Resampled audio at the target rate.
fn resample(&self, samples: &[f32], source_rate: u32) -> Result<Vec<f32>, AudioError>;
/// Returns the target sample rate.
fn target_rate(&self) -> u32;
}
/// Trait for audio segmentation.
pub trait AudioSegmenter: Send + Sync {
/// Segments audio into regions of interest.
///
/// # Arguments
/// * `samples` - Input samples (mono)
/// * `sample_rate` - Sample rate in Hz
/// * `recording_id` - ID of the parent recording
///
/// # Returns
/// A vector of detected segments.
fn segment(
&self,
samples: &[f32],
sample_rate: u32,
recording_id: RecordingId,
) -> Result<Vec<CallSegment>, AudioError>;
}

View File

@@ -0,0 +1,323 @@
//! Audio resampling implementation using Rubato.
//!
//! Rubato provides high-quality sample rate conversion using
//! polyphase sinc interpolation.
use rubato::{
FftFixedInOut, Resampler, SincFixedIn,
SincInterpolationParameters, SincInterpolationType, WindowFunction,
};
use tracing::{debug, instrument};
use super::AudioResampler;
use crate::AudioError;
/// High-quality audio resampler using Rubato.
pub struct RubatoResampler {
/// Target sample rate in Hz.
target_rate: u32,
/// Resampler quality settings.
quality: ResamplerQuality,
}
/// Quality presets for resampling.
#[derive(Debug, Clone, Copy, Default)]
pub enum ResamplerQuality {
/// Fast resampling, lower quality.
Fast,
/// Balanced quality and speed.
#[default]
Normal,
/// High quality, slower processing.
High,
/// Maximum quality for critical applications.
Best,
}
impl ResamplerQuality {
/// Returns the sinc length for this quality level.
fn sinc_len(&self) -> usize {
match self {
Self::Fast => 64,
Self::Normal => 128,
Self::High => 256,
Self::Best => 512,
}
}
/// Returns the oversampling factor for this quality level.
fn oversampling_factor(&self) -> usize {
match self {
Self::Fast => 64,
Self::Normal => 128,
Self::High => 256,
Self::Best => 256,
}
}
}
impl RubatoResampler {
/// Creates a new RubatoResampler with the target sample rate.
///
/// # Arguments
/// * `target_rate` - Target sample rate in Hz (typically 32000)
///
/// # Errors
/// Returns an error if the target rate is invalid.
pub fn new(target_rate: u32) -> Result<Self, AudioError> {
if target_rate == 0 {
return Err(AudioError::Config("Target rate must be positive".into()));
}
Ok(Self {
target_rate,
quality: ResamplerQuality::default(),
})
}
/// Creates a resampler with specific quality settings.
pub fn with_quality(mut self, quality: ResamplerQuality) -> Self {
self.quality = quality;
self
}
/// Creates the Rubato resampler instance.
fn create_resampler(
&self,
source_rate: u32,
chunk_size: usize,
) -> Result<SincFixedIn<f32>, AudioError> {
let params = SincInterpolationParameters {
sinc_len: self.quality.sinc_len(),
f_cutoff: 0.95,
interpolation: SincInterpolationType::Cubic,
oversampling_factor: self.quality.oversampling_factor(),
window: WindowFunction::BlackmanHarris2,
};
let resample_ratio = f64::from(self.target_rate) / f64::from(source_rate);
SincFixedIn::new(
resample_ratio,
2.0, // Max relative deviation from nominal ratio
params,
chunk_size,
1, // Mono
)
.map_err(|e| AudioError::resampling(format!("Failed to create resampler: {e}")))
}
}
impl AudioResampler for RubatoResampler {
#[instrument(skip(self, samples), fields(source_rate = source_rate, target_rate = self.target_rate))]
fn resample(&self, samples: &[f32], source_rate: u32) -> Result<Vec<f32>, AudioError> {
if source_rate == self.target_rate {
debug!("No resampling needed, rates match");
return Ok(samples.to_vec());
}
if samples.is_empty() {
return Ok(Vec::new());
}
let resample_ratio = f64::from(self.target_rate) / f64::from(source_rate);
let expected_output_len = (samples.len() as f64 * resample_ratio).ceil() as usize;
// Use chunk-based processing for memory efficiency
let chunk_size = 1024.min(samples.len());
let mut resampler = self.create_resampler(source_rate, chunk_size)?;
let mut output = Vec::with_capacity(expected_output_len);
let mut input_pos = 0;
// Process full chunks
while input_pos + chunk_size <= samples.len() {
let input_chunk = vec![samples[input_pos..input_pos + chunk_size].to_vec()];
let output_chunk = resampler
.process(&input_chunk, None)
.map_err(|e| AudioError::resampling(format!("Resampling failed: {e}")))?;
output.extend(&output_chunk[0]);
input_pos += chunk_size;
}
// Handle remaining samples
if input_pos < samples.len() {
let remaining = samples.len() - input_pos;
// Pad the remaining samples to chunk size
let mut padded = samples[input_pos..].to_vec();
padded.resize(chunk_size, 0.0);
let input_chunk = vec![padded];
let output_chunk = resampler
.process(&input_chunk, None)
.map_err(|e| AudioError::resampling(format!("Final chunk failed: {e}")))?;
// Only take the proportional amount of output
let output_samples = ((remaining as f64) * resample_ratio).ceil() as usize;
output.extend(&output_chunk[0][..output_samples.min(output_chunk[0].len())]);
}
debug!(
input_samples = samples.len(),
output_samples = output.len(),
"Resampling complete"
);
Ok(output)
}
fn target_rate(&self) -> u32 {
self.target_rate
}
}
/// FFT-based resampler for specific ratio resampling.
///
/// More efficient when the ratio is a simple fraction.
pub struct FftResampler {
target_rate: u32,
}
impl FftResampler {
/// Creates a new FFT-based resampler.
#[must_use]
pub const fn new(target_rate: u32) -> Self {
Self { target_rate }
}
/// Resamples using FFT method.
pub fn resample(&self, samples: &[f32], source_rate: u32) -> Result<Vec<f32>, AudioError> {
if source_rate == self.target_rate {
return Ok(samples.to_vec());
}
if samples.is_empty() {
return Ok(Vec::new());
}
// Calculate GCD for ratio simplification
let gcd = Self::gcd(source_rate, self.target_rate);
let upsample_factor = self.target_rate / gcd;
let downsample_factor = source_rate / gcd;
// For very complex ratios, fall back to sinc interpolation
if upsample_factor > 16 || downsample_factor > 16 {
let rubato = RubatoResampler::new(self.target_rate)?;
return rubato.resample(samples, source_rate);
}
let chunk_size = 1024.min(samples.len());
let mut resampler = FftFixedInOut::<f32>::new(
source_rate as usize,
self.target_rate as usize,
chunk_size,
1, // Mono
)
.map_err(|e| AudioError::resampling(format!("FFT resampler creation failed: {e}")))?;
let resample_ratio = f64::from(self.target_rate) / f64::from(source_rate);
let expected_len = (samples.len() as f64 * resample_ratio).ceil() as usize;
let mut output = Vec::with_capacity(expected_len);
let input_frames = resampler.input_frames_next();
let mut pos = 0;
while pos + input_frames <= samples.len() {
let input = vec![samples[pos..pos + input_frames].to_vec()];
let result = resampler
.process(&input, None)
.map_err(|e| AudioError::resampling(e.to_string()))?;
output.extend(&result[0]);
pos += input_frames;
}
Ok(output)
}
/// Calculates greatest common divisor using Euclidean algorithm.
const fn gcd(mut a: u32, mut b: u32) -> u32 {
while b != 0 {
let temp = b;
b = a % b;
a = temp;
}
a
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_resampler_creation() {
let resampler = RubatoResampler::new(32000);
assert!(resampler.is_ok());
assert_eq!(resampler.unwrap().target_rate(), 32000);
}
#[test]
fn test_resampler_invalid_rate() {
let result = RubatoResampler::new(0);
assert!(result.is_err());
}
#[test]
fn test_no_resample_needed() {
let resampler = RubatoResampler::new(44100).unwrap();
let samples: Vec<f32> = (0..1000).map(|i| (i as f32).sin()).collect();
let result = resampler.resample(&samples, 44100).unwrap();
assert_eq!(result.len(), samples.len());
}
#[test]
fn test_downsample() {
let resampler = RubatoResampler::new(32000).unwrap();
let samples: Vec<f32> = (0..44100).map(|i| (i as f32 * 0.01).sin()).collect();
let result = resampler.resample(&samples, 44100).unwrap();
// Output should be approximately 32000/44100 * input length
let expected_ratio = 32000.0 / 44100.0;
let expected_len = (samples.len() as f64 * expected_ratio) as usize;
// Allow 5% tolerance due to filtering artifacts
assert!((result.len() as f64 - expected_len as f64).abs() < expected_len as f64 * 0.05);
}
#[test]
fn test_upsample() {
let resampler = RubatoResampler::new(48000).unwrap();
let samples: Vec<f32> = (0..32000).map(|i| (i as f32 * 0.01).sin()).collect();
let result = resampler.resample(&samples, 32000).unwrap();
// Output should be approximately 48000/32000 * input length
let expected_ratio = 48000.0 / 32000.0;
let expected_len = (samples.len() as f64 * expected_ratio) as usize;
assert!((result.len() as f64 - expected_len as f64).abs() < expected_len as f64 * 0.05);
}
#[test]
fn test_empty_input() {
let resampler = RubatoResampler::new(32000).unwrap();
let result = resampler.resample(&[], 44100).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_quality_settings() {
assert_eq!(ResamplerQuality::Fast.sinc_len(), 64);
assert_eq!(ResamplerQuality::Best.sinc_len(), 512);
}
#[test]
fn test_gcd() {
assert_eq!(FftResampler::gcd(44100, 32000), 100);
assert_eq!(FftResampler::gcd(48000, 44100), 300);
}
}

View File

@@ -0,0 +1,475 @@
//! Energy-based audio segmentation for isolating vocalizations.
//!
//! This module implements segmentation algorithms to detect regions
//! of interest (potential bird calls) in audio recordings based on
//! energy levels and signal characteristics.
use rayon::prelude::*;
use sevensense_core::RecordingId;
use tracing::{debug, instrument};
use super::AudioSegmenter;
use crate::domain::entities::{CallSegment, SignalQuality};
use crate::AudioError;
/// Energy-based audio segmenter.
///
/// Uses short-time energy analysis with adaptive thresholding
/// to detect regions containing vocalizations.
pub struct EnergySegmenter {
config: SegmenterConfig,
}
/// Configuration for the energy segmenter.
#[derive(Debug, Clone)]
pub struct SegmenterConfig {
/// Window size for energy calculation in samples.
pub window_size: usize,
/// Hop size between windows in samples.
pub hop_size: usize,
/// Minimum energy ratio above noise floor for detection.
pub energy_threshold_ratio: f32,
/// Minimum segment duration in milliseconds.
pub min_segment_ms: u64,
/// Maximum segment duration in milliseconds.
pub max_segment_ms: u64,
/// Minimum gap between segments in milliseconds.
pub min_gap_ms: u64,
/// Number of frames to use for noise floor estimation.
pub noise_floor_frames: usize,
/// Smoothing factor for energy envelope (0.0 to 1.0).
pub smoothing: f32,
}
impl Default for SegmenterConfig {
fn default() -> Self {
Self {
window_size: 1024, // ~32ms at 32kHz
hop_size: 256, // ~8ms hop
energy_threshold_ratio: 3.0, // 3x noise floor
min_segment_ms: 100, // Minimum 100ms
max_segment_ms: 10_000, // Maximum 10s
min_gap_ms: 50, // 50ms minimum gap
noise_floor_frames: 10, // Use 10 quietest frames
smoothing: 0.3, // Light smoothing
}
}
}
impl EnergySegmenter {
/// Creates a new EnergySegmenter with default configuration.
#[must_use]
pub fn new() -> Self {
Self {
config: SegmenterConfig::default(),
}
}
/// Creates an EnergySegmenter with custom configuration.
#[must_use]
pub fn with_config(config: SegmenterConfig) -> Self {
Self { config }
}
/// Calculates the RMS energy of a window.
fn calculate_rms(samples: &[f32]) -> f32 {
if samples.is_empty() {
return 0.0;
}
let sum_squares: f32 = samples.iter().map(|s| s * s).sum();
(sum_squares / samples.len() as f32).sqrt()
}
/// Calculates the peak amplitude in a window.
fn calculate_peak(samples: &[f32]) -> f32 {
samples
.iter()
.map(|s| s.abs())
.fold(0.0f32, |a, b| a.max(b))
}
/// Calculates zero-crossing rate for a window.
fn calculate_zcr(samples: &[f32]) -> f32 {
if samples.len() < 2 {
return 0.0;
}
let crossings: usize = samples
.windows(2)
.filter(|w| (w[0] >= 0.0) != (w[1] >= 0.0))
.count();
crossings as f32 / (samples.len() - 1) as f32
}
/// Estimates the noise floor from the quietest frames.
fn estimate_noise_floor(&self, energies: &[f32]) -> f32 {
let mut sorted = energies.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let noise_frames = self.config.noise_floor_frames.min(sorted.len() / 4).max(1);
let noise_sum: f32 = sorted.iter().take(noise_frames).sum();
(noise_sum / noise_frames as f32).max(1e-10)
}
/// Smooths the energy envelope using exponential moving average.
fn smooth_envelope(&self, energies: &[f32]) -> Vec<f32> {
let alpha = self.config.smoothing;
let mut smoothed = Vec::with_capacity(energies.len());
if energies.is_empty() {
return smoothed;
}
smoothed.push(energies[0]);
for &energy in &energies[1..] {
let prev = *smoothed.last().unwrap();
smoothed.push(alpha * energy + (1.0 - alpha) * prev);
}
smoothed
}
/// Finds segment boundaries from the binary activity signal.
fn find_segments(
&self,
activity: &[bool],
sample_rate: u32,
) -> Vec<(u64, u64)> {
let hop_ms = (self.config.hop_size as u64 * 1000) / u64::from(sample_rate);
let min_frames = (self.config.min_segment_ms / hop_ms).max(1) as usize;
let max_frames = (self.config.max_segment_ms / hop_ms) as usize;
let min_gap_frames = (self.config.min_gap_ms / hop_ms).max(1) as usize;
let mut segments = Vec::new();
let mut in_segment = false;
let mut start_frame = 0;
let mut gap_count = 0;
for (i, &active) in activity.iter().enumerate() {
if active {
if !in_segment {
// Start new segment
start_frame = i;
in_segment = true;
}
gap_count = 0;
} else if in_segment {
gap_count += 1;
if gap_count >= min_gap_frames {
// End segment
let end_frame = i - gap_count + 1;
let duration = end_frame - start_frame;
if duration >= min_frames && duration <= max_frames {
let start_ms = start_frame as u64 * hop_ms;
let end_ms = end_frame as u64 * hop_ms;
segments.push((start_ms, end_ms));
}
in_segment = false;
gap_count = 0;
}
}
}
// Handle segment at end of recording
if in_segment {
let end_frame = activity.len();
let duration = end_frame - start_frame;
if duration >= min_frames && duration <= max_frames {
let start_ms = start_frame as u64 * hop_ms;
let end_ms = end_frame as u64 * hop_ms;
segments.push((start_ms, end_ms));
}
}
segments
}
/// Assesses signal quality for a segment.
fn assess_quality(
&self,
samples: &[f32],
noise_floor: f32,
) -> (SignalQuality, f32, f32, f32) {
let rms = Self::calculate_rms(samples);
let peak = Self::calculate_peak(samples);
let zcr = Self::calculate_zcr(samples);
// Estimate SNR
let snr_db = if noise_floor > 0.0 {
20.0 * (rms / noise_floor).log10()
} else {
0.0
};
let quality = SignalQuality::from_metrics(snr_db, rms, zcr);
(quality, peak, rms, zcr)
}
}
impl Default for EnergySegmenter {
fn default() -> Self {
Self::new()
}
}
impl AudioSegmenter for EnergySegmenter {
#[instrument(skip(self, samples), fields(samples_len = samples.len(), sample_rate = sample_rate))]
fn segment(
&self,
samples: &[f32],
sample_rate: u32,
recording_id: RecordingId,
) -> Result<Vec<CallSegment>, AudioError> {
if samples.is_empty() {
return Ok(Vec::new());
}
let num_windows = (samples.len().saturating_sub(self.config.window_size))
/ self.config.hop_size
+ 1;
if num_windows == 0 {
return Ok(Vec::new());
}
debug!(num_windows = num_windows, "Starting energy analysis");
// Calculate energy for each window (parallel)
let energies: Vec<f32> = (0..num_windows)
.into_par_iter()
.map(|i| {
let start = i * self.config.hop_size;
let end = (start + self.config.window_size).min(samples.len());
Self::calculate_rms(&samples[start..end])
})
.collect();
// Smooth the energy envelope
let smoothed = self.smooth_envelope(&energies);
// Estimate noise floor
let noise_floor = self.estimate_noise_floor(&smoothed);
let threshold = noise_floor * self.config.energy_threshold_ratio;
debug!(
noise_floor = noise_floor,
threshold = threshold,
"Adaptive threshold calculated"
);
// Create binary activity signal
let activity: Vec<bool> = smoothed.iter().map(|&e| e > threshold).collect();
// Find segment boundaries
let boundaries = self.find_segments(&activity, sample_rate);
debug!(candidate_segments = boundaries.len(), "Found segment candidates");
// Create CallSegment entities with quality assessment
let segments: Vec<CallSegment> = boundaries
.into_par_iter()
.filter_map(|(start_ms, end_ms)| {
let start_sample = (start_ms as usize * sample_rate as usize) / 1000;
let end_sample = (end_ms as usize * sample_rate as usize) / 1000;
let end_sample = end_sample.min(samples.len());
if start_sample >= end_sample {
return None;
}
let segment_samples = &samples[start_sample..end_sample];
let (quality, peak, rms, zcr) =
self.assess_quality(segment_samples, noise_floor);
Some(
CallSegment::new(recording_id, start_ms, end_ms, peak, rms, quality)
.with_zero_crossing_rate(zcr),
)
})
.collect();
debug!(
final_segments = segments.len(),
high_quality = segments.iter().filter(|s| matches!(s.signal_quality, SignalQuality::High)).count(),
"Segmentation complete"
);
Ok(segments)
}
}
/// Spectral-based segmenter for more sophisticated detection.
///
/// Uses spectral features in addition to energy for detection.
pub struct SpectralSegmenter {
energy_segmenter: EnergySegmenter,
/// Frequency range of interest in Hz.
freq_range: (f32, f32),
}
impl SpectralSegmenter {
/// Creates a new SpectralSegmenter focused on bird frequencies.
#[must_use]
pub fn new() -> Self {
Self {
energy_segmenter: EnergySegmenter::new(),
freq_range: (1000.0, 10000.0), // Bird vocalization range
}
}
/// Sets the frequency range of interest.
#[must_use]
pub fn with_freq_range(mut self, min_hz: f32, max_hz: f32) -> Self {
self.freq_range = (min_hz, max_hz);
self
}
}
impl Default for SpectralSegmenter {
fn default() -> Self {
Self::new()
}
}
impl AudioSegmenter for SpectralSegmenter {
fn segment(
&self,
samples: &[f32],
sample_rate: u32,
recording_id: RecordingId,
) -> Result<Vec<CallSegment>, AudioError> {
// For now, delegate to energy segmenter
// Future: Add bandpass filtering for freq_range
self.energy_segmenter
.segment(samples, sample_rate, recording_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn generate_sine_wave(freq: f32, duration_s: f32, sample_rate: u32) -> Vec<f32> {
let num_samples = (duration_s * sample_rate as f32) as usize;
(0..num_samples)
.map(|i| {
let t = i as f32 / sample_rate as f32;
(2.0 * std::f32::consts::PI * freq * t).sin()
})
.collect()
}
fn generate_test_signal(sample_rate: u32) -> Vec<f32> {
let mut samples = Vec::new();
// 1s silence
samples.extend(vec![0.001f32; sample_rate as usize]);
// 0.5s tone
samples.extend(generate_sine_wave(1000.0, 0.5, sample_rate));
// 0.3s silence
samples.extend(vec![0.001f32; (sample_rate as f32 * 0.3) as usize]);
// 0.8s tone
samples.extend(generate_sine_wave(2000.0, 0.8, sample_rate));
// 0.5s silence
samples.extend(vec![0.001f32; (sample_rate as f32 * 0.5) as usize]);
samples
}
#[test]
fn test_rms_calculation() {
let samples = vec![0.5, -0.5, 0.5, -0.5];
let rms = EnergySegmenter::calculate_rms(&samples);
assert!((rms - 0.5).abs() < 0.001);
}
#[test]
fn test_peak_calculation() {
let samples = vec![0.3, -0.8, 0.5, -0.2];
let peak = EnergySegmenter::calculate_peak(&samples);
assert!((peak - 0.8).abs() < 0.001);
}
#[test]
fn test_zcr_calculation() {
// Pure sine wave has high ZCR
let sine: Vec<f32> = (0..100)
.map(|i| (i as f32 * 0.5).sin())
.collect();
let zcr = EnergySegmenter::calculate_zcr(&sine);
assert!(zcr > 0.0);
assert!(zcr < 1.0);
}
#[test]
fn test_segmentation() {
let segmenter = EnergySegmenter::new();
let samples = generate_test_signal(32000);
let recording_id = RecordingId::new();
let segments = segmenter.segment(&samples, 32000, recording_id).unwrap();
// Should detect 2 segments
assert_eq!(segments.len(), 2);
// First segment should be around 1000-1500ms
assert!(segments[0].start_ms >= 900 && segments[0].start_ms <= 1100);
// Second segment should be around 1800-2600ms
assert!(segments[1].start_ms >= 1700);
}
#[test]
fn test_empty_input() {
let segmenter = EnergySegmenter::new();
let recording_id = RecordingId::new();
let segments = segmenter.segment(&[], 32000, recording_id).unwrap();
assert!(segments.is_empty());
}
#[test]
fn test_silent_input() {
let segmenter = EnergySegmenter::new();
let recording_id = RecordingId::new();
let silence = vec![0.0f32; 32000];
let segments = segmenter.segment(&silence, 32000, recording_id).unwrap();
assert!(segments.is_empty());
}
#[test]
fn test_config_customization() {
let config = SegmenterConfig {
min_segment_ms: 200,
max_segment_ms: 5000,
energy_threshold_ratio: 2.0,
..Default::default()
};
let segmenter = EnergySegmenter::with_config(config);
assert_eq!(segmenter.config.min_segment_ms, 200);
}
#[test]
fn test_signal_quality_assessment() {
// High quality signal (good SNR)
let high_snr = vec![0.5f32; 1000];
let noise_floor = 0.01;
let segmenter = EnergySegmenter::new();
let (quality, _, _, _) = segmenter.assess_quality(&high_snr, noise_floor);
assert!(matches!(quality, SignalQuality::High | SignalQuality::Medium));
// Low quality signal (low SNR)
let low_snr = vec![0.02f32; 1000];
let (quality, _, _, _) = segmenter.assess_quality(&low_snr, noise_floor);
assert!(matches!(quality, SignalQuality::Low | SignalQuality::Noise | SignalQuality::Medium));
}
}

View File

@@ -0,0 +1,70 @@
//! # sevensense-audio
//!
//! Audio processing and segmentation for the 7sense bioacoustics platform.
//!
//! This crate provides:
//! - Audio file decoding (WAV, FLAC, MP3, Ogg)
//! - Sample rate conversion and normalization
//! - Spectrogram generation
//! - Segment detection and extraction
//! - Audio quality analysis
//!
//! ## Architecture
//!
//! The crate follows Domain-Driven Design with clean architecture:
//! - **Domain Layer**: Core entities (Recording, CallSegment) and repository traits
//! - **Application Layer**: Use cases and services (AudioIngestionService)
//! - **Infrastructure Layer**: Technical implementations (file readers, resamplers)
//!
//! ## Example Usage
//!
//! ```rust,no_run
//! use sevensense_audio::application::AudioIngestionService;
//! use sevensense_audio::infrastructure::{SymphoniaFileReader, RubatoResampler, EnergySegmenter};
//! use std::path::Path;
//! use std::sync::Arc;
//!
//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
//! // Create infrastructure components
//! let reader = Arc::new(SymphoniaFileReader::new());
//! let resampler = Arc::new(RubatoResampler::new(32000)?);
//! let segmenter = Arc::new(EnergySegmenter::default());
//!
//! // Create the service
//! let service = AudioIngestionService::new(reader, resampler, segmenter);
//!
//! // Ingest an audio file
//! let mut recording = service.ingest_file(Path::new("recording.wav")).await?;
//!
//! // Segment the recording to find calls
//! let segments = service.segment_recording(&mut recording).await?;
//! println!("Found {} call segments", segments.len());
//! # Ok(())
//! # }
//! ```
#![warn(missing_docs)]
#![warn(clippy::all)]
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)]
pub mod domain;
pub mod application;
pub mod infrastructure;
pub mod spectrogram;
// Re-export main types
pub use domain::entities::{Recording, CallSegment, SignalQuality};
pub use domain::repository::RecordingRepository;
pub use application::services::AudioIngestionService;
pub use application::error::{AudioError, AudioResult};
pub use spectrogram::{MelSpectrogram, SpectrogramConfig};
/// Standard target sample rate for all processing (32 kHz).
pub const TARGET_SAMPLE_RATE: u32 = 32_000;
/// Standard segment duration for analysis (5 seconds).
pub const STANDARD_SEGMENT_DURATION_MS: u64 = 5_000;
/// Crate version information.
pub const VERSION: &str = env!("CARGO_PKG_VERSION");

View File

@@ -0,0 +1,459 @@
//! Mel spectrogram computation for audio feature extraction.
//!
//! This module provides efficient spectrogram computation using FFT
//! and mel-scale filterbanks, producing features suitable for ML models.
use ndarray::{Array2, Axis};
use rayon::prelude::*;
use realfft::RealFftPlanner;
use std::f32::consts::PI;
use tracing::{debug, instrument};
use crate::AudioError;
/// Configuration for spectrogram computation.
#[derive(Debug, Clone)]
pub struct SpectrogramConfig {
/// Number of mel frequency bands.
pub n_mels: usize,
/// FFT window size in samples.
pub n_fft: usize,
/// Hop size between frames in samples.
pub hop_length: usize,
/// Sample rate of the input audio.
pub sample_rate: u32,
/// Minimum frequency for mel filterbank (Hz).
pub f_min: f32,
/// Maximum frequency for mel filterbank (Hz).
pub f_max: f32,
/// Whether to apply log scaling.
pub log_scale: bool,
/// Reference value for dB conversion.
pub ref_db: f32,
/// Minimum value for log scaling (avoids log(0)).
pub min_value: f32,
}
impl Default for SpectrogramConfig {
fn default() -> Self {
Self {
n_mels: 128,
n_fft: 2048,
hop_length: 512,
sample_rate: 32_000,
f_min: 0.0,
f_max: 16_000.0, // Nyquist for 32kHz
log_scale: true,
ref_db: 1.0,
min_value: 1e-10,
}
}
}
impl SpectrogramConfig {
/// Creates a config optimized for 5-second segments producing 500 frames.
///
/// For 32kHz audio:
/// - 5s = 160,000 samples
/// - hop_length = 320 gives ~500 frames
#[must_use]
pub fn for_5s_segment() -> Self {
Self {
n_mels: 128,
n_fft: 2048,
hop_length: 320, // 160000 / 320 = 500 frames
sample_rate: 32_000,
f_min: 500.0, // Filter out very low frequencies
f_max: 15_000.0, // Most bird calls below 15kHz
log_scale: true,
ref_db: 1.0,
min_value: 1e-10,
}
}
/// Creates a config for variable-length audio.
#[must_use]
pub fn with_target_frames(target_frames: usize, duration_ms: u64, sample_rate: u32) -> Self {
let total_samples = (duration_ms as usize * sample_rate as usize) / 1000;
let hop_length = total_samples / target_frames;
Self {
hop_length: hop_length.max(1),
sample_rate,
..Self::default()
}
}
}
/// A computed mel spectrogram.
#[derive(Debug, Clone)]
pub struct MelSpectrogram {
/// Spectrogram data (n_mels x n_frames).
pub data: Array2<f32>,
/// Configuration used to compute this spectrogram.
pub config: SpectrogramConfig,
/// Duration of the source audio in milliseconds.
pub duration_ms: u64,
}
impl MelSpectrogram {
/// Computes a mel spectrogram from audio samples.
///
/// # Arguments
/// * `samples` - Mono audio samples
/// * `config` - Spectrogram configuration
///
/// # Returns
/// A MelSpectrogram with shape (n_mels, n_frames).
#[instrument(skip(samples), fields(samples_len = samples.len()))]
pub fn compute(samples: &[f32], config: SpectrogramConfig) -> Result<Self, AudioError> {
if samples.is_empty() {
return Err(AudioError::invalid_data("Cannot compute spectrogram of empty audio"));
}
let duration_ms = (samples.len() as u64 * 1000) / u64::from(config.sample_rate);
// Compute STFT
let stft = Self::stft(samples, config.n_fft, config.hop_length)?;
// Compute mel filterbank
let mel_filterbank = Self::create_mel_filterbank(
config.n_mels,
config.n_fft,
config.sample_rate,
config.f_min,
config.f_max,
);
// Apply mel filterbank
let n_frames = stft.ncols();
let mut mel_spec = Array2::zeros((config.n_mels, n_frames));
for (frame_idx, frame) in stft.axis_iter(Axis(1)).enumerate() {
for (mel_idx, filter) in mel_filterbank.axis_iter(Axis(0)).enumerate() {
let energy: f32 = frame
.iter()
.zip(filter.iter())
.map(|(s, f)| s * f)
.sum();
mel_spec[[mel_idx, frame_idx]] = energy.max(config.min_value);
}
}
// Apply log scaling if requested
if config.log_scale {
mel_spec.mapv_inplace(|x| 10.0 * (x / config.ref_db).log10());
}
debug!(
n_mels = config.n_mels,
n_frames = n_frames,
duration_ms = duration_ms,
"Spectrogram computed"
);
Ok(Self {
data: mel_spec,
config,
duration_ms,
})
}
/// Returns the shape as (n_mels, n_frames).
#[must_use]
pub fn shape(&self) -> (usize, usize) {
(self.data.nrows(), self.data.ncols())
}
/// Returns the number of mel bands.
#[must_use]
pub fn n_mels(&self) -> usize {
self.data.nrows()
}
/// Returns the number of time frames.
#[must_use]
pub fn n_frames(&self) -> usize {
self.data.ncols()
}
/// Extracts a time slice of the spectrogram.
#[must_use]
pub fn slice_frames(&self, start: usize, end: usize) -> Array2<f32> {
let end = end.min(self.n_frames());
let start = start.min(end);
self.data.slice(ndarray::s![.., start..end]).to_owned()
}
/// Normalizes the spectrogram to zero mean and unit variance per mel band.
pub fn normalize(&mut self) {
for mut row in self.data.axis_iter_mut(Axis(0)) {
let mean = row.mean().unwrap_or(0.0);
let std = row.std(0.0);
if std > 1e-6 {
row.mapv_inplace(|x| (x - mean) / std);
} else {
row.mapv_inplace(|x| x - mean);
}
}
}
/// Returns the raw data as a flat vector (row-major order).
#[must_use]
pub fn to_vec(&self) -> Vec<f32> {
self.data.iter().copied().collect()
}
/// Computes Short-Time Fourier Transform.
fn stft(
samples: &[f32],
n_fft: usize,
hop_length: usize,
) -> Result<Array2<f32>, AudioError> {
let n_frames = (samples.len().saturating_sub(n_fft)) / hop_length + 1;
if n_frames == 0 {
return Err(AudioError::invalid_data(
"Audio too short for FFT window size",
));
}
let n_bins = n_fft / 2 + 1;
let mut planner = RealFftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(n_fft);
// Pre-compute Hann window
let window: Vec<f32> = (0..n_fft)
.map(|i| 0.5 * (1.0 - (2.0 * PI * i as f32 / n_fft as f32).cos()))
.collect();
// Compute STFT frames in parallel
let frames: Vec<Vec<f32>> = (0..n_frames)
.into_par_iter()
.map(|frame_idx| {
let start = frame_idx * hop_length;
let mut input = vec![0.0f32; n_fft];
// Copy and window the input
for (i, &w) in window.iter().enumerate() {
if start + i < samples.len() {
input[i] = samples[start + i] * w;
}
}
// Perform FFT
let mut spectrum = fft.make_output_vec();
let mut scratch = fft.make_scratch_vec();
// Clone fft for thread safety
let fft = RealFftPlanner::<f32>::new().plan_fft_forward(n_fft);
fft.process_with_scratch(&mut input, &mut spectrum, &mut scratch)
.ok();
// Compute magnitude spectrum
spectrum
.iter()
.take(n_bins)
.map(|c| (c.re * c.re + c.im * c.im).sqrt())
.collect()
})
.collect();
// Assemble into 2D array
let mut stft = Array2::zeros((n_bins, n_frames));
for (frame_idx, frame) in frames.into_iter().enumerate() {
for (bin_idx, &value) in frame.iter().enumerate() {
stft[[bin_idx, frame_idx]] = value;
}
}
Ok(stft)
}
/// Creates a mel filterbank matrix.
fn create_mel_filterbank(
n_mels: usize,
n_fft: usize,
sample_rate: u32,
f_min: f32,
f_max: f32,
) -> Array2<f32> {
let n_bins = n_fft / 2 + 1;
// Convert frequency to mel scale
let mel_min = Self::hz_to_mel(f_min);
let mel_max = Self::hz_to_mel(f_max);
// Create mel points equally spaced in mel scale
let mel_points: Vec<f32> = (0..=n_mels + 1)
.map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32)
.collect();
// Convert back to Hz
let hz_points: Vec<f32> = mel_points.iter().map(|&m| Self::mel_to_hz(m)).collect();
// Convert to FFT bin indices
let bin_points: Vec<usize> = hz_points
.iter()
.map(|&f| {
let bin = (f * n_fft as f32 / sample_rate as f32).round() as usize;
bin.min(n_bins - 1)
})
.collect();
// Create filterbank matrix
let mut filterbank = Array2::zeros((n_mels, n_bins));
for m in 0..n_mels {
let left = bin_points[m];
let center = bin_points[m + 1];
let right = bin_points[m + 2];
// Rising slope
for k in left..center {
if center != left {
filterbank[[m, k]] = (k - left) as f32 / (center - left) as f32;
}
}
// Falling slope
for k in center..=right {
if right != center {
filterbank[[m, k]] = (right - k) as f32 / (right - center) as f32;
}
}
}
filterbank
}
/// Converts frequency from Hz to mel scale.
fn hz_to_mel(hz: f32) -> f32 {
2595.0 * (1.0 + hz / 700.0).log10()
}
/// Converts frequency from mel scale to Hz.
fn mel_to_hz(mel: f32) -> f32 {
700.0 * (10.0f32.powf(mel / 2595.0) - 1.0)
}
}
/// Batch spectrogram computation for multiple segments.
pub struct SpectrogramBatch;
impl SpectrogramBatch {
/// Computes spectrograms for multiple audio segments in parallel.
pub fn compute_batch(
segments: &[Vec<f32>],
config: &SpectrogramConfig,
) -> Result<Vec<MelSpectrogram>, AudioError> {
segments
.par_iter()
.map(|samples| MelSpectrogram::compute(samples, config.clone()))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn generate_sine_wave(freq: f32, duration_s: f32, sample_rate: u32) -> Vec<f32> {
let num_samples = (duration_s * sample_rate as f32) as usize;
(0..num_samples)
.map(|i| {
let t = i as f32 / sample_rate as f32;
(2.0 * PI * freq * t).sin()
})
.collect()
}
#[test]
fn test_spectrogram_config_default() {
let config = SpectrogramConfig::default();
assert_eq!(config.n_mels, 128);
assert_eq!(config.n_fft, 2048);
}
#[test]
fn test_spectrogram_5s_config() {
let config = SpectrogramConfig::for_5s_segment();
assert_eq!(config.hop_length, 320);
}
#[test]
fn test_mel_conversion() {
let hz = 1000.0;
let mel = MelSpectrogram::hz_to_mel(hz);
let hz_back = MelSpectrogram::mel_to_hz(mel);
assert!((hz - hz_back).abs() < 0.01);
}
#[test]
fn test_spectrogram_computation() {
let samples = generate_sine_wave(1000.0, 1.0, 32000);
let config = SpectrogramConfig::default();
let spec = MelSpectrogram::compute(&samples, config).unwrap();
assert_eq!(spec.n_mels(), 128);
assert!(spec.n_frames() > 0);
}
#[test]
fn test_spectrogram_5s_segment() {
// 5 seconds at 32kHz = 160,000 samples
let samples = generate_sine_wave(2000.0, 5.0, 32000);
let config = SpectrogramConfig::for_5s_segment();
let spec = MelSpectrogram::compute(&samples, config).unwrap();
assert_eq!(spec.n_mels(), 128);
// Should be approximately 500 frames
assert!((spec.n_frames() as i32 - 500).abs() < 10);
}
#[test]
fn test_spectrogram_normalization() {
let samples = generate_sine_wave(1000.0, 1.0, 32000);
let config = SpectrogramConfig::default();
let mut spec = MelSpectrogram::compute(&samples, config).unwrap();
spec.normalize();
// Check that at least one row has roughly zero mean
let first_row = spec.data.row(0);
let mean = first_row.mean().unwrap_or(1.0);
assert!(mean.abs() < 0.1);
}
#[test]
fn test_spectrogram_slice() {
let samples = generate_sine_wave(1000.0, 2.0, 32000);
let config = SpectrogramConfig::default();
let spec = MelSpectrogram::compute(&samples, config).unwrap();
let slice = spec.slice_frames(0, 10);
assert_eq!(slice.ncols(), 10);
assert_eq!(slice.nrows(), spec.n_mels());
}
#[test]
fn test_empty_input_error() {
let config = SpectrogramConfig::default();
let result = MelSpectrogram::compute(&[], config);
assert!(result.is_err());
}
#[test]
fn test_batch_computation() {
let segment1 = generate_sine_wave(1000.0, 1.0, 32000);
let segment2 = generate_sine_wave(2000.0, 1.0, 32000);
let segments = vec![segment1, segment2];
let config = SpectrogramConfig::default();
let specs = SpectrogramBatch::compute_batch(&segments, &config).unwrap();
assert_eq!(specs.len(), 2);
}
}