Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
270
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/error.rs
vendored
Normal file
270
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/error.rs
vendored
Normal file
@@ -0,0 +1,270 @@
|
||||
//! API error types and HTTP response handling.
|
||||
//!
|
||||
//! This module provides a unified error type for all API endpoints with
|
||||
//! proper HTTP status code mapping and JSON error responses.
|
||||
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use crate::services::{AnalysisError, AudioError, EmbeddingError, InterpretationError, VectorError};
|
||||
|
||||
/// Unified API error type.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ApiError {
|
||||
/// Resource not found (404)
|
||||
#[error("Resource not found: {0}")]
|
||||
NotFound(String),
|
||||
|
||||
/// Bad request with validation errors (400)
|
||||
#[error("Bad request: {0}")]
|
||||
BadRequest(String),
|
||||
|
||||
/// Unauthorized access (401)
|
||||
#[error("Unauthorized: {0}")]
|
||||
Unauthorized(String),
|
||||
|
||||
/// Forbidden access (403)
|
||||
#[error("Forbidden: {0}")]
|
||||
Forbidden(String),
|
||||
|
||||
/// Conflict with existing resource (409)
|
||||
#[error("Conflict: {0}")]
|
||||
Conflict(String),
|
||||
|
||||
/// Payload too large (413)
|
||||
#[error("Payload too large: {0}")]
|
||||
PayloadTooLarge(String),
|
||||
|
||||
/// Unsupported media type (415)
|
||||
#[error("Unsupported media type: {0}")]
|
||||
UnsupportedMediaType(String),
|
||||
|
||||
/// Rate limit exceeded (429)
|
||||
#[error("Rate limit exceeded")]
|
||||
RateLimitExceeded,
|
||||
|
||||
/// Internal server error (500)
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
|
||||
/// Service unavailable (503)
|
||||
#[error("Service unavailable: {0}")]
|
||||
ServiceUnavailable(String),
|
||||
|
||||
/// Audio processing error
|
||||
#[error("Audio processing error: {0}")]
|
||||
AudioProcessing(#[from] AudioError),
|
||||
|
||||
/// Embedding error
|
||||
#[error("Embedding error: {0}")]
|
||||
Embedding(#[from] EmbeddingError),
|
||||
|
||||
/// Vector index error
|
||||
#[error("Vector index error: {0}")]
|
||||
VectorIndex(#[from] VectorError),
|
||||
|
||||
/// Analysis error
|
||||
#[error("Analysis error: {0}")]
|
||||
Analysis(#[from] AnalysisError),
|
||||
|
||||
/// Interpretation error
|
||||
#[error("Interpretation error: {0}")]
|
||||
Interpretation(#[from] InterpretationError),
|
||||
|
||||
/// Generic anyhow error
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
/// JSON error response body.
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct ErrorResponse {
|
||||
/// Error type identifier
|
||||
#[schema(example = "not_found")]
|
||||
pub error: String,
|
||||
/// Human-readable error message
|
||||
#[schema(example = "Recording with ID xyz not found")]
|
||||
pub message: String,
|
||||
/// Optional error details
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub details: Option<serde_json::Value>,
|
||||
/// Request ID for tracing
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub request_id: Option<String>,
|
||||
}
|
||||
|
||||
impl ApiError {
|
||||
/// Get the HTTP status code for this error.
|
||||
#[must_use]
|
||||
pub fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
Self::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
Self::BadRequest(_) => StatusCode::BAD_REQUEST,
|
||||
Self::Unauthorized(_) => StatusCode::UNAUTHORIZED,
|
||||
Self::Forbidden(_) => StatusCode::FORBIDDEN,
|
||||
Self::Conflict(_) => StatusCode::CONFLICT,
|
||||
Self::PayloadTooLarge(_) => StatusCode::PAYLOAD_TOO_LARGE,
|
||||
Self::UnsupportedMediaType(_) => StatusCode::UNSUPPORTED_MEDIA_TYPE,
|
||||
Self::RateLimitExceeded => StatusCode::TOO_MANY_REQUESTS,
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::ServiceUnavailable(_) => StatusCode::SERVICE_UNAVAILABLE,
|
||||
Self::AudioProcessing(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Self::Embedding(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::VectorIndex(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::Analysis(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::Interpretation(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::Other(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the error type identifier.
|
||||
#[must_use]
|
||||
pub fn error_type(&self) -> &'static str {
|
||||
match self {
|
||||
Self::NotFound(_) => "not_found",
|
||||
Self::BadRequest(_) => "bad_request",
|
||||
Self::Unauthorized(_) => "unauthorized",
|
||||
Self::Forbidden(_) => "forbidden",
|
||||
Self::Conflict(_) => "conflict",
|
||||
Self::PayloadTooLarge(_) => "payload_too_large",
|
||||
Self::UnsupportedMediaType(_) => "unsupported_media_type",
|
||||
Self::RateLimitExceeded => "rate_limit_exceeded",
|
||||
Self::Internal(_) => "internal_error",
|
||||
Self::ServiceUnavailable(_) => "service_unavailable",
|
||||
Self::AudioProcessing(_) => "audio_processing_error",
|
||||
Self::Embedding(_) => "embedding_error",
|
||||
Self::VectorIndex(_) => "vector_index_error",
|
||||
Self::Analysis(_) => "analysis_error",
|
||||
Self::Interpretation(_) => "interpretation_error",
|
||||
Self::Other(_) => "internal_error",
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a not found error for a specific resource type.
|
||||
#[must_use]
|
||||
pub fn not_found<T: std::fmt::Display>(resource: &str, id: T) -> Self {
|
||||
Self::NotFound(format!("{resource} with ID {id} not found"))
|
||||
}
|
||||
|
||||
/// Create a validation error with details.
|
||||
#[must_use]
|
||||
pub fn validation<T: Serialize>(message: &str, details: T) -> Self {
|
||||
Self::BadRequest(format!(
|
||||
"{}: {}",
|
||||
message,
|
||||
serde_json::to_string(&details).unwrap_or_default()
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for ApiError {
|
||||
fn into_response(self) -> Response {
|
||||
let status = self.status_code();
|
||||
let error_type = self.error_type();
|
||||
let message = self.to_string();
|
||||
|
||||
// Log internal errors
|
||||
match &self {
|
||||
Self::Internal(_)
|
||||
| Self::Other(_)
|
||||
| Self::Embedding(_)
|
||||
| Self::VectorIndex(_)
|
||||
| Self::Analysis(_)
|
||||
| Self::Interpretation(_) => {
|
||||
tracing::error!(error = %self, "Internal API error");
|
||||
}
|
||||
_ => {
|
||||
tracing::debug!(error = %self, "API error response");
|
||||
}
|
||||
}
|
||||
|
||||
let body = ErrorResponse {
|
||||
error: error_type.to_string(),
|
||||
message,
|
||||
details: None,
|
||||
request_id: None,
|
||||
};
|
||||
|
||||
(status, Json(body)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
/// Result type alias for API handlers.
|
||||
pub type ApiResult<T> = Result<T, ApiError>;
|
||||
|
||||
/// Extension trait for adding context to errors.
|
||||
pub trait ResultExt<T> {
|
||||
/// Convert error to `ApiError` with context.
|
||||
fn api_context(self, context: &str) -> ApiResult<T>;
|
||||
|
||||
/// Convert to not found error if None.
|
||||
fn or_not_found(self, resource: &str, id: &str) -> ApiResult<T>;
|
||||
}
|
||||
|
||||
impl<T, E: std::error::Error + Send + Sync + 'static> ResultExt<T> for Result<T, E> {
|
||||
fn api_context(self, context: &str) -> ApiResult<T> {
|
||||
self.map_err(|e| ApiError::Internal(format!("{context}: {e}")))
|
||||
}
|
||||
|
||||
fn or_not_found(self, _resource: &str, _id: &str) -> ApiResult<T> {
|
||||
self.map_err(|e| ApiError::Internal(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ResultExt<T> for Option<T> {
|
||||
fn api_context(self, context: &str) -> ApiResult<T> {
|
||||
self.ok_or_else(|| ApiError::Internal(format!("{context}: value was None")))
|
||||
}
|
||||
|
||||
fn or_not_found(self, resource: &str, id: &str) -> ApiResult<T> {
|
||||
self.ok_or_else(|| ApiError::not_found(resource, id))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_error_status_codes() {
|
||||
assert_eq!(
|
||||
ApiError::NotFound("test".into()).status_code(),
|
||||
StatusCode::NOT_FOUND
|
||||
);
|
||||
assert_eq!(
|
||||
ApiError::BadRequest("test".into()).status_code(),
|
||||
StatusCode::BAD_REQUEST
|
||||
);
|
||||
assert_eq!(
|
||||
ApiError::RateLimitExceeded.status_code(),
|
||||
StatusCode::TOO_MANY_REQUESTS
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_not_found_helper() {
|
||||
let err = ApiError::not_found("Recording", "abc-123");
|
||||
assert!(err.to_string().contains("Recording"));
|
||||
assert!(err.to_string().contains("abc-123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_response_serialization() {
|
||||
let response = ErrorResponse {
|
||||
error: "not_found".into(),
|
||||
message: "Resource not found".into(),
|
||||
details: None,
|
||||
request_id: Some("req-123".into()),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&response).unwrap();
|
||||
assert!(json.contains("not_found"));
|
||||
assert!(json.contains("req-123"));
|
||||
}
|
||||
}
|
||||
120
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/graphql/mod.rs
vendored
Normal file
120
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/graphql/mod.rs
vendored
Normal file
@@ -0,0 +1,120 @@
|
||||
//! GraphQL API module for 7sense.
|
||||
//!
|
||||
//! This module provides a flexible GraphQL API with:
|
||||
//! - Query operations for recordings, segments, clusters, and evidence
|
||||
//! - Mutations for ingestion and labeling
|
||||
//! - Subscriptions for real-time processing updates
|
||||
//!
|
||||
//! ## Schema
|
||||
//!
|
||||
//! The schema is defined using `async-graphql` with automatic type generation.
|
||||
//! Access the GraphQL playground at `/graphql` when enabled.
|
||||
|
||||
pub mod schema;
|
||||
pub mod types;
|
||||
|
||||
use async_graphql::Schema;
|
||||
use async_graphql_axum::{GraphQLRequest, GraphQLResponse};
|
||||
use axum::{
|
||||
extract::State,
|
||||
response::{Html, IntoResponse},
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
|
||||
use crate::AppContext;
|
||||
use schema::{MutationRoot, QueryRoot, SubscriptionRoot};
|
||||
|
||||
/// GraphQL schema type alias.
|
||||
pub type ApiSchema = Schema<QueryRoot, MutationRoot, SubscriptionRoot>;
|
||||
|
||||
/// Build the GraphQL schema with the application context.
|
||||
#[must_use]
|
||||
pub fn build_schema(ctx: AppContext) -> ApiSchema {
|
||||
Schema::build(QueryRoot, MutationRoot, SubscriptionRoot)
|
||||
.data(ctx)
|
||||
.finish()
|
||||
}
|
||||
|
||||
/// Create the GraphQL router.
|
||||
#[must_use]
|
||||
pub fn create_router(ctx: AppContext) -> Router<AppContext> {
|
||||
let schema = build_schema(ctx.clone());
|
||||
|
||||
Router::new()
|
||||
.route("/", get(graphql_playground).post(graphql_handler))
|
||||
.with_state(schema)
|
||||
}
|
||||
|
||||
/// GraphQL request handler.
|
||||
async fn graphql_handler(State(schema): State<ApiSchema>, req: GraphQLRequest) -> GraphQLResponse {
|
||||
schema.execute(req.into_inner()).await.into()
|
||||
}
|
||||
|
||||
/// GraphQL Playground HTML page.
|
||||
#[allow(clippy::unused_async)]
|
||||
async fn graphql_playground() -> impl IntoResponse {
|
||||
Html(PLAYGROUND_HTML)
|
||||
}
|
||||
|
||||
const PLAYGROUND_HTML: &str = r#"
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>7sense GraphQL Playground</title>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/graphql-playground-react@1.7.26/build/static/css/index.css" />
|
||||
<script src="https://cdn.jsdelivr.net/npm/graphql-playground-react@1.7.26/build/static/js/middleware.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script>
|
||||
window.addEventListener('load', function (event) {
|
||||
GraphQLPlayground.init(document.getElementById('root'), {
|
||||
endpoint: '/graphql',
|
||||
settings: {
|
||||
'editor.theme': 'dark',
|
||||
'editor.fontSize': 14,
|
||||
'request.credentials': 'include',
|
||||
},
|
||||
tabs: [
|
||||
{
|
||||
name: 'Example Queries',
|
||||
endpoint: '/graphql',
|
||||
query: `# 7sense GraphQL API
|
||||
|
||||
# Get all clusters
|
||||
query ListClusters {
|
||||
clusters {
|
||||
id
|
||||
label
|
||||
size
|
||||
density
|
||||
}
|
||||
}
|
||||
|
||||
# Find similar segments
|
||||
query FindNeighbors($segmentId: UUID!, $k: Int) {
|
||||
neighbors(segmentId: $segmentId, k: $k) {
|
||||
segmentId
|
||||
similarity
|
||||
startTime
|
||||
endTime
|
||||
}
|
||||
}
|
||||
`
|
||||
}
|
||||
]
|
||||
});
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"#;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Schema tests would go here with mock context
|
||||
}
|
||||
238
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/graphql/schema.rs
vendored
Normal file
238
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/graphql/schema.rs
vendored
Normal file
@@ -0,0 +1,238 @@
|
||||
//! GraphQL schema definitions for 7sense API.
|
||||
//!
|
||||
//! This module defines the Query, Mutation, and Subscription roots
|
||||
//! for the GraphQL API.
|
||||
|
||||
use async_graphql::*;
|
||||
use futures::Stream;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::types::*;
|
||||
use crate::{AppContext, ProcessingStatus};
|
||||
|
||||
/// Root query type for GraphQL API.
|
||||
pub struct QueryRoot;
|
||||
|
||||
#[Object]
|
||||
impl QueryRoot {
|
||||
/// Find similar segments (neighbors) for a given segment.
|
||||
async fn neighbors(
|
||||
&self,
|
||||
ctx: &Context<'_>,
|
||||
segment_id: ID,
|
||||
#[graphql(default = 10)] k: i32,
|
||||
#[graphql(default)] min_similarity: f32,
|
||||
) -> Result<Vec<Neighbor>> {
|
||||
let app_ctx = ctx.data::<AppContext>()?;
|
||||
|
||||
let segment_uuid = Uuid::parse_str(segment_id.as_str())
|
||||
.map_err(|_| Error::new("Invalid segment ID"))?;
|
||||
|
||||
// Get segment embedding
|
||||
let embedding = app_ctx
|
||||
.vector_index
|
||||
.get_embedding(&segment_uuid)
|
||||
.map_err(|e| Error::new(format!("Vector index error: {e}")))?
|
||||
.ok_or_else(|| Error::new(format!("Segment {} not found", segment_id.as_str())))?;
|
||||
|
||||
// Search for neighbors
|
||||
let results = app_ctx
|
||||
.vector_index
|
||||
.search(&embedding, k as usize, min_similarity)
|
||||
.map_err(|e| Error::new(format!("Search error: {e}")))?;
|
||||
|
||||
// Convert to GraphQL types
|
||||
let neighbors: Vec<Neighbor> = results
|
||||
.into_iter()
|
||||
.filter(|r| r.id != segment_uuid)
|
||||
.map(|r| Neighbor {
|
||||
segment_id: ID::from(r.id.to_string()),
|
||||
recording_id: ID::from(r.recording_id.to_string()),
|
||||
similarity: 1.0 - r.distance,
|
||||
distance: r.distance,
|
||||
start_time: r.start_time,
|
||||
end_time: r.end_time,
|
||||
species: r.species.map(|s| Species {
|
||||
common_name: s.common_name,
|
||||
scientific_name: s.scientific_name,
|
||||
confidence: s.confidence,
|
||||
}),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(neighbors)
|
||||
}
|
||||
|
||||
/// List all discovered clusters.
|
||||
async fn clusters(&self, ctx: &Context<'_>) -> Result<Vec<Cluster>> {
|
||||
let app_ctx = ctx.data::<AppContext>()?;
|
||||
|
||||
let cluster_data = app_ctx
|
||||
.cluster_engine
|
||||
.get_all_clusters()
|
||||
.map_err(|e| Error::new(format!("Analysis error: {e}")))?;
|
||||
|
||||
let clusters: Vec<Cluster> = cluster_data
|
||||
.into_iter()
|
||||
.map(|c| Cluster {
|
||||
id: ID::from(c.id.to_string()),
|
||||
label: c.label,
|
||||
size: c.size as i32,
|
||||
density: c.density,
|
||||
exemplar_ids: c.exemplar_ids.into_iter().map(|id| ID::from(id.to_string())).collect(),
|
||||
species_distribution: c
|
||||
.species_distribution
|
||||
.into_iter()
|
||||
.map(|(name, count, percentage)| SpeciesCount {
|
||||
name,
|
||||
scientific_name: None,
|
||||
count: count as i32,
|
||||
percentage,
|
||||
})
|
||||
.collect(),
|
||||
created_at: c.created_at,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(clusters)
|
||||
}
|
||||
|
||||
/// Get a specific cluster by ID.
|
||||
async fn cluster(&self, ctx: &Context<'_>, id: ID) -> Result<Option<Cluster>> {
|
||||
let app_ctx = ctx.data::<AppContext>()?;
|
||||
|
||||
let cluster_uuid =
|
||||
Uuid::parse_str(id.as_str()).map_err(|_| Error::new("Invalid cluster ID"))?;
|
||||
|
||||
let cluster_data = app_ctx
|
||||
.cluster_engine
|
||||
.get_cluster(&cluster_uuid)
|
||||
.map_err(|e| Error::new(format!("Analysis error: {e}")))?;
|
||||
|
||||
Ok(cluster_data.map(|c| Cluster {
|
||||
id: ID::from(c.id.to_string()),
|
||||
label: c.label,
|
||||
size: c.size as i32,
|
||||
density: c.density,
|
||||
exemplar_ids: c.exemplar_ids.into_iter().map(|id| ID::from(id.to_string())).collect(),
|
||||
species_distribution: c
|
||||
.species_distribution
|
||||
.into_iter()
|
||||
.map(|(name, count, percentage)| SpeciesCount {
|
||||
name,
|
||||
scientific_name: None,
|
||||
count: count as i32,
|
||||
percentage,
|
||||
})
|
||||
.collect(),
|
||||
created_at: c.created_at,
|
||||
}))
|
||||
}
|
||||
|
||||
/// System health check.
|
||||
async fn health(&self) -> HealthStatus {
|
||||
HealthStatus {
|
||||
status: "healthy".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Root mutation type for GraphQL API.
|
||||
pub struct MutationRoot;
|
||||
|
||||
#[Object]
|
||||
impl MutationRoot {
|
||||
/// Assign a label to a cluster.
|
||||
async fn assign_label(
|
||||
&self,
|
||||
ctx: &Context<'_>,
|
||||
cluster_id: ID,
|
||||
label: String,
|
||||
) -> Result<Cluster> {
|
||||
let app_ctx = ctx.data::<AppContext>()?;
|
||||
|
||||
let cluster_uuid =
|
||||
Uuid::parse_str(cluster_id.as_str()).map_err(|_| Error::new("Invalid cluster ID"))?;
|
||||
|
||||
let cluster_data = app_ctx
|
||||
.cluster_engine
|
||||
.assign_label(&cluster_uuid, &label)
|
||||
.map_err(|e| Error::new(format!("Analysis error: {e}")))?
|
||||
.ok_or_else(|| Error::new(format!("Cluster {} not found", cluster_id.as_str())))?;
|
||||
|
||||
Ok(Cluster {
|
||||
id: ID::from(cluster_data.id.to_string()),
|
||||
label: cluster_data.label,
|
||||
size: cluster_data.size as i32,
|
||||
density: cluster_data.density,
|
||||
exemplar_ids: cluster_data.exemplar_ids.into_iter().map(|id| ID::from(id.to_string())).collect(),
|
||||
species_distribution: cluster_data
|
||||
.species_distribution
|
||||
.into_iter()
|
||||
.map(|(name, count, percentage)| SpeciesCount {
|
||||
name,
|
||||
scientific_name: None,
|
||||
count: count as i32,
|
||||
percentage,
|
||||
})
|
||||
.collect(),
|
||||
created_at: cluster_data.created_at,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Root subscription type for GraphQL API.
|
||||
pub struct SubscriptionRoot;
|
||||
|
||||
#[Subscription]
|
||||
impl SubscriptionRoot {
|
||||
/// Subscribe to processing status updates for a recording.
|
||||
async fn processing_status(
|
||||
&self,
|
||||
ctx: &Context<'_>,
|
||||
recording_id: ID,
|
||||
) -> Result<impl Stream<Item = ProcessingUpdate>> {
|
||||
let app_ctx = ctx.data::<AppContext>()?;
|
||||
let recording_uuid = Uuid::parse_str(recording_id.as_str())
|
||||
.map_err(|_| Error::new("Invalid recording ID"))?;
|
||||
|
||||
let mut rx = app_ctx.subscribe_events();
|
||||
|
||||
Ok(async_stream::stream! {
|
||||
while let Ok(event) = rx.recv().await {
|
||||
if event.recording_id == recording_uuid {
|
||||
yield ProcessingUpdate {
|
||||
recording_id: ID::from(event.recording_id.to_string()),
|
||||
status: match event.status {
|
||||
ProcessingStatus::Queued => ProcessingStatusGql::Queued,
|
||||
ProcessingStatus::Loading => ProcessingStatusGql::Loading,
|
||||
ProcessingStatus::Segmenting => ProcessingStatusGql::Segmenting,
|
||||
ProcessingStatus::Embedding => ProcessingStatusGql::Embedding,
|
||||
ProcessingStatus::Indexing => ProcessingStatusGql::Indexing,
|
||||
ProcessingStatus::Analyzing => ProcessingStatusGql::Analyzing,
|
||||
ProcessingStatus::Complete => ProcessingStatusGql::Complete,
|
||||
ProcessingStatus::Failed => ProcessingStatusGql::Failed,
|
||||
},
|
||||
progress: event.progress,
|
||||
message: event.message,
|
||||
};
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_health_status() {
|
||||
let status = HealthStatus {
|
||||
status: "healthy".to_string(),
|
||||
version: "0.1.0".to_string(),
|
||||
};
|
||||
assert_eq!(status.status, "healthy");
|
||||
}
|
||||
}
|
||||
131
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/graphql/types.rs
vendored
Normal file
131
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/graphql/types.rs
vendored
Normal file
@@ -0,0 +1,131 @@
|
||||
//! GraphQL type definitions for 7sense API.
|
||||
//!
|
||||
//! This module contains all GraphQL object types, input types, and enums
|
||||
//! used by the schema.
|
||||
|
||||
use async_graphql::*;
|
||||
use chrono::{DateTime, Utc};
|
||||
|
||||
// ============================================================================
|
||||
// Object Types
|
||||
// ============================================================================
|
||||
|
||||
/// A similar segment found through vector search.
|
||||
#[derive(Debug, Clone, SimpleObject)]
|
||||
pub struct Neighbor {
|
||||
/// Segment identifier
|
||||
pub segment_id: ID,
|
||||
/// Parent recording ID
|
||||
pub recording_id: ID,
|
||||
/// Similarity score (0.0 to 1.0)
|
||||
pub similarity: f32,
|
||||
/// Distance in embedding space
|
||||
pub distance: f32,
|
||||
/// Segment start time
|
||||
pub start_time: f64,
|
||||
/// Segment end time
|
||||
pub end_time: f64,
|
||||
/// Detected species
|
||||
pub species: Option<Species>,
|
||||
}
|
||||
|
||||
/// Species information.
|
||||
#[derive(Debug, Clone, SimpleObject)]
|
||||
pub struct Species {
|
||||
/// Common name
|
||||
pub common_name: String,
|
||||
/// Scientific name (binomial)
|
||||
pub scientific_name: Option<String>,
|
||||
/// Detection confidence
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
/// A cluster of similar calls.
|
||||
#[derive(Debug, Clone, SimpleObject)]
|
||||
pub struct Cluster {
|
||||
/// Cluster identifier
|
||||
pub id: ID,
|
||||
/// Human-assigned label
|
||||
pub label: Option<String>,
|
||||
/// Number of segments in cluster
|
||||
pub size: i32,
|
||||
/// Cluster density/compactness
|
||||
pub density: f32,
|
||||
/// Representative segment IDs
|
||||
pub exemplar_ids: Vec<ID>,
|
||||
/// Species distribution
|
||||
pub species_distribution: Vec<SpeciesCount>,
|
||||
/// Creation timestamp
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Species count within a cluster.
|
||||
#[derive(Debug, Clone, SimpleObject)]
|
||||
pub struct SpeciesCount {
|
||||
/// Species common name
|
||||
pub name: String,
|
||||
/// Scientific name
|
||||
pub scientific_name: Option<String>,
|
||||
/// Count of segments
|
||||
pub count: i32,
|
||||
/// Percentage of cluster
|
||||
pub percentage: f64,
|
||||
}
|
||||
|
||||
/// Processing status update.
|
||||
#[derive(Debug, Clone, SimpleObject)]
|
||||
pub struct ProcessingUpdate {
|
||||
/// Recording ID
|
||||
pub recording_id: ID,
|
||||
/// Current status
|
||||
pub status: ProcessingStatusGql,
|
||||
/// Progress (0.0 to 1.0)
|
||||
pub progress: f32,
|
||||
/// Status message
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
/// Health status response.
|
||||
#[derive(Debug, Clone, SimpleObject)]
|
||||
pub struct HealthStatus {
|
||||
/// Service status
|
||||
pub status: String,
|
||||
/// Version
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Enums
|
||||
// ============================================================================
|
||||
|
||||
/// Processing status stages.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Enum)]
|
||||
pub enum ProcessingStatusGql {
|
||||
/// Queued for processing
|
||||
Queued,
|
||||
/// Loading audio
|
||||
Loading,
|
||||
/// Segmenting
|
||||
Segmenting,
|
||||
/// Generating embeddings
|
||||
Embedding,
|
||||
/// Indexing vectors
|
||||
Indexing,
|
||||
/// Analyzing clusters
|
||||
Analyzing,
|
||||
/// Complete
|
||||
Complete,
|
||||
/// Failed
|
||||
Failed,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_processing_status_enum() {
|
||||
let status = ProcessingStatusGql::Embedding;
|
||||
assert_eq!(status, ProcessingStatusGql::Embedding);
|
||||
}
|
||||
}
|
||||
316
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/lib.rs
vendored
Normal file
316
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/lib.rs
vendored
Normal file
@@ -0,0 +1,316 @@
|
||||
//! # sevensense-api
|
||||
//!
|
||||
//! REST, GraphQL, and WebSocket API layer for 7sense bioacoustic analysis.
|
||||
//!
|
||||
//! This crate provides a comprehensive API for:
|
||||
//! - Audio recording upload and processing
|
||||
//! - Segment similarity search via vector embeddings
|
||||
//! - Cluster discovery and labeling
|
||||
//! - Evidence pack generation for interpretability
|
||||
//! - Real-time processing status via WebSocket
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! The API follows a layered architecture:
|
||||
//! - **REST API** (`/api/v1/*`) - RESTful endpoints for CRUD operations
|
||||
//! - **GraphQL** (`/graphql`) - Flexible query interface with subscriptions
|
||||
//! - **WebSocket** (`/ws`) - Real-time updates for long-running operations
|
||||
//!
|
||||
//! ## Example
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use sevensense_api::{AppBuilder, Config};
|
||||
//!
|
||||
//! #[tokio::main]
|
||||
//! async fn main() -> anyhow::Result<()> {
|
||||
//! let config = Config::from_env()?;
|
||||
//! let app = AppBuilder::new(config).build().await?;
|
||||
//!
|
||||
//! axum::serve(listener, app).await?;
|
||||
//! Ok(())
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
#![warn(missing_docs)]
|
||||
#![warn(clippy::all)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
#![allow(clippy::missing_errors_doc)]
|
||||
#![allow(clippy::missing_panics_doc)]
|
||||
|
||||
pub mod error;
|
||||
pub mod graphql;
|
||||
pub mod openapi;
|
||||
pub mod rest;
|
||||
pub mod services;
|
||||
pub mod websocket;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::Router;
|
||||
use tokio::sync::broadcast;
|
||||
use tower_http::{
|
||||
compression::CompressionLayer,
|
||||
trace::TraceLayer,
|
||||
};
|
||||
|
||||
pub use services::{
|
||||
AudioPipeline, ClusterEngine, EmbeddingModel, InterpretationEngine, VectorIndex,
|
||||
};
|
||||
|
||||
/// Crate version information
|
||||
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
/// Application configuration loaded from environment or config file.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
/// Server host address
|
||||
pub host: String,
|
||||
/// Server port
|
||||
pub port: u16,
|
||||
/// CORS allowed origins
|
||||
pub cors_origins: Vec<String>,
|
||||
/// Rate limit requests per second
|
||||
pub rate_limit_rps: u32,
|
||||
/// Maximum upload size in bytes
|
||||
pub max_upload_size: usize,
|
||||
/// Enable GraphQL playground
|
||||
pub enable_playground: bool,
|
||||
/// API key for authentication (optional)
|
||||
pub api_key: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
host: "0.0.0.0".to_string(),
|
||||
port: 8080,
|
||||
cors_origins: vec!["*".to_string()],
|
||||
rate_limit_rps: 100,
|
||||
max_upload_size: 100 * 1024 * 1024, // 100MB
|
||||
enable_playground: true,
|
||||
api_key: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Load configuration from environment variables.
|
||||
pub fn from_env() -> anyhow::Result<Self> {
|
||||
dotenvy::dotenv().ok();
|
||||
|
||||
Ok(Self {
|
||||
host: std::env::var("SEVENSENSE_HOST").unwrap_or_else(|_| "0.0.0.0".to_string()),
|
||||
port: std::env::var("SEVENSENSE_PORT")
|
||||
.ok()
|
||||
.and_then(|p| p.parse().ok())
|
||||
.unwrap_or(8080),
|
||||
cors_origins: std::env::var("SEVENSENSE_CORS_ORIGINS")
|
||||
.map(|s| s.split(',').map(String::from).collect())
|
||||
.unwrap_or_else(|_| vec!["*".to_string()]),
|
||||
rate_limit_rps: std::env::var("SEVENSENSE_RATE_LIMIT")
|
||||
.ok()
|
||||
.and_then(|r| r.parse().ok())
|
||||
.unwrap_or(100),
|
||||
max_upload_size: std::env::var("SEVENSENSE_MAX_UPLOAD")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(100 * 1024 * 1024),
|
||||
enable_playground: std::env::var("SEVENSENSE_ENABLE_PLAYGROUND")
|
||||
.map(|s| s == "true" || s == "1")
|
||||
.unwrap_or(true),
|
||||
api_key: std::env::var("SEVENSENSE_API_KEY").ok(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Processing status event for WebSocket broadcasts.
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
pub struct ProcessingEvent {
|
||||
/// Recording identifier
|
||||
pub recording_id: uuid::Uuid,
|
||||
/// Current processing status
|
||||
pub status: ProcessingStatus,
|
||||
/// Progress percentage (0.0 to 1.0)
|
||||
pub progress: f32,
|
||||
/// Optional status message
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
/// Processing status stages.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize, utoipa::ToSchema)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ProcessingStatus {
|
||||
/// Recording queued for processing
|
||||
Queued,
|
||||
/// Loading audio file
|
||||
Loading,
|
||||
/// Detecting segments
|
||||
Segmenting,
|
||||
/// Generating embeddings
|
||||
Embedding,
|
||||
/// Adding to vector index
|
||||
Indexing,
|
||||
/// Running cluster analysis
|
||||
Analyzing,
|
||||
/// Processing complete
|
||||
Complete,
|
||||
/// Processing failed
|
||||
Failed,
|
||||
}
|
||||
|
||||
/// Shared application context accessible from all handlers.
|
||||
#[derive(Clone)]
|
||||
pub struct AppContext {
|
||||
/// Audio processing pipeline
|
||||
pub audio_pipeline: Arc<AudioPipeline>,
|
||||
/// Embedding model for segment vectorization
|
||||
pub embedding_model: Arc<EmbeddingModel>,
|
||||
/// Vector index for similarity search
|
||||
pub vector_index: Arc<VectorIndex>,
|
||||
/// Cluster analysis engine
|
||||
pub cluster_engine: Arc<ClusterEngine>,
|
||||
/// Interpretation engine for evidence packs
|
||||
pub interpretation_engine: Arc<InterpretationEngine>,
|
||||
/// Broadcast channel for processing events
|
||||
pub event_tx: broadcast::Sender<ProcessingEvent>,
|
||||
/// Application configuration
|
||||
pub config: Arc<Config>,
|
||||
}
|
||||
|
||||
impl AppContext {
|
||||
/// Create a new application context with all required services.
|
||||
pub async fn new(config: Config) -> anyhow::Result<Self> {
|
||||
// Initialize audio pipeline
|
||||
let audio_pipeline = Arc::new(AudioPipeline::new(Default::default())?);
|
||||
|
||||
// Initialize embedding model
|
||||
let embedding_model = Arc::new(EmbeddingModel::new(Default::default()).await?);
|
||||
|
||||
// Initialize vector index
|
||||
let vector_index = Arc::new(VectorIndex::new(Default::default())?);
|
||||
|
||||
// Initialize cluster engine
|
||||
let cluster_engine = Arc::new(ClusterEngine::new(Default::default())?);
|
||||
|
||||
// Initialize interpretation engine
|
||||
let interpretation_engine = Arc::new(InterpretationEngine::new(Default::default())?);
|
||||
|
||||
// Create broadcast channel for events (capacity of 1024)
|
||||
let (event_tx, _) = broadcast::channel(1024);
|
||||
|
||||
Ok(Self {
|
||||
audio_pipeline,
|
||||
embedding_model,
|
||||
vector_index,
|
||||
cluster_engine,
|
||||
interpretation_engine,
|
||||
event_tx,
|
||||
config: Arc::new(config),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a receiver for processing events.
|
||||
#[must_use]
|
||||
pub fn subscribe_events(&self) -> broadcast::Receiver<ProcessingEvent> {
|
||||
self.event_tx.subscribe()
|
||||
}
|
||||
|
||||
/// Publish a processing event.
|
||||
pub fn publish_event(&self, event: ProcessingEvent) {
|
||||
// Ignore send errors (no receivers)
|
||||
let _ = self.event_tx.send(event);
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for constructing the application router.
|
||||
pub struct AppBuilder {
|
||||
config: Config,
|
||||
context: Option<AppContext>,
|
||||
}
|
||||
|
||||
impl AppBuilder {
|
||||
/// Create a new app builder with configuration.
|
||||
#[must_use]
|
||||
pub fn new(config: Config) -> Self {
|
||||
Self {
|
||||
config,
|
||||
context: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a pre-built context (useful for testing).
|
||||
#[must_use]
|
||||
pub fn with_context(mut self, context: AppContext) -> Self {
|
||||
self.context = Some(context);
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the complete application router.
|
||||
pub async fn build(self) -> anyhow::Result<Router> {
|
||||
// Initialize context if not provided
|
||||
let context = match self.context {
|
||||
Some(ctx) => ctx,
|
||||
None => AppContext::new(self.config.clone()).await?,
|
||||
};
|
||||
|
||||
// Build REST routes
|
||||
let rest_router = rest::routes::create_router(context.clone());
|
||||
|
||||
// Build GraphQL routes
|
||||
let graphql_router = graphql::create_router(context.clone());
|
||||
|
||||
// Build WebSocket routes
|
||||
let ws_router = websocket::create_router(context.clone());
|
||||
|
||||
// Build OpenAPI documentation routes
|
||||
let openapi_router = openapi::create_router();
|
||||
|
||||
// Combine all routers
|
||||
let app = Router::new()
|
||||
.nest("/api/v1", rest_router)
|
||||
.nest("/graphql", graphql_router)
|
||||
.nest("/ws", ws_router)
|
||||
.nest("/docs", openapi_router)
|
||||
.layer(
|
||||
tower::ServiceBuilder::new()
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(CompressionLayer::new())
|
||||
.layer(rest::middleware::cors_layer(&self.config)),
|
||||
)
|
||||
.with_state(context);
|
||||
|
||||
Ok(app)
|
||||
}
|
||||
}
|
||||
|
||||
/// Health check response.
|
||||
#[derive(Debug, serde::Serialize, utoipa::ToSchema)]
|
||||
pub struct HealthResponse {
|
||||
/// Service status
|
||||
pub status: String,
|
||||
/// API version
|
||||
pub version: String,
|
||||
/// Server uptime in seconds
|
||||
pub uptime_secs: u64,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_config_default() {
|
||||
let config = Config::default();
|
||||
assert_eq!(config.host, "0.0.0.0");
|
||||
assert_eq!(config.port, 8080);
|
||||
assert!(config.enable_playground);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_processing_status_serialize() {
|
||||
let status = ProcessingStatus::Embedding;
|
||||
let json = serde_json::to_string(&status).unwrap();
|
||||
assert_eq!(json, "\"embedding\"");
|
||||
}
|
||||
}
|
||||
168
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/main.rs
vendored
Normal file
168
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/main.rs
vendored
Normal file
@@ -0,0 +1,168 @@
|
||||
//! 7sense API Server
|
||||
//!
|
||||
//! This is the main entry point for the 7sense bioacoustic analysis API server.
|
||||
//! It provides REST, GraphQL, and WebSocket endpoints for audio processing,
|
||||
//! similarity search, and cluster discovery.
|
||||
//!
|
||||
//! ## Usage
|
||||
//!
|
||||
//! ```bash
|
||||
//! # Run with default settings
|
||||
//! cargo run --release
|
||||
//!
|
||||
//! # With environment configuration
|
||||
//! SEVENSENSE_PORT=3000 SEVENSENSE_API_KEY=secret cargo run --release
|
||||
//! ```
|
||||
//!
|
||||
//! ## Endpoints
|
||||
//!
|
||||
//! - REST API: `http://localhost:8080/api/v1/`
|
||||
//! - GraphQL: `http://localhost:8080/graphql`
|
||||
//! - GraphQL Playground: `http://localhost:8080/graphql` (GET)
|
||||
//! - WebSocket: `ws://localhost:8080/ws/`
|
||||
//! - OpenAPI/Swagger: `http://localhost:8080/docs/swagger`
|
||||
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use anyhow::Result;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::signal;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
||||
|
||||
use sevensense_api::{AppBuilder, Config};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// Initialize tracing/logging
|
||||
init_tracing();
|
||||
|
||||
// Load configuration
|
||||
let config = Config::from_env()?;
|
||||
|
||||
tracing::info!(
|
||||
host = %config.host,
|
||||
port = %config.port,
|
||||
"Starting 7sense API server"
|
||||
);
|
||||
|
||||
// Build application
|
||||
let app = AppBuilder::new(config.clone()).build().await?;
|
||||
|
||||
// Bind to address
|
||||
let addr: SocketAddr = format!("{}:{}", config.host, config.port).parse()?;
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
|
||||
tracing::info!(
|
||||
address = %addr,
|
||||
"7sense API server listening"
|
||||
);
|
||||
|
||||
// Print startup banner
|
||||
print_banner(&config);
|
||||
|
||||
// Run server with graceful shutdown
|
||||
axum::serve(listener, app)
|
||||
.with_graceful_shutdown(shutdown_signal())
|
||||
.await?;
|
||||
|
||||
tracing::info!("7sense API server shut down gracefully");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Initialize tracing subscriber with environment filter.
|
||||
fn init_tracing() {
|
||||
let env_filter = EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new("sevensense_api=info,tower_http=info,axum=info"));
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(env_filter)
|
||||
.with(
|
||||
tracing_subscriber::fmt::layer()
|
||||
.with_target(true)
|
||||
.with_thread_ids(false)
|
||||
.with_file(false)
|
||||
.with_line_number(false),
|
||||
)
|
||||
.init();
|
||||
}
|
||||
|
||||
/// Print startup banner with endpoint information.
|
||||
fn print_banner(config: &Config) {
|
||||
let base_url = format!("http://{}:{}", config.host, config.port);
|
||||
|
||||
println!();
|
||||
println!("========================================");
|
||||
println!(" 7sense Bioacoustic API");
|
||||
println!("========================================");
|
||||
println!();
|
||||
println!(" REST API: {base_url}/api/v1/");
|
||||
println!(" GraphQL: {base_url}/graphql");
|
||||
println!(" WebSocket: ws://{}:{}/ws/", config.host, config.port);
|
||||
println!(" Swagger UI: {base_url}/docs/swagger");
|
||||
println!(" OpenAPI: {base_url}/docs/openapi.json");
|
||||
println!();
|
||||
println!(" Health: {base_url}/api/v1/health");
|
||||
println!();
|
||||
|
||||
if config.api_key.is_some() {
|
||||
println!(" Auth: API key required (Bearer token)");
|
||||
} else {
|
||||
println!(" Auth: No authentication (development mode)");
|
||||
}
|
||||
|
||||
if config.enable_playground {
|
||||
println!(" Playground: Enabled");
|
||||
}
|
||||
|
||||
println!();
|
||||
println!(" Rate limit: {} req/sec", config.rate_limit_rps);
|
||||
println!(
|
||||
" Max upload: {} MB",
|
||||
config.max_upload_size / 1024 / 1024
|
||||
);
|
||||
println!();
|
||||
println!("========================================");
|
||||
println!();
|
||||
}
|
||||
|
||||
/// Create shutdown signal handler for graceful shutdown.
|
||||
async fn shutdown_signal() {
|
||||
let ctrl_c = async {
|
||||
signal::ctrl_c()
|
||||
.await
|
||||
.expect("Failed to install Ctrl+C handler");
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||
.expect("Failed to install signal handler")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let terminate = std::future::pending::<()>();
|
||||
|
||||
tokio::select! {
|
||||
() = ctrl_c => {
|
||||
tracing::info!("Received Ctrl+C, initiating shutdown");
|
||||
}
|
||||
() = terminate => {
|
||||
tracing::info!("Received SIGTERM, initiating shutdown");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_config_parsing() {
|
||||
let config = Config::default();
|
||||
assert_eq!(config.host, "0.0.0.0");
|
||||
assert_eq!(config.port, 8080);
|
||||
}
|
||||
}
|
||||
148
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/openapi.rs
vendored
Normal file
148
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/openapi.rs
vendored
Normal file
@@ -0,0 +1,148 @@
|
||||
//! OpenAPI/Swagger documentation generation.
|
||||
//!
|
||||
//! This module generates OpenAPI 3.0 documentation for the REST API
|
||||
//! and serves a Swagger UI at `/docs`.
|
||||
|
||||
use axum::{routing::get, Json, Router};
|
||||
use utoipa::{
|
||||
openapi::{
|
||||
security::{HttpAuthScheme, HttpBuilder, SecurityScheme},
|
||||
OpenApi as OpenApiDoc,
|
||||
},
|
||||
Modify, OpenApi,
|
||||
};
|
||||
use utoipa_swagger_ui::SwaggerUi;
|
||||
|
||||
use crate::error::ErrorResponse;
|
||||
use crate::rest::handlers::*;
|
||||
use crate::AppContext;
|
||||
use crate::HealthResponse;
|
||||
|
||||
/// OpenAPI documentation struct.
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
info(
|
||||
title = "7sense Bioacoustic Analysis API",
|
||||
version = "1.0.0",
|
||||
description = "REST API for bioacoustic recording analysis, similarity search, and cluster discovery.",
|
||||
contact(
|
||||
name = "7sense Team",
|
||||
url = "https://github.com/vibecast/vibecast"
|
||||
),
|
||||
license(
|
||||
name = "MIT OR Apache-2.0",
|
||||
url = "https://opensource.org/licenses/MIT"
|
||||
)
|
||||
),
|
||||
servers(
|
||||
(url = "/api/v1", description = "API v1")
|
||||
),
|
||||
paths(
|
||||
upload_recording,
|
||||
get_recording,
|
||||
get_neighbors,
|
||||
list_clusters,
|
||||
get_cluster,
|
||||
assign_cluster_label,
|
||||
get_evidence_pack,
|
||||
generate_evidence_pack,
|
||||
search,
|
||||
health_check,
|
||||
),
|
||||
components(
|
||||
schemas(
|
||||
Recording,
|
||||
UploadResponse,
|
||||
Neighbor,
|
||||
NeighborParams,
|
||||
Cluster,
|
||||
SpeciesCount,
|
||||
EvidencePack,
|
||||
SegmentSummary,
|
||||
NeighborEvidence,
|
||||
FeatureContribution,
|
||||
AcousticFeature,
|
||||
EvidenceVisualizations,
|
||||
SearchQuery,
|
||||
SearchResults,
|
||||
SearchQueryEcho,
|
||||
SearchResult,
|
||||
AssignLabelRequest,
|
||||
GenerateEvidenceRequest,
|
||||
ErrorResponse,
|
||||
HealthResponse,
|
||||
)
|
||||
),
|
||||
modifiers(&SecurityAddon),
|
||||
tags(
|
||||
(name = "recordings", description = "Recording upload and management"),
|
||||
(name = "segments", description = "Segment analysis and similarity search"),
|
||||
(name = "clusters", description = "Cluster discovery and labeling"),
|
||||
(name = "evidence", description = "Evidence packs for interpretability"),
|
||||
(name = "search", description = "Semantic search"),
|
||||
(name = "system", description = "System health and status"),
|
||||
)
|
||||
)]
|
||||
pub struct ApiDoc;
|
||||
|
||||
/// Security scheme addon.
|
||||
struct SecurityAddon;
|
||||
|
||||
impl Modify for SecurityAddon {
|
||||
fn modify(&self, openapi: &mut OpenApiDoc) {
|
||||
if let Some(components) = openapi.components.as_mut() {
|
||||
components.add_security_scheme(
|
||||
"bearer_auth",
|
||||
SecurityScheme::Http(
|
||||
HttpBuilder::new()
|
||||
.scheme(HttpAuthScheme::Bearer)
|
||||
.bearer_format("JWT")
|
||||
.description(Some("API key authentication"))
|
||||
.build(),
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create the OpenAPI documentation router.
|
||||
#[must_use]
|
||||
pub fn create_router() -> Router<AppContext> {
|
||||
Router::new()
|
||||
// Raw OpenAPI JSON
|
||||
.route("/openapi.json", get(openapi_json))
|
||||
// Swagger UI - merge directly
|
||||
.merge(SwaggerUi::new("/docs/swagger-ui")
|
||||
.url("/docs/openapi.json", ApiDoc::openapi()))
|
||||
}
|
||||
|
||||
/// Get raw OpenAPI JSON.
|
||||
#[allow(clippy::unused_async)]
|
||||
async fn openapi_json() -> Json<utoipa::openapi::OpenApi> {
|
||||
Json(ApiDoc::openapi())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_openapi_generation() {
|
||||
let doc = ApiDoc::openapi();
|
||||
assert_eq!(doc.info.title, "7sense Bioacoustic Analysis API");
|
||||
assert!(!doc.paths.paths.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openapi_has_required_paths() {
|
||||
let doc = ApiDoc::openapi();
|
||||
let paths: Vec<&str> = doc.paths.paths.keys().map(std::string::String::as_str).collect();
|
||||
|
||||
assert!(paths.contains(&"/recordings"));
|
||||
assert!(paths.contains(&"/segments/{id}/neighbors"));
|
||||
assert!(paths.contains(&"/clusters"));
|
||||
assert!(paths.contains(&"/evidence/{id}"));
|
||||
assert!(paths.contains(&"/search"));
|
||||
assert!(paths.contains(&"/health"));
|
||||
}
|
||||
}
|
||||
1023
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/rest/handlers.rs
vendored
Normal file
1023
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/rest/handlers.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
266
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/rest/middleware.rs
vendored
Normal file
266
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/rest/middleware.rs
vendored
Normal file
@@ -0,0 +1,266 @@
|
||||
//! REST API middleware for cross-cutting concerns.
|
||||
//!
|
||||
//! This module provides:
|
||||
//! - CORS configuration
|
||||
//! - Rate limiting
|
||||
//! - API key authentication
|
||||
//! - Request logging
|
||||
|
||||
use std::{
|
||||
net::SocketAddr,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{ConnectInfo, State},
|
||||
http::{header, HeaderMap, Request, StatusCode},
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use governor::{
|
||||
clock::DefaultClock,
|
||||
state::{InMemoryState, NotKeyed},
|
||||
Quota, RateLimiter,
|
||||
};
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
|
||||
use crate::{error::ErrorResponse, AppContext, Config};
|
||||
|
||||
/// Create CORS layer based on configuration.
|
||||
pub fn cors_layer(config: &Config) -> CorsLayer {
|
||||
let cors = CorsLayer::new()
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any)
|
||||
.max_age(Duration::from_secs(3600));
|
||||
|
||||
if config.cors_origins.contains(&"*".to_string()) {
|
||||
cors.allow_origin(Any)
|
||||
} else {
|
||||
// Parse origins - in production, validate these
|
||||
cors.allow_origin(Any) // Simplified for now
|
||||
}
|
||||
}
|
||||
|
||||
/// Rate limiter type alias.
|
||||
pub type SharedRateLimiter = Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>;
|
||||
|
||||
/// Create a rate limiter with the configured limit.
|
||||
pub fn create_rate_limiter(rps: u32) -> SharedRateLimiter {
|
||||
let quota = Quota::per_second(std::num::NonZeroU32::new(rps).unwrap());
|
||||
Arc::new(RateLimiter::direct(quota))
|
||||
}
|
||||
|
||||
/// Rate limiting middleware.
|
||||
pub async fn rate_limit_middleware(
|
||||
State(limiter): State<SharedRateLimiter>,
|
||||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
match limiter.check() {
|
||||
Ok(_) => next.run(request).await,
|
||||
Err(_) => {
|
||||
let body = ErrorResponse {
|
||||
error: "rate_limit_exceeded".into(),
|
||||
message: "Too many requests. Please slow down.".into(),
|
||||
details: None,
|
||||
request_id: None,
|
||||
};
|
||||
(StatusCode::TOO_MANY_REQUESTS, Json(body)).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// API key authentication middleware.
|
||||
pub async fn auth_middleware(
|
||||
State(ctx): State<AppContext>,
|
||||
headers: HeaderMap,
|
||||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
// If no API key configured, allow all requests
|
||||
let Some(expected_key) = &ctx.config.api_key else {
|
||||
return next.run(request).await;
|
||||
};
|
||||
|
||||
// Check Authorization header
|
||||
let auth_header = headers
|
||||
.get(header::AUTHORIZATION)
|
||||
.and_then(|h| h.to_str().ok());
|
||||
|
||||
match auth_header {
|
||||
Some(auth) if auth.starts_with("Bearer ") => {
|
||||
let provided_key = auth.trim_start_matches("Bearer ").trim();
|
||||
if provided_key == expected_key {
|
||||
next.run(request).await
|
||||
} else {
|
||||
unauthorized_response("Invalid API key")
|
||||
}
|
||||
}
|
||||
Some(_) => unauthorized_response("Invalid authorization format. Use 'Bearer <api_key>'"),
|
||||
None => unauthorized_response("Missing Authorization header"),
|
||||
}
|
||||
}
|
||||
|
||||
fn unauthorized_response(message: &str) -> Response {
|
||||
let body = ErrorResponse {
|
||||
error: "unauthorized".into(),
|
||||
message: message.into(),
|
||||
details: None,
|
||||
request_id: None,
|
||||
};
|
||||
(StatusCode::UNAUTHORIZED, Json(body)).into_response()
|
||||
}
|
||||
|
||||
/// Request logging middleware that adds structured logging.
|
||||
pub async fn logging_middleware(
|
||||
headers: HeaderMap,
|
||||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let method = request.method().clone();
|
||||
let uri = request.uri().clone();
|
||||
let request_id = headers
|
||||
.get("x-request-id")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let response = next.run(request).await;
|
||||
|
||||
let latency = start.elapsed();
|
||||
let status = response.status();
|
||||
|
||||
tracing::info!(
|
||||
method = %method,
|
||||
uri = %uri,
|
||||
status = %status.as_u16(),
|
||||
latency_ms = %latency.as_millis(),
|
||||
request_id = ?request_id,
|
||||
"HTTP request"
|
||||
);
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
/// Content type validation middleware for JSON endpoints.
|
||||
pub async fn json_content_type_middleware(
|
||||
headers: HeaderMap,
|
||||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
// Only check POST/PUT/PATCH requests
|
||||
if matches!(
|
||||
request.method().as_str(),
|
||||
"POST" | "PUT" | "PATCH"
|
||||
) {
|
||||
// Skip multipart endpoints
|
||||
let path = request.uri().path();
|
||||
if path.contains("/recordings") {
|
||||
return next.run(request).await;
|
||||
}
|
||||
|
||||
// Check content type
|
||||
let content_type = headers
|
||||
.get(header::CONTENT_TYPE)
|
||||
.and_then(|h| h.to_str().ok());
|
||||
|
||||
match content_type {
|
||||
Some(ct) if ct.contains("application/json") => next.run(request).await,
|
||||
Some(ct) => {
|
||||
let body = ErrorResponse {
|
||||
error: "unsupported_media_type".into(),
|
||||
message: format!("Expected application/json, got {}", ct),
|
||||
details: None,
|
||||
request_id: None,
|
||||
};
|
||||
(StatusCode::UNSUPPORTED_MEDIA_TYPE, Json(body)).into_response()
|
||||
}
|
||||
None => {
|
||||
let body = ErrorResponse {
|
||||
error: "unsupported_media_type".into(),
|
||||
message: "Missing Content-Type header".into(),
|
||||
details: None,
|
||||
request_id: None,
|
||||
};
|
||||
(StatusCode::UNSUPPORTED_MEDIA_TYPE, Json(body)).into_response()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
next.run(request).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Request body size limit middleware.
|
||||
pub struct BodyLimitMiddleware {
|
||||
max_size: usize,
|
||||
}
|
||||
|
||||
impl BodyLimitMiddleware {
|
||||
pub fn new(max_size: usize) -> Self {
|
||||
Self { max_size }
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract client IP from request.
|
||||
pub fn extract_client_ip(headers: &HeaderMap, connect_info: Option<&ConnectInfo<SocketAddr>>) -> Option<String> {
|
||||
// Try X-Forwarded-For first (for proxied requests)
|
||||
if let Some(forwarded) = headers
|
||||
.get("x-forwarded-for")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
{
|
||||
// Take the first IP in the chain
|
||||
if let Some(ip) = forwarded.split(',').next() {
|
||||
return Some(ip.trim().to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Try X-Real-IP
|
||||
if let Some(real_ip) = headers
|
||||
.get("x-real-ip")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
{
|
||||
return Some(real_ip.to_string());
|
||||
}
|
||||
|
||||
// Fall back to connection info
|
||||
connect_info.map(|ci| ci.0.ip().to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cors_layer_creation() {
|
||||
let config = Config::default();
|
||||
let _layer = cors_layer(&config);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_limiter_creation() {
|
||||
let limiter = create_rate_limiter(100);
|
||||
assert!(limiter.check().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_client_ip_x_forwarded() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-forwarded-for", "1.2.3.4, 5.6.7.8".parse().unwrap());
|
||||
|
||||
let ip = extract_client_ip(&headers, None);
|
||||
assert_eq!(ip, Some("1.2.3.4".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_client_ip_x_real() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-real-ip", "10.0.0.1".parse().unwrap());
|
||||
|
||||
let ip = extract_client_ip(&headers, None);
|
||||
assert_eq!(ip, Some("10.0.0.1".to_string()));
|
||||
}
|
||||
}
|
||||
24
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/rest/mod.rs
vendored
Normal file
24
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/rest/mod.rs
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
//! REST API module for 7sense.
|
||||
//!
|
||||
//! This module provides RESTful endpoints for:
|
||||
//! - Recording upload and management
|
||||
//! - Segment similarity search
|
||||
//! - Cluster discovery and labeling
|
||||
//! - Evidence pack retrieval
|
||||
//!
|
||||
//! ## API Versioning
|
||||
//!
|
||||
//! All endpoints are versioned under `/api/v1/`. Breaking changes will
|
||||
//! result in a new API version (e.g., `/api/v2/`).
|
||||
//!
|
||||
//! ## Authentication
|
||||
//!
|
||||
//! If `SEVENSENSE_API_KEY` is set, all requests must include an
|
||||
//! `Authorization: Bearer <api_key>` header.
|
||||
|
||||
pub mod handlers;
|
||||
pub mod middleware;
|
||||
pub mod routes;
|
||||
|
||||
pub use handlers::*;
|
||||
pub use routes::create_router;
|
||||
74
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/rest/routes.rs
vendored
Normal file
74
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/rest/routes.rs
vendored
Normal file
@@ -0,0 +1,74 @@
|
||||
//! REST API route definitions with versioning.
|
||||
//!
|
||||
//! Routes are organized by resource type and versioned under `/api/v1/`.
|
||||
|
||||
use axum::{
|
||||
routing::{get, post, put},
|
||||
Router,
|
||||
};
|
||||
|
||||
use super::handlers;
|
||||
use crate::AppContext;
|
||||
|
||||
/// Create the REST API router with all endpoints.
|
||||
pub fn create_router(_ctx: AppContext) -> Router<AppContext> {
|
||||
Router::new()
|
||||
// Health check
|
||||
.route("/health", get(handlers::health_check))
|
||||
// Recordings
|
||||
.nest("/recordings", recordings_router())
|
||||
// Segments
|
||||
.nest("/segments", segments_router())
|
||||
// Clusters
|
||||
.nest("/clusters", clusters_router())
|
||||
// Evidence
|
||||
.nest("/evidence", evidence_router())
|
||||
// Search
|
||||
.route("/search", post(handlers::search))
|
||||
}
|
||||
|
||||
/// Recording management routes.
|
||||
fn recordings_router() -> Router<AppContext> {
|
||||
Router::new()
|
||||
// POST /recordings - Upload new recording
|
||||
.route("/", post(handlers::upload_recording))
|
||||
// GET /recordings/:id - Get recording by ID
|
||||
.route("/:id", get(handlers::get_recording))
|
||||
}
|
||||
|
||||
/// Segment analysis routes.
|
||||
fn segments_router() -> Router<AppContext> {
|
||||
Router::new()
|
||||
// GET /segments/:id/neighbors - Find similar segments
|
||||
.route("/:id/neighbors", get(handlers::get_neighbors))
|
||||
}
|
||||
|
||||
/// Cluster management routes.
|
||||
fn clusters_router() -> Router<AppContext> {
|
||||
Router::new()
|
||||
// GET /clusters - List all clusters
|
||||
.route("/", get(handlers::list_clusters))
|
||||
// GET /clusters/:id - Get specific cluster
|
||||
.route("/:id", get(handlers::get_cluster))
|
||||
// PUT /clusters/:id/label - Assign label to cluster
|
||||
.route("/:id/label", put(handlers::assign_cluster_label))
|
||||
}
|
||||
|
||||
/// Evidence pack routes.
|
||||
fn evidence_router() -> Router<AppContext> {
|
||||
Router::new()
|
||||
// POST /evidence - Generate evidence pack
|
||||
.route("/", post(handlers::generate_evidence_pack))
|
||||
// GET /evidence/:id - Get evidence pack by ID
|
||||
.route("/:id", get(handlers::get_evidence_pack))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::body::Body;
|
||||
use axum::http::{Request, StatusCode};
|
||||
use tower::ServiceExt;
|
||||
|
||||
// Integration tests would go here with a mock AppContext
|
||||
}
|
||||
149
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/services/audio.rs
vendored
Normal file
149
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/services/audio.rs
vendored
Normal file
@@ -0,0 +1,149 @@
|
||||
//! Audio processing service.
|
||||
//!
|
||||
//! This module provides the `AudioPipeline` service for loading and
|
||||
//! segmenting audio recordings.
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use super::{Audio, Segment};
|
||||
|
||||
/// Audio processing error.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AudioError {
|
||||
/// Invalid audio format
|
||||
#[error("Invalid audio format: {0}")]
|
||||
InvalidFormat(String),
|
||||
|
||||
/// Decoding error
|
||||
#[error("Failed to decode audio: {0}")]
|
||||
DecodingError(String),
|
||||
|
||||
/// IO error
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
/// Unsupported sample rate
|
||||
#[error("Unsupported sample rate: {0}")]
|
||||
UnsupportedSampleRate(u32),
|
||||
|
||||
/// Empty audio
|
||||
#[error("Audio file is empty or too short")]
|
||||
EmptyAudio,
|
||||
}
|
||||
|
||||
/// Audio pipeline configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AudioPipelineConfig {
|
||||
/// Target sample rate for processing
|
||||
pub target_sample_rate: u32,
|
||||
/// Minimum segment duration in seconds
|
||||
pub min_segment_duration: f64,
|
||||
/// Maximum segment duration in seconds
|
||||
pub max_segment_duration: f64,
|
||||
/// Energy threshold for segmentation
|
||||
pub energy_threshold: f32,
|
||||
}
|
||||
|
||||
impl Default for AudioPipelineConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
target_sample_rate: 32000,
|
||||
min_segment_duration: 0.5,
|
||||
max_segment_duration: 10.0,
|
||||
energy_threshold: 0.01,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Audio processing pipeline.
|
||||
///
|
||||
/// Handles audio loading, resampling, and segmentation.
|
||||
pub struct AudioPipeline {
|
||||
config: AudioPipelineConfig,
|
||||
}
|
||||
|
||||
impl AudioPipeline {
|
||||
/// Create a new audio pipeline with the given configuration.
|
||||
pub fn new(config: AudioPipelineConfig) -> Result<Self, AudioError> {
|
||||
Ok(Self { config })
|
||||
}
|
||||
|
||||
/// Get metadata from audio data without full decoding.
|
||||
pub fn get_metadata(&self, data: &[u8]) -> Result<(f64, u32, u16), AudioError> {
|
||||
// In a real implementation, this would parse the audio header
|
||||
// For now, return reasonable defaults
|
||||
if data.len() < 44 {
|
||||
return Err(AudioError::EmptyAudio);
|
||||
}
|
||||
|
||||
// Parse WAV header (simplified)
|
||||
// Real implementation would use symphonia or hound
|
||||
let sample_rate = 44100u32;
|
||||
let channels = 1u16;
|
||||
let duration = data.len() as f64 / (sample_rate as f64 * channels as f64 * 2.0);
|
||||
|
||||
Ok((duration, sample_rate, channels))
|
||||
}
|
||||
|
||||
/// Load audio from raw bytes.
|
||||
pub fn load_audio(&self, data: &[u8]) -> Result<Audio, AudioError> {
|
||||
if data.is_empty() {
|
||||
return Err(AudioError::EmptyAudio);
|
||||
}
|
||||
|
||||
// In a real implementation, this would:
|
||||
// 1. Detect format (WAV, FLAC, MP3, etc.)
|
||||
// 2. Decode to samples
|
||||
// 3. Convert to mono if stereo
|
||||
// 4. Resample to target rate
|
||||
// 5. Normalize to -1.0 to 1.0
|
||||
|
||||
let (duration, _sample_rate, _) = self.get_metadata(data)?;
|
||||
|
||||
// Generate placeholder samples
|
||||
let num_samples = (duration * self.config.target_sample_rate as f64) as usize;
|
||||
let samples = vec![0.0f32; num_samples];
|
||||
|
||||
Ok(Audio {
|
||||
samples,
|
||||
sample_rate: self.config.target_sample_rate,
|
||||
duration_secs: duration,
|
||||
})
|
||||
}
|
||||
|
||||
/// Segment audio into individual calls/vocalizations.
|
||||
pub fn segment(&self, _audio: &Audio) -> Result<Vec<Segment>, AudioError> {
|
||||
// In a real implementation, this would:
|
||||
// 1. Compute spectrogram
|
||||
// 2. Detect energy regions above threshold
|
||||
// 3. Apply minimum/maximum duration constraints
|
||||
// 4. Extract segment audio
|
||||
|
||||
// For now, return empty segments (placeholder)
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
/// Get the target sample rate.
|
||||
#[must_use]
|
||||
pub fn target_sample_rate(&self) -> u32 {
|
||||
self.config.target_sample_rate
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_audio_pipeline_creation() {
|
||||
let pipeline = AudioPipeline::new(Default::default());
|
||||
assert!(pipeline.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_audio_error() {
|
||||
let pipeline = AudioPipeline::new(Default::default()).unwrap();
|
||||
let result = pipeline.load_audio(&[]);
|
||||
assert!(matches!(result, Err(AudioError::EmptyAudio)));
|
||||
}
|
||||
}
|
||||
247
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/services/cluster.rs
vendored
Normal file
247
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/services/cluster.rs
vendored
Normal file
@@ -0,0 +1,247 @@
|
||||
//! Cluster analysis service.
|
||||
//!
|
||||
//! This module provides the `ClusterEngine` service for discovering
|
||||
//! and managing clusters of similar segments.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::RwLock;
|
||||
|
||||
use chrono::Utc;
|
||||
use thiserror::Error;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{ClusterData, SegmentEmbedding};
|
||||
|
||||
/// Cluster analysis error.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AnalysisError {
|
||||
/// Clustering error
|
||||
#[error("Clustering failed: {0}")]
|
||||
ClusteringError(String),
|
||||
|
||||
/// Invalid parameters
|
||||
#[error("Invalid parameters: {0}")]
|
||||
InvalidParameters(String),
|
||||
|
||||
/// Not found
|
||||
#[error("Not found: {0}")]
|
||||
NotFound(String),
|
||||
|
||||
/// Internal error
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
/// Cluster engine configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClusterEngineConfig {
|
||||
/// Minimum cluster size
|
||||
pub min_cluster_size: usize,
|
||||
/// HDBSCAN min_samples
|
||||
pub min_samples: usize,
|
||||
/// Distance threshold for merging
|
||||
pub merge_threshold: f32,
|
||||
}
|
||||
|
||||
impl Default for ClusterEngineConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
min_cluster_size: 5,
|
||||
min_samples: 3,
|
||||
merge_threshold: 0.15,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cluster analysis engine.
|
||||
///
|
||||
/// Manages cluster discovery, labeling, and updates.
|
||||
pub struct ClusterEngine {
|
||||
config: ClusterEngineConfig,
|
||||
// In-memory cluster storage for stub implementation
|
||||
clusters: RwLock<HashMap<Uuid, ClusterData>>,
|
||||
}
|
||||
|
||||
impl ClusterEngine {
|
||||
/// Create a new cluster engine with the given configuration.
|
||||
pub fn new(config: ClusterEngineConfig) -> Result<Self, AnalysisError> {
|
||||
Ok(Self {
|
||||
config,
|
||||
clusters: RwLock::new(HashMap::new()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Update clusters with new embeddings.
|
||||
pub fn update_clusters(&self, _embeddings: &[SegmentEmbedding]) -> Result<(), AnalysisError> {
|
||||
// In a real implementation, this would:
|
||||
// 1. Run HDBSCAN or similar clustering
|
||||
// 2. Merge with existing clusters if similar
|
||||
// 3. Update cluster centroids and metadata
|
||||
|
||||
// For the stub, we don't create clusters automatically
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get all clusters.
|
||||
pub fn get_all_clusters(&self) -> Result<Vec<ClusterData>, AnalysisError> {
|
||||
let clusters = self
|
||||
.clusters
|
||||
.read()
|
||||
.map_err(|e| AnalysisError::Internal(e.to_string()))?;
|
||||
|
||||
Ok(clusters.values().cloned().collect())
|
||||
}
|
||||
|
||||
/// Get a specific cluster by ID.
|
||||
pub fn get_cluster(&self, id: &Uuid) -> Result<Option<ClusterData>, AnalysisError> {
|
||||
let clusters = self
|
||||
.clusters
|
||||
.read()
|
||||
.map_err(|e| AnalysisError::Internal(e.to_string()))?;
|
||||
|
||||
Ok(clusters.get(id).cloned())
|
||||
}
|
||||
|
||||
/// Assign a label to a cluster.
|
||||
pub fn assign_label(
|
||||
&self,
|
||||
cluster_id: &Uuid,
|
||||
label: &str,
|
||||
) -> Result<Option<ClusterData>, AnalysisError> {
|
||||
let mut clusters = self
|
||||
.clusters
|
||||
.write()
|
||||
.map_err(|e| AnalysisError::Internal(e.to_string()))?;
|
||||
|
||||
if let Some(cluster) = clusters.get_mut(cluster_id) {
|
||||
cluster.label = Some(label.to_string());
|
||||
Ok(Some(cluster.clone()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new cluster manually.
|
||||
pub fn create_cluster(
|
||||
&self,
|
||||
centroid: Vec<f32>,
|
||||
exemplar_ids: Vec<Uuid>,
|
||||
) -> Result<ClusterData, AnalysisError> {
|
||||
let cluster = ClusterData {
|
||||
id: Uuid::new_v4(),
|
||||
label: None,
|
||||
size: exemplar_ids.len(),
|
||||
centroid,
|
||||
density: 0.0,
|
||||
exemplar_ids,
|
||||
species_distribution: vec![],
|
||||
created_at: Utc::now(),
|
||||
};
|
||||
|
||||
let mut clusters = self
|
||||
.clusters
|
||||
.write()
|
||||
.map_err(|e| AnalysisError::Internal(e.to_string()))?;
|
||||
|
||||
clusters.insert(cluster.id, cluster.clone());
|
||||
|
||||
Ok(cluster)
|
||||
}
|
||||
|
||||
/// Delete a cluster.
|
||||
pub fn delete_cluster(&self, id: &Uuid) -> Result<bool, AnalysisError> {
|
||||
let mut clusters = self
|
||||
.clusters
|
||||
.write()
|
||||
.map_err(|e| AnalysisError::Internal(e.to_string()))?;
|
||||
|
||||
Ok(clusters.remove(id).is_some())
|
||||
}
|
||||
|
||||
/// Merge two clusters.
|
||||
pub fn merge_clusters(
|
||||
&self,
|
||||
cluster_a: &Uuid,
|
||||
cluster_b: &Uuid,
|
||||
) -> Result<ClusterData, AnalysisError> {
|
||||
let mut clusters = self
|
||||
.clusters
|
||||
.write()
|
||||
.map_err(|e| AnalysisError::Internal(e.to_string()))?;
|
||||
|
||||
let a = clusters
|
||||
.remove(cluster_a)
|
||||
.ok_or_else(|| AnalysisError::NotFound(format!("Cluster {} not found", cluster_a)))?;
|
||||
|
||||
let b = clusters
|
||||
.remove(cluster_b)
|
||||
.ok_or_else(|| AnalysisError::NotFound(format!("Cluster {} not found", cluster_b)))?;
|
||||
|
||||
// Merge exemplar IDs
|
||||
let mut merged_exemplars = a.exemplar_ids;
|
||||
merged_exemplars.extend(b.exemplar_ids);
|
||||
|
||||
// Average centroids (simplified)
|
||||
let merged_centroid: Vec<f32> = a
|
||||
.centroid
|
||||
.iter()
|
||||
.zip(b.centroid.iter())
|
||||
.map(|(x, y)| (x + y) / 2.0)
|
||||
.collect();
|
||||
|
||||
let merged = ClusterData {
|
||||
id: Uuid::new_v4(),
|
||||
label: a.label.or(b.label),
|
||||
size: a.size + b.size,
|
||||
centroid: merged_centroid,
|
||||
density: (a.density + b.density) / 2.0,
|
||||
exemplar_ids: merged_exemplars,
|
||||
species_distribution: vec![], // Would recompute
|
||||
created_at: Utc::now(),
|
||||
};
|
||||
|
||||
clusters.insert(merged.id, merged.clone());
|
||||
|
||||
Ok(merged)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cluster_engine_creation() {
|
||||
let engine = ClusterEngine::new(Default::default());
|
||||
assert!(engine.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_and_get_cluster() {
|
||||
let engine = ClusterEngine::new(Default::default()).unwrap();
|
||||
|
||||
let cluster = engine
|
||||
.create_cluster(vec![0.0; 1024], vec![Uuid::new_v4()])
|
||||
.unwrap();
|
||||
|
||||
let retrieved = engine.get_cluster(&cluster.id).unwrap();
|
||||
assert!(retrieved.is_some());
|
||||
assert_eq!(retrieved.unwrap().id, cluster.id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_assign_label() {
|
||||
let engine = ClusterEngine::new(Default::default()).unwrap();
|
||||
|
||||
let cluster = engine
|
||||
.create_cluster(vec![0.0; 1024], vec![Uuid::new_v4()])
|
||||
.unwrap();
|
||||
|
||||
let updated = engine
|
||||
.assign_label(&cluster.id, "Test Label")
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(updated.label, Some("Test Label".to_string()));
|
||||
}
|
||||
}
|
||||
132
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/services/embedding.rs
vendored
Normal file
132
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/services/embedding.rs
vendored
Normal file
@@ -0,0 +1,132 @@
|
||||
//! Embedding model service.
|
||||
//!
|
||||
//! This module provides the `EmbeddingModel` service for generating
|
||||
//! vector embeddings from audio segments.
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use super::{Segment, SegmentEmbedding};
|
||||
|
||||
/// Embedding model error.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum EmbeddingError {
|
||||
/// Model loading error
|
||||
#[error("Failed to load model: {0}")]
|
||||
ModelLoadError(String),
|
||||
|
||||
/// Inference error
|
||||
#[error("Inference failed: {0}")]
|
||||
InferenceError(String),
|
||||
|
||||
/// Invalid input
|
||||
#[error("Invalid input: {0}")]
|
||||
InvalidInput(String),
|
||||
|
||||
/// Model not initialized
|
||||
#[error("Model not initialized")]
|
||||
NotInitialized,
|
||||
}
|
||||
|
||||
/// Embedding model configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EmbeddingModelConfig {
|
||||
/// Model path or identifier
|
||||
pub model_id: String,
|
||||
/// Embedding dimension
|
||||
pub embedding_dim: usize,
|
||||
/// Batch size for inference
|
||||
pub batch_size: usize,
|
||||
/// Use GPU if available
|
||||
pub use_gpu: bool,
|
||||
}
|
||||
|
||||
impl Default for EmbeddingModelConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model_id: "birdnet-v2.4".to_string(),
|
||||
embedding_dim: 1024,
|
||||
batch_size: 32,
|
||||
use_gpu: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Embedding model for generating audio embeddings.
|
||||
///
|
||||
/// Wraps ONNX model inference for generating fixed-size vector
|
||||
/// representations of audio segments.
|
||||
pub struct EmbeddingModel {
|
||||
config: EmbeddingModelConfig,
|
||||
}
|
||||
|
||||
impl EmbeddingModel {
|
||||
/// Create a new embedding model with the given configuration.
|
||||
pub async fn new(config: EmbeddingModelConfig) -> Result<Self, EmbeddingError> {
|
||||
// In a real implementation, this would:
|
||||
// 1. Load ONNX model from path
|
||||
// 2. Initialize ONNX runtime session
|
||||
// 3. Configure GPU/CPU execution providers
|
||||
|
||||
Ok(Self { config })
|
||||
}
|
||||
|
||||
/// Generate embeddings for a batch of segments.
|
||||
pub async fn embed_batch(
|
||||
&self,
|
||||
segments: &[Segment],
|
||||
) -> Result<Vec<SegmentEmbedding>, EmbeddingError> {
|
||||
// In a real implementation, this would:
|
||||
// 1. Preprocess segments (mel spectrogram)
|
||||
// 2. Batch and run inference
|
||||
// 3. L2 normalize embeddings
|
||||
|
||||
let embeddings = segments
|
||||
.iter()
|
||||
.map(|seg| SegmentEmbedding {
|
||||
id: seg.id,
|
||||
recording_id: seg.recording_id,
|
||||
embedding: vec![0.0; self.config.embedding_dim],
|
||||
start_time: seg.start_time,
|
||||
end_time: seg.end_time,
|
||||
species: seg.species.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(embeddings)
|
||||
}
|
||||
|
||||
/// Generate embedding for text (for text-to-audio search).
|
||||
pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
|
||||
if text.is_empty() {
|
||||
return Err(EmbeddingError::InvalidInput("Empty text".to_string()));
|
||||
}
|
||||
|
||||
// In a real implementation, this would use a text encoder
|
||||
// For now, return a zero vector
|
||||
Ok(vec![0.0; self.config.embedding_dim])
|
||||
}
|
||||
|
||||
/// Get the embedding dimension.
|
||||
#[must_use]
|
||||
pub fn embedding_dim(&self) -> usize {
|
||||
self.config.embedding_dim
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_embedding_model_creation() {
|
||||
let model = EmbeddingModel::new(Default::default()).await;
|
||||
assert!(model.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_embed_text_empty() {
|
||||
let model = EmbeddingModel::new(Default::default()).await.unwrap();
|
||||
let result = model.embed_text("").await;
|
||||
assert!(matches!(result, Err(EmbeddingError::InvalidInput(_))));
|
||||
}
|
||||
}
|
||||
246
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/services/interpretation.rs
vendored
Normal file
246
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/services/interpretation.rs
vendored
Normal file
@@ -0,0 +1,246 @@
|
||||
//! Interpretation service.
|
||||
//!
|
||||
//! This module provides the `InterpretationEngine` service for generating
|
||||
//! evidence packs that explain similarity relationships.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::RwLock;
|
||||
|
||||
use chrono::Utc;
|
||||
use thiserror::Error;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{
|
||||
EvidencePackData, EvidenceSegment, FeatureContributionData, NeighborEvidenceData,
|
||||
SearchResult, SharedFeature, VisualizationUrls,
|
||||
};
|
||||
|
||||
/// Interpretation error.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum InterpretationError {
|
||||
/// Evidence pack generation failed
|
||||
#[error("Evidence generation failed: {0}")]
|
||||
GenerationError(String),
|
||||
|
||||
/// Not found
|
||||
#[error("Not found: {0}")]
|
||||
NotFound(String),
|
||||
|
||||
/// Internal error
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
/// Interpretation engine configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InterpretationEngineConfig {
|
||||
/// Number of top features to include
|
||||
pub top_features: usize,
|
||||
/// Generate spectrograms
|
||||
pub generate_spectrograms: bool,
|
||||
/// Generate UMAP visualizations
|
||||
pub generate_umap: bool,
|
||||
}
|
||||
|
||||
impl Default for InterpretationEngineConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
top_features: 5,
|
||||
generate_spectrograms: true,
|
||||
generate_umap: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Interpretation engine for generating evidence packs.
|
||||
///
|
||||
/// Creates interpretable explanations for similarity relationships.
|
||||
pub struct InterpretationEngine {
|
||||
config: InterpretationEngineConfig,
|
||||
// Cache for generated evidence packs
|
||||
cache: RwLock<HashMap<Uuid, EvidencePackData>>,
|
||||
}
|
||||
|
||||
impl InterpretationEngine {
|
||||
/// Create a new interpretation engine with the given configuration.
|
||||
pub fn new(config: InterpretationEngineConfig) -> Result<Self, InterpretationError> {
|
||||
Ok(Self {
|
||||
config,
|
||||
cache: RwLock::new(HashMap::new()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a cached evidence pack by query ID.
|
||||
pub fn get_evidence_pack(
|
||||
&self,
|
||||
query_id: &Uuid,
|
||||
) -> Result<Option<EvidencePackData>, InterpretationError> {
|
||||
let cache = self
|
||||
.cache
|
||||
.read()
|
||||
.map_err(|e| InterpretationError::Internal(e.to_string()))?;
|
||||
|
||||
Ok(cache.get(query_id).cloned())
|
||||
}
|
||||
|
||||
/// Generate an evidence pack for a query segment and its neighbors.
|
||||
pub async fn generate_evidence_pack(
|
||||
&self,
|
||||
segment_id: &Uuid,
|
||||
neighbors: &[SearchResult],
|
||||
) -> Result<EvidencePackData, InterpretationError> {
|
||||
let query_id = Uuid::new_v4();
|
||||
|
||||
// Create query segment info
|
||||
let query_segment = EvidenceSegment {
|
||||
id: *segment_id,
|
||||
recording_id: Uuid::new_v4(), // Would be looked up
|
||||
start_time: 0.0,
|
||||
end_time: 1.0,
|
||||
species: None,
|
||||
};
|
||||
|
||||
// Generate evidence for each neighbor
|
||||
let neighbor_evidence: Vec<NeighborEvidenceData> = neighbors
|
||||
.iter()
|
||||
.map(|n| {
|
||||
// In a real implementation, this would:
|
||||
// 1. Analyze embedding dimensions
|
||||
// 2. Identify contributing features
|
||||
// 3. Generate spectrogram comparisons
|
||||
|
||||
let contributing_features = vec![
|
||||
FeatureContributionData {
|
||||
name: "fundamental_frequency".to_string(),
|
||||
weight: 0.25,
|
||||
query_value: 2500.0,
|
||||
neighbor_value: 2480.0,
|
||||
},
|
||||
FeatureContributionData {
|
||||
name: "duration".to_string(),
|
||||
weight: 0.15,
|
||||
query_value: 0.5,
|
||||
neighbor_value: 0.48,
|
||||
},
|
||||
FeatureContributionData {
|
||||
name: "bandwidth".to_string(),
|
||||
weight: 0.12,
|
||||
query_value: 1500.0,
|
||||
neighbor_value: 1520.0,
|
||||
},
|
||||
];
|
||||
|
||||
NeighborEvidenceData {
|
||||
segment: EvidenceSegment {
|
||||
id: n.id,
|
||||
recording_id: n.recording_id,
|
||||
start_time: n.start_time,
|
||||
end_time: n.end_time,
|
||||
species: n.species.clone(),
|
||||
},
|
||||
similarity: 1.0 - n.distance,
|
||||
contributing_features,
|
||||
spectrogram_comparison_url: if self.config.generate_spectrograms {
|
||||
Some(format!("/api/v1/evidence/{}/spectrograms/{}", query_id, n.id))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Identify shared features across neighbors
|
||||
let shared_features = vec![
|
||||
SharedFeature {
|
||||
name: "frequency_modulation".to_string(),
|
||||
description: "Rapid upward frequency sweep in 100-200ms range".to_string(),
|
||||
confidence: 0.92,
|
||||
},
|
||||
SharedFeature {
|
||||
name: "harmonic_structure".to_string(),
|
||||
description: "Clear harmonic overtones at 2x and 3x fundamental".to_string(),
|
||||
confidence: 0.87,
|
||||
},
|
||||
];
|
||||
|
||||
// Generate visualization URLs
|
||||
let visualizations = VisualizationUrls {
|
||||
umap_url: if self.config.generate_umap {
|
||||
Some(format!("/api/v1/evidence/{}/umap", query_id))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
spectrogram_grid_url: if self.config.generate_spectrograms {
|
||||
Some(format!("/api/v1/evidence/{}/grid", query_id))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
feature_importance_url: Some(format!("/api/v1/evidence/{}/features", query_id)),
|
||||
};
|
||||
|
||||
let evidence_pack = EvidencePackData {
|
||||
query_id,
|
||||
query_segment,
|
||||
neighbors: neighbor_evidence,
|
||||
shared_features,
|
||||
visualizations,
|
||||
generated_at: Utc::now(),
|
||||
};
|
||||
|
||||
// Cache the evidence pack
|
||||
{
|
||||
let mut cache = self
|
||||
.cache
|
||||
.write()
|
||||
.map_err(|e| InterpretationError::Internal(e.to_string()))?;
|
||||
|
||||
cache.insert(query_id, evidence_pack.clone());
|
||||
}
|
||||
|
||||
Ok(evidence_pack)
|
||||
}
|
||||
|
||||
/// Clear the evidence pack cache.
|
||||
pub fn clear_cache(&self) -> Result<(), InterpretationError> {
|
||||
let mut cache = self
|
||||
.cache
|
||||
.write()
|
||||
.map_err(|e| InterpretationError::Internal(e.to_string()))?;
|
||||
|
||||
cache.clear();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_interpretation_engine_creation() {
|
||||
let engine = InterpretationEngine::new(Default::default());
|
||||
assert!(engine.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_generate_evidence_pack() {
|
||||
let engine = InterpretationEngine::new(Default::default()).unwrap();
|
||||
|
||||
let segment_id = Uuid::new_v4();
|
||||
let neighbors = vec![SearchResult {
|
||||
id: Uuid::new_v4(),
|
||||
recording_id: Uuid::new_v4(),
|
||||
distance: 0.1,
|
||||
start_time: 0.0,
|
||||
end_time: 1.0,
|
||||
species: None,
|
||||
}];
|
||||
|
||||
let result = engine.generate_evidence_pack(&segment_id, &neighbors).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let pack = result.unwrap();
|
||||
assert!(!pack.neighbors.is_empty());
|
||||
assert!(!pack.shared_features.is_empty());
|
||||
}
|
||||
}
|
||||
218
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/services/mod.rs
vendored
Normal file
218
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/services/mod.rs
vendored
Normal file
@@ -0,0 +1,218 @@
|
||||
//! Service layer abstractions for 7sense API.
|
||||
//!
|
||||
//! This module defines the interfaces and implementations for core services:
|
||||
//! - `AudioPipeline` - Audio loading and segmentation
|
||||
//! - `EmbeddingModel` - Segment embedding generation
|
||||
//! - `VectorIndex` - Similarity search
|
||||
//! - `ClusterEngine` - Cluster analysis
|
||||
//! - `InterpretationEngine` - Evidence pack generation
|
||||
//!
|
||||
//! These services wrap the underlying crate implementations and provide
|
||||
//! API-specific functionality.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
// Re-export service types
|
||||
pub use audio::*;
|
||||
pub use cluster::*;
|
||||
pub use embedding::*;
|
||||
pub use interpretation::*;
|
||||
pub use vector::*;
|
||||
|
||||
mod audio;
|
||||
mod cluster;
|
||||
mod embedding;
|
||||
mod interpretation;
|
||||
mod vector;
|
||||
|
||||
/// Species information attached to a segment.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct SpeciesInfo {
|
||||
/// Common name
|
||||
pub common_name: String,
|
||||
/// Scientific name (binomial nomenclature)
|
||||
pub scientific_name: Option<String>,
|
||||
/// Detection confidence (0.0 to 1.0)
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
/// A detected audio segment.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Segment {
|
||||
/// Unique identifier
|
||||
pub id: Uuid,
|
||||
/// Parent recording ID
|
||||
pub recording_id: Uuid,
|
||||
/// Start time in seconds
|
||||
pub start_time: f64,
|
||||
/// End time in seconds
|
||||
pub end_time: f64,
|
||||
/// Audio samples (mono, normalized)
|
||||
pub samples: Vec<f32>,
|
||||
/// Sample rate
|
||||
pub sample_rate: u32,
|
||||
/// Detected species
|
||||
pub species: Option<SpeciesInfo>,
|
||||
/// Quality score
|
||||
pub quality_score: f32,
|
||||
}
|
||||
|
||||
/// Audio metadata.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AudioMetadata {
|
||||
/// Duration in seconds
|
||||
pub duration_secs: f64,
|
||||
/// Sample rate in Hz
|
||||
pub sample_rate: u32,
|
||||
/// Number of channels
|
||||
pub channels: u16,
|
||||
}
|
||||
|
||||
/// Loaded audio data.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Audio {
|
||||
/// Mono samples (normalized to -1.0 to 1.0)
|
||||
pub samples: Vec<f32>,
|
||||
/// Sample rate in Hz
|
||||
pub sample_rate: u32,
|
||||
/// Original duration in seconds
|
||||
pub duration_secs: f64,
|
||||
}
|
||||
|
||||
/// Segment embedding for vector storage.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SegmentEmbedding {
|
||||
/// Segment ID
|
||||
pub id: Uuid,
|
||||
/// Recording ID
|
||||
pub recording_id: Uuid,
|
||||
/// Embedding vector
|
||||
pub embedding: Vec<f32>,
|
||||
/// Start time
|
||||
pub start_time: f64,
|
||||
/// End time
|
||||
pub end_time: f64,
|
||||
/// Detected species
|
||||
pub species: Option<SpeciesInfo>,
|
||||
}
|
||||
|
||||
/// Search result from vector index.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SearchResult {
|
||||
/// Segment ID
|
||||
pub id: Uuid,
|
||||
/// Recording ID
|
||||
pub recording_id: Uuid,
|
||||
/// Distance to query
|
||||
pub distance: f32,
|
||||
/// Start time
|
||||
pub start_time: f64,
|
||||
/// End time
|
||||
pub end_time: f64,
|
||||
/// Detected species
|
||||
pub species: Option<SpeciesInfo>,
|
||||
}
|
||||
|
||||
/// Cluster data.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClusterData {
|
||||
/// Cluster ID
|
||||
pub id: Uuid,
|
||||
/// Human-assigned label
|
||||
pub label: Option<String>,
|
||||
/// Number of segments
|
||||
pub size: usize,
|
||||
/// Centroid embedding
|
||||
pub centroid: Vec<f32>,
|
||||
/// Cluster density
|
||||
pub density: f32,
|
||||
/// Representative segment IDs
|
||||
pub exemplar_ids: Vec<Uuid>,
|
||||
/// Species distribution: (name, count, percentage)
|
||||
pub species_distribution: Vec<(String, usize, f64)>,
|
||||
/// Creation timestamp
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Evidence pack data.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EvidencePackData {
|
||||
/// Query ID
|
||||
pub query_id: Uuid,
|
||||
/// Query segment
|
||||
pub query_segment: EvidenceSegment,
|
||||
/// Neighbor evidence
|
||||
pub neighbors: Vec<NeighborEvidenceData>,
|
||||
/// Shared features
|
||||
pub shared_features: Vec<SharedFeature>,
|
||||
/// Visualizations
|
||||
pub visualizations: VisualizationUrls,
|
||||
/// Generation timestamp
|
||||
pub generated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Segment for evidence pack.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EvidenceSegment {
|
||||
/// Segment ID
|
||||
pub id: Uuid,
|
||||
/// Recording ID
|
||||
pub recording_id: Uuid,
|
||||
/// Start time
|
||||
pub start_time: f64,
|
||||
/// End time
|
||||
pub end_time: f64,
|
||||
/// Species info
|
||||
pub species: Option<SpeciesInfo>,
|
||||
}
|
||||
|
||||
/// Neighbor evidence data.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NeighborEvidenceData {
|
||||
/// Neighbor segment
|
||||
pub segment: EvidenceSegment,
|
||||
/// Similarity score
|
||||
pub similarity: f32,
|
||||
/// Contributing features
|
||||
pub contributing_features: Vec<FeatureContributionData>,
|
||||
/// Spectrogram comparison URL
|
||||
pub spectrogram_comparison_url: Option<String>,
|
||||
}
|
||||
|
||||
/// Feature contribution data.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FeatureContributionData {
|
||||
/// Feature name
|
||||
pub name: String,
|
||||
/// Contribution weight
|
||||
pub weight: f32,
|
||||
/// Query value
|
||||
pub query_value: f64,
|
||||
/// Neighbor value
|
||||
pub neighbor_value: f64,
|
||||
}
|
||||
|
||||
/// Shared acoustic feature.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SharedFeature {
|
||||
/// Feature name
|
||||
pub name: String,
|
||||
/// Description
|
||||
pub description: String,
|
||||
/// Confidence score
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
/// Visualization URLs.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VisualizationUrls {
|
||||
/// UMAP projection URL
|
||||
pub umap_url: Option<String>,
|
||||
/// Spectrogram grid URL
|
||||
pub spectrogram_grid_url: Option<String>,
|
||||
/// Feature importance URL
|
||||
pub feature_importance_url: Option<String>,
|
||||
}
|
||||
235
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/services/vector.rs
vendored
Normal file
235
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/services/vector.rs
vendored
Normal file
@@ -0,0 +1,235 @@
|
||||
//! Vector index service.
|
||||
//!
|
||||
//! This module provides the `VectorIndex` service for similarity search
|
||||
//! using vector embeddings.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::RwLock;
|
||||
|
||||
use thiserror::Error;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{SearchResult, SegmentEmbedding, SpeciesInfo};
|
||||
|
||||
/// Vector index error.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum VectorError {
|
||||
/// Connection error
|
||||
#[error("Connection error: {0}")]
|
||||
ConnectionError(String),
|
||||
|
||||
/// Query error
|
||||
#[error("Query error: {0}")]
|
||||
QueryError(String),
|
||||
|
||||
/// Index error
|
||||
#[error("Index error: {0}")]
|
||||
IndexError(String),
|
||||
|
||||
/// Not found
|
||||
#[error("Not found: {0}")]
|
||||
NotFound(String),
|
||||
}
|
||||
|
||||
/// Vector index configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VectorIndexConfig {
|
||||
/// Collection name
|
||||
pub collection_name: String,
|
||||
/// Embedding dimension
|
||||
pub embedding_dim: usize,
|
||||
/// HNSW M parameter
|
||||
pub hnsw_m: usize,
|
||||
/// HNSW ef_construct parameter
|
||||
pub hnsw_ef_construct: usize,
|
||||
}
|
||||
|
||||
impl Default for VectorIndexConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
collection_name: "sevensense_segments".to_string(),
|
||||
embedding_dim: 1024,
|
||||
hnsw_m: 16,
|
||||
hnsw_ef_construct: 100,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// In-memory segment storage for the stub implementation.
|
||||
struct StoredSegment {
|
||||
recording_id: Uuid,
|
||||
embedding: Vec<f32>,
|
||||
start_time: f64,
|
||||
end_time: f64,
|
||||
species: Option<SpeciesInfo>,
|
||||
}
|
||||
|
||||
/// Vector index for similarity search.
|
||||
///
|
||||
/// Wraps vector database (Qdrant) for efficient nearest neighbor search.
|
||||
pub struct VectorIndex {
|
||||
config: VectorIndexConfig,
|
||||
// In-memory storage for stub implementation
|
||||
storage: RwLock<HashMap<Uuid, StoredSegment>>,
|
||||
}
|
||||
|
||||
impl VectorIndex {
|
||||
/// Create a new vector index with the given configuration.
|
||||
pub fn new(config: VectorIndexConfig) -> Result<Self, VectorError> {
|
||||
// In a real implementation, this would:
|
||||
// 1. Connect to Qdrant
|
||||
// 2. Create/verify collection
|
||||
// 3. Configure HNSW index
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
storage: RwLock::new(HashMap::new()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Add a batch of embeddings to the index.
|
||||
pub fn add_batch(&self, embeddings: &[SegmentEmbedding]) -> Result<(), VectorError> {
|
||||
let mut storage = self
|
||||
.storage
|
||||
.write()
|
||||
.map_err(|e| VectorError::IndexError(e.to_string()))?;
|
||||
|
||||
for emb in embeddings {
|
||||
storage.insert(
|
||||
emb.id,
|
||||
StoredSegment {
|
||||
recording_id: emb.recording_id,
|
||||
embedding: emb.embedding.clone(),
|
||||
start_time: emb.start_time,
|
||||
end_time: emb.end_time,
|
||||
species: emb.species.clone(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get embedding for a segment.
|
||||
pub fn get_embedding(&self, segment_id: &Uuid) -> Result<Option<Vec<f32>>, VectorError> {
|
||||
let storage = self
|
||||
.storage
|
||||
.read()
|
||||
.map_err(|e| VectorError::QueryError(e.to_string()))?;
|
||||
|
||||
Ok(storage.get(segment_id).map(|s| s.embedding.clone()))
|
||||
}
|
||||
|
||||
/// Search for similar segments.
|
||||
pub fn search(
|
||||
&self,
|
||||
query: &[f32],
|
||||
k: usize,
|
||||
min_similarity: f32,
|
||||
) -> Result<Vec<SearchResult>, VectorError> {
|
||||
let storage = self
|
||||
.storage
|
||||
.read()
|
||||
.map_err(|e| VectorError::QueryError(e.to_string()))?;
|
||||
|
||||
// Compute distances to all stored embeddings
|
||||
let mut results: Vec<(Uuid, f32, &StoredSegment)> = storage
|
||||
.iter()
|
||||
.map(|(id, seg)| {
|
||||
let distance = cosine_distance(query, &seg.embedding);
|
||||
(*id, distance, seg)
|
||||
})
|
||||
.filter(|(_, dist, _)| (1.0 - *dist) >= min_similarity)
|
||||
.collect();
|
||||
|
||||
// Sort by distance (ascending)
|
||||
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
// Take top k
|
||||
let results: Vec<SearchResult> = results
|
||||
.into_iter()
|
||||
.take(k)
|
||||
.map(|(id, distance, seg)| SearchResult {
|
||||
id,
|
||||
recording_id: seg.recording_id,
|
||||
distance,
|
||||
start_time: seg.start_time,
|
||||
end_time: seg.end_time,
|
||||
species: seg.species.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Delete embeddings for a recording.
|
||||
pub fn delete_recording(&self, recording_id: &Uuid) -> Result<usize, VectorError> {
|
||||
let mut storage = self
|
||||
.storage
|
||||
.write()
|
||||
.map_err(|e| VectorError::IndexError(e.to_string()))?;
|
||||
|
||||
let to_remove: Vec<Uuid> = storage
|
||||
.iter()
|
||||
.filter(|(_, seg)| seg.recording_id == *recording_id)
|
||||
.map(|(id, _)| *id)
|
||||
.collect();
|
||||
|
||||
let count = to_remove.len();
|
||||
for id in to_remove {
|
||||
storage.remove(&id);
|
||||
}
|
||||
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
/// Get total number of indexed segments.
|
||||
pub fn count(&self) -> Result<usize, VectorError> {
|
||||
let storage = self
|
||||
.storage
|
||||
.read()
|
||||
.map_err(|e| VectorError::QueryError(e.to_string()))?;
|
||||
|
||||
Ok(storage.len())
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute cosine distance between two vectors.
|
||||
fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() || a.is_empty() {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a == 0.0 || norm_b == 0.0 {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
1.0 - (dot / (norm_a * norm_b))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_vector_index_creation() {
|
||||
let index = VectorIndex::new(Default::default());
|
||||
assert!(index.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_distance() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
let dist = cosine_distance(&a, &b);
|
||||
assert!((dist - 0.0).abs() < 0.001);
|
||||
|
||||
let c = vec![0.0, 1.0, 0.0];
|
||||
let dist = cosine_distance(&a, &c);
|
||||
assert!((dist - 1.0).abs() < 0.001);
|
||||
}
|
||||
}
|
||||
286
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/websocket/handlers.rs
vendored
Normal file
286
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/websocket/handlers.rs
vendored
Normal file
@@ -0,0 +1,286 @@
|
||||
//! WebSocket handlers for real-time updates.
|
||||
//!
|
||||
//! These handlers manage WebSocket connections for streaming
|
||||
//! processing status, cluster updates, and other real-time data.
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
Path, State,
|
||||
},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::broadcast;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{AppContext, ProcessingEvent, ProcessingStatus};
|
||||
|
||||
/// WebSocket message types for client communication.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", content = "data")]
|
||||
pub enum WsMessage {
|
||||
/// Processing status update
|
||||
#[serde(rename = "status")]
|
||||
Status(StatusUpdate),
|
||||
/// Error message
|
||||
#[serde(rename = "error")]
|
||||
Error(ErrorMessage),
|
||||
/// Ping/keepalive
|
||||
#[serde(rename = "ping")]
|
||||
Ping,
|
||||
/// Pong response
|
||||
#[serde(rename = "pong")]
|
||||
Pong,
|
||||
/// Subscription confirmation
|
||||
#[serde(rename = "subscribed")]
|
||||
Subscribed {
|
||||
/// Channel subscribed to
|
||||
channel: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// Processing status update message.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StatusUpdate {
|
||||
/// Recording ID
|
||||
pub recording_id: Uuid,
|
||||
/// Status string
|
||||
pub status: String,
|
||||
/// Progress (0.0 to 1.0)
|
||||
pub progress: f32,
|
||||
/// Optional message
|
||||
pub message: Option<String>,
|
||||
/// Timestamp
|
||||
pub timestamp: i64,
|
||||
}
|
||||
|
||||
/// Error message.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ErrorMessage {
|
||||
/// Error code
|
||||
pub code: String,
|
||||
/// Error message
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl From<ProcessingEvent> for StatusUpdate {
|
||||
fn from(event: ProcessingEvent) -> Self {
|
||||
Self {
|
||||
recording_id: event.recording_id,
|
||||
status: match event.status {
|
||||
ProcessingStatus::Queued => "queued",
|
||||
ProcessingStatus::Loading => "loading",
|
||||
ProcessingStatus::Segmenting => "segmenting",
|
||||
ProcessingStatus::Embedding => "embedding",
|
||||
ProcessingStatus::Indexing => "indexing",
|
||||
ProcessingStatus::Analyzing => "analyzing",
|
||||
ProcessingStatus::Complete => "complete",
|
||||
ProcessingStatus::Failed => "failed",
|
||||
}
|
||||
.to_string(),
|
||||
progress: event.progress,
|
||||
message: event.message,
|
||||
timestamp: chrono::Utc::now().timestamp_millis(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// WebSocket handler for recording status updates.
|
||||
///
|
||||
/// Clients connect to `/ws/recordings/{id}` to receive real-time
|
||||
/// status updates for a specific recording.
|
||||
pub async fn recording_status_ws(
|
||||
ws: WebSocketUpgrade,
|
||||
Path(recording_id): Path<Uuid>,
|
||||
State(ctx): State<AppContext>,
|
||||
) -> impl IntoResponse {
|
||||
ws.on_upgrade(move |socket| handle_recording_status(socket, recording_id, ctx))
|
||||
}
|
||||
|
||||
async fn handle_recording_status(socket: WebSocket, recording_id: Uuid, ctx: AppContext) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
// Subscribe to events
|
||||
let mut event_rx = ctx.subscribe_events();
|
||||
|
||||
// Send subscription confirmation
|
||||
let confirm = WsMessage::Subscribed {
|
||||
channel: format!("recordings/{recording_id}"),
|
||||
};
|
||||
if let Ok(json) = serde_json::to_string(&confirm) {
|
||||
let _ = sender.send(Message::Text(json.into())).await;
|
||||
}
|
||||
|
||||
// Spawn task to handle incoming messages (pings, etc.)
|
||||
let mut recv_task = tokio::spawn(async move {
|
||||
while let Some(msg) = receiver.next().await {
|
||||
match msg {
|
||||
Ok(Message::Close(_)) | Err(_) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Main event loop
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
// Keepalive interval
|
||||
let mut keepalive = tokio::time::interval(Duration::from_secs(30));
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Handle processing events
|
||||
event = event_rx.recv() => {
|
||||
match event {
|
||||
Ok(event) if event.recording_id == recording_id => {
|
||||
let update: StatusUpdate = event.into();
|
||||
let msg = WsMessage::Status(update);
|
||||
if let Ok(json) = serde_json::to_string(&msg) {
|
||||
if sender.send(Message::Text(json.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(broadcast::error::RecvError::Closed) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
// Keepalive ping
|
||||
_ = keepalive.tick() => {
|
||||
let msg = WsMessage::Ping;
|
||||
if let Ok(json) = serde_json::to_string(&msg) {
|
||||
if sender.send(Message::Text(json.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Wait for either task to complete
|
||||
tokio::select! {
|
||||
_ = &mut send_task => recv_task.abort(),
|
||||
_ = &mut recv_task => send_task.abort(),
|
||||
}
|
||||
|
||||
tracing::debug!(recording_id = %recording_id, "WebSocket connection closed");
|
||||
}
|
||||
|
||||
/// WebSocket handler for all events stream.
|
||||
///
|
||||
/// Admin endpoint that streams all processing events.
|
||||
pub async fn events_ws(ws: WebSocketUpgrade, State(ctx): State<AppContext>) -> impl IntoResponse {
|
||||
ws.on_upgrade(move |socket| handle_all_events(socket, ctx))
|
||||
}
|
||||
|
||||
async fn handle_all_events(socket: WebSocket, ctx: AppContext) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
let mut event_rx = ctx.subscribe_events();
|
||||
|
||||
// Send subscription confirmation
|
||||
let confirm = WsMessage::Subscribed {
|
||||
channel: "events".to_string(),
|
||||
};
|
||||
if let Ok(json) = serde_json::to_string(&confirm) {
|
||||
let _ = sender.send(Message::Text(json.into())).await;
|
||||
}
|
||||
|
||||
// Spawn receiver task
|
||||
let mut recv_task = tokio::spawn(async move {
|
||||
while let Some(msg) = receiver.next().await {
|
||||
match msg {
|
||||
Ok(Message::Close(_)) | Err(_) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Main send loop
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
let mut keepalive = tokio::time::interval(Duration::from_secs(30));
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
event = event_rx.recv() => {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
let update: StatusUpdate = event.into();
|
||||
let msg = WsMessage::Status(update);
|
||||
if let Ok(json) = serde_json::to_string(&msg) {
|
||||
if sender.send(Message::Text(json.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(broadcast::error::RecvError::Closed) => break,
|
||||
Err(_) => {} // Lagged, skip
|
||||
}
|
||||
}
|
||||
_ = keepalive.tick() => {
|
||||
if let Ok(json) = serde_json::to_string(&WsMessage::Ping) {
|
||||
if sender.send(Message::Text(json.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = &mut send_task => recv_task.abort(),
|
||||
_ = &mut recv_task => send_task.abort(),
|
||||
}
|
||||
|
||||
tracing::debug!("Events WebSocket connection closed");
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ws_message_serialization() {
|
||||
let msg = WsMessage::Status(StatusUpdate {
|
||||
recording_id: Uuid::new_v4(),
|
||||
status: "processing".to_string(),
|
||||
progress: 0.5,
|
||||
message: Some("Halfway done".to_string()),
|
||||
timestamp: 1_234_567_890,
|
||||
});
|
||||
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
assert!(json.contains("status"));
|
||||
assert!(json.contains("processing"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_status_update_from_event() {
|
||||
let event = ProcessingEvent {
|
||||
recording_id: Uuid::new_v4(),
|
||||
status: ProcessingStatus::Embedding,
|
||||
progress: 0.5,
|
||||
message: Some("Generating embeddings".to_string()),
|
||||
};
|
||||
|
||||
let update: StatusUpdate = event.into();
|
||||
assert_eq!(update.status, "embedding");
|
||||
assert!((update.progress - 0.5).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_message() {
|
||||
let error = WsMessage::Error(ErrorMessage {
|
||||
code: "not_found".to_string(),
|
||||
message: "Recording not found".to_string(),
|
||||
});
|
||||
|
||||
let json = serde_json::to_string(&error).unwrap();
|
||||
assert!(json.contains("not_found"));
|
||||
}
|
||||
}
|
||||
22
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/websocket/mod.rs
vendored
Normal file
22
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-api/src/websocket/mod.rs
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
//! WebSocket module for real-time updates.
|
||||
//!
|
||||
//! This module provides WebSocket endpoints for:
|
||||
//! - Real-time processing status updates
|
||||
//! - Live cluster updates
|
||||
//! - Streaming search results
|
||||
|
||||
pub mod handlers;
|
||||
|
||||
use axum::{routing::get, Router};
|
||||
|
||||
use crate::AppContext;
|
||||
|
||||
/// Create the WebSocket router.
|
||||
#[must_use]
|
||||
pub fn create_router(_ctx: AppContext) -> Router<AppContext> {
|
||||
Router::new()
|
||||
// Recording status updates
|
||||
.route("/recordings/:id", get(handlers::recording_status_ws))
|
||||
// All events stream (admin)
|
||||
.route("/events", get(handlers::events_ws))
|
||||
}
|
||||
Reference in New Issue
Block a user