Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,270 @@
//! API error types and HTTP response handling.
//!
//! This module provides a unified error type for all API endpoints with
//! proper HTTP status code mapping and JSON error responses.
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde::Serialize;
use thiserror::Error;
use utoipa::ToSchema;
use crate::services::{AnalysisError, AudioError, EmbeddingError, InterpretationError, VectorError};
/// Unified API error type.
#[derive(Debug, Error)]
pub enum ApiError {
/// Resource not found (404)
#[error("Resource not found: {0}")]
NotFound(String),
/// Bad request with validation errors (400)
#[error("Bad request: {0}")]
BadRequest(String),
/// Unauthorized access (401)
#[error("Unauthorized: {0}")]
Unauthorized(String),
/// Forbidden access (403)
#[error("Forbidden: {0}")]
Forbidden(String),
/// Conflict with existing resource (409)
#[error("Conflict: {0}")]
Conflict(String),
/// Payload too large (413)
#[error("Payload too large: {0}")]
PayloadTooLarge(String),
/// Unsupported media type (415)
#[error("Unsupported media type: {0}")]
UnsupportedMediaType(String),
/// Rate limit exceeded (429)
#[error("Rate limit exceeded")]
RateLimitExceeded,
/// Internal server error (500)
#[error("Internal error: {0}")]
Internal(String),
/// Service unavailable (503)
#[error("Service unavailable: {0}")]
ServiceUnavailable(String),
/// Audio processing error
#[error("Audio processing error: {0}")]
AudioProcessing(#[from] AudioError),
/// Embedding error
#[error("Embedding error: {0}")]
Embedding(#[from] EmbeddingError),
/// Vector index error
#[error("Vector index error: {0}")]
VectorIndex(#[from] VectorError),
/// Analysis error
#[error("Analysis error: {0}")]
Analysis(#[from] AnalysisError),
/// Interpretation error
#[error("Interpretation error: {0}")]
Interpretation(#[from] InterpretationError),
/// Generic anyhow error
#[error(transparent)]
Other(#[from] anyhow::Error),
}
/// JSON error response body.
#[derive(Debug, Serialize, ToSchema)]
pub struct ErrorResponse {
/// Error type identifier
#[schema(example = "not_found")]
pub error: String,
/// Human-readable error message
#[schema(example = "Recording with ID xyz not found")]
pub message: String,
/// Optional error details
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<serde_json::Value>,
/// Request ID for tracing
#[serde(skip_serializing_if = "Option::is_none")]
pub request_id: Option<String>,
}
impl ApiError {
/// Get the HTTP status code for this error.
#[must_use]
pub fn status_code(&self) -> StatusCode {
match self {
Self::NotFound(_) => StatusCode::NOT_FOUND,
Self::BadRequest(_) => StatusCode::BAD_REQUEST,
Self::Unauthorized(_) => StatusCode::UNAUTHORIZED,
Self::Forbidden(_) => StatusCode::FORBIDDEN,
Self::Conflict(_) => StatusCode::CONFLICT,
Self::PayloadTooLarge(_) => StatusCode::PAYLOAD_TOO_LARGE,
Self::UnsupportedMediaType(_) => StatusCode::UNSUPPORTED_MEDIA_TYPE,
Self::RateLimitExceeded => StatusCode::TOO_MANY_REQUESTS,
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::ServiceUnavailable(_) => StatusCode::SERVICE_UNAVAILABLE,
Self::AudioProcessing(_) => StatusCode::UNPROCESSABLE_ENTITY,
Self::Embedding(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::VectorIndex(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::Analysis(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::Interpretation(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::Other(_) => StatusCode::INTERNAL_SERVER_ERROR,
}
}
/// Get the error type identifier.
#[must_use]
pub fn error_type(&self) -> &'static str {
match self {
Self::NotFound(_) => "not_found",
Self::BadRequest(_) => "bad_request",
Self::Unauthorized(_) => "unauthorized",
Self::Forbidden(_) => "forbidden",
Self::Conflict(_) => "conflict",
Self::PayloadTooLarge(_) => "payload_too_large",
Self::UnsupportedMediaType(_) => "unsupported_media_type",
Self::RateLimitExceeded => "rate_limit_exceeded",
Self::Internal(_) => "internal_error",
Self::ServiceUnavailable(_) => "service_unavailable",
Self::AudioProcessing(_) => "audio_processing_error",
Self::Embedding(_) => "embedding_error",
Self::VectorIndex(_) => "vector_index_error",
Self::Analysis(_) => "analysis_error",
Self::Interpretation(_) => "interpretation_error",
Self::Other(_) => "internal_error",
}
}
/// Create a not found error for a specific resource type.
#[must_use]
pub fn not_found<T: std::fmt::Display>(resource: &str, id: T) -> Self {
Self::NotFound(format!("{resource} with ID {id} not found"))
}
/// Create a validation error with details.
#[must_use]
pub fn validation<T: Serialize>(message: &str, details: T) -> Self {
Self::BadRequest(format!(
"{}: {}",
message,
serde_json::to_string(&details).unwrap_or_default()
))
}
}
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
let status = self.status_code();
let error_type = self.error_type();
let message = self.to_string();
// Log internal errors
match &self {
Self::Internal(_)
| Self::Other(_)
| Self::Embedding(_)
| Self::VectorIndex(_)
| Self::Analysis(_)
| Self::Interpretation(_) => {
tracing::error!(error = %self, "Internal API error");
}
_ => {
tracing::debug!(error = %self, "API error response");
}
}
let body = ErrorResponse {
error: error_type.to_string(),
message,
details: None,
request_id: None,
};
(status, Json(body)).into_response()
}
}
/// Result type alias for API handlers.
pub type ApiResult<T> = Result<T, ApiError>;
/// Extension trait for adding context to errors.
pub trait ResultExt<T> {
/// Convert error to `ApiError` with context.
fn api_context(self, context: &str) -> ApiResult<T>;
/// Convert to not found error if None.
fn or_not_found(self, resource: &str, id: &str) -> ApiResult<T>;
}
impl<T, E: std::error::Error + Send + Sync + 'static> ResultExt<T> for Result<T, E> {
fn api_context(self, context: &str) -> ApiResult<T> {
self.map_err(|e| ApiError::Internal(format!("{context}: {e}")))
}
fn or_not_found(self, _resource: &str, _id: &str) -> ApiResult<T> {
self.map_err(|e| ApiError::Internal(e.to_string()))
}
}
impl<T> ResultExt<T> for Option<T> {
fn api_context(self, context: &str) -> ApiResult<T> {
self.ok_or_else(|| ApiError::Internal(format!("{context}: value was None")))
}
fn or_not_found(self, resource: &str, id: &str) -> ApiResult<T> {
self.ok_or_else(|| ApiError::not_found(resource, id))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_status_codes() {
assert_eq!(
ApiError::NotFound("test".into()).status_code(),
StatusCode::NOT_FOUND
);
assert_eq!(
ApiError::BadRequest("test".into()).status_code(),
StatusCode::BAD_REQUEST
);
assert_eq!(
ApiError::RateLimitExceeded.status_code(),
StatusCode::TOO_MANY_REQUESTS
);
}
#[test]
fn test_not_found_helper() {
let err = ApiError::not_found("Recording", "abc-123");
assert!(err.to_string().contains("Recording"));
assert!(err.to_string().contains("abc-123"));
}
#[test]
fn test_error_response_serialization() {
let response = ErrorResponse {
error: "not_found".into(),
message: "Resource not found".into(),
details: None,
request_id: Some("req-123".into()),
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("not_found"));
assert!(json.contains("req-123"));
}
}

View File

@@ -0,0 +1,120 @@
//! GraphQL API module for 7sense.
//!
//! This module provides a flexible GraphQL API with:
//! - Query operations for recordings, segments, clusters, and evidence
//! - Mutations for ingestion and labeling
//! - Subscriptions for real-time processing updates
//!
//! ## Schema
//!
//! The schema is defined using `async-graphql` with automatic type generation.
//! Access the GraphQL playground at `/graphql` when enabled.
pub mod schema;
pub mod types;
use async_graphql::Schema;
use async_graphql_axum::{GraphQLRequest, GraphQLResponse};
use axum::{
extract::State,
response::{Html, IntoResponse},
routing::get,
Router,
};
use crate::AppContext;
use schema::{MutationRoot, QueryRoot, SubscriptionRoot};
/// GraphQL schema type alias.
pub type ApiSchema = Schema<QueryRoot, MutationRoot, SubscriptionRoot>;
/// Build the GraphQL schema with the application context.
#[must_use]
pub fn build_schema(ctx: AppContext) -> ApiSchema {
Schema::build(QueryRoot, MutationRoot, SubscriptionRoot)
.data(ctx)
.finish()
}
/// Create the GraphQL router.
#[must_use]
pub fn create_router(ctx: AppContext) -> Router<AppContext> {
let schema = build_schema(ctx.clone());
Router::new()
.route("/", get(graphql_playground).post(graphql_handler))
.with_state(schema)
}
/// GraphQL request handler.
async fn graphql_handler(State(schema): State<ApiSchema>, req: GraphQLRequest) -> GraphQLResponse {
schema.execute(req.into_inner()).await.into()
}
/// GraphQL Playground HTML page.
#[allow(clippy::unused_async)]
async fn graphql_playground() -> impl IntoResponse {
Html(PLAYGROUND_HTML)
}
const PLAYGROUND_HTML: &str = r#"
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>7sense GraphQL Playground</title>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/graphql-playground-react@1.7.26/build/static/css/index.css" />
<script src="https://cdn.jsdelivr.net/npm/graphql-playground-react@1.7.26/build/static/js/middleware.js"></script>
</head>
<body>
<div id="root"></div>
<script>
window.addEventListener('load', function (event) {
GraphQLPlayground.init(document.getElementById('root'), {
endpoint: '/graphql',
settings: {
'editor.theme': 'dark',
'editor.fontSize': 14,
'request.credentials': 'include',
},
tabs: [
{
name: 'Example Queries',
endpoint: '/graphql',
query: `# 7sense GraphQL API
# Get all clusters
query ListClusters {
clusters {
id
label
size
density
}
}
# Find similar segments
query FindNeighbors($segmentId: UUID!, $k: Int) {
neighbors(segmentId: $segmentId, k: $k) {
segmentId
similarity
startTime
endTime
}
}
`
}
]
});
});
</script>
</body>
</html>
"#;
#[cfg(test)]
mod tests {
use super::*;
// Schema tests would go here with mock context
}

View File

@@ -0,0 +1,238 @@
//! GraphQL schema definitions for 7sense API.
//!
//! This module defines the Query, Mutation, and Subscription roots
//! for the GraphQL API.
use async_graphql::*;
use futures::Stream;
use uuid::Uuid;
use super::types::*;
use crate::{AppContext, ProcessingStatus};
/// Root query type for GraphQL API.
pub struct QueryRoot;
#[Object]
impl QueryRoot {
/// Find similar segments (neighbors) for a given segment.
async fn neighbors(
&self,
ctx: &Context<'_>,
segment_id: ID,
#[graphql(default = 10)] k: i32,
#[graphql(default)] min_similarity: f32,
) -> Result<Vec<Neighbor>> {
let app_ctx = ctx.data::<AppContext>()?;
let segment_uuid = Uuid::parse_str(segment_id.as_str())
.map_err(|_| Error::new("Invalid segment ID"))?;
// Get segment embedding
let embedding = app_ctx
.vector_index
.get_embedding(&segment_uuid)
.map_err(|e| Error::new(format!("Vector index error: {e}")))?
.ok_or_else(|| Error::new(format!("Segment {} not found", segment_id.as_str())))?;
// Search for neighbors
let results = app_ctx
.vector_index
.search(&embedding, k as usize, min_similarity)
.map_err(|e| Error::new(format!("Search error: {e}")))?;
// Convert to GraphQL types
let neighbors: Vec<Neighbor> = results
.into_iter()
.filter(|r| r.id != segment_uuid)
.map(|r| Neighbor {
segment_id: ID::from(r.id.to_string()),
recording_id: ID::from(r.recording_id.to_string()),
similarity: 1.0 - r.distance,
distance: r.distance,
start_time: r.start_time,
end_time: r.end_time,
species: r.species.map(|s| Species {
common_name: s.common_name,
scientific_name: s.scientific_name,
confidence: s.confidence,
}),
})
.collect();
Ok(neighbors)
}
/// List all discovered clusters.
async fn clusters(&self, ctx: &Context<'_>) -> Result<Vec<Cluster>> {
let app_ctx = ctx.data::<AppContext>()?;
let cluster_data = app_ctx
.cluster_engine
.get_all_clusters()
.map_err(|e| Error::new(format!("Analysis error: {e}")))?;
let clusters: Vec<Cluster> = cluster_data
.into_iter()
.map(|c| Cluster {
id: ID::from(c.id.to_string()),
label: c.label,
size: c.size as i32,
density: c.density,
exemplar_ids: c.exemplar_ids.into_iter().map(|id| ID::from(id.to_string())).collect(),
species_distribution: c
.species_distribution
.into_iter()
.map(|(name, count, percentage)| SpeciesCount {
name,
scientific_name: None,
count: count as i32,
percentage,
})
.collect(),
created_at: c.created_at,
})
.collect();
Ok(clusters)
}
/// Get a specific cluster by ID.
async fn cluster(&self, ctx: &Context<'_>, id: ID) -> Result<Option<Cluster>> {
let app_ctx = ctx.data::<AppContext>()?;
let cluster_uuid =
Uuid::parse_str(id.as_str()).map_err(|_| Error::new("Invalid cluster ID"))?;
let cluster_data = app_ctx
.cluster_engine
.get_cluster(&cluster_uuid)
.map_err(|e| Error::new(format!("Analysis error: {e}")))?;
Ok(cluster_data.map(|c| Cluster {
id: ID::from(c.id.to_string()),
label: c.label,
size: c.size as i32,
density: c.density,
exemplar_ids: c.exemplar_ids.into_iter().map(|id| ID::from(id.to_string())).collect(),
species_distribution: c
.species_distribution
.into_iter()
.map(|(name, count, percentage)| SpeciesCount {
name,
scientific_name: None,
count: count as i32,
percentage,
})
.collect(),
created_at: c.created_at,
}))
}
/// System health check.
async fn health(&self) -> HealthStatus {
HealthStatus {
status: "healthy".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
}
}
}
/// Root mutation type for GraphQL API.
pub struct MutationRoot;
#[Object]
impl MutationRoot {
/// Assign a label to a cluster.
async fn assign_label(
&self,
ctx: &Context<'_>,
cluster_id: ID,
label: String,
) -> Result<Cluster> {
let app_ctx = ctx.data::<AppContext>()?;
let cluster_uuid =
Uuid::parse_str(cluster_id.as_str()).map_err(|_| Error::new("Invalid cluster ID"))?;
let cluster_data = app_ctx
.cluster_engine
.assign_label(&cluster_uuid, &label)
.map_err(|e| Error::new(format!("Analysis error: {e}")))?
.ok_or_else(|| Error::new(format!("Cluster {} not found", cluster_id.as_str())))?;
Ok(Cluster {
id: ID::from(cluster_data.id.to_string()),
label: cluster_data.label,
size: cluster_data.size as i32,
density: cluster_data.density,
exemplar_ids: cluster_data.exemplar_ids.into_iter().map(|id| ID::from(id.to_string())).collect(),
species_distribution: cluster_data
.species_distribution
.into_iter()
.map(|(name, count, percentage)| SpeciesCount {
name,
scientific_name: None,
count: count as i32,
percentage,
})
.collect(),
created_at: cluster_data.created_at,
})
}
}
/// Root subscription type for GraphQL API.
pub struct SubscriptionRoot;
#[Subscription]
impl SubscriptionRoot {
/// Subscribe to processing status updates for a recording.
async fn processing_status(
&self,
ctx: &Context<'_>,
recording_id: ID,
) -> Result<impl Stream<Item = ProcessingUpdate>> {
let app_ctx = ctx.data::<AppContext>()?;
let recording_uuid = Uuid::parse_str(recording_id.as_str())
.map_err(|_| Error::new("Invalid recording ID"))?;
let mut rx = app_ctx.subscribe_events();
Ok(async_stream::stream! {
while let Ok(event) = rx.recv().await {
if event.recording_id == recording_uuid {
yield ProcessingUpdate {
recording_id: ID::from(event.recording_id.to_string()),
status: match event.status {
ProcessingStatus::Queued => ProcessingStatusGql::Queued,
ProcessingStatus::Loading => ProcessingStatusGql::Loading,
ProcessingStatus::Segmenting => ProcessingStatusGql::Segmenting,
ProcessingStatus::Embedding => ProcessingStatusGql::Embedding,
ProcessingStatus::Indexing => ProcessingStatusGql::Indexing,
ProcessingStatus::Analyzing => ProcessingStatusGql::Analyzing,
ProcessingStatus::Complete => ProcessingStatusGql::Complete,
ProcessingStatus::Failed => ProcessingStatusGql::Failed,
},
progress: event.progress,
message: event.message,
};
}
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_health_status() {
let status = HealthStatus {
status: "healthy".to_string(),
version: "0.1.0".to_string(),
};
assert_eq!(status.status, "healthy");
}
}

View File

@@ -0,0 +1,131 @@
//! GraphQL type definitions for 7sense API.
//!
//! This module contains all GraphQL object types, input types, and enums
//! used by the schema.
use async_graphql::*;
use chrono::{DateTime, Utc};
// ============================================================================
// Object Types
// ============================================================================
/// A similar segment found through vector search.
#[derive(Debug, Clone, SimpleObject)]
pub struct Neighbor {
/// Segment identifier
pub segment_id: ID,
/// Parent recording ID
pub recording_id: ID,
/// Similarity score (0.0 to 1.0)
pub similarity: f32,
/// Distance in embedding space
pub distance: f32,
/// Segment start time
pub start_time: f64,
/// Segment end time
pub end_time: f64,
/// Detected species
pub species: Option<Species>,
}
/// Species information.
#[derive(Debug, Clone, SimpleObject)]
pub struct Species {
/// Common name
pub common_name: String,
/// Scientific name (binomial)
pub scientific_name: Option<String>,
/// Detection confidence
pub confidence: f32,
}
/// A cluster of similar calls.
#[derive(Debug, Clone, SimpleObject)]
pub struct Cluster {
/// Cluster identifier
pub id: ID,
/// Human-assigned label
pub label: Option<String>,
/// Number of segments in cluster
pub size: i32,
/// Cluster density/compactness
pub density: f32,
/// Representative segment IDs
pub exemplar_ids: Vec<ID>,
/// Species distribution
pub species_distribution: Vec<SpeciesCount>,
/// Creation timestamp
pub created_at: DateTime<Utc>,
}
/// Species count within a cluster.
#[derive(Debug, Clone, SimpleObject)]
pub struct SpeciesCount {
/// Species common name
pub name: String,
/// Scientific name
pub scientific_name: Option<String>,
/// Count of segments
pub count: i32,
/// Percentage of cluster
pub percentage: f64,
}
/// Processing status update.
#[derive(Debug, Clone, SimpleObject)]
pub struct ProcessingUpdate {
/// Recording ID
pub recording_id: ID,
/// Current status
pub status: ProcessingStatusGql,
/// Progress (0.0 to 1.0)
pub progress: f32,
/// Status message
pub message: Option<String>,
}
/// Health status response.
#[derive(Debug, Clone, SimpleObject)]
pub struct HealthStatus {
/// Service status
pub status: String,
/// Version
pub version: String,
}
// ============================================================================
// Enums
// ============================================================================
/// Processing status stages.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Enum)]
pub enum ProcessingStatusGql {
/// Queued for processing
Queued,
/// Loading audio
Loading,
/// Segmenting
Segmenting,
/// Generating embeddings
Embedding,
/// Indexing vectors
Indexing,
/// Analyzing clusters
Analyzing,
/// Complete
Complete,
/// Failed
Failed,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_processing_status_enum() {
let status = ProcessingStatusGql::Embedding;
assert_eq!(status, ProcessingStatusGql::Embedding);
}
}

View File

@@ -0,0 +1,316 @@
//! # sevensense-api
//!
//! REST, GraphQL, and WebSocket API layer for 7sense bioacoustic analysis.
//!
//! This crate provides a comprehensive API for:
//! - Audio recording upload and processing
//! - Segment similarity search via vector embeddings
//! - Cluster discovery and labeling
//! - Evidence pack generation for interpretability
//! - Real-time processing status via WebSocket
//!
//! ## Architecture
//!
//! The API follows a layered architecture:
//! - **REST API** (`/api/v1/*`) - RESTful endpoints for CRUD operations
//! - **GraphQL** (`/graphql`) - Flexible query interface with subscriptions
//! - **WebSocket** (`/ws`) - Real-time updates for long-running operations
//!
//! ## Example
//!
//! ```rust,ignore
//! use sevensense_api::{AppBuilder, Config};
//!
//! #[tokio::main]
//! async fn main() -> anyhow::Result<()> {
//! let config = Config::from_env()?;
//! let app = AppBuilder::new(config).build().await?;
//!
//! axum::serve(listener, app).await?;
//! Ok(())
//! }
//! ```
#![warn(missing_docs)]
#![warn(clippy::all)]
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)]
#![allow(clippy::missing_errors_doc)]
#![allow(clippy::missing_panics_doc)]
pub mod error;
pub mod graphql;
pub mod openapi;
pub mod rest;
pub mod services;
pub mod websocket;
use std::sync::Arc;
use axum::Router;
use tokio::sync::broadcast;
use tower_http::{
compression::CompressionLayer,
trace::TraceLayer,
};
pub use services::{
AudioPipeline, ClusterEngine, EmbeddingModel, InterpretationEngine, VectorIndex,
};
/// Crate version information
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
/// Application configuration loaded from environment or config file.
#[derive(Debug, Clone)]
pub struct Config {
/// Server host address
pub host: String,
/// Server port
pub port: u16,
/// CORS allowed origins
pub cors_origins: Vec<String>,
/// Rate limit requests per second
pub rate_limit_rps: u32,
/// Maximum upload size in bytes
pub max_upload_size: usize,
/// Enable GraphQL playground
pub enable_playground: bool,
/// API key for authentication (optional)
pub api_key: Option<String>,
}
impl Default for Config {
fn default() -> Self {
Self {
host: "0.0.0.0".to_string(),
port: 8080,
cors_origins: vec!["*".to_string()],
rate_limit_rps: 100,
max_upload_size: 100 * 1024 * 1024, // 100MB
enable_playground: true,
api_key: None,
}
}
}
impl Config {
/// Load configuration from environment variables.
pub fn from_env() -> anyhow::Result<Self> {
dotenvy::dotenv().ok();
Ok(Self {
host: std::env::var("SEVENSENSE_HOST").unwrap_or_else(|_| "0.0.0.0".to_string()),
port: std::env::var("SEVENSENSE_PORT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(8080),
cors_origins: std::env::var("SEVENSENSE_CORS_ORIGINS")
.map(|s| s.split(',').map(String::from).collect())
.unwrap_or_else(|_| vec!["*".to_string()]),
rate_limit_rps: std::env::var("SEVENSENSE_RATE_LIMIT")
.ok()
.and_then(|r| r.parse().ok())
.unwrap_or(100),
max_upload_size: std::env::var("SEVENSENSE_MAX_UPLOAD")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(100 * 1024 * 1024),
enable_playground: std::env::var("SEVENSENSE_ENABLE_PLAYGROUND")
.map(|s| s == "true" || s == "1")
.unwrap_or(true),
api_key: std::env::var("SEVENSENSE_API_KEY").ok(),
})
}
}
/// Processing status event for WebSocket broadcasts.
#[derive(Debug, Clone, serde::Serialize)]
pub struct ProcessingEvent {
/// Recording identifier
pub recording_id: uuid::Uuid,
/// Current processing status
pub status: ProcessingStatus,
/// Progress percentage (0.0 to 1.0)
pub progress: f32,
/// Optional status message
pub message: Option<String>,
}
/// Processing status stages.
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize, utoipa::ToSchema)]
#[serde(rename_all = "snake_case")]
pub enum ProcessingStatus {
/// Recording queued for processing
Queued,
/// Loading audio file
Loading,
/// Detecting segments
Segmenting,
/// Generating embeddings
Embedding,
/// Adding to vector index
Indexing,
/// Running cluster analysis
Analyzing,
/// Processing complete
Complete,
/// Processing failed
Failed,
}
/// Shared application context accessible from all handlers.
#[derive(Clone)]
pub struct AppContext {
/// Audio processing pipeline
pub audio_pipeline: Arc<AudioPipeline>,
/// Embedding model for segment vectorization
pub embedding_model: Arc<EmbeddingModel>,
/// Vector index for similarity search
pub vector_index: Arc<VectorIndex>,
/// Cluster analysis engine
pub cluster_engine: Arc<ClusterEngine>,
/// Interpretation engine for evidence packs
pub interpretation_engine: Arc<InterpretationEngine>,
/// Broadcast channel for processing events
pub event_tx: broadcast::Sender<ProcessingEvent>,
/// Application configuration
pub config: Arc<Config>,
}
impl AppContext {
/// Create a new application context with all required services.
pub async fn new(config: Config) -> anyhow::Result<Self> {
// Initialize audio pipeline
let audio_pipeline = Arc::new(AudioPipeline::new(Default::default())?);
// Initialize embedding model
let embedding_model = Arc::new(EmbeddingModel::new(Default::default()).await?);
// Initialize vector index
let vector_index = Arc::new(VectorIndex::new(Default::default())?);
// Initialize cluster engine
let cluster_engine = Arc::new(ClusterEngine::new(Default::default())?);
// Initialize interpretation engine
let interpretation_engine = Arc::new(InterpretationEngine::new(Default::default())?);
// Create broadcast channel for events (capacity of 1024)
let (event_tx, _) = broadcast::channel(1024);
Ok(Self {
audio_pipeline,
embedding_model,
vector_index,
cluster_engine,
interpretation_engine,
event_tx,
config: Arc::new(config),
})
}
/// Get a receiver for processing events.
#[must_use]
pub fn subscribe_events(&self) -> broadcast::Receiver<ProcessingEvent> {
self.event_tx.subscribe()
}
/// Publish a processing event.
pub fn publish_event(&self, event: ProcessingEvent) {
// Ignore send errors (no receivers)
let _ = self.event_tx.send(event);
}
}
/// Builder for constructing the application router.
pub struct AppBuilder {
config: Config,
context: Option<AppContext>,
}
impl AppBuilder {
/// Create a new app builder with configuration.
#[must_use]
pub fn new(config: Config) -> Self {
Self {
config,
context: None,
}
}
/// Set a pre-built context (useful for testing).
#[must_use]
pub fn with_context(mut self, context: AppContext) -> Self {
self.context = Some(context);
self
}
/// Build the complete application router.
pub async fn build(self) -> anyhow::Result<Router> {
// Initialize context if not provided
let context = match self.context {
Some(ctx) => ctx,
None => AppContext::new(self.config.clone()).await?,
};
// Build REST routes
let rest_router = rest::routes::create_router(context.clone());
// Build GraphQL routes
let graphql_router = graphql::create_router(context.clone());
// Build WebSocket routes
let ws_router = websocket::create_router(context.clone());
// Build OpenAPI documentation routes
let openapi_router = openapi::create_router();
// Combine all routers
let app = Router::new()
.nest("/api/v1", rest_router)
.nest("/graphql", graphql_router)
.nest("/ws", ws_router)
.nest("/docs", openapi_router)
.layer(
tower::ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(CompressionLayer::new())
.layer(rest::middleware::cors_layer(&self.config)),
)
.with_state(context);
Ok(app)
}
}
/// Health check response.
#[derive(Debug, serde::Serialize, utoipa::ToSchema)]
pub struct HealthResponse {
/// Service status
pub status: String,
/// API version
pub version: String,
/// Server uptime in seconds
pub uptime_secs: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = Config::default();
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 8080);
assert!(config.enable_playground);
}
#[test]
fn test_processing_status_serialize() {
let status = ProcessingStatus::Embedding;
let json = serde_json::to_string(&status).unwrap();
assert_eq!(json, "\"embedding\"");
}
}

View File

@@ -0,0 +1,168 @@
//! 7sense API Server
//!
//! This is the main entry point for the 7sense bioacoustic analysis API server.
//! It provides REST, GraphQL, and WebSocket endpoints for audio processing,
//! similarity search, and cluster discovery.
//!
//! ## Usage
//!
//! ```bash
//! # Run with default settings
//! cargo run --release
//!
//! # With environment configuration
//! SEVENSENSE_PORT=3000 SEVENSENSE_API_KEY=secret cargo run --release
//! ```
//!
//! ## Endpoints
//!
//! - REST API: `http://localhost:8080/api/v1/`
//! - GraphQL: `http://localhost:8080/graphql`
//! - GraphQL Playground: `http://localhost:8080/graphql` (GET)
//! - WebSocket: `ws://localhost:8080/ws/`
//! - OpenAPI/Swagger: `http://localhost:8080/docs/swagger`
use std::net::SocketAddr;
use anyhow::Result;
use tokio::net::TcpListener;
use tokio::signal;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
use sevensense_api::{AppBuilder, Config};
#[tokio::main]
async fn main() -> Result<()> {
// Initialize tracing/logging
init_tracing();
// Load configuration
let config = Config::from_env()?;
tracing::info!(
host = %config.host,
port = %config.port,
"Starting 7sense API server"
);
// Build application
let app = AppBuilder::new(config.clone()).build().await?;
// Bind to address
let addr: SocketAddr = format!("{}:{}", config.host, config.port).parse()?;
let listener = TcpListener::bind(addr).await?;
tracing::info!(
address = %addr,
"7sense API server listening"
);
// Print startup banner
print_banner(&config);
// Run server with graceful shutdown
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await?;
tracing::info!("7sense API server shut down gracefully");
Ok(())
}
/// Initialize tracing subscriber with environment filter.
fn init_tracing() {
let env_filter = EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("sevensense_api=info,tower_http=info,axum=info"));
tracing_subscriber::registry()
.with(env_filter)
.with(
tracing_subscriber::fmt::layer()
.with_target(true)
.with_thread_ids(false)
.with_file(false)
.with_line_number(false),
)
.init();
}
/// Print startup banner with endpoint information.
fn print_banner(config: &Config) {
let base_url = format!("http://{}:{}", config.host, config.port);
println!();
println!("========================================");
println!(" 7sense Bioacoustic API");
println!("========================================");
println!();
println!(" REST API: {base_url}/api/v1/");
println!(" GraphQL: {base_url}/graphql");
println!(" WebSocket: ws://{}:{}/ws/", config.host, config.port);
println!(" Swagger UI: {base_url}/docs/swagger");
println!(" OpenAPI: {base_url}/docs/openapi.json");
println!();
println!(" Health: {base_url}/api/v1/health");
println!();
if config.api_key.is_some() {
println!(" Auth: API key required (Bearer token)");
} else {
println!(" Auth: No authentication (development mode)");
}
if config.enable_playground {
println!(" Playground: Enabled");
}
println!();
println!(" Rate limit: {} req/sec", config.rate_limit_rps);
println!(
" Max upload: {} MB",
config.max_upload_size / 1024 / 1024
);
println!();
println!("========================================");
println!();
}
/// Create shutdown signal handler for graceful shutdown.
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("Failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("Failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
() = ctrl_c => {
tracing::info!("Received Ctrl+C, initiating shutdown");
}
() = terminate => {
tracing::info!("Received SIGTERM, initiating shutdown");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_parsing() {
let config = Config::default();
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 8080);
}
}

View File

@@ -0,0 +1,148 @@
//! OpenAPI/Swagger documentation generation.
//!
//! This module generates OpenAPI 3.0 documentation for the REST API
//! and serves a Swagger UI at `/docs`.
use axum::{routing::get, Json, Router};
use utoipa::{
openapi::{
security::{HttpAuthScheme, HttpBuilder, SecurityScheme},
OpenApi as OpenApiDoc,
},
Modify, OpenApi,
};
use utoipa_swagger_ui::SwaggerUi;
use crate::error::ErrorResponse;
use crate::rest::handlers::*;
use crate::AppContext;
use crate::HealthResponse;
/// OpenAPI documentation struct.
#[derive(OpenApi)]
#[openapi(
info(
title = "7sense Bioacoustic Analysis API",
version = "1.0.0",
description = "REST API for bioacoustic recording analysis, similarity search, and cluster discovery.",
contact(
name = "7sense Team",
url = "https://github.com/vibecast/vibecast"
),
license(
name = "MIT OR Apache-2.0",
url = "https://opensource.org/licenses/MIT"
)
),
servers(
(url = "/api/v1", description = "API v1")
),
paths(
upload_recording,
get_recording,
get_neighbors,
list_clusters,
get_cluster,
assign_cluster_label,
get_evidence_pack,
generate_evidence_pack,
search,
health_check,
),
components(
schemas(
Recording,
UploadResponse,
Neighbor,
NeighborParams,
Cluster,
SpeciesCount,
EvidencePack,
SegmentSummary,
NeighborEvidence,
FeatureContribution,
AcousticFeature,
EvidenceVisualizations,
SearchQuery,
SearchResults,
SearchQueryEcho,
SearchResult,
AssignLabelRequest,
GenerateEvidenceRequest,
ErrorResponse,
HealthResponse,
)
),
modifiers(&SecurityAddon),
tags(
(name = "recordings", description = "Recording upload and management"),
(name = "segments", description = "Segment analysis and similarity search"),
(name = "clusters", description = "Cluster discovery and labeling"),
(name = "evidence", description = "Evidence packs for interpretability"),
(name = "search", description = "Semantic search"),
(name = "system", description = "System health and status"),
)
)]
pub struct ApiDoc;
/// Security scheme addon.
struct SecurityAddon;
impl Modify for SecurityAddon {
fn modify(&self, openapi: &mut OpenApiDoc) {
if let Some(components) = openapi.components.as_mut() {
components.add_security_scheme(
"bearer_auth",
SecurityScheme::Http(
HttpBuilder::new()
.scheme(HttpAuthScheme::Bearer)
.bearer_format("JWT")
.description(Some("API key authentication"))
.build(),
),
);
}
}
}
/// Create the OpenAPI documentation router.
#[must_use]
pub fn create_router() -> Router<AppContext> {
Router::new()
// Raw OpenAPI JSON
.route("/openapi.json", get(openapi_json))
// Swagger UI - merge directly
.merge(SwaggerUi::new("/docs/swagger-ui")
.url("/docs/openapi.json", ApiDoc::openapi()))
}
/// Get raw OpenAPI JSON.
#[allow(clippy::unused_async)]
async fn openapi_json() -> Json<utoipa::openapi::OpenApi> {
Json(ApiDoc::openapi())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_openapi_generation() {
let doc = ApiDoc::openapi();
assert_eq!(doc.info.title, "7sense Bioacoustic Analysis API");
assert!(!doc.paths.paths.is_empty());
}
#[test]
fn test_openapi_has_required_paths() {
let doc = ApiDoc::openapi();
let paths: Vec<&str> = doc.paths.paths.keys().map(std::string::String::as_str).collect();
assert!(paths.contains(&"/recordings"));
assert!(paths.contains(&"/segments/{id}/neighbors"));
assert!(paths.contains(&"/clusters"));
assert!(paths.contains(&"/evidence/{id}"));
assert!(paths.contains(&"/search"));
assert!(paths.contains(&"/health"));
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,266 @@
//! REST API middleware for cross-cutting concerns.
//!
//! This module provides:
//! - CORS configuration
//! - Rate limiting
//! - API key authentication
//! - Request logging
use std::{
net::SocketAddr,
sync::Arc,
time::Duration,
};
use axum::{
body::Body,
extract::{ConnectInfo, State},
http::{header, HeaderMap, Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
Json,
};
use governor::{
clock::DefaultClock,
state::{InMemoryState, NotKeyed},
Quota, RateLimiter,
};
use tower_http::cors::{Any, CorsLayer};
use crate::{error::ErrorResponse, AppContext, Config};
/// Create CORS layer based on configuration.
pub fn cors_layer(config: &Config) -> CorsLayer {
let cors = CorsLayer::new()
.allow_methods(Any)
.allow_headers(Any)
.max_age(Duration::from_secs(3600));
if config.cors_origins.contains(&"*".to_string()) {
cors.allow_origin(Any)
} else {
// Parse origins - in production, validate these
cors.allow_origin(Any) // Simplified for now
}
}
/// Rate limiter type alias.
pub type SharedRateLimiter = Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>;
/// Create a rate limiter with the configured limit.
pub fn create_rate_limiter(rps: u32) -> SharedRateLimiter {
let quota = Quota::per_second(std::num::NonZeroU32::new(rps).unwrap());
Arc::new(RateLimiter::direct(quota))
}
/// Rate limiting middleware.
pub async fn rate_limit_middleware(
State(limiter): State<SharedRateLimiter>,
request: Request<Body>,
next: Next,
) -> Response {
match limiter.check() {
Ok(_) => next.run(request).await,
Err(_) => {
let body = ErrorResponse {
error: "rate_limit_exceeded".into(),
message: "Too many requests. Please slow down.".into(),
details: None,
request_id: None,
};
(StatusCode::TOO_MANY_REQUESTS, Json(body)).into_response()
}
}
}
/// API key authentication middleware.
pub async fn auth_middleware(
State(ctx): State<AppContext>,
headers: HeaderMap,
request: Request<Body>,
next: Next,
) -> Response {
// If no API key configured, allow all requests
let Some(expected_key) = &ctx.config.api_key else {
return next.run(request).await;
};
// Check Authorization header
let auth_header = headers
.get(header::AUTHORIZATION)
.and_then(|h| h.to_str().ok());
match auth_header {
Some(auth) if auth.starts_with("Bearer ") => {
let provided_key = auth.trim_start_matches("Bearer ").trim();
if provided_key == expected_key {
next.run(request).await
} else {
unauthorized_response("Invalid API key")
}
}
Some(_) => unauthorized_response("Invalid authorization format. Use 'Bearer <api_key>'"),
None => unauthorized_response("Missing Authorization header"),
}
}
fn unauthorized_response(message: &str) -> Response {
let body = ErrorResponse {
error: "unauthorized".into(),
message: message.into(),
details: None,
request_id: None,
};
(StatusCode::UNAUTHORIZED, Json(body)).into_response()
}
/// Request logging middleware that adds structured logging.
pub async fn logging_middleware(
headers: HeaderMap,
request: Request<Body>,
next: Next,
) -> Response {
let method = request.method().clone();
let uri = request.uri().clone();
let request_id = headers
.get("x-request-id")
.and_then(|h| h.to_str().ok())
.map(String::from);
let start = std::time::Instant::now();
let response = next.run(request).await;
let latency = start.elapsed();
let status = response.status();
tracing::info!(
method = %method,
uri = %uri,
status = %status.as_u16(),
latency_ms = %latency.as_millis(),
request_id = ?request_id,
"HTTP request"
);
response
}
/// Content type validation middleware for JSON endpoints.
pub async fn json_content_type_middleware(
headers: HeaderMap,
request: Request<Body>,
next: Next,
) -> Response {
// Only check POST/PUT/PATCH requests
if matches!(
request.method().as_str(),
"POST" | "PUT" | "PATCH"
) {
// Skip multipart endpoints
let path = request.uri().path();
if path.contains("/recordings") {
return next.run(request).await;
}
// Check content type
let content_type = headers
.get(header::CONTENT_TYPE)
.and_then(|h| h.to_str().ok());
match content_type {
Some(ct) if ct.contains("application/json") => next.run(request).await,
Some(ct) => {
let body = ErrorResponse {
error: "unsupported_media_type".into(),
message: format!("Expected application/json, got {}", ct),
details: None,
request_id: None,
};
(StatusCode::UNSUPPORTED_MEDIA_TYPE, Json(body)).into_response()
}
None => {
let body = ErrorResponse {
error: "unsupported_media_type".into(),
message: "Missing Content-Type header".into(),
details: None,
request_id: None,
};
(StatusCode::UNSUPPORTED_MEDIA_TYPE, Json(body)).into_response()
}
}
} else {
next.run(request).await
}
}
/// Request body size limit middleware.
pub struct BodyLimitMiddleware {
max_size: usize,
}
impl BodyLimitMiddleware {
pub fn new(max_size: usize) -> Self {
Self { max_size }
}
}
/// Extract client IP from request.
pub fn extract_client_ip(headers: &HeaderMap, connect_info: Option<&ConnectInfo<SocketAddr>>) -> Option<String> {
// Try X-Forwarded-For first (for proxied requests)
if let Some(forwarded) = headers
.get("x-forwarded-for")
.and_then(|h| h.to_str().ok())
{
// Take the first IP in the chain
if let Some(ip) = forwarded.split(',').next() {
return Some(ip.trim().to_string());
}
}
// Try X-Real-IP
if let Some(real_ip) = headers
.get("x-real-ip")
.and_then(|h| h.to_str().ok())
{
return Some(real_ip.to_string());
}
// Fall back to connection info
connect_info.map(|ci| ci.0.ip().to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cors_layer_creation() {
let config = Config::default();
let _layer = cors_layer(&config);
}
#[test]
fn test_rate_limiter_creation() {
let limiter = create_rate_limiter(100);
assert!(limiter.check().is_ok());
}
#[test]
fn test_extract_client_ip_x_forwarded() {
let mut headers = HeaderMap::new();
headers.insert("x-forwarded-for", "1.2.3.4, 5.6.7.8".parse().unwrap());
let ip = extract_client_ip(&headers, None);
assert_eq!(ip, Some("1.2.3.4".to_string()));
}
#[test]
fn test_extract_client_ip_x_real() {
let mut headers = HeaderMap::new();
headers.insert("x-real-ip", "10.0.0.1".parse().unwrap());
let ip = extract_client_ip(&headers, None);
assert_eq!(ip, Some("10.0.0.1".to_string()));
}
}

View File

@@ -0,0 +1,24 @@
//! REST API module for 7sense.
//!
//! This module provides RESTful endpoints for:
//! - Recording upload and management
//! - Segment similarity search
//! - Cluster discovery and labeling
//! - Evidence pack retrieval
//!
//! ## API Versioning
//!
//! All endpoints are versioned under `/api/v1/`. Breaking changes will
//! result in a new API version (e.g., `/api/v2/`).
//!
//! ## Authentication
//!
//! If `SEVENSENSE_API_KEY` is set, all requests must include an
//! `Authorization: Bearer <api_key>` header.
pub mod handlers;
pub mod middleware;
pub mod routes;
pub use handlers::*;
pub use routes::create_router;

View File

@@ -0,0 +1,74 @@
//! REST API route definitions with versioning.
//!
//! Routes are organized by resource type and versioned under `/api/v1/`.
use axum::{
routing::{get, post, put},
Router,
};
use super::handlers;
use crate::AppContext;
/// Create the REST API router with all endpoints.
pub fn create_router(_ctx: AppContext) -> Router<AppContext> {
Router::new()
// Health check
.route("/health", get(handlers::health_check))
// Recordings
.nest("/recordings", recordings_router())
// Segments
.nest("/segments", segments_router())
// Clusters
.nest("/clusters", clusters_router())
// Evidence
.nest("/evidence", evidence_router())
// Search
.route("/search", post(handlers::search))
}
/// Recording management routes.
fn recordings_router() -> Router<AppContext> {
Router::new()
// POST /recordings - Upload new recording
.route("/", post(handlers::upload_recording))
// GET /recordings/:id - Get recording by ID
.route("/:id", get(handlers::get_recording))
}
/// Segment analysis routes.
fn segments_router() -> Router<AppContext> {
Router::new()
// GET /segments/:id/neighbors - Find similar segments
.route("/:id/neighbors", get(handlers::get_neighbors))
}
/// Cluster management routes.
fn clusters_router() -> Router<AppContext> {
Router::new()
// GET /clusters - List all clusters
.route("/", get(handlers::list_clusters))
// GET /clusters/:id - Get specific cluster
.route("/:id", get(handlers::get_cluster))
// PUT /clusters/:id/label - Assign label to cluster
.route("/:id/label", put(handlers::assign_cluster_label))
}
/// Evidence pack routes.
fn evidence_router() -> Router<AppContext> {
Router::new()
// POST /evidence - Generate evidence pack
.route("/", post(handlers::generate_evidence_pack))
// GET /evidence/:id - Get evidence pack by ID
.route("/:id", get(handlers::get_evidence_pack))
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use tower::ServiceExt;
// Integration tests would go here with a mock AppContext
}

View File

@@ -0,0 +1,149 @@
//! Audio processing service.
//!
//! This module provides the `AudioPipeline` service for loading and
//! segmenting audio recordings.
use thiserror::Error;
use super::{Audio, Segment};
/// Audio processing error.
#[derive(Debug, Error)]
pub enum AudioError {
/// Invalid audio format
#[error("Invalid audio format: {0}")]
InvalidFormat(String),
/// Decoding error
#[error("Failed to decode audio: {0}")]
DecodingError(String),
/// IO error
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
/// Unsupported sample rate
#[error("Unsupported sample rate: {0}")]
UnsupportedSampleRate(u32),
/// Empty audio
#[error("Audio file is empty or too short")]
EmptyAudio,
}
/// Audio pipeline configuration.
#[derive(Debug, Clone)]
pub struct AudioPipelineConfig {
/// Target sample rate for processing
pub target_sample_rate: u32,
/// Minimum segment duration in seconds
pub min_segment_duration: f64,
/// Maximum segment duration in seconds
pub max_segment_duration: f64,
/// Energy threshold for segmentation
pub energy_threshold: f32,
}
impl Default for AudioPipelineConfig {
fn default() -> Self {
Self {
target_sample_rate: 32000,
min_segment_duration: 0.5,
max_segment_duration: 10.0,
energy_threshold: 0.01,
}
}
}
/// Audio processing pipeline.
///
/// Handles audio loading, resampling, and segmentation.
pub struct AudioPipeline {
config: AudioPipelineConfig,
}
impl AudioPipeline {
/// Create a new audio pipeline with the given configuration.
pub fn new(config: AudioPipelineConfig) -> Result<Self, AudioError> {
Ok(Self { config })
}
/// Get metadata from audio data without full decoding.
pub fn get_metadata(&self, data: &[u8]) -> Result<(f64, u32, u16), AudioError> {
// In a real implementation, this would parse the audio header
// For now, return reasonable defaults
if data.len() < 44 {
return Err(AudioError::EmptyAudio);
}
// Parse WAV header (simplified)
// Real implementation would use symphonia or hound
let sample_rate = 44100u32;
let channels = 1u16;
let duration = data.len() as f64 / (sample_rate as f64 * channels as f64 * 2.0);
Ok((duration, sample_rate, channels))
}
/// Load audio from raw bytes.
pub fn load_audio(&self, data: &[u8]) -> Result<Audio, AudioError> {
if data.is_empty() {
return Err(AudioError::EmptyAudio);
}
// In a real implementation, this would:
// 1. Detect format (WAV, FLAC, MP3, etc.)
// 2. Decode to samples
// 3. Convert to mono if stereo
// 4. Resample to target rate
// 5. Normalize to -1.0 to 1.0
let (duration, _sample_rate, _) = self.get_metadata(data)?;
// Generate placeholder samples
let num_samples = (duration * self.config.target_sample_rate as f64) as usize;
let samples = vec![0.0f32; num_samples];
Ok(Audio {
samples,
sample_rate: self.config.target_sample_rate,
duration_secs: duration,
})
}
/// Segment audio into individual calls/vocalizations.
pub fn segment(&self, _audio: &Audio) -> Result<Vec<Segment>, AudioError> {
// In a real implementation, this would:
// 1. Compute spectrogram
// 2. Detect energy regions above threshold
// 3. Apply minimum/maximum duration constraints
// 4. Extract segment audio
// For now, return empty segments (placeholder)
Ok(vec![])
}
/// Get the target sample rate.
#[must_use]
pub fn target_sample_rate(&self) -> u32 {
self.config.target_sample_rate
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_audio_pipeline_creation() {
let pipeline = AudioPipeline::new(Default::default());
assert!(pipeline.is_ok());
}
#[test]
fn test_empty_audio_error() {
let pipeline = AudioPipeline::new(Default::default()).unwrap();
let result = pipeline.load_audio(&[]);
assert!(matches!(result, Err(AudioError::EmptyAudio)));
}
}

View File

@@ -0,0 +1,247 @@
//! Cluster analysis service.
//!
//! This module provides the `ClusterEngine` service for discovering
//! and managing clusters of similar segments.
use std::collections::HashMap;
use std::sync::RwLock;
use chrono::Utc;
use thiserror::Error;
use uuid::Uuid;
use super::{ClusterData, SegmentEmbedding};
/// Cluster analysis error.
#[derive(Debug, Error)]
pub enum AnalysisError {
/// Clustering error
#[error("Clustering failed: {0}")]
ClusteringError(String),
/// Invalid parameters
#[error("Invalid parameters: {0}")]
InvalidParameters(String),
/// Not found
#[error("Not found: {0}")]
NotFound(String),
/// Internal error
#[error("Internal error: {0}")]
Internal(String),
}
/// Cluster engine configuration.
#[derive(Debug, Clone)]
pub struct ClusterEngineConfig {
/// Minimum cluster size
pub min_cluster_size: usize,
/// HDBSCAN min_samples
pub min_samples: usize,
/// Distance threshold for merging
pub merge_threshold: f32,
}
impl Default for ClusterEngineConfig {
fn default() -> Self {
Self {
min_cluster_size: 5,
min_samples: 3,
merge_threshold: 0.15,
}
}
}
/// Cluster analysis engine.
///
/// Manages cluster discovery, labeling, and updates.
pub struct ClusterEngine {
config: ClusterEngineConfig,
// In-memory cluster storage for stub implementation
clusters: RwLock<HashMap<Uuid, ClusterData>>,
}
impl ClusterEngine {
/// Create a new cluster engine with the given configuration.
pub fn new(config: ClusterEngineConfig) -> Result<Self, AnalysisError> {
Ok(Self {
config,
clusters: RwLock::new(HashMap::new()),
})
}
/// Update clusters with new embeddings.
pub fn update_clusters(&self, _embeddings: &[SegmentEmbedding]) -> Result<(), AnalysisError> {
// In a real implementation, this would:
// 1. Run HDBSCAN or similar clustering
// 2. Merge with existing clusters if similar
// 3. Update cluster centroids and metadata
// For the stub, we don't create clusters automatically
Ok(())
}
/// Get all clusters.
pub fn get_all_clusters(&self) -> Result<Vec<ClusterData>, AnalysisError> {
let clusters = self
.clusters
.read()
.map_err(|e| AnalysisError::Internal(e.to_string()))?;
Ok(clusters.values().cloned().collect())
}
/// Get a specific cluster by ID.
pub fn get_cluster(&self, id: &Uuid) -> Result<Option<ClusterData>, AnalysisError> {
let clusters = self
.clusters
.read()
.map_err(|e| AnalysisError::Internal(e.to_string()))?;
Ok(clusters.get(id).cloned())
}
/// Assign a label to a cluster.
pub fn assign_label(
&self,
cluster_id: &Uuid,
label: &str,
) -> Result<Option<ClusterData>, AnalysisError> {
let mut clusters = self
.clusters
.write()
.map_err(|e| AnalysisError::Internal(e.to_string()))?;
if let Some(cluster) = clusters.get_mut(cluster_id) {
cluster.label = Some(label.to_string());
Ok(Some(cluster.clone()))
} else {
Ok(None)
}
}
/// Create a new cluster manually.
pub fn create_cluster(
&self,
centroid: Vec<f32>,
exemplar_ids: Vec<Uuid>,
) -> Result<ClusterData, AnalysisError> {
let cluster = ClusterData {
id: Uuid::new_v4(),
label: None,
size: exemplar_ids.len(),
centroid,
density: 0.0,
exemplar_ids,
species_distribution: vec![],
created_at: Utc::now(),
};
let mut clusters = self
.clusters
.write()
.map_err(|e| AnalysisError::Internal(e.to_string()))?;
clusters.insert(cluster.id, cluster.clone());
Ok(cluster)
}
/// Delete a cluster.
pub fn delete_cluster(&self, id: &Uuid) -> Result<bool, AnalysisError> {
let mut clusters = self
.clusters
.write()
.map_err(|e| AnalysisError::Internal(e.to_string()))?;
Ok(clusters.remove(id).is_some())
}
/// Merge two clusters.
pub fn merge_clusters(
&self,
cluster_a: &Uuid,
cluster_b: &Uuid,
) -> Result<ClusterData, AnalysisError> {
let mut clusters = self
.clusters
.write()
.map_err(|e| AnalysisError::Internal(e.to_string()))?;
let a = clusters
.remove(cluster_a)
.ok_or_else(|| AnalysisError::NotFound(format!("Cluster {} not found", cluster_a)))?;
let b = clusters
.remove(cluster_b)
.ok_or_else(|| AnalysisError::NotFound(format!("Cluster {} not found", cluster_b)))?;
// Merge exemplar IDs
let mut merged_exemplars = a.exemplar_ids;
merged_exemplars.extend(b.exemplar_ids);
// Average centroids (simplified)
let merged_centroid: Vec<f32> = a
.centroid
.iter()
.zip(b.centroid.iter())
.map(|(x, y)| (x + y) / 2.0)
.collect();
let merged = ClusterData {
id: Uuid::new_v4(),
label: a.label.or(b.label),
size: a.size + b.size,
centroid: merged_centroid,
density: (a.density + b.density) / 2.0,
exemplar_ids: merged_exemplars,
species_distribution: vec![], // Would recompute
created_at: Utc::now(),
};
clusters.insert(merged.id, merged.clone());
Ok(merged)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cluster_engine_creation() {
let engine = ClusterEngine::new(Default::default());
assert!(engine.is_ok());
}
#[test]
fn test_create_and_get_cluster() {
let engine = ClusterEngine::new(Default::default()).unwrap();
let cluster = engine
.create_cluster(vec![0.0; 1024], vec![Uuid::new_v4()])
.unwrap();
let retrieved = engine.get_cluster(&cluster.id).unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().id, cluster.id);
}
#[test]
fn test_assign_label() {
let engine = ClusterEngine::new(Default::default()).unwrap();
let cluster = engine
.create_cluster(vec![0.0; 1024], vec![Uuid::new_v4()])
.unwrap();
let updated = engine
.assign_label(&cluster.id, "Test Label")
.unwrap()
.unwrap();
assert_eq!(updated.label, Some("Test Label".to_string()));
}
}

View File

@@ -0,0 +1,132 @@
//! Embedding model service.
//!
//! This module provides the `EmbeddingModel` service for generating
//! vector embeddings from audio segments.
use thiserror::Error;
use super::{Segment, SegmentEmbedding};
/// Embedding model error.
#[derive(Debug, Error)]
pub enum EmbeddingError {
/// Model loading error
#[error("Failed to load model: {0}")]
ModelLoadError(String),
/// Inference error
#[error("Inference failed: {0}")]
InferenceError(String),
/// Invalid input
#[error("Invalid input: {0}")]
InvalidInput(String),
/// Model not initialized
#[error("Model not initialized")]
NotInitialized,
}
/// Embedding model configuration.
#[derive(Debug, Clone)]
pub struct EmbeddingModelConfig {
/// Model path or identifier
pub model_id: String,
/// Embedding dimension
pub embedding_dim: usize,
/// Batch size for inference
pub batch_size: usize,
/// Use GPU if available
pub use_gpu: bool,
}
impl Default for EmbeddingModelConfig {
fn default() -> Self {
Self {
model_id: "birdnet-v2.4".to_string(),
embedding_dim: 1024,
batch_size: 32,
use_gpu: false,
}
}
}
/// Embedding model for generating audio embeddings.
///
/// Wraps ONNX model inference for generating fixed-size vector
/// representations of audio segments.
pub struct EmbeddingModel {
config: EmbeddingModelConfig,
}
impl EmbeddingModel {
/// Create a new embedding model with the given configuration.
pub async fn new(config: EmbeddingModelConfig) -> Result<Self, EmbeddingError> {
// In a real implementation, this would:
// 1. Load ONNX model from path
// 2. Initialize ONNX runtime session
// 3. Configure GPU/CPU execution providers
Ok(Self { config })
}
/// Generate embeddings for a batch of segments.
pub async fn embed_batch(
&self,
segments: &[Segment],
) -> Result<Vec<SegmentEmbedding>, EmbeddingError> {
// In a real implementation, this would:
// 1. Preprocess segments (mel spectrogram)
// 2. Batch and run inference
// 3. L2 normalize embeddings
let embeddings = segments
.iter()
.map(|seg| SegmentEmbedding {
id: seg.id,
recording_id: seg.recording_id,
embedding: vec![0.0; self.config.embedding_dim],
start_time: seg.start_time,
end_time: seg.end_time,
species: seg.species.clone(),
})
.collect();
Ok(embeddings)
}
/// Generate embedding for text (for text-to-audio search).
pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
if text.is_empty() {
return Err(EmbeddingError::InvalidInput("Empty text".to_string()));
}
// In a real implementation, this would use a text encoder
// For now, return a zero vector
Ok(vec![0.0; self.config.embedding_dim])
}
/// Get the embedding dimension.
#[must_use]
pub fn embedding_dim(&self) -> usize {
self.config.embedding_dim
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_embedding_model_creation() {
let model = EmbeddingModel::new(Default::default()).await;
assert!(model.is_ok());
}
#[tokio::test]
async fn test_embed_text_empty() {
let model = EmbeddingModel::new(Default::default()).await.unwrap();
let result = model.embed_text("").await;
assert!(matches!(result, Err(EmbeddingError::InvalidInput(_))));
}
}

View File

@@ -0,0 +1,246 @@
//! Interpretation service.
//!
//! This module provides the `InterpretationEngine` service for generating
//! evidence packs that explain similarity relationships.
use std::collections::HashMap;
use std::sync::RwLock;
use chrono::Utc;
use thiserror::Error;
use uuid::Uuid;
use super::{
EvidencePackData, EvidenceSegment, FeatureContributionData, NeighborEvidenceData,
SearchResult, SharedFeature, VisualizationUrls,
};
/// Interpretation error.
#[derive(Debug, Error)]
pub enum InterpretationError {
/// Evidence pack generation failed
#[error("Evidence generation failed: {0}")]
GenerationError(String),
/// Not found
#[error("Not found: {0}")]
NotFound(String),
/// Internal error
#[error("Internal error: {0}")]
Internal(String),
}
/// Interpretation engine configuration.
#[derive(Debug, Clone)]
pub struct InterpretationEngineConfig {
/// Number of top features to include
pub top_features: usize,
/// Generate spectrograms
pub generate_spectrograms: bool,
/// Generate UMAP visualizations
pub generate_umap: bool,
}
impl Default for InterpretationEngineConfig {
fn default() -> Self {
Self {
top_features: 5,
generate_spectrograms: true,
generate_umap: true,
}
}
}
/// Interpretation engine for generating evidence packs.
///
/// Creates interpretable explanations for similarity relationships.
pub struct InterpretationEngine {
config: InterpretationEngineConfig,
// Cache for generated evidence packs
cache: RwLock<HashMap<Uuid, EvidencePackData>>,
}
impl InterpretationEngine {
/// Create a new interpretation engine with the given configuration.
pub fn new(config: InterpretationEngineConfig) -> Result<Self, InterpretationError> {
Ok(Self {
config,
cache: RwLock::new(HashMap::new()),
})
}
/// Get a cached evidence pack by query ID.
pub fn get_evidence_pack(
&self,
query_id: &Uuid,
) -> Result<Option<EvidencePackData>, InterpretationError> {
let cache = self
.cache
.read()
.map_err(|e| InterpretationError::Internal(e.to_string()))?;
Ok(cache.get(query_id).cloned())
}
/// Generate an evidence pack for a query segment and its neighbors.
pub async fn generate_evidence_pack(
&self,
segment_id: &Uuid,
neighbors: &[SearchResult],
) -> Result<EvidencePackData, InterpretationError> {
let query_id = Uuid::new_v4();
// Create query segment info
let query_segment = EvidenceSegment {
id: *segment_id,
recording_id: Uuid::new_v4(), // Would be looked up
start_time: 0.0,
end_time: 1.0,
species: None,
};
// Generate evidence for each neighbor
let neighbor_evidence: Vec<NeighborEvidenceData> = neighbors
.iter()
.map(|n| {
// In a real implementation, this would:
// 1. Analyze embedding dimensions
// 2. Identify contributing features
// 3. Generate spectrogram comparisons
let contributing_features = vec![
FeatureContributionData {
name: "fundamental_frequency".to_string(),
weight: 0.25,
query_value: 2500.0,
neighbor_value: 2480.0,
},
FeatureContributionData {
name: "duration".to_string(),
weight: 0.15,
query_value: 0.5,
neighbor_value: 0.48,
},
FeatureContributionData {
name: "bandwidth".to_string(),
weight: 0.12,
query_value: 1500.0,
neighbor_value: 1520.0,
},
];
NeighborEvidenceData {
segment: EvidenceSegment {
id: n.id,
recording_id: n.recording_id,
start_time: n.start_time,
end_time: n.end_time,
species: n.species.clone(),
},
similarity: 1.0 - n.distance,
contributing_features,
spectrogram_comparison_url: if self.config.generate_spectrograms {
Some(format!("/api/v1/evidence/{}/spectrograms/{}", query_id, n.id))
} else {
None
},
}
})
.collect();
// Identify shared features across neighbors
let shared_features = vec![
SharedFeature {
name: "frequency_modulation".to_string(),
description: "Rapid upward frequency sweep in 100-200ms range".to_string(),
confidence: 0.92,
},
SharedFeature {
name: "harmonic_structure".to_string(),
description: "Clear harmonic overtones at 2x and 3x fundamental".to_string(),
confidence: 0.87,
},
];
// Generate visualization URLs
let visualizations = VisualizationUrls {
umap_url: if self.config.generate_umap {
Some(format!("/api/v1/evidence/{}/umap", query_id))
} else {
None
},
spectrogram_grid_url: if self.config.generate_spectrograms {
Some(format!("/api/v1/evidence/{}/grid", query_id))
} else {
None
},
feature_importance_url: Some(format!("/api/v1/evidence/{}/features", query_id)),
};
let evidence_pack = EvidencePackData {
query_id,
query_segment,
neighbors: neighbor_evidence,
shared_features,
visualizations,
generated_at: Utc::now(),
};
// Cache the evidence pack
{
let mut cache = self
.cache
.write()
.map_err(|e| InterpretationError::Internal(e.to_string()))?;
cache.insert(query_id, evidence_pack.clone());
}
Ok(evidence_pack)
}
/// Clear the evidence pack cache.
pub fn clear_cache(&self) -> Result<(), InterpretationError> {
let mut cache = self
.cache
.write()
.map_err(|e| InterpretationError::Internal(e.to_string()))?;
cache.clear();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_interpretation_engine_creation() {
let engine = InterpretationEngine::new(Default::default());
assert!(engine.is_ok());
}
#[tokio::test]
async fn test_generate_evidence_pack() {
let engine = InterpretationEngine::new(Default::default()).unwrap();
let segment_id = Uuid::new_v4();
let neighbors = vec![SearchResult {
id: Uuid::new_v4(),
recording_id: Uuid::new_v4(),
distance: 0.1,
start_time: 0.0,
end_time: 1.0,
species: None,
}];
let result = engine.generate_evidence_pack(&segment_id, &neighbors).await;
assert!(result.is_ok());
let pack = result.unwrap();
assert!(!pack.neighbors.is_empty());
assert!(!pack.shared_features.is_empty());
}
}

View File

@@ -0,0 +1,218 @@
//! Service layer abstractions for 7sense API.
//!
//! This module defines the interfaces and implementations for core services:
//! - `AudioPipeline` - Audio loading and segmentation
//! - `EmbeddingModel` - Segment embedding generation
//! - `VectorIndex` - Similarity search
//! - `ClusterEngine` - Cluster analysis
//! - `InterpretationEngine` - Evidence pack generation
//!
//! These services wrap the underlying crate implementations and provide
//! API-specific functionality.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use utoipa::ToSchema;
// Re-export service types
pub use audio::*;
pub use cluster::*;
pub use embedding::*;
pub use interpretation::*;
pub use vector::*;
mod audio;
mod cluster;
mod embedding;
mod interpretation;
mod vector;
/// Species information attached to a segment.
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct SpeciesInfo {
/// Common name
pub common_name: String,
/// Scientific name (binomial nomenclature)
pub scientific_name: Option<String>,
/// Detection confidence (0.0 to 1.0)
pub confidence: f32,
}
/// A detected audio segment.
#[derive(Debug, Clone)]
pub struct Segment {
/// Unique identifier
pub id: Uuid,
/// Parent recording ID
pub recording_id: Uuid,
/// Start time in seconds
pub start_time: f64,
/// End time in seconds
pub end_time: f64,
/// Audio samples (mono, normalized)
pub samples: Vec<f32>,
/// Sample rate
pub sample_rate: u32,
/// Detected species
pub species: Option<SpeciesInfo>,
/// Quality score
pub quality_score: f32,
}
/// Audio metadata.
#[derive(Debug, Clone)]
pub struct AudioMetadata {
/// Duration in seconds
pub duration_secs: f64,
/// Sample rate in Hz
pub sample_rate: u32,
/// Number of channels
pub channels: u16,
}
/// Loaded audio data.
#[derive(Debug, Clone)]
pub struct Audio {
/// Mono samples (normalized to -1.0 to 1.0)
pub samples: Vec<f32>,
/// Sample rate in Hz
pub sample_rate: u32,
/// Original duration in seconds
pub duration_secs: f64,
}
/// Segment embedding for vector storage.
#[derive(Debug, Clone)]
pub struct SegmentEmbedding {
/// Segment ID
pub id: Uuid,
/// Recording ID
pub recording_id: Uuid,
/// Embedding vector
pub embedding: Vec<f32>,
/// Start time
pub start_time: f64,
/// End time
pub end_time: f64,
/// Detected species
pub species: Option<SpeciesInfo>,
}
/// Search result from vector index.
#[derive(Debug, Clone)]
pub struct SearchResult {
/// Segment ID
pub id: Uuid,
/// Recording ID
pub recording_id: Uuid,
/// Distance to query
pub distance: f32,
/// Start time
pub start_time: f64,
/// End time
pub end_time: f64,
/// Detected species
pub species: Option<SpeciesInfo>,
}
/// Cluster data.
#[derive(Debug, Clone)]
pub struct ClusterData {
/// Cluster ID
pub id: Uuid,
/// Human-assigned label
pub label: Option<String>,
/// Number of segments
pub size: usize,
/// Centroid embedding
pub centroid: Vec<f32>,
/// Cluster density
pub density: f32,
/// Representative segment IDs
pub exemplar_ids: Vec<Uuid>,
/// Species distribution: (name, count, percentage)
pub species_distribution: Vec<(String, usize, f64)>,
/// Creation timestamp
pub created_at: DateTime<Utc>,
}
/// Evidence pack data.
#[derive(Debug, Clone)]
pub struct EvidencePackData {
/// Query ID
pub query_id: Uuid,
/// Query segment
pub query_segment: EvidenceSegment,
/// Neighbor evidence
pub neighbors: Vec<NeighborEvidenceData>,
/// Shared features
pub shared_features: Vec<SharedFeature>,
/// Visualizations
pub visualizations: VisualizationUrls,
/// Generation timestamp
pub generated_at: DateTime<Utc>,
}
/// Segment for evidence pack.
#[derive(Debug, Clone)]
pub struct EvidenceSegment {
/// Segment ID
pub id: Uuid,
/// Recording ID
pub recording_id: Uuid,
/// Start time
pub start_time: f64,
/// End time
pub end_time: f64,
/// Species info
pub species: Option<SpeciesInfo>,
}
/// Neighbor evidence data.
#[derive(Debug, Clone)]
pub struct NeighborEvidenceData {
/// Neighbor segment
pub segment: EvidenceSegment,
/// Similarity score
pub similarity: f32,
/// Contributing features
pub contributing_features: Vec<FeatureContributionData>,
/// Spectrogram comparison URL
pub spectrogram_comparison_url: Option<String>,
}
/// Feature contribution data.
#[derive(Debug, Clone)]
pub struct FeatureContributionData {
/// Feature name
pub name: String,
/// Contribution weight
pub weight: f32,
/// Query value
pub query_value: f64,
/// Neighbor value
pub neighbor_value: f64,
}
/// Shared acoustic feature.
#[derive(Debug, Clone)]
pub struct SharedFeature {
/// Feature name
pub name: String,
/// Description
pub description: String,
/// Confidence score
pub confidence: f32,
}
/// Visualization URLs.
#[derive(Debug, Clone)]
pub struct VisualizationUrls {
/// UMAP projection URL
pub umap_url: Option<String>,
/// Spectrogram grid URL
pub spectrogram_grid_url: Option<String>,
/// Feature importance URL
pub feature_importance_url: Option<String>,
}

View File

@@ -0,0 +1,235 @@
//! Vector index service.
//!
//! This module provides the `VectorIndex` service for similarity search
//! using vector embeddings.
use std::collections::HashMap;
use std::sync::RwLock;
use thiserror::Error;
use uuid::Uuid;
use super::{SearchResult, SegmentEmbedding, SpeciesInfo};
/// Vector index error.
#[derive(Debug, Error)]
pub enum VectorError {
/// Connection error
#[error("Connection error: {0}")]
ConnectionError(String),
/// Query error
#[error("Query error: {0}")]
QueryError(String),
/// Index error
#[error("Index error: {0}")]
IndexError(String),
/// Not found
#[error("Not found: {0}")]
NotFound(String),
}
/// Vector index configuration.
#[derive(Debug, Clone)]
pub struct VectorIndexConfig {
/// Collection name
pub collection_name: String,
/// Embedding dimension
pub embedding_dim: usize,
/// HNSW M parameter
pub hnsw_m: usize,
/// HNSW ef_construct parameter
pub hnsw_ef_construct: usize,
}
impl Default for VectorIndexConfig {
fn default() -> Self {
Self {
collection_name: "sevensense_segments".to_string(),
embedding_dim: 1024,
hnsw_m: 16,
hnsw_ef_construct: 100,
}
}
}
/// In-memory segment storage for the stub implementation.
struct StoredSegment {
recording_id: Uuid,
embedding: Vec<f32>,
start_time: f64,
end_time: f64,
species: Option<SpeciesInfo>,
}
/// Vector index for similarity search.
///
/// Wraps vector database (Qdrant) for efficient nearest neighbor search.
pub struct VectorIndex {
config: VectorIndexConfig,
// In-memory storage for stub implementation
storage: RwLock<HashMap<Uuid, StoredSegment>>,
}
impl VectorIndex {
/// Create a new vector index with the given configuration.
pub fn new(config: VectorIndexConfig) -> Result<Self, VectorError> {
// In a real implementation, this would:
// 1. Connect to Qdrant
// 2. Create/verify collection
// 3. Configure HNSW index
Ok(Self {
config,
storage: RwLock::new(HashMap::new()),
})
}
/// Add a batch of embeddings to the index.
pub fn add_batch(&self, embeddings: &[SegmentEmbedding]) -> Result<(), VectorError> {
let mut storage = self
.storage
.write()
.map_err(|e| VectorError::IndexError(e.to_string()))?;
for emb in embeddings {
storage.insert(
emb.id,
StoredSegment {
recording_id: emb.recording_id,
embedding: emb.embedding.clone(),
start_time: emb.start_time,
end_time: emb.end_time,
species: emb.species.clone(),
},
);
}
Ok(())
}
/// Get embedding for a segment.
pub fn get_embedding(&self, segment_id: &Uuid) -> Result<Option<Vec<f32>>, VectorError> {
let storage = self
.storage
.read()
.map_err(|e| VectorError::QueryError(e.to_string()))?;
Ok(storage.get(segment_id).map(|s| s.embedding.clone()))
}
/// Search for similar segments.
pub fn search(
&self,
query: &[f32],
k: usize,
min_similarity: f32,
) -> Result<Vec<SearchResult>, VectorError> {
let storage = self
.storage
.read()
.map_err(|e| VectorError::QueryError(e.to_string()))?;
// Compute distances to all stored embeddings
let mut results: Vec<(Uuid, f32, &StoredSegment)> = storage
.iter()
.map(|(id, seg)| {
let distance = cosine_distance(query, &seg.embedding);
(*id, distance, seg)
})
.filter(|(_, dist, _)| (1.0 - *dist) >= min_similarity)
.collect();
// Sort by distance (ascending)
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
// Take top k
let results: Vec<SearchResult> = results
.into_iter()
.take(k)
.map(|(id, distance, seg)| SearchResult {
id,
recording_id: seg.recording_id,
distance,
start_time: seg.start_time,
end_time: seg.end_time,
species: seg.species.clone(),
})
.collect();
Ok(results)
}
/// Delete embeddings for a recording.
pub fn delete_recording(&self, recording_id: &Uuid) -> Result<usize, VectorError> {
let mut storage = self
.storage
.write()
.map_err(|e| VectorError::IndexError(e.to_string()))?;
let to_remove: Vec<Uuid> = storage
.iter()
.filter(|(_, seg)| seg.recording_id == *recording_id)
.map(|(id, _)| *id)
.collect();
let count = to_remove.len();
for id in to_remove {
storage.remove(&id);
}
Ok(count)
}
/// Get total number of indexed segments.
pub fn count(&self) -> Result<usize, VectorError> {
let storage = self
.storage
.read()
.map_err(|e| VectorError::QueryError(e.to_string()))?;
Ok(storage.len())
}
}
/// Compute cosine distance between two vectors.
fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 1.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 1.0;
}
1.0 - (dot / (norm_a * norm_b))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vector_index_creation() {
let index = VectorIndex::new(Default::default());
assert!(index.is_ok());
}
#[test]
fn test_cosine_distance() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let dist = cosine_distance(&a, &b);
assert!((dist - 0.0).abs() < 0.001);
let c = vec![0.0, 1.0, 0.0];
let dist = cosine_distance(&a, &c);
assert!((dist - 1.0).abs() < 0.001);
}
}

View File

@@ -0,0 +1,286 @@
//! WebSocket handlers for real-time updates.
//!
//! These handlers manage WebSocket connections for streaming
//! processing status, cluster updates, and other real-time data.
use std::time::Duration;
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
Path, State,
},
response::IntoResponse,
};
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use tokio::sync::broadcast;
use uuid::Uuid;
use crate::{AppContext, ProcessingEvent, ProcessingStatus};
/// WebSocket message types for client communication.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "data")]
pub enum WsMessage {
/// Processing status update
#[serde(rename = "status")]
Status(StatusUpdate),
/// Error message
#[serde(rename = "error")]
Error(ErrorMessage),
/// Ping/keepalive
#[serde(rename = "ping")]
Ping,
/// Pong response
#[serde(rename = "pong")]
Pong,
/// Subscription confirmation
#[serde(rename = "subscribed")]
Subscribed {
/// Channel subscribed to
channel: String,
},
}
/// Processing status update message.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StatusUpdate {
/// Recording ID
pub recording_id: Uuid,
/// Status string
pub status: String,
/// Progress (0.0 to 1.0)
pub progress: f32,
/// Optional message
pub message: Option<String>,
/// Timestamp
pub timestamp: i64,
}
/// Error message.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorMessage {
/// Error code
pub code: String,
/// Error message
pub message: String,
}
impl From<ProcessingEvent> for StatusUpdate {
fn from(event: ProcessingEvent) -> Self {
Self {
recording_id: event.recording_id,
status: match event.status {
ProcessingStatus::Queued => "queued",
ProcessingStatus::Loading => "loading",
ProcessingStatus::Segmenting => "segmenting",
ProcessingStatus::Embedding => "embedding",
ProcessingStatus::Indexing => "indexing",
ProcessingStatus::Analyzing => "analyzing",
ProcessingStatus::Complete => "complete",
ProcessingStatus::Failed => "failed",
}
.to_string(),
progress: event.progress,
message: event.message,
timestamp: chrono::Utc::now().timestamp_millis(),
}
}
}
/// WebSocket handler for recording status updates.
///
/// Clients connect to `/ws/recordings/{id}` to receive real-time
/// status updates for a specific recording.
pub async fn recording_status_ws(
ws: WebSocketUpgrade,
Path(recording_id): Path<Uuid>,
State(ctx): State<AppContext>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_recording_status(socket, recording_id, ctx))
}
async fn handle_recording_status(socket: WebSocket, recording_id: Uuid, ctx: AppContext) {
let (mut sender, mut receiver) = socket.split();
// Subscribe to events
let mut event_rx = ctx.subscribe_events();
// Send subscription confirmation
let confirm = WsMessage::Subscribed {
channel: format!("recordings/{recording_id}"),
};
if let Ok(json) = serde_json::to_string(&confirm) {
let _ = sender.send(Message::Text(json.into())).await;
}
// Spawn task to handle incoming messages (pings, etc.)
let mut recv_task = tokio::spawn(async move {
while let Some(msg) = receiver.next().await {
match msg {
Ok(Message::Close(_)) | Err(_) => break,
_ => {}
}
}
});
// Main event loop
let mut send_task = tokio::spawn(async move {
// Keepalive interval
let mut keepalive = tokio::time::interval(Duration::from_secs(30));
loop {
tokio::select! {
// Handle processing events
event = event_rx.recv() => {
match event {
Ok(event) if event.recording_id == recording_id => {
let update: StatusUpdate = event.into();
let msg = WsMessage::Status(update);
if let Ok(json) = serde_json::to_string(&msg) {
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
}
}
Err(broadcast::error::RecvError::Closed) => break,
_ => {}
}
}
// Keepalive ping
_ = keepalive.tick() => {
let msg = WsMessage::Ping;
if let Ok(json) = serde_json::to_string(&msg) {
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
}
}
}
}
});
// Wait for either task to complete
tokio::select! {
_ = &mut send_task => recv_task.abort(),
_ = &mut recv_task => send_task.abort(),
}
tracing::debug!(recording_id = %recording_id, "WebSocket connection closed");
}
/// WebSocket handler for all events stream.
///
/// Admin endpoint that streams all processing events.
pub async fn events_ws(ws: WebSocketUpgrade, State(ctx): State<AppContext>) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_all_events(socket, ctx))
}
async fn handle_all_events(socket: WebSocket, ctx: AppContext) {
let (mut sender, mut receiver) = socket.split();
let mut event_rx = ctx.subscribe_events();
// Send subscription confirmation
let confirm = WsMessage::Subscribed {
channel: "events".to_string(),
};
if let Ok(json) = serde_json::to_string(&confirm) {
let _ = sender.send(Message::Text(json.into())).await;
}
// Spawn receiver task
let mut recv_task = tokio::spawn(async move {
while let Some(msg) = receiver.next().await {
match msg {
Ok(Message::Close(_)) | Err(_) => break,
_ => {}
}
}
});
// Main send loop
let mut send_task = tokio::spawn(async move {
let mut keepalive = tokio::time::interval(Duration::from_secs(30));
loop {
tokio::select! {
event = event_rx.recv() => {
match event {
Ok(event) => {
let update: StatusUpdate = event.into();
let msg = WsMessage::Status(update);
if let Ok(json) = serde_json::to_string(&msg) {
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
}
}
Err(broadcast::error::RecvError::Closed) => break,
Err(_) => {} // Lagged, skip
}
}
_ = keepalive.tick() => {
if let Ok(json) = serde_json::to_string(&WsMessage::Ping) {
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
}
}
}
}
});
tokio::select! {
_ = &mut send_task => recv_task.abort(),
_ = &mut recv_task => send_task.abort(),
}
tracing::debug!("Events WebSocket connection closed");
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ws_message_serialization() {
let msg = WsMessage::Status(StatusUpdate {
recording_id: Uuid::new_v4(),
status: "processing".to_string(),
progress: 0.5,
message: Some("Halfway done".to_string()),
timestamp: 1_234_567_890,
});
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("status"));
assert!(json.contains("processing"));
}
#[test]
fn test_status_update_from_event() {
let event = ProcessingEvent {
recording_id: Uuid::new_v4(),
status: ProcessingStatus::Embedding,
progress: 0.5,
message: Some("Generating embeddings".to_string()),
};
let update: StatusUpdate = event.into();
assert_eq!(update.status, "embedding");
assert!((update.progress - 0.5).abs() < f32::EPSILON);
}
#[test]
fn test_error_message() {
let error = WsMessage::Error(ErrorMessage {
code: "not_found".to_string(),
message: "Recording not found".to_string(),
});
let json = serde_json::to_string(&error).unwrap();
assert!(json.contains("not_found"));
}
}

View File

@@ -0,0 +1,22 @@
//! WebSocket module for real-time updates.
//!
//! This module provides WebSocket endpoints for:
//! - Real-time processing status updates
//! - Live cluster updates
//! - Streaming search results
pub mod handlers;
use axum::{routing::get, Router};
use crate::AppContext;
/// Create the WebSocket router.
#[must_use]
pub fn create_router(_ctx: AppContext) -> Router<AppContext> {
Router::new()
// Recording status updates
.route("/recordings/:id", get(handlers::recording_status_ws))
// All events stream (admin)
.route("/events", get(handlers::events_ws))
}