Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,111 @@
//! OSpipe REST API server binary.
//!
//! Starts the OSpipe HTTP server with a default pipeline configuration.
//! The server exposes semantic search, query routing, health, and stats endpoints.
//!
//! ## Usage
//!
//! ```bash
//! ospipe-server # default port 3030
//! ospipe-server --port 8080 # custom port
//! ospipe-server --data-dir /tmp/ospipe # custom data directory
//! ```
use std::sync::Arc;
use tokio::sync::RwLock;
fn main() {
// Parse CLI arguments
let args: Vec<String> = std::env::args().collect();
let mut port: u16 = 3030;
let mut data_dir: Option<String> = None;
let mut i = 1;
while i < args.len() {
match args[i].as_str() {
"--port" | "-p" => {
if i + 1 < args.len() {
port = args[i + 1].parse().unwrap_or_else(|_| {
eprintln!("Invalid port: {}", args[i + 1]);
std::process::exit(1);
});
i += 2;
} else {
eprintln!("--port requires a value");
std::process::exit(1);
}
}
"--data-dir" | "-d" => {
if i + 1 < args.len() {
data_dir = Some(args[i + 1].clone());
i += 2;
} else {
eprintln!("--data-dir requires a value");
std::process::exit(1);
}
}
"--help" | "-h" => {
println!("OSpipe Server - RuVector-enhanced personal AI memory");
println!();
println!("Usage: ospipe-server [OPTIONS]");
println!();
println!("Options:");
println!(" -p, --port <PORT> Listen port (default: 3030)");
println!(" -d, --data-dir <PATH> Data directory (default: ~/.ospipe)");
println!(" -h, --help Show this help message");
println!(" -V, --version Show version");
std::process::exit(0);
}
"--version" | "-V" => {
println!("ospipe-server {}", env!("CARGO_PKG_VERSION"));
std::process::exit(0);
}
other => {
eprintln!("Unknown argument: {}", other);
eprintln!("Run with --help for usage information");
std::process::exit(1);
}
}
}
// Initialize tracing
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.init();
// Build configuration
let mut config = ospipe::config::OsPipeConfig::default();
if let Some(dir) = data_dir {
config.data_dir = std::path::PathBuf::from(dir);
}
// Create the pipeline
let pipeline =
ospipe::pipeline::ingestion::IngestionPipeline::new(config).unwrap_or_else(|e| {
eprintln!("Failed to initialize pipeline: {}", e);
std::process::exit(1);
});
let state = ospipe::server::ServerState {
pipeline: Arc::new(RwLock::new(pipeline)),
router: Arc::new(ospipe::search::QueryRouter::new()),
started_at: std::time::Instant::now(),
};
// Start the async runtime and server
let rt = tokio::runtime::Runtime::new().unwrap_or_else(|e| {
eprintln!("Failed to create Tokio runtime: {}", e);
std::process::exit(1);
});
rt.block_on(async {
tracing::info!("Starting OSpipe server on port {}", port);
if let Err(e) = ospipe::server::start_server(state, port).await {
eprintln!("Server error: {}", e);
std::process::exit(1);
}
});
}

View File

@@ -0,0 +1,164 @@
//! Captured frame data structures.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// A single captured frame from any Screenpipe source.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CapturedFrame {
/// Unique identifier for this frame.
pub id: Uuid,
/// When this frame was captured.
pub timestamp: DateTime<Utc>,
/// The source that produced this frame.
pub source: CaptureSource,
/// The actual content of the frame.
pub content: FrameContent,
/// Additional metadata about the frame.
pub metadata: FrameMetadata,
}
/// The source that produced a captured frame.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CaptureSource {
/// Screen capture with OCR.
Screen {
/// Monitor index.
monitor: u32,
/// Foreground application name.
app: String,
/// Window title.
window: String,
},
/// Audio capture with transcription.
Audio {
/// Audio device name.
device: String,
/// Detected speaker (if diarization is available).
speaker: Option<String>,
},
/// UI accessibility event.
Ui {
/// Type of UI event (e.g., "click", "focus", "scroll").
event_type: String,
},
}
/// The actual content extracted from a captured frame.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FrameContent {
/// OCR text extracted from a screen capture.
OcrText(String),
/// Transcribed text from an audio capture.
Transcription(String),
/// A UI accessibility event description.
UiEvent(String),
}
/// Metadata associated with a captured frame.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FrameMetadata {
/// Name of the foreground application, if known.
pub app_name: Option<String>,
/// Title of the active window, if known.
pub window_title: Option<String>,
/// Monitor index, if applicable.
pub monitor_id: Option<u32>,
/// Confidence score for the extracted content (0.0 to 1.0).
pub confidence: f32,
/// Detected language code (e.g., "en", "es"), if known.
pub language: Option<String>,
}
impl CapturedFrame {
/// Create a new frame from a screen capture with OCR text.
pub fn new_screen(app: &str, window: &str, ocr_text: &str, monitor: u32) -> Self {
Self {
id: Uuid::new_v4(),
timestamp: Utc::now(),
source: CaptureSource::Screen {
monitor,
app: app.to_string(),
window: window.to_string(),
},
content: FrameContent::OcrText(ocr_text.to_string()),
metadata: FrameMetadata {
app_name: Some(app.to_string()),
window_title: Some(window.to_string()),
monitor_id: Some(monitor),
confidence: 0.9,
language: None,
},
}
}
/// Create a new frame from an audio transcription.
pub fn new_audio(device: &str, transcription: &str, speaker: Option<&str>) -> Self {
Self {
id: Uuid::new_v4(),
timestamp: Utc::now(),
source: CaptureSource::Audio {
device: device.to_string(),
speaker: speaker.map(|s| s.to_string()),
},
content: FrameContent::Transcription(transcription.to_string()),
metadata: FrameMetadata {
app_name: None,
window_title: None,
monitor_id: None,
confidence: 0.85,
language: None,
},
}
}
/// Create a new frame from a UI accessibility event.
pub fn new_ui_event(event_type: &str, description: &str) -> Self {
Self {
id: Uuid::new_v4(),
timestamp: Utc::now(),
source: CaptureSource::Ui {
event_type: event_type.to_string(),
},
content: FrameContent::UiEvent(description.to_string()),
metadata: FrameMetadata {
app_name: None,
window_title: None,
monitor_id: None,
confidence: 1.0,
language: None,
},
}
}
/// Extract the text content from this frame regardless of source type.
pub fn text_content(&self) -> &str {
match &self.content {
FrameContent::OcrText(text) => text,
FrameContent::Transcription(text) => text,
FrameContent::UiEvent(text) => text,
}
}
/// Return the content type as a string label.
pub fn content_type(&self) -> &str {
match &self.content {
FrameContent::OcrText(_) => "ocr",
FrameContent::Transcription(_) => "transcription",
FrameContent::UiEvent(_) => "ui_event",
}
}
}
impl Default for FrameMetadata {
fn default() -> Self {
Self {
app_name: None,
window_title: None,
monitor_id: None,
confidence: 0.0,
language: None,
}
}
}

View File

@@ -0,0 +1,9 @@
//! Capture module for processing screen, audio, and UI event data.
//!
//! This module defines the data structures that represent captured frames
//! from Screenpipe sources: OCR text from screen recordings, audio
//! transcriptions, and UI accessibility events.
pub mod frame;
pub use frame::{CaptureSource, CapturedFrame, FrameContent, FrameMetadata};

View File

@@ -0,0 +1,173 @@
//! Configuration types for all OSpipe subsystems.
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
/// Top-level OSpipe configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OsPipeConfig {
/// Directory for persistent data storage.
pub data_dir: PathBuf,
/// Capture subsystem configuration.
pub capture: CaptureConfig,
/// Storage subsystem configuration.
pub storage: StorageConfig,
/// Search subsystem configuration.
pub search: SearchConfig,
/// Safety gate configuration.
pub safety: SafetyConfig,
}
/// Configuration for the capture subsystem.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CaptureConfig {
/// Frames per second for screen capture. Default: 1.0
pub fps: f32,
/// Duration of audio chunks in seconds. Default: 30
pub audio_chunk_secs: u32,
/// Application names to exclude from capture.
pub excluded_apps: Vec<String>,
/// Whether to skip windows marked as private/incognito.
pub skip_private_windows: bool,
}
/// Configuration for the vector storage subsystem.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageConfig {
/// Dimensionality of embedding vectors. Default: 384
pub embedding_dim: usize,
/// HNSW M parameter (max connections per layer). Default: 32
pub hnsw_m: usize,
/// HNSW ef_construction parameter. Default: 200
pub hnsw_ef_construction: usize,
/// HNSW ef_search parameter. Default: 100
pub hnsw_ef_search: usize,
/// Cosine similarity threshold for deduplication. Default: 0.95
pub dedup_threshold: f32,
/// Quantization tiers for aging data.
pub quantization_tiers: Vec<QuantizationTier>,
}
/// A quantization tier that defines how vectors are compressed based on age.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizationTier {
/// Age in hours after which this quantization is applied.
pub age_hours: u64,
/// The quantization method to use.
pub method: QuantizationMethod,
}
/// Supported vector quantization methods.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum QuantizationMethod {
/// No quantization (full precision f32).
None,
/// Scalar quantization (int8).
Scalar,
/// Product quantization.
Product,
/// Binary quantization (1-bit per dimension).
Binary,
}
/// Configuration for the search subsystem.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchConfig {
/// Default number of results to return. Default: 10
pub default_k: usize,
/// Weight for semantic vs keyword search in hybrid mode. Default: 0.7
/// 1.0 = pure semantic, 0.0 = pure keyword.
pub hybrid_weight: f32,
/// MMR lambda for diversity vs relevance tradeoff. Default: 0.5
pub mmr_lambda: f32,
/// Whether to enable result reranking.
pub rerank_enabled: bool,
}
/// Configuration for the safety gate subsystem.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SafetyConfig {
/// Enable PII detection (names, emails, phone numbers).
pub pii_detection: bool,
/// Enable credit card number redaction.
pub credit_card_redaction: bool,
/// Enable SSN redaction.
pub ssn_redaction: bool,
/// Custom regex-like patterns to redact (simple substring matching).
pub custom_patterns: Vec<String>,
}
impl Default for OsPipeConfig {
fn default() -> Self {
Self {
data_dir: PathBuf::from("~/.ospipe"),
capture: CaptureConfig::default(),
storage: StorageConfig::default(),
search: SearchConfig::default(),
safety: SafetyConfig::default(),
}
}
}
impl Default for CaptureConfig {
fn default() -> Self {
Self {
fps: 1.0,
audio_chunk_secs: 30,
excluded_apps: vec!["1Password".to_string(), "Keychain Access".to_string()],
skip_private_windows: true,
}
}
}
impl Default for StorageConfig {
fn default() -> Self {
Self {
embedding_dim: 384,
hnsw_m: 32,
hnsw_ef_construction: 200,
hnsw_ef_search: 100,
dedup_threshold: 0.95,
quantization_tiers: vec![
QuantizationTier {
age_hours: 0,
method: QuantizationMethod::None,
},
QuantizationTier {
age_hours: 24,
method: QuantizationMethod::Scalar,
},
QuantizationTier {
age_hours: 168, // 1 week
method: QuantizationMethod::Product,
},
QuantizationTier {
age_hours: 720, // 30 days
method: QuantizationMethod::Binary,
},
],
}
}
}
impl Default for SearchConfig {
fn default() -> Self {
Self {
default_k: 10,
hybrid_weight: 0.7,
mmr_lambda: 0.5,
rerank_enabled: false,
}
}
}
impl Default for SafetyConfig {
fn default() -> Self {
Self {
pii_detection: true,
credit_card_redaction: true,
ssn_redaction: true,
custom_patterns: Vec::new(),
}
}
}

View File

@@ -0,0 +1,41 @@
//! Unified error types for OSpipe.
use thiserror::Error;
/// Top-level error type for all OSpipe operations.
#[derive(Error, Debug)]
pub enum OsPipeError {
/// An error occurred during screen/audio capture processing.
#[error("Capture error: {0}")]
Capture(String),
/// An error occurred in the vector storage layer.
#[error("Storage error: {0}")]
Storage(String),
/// An error occurred during search operations.
#[error("Search error: {0}")]
Search(String),
/// An error occurred in the ingestion pipeline.
#[error("Pipeline error: {0}")]
Pipeline(String),
/// The safety gate denied ingestion of content.
#[error("Safety gate denied: {reason}")]
SafetyDenied {
/// Human-readable reason for denial.
reason: String,
},
/// A configuration-related error.
#[error("Configuration error: {0}")]
Config(String),
/// A JSON serialization or deserialization error.
#[error("Serialization error: {0}")]
Serde(#[from] serde_json::Error),
}
/// Convenience alias for `Result<T, OsPipeError>`.
pub type Result<T> = std::result::Result<T, OsPipeError>;

View File

@@ -0,0 +1,217 @@
//! Heuristic named-entity recognition (NER) for extracting entities from text.
//!
//! This module performs lightweight, regex-free entity extraction suitable for
//! processing screen captures and transcriptions. It recognises:
//!
//! - **URLs** (`https://...` / `http://...`)
//! - **Email addresses** (`user@domain.tld`)
//! - **Mentions** (`@handle`)
//! - **Capitalized phrases** (two or more consecutive capitalized words -> proper nouns)
/// Extract `(label, name)` pairs from free-form `text`.
///
/// Labels returned:
/// - `"Url"` for HTTP(S) URLs
/// - `"Email"` for email-like patterns
/// - `"Mention"` for `@handle` patterns
/// - `"Person"` for capitalized multi-word phrases (heuristic proper noun)
pub fn extract_entities(text: &str) -> Vec<(String, String)> {
let mut entities: Vec<(String, String)> = Vec::new();
let mut seen = std::collections::HashSet::new();
// --- URL detection ---
for word in text.split_whitespace() {
let trimmed =
word.trim_matches(|c: char| c == ',' || c == '.' || c == ')' || c == '(' || c == ';');
if (trimmed.starts_with("http://") || trimmed.starts_with("https://"))
&& trimmed.len() > 10
&& seen.insert(("Url", trimmed.to_string()))
{
entities.push(("Url".to_string(), trimmed.to_string()));
}
}
// --- Email detection ---
for word in text.split_whitespace() {
let trimmed = word.trim_matches(|c: char| {
c == ',' || c == '.' || c == ')' || c == '(' || c == ';' || c == '<' || c == '>'
});
if is_email_like(trimmed) && seen.insert(("Email", trimmed.to_string())) {
entities.push(("Email".to_string(), trimmed.to_string()));
}
}
// --- @mention detection ---
for word in text.split_whitespace() {
let trimmed =
word.trim_matches(|c: char| c == ',' || c == '.' || c == ')' || c == '(' || c == ';');
if trimmed.starts_with('@') && trimmed.len() > 1 {
let handle = trimmed.to_string();
if seen.insert(("Mention", handle.clone())) {
entities.push(("Mention".to_string(), handle));
}
}
}
// --- Capitalized phrase detection (proper nouns) ---
let cap_phrases = extract_capitalized_phrases(text);
for phrase in cap_phrases {
if seen.insert(("Person", phrase.clone())) {
entities.push(("Person".to_string(), phrase));
}
}
entities
}
/// Returns `true` if `s` looks like an email address (`local@domain.tld`).
fn is_email_like(s: &str) -> bool {
// Must contain exactly one '@', with non-empty parts on both sides,
// and the domain part must contain at least one '.'.
if let Some(at_pos) = s.find('@') {
let local = &s[..at_pos];
let domain = &s[at_pos + 1..];
!local.is_empty()
&& !domain.is_empty()
&& domain.contains('.')
&& !domain.starts_with('.')
&& !domain.ends_with('.')
&& local
.chars()
.all(|c| c.is_alphanumeric() || c == '.' || c == '_' || c == '-' || c == '+')
&& domain
.chars()
.all(|c| c.is_alphanumeric() || c == '.' || c == '-')
} else {
false
}
}
/// Extract sequences of two or more consecutive capitalized words as likely
/// proper nouns. Filters out common sentence-starting words when they appear
/// alone at what looks like a sentence boundary.
fn extract_capitalized_phrases(text: &str) -> Vec<String> {
let mut phrases = Vec::new();
let words: Vec<&str> = text.split_whitespace().collect();
let mut i = 0;
while i < words.len() {
// Skip words that start a sentence (preceded by nothing or a sentence-ending punctuation).
let word = words[i].trim_matches(|c: char| !c.is_alphanumeric());
if is_capitalized(word) && word.len() > 1 {
// Accumulate consecutive capitalized words.
let start = i;
let mut parts = vec![word.to_string()];
i += 1;
while i < words.len() {
let next = words[i].trim_matches(|c: char| !c.is_alphanumeric());
if is_capitalized(next) && next.len() > 1 {
parts.push(next.to_string());
i += 1;
} else {
break;
}
}
// Only take phrases of 2+ words (single capitalized words are too noisy).
if parts.len() >= 2 {
// Skip if the first word is at position 0 or follows a sentence terminator
// and is a common article/pronoun. We still keep it if part of a longer
// multi-word phrase that itself is capitalized.
let is_sentence_start = start == 0
|| words.get(start.wrapping_sub(1)).is_some_and(|prev| {
prev.ends_with('.') || prev.ends_with('!') || prev.ends_with('?')
});
if is_sentence_start && parts.len() == 2 && is_common_starter(&parts[0]) {
// Skip - likely just a sentence starting with "The Xyz" etc.
} else {
let phrase = parts.join(" ");
phrases.push(phrase);
}
}
} else {
i += 1;
}
}
phrases
}
/// Returns `true` if the first character of `word` is uppercase ASCII.
fn is_capitalized(word: &str) -> bool {
word.chars().next().is_some_and(|c| c.is_uppercase())
}
/// Common sentence-starting words that are not proper nouns.
fn is_common_starter(word: &str) -> bool {
matches!(
word.to_lowercase().as_str(),
"the"
| "a"
| "an"
| "this"
| "that"
| "these"
| "those"
| "it"
| "i"
| "we"
| "they"
| "he"
| "she"
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_urls() {
let entities =
extract_entities("Visit https://example.com/page and http://foo.bar/baz for info.");
let urls: Vec<_> = entities.iter().filter(|(l, _)| l == "Url").collect();
assert_eq!(urls.len(), 2);
assert_eq!(urls[0].1, "https://example.com/page");
assert_eq!(urls[1].1, "http://foo.bar/baz");
}
#[test]
fn test_extract_emails() {
let entities = extract_entities("Email alice@example.com or bob@company.org for help.");
let emails: Vec<_> = entities.iter().filter(|(l, _)| l == "Email").collect();
assert_eq!(emails.len(), 2);
}
#[test]
fn test_extract_mentions() {
let entities = extract_entities("Hey @alice and @bob-dev, check this out.");
let mentions: Vec<_> = entities.iter().filter(|(l, _)| l == "Mention").collect();
assert_eq!(mentions.len(), 2);
assert_eq!(mentions[0].1, "@alice");
assert_eq!(mentions[1].1, "@bob-dev");
}
#[test]
fn test_extract_capitalized_phrases() {
let entities = extract_entities("I met John Smith at the World Trade Center yesterday.");
let persons: Vec<_> = entities.iter().filter(|(l, _)| l == "Person").collect();
assert!(persons.iter().any(|(_, n)| n == "John Smith"));
assert!(persons.iter().any(|(_, n)| n == "World Trade Center"));
}
#[test]
fn test_no_false_positives_on_sentence_start() {
let entities = extract_entities("The cat sat on the mat.");
let persons: Vec<_> = entities.iter().filter(|(l, _)| l == "Person").collect();
// "The cat" should not appear as a person (single cap word + lowercase).
assert!(persons.is_empty());
}
#[test]
fn test_deduplication() {
let entities = extract_entities("Visit https://example.com and https://example.com again.");
let urls: Vec<_> = entities.iter().filter(|(l, _)| l == "Url").collect();
assert_eq!(urls.len(), 1);
}
}

View File

@@ -0,0 +1,359 @@
//! Knowledge graph integration for OSpipe.
//!
//! Provides entity extraction from captured text and stores entity relationships
//! in a [`ruvector_graph::GraphDB`] (native) or a lightweight in-memory stub (WASM).
//!
//! ## Usage
//!
//! ```rust,no_run
//! use ospipe::graph::KnowledgeGraph;
//!
//! let mut kg = KnowledgeGraph::new();
//! let ids = kg.ingest_frame_entities("frame-001", "Meeting with John Smith at https://meet.example.com").unwrap();
//! let people = kg.find_by_label("Person");
//! ```
pub mod entity_extractor;
use crate::error::Result;
use std::collections::HashMap;
/// A lightweight entity representation returned by query methods.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Entity {
/// Unique identifier for this entity.
pub id: String,
/// Category label (e.g. "Person", "Url", "Mention", "Email", "Frame").
pub label: String,
/// Human-readable name or value.
pub name: String,
/// Additional key-value properties.
pub properties: HashMap<String, String>,
}
// ---------------------------------------------------------------------------
// Native implementation (backed by ruvector-graph)
// ---------------------------------------------------------------------------
#[cfg(not(target_arch = "wasm32"))]
mod inner {
use super::*;
use crate::error::OsPipeError;
use ruvector_graph::{EdgeBuilder, GraphDB, NodeBuilder, PropertyValue};
/// A knowledge graph that stores entity relationships extracted from captured
/// frames. On native targets this is backed by [`ruvector_graph::GraphDB`].
pub struct KnowledgeGraph {
db: GraphDB,
}
impl KnowledgeGraph {
/// Create a new, empty knowledge graph.
pub fn new() -> Self {
Self { db: GraphDB::new() }
}
/// Add an entity node to the graph.
///
/// Returns the newly created node ID.
pub fn add_entity(
&self,
label: &str,
name: &str,
properties: HashMap<String, String>,
) -> Result<String> {
let mut builder = NodeBuilder::new().label(label).property("name", name);
for (k, v) in &properties {
builder = builder.property(k.as_str(), v.as_str());
}
let node = builder.build();
let id = self
.db
.create_node(node)
.map_err(|e| OsPipeError::Storage(format!("graph: {}", e)))?;
Ok(id)
}
/// Create a directed relationship (edge) between two entities.
///
/// Both `from_id` and `to_id` must refer to existing nodes.
/// Returns the edge ID.
pub fn add_relationship(
&self,
from_id: &str,
to_id: &str,
rel_type: &str,
) -> Result<String> {
let edge = EdgeBuilder::new(from_id.to_string(), to_id.to_string(), rel_type).build();
let id = self
.db
.create_edge(edge)
.map_err(|e| OsPipeError::Storage(format!("graph: {}", e)))?;
Ok(id)
}
/// Find all entities that carry `label`.
pub fn find_by_label(&self, label: &str) -> Vec<Entity> {
self.db
.get_nodes_by_label(label)
.into_iter()
.map(|n| node_to_entity(&n))
.collect()
}
/// Find all entities directly connected to `entity_id` (both outgoing and
/// incoming edges).
pub fn neighbors(&self, entity_id: &str) -> Vec<Entity> {
let mut seen = std::collections::HashSet::new();
let mut result = Vec::new();
let node_id = entity_id.to_string();
// Outgoing neighbours.
for edge in self.db.get_outgoing_edges(&node_id) {
if seen.insert(edge.to.clone()) {
if let Some(node) = self.db.get_node(&edge.to) {
result.push(node_to_entity(&node));
}
}
}
// Incoming neighbours.
for edge in self.db.get_incoming_edges(&node_id) {
if seen.insert(edge.from.clone()) {
if let Some(node) = self.db.get_node(&edge.from) {
result.push(node_to_entity(&node));
}
}
}
result
}
/// Run heuristic NER on `text` and return extracted `(label, name)` pairs.
pub fn extract_entities(text: &str) -> Vec<(String, String)> {
entity_extractor::extract_entities(text)
}
/// Extract entities from `text`, create nodes for each, link them to the
/// given `frame_id` node (creating the frame node if it does not yet exist),
/// and return the IDs of all newly created entity nodes.
pub fn ingest_frame_entities(&self, frame_id: &str, text: &str) -> Result<Vec<String>> {
// Ensure frame node exists.
let frame_node_id = if self.db.get_node(frame_id).is_some() {
frame_id.to_string()
} else {
let node = NodeBuilder::new()
.id(frame_id)
.label("Frame")
.property("name", frame_id)
.build();
self.db
.create_node(node)
.map_err(|e| OsPipeError::Storage(format!("graph: {}", e)))?
};
let extracted = entity_extractor::extract_entities(text);
let mut entity_ids = Vec::with_capacity(extracted.len());
for (label, name) in &extracted {
let entity_id = self.add_entity(label, name, HashMap::new())?;
self.add_relationship(&frame_node_id, &entity_id, "CONTAINS")?;
entity_ids.push(entity_id);
}
Ok(entity_ids)
}
}
impl Default for KnowledgeGraph {
fn default() -> Self {
Self::new()
}
}
/// Convert a `ruvector_graph::Node` into the crate-public `Entity` type.
fn node_to_entity(node: &ruvector_graph::Node) -> Entity {
let label = node
.labels
.first()
.map_or_else(String::new, |l| l.name.clone());
let name = match node.get_property("name") {
Some(PropertyValue::String(s)) => s.clone(),
_ => String::new(),
};
let mut properties = HashMap::new();
for (k, v) in &node.properties {
if k == "name" {
continue;
}
let v_str = match v {
PropertyValue::String(s) => s.clone(),
PropertyValue::Integer(i) => i.to_string(),
PropertyValue::Float(f) => f.to_string(),
PropertyValue::Boolean(b) => b.to_string(),
_ => format!("{:?}", v),
};
properties.insert(k.clone(), v_str);
}
Entity {
id: node.id.clone(),
label,
name,
properties,
}
}
}
// ---------------------------------------------------------------------------
// WASM fallback (lightweight in-memory stub)
// ---------------------------------------------------------------------------
#[cfg(target_arch = "wasm32")]
mod inner {
use super::*;
struct StoredNode {
id: String,
label: String,
name: String,
properties: HashMap<String, String>,
}
struct StoredEdge {
_id: String,
from: String,
to: String,
_rel_type: String,
}
/// A knowledge graph backed by simple `Vec` storage for WASM targets.
pub struct KnowledgeGraph {
nodes: Vec<StoredNode>,
edges: Vec<StoredEdge>,
next_id: u64,
}
impl KnowledgeGraph {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
edges: Vec::new(),
next_id: 0,
}
}
pub fn add_entity(
&mut self,
label: &str,
name: &str,
properties: HashMap<String, String>,
) -> Result<String> {
let id = format!("wasm-{}", self.next_id);
self.next_id += 1;
self.nodes.push(StoredNode {
id: id.clone(),
label: label.to_string(),
name: name.to_string(),
properties,
});
Ok(id)
}
pub fn add_relationship(
&mut self,
from_id: &str,
to_id: &str,
rel_type: &str,
) -> Result<String> {
let id = format!("wasm-e-{}", self.next_id);
self.next_id += 1;
self.edges.push(StoredEdge {
_id: id.clone(),
from: from_id.to_string(),
to: to_id.to_string(),
_rel_type: rel_type.to_string(),
});
Ok(id)
}
pub fn find_by_label(&self, label: &str) -> Vec<Entity> {
self.nodes
.iter()
.filter(|n| n.label == label)
.map(|n| Entity {
id: n.id.clone(),
label: n.label.clone(),
name: n.name.clone(),
properties: n.properties.clone(),
})
.collect()
}
pub fn neighbors(&self, entity_id: &str) -> Vec<Entity> {
let mut ids = std::collections::HashSet::new();
for e in &self.edges {
if e.from == entity_id {
ids.insert(e.to.clone());
}
if e.to == entity_id {
ids.insert(e.from.clone());
}
}
self.nodes
.iter()
.filter(|n| ids.contains(&n.id))
.map(|n| Entity {
id: n.id.clone(),
label: n.label.clone(),
name: n.name.clone(),
properties: n.properties.clone(),
})
.collect()
}
pub fn extract_entities(text: &str) -> Vec<(String, String)> {
entity_extractor::extract_entities(text)
}
pub fn ingest_frame_entities(&mut self, frame_id: &str, text: &str) -> Result<Vec<String>> {
// Ensure frame node.
let frame_exists = self.nodes.iter().any(|n| n.id == frame_id);
let frame_node_id = if frame_exists {
frame_id.to_string()
} else {
let id = frame_id.to_string();
self.nodes.push(StoredNode {
id: id.clone(),
label: "Frame".to_string(),
name: frame_id.to_string(),
properties: HashMap::new(),
});
id
};
let extracted = entity_extractor::extract_entities(text);
let mut entity_ids = Vec::with_capacity(extracted.len());
for (label, name) in &extracted {
let eid = self.add_entity(label, name, HashMap::new())?;
self.add_relationship(&frame_node_id, &eid, "CONTAINS")?;
entity_ids.push(eid);
}
Ok(entity_ids)
}
}
impl Default for KnowledgeGraph {
fn default() -> Self {
Self::new()
}
}
}
// Re-export the platform-appropriate implementation.
pub use inner::KnowledgeGraph;

View File

@@ -0,0 +1,327 @@
//! Continual learning for search improvement.
//!
//! This module integrates `ruvector-gnn` to provide:
//!
//! - **[`SearchLearner`]** -- records user relevance feedback and uses Elastic
//! Weight Consolidation (EWC) to prevent catastrophic forgetting when the
//! embedding model is fine-tuned over time.
//! - **[`EmbeddingQuantizer`]** -- compresses stored embeddings based on their
//! age, trading precision for storage savings on cold data.
//!
//! Both structs compile to no-op stubs on `wasm32` targets where the native
//! `ruvector-gnn` crate is unavailable.
// ---------------------------------------------------------------------------
// Native implementation (non-WASM)
// ---------------------------------------------------------------------------
#[cfg(not(target_arch = "wasm32"))]
mod native {
use ruvector_gnn::compress::TensorCompress;
use ruvector_gnn::ewc::ElasticWeightConsolidation;
use ruvector_gnn::replay::ReplayBuffer;
/// Minimum number of feedback entries before learning data is considered
/// sufficient for a consolidation step.
const MIN_FEEDBACK_ENTRIES: usize = 32;
/// Records search relevance feedback and manages continual-learning state.
///
/// Internally the learner maintains:
/// - A [`ReplayBuffer`] that stores (query, result, relevance) triples via
/// reservoir sampling so old feedback is not forgotten.
/// - An [`ElasticWeightConsolidation`] instance whose Fisher diagonal and
/// anchor weights track which embedding dimensions are important.
/// - A simple parameter vector (`weights`) that represents a learned
/// relevance projection (one weight per embedding dimension).
pub struct SearchLearner {
replay_buffer: ReplayBuffer,
ewc: ElasticWeightConsolidation,
/// Learned relevance-projection weights (one per embedding dimension).
weights: Vec<f32>,
}
impl SearchLearner {
/// Create a new learner.
///
/// # Arguments
/// * `embedding_dim` - Dimensionality of the embedding vectors.
/// * `replay_capacity` - Maximum number of feedback entries retained.
pub fn new(embedding_dim: usize, replay_capacity: usize) -> Self {
Self {
replay_buffer: ReplayBuffer::new(replay_capacity),
ewc: ElasticWeightConsolidation::new(100.0),
weights: vec![1.0; embedding_dim],
}
}
/// Record a single piece of user feedback.
///
/// The query and result embeddings are concatenated and stored in the
/// replay buffer. Positive feedback entries use `positive_ids = [1]`,
/// negative ones use `positive_ids = [0]`, which allows downstream
/// training loops to distinguish them.
///
/// # Arguments
/// * `query_embedding` - Embedding of the search query.
/// * `result_embedding` - Embedding of the search result.
/// * `relevant` - Whether the user considered the result relevant.
pub fn record_feedback(
&mut self,
query_embedding: Vec<f32>,
result_embedding: Vec<f32>,
relevant: bool,
) {
let mut combined = query_embedding;
combined.extend_from_slice(&result_embedding);
let positive_id: usize = if relevant { 1 } else { 0 };
self.replay_buffer.add(&combined, &[positive_id]);
}
/// Return the current size of the replay buffer.
pub fn replay_buffer_len(&self) -> usize {
self.replay_buffer.len()
}
/// Returns `true` when the buffer contains enough data for a
/// meaningful consolidation step (>= 32 entries).
pub fn has_sufficient_data(&self) -> bool {
self.replay_buffer.len() >= MIN_FEEDBACK_ENTRIES
}
/// Lock the current parameter state with EWC.
///
/// This computes the Fisher information diagonal from sampled replay
/// entries and saves the current weights as the EWC anchor. Future
/// EWC penalties will discourage large deviations from these weights.
pub fn consolidate(&mut self) {
if self.replay_buffer.is_empty() {
return;
}
// Sample gradients -- we approximate them as the difference between
// query and result portions of each stored entry.
let samples = self.replay_buffer.sample(self.replay_buffer.len().min(64));
let dim = self.weights.len();
let gradients: Vec<Vec<f32>> = samples
.iter()
.filter_map(|entry| {
// Each entry stores [query || result]; extract gradient proxy.
if entry.query.len() >= dim * 2 {
let query_part = &entry.query[..dim];
let result_part = &entry.query[dim..dim * 2];
let grad: Vec<f32> = query_part
.iter()
.zip(result_part.iter())
.map(|(q, r)| q - r)
.collect();
Some(grad)
} else {
None
}
})
.collect();
if gradients.is_empty() {
return;
}
let grad_refs: Vec<&[f32]> = gradients.iter().map(|g| g.as_slice()).collect();
let sample_count = grad_refs.len();
self.ewc.compute_fisher(&grad_refs, sample_count);
self.ewc.consolidate(&self.weights);
}
/// Return the current EWC penalty for the learned weights.
///
/// Returns `0.0` if [`consolidate`](Self::consolidate) has not been
/// called yet.
pub fn ewc_penalty(&self) -> f32 {
self.ewc.penalty(&self.weights)
}
}
// -----------------------------------------------------------------------
// EmbeddingQuantizer
// -----------------------------------------------------------------------
/// Age-aware embedding quantizer backed by [`TensorCompress`].
///
/// Older embeddings are compressed more aggressively:
///
/// | Age | Compression |
/// |----------------|----------------------|
/// | < 1 hour | Full precision |
/// | 1 h -- 24 h | Half precision (FP16)|
/// | 1 d -- 7 d | PQ8 |
/// | > 7 d | Binary |
pub struct EmbeddingQuantizer {
compressor: TensorCompress,
}
impl Default for EmbeddingQuantizer {
fn default() -> Self {
Self::new()
}
}
impl EmbeddingQuantizer {
/// Create a new quantizer instance.
pub fn new() -> Self {
Self {
compressor: TensorCompress::new(),
}
}
/// Compress an embedding based on its age.
///
/// The age determines the access-frequency proxy passed to the
/// underlying `TensorCompress`:
/// - `< 1 h` -> freq `1.0` (no compression)
/// - `1-24 h` -> freq `0.5` (half precision)
/// - `1-7 d` -> freq `0.2` (PQ8)
/// - `> 7 d` -> freq `0.005` (binary)
///
/// # Arguments
/// * `embedding` - The raw embedding vector.
/// * `age_hours` - Age of the embedding in hours.
///
/// # Returns
/// Serialised compressed bytes. Use [`dequantize`](Self::dequantize)
/// to recover the original (lossy) vector.
pub fn quantize_by_age(&self, embedding: &[f32], age_hours: u64) -> Vec<u8> {
let access_freq = Self::age_to_freq(age_hours);
match self.compressor.compress(embedding, access_freq) {
Ok(compressed) => {
serde_json::to_vec(&compressed).unwrap_or_else(|_| {
// Fallback: store raw f32 bytes.
embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
})
}
Err(_) => {
// Fallback: store raw f32 bytes.
embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
}
}
}
/// Decompress bytes produced by [`quantize_by_age`](Self::quantize_by_age).
///
/// # Arguments
/// * `data` - Compressed byte representation.
/// * `original_dim` - Expected dimensionality of the output vector.
///
/// # Returns
/// The decompressed embedding (lossy). If decompression fails, a
/// zero-vector of `original_dim` length is returned.
pub fn dequantize(&self, data: &[u8], original_dim: usize) -> Vec<f32> {
if let Ok(compressed) =
serde_json::from_slice::<ruvector_gnn::compress::CompressedTensor>(data)
{
if let Ok(decompressed) = self.compressor.decompress(&compressed) {
if decompressed.len() == original_dim {
return decompressed;
}
}
}
// Fallback: try interpreting as raw f32 bytes.
if data.len() == original_dim * 4 {
return data
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
}
vec![0.0; original_dim]
}
/// Map an age in hours to an access-frequency proxy in [0, 1].
fn age_to_freq(age_hours: u64) -> f32 {
match age_hours {
0 => 1.0, // Fresh -- full precision
1..=24 => 0.5, // Warm -- half precision
25..=168 => 0.2, // Cool -- PQ8
_ => 0.005, // Cold -- binary
}
}
}
}
// ---------------------------------------------------------------------------
// WASM stub implementation
// ---------------------------------------------------------------------------
#[cfg(target_arch = "wasm32")]
mod wasm_stub {
/// No-op search learner for WASM targets.
pub struct SearchLearner {
buffer_len: usize,
}
impl SearchLearner {
pub fn new(_embedding_dim: usize, _replay_capacity: usize) -> Self {
Self { buffer_len: 0 }
}
pub fn record_feedback(
&mut self,
_query_embedding: Vec<f32>,
_result_embedding: Vec<f32>,
_relevant: bool,
) {
self.buffer_len += 1;
}
pub fn replay_buffer_len(&self) -> usize {
self.buffer_len
}
pub fn has_sufficient_data(&self) -> bool {
self.buffer_len >= 32
}
pub fn consolidate(&mut self) {}
pub fn ewc_penalty(&self) -> f32 {
0.0
}
}
/// No-op embedding quantizer for WASM targets.
///
/// Returns the original embedding bytes without compression.
pub struct EmbeddingQuantizer;
impl EmbeddingQuantizer {
pub fn new() -> Self {
Self
}
pub fn quantize_by_age(&self, embedding: &[f32], _age_hours: u64) -> Vec<u8> {
embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
}
pub fn dequantize(&self, data: &[u8], original_dim: usize) -> Vec<f32> {
if data.len() == original_dim * 4 {
data.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
} else {
vec![0.0; original_dim]
}
}
}
}
// ---------------------------------------------------------------------------
// Re-exports
// ---------------------------------------------------------------------------
#[cfg(not(target_arch = "wasm32"))]
pub use native::{EmbeddingQuantizer, SearchLearner};
#[cfg(target_arch = "wasm32")]
pub use wasm_stub::{EmbeddingQuantizer, SearchLearner};

View File

@@ -0,0 +1,43 @@
//! # OSpipe
//!
//! RuVector-enhanced personal AI memory system integrating with Screenpipe.
//!
//! OSpipe captures screen content, audio transcriptions, and UI events,
//! processes them through a safety-aware ingestion pipeline, and stores
//! them as searchable vector embeddings for personal AI memory recall.
//!
//! ## Architecture
//!
//! ```text
//! Screenpipe -> Capture -> Safety Gate -> Dedup -> Embed -> VectorStore
//! |
//! Search Router <--------+
//! (Semantic / Keyword / Hybrid)
//! ```
//!
//! ## Modules
//!
//! - [`capture`] - Captured frame data structures (OCR, transcription, UI events)
//! - [`storage`] - HNSW-backed vector storage and embedding engine
//! - [`search`] - Query routing and hybrid search (semantic + keyword)
//! - [`pipeline`] - Ingestion pipeline with deduplication
//! - [`safety`] - PII detection and content redaction
//! - [`config`] - Configuration for all subsystems
//! - [`error`] - Unified error types
pub mod capture;
pub mod config;
pub mod error;
pub mod graph;
pub mod learning;
#[cfg(not(target_arch = "wasm32"))]
pub mod persistence;
pub mod pipeline;
pub mod quantum;
pub mod safety;
pub mod search;
#[cfg(not(target_arch = "wasm32"))]
pub mod server;
pub mod storage;
pub mod wasm;

View File

@@ -0,0 +1,319 @@
//! JSON-file persistence layer for OSpipe data.
//!
//! Provides durable storage of frames, configuration, and embedding data
//! using the local filesystem. All data is serialized to JSON (frames and
//! config) or raw bytes (embeddings) inside a configurable data directory.
//!
//! This module is gated behind `cfg(not(target_arch = "wasm32"))` because
//! WASM targets do not have filesystem access.
use crate::capture::CapturedFrame;
use crate::config::OsPipeConfig;
use crate::error::{OsPipeError, Result};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
/// A serializable wrapper around [`CapturedFrame`] for disk persistence.
///
/// This mirrors all fields of `CapturedFrame` but is kept as a distinct
/// type so the persistence format can evolve independently of the
/// in-memory representation.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredFrame {
/// The captured frame data.
pub frame: CapturedFrame,
/// Optional text that was stored after safety-gate processing.
/// If `None`, the original frame text was used unchanged.
pub safe_text: Option<String>,
}
/// Filesystem-backed persistence for OSpipe data.
///
/// All files are written inside `data_dir`:
/// - `frames.json` - serialized vector of [`StoredFrame`]
/// - `config.json` - serialized [`OsPipeConfig`]
/// - `embeddings.bin` - raw bytes (e.g. HNSW index serialization)
pub struct PersistenceLayer {
data_dir: PathBuf,
}
impl PersistenceLayer {
/// Create a new persistence layer rooted at `data_dir`.
///
/// The directory (and any missing parents) will be created if they
/// do not already exist.
pub fn new(data_dir: PathBuf) -> Result<Self> {
std::fs::create_dir_all(&data_dir).map_err(|e| {
OsPipeError::Storage(format!(
"Failed to create data directory {}: {}",
data_dir.display(),
e
))
})?;
Ok(Self { data_dir })
}
/// Return the path to a named file inside the data directory.
fn file_path(&self, name: &str) -> PathBuf {
self.data_dir.join(name)
}
// ---- Frames ----
/// Persist a slice of stored frames to `frames.json`.
pub fn save_frames(&self, frames: &[StoredFrame]) -> Result<()> {
let path = self.file_path("frames.json");
let json = serde_json::to_string_pretty(frames)?;
std::fs::write(&path, json).map_err(|e| {
OsPipeError::Storage(format!(
"Failed to write frames to {}: {}",
path.display(),
e
))
})
}
/// Load stored frames from `frames.json`.
///
/// Returns an empty vector if the file does not exist.
pub fn load_frames(&self) -> Result<Vec<StoredFrame>> {
let path = self.file_path("frames.json");
if !path.exists() {
return Ok(Vec::new());
}
let data = std::fs::read_to_string(&path).map_err(|e| {
OsPipeError::Storage(format!(
"Failed to read frames from {}: {}",
path.display(),
e
))
})?;
let frames: Vec<StoredFrame> = serde_json::from_str(&data)?;
Ok(frames)
}
// ---- Config ----
/// Persist the pipeline configuration to `config.json`.
pub fn save_config(&self, config: &OsPipeConfig) -> Result<()> {
let path = self.file_path("config.json");
let json = serde_json::to_string_pretty(config)?;
std::fs::write(&path, json).map_err(|e| {
OsPipeError::Storage(format!(
"Failed to write config to {}: {}",
path.display(),
e
))
})
}
/// Load the pipeline configuration from `config.json`.
///
/// Returns `None` if the file does not exist.
pub fn load_config(&self) -> Result<Option<OsPipeConfig>> {
let path = self.file_path("config.json");
if !path.exists() {
return Ok(None);
}
let data = std::fs::read_to_string(&path).map_err(|e| {
OsPipeError::Storage(format!(
"Failed to read config from {}: {}",
path.display(),
e
))
})?;
let config: OsPipeConfig = serde_json::from_str(&data)?;
Ok(Some(config))
}
// ---- Embeddings (raw bytes) ----
/// Persist raw embedding bytes to `embeddings.bin`.
///
/// This is intended for serializing an HNSW index or other binary
/// data that does not fit the JSON format.
pub fn save_embeddings(&self, data: &[u8]) -> Result<()> {
let path = self.file_path("embeddings.bin");
std::fs::write(&path, data).map_err(|e| {
OsPipeError::Storage(format!(
"Failed to write embeddings to {}: {}",
path.display(),
e
))
})
}
/// Load raw embedding bytes from `embeddings.bin`.
///
/// Returns `None` if the file does not exist.
pub fn load_embeddings(&self) -> Result<Option<Vec<u8>>> {
let path = self.file_path("embeddings.bin");
if !path.exists() {
return Ok(None);
}
let data = std::fs::read(&path).map_err(|e| {
OsPipeError::Storage(format!(
"Failed to read embeddings from {}: {}",
path.display(),
e
))
})?;
Ok(Some(data))
}
/// Return the data directory path.
pub fn data_dir(&self) -> &PathBuf {
&self.data_dir
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::capture::CapturedFrame;
fn temp_dir() -> PathBuf {
let dir = std::env::temp_dir().join(format!("ospipe_test_{}", uuid::Uuid::new_v4()));
std::fs::create_dir_all(&dir).unwrap();
dir
}
#[test]
fn test_frames_roundtrip() {
let dir = temp_dir();
let layer = PersistenceLayer::new(dir.clone()).unwrap();
let frame = CapturedFrame::new_screen("VSCode", "main.rs", "fn main() {}", 0);
let stored = vec![StoredFrame {
frame,
safe_text: None,
}];
layer.save_frames(&stored).unwrap();
let loaded = layer.load_frames().unwrap();
assert_eq!(loaded.len(), 1);
assert_eq!(loaded[0].frame.text_content(), "fn main() {}");
assert!(loaded[0].safe_text.is_none());
// Cleanup
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn test_frames_empty_when_missing() {
let dir = temp_dir();
let layer = PersistenceLayer::new(dir.clone()).unwrap();
let loaded = layer.load_frames().unwrap();
assert!(loaded.is_empty());
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn test_config_roundtrip() {
let dir = temp_dir();
let layer = PersistenceLayer::new(dir.clone()).unwrap();
let config = OsPipeConfig::default();
layer.save_config(&config).unwrap();
let loaded = layer.load_config().unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.storage.embedding_dim, 384);
assert_eq!(loaded.capture.fps, 1.0);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn test_config_none_when_missing() {
let dir = temp_dir();
let layer = PersistenceLayer::new(dir.clone()).unwrap();
let loaded = layer.load_config().unwrap();
assert!(loaded.is_none());
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn test_embeddings_roundtrip() {
let dir = temp_dir();
let layer = PersistenceLayer::new(dir.clone()).unwrap();
let data: Vec<u8> = vec![0xDE, 0xAD, 0xBE, 0xEF, 1, 2, 3, 4];
layer.save_embeddings(&data).unwrap();
let loaded = layer.load_embeddings().unwrap();
assert!(loaded.is_some());
assert_eq!(loaded.unwrap(), data);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn test_embeddings_none_when_missing() {
let dir = temp_dir();
let layer = PersistenceLayer::new(dir.clone()).unwrap();
let loaded = layer.load_embeddings().unwrap();
assert!(loaded.is_none());
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn test_creates_directory_if_missing() {
let dir = std::env::temp_dir()
.join(format!("ospipe_test_{}", uuid::Uuid::new_v4()))
.join("nested")
.join("deep");
assert!(!dir.exists());
let layer = PersistenceLayer::new(dir.clone());
assert!(layer.is_ok());
assert!(dir.exists());
let _ = std::fs::remove_dir_all(dir.parent().unwrap().parent().unwrap());
}
#[test]
fn test_multiple_frames_roundtrip() {
let dir = temp_dir();
let layer = PersistenceLayer::new(dir.clone()).unwrap();
let frames: Vec<StoredFrame> = (0..5)
.map(|i| StoredFrame {
frame: CapturedFrame::new_screen(
"App",
&format!("Window {}", i),
&format!("Content {}", i),
0,
),
safe_text: if i % 2 == 0 {
Some(format!("Redacted {}", i))
} else {
None
},
})
.collect();
layer.save_frames(&frames).unwrap();
let loaded = layer.load_frames().unwrap();
assert_eq!(loaded.len(), 5);
for (i, sf) in loaded.iter().enumerate() {
assert_eq!(sf.frame.text_content(), &format!("Content {}", i));
if i % 2 == 0 {
assert_eq!(sf.safe_text, Some(format!("Redacted {}", i)));
} else {
assert!(sf.safe_text.is_none());
}
}
let _ = std::fs::remove_dir_all(&dir);
}
}

View File

@@ -0,0 +1,89 @@
//! Frame deduplication using cosine similarity.
//!
//! Maintains a sliding window of recent embeddings and checks new
//! frames against them to avoid storing near-duplicate content
//! (e.g., consecutive screen captures of the same static page).
use std::collections::VecDeque;
use crate::storage::embedding::cosine_similarity;
use uuid::Uuid;
/// Deduplicator that checks new embeddings against a sliding window
/// of recently stored embeddings.
pub struct FrameDeduplicator {
/// Cosine similarity threshold above which a frame is considered duplicate.
threshold: f32,
/// Sliding window of recent embeddings (id, vector).
recent_embeddings: VecDeque<(Uuid, Vec<f32>)>,
/// Maximum number of recent embeddings to keep.
window_size: usize,
}
impl FrameDeduplicator {
/// Create a new deduplicator.
///
/// - `threshold`: Cosine similarity threshold for duplicate detection (e.g., 0.95).
/// - `window_size`: Number of recent embeddings to keep for comparison.
pub fn new(threshold: f32, window_size: usize) -> Self {
Self {
threshold,
recent_embeddings: VecDeque::with_capacity(window_size),
window_size,
}
}
/// Check if the given embedding is a duplicate of a recent entry.
///
/// Returns `Some((id, similarity))` if a duplicate is found, where
/// `id` is the ID of the matching recent embedding and `similarity`
/// is the cosine similarity score.
pub fn is_duplicate(&self, embedding: &[f32]) -> Option<(Uuid, f32)> {
let mut best_match: Option<(Uuid, f32)> = None;
for (id, stored_emb) in &self.recent_embeddings {
if stored_emb.len() != embedding.len() {
continue;
}
let sim = cosine_similarity(embedding, stored_emb);
if sim >= self.threshold {
match best_match {
Some((_, best_sim)) if sim > best_sim => {
best_match = Some((*id, sim));
}
None => {
best_match = Some((*id, sim));
}
_ => {}
}
}
}
best_match
}
/// Add an embedding to the sliding window.
///
/// If the window is full, the oldest entry is evicted.
pub fn add(&mut self, id: Uuid, embedding: Vec<f32>) {
if self.recent_embeddings.len() >= self.window_size {
self.recent_embeddings.pop_front();
}
self.recent_embeddings.push_back((id, embedding));
}
/// Return the current number of embeddings in the window.
pub fn window_len(&self) -> usize {
self.recent_embeddings.len()
}
/// Return the configured similarity threshold.
pub fn threshold(&self) -> f32 {
self.threshold
}
/// Clear all entries from the sliding window.
pub fn clear(&mut self) {
self.recent_embeddings.clear();
}
}

View File

@@ -0,0 +1,212 @@
//! Main ingestion pipeline.
use crate::capture::CapturedFrame;
use crate::config::OsPipeConfig;
use crate::error::Result;
use crate::graph::KnowledgeGraph;
use crate::pipeline::dedup::FrameDeduplicator;
use crate::safety::{SafetyDecision, SafetyGate};
use crate::search::enhanced::EnhancedSearch;
use crate::storage::embedding::EmbeddingEngine;
use crate::storage::vector_store::{SearchResult, VectorStore};
use uuid::Uuid;
/// Result of ingesting a single frame.
#[derive(Debug, Clone)]
pub enum IngestResult {
/// The frame was successfully stored.
Stored {
/// ID of the stored frame.
id: Uuid,
},
/// The frame was deduplicated (not stored).
Deduplicated {
/// ID of the existing similar frame.
similar_to: Uuid,
/// Cosine similarity score with the existing frame.
similarity: f32,
},
/// The frame was denied by the safety gate.
Denied {
/// Reason for denial.
reason: String,
},
}
/// Statistics about the ingestion pipeline.
#[derive(Debug, Clone, Default)]
pub struct PipelineStats {
/// Total frames successfully ingested.
pub total_ingested: u64,
/// Total frames deduplicated.
pub total_deduplicated: u64,
/// Total frames denied by safety gate.
pub total_denied: u64,
/// Total frames that had content redacted before storage.
pub total_redacted: u64,
}
/// The main ingestion pipeline that processes captured frames.
///
/// Frames flow through:
/// Safety Gate -> Deduplication -> Embedding -> Storage -> Graph (extract entities)
///
/// Search flow:
/// Route -> Search -> Rerank (attention) -> Diversity (quantum) -> Return
pub struct IngestionPipeline {
embedding_engine: EmbeddingEngine,
vector_store: VectorStore,
safety_gate: SafetyGate,
dedup: FrameDeduplicator,
stats: PipelineStats,
/// Optional knowledge graph for entity extraction after storage.
knowledge_graph: Option<KnowledgeGraph>,
/// Optional enhanced search orchestrator (router + reranker + quantum).
enhanced_search: Option<EnhancedSearch>,
}
impl IngestionPipeline {
/// Create a new ingestion pipeline with the given configuration.
pub fn new(config: OsPipeConfig) -> Result<Self> {
let embedding_engine = EmbeddingEngine::new(config.storage.embedding_dim);
let vector_store = VectorStore::new(config.storage.clone())?;
let safety_gate = SafetyGate::new(config.safety.clone());
let dedup = FrameDeduplicator::new(config.storage.dedup_threshold, 100);
Ok(Self {
embedding_engine,
vector_store,
safety_gate,
dedup,
stats: PipelineStats::default(),
knowledge_graph: None,
enhanced_search: None,
})
}
/// Attach a knowledge graph for entity extraction on ingested frames.
///
/// When a graph is attached, every successfully stored frame will have
/// its text analysed for entities (persons, URLs, emails, mentions),
/// which are then added to the graph as nodes linked to the frame.
pub fn with_graph(mut self, kg: KnowledgeGraph) -> Self {
self.knowledge_graph = Some(kg);
self
}
/// Attach an enhanced search orchestrator.
///
/// When attached, the [`search`](Self::search) method will route the
/// query, fetch extra candidates, re-rank with attention, and apply
/// quantum-inspired diversity selection before returning results.
pub fn with_enhanced_search(mut self, es: EnhancedSearch) -> Self {
self.enhanced_search = Some(es);
self
}
/// Ingest a single captured frame through the pipeline.
pub fn ingest(&mut self, frame: CapturedFrame) -> Result<IngestResult> {
let text = frame.text_content().to_string();
// Step 1: Safety check
let safe_text = match self.safety_gate.check(&text) {
SafetyDecision::Allow => text,
SafetyDecision::AllowRedacted(redacted) => {
self.stats.total_redacted += 1;
redacted
}
SafetyDecision::Deny { reason } => {
self.stats.total_denied += 1;
return Ok(IngestResult::Denied { reason });
}
};
// Step 2: Generate embedding from the (possibly redacted) text
let embedding = self.embedding_engine.embed(&safe_text);
// Step 3: Deduplication check
if let Some((similar_id, similarity)) = self.dedup.is_duplicate(&embedding) {
self.stats.total_deduplicated += 1;
return Ok(IngestResult::Deduplicated {
similar_to: similar_id,
similarity,
});
}
// Step 4: Store the frame
// If the text was redacted, create a modified frame with the safe text
let mut store_frame = frame;
if safe_text != store_frame.text_content() {
store_frame.content = match &store_frame.content {
crate::capture::FrameContent::OcrText(_) => {
crate::capture::FrameContent::OcrText(safe_text)
}
crate::capture::FrameContent::Transcription(_) => {
crate::capture::FrameContent::Transcription(safe_text)
}
crate::capture::FrameContent::UiEvent(_) => {
crate::capture::FrameContent::UiEvent(safe_text)
}
};
}
self.vector_store.insert(&store_frame, &embedding)?;
let id = store_frame.id;
self.dedup.add(id, embedding);
self.stats.total_ingested += 1;
// Step 5: Graph entity extraction (if knowledge graph is attached)
if let Some(ref mut kg) = self.knowledge_graph {
let frame_id_str = id.to_string();
let _ = kg.ingest_frame_entities(&frame_id_str, store_frame.text_content());
}
Ok(IngestResult::Stored { id })
}
/// Ingest a batch of frames.
pub fn ingest_batch(&mut self, frames: Vec<CapturedFrame>) -> Result<Vec<IngestResult>> {
let mut results = Vec::with_capacity(frames.len());
for frame in frames {
results.push(self.ingest(frame)?);
}
Ok(results)
}
/// Return current pipeline statistics.
pub fn stats(&self) -> &PipelineStats {
&self.stats
}
/// Return a reference to the underlying vector store.
pub fn vector_store(&self) -> &VectorStore {
&self.vector_store
}
/// Return a reference to the embedding engine.
pub fn embedding_engine(&self) -> &EmbeddingEngine {
&self.embedding_engine
}
/// Return a reference to the knowledge graph, if one is attached.
pub fn knowledge_graph(&self) -> Option<&KnowledgeGraph> {
self.knowledge_graph.as_ref()
}
/// Search the pipeline's vector store.
///
/// If an [`EnhancedSearch`] orchestrator is attached, the query is routed,
/// candidates are fetched with headroom, re-ranked with attention, and
/// diversity-selected via quantum-inspired algorithms.
///
/// Otherwise, a basic vector similarity search is performed.
pub fn search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
let embedding = self.embedding_engine.embed(query);
if let Some(ref es) = self.enhanced_search {
es.search(query, &embedding, &self.vector_store, k)
} else {
self.vector_store.search(&embedding, k)
}
}
}

View File

@@ -0,0 +1,11 @@
//! Ingestion pipeline with deduplication.
//!
//! The pipeline receives captured frames, passes them through the safety
//! gate, checks for duplicates, generates embeddings, and stores the
//! results in the vector store.
pub mod dedup;
pub mod ingestion;
pub use dedup::FrameDeduplicator;
pub use ingestion::{IngestResult, IngestionPipeline, PipelineStats};

View File

@@ -0,0 +1,324 @@
//! Quantum-inspired search acceleration.
//!
//! Provides [`QuantumSearch`], a collection of quantum-inspired algorithms
//! that accelerate and diversify search results.
//!
//! On native targets the implementation delegates to the `ruqu-algorithms`
//! crate (Grover's amplitude amplification, QAOA for MaxCut). On WASM
//! targets an equivalent classical fallback is provided so that the same
//! API is available everywhere.
/// Quantum-inspired search operations.
///
/// All methods are deterministic and require no quantum hardware; they
/// use classical simulations of quantum algorithms (on native) or
/// purely classical heuristics (on WASM) to improve search result
/// quality.
pub struct QuantumSearch {
_private: (),
}
impl QuantumSearch {
/// Create a new `QuantumSearch` instance.
pub fn new() -> Self {
Self { _private: () }
}
/// Compute the theoretically optimal number of Grover iterations for
/// a search space of `search_space_size` items (with a single target).
///
/// Returns `floor(pi/4 * sqrt(N))`, which is at least 1.
pub fn optimal_iterations(&self, search_space_size: u32) -> u32 {
if search_space_size <= 1 {
return 1;
}
let n = search_space_size as f64;
let iters = (std::f64::consts::FRAC_PI_4 * n.sqrt()).floor() as u32;
iters.max(1)
}
/// Select `k` diverse results from a scored set using QAOA-inspired
/// MaxCut partitioning.
///
/// A similarity graph is built between all result pairs and a
/// partition is found that maximizes the "cut" between selected and
/// unselected items. For small `k` (<=8) on native targets the
/// quantum QAOA solver is used; otherwise a greedy heuristic selects
/// the next-highest-scoring item that is most different from those
/// already selected.
///
/// Returns up to `k` items from `scores`, preserving their original
/// `(id, score)` tuples.
pub fn diversity_select(&self, scores: &[(String, f32)], k: usize) -> Vec<(String, f32)> {
if scores.is_empty() || k == 0 {
return Vec::new();
}
let k = k.min(scores.len());
// Try QAOA path on native for small k.
#[cfg(not(target_arch = "wasm32"))]
{
if k <= 8 {
if let Some(result) = self.qaoa_diversity_select(scores, k) {
return result;
}
}
}
// Classical greedy fallback (also used on WASM).
self.greedy_diversity_select(scores, k)
}
/// Amplify scores above `target_threshold` and dampen scores below
/// it, inspired by Grover amplitude amplification.
///
/// Scores above the threshold are boosted by `sqrt(boost_factor)`
/// and scores below are dampened by `1/sqrt(boost_factor)`. All
/// scores are then re-normalized to the [0, 1] range.
///
/// The boost factor is derived from the ratio of items above vs
/// below the threshold, clamped so that results stay meaningful.
pub fn amplitude_boost(&self, scores: &mut [(String, f32)], target_threshold: f32) {
if scores.is_empty() {
return;
}
let above_count = scores
.iter()
.filter(|(_, s)| *s >= target_threshold)
.count();
let below_count = scores.len() - above_count;
if above_count == 0 || below_count == 0 {
// All on one side -- nothing useful to amplify.
return;
}
// Boost factor: ratio of total to above (analogous to Grover's
// N/M amplification), clamped to [1.5, 4.0] to avoid extremes.
let boost_factor = (scores.len() as f64 / above_count as f64).clamp(1.5, 4.0);
let sqrt_boost = (boost_factor).sqrt() as f32;
let inv_sqrt_boost = 1.0 / sqrt_boost;
for (_id, score) in scores.iter_mut() {
if *score >= target_threshold {
*score *= sqrt_boost;
} else {
*score *= inv_sqrt_boost;
}
}
// Re-normalize to [0, 1].
let max_score = scores
.iter()
.map(|(_, s)| *s)
.fold(f32::NEG_INFINITY, f32::max);
let min_score = scores.iter().map(|(_, s)| *s).fold(f32::INFINITY, f32::min);
let range = max_score - min_score;
if range > f32::EPSILON {
for (_id, score) in scores.iter_mut() {
*score = (*score - min_score) / range;
}
} else {
// All scores are identical after boost; set to 1.0.
for (_id, score) in scores.iter_mut() {
*score = 1.0;
}
}
}
// ------------------------------------------------------------------
// Native-only: QAOA diversity selection
// ------------------------------------------------------------------
#[cfg(not(target_arch = "wasm32"))]
fn qaoa_diversity_select(
&self,
scores: &[(String, f32)],
k: usize,
) -> Option<Vec<(String, f32)>> {
use ruqu_algorithms::{run_qaoa, Graph, QaoaConfig};
let n = scores.len();
if n < 2 {
return Some(scores.to_vec());
}
// Build a similarity graph: edge weight encodes how *similar*
// two items are (based on score proximity). QAOA MaxCut will
// then prefer to *separate* similar items across the partition,
// giving us diversity.
let mut graph = Graph::new(n as u32);
for i in 0..n {
for j in (i + 1)..n {
// Similarity = 1 - |score_i - score_j| (higher when scores
// are close, promoting diversity in the selected set).
let similarity = 1.0 - (scores[i].1 - scores[j].1).abs();
graph.add_edge(i as u32, j as u32, similarity as f64);
}
}
let config = QaoaConfig {
graph,
p: 2,
max_iterations: 50,
learning_rate: 0.1,
seed: Some(42),
};
let result = run_qaoa(&config).ok()?;
// Collect indices for the partition with the most members near k.
let partition_true: Vec<usize> = result
.best_bitstring
.iter()
.enumerate()
.filter(|(_, &b)| b)
.map(|(i, _)| i)
.collect();
let partition_false: Vec<usize> = result
.best_bitstring
.iter()
.enumerate()
.filter(|(_, &b)| !b)
.map(|(i, _)| i)
.collect();
// Pick the partition closer to size k, then sort by score
// descending and take the top k.
let chosen = if (partition_true.len() as isize - k as isize).unsigned_abs()
<= (partition_false.len() as isize - k as isize).unsigned_abs()
{
partition_true
} else {
partition_false
};
// If neither partition has at least k items, fall back to greedy.
if chosen.len() < k {
return None;
}
let mut selected: Vec<(String, f32)> = chosen.iter().map(|&i| scores[i].clone()).collect();
selected.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
selected.truncate(k);
Some(selected)
}
// ------------------------------------------------------------------
// Classical greedy diversity selection (WASM + large-k fallback)
// ------------------------------------------------------------------
fn greedy_diversity_select(&self, scores: &[(String, f32)], k: usize) -> Vec<(String, f32)> {
let mut remaining: Vec<(usize, &(String, f32))> = scores.iter().enumerate().collect();
// Sort by score descending to seed with the best item.
remaining.sort_by(|a, b| {
b.1 .1
.partial_cmp(&a.1 .1)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut selected: Vec<(String, f32)> = Vec::with_capacity(k);
// Pick the highest-scoring item first.
if let Some((_, first)) = remaining.first() {
selected.push((*first).clone());
}
let first_idx = remaining.first().map(|(i, _)| *i);
remaining.retain(|(i, _)| Some(*i) != first_idx);
// Greedily pick the next item that maximizes (score * diversity).
// Diversity is measured as the minimum score-distance from any
// already-selected item.
while selected.len() < k && !remaining.is_empty() {
let mut best_idx_in_remaining = 0;
let mut best_value = f64::NEG_INFINITY;
for (ri, (_, candidate)) in remaining.iter().enumerate() {
let min_dist: f32 = selected
.iter()
.map(|(_, sel_score)| (candidate.1 - sel_score).abs())
.fold(f32::INFINITY, f32::min);
// Combined objective: high score + high diversity.
let value = candidate.1 as f64 + min_dist as f64;
if value > best_value {
best_value = value;
best_idx_in_remaining = ri;
}
}
let (_, picked) = remaining.remove(best_idx_in_remaining);
selected.push(picked.clone());
}
selected
}
}
impl Default for QuantumSearch {
fn default() -> Self {
Self::new()
}
}
// ---------------------------------------------------------------------------
// Unit tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_optimal_iterations_basic() {
let qs = QuantumSearch::new();
assert_eq!(qs.optimal_iterations(1), 1);
assert_eq!(qs.optimal_iterations(4), 1); // pi/4 * 2 = 1.57 -> floor = 1
}
#[test]
fn test_optimal_iterations_larger() {
let qs = QuantumSearch::new();
// pi/4 * sqrt(100) = pi/4 * 10 = 7.85 -> floor = 7
assert_eq!(qs.optimal_iterations(100), 7);
}
#[test]
fn test_diversity_select_empty() {
let qs = QuantumSearch::new();
let result = qs.diversity_select(&[], 3);
assert!(result.is_empty());
}
#[test]
fn test_diversity_select_k_zero() {
let qs = QuantumSearch::new();
let scores = vec![("a".to_string(), 0.5)];
let result = qs.diversity_select(&scores, 0);
assert!(result.is_empty());
}
#[test]
fn test_amplitude_boost_empty() {
let qs = QuantumSearch::new();
let mut scores: Vec<(String, f32)> = Vec::new();
qs.amplitude_boost(&mut scores, 0.5);
assert!(scores.is_empty());
}
#[test]
fn test_amplitude_boost_all_above() {
let qs = QuantumSearch::new();
let mut scores = vec![("a".to_string(), 0.8), ("b".to_string(), 0.9)];
let orig = scores.clone();
qs.amplitude_boost(&mut scores, 0.5);
// All above threshold -> no change in relative ordering,
// but scores remain unchanged since boost is a no-op.
assert_eq!(scores[0].0, orig[0].0);
assert_eq!(scores[1].0, orig[1].0);
}
}

View File

@@ -0,0 +1,550 @@
//! Safety gate for content filtering and PII redaction.
//!
//! The safety gate inspects captured content before it enters the
//! ingestion pipeline, detecting and optionally redacting sensitive
//! information such as credit card numbers, SSNs, and custom patterns.
use crate::config::SafetyConfig;
/// Decision made by the safety gate about a piece of content.
#[derive(Debug, Clone, PartialEq)]
pub enum SafetyDecision {
/// Content is safe to store as-is.
Allow,
/// Content is safe after redaction; the redacted version is provided.
AllowRedacted(String),
/// Content must not be stored.
Deny {
/// Reason for denial.
reason: String,
},
}
/// Safety gate that checks content for sensitive information.
pub struct SafetyGate {
config: SafetyConfig,
}
impl SafetyGate {
/// Create a new safety gate with the given configuration.
pub fn new(config: SafetyConfig) -> Self {
Self { config }
}
/// Check content and return a safety decision.
///
/// If PII is detected and redaction is enabled, the content is
/// returned in redacted form. If custom patterns match and no
/// redaction is possible, the content is denied.
pub fn check(&self, content: &str) -> SafetyDecision {
let mut redacted = content.to_string();
let mut was_redacted = false;
// Credit card redaction
if self.config.credit_card_redaction {
let (new_text, found) = redact_credit_cards(&redacted);
if found {
redacted = new_text;
was_redacted = true;
}
}
// SSN redaction
if self.config.ssn_redaction {
let (new_text, found) = redact_ssns(&redacted);
if found {
redacted = new_text;
was_redacted = true;
}
}
// PII detection (email addresses)
if self.config.pii_detection {
let (new_text, found) = redact_emails(&redacted);
if found {
redacted = new_text;
was_redacted = true;
}
}
// Custom patterns: deny if found (custom patterns indicate content
// that should not be stored at all)
for pattern in &self.config.custom_patterns {
if content.contains(pattern.as_str()) {
return SafetyDecision::Deny {
reason: format!("Custom pattern matched: {}", pattern),
};
}
}
if was_redacted {
SafetyDecision::AllowRedacted(redacted)
} else {
SafetyDecision::Allow
}
}
/// Redact all detected sensitive content and return the cleaned string.
pub fn redact(&self, content: &str) -> String {
match self.check(content) {
SafetyDecision::Allow => content.to_string(),
SafetyDecision::AllowRedacted(redacted) => redacted,
SafetyDecision::Deny { .. } => "[REDACTED]".to_string(),
}
}
}
/// Detect and redact sequences of 13-16 digits that look like credit card numbers.
///
/// This uses a simple pattern: sequences of digits (with optional spaces or dashes)
/// totaling 13-16 digits are replaced with [CC_REDACTED].
fn redact_credit_cards(text: &str) -> (String, bool) {
let mut result = String::with_capacity(text.len());
let chars: Vec<char> = text.chars().collect();
let mut i = 0;
let mut found = false;
while i < chars.len() {
// Check if we are at the start of a digit sequence
if chars[i].is_ascii_digit() {
let start = i;
let mut digit_count = 0;
// Consume digits, spaces, and dashes
while i < chars.len()
&& (chars[i].is_ascii_digit() || chars[i] == ' ' || chars[i] == '-')
{
if chars[i].is_ascii_digit() {
digit_count += 1;
}
i += 1;
}
if (13..=16).contains(&digit_count) {
result.push_str("[CC_REDACTED]");
found = true;
} else {
// Not a credit card, keep original text
for c in &chars[start..i] {
result.push(*c);
}
}
} else {
result.push(chars[i]);
i += 1;
}
}
(result, found)
}
/// Detect and redact SSN patterns (XXX-XX-XXXX).
fn redact_ssns(text: &str) -> (String, bool) {
let mut result = String::new();
let chars: Vec<char> = text.chars().collect();
let mut found = false;
let mut i = 0;
while i < chars.len() {
// Check for SSN pattern: 3 digits, dash, 2 digits, dash, 4 digits
if i + 10 < chars.len() && is_ssn_at(&chars, i) {
result.push_str("[SSN_REDACTED]");
found = true;
i += 11; // Skip the SSN (XXX-XX-XXXX = 11 chars)
} else {
result.push(chars[i]);
i += 1;
}
}
(result, found)
}
/// Check if an SSN pattern exists at the given position.
fn is_ssn_at(chars: &[char], pos: usize) -> bool {
if pos + 10 >= chars.len() {
return false;
}
// XXX-XX-XXXX
chars[pos].is_ascii_digit()
&& chars[pos + 1].is_ascii_digit()
&& chars[pos + 2].is_ascii_digit()
&& chars[pos + 3] == '-'
&& chars[pos + 4].is_ascii_digit()
&& chars[pos + 5].is_ascii_digit()
&& chars[pos + 6] == '-'
&& chars[pos + 7].is_ascii_digit()
&& chars[pos + 8].is_ascii_digit()
&& chars[pos + 9].is_ascii_digit()
&& chars[pos + 10].is_ascii_digit()
}
/// Detect and redact email addresses while preserving surrounding whitespace.
///
/// Scans character-by-character for `@` signs, then expands outward to find
/// the full `local@domain.tld` span and replaces it in-place, keeping all
/// surrounding whitespace (tabs, newlines, multi-space runs) intact.
fn redact_emails(text: &str) -> (String, bool) {
let chars: Vec<char> = text.chars().collect();
let len = chars.len();
let mut result = String::with_capacity(text.len());
let mut found = false;
let mut i = 0;
while i < len {
if chars[i] == '@' {
// Try to identify an email around this '@'.
// Scan backwards for the local part.
let mut local_start = i;
while local_start > 0 && is_email_local_char(chars[local_start - 1]) {
local_start -= 1;
}
// Scan forwards for the domain part.
let mut domain_end = i + 1;
let mut has_dot = false;
while domain_end < len && is_email_domain_char(chars[domain_end]) {
if chars[domain_end] == '.' {
has_dot = true;
}
domain_end += 1;
}
// Trim trailing dots/hyphens from domain (not valid at end).
while domain_end > i + 1
&& (chars[domain_end - 1] == '.' || chars[domain_end - 1] == '-')
{
if chars[domain_end - 1] == '.' {
// Re-check if we still have a dot in the trimmed domain.
has_dot = chars[i + 1..domain_end - 1].contains(&'.');
}
domain_end -= 1;
}
let local_len = i - local_start;
let domain_len = domain_end - (i + 1);
if local_len > 0 && domain_len >= 3 && has_dot {
// Valid email: replace the span [local_start..domain_end]
// We need to remove any characters already pushed for the local part.
// They were pushed in the normal flow below, so truncate them.
let already_pushed = i - local_start;
let new_len = result.len() - already_pushed;
result.truncate(new_len);
result.push_str("[EMAIL_REDACTED]");
found = true;
i = domain_end;
} else {
// Not a valid email, keep the '@' as-is.
result.push(chars[i]);
i += 1;
}
} else {
result.push(chars[i]);
i += 1;
}
}
(result, found)
}
/// Characters valid in the local part of an email address.
fn is_email_local_char(c: char) -> bool {
c.is_ascii_alphanumeric() || c == '.' || c == '+' || c == '-' || c == '_'
}
/// Characters valid in the domain part of an email address.
fn is_email_domain_char(c: char) -> bool {
c.is_ascii_alphanumeric() || c == '.' || c == '-'
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::SafetyConfig;
// ---------------------------------------------------------------
// Email redaction whitespace preservation tests
// ---------------------------------------------------------------
#[test]
fn test_email_redaction_preserves_tabs() {
let (result, found) = redact_emails("contact\tuser@example.com\there");
assert!(found);
assert_eq!(result, "contact\t[EMAIL_REDACTED]\there");
}
#[test]
fn test_email_redaction_preserves_newlines() {
let (result, found) = redact_emails("contact\nuser@example.com\nhere");
assert!(found);
assert_eq!(result, "contact\n[EMAIL_REDACTED]\nhere");
}
#[test]
fn test_email_redaction_preserves_multi_spaces() {
let (result, found) = redact_emails("contact user@example.com here");
assert!(found);
assert_eq!(result, "contact [EMAIL_REDACTED] here");
}
#[test]
fn test_email_redaction_preserves_mixed_whitespace() {
let (result, found) = redact_emails("contact\t user@example.com\n here");
assert!(found);
assert_eq!(result, "contact\t [EMAIL_REDACTED]\n here");
}
#[test]
fn test_email_redaction_basic() {
let (result, found) = redact_emails("email user@example.com here");
assert!(found);
assert_eq!(result, "email [EMAIL_REDACTED] here");
}
#[test]
fn test_email_redaction_no_email() {
let (result, found) = redact_emails("no email here");
assert!(!found);
assert_eq!(result, "no email here");
}
#[test]
fn test_email_redaction_multiple_emails() {
let (result, found) = redact_emails("a@b.com and c@d.org");
assert!(found);
assert_eq!(result, "[EMAIL_REDACTED] and [EMAIL_REDACTED]");
}
#[test]
fn test_email_redaction_at_start() {
let (result, found) = redact_emails("user@example.com is the contact");
assert!(found);
assert_eq!(result, "[EMAIL_REDACTED] is the contact");
}
#[test]
fn test_email_redaction_at_end() {
let (result, found) = redact_emails("contact: user@example.com");
assert!(found);
assert_eq!(result, "contact: [EMAIL_REDACTED]");
}
// ---------------------------------------------------------------
// Safety gate integration tests for consistency
// ---------------------------------------------------------------
#[test]
fn test_safety_gate_email_preserves_whitespace() {
let config = SafetyConfig::default();
let gate = SafetyGate::new(config);
let decision = gate.check("contact\tuser@example.com\nhere");
match decision {
SafetyDecision::AllowRedacted(redacted) => {
assert_eq!(redacted, "contact\t[EMAIL_REDACTED]\nhere");
}
other => panic!("Expected AllowRedacted, got {:?}", other),
}
}
// ---------------------------------------------------------------
// Routing consistency tests (WASM vs native)
// ---------------------------------------------------------------
#[test]
fn test_wasm_routing_matches_native_temporal() {
use crate::search::router::QueryRoute;
use crate::search::router::QueryRouter;
use crate::wasm::helpers::route_query;
let router = QueryRouter::new();
let queries = [
"what did I see yesterday",
"show me last week",
"results from today",
];
for q in &queries {
assert_eq!(
router.route(q),
QueryRoute::Temporal,
"Native router failed for: {}",
q
);
assert_eq!(route_query(q), "Temporal", "WASM router failed for: {}", q);
}
}
#[test]
fn test_wasm_routing_matches_native_graph() {
use crate::search::router::QueryRoute;
use crate::search::router::QueryRouter;
use crate::wasm::helpers::route_query;
let router = QueryRouter::new();
let queries = [
"documents related to authentication",
"things connected to the API module",
];
for q in &queries {
assert_eq!(
router.route(q),
QueryRoute::Graph,
"Native router failed for: {}",
q
);
assert_eq!(route_query(q), "Graph", "WASM router failed for: {}", q);
}
}
#[test]
fn test_wasm_routing_matches_native_keyword_short() {
use crate::search::router::QueryRoute;
use crate::search::router::QueryRouter;
use crate::wasm::helpers::route_query;
let router = QueryRouter::new();
let queries = ["hello", "rust programming"];
for q in &queries {
assert_eq!(
router.route(q),
QueryRoute::Keyword,
"Native router failed for: {}",
q
);
assert_eq!(route_query(q), "Keyword", "WASM router failed for: {}", q);
}
}
#[test]
fn test_wasm_routing_matches_native_keyword_quoted() {
use crate::search::router::QueryRoute;
use crate::search::router::QueryRouter;
use crate::wasm::helpers::route_query;
let router = QueryRouter::new();
let q = "\"exact phrase search\"";
assert_eq!(router.route(q), QueryRoute::Keyword);
assert_eq!(route_query(q), "Keyword");
}
#[test]
fn test_wasm_routing_matches_native_hybrid() {
use crate::search::router::QueryRoute;
use crate::search::router::QueryRouter;
use crate::wasm::helpers::route_query;
let router = QueryRouter::new();
let queries = [
"how to implement authentication in Rust",
"explain how embeddings work",
"something about machine learning",
];
for q in &queries {
assert_eq!(
router.route(q),
QueryRoute::Hybrid,
"Native router failed for: {}",
q
);
assert_eq!(route_query(q), "Hybrid", "WASM router failed for: {}", q);
}
}
// ---------------------------------------------------------------
// Safety consistency tests (WASM vs native)
// ---------------------------------------------------------------
#[test]
fn test_wasm_safety_matches_native_cc() {
use crate::wasm::helpers::safety_classify;
// Native: CC -> AllowRedacted; WASM should return "redact"
let config = SafetyConfig::default();
let gate = SafetyGate::new(config);
let content = "pay with 4111-1111-1111-1111";
assert!(matches!(
gate.check(content),
SafetyDecision::AllowRedacted(_)
));
assert_eq!(safety_classify(content), "redact");
}
#[test]
fn test_wasm_safety_matches_native_ssn() {
use crate::wasm::helpers::safety_classify;
let config = SafetyConfig::default();
let gate = SafetyGate::new(config);
let content = "my ssn 123-45-6789";
assert!(matches!(
gate.check(content),
SafetyDecision::AllowRedacted(_)
));
assert_eq!(safety_classify(content), "redact");
}
#[test]
fn test_wasm_safety_matches_native_email() {
use crate::wasm::helpers::safety_classify;
let config = SafetyConfig::default();
let gate = SafetyGate::new(config);
let content = "email user@example.com here";
assert!(matches!(
gate.check(content),
SafetyDecision::AllowRedacted(_)
));
assert_eq!(safety_classify(content), "redact");
}
#[test]
fn test_wasm_safety_matches_native_custom_deny() {
use crate::wasm::helpers::safety_classify;
// Native: custom_patterns -> Deny; WASM: sensitive keywords -> "deny"
let config = SafetyConfig {
custom_patterns: vec!["password".to_string()],
..Default::default()
};
let gate = SafetyGate::new(config);
let content = "my password is foo";
assert!(matches!(gate.check(content), SafetyDecision::Deny { .. }));
assert_eq!(safety_classify(content), "deny");
}
#[test]
fn test_wasm_safety_matches_native_allow() {
use crate::wasm::helpers::safety_classify;
let config = SafetyConfig::default();
let gate = SafetyGate::new(config);
let content = "the weather is nice";
assert_eq!(gate.check(content), SafetyDecision::Allow);
assert_eq!(safety_classify(content), "allow");
}
// ---------------------------------------------------------------
// MMR tests
// ---------------------------------------------------------------
#[test]
fn test_mmr_produces_different_order_than_cosine() {
use crate::search::mmr::MmrReranker;
let mmr = MmrReranker::new(0.3);
let query = vec![1.0, 0.0, 0.0, 0.0];
let results = vec![
("a".to_string(), 0.95, vec![1.0, 0.0, 0.0, 0.0]),
("b".to_string(), 0.90, vec![0.99, 0.01, 0.0, 0.0]),
("c".to_string(), 0.60, vec![0.0, 1.0, 0.0, 0.0]),
];
let ranked = mmr.rerank(&query, &results, 3);
assert_eq!(ranked.len(), 3);
// Pure cosine order: a, b, c
// MMR with diversity: a, c, b (c is diverse, b is near-duplicate of a)
assert_eq!(ranked[0].0, "a");
assert_eq!(ranked[1].0, "c", "MMR should promote diverse result");
assert_eq!(ranked[2].0, "b");
}
}

View File

@@ -0,0 +1,220 @@
//! Enhanced search orchestrator.
//!
//! Combines query routing, attention-based re-ranking, and quantum-inspired
//! diversity selection into a single search pipeline:
//!
//! ```text
//! Route -> Search (3x k candidates) -> Rerank (attention) -> Diversity (quantum) -> Return
//! ```
use crate::error::Result;
use crate::quantum::QuantumSearch;
use crate::search::reranker::AttentionReranker;
use crate::search::router::QueryRouter;
use crate::storage::vector_store::{SearchResult, VectorStore};
/// Orchestrates a full search pipeline: routing, candidate retrieval,
/// attention re-ranking, and quantum diversity selection.
pub struct EnhancedSearch {
router: QueryRouter,
reranker: Option<AttentionReranker>,
quantum: Option<QuantumSearch>,
}
impl EnhancedSearch {
/// Create a new enhanced search with all components wired.
///
/// # Arguments
/// * `dim` - Embedding dimension used to configure the attention reranker.
pub fn new(dim: usize) -> Self {
Self {
router: QueryRouter::new(),
reranker: Some(AttentionReranker::new(dim, 4)),
quantum: Some(QuantumSearch::new()),
}
}
/// Create an enhanced search with only the router (no reranking or diversity).
pub fn router_only() -> Self {
Self {
router: QueryRouter::new(),
reranker: None,
quantum: None,
}
}
/// Return a reference to the query router.
pub fn router(&self) -> &QueryRouter {
&self.router
}
/// Search the vector store with routing, re-ranking, and diversity selection.
///
/// The pipeline:
/// 1. Route the query to determine the search strategy.
/// 2. Fetch `3 * k` candidates from the store to give the reranker headroom.
/// 3. If a reranker is available, re-rank candidates using attention scores.
/// 4. If quantum diversity selection is available, select the final `k`
/// results with maximum diversity.
/// 5. Return the final results.
pub fn search(
&self,
query: &str,
query_embedding: &[f32],
store: &VectorStore,
k: usize,
) -> Result<Vec<SearchResult>> {
// Step 1: Route the query (informational -- we always search the
// vector store for now, but the route is available for future use).
let _route = self.router.route(query);
// Step 2: Fetch candidates with headroom for reranking.
let candidate_k = (k * 3).max(10).min(store.len().max(1));
let candidates = store.search(query_embedding, candidate_k)?;
if candidates.is_empty() {
return Ok(Vec::new());
}
// Step 3: Re-rank with attention if available.
let results = if let Some(ref reranker) = self.reranker {
// Build the tuples the reranker expects: (id_string, score, embedding).
let reranker_input: Vec<(String, f32, Vec<f32>)> = candidates
.iter()
.map(|sr| {
// Retrieve the stored embedding for this result.
let embedding = store
.get(&sr.id)
.map(|stored| stored.vector.clone())
.unwrap_or_else(|| vec![0.0; query_embedding.len()]);
(sr.id.to_string(), sr.score, embedding)
})
.collect();
// The reranker returns more than k so quantum diversity can choose.
let rerank_k = if self.quantum.is_some() {
(k * 2).min(reranker_input.len())
} else {
k
};
let reranked = reranker.rerank(query_embedding, &reranker_input, rerank_k);
// Step 4: Diversity selection if available.
let final_scored = if let Some(ref quantum) = self.quantum {
quantum.diversity_select(&reranked, k)
} else {
let mut r = reranked;
r.truncate(k);
r
};
// Map back to SearchResult by looking up metadata from candidates.
final_scored
.into_iter()
.filter_map(|(id_str, score)| {
// Parse the UUID back.
let uid: uuid::Uuid = id_str.parse().ok()?;
// Find the original candidate to retrieve metadata.
let original = candidates.iter().find(|c| c.id == uid)?;
Some(SearchResult {
id: uid,
score,
metadata: original.metadata.clone(),
})
})
.collect()
} else {
// No reranker -- just truncate.
candidates.into_iter().take(k).collect()
};
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::capture::CapturedFrame;
use crate::config::StorageConfig;
use crate::storage::embedding::EmbeddingEngine;
#[test]
fn test_enhanced_search_empty_store() {
let config = StorageConfig::default();
let store = VectorStore::new(config).unwrap();
let engine = EmbeddingEngine::new(384);
let es = EnhancedSearch::new(384);
let query_emb = engine.embed("test query");
let results = es.search("test query", &query_emb, &store, 5).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_enhanced_search_returns_results() {
let config = StorageConfig::default();
let mut store = VectorStore::new(config).unwrap();
let engine = EmbeddingEngine::new(384);
let frames = vec![
CapturedFrame::new_screen("Editor", "code.rs", "implementing vector search in Rust", 0),
CapturedFrame::new_screen("Browser", "docs", "Rust vector database documentation", 0),
CapturedFrame::new_audio("Mic", "discussing Python machine learning", None),
];
for frame in &frames {
let emb = engine.embed(frame.text_content());
store.insert(frame, &emb).unwrap();
}
let es = EnhancedSearch::new(384);
let query_emb = engine.embed("vector search Rust");
let results = es
.search("vector search Rust", &query_emb, &store, 2)
.unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 2);
}
#[test]
fn test_enhanced_search_router_only() {
let config = StorageConfig::default();
let mut store = VectorStore::new(config).unwrap();
let engine = EmbeddingEngine::new(384);
let frame = CapturedFrame::new_screen("App", "Win", "test content", 0);
let emb = engine.embed(frame.text_content());
store.insert(&frame, &emb).unwrap();
let es = EnhancedSearch::router_only();
let query_emb = engine.embed("test content");
let results = es.search("test content", &query_emb, &store, 5).unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn test_enhanced_search_respects_k() {
let config = StorageConfig::default();
let mut store = VectorStore::new(config).unwrap();
let engine = EmbeddingEngine::new(384);
for i in 0..10 {
let frame = CapturedFrame::new_screen("App", "Win", &format!("content {}", i), 0);
let emb = engine.embed(frame.text_content());
store.insert(&frame, &emb).unwrap();
}
let es = EnhancedSearch::new(384);
let query_emb = engine.embed("content");
let results = es.search("content", &query_emb, &store, 3).unwrap();
assert!(
results.len() <= 3,
"Should return at most k=3 results, got {}",
results.len()
);
}
}

View File

@@ -0,0 +1,116 @@
//! Hybrid search combining semantic and keyword approaches.
use crate::error::Result;
use crate::storage::{SearchResult, VectorStore};
use std::collections::HashMap;
use uuid::Uuid;
/// Hybrid search that combines semantic vector similarity with keyword
/// matching using a configurable weight parameter.
pub struct HybridSearch {
/// Weight for semantic search (1.0 = pure semantic, 0.0 = pure keyword).
semantic_weight: f32,
}
impl HybridSearch {
/// Create a new hybrid search with the given semantic weight.
///
/// The weight controls the balance between semantic (vector) and
/// keyword (text match) scores. A value of 0.7 means 70% semantic
/// and 30% keyword.
pub fn new(semantic_weight: f32) -> Self {
Self {
semantic_weight: semantic_weight.clamp(0.0, 1.0),
}
}
/// Perform a hybrid search combining semantic and keyword results.
///
/// The `query` is used for keyword matching against stored text content.
/// The `embedding` is used for semantic similarity scoring.
pub fn search(
&self,
store: &VectorStore,
query: &str,
embedding: &[f32],
k: usize,
) -> Result<Vec<SearchResult>> {
// Get semantic results (more candidates than needed for merging)
let candidate_k = (k * 3).max(20).min(store.len());
let semantic_results = store.search(embedding, candidate_k)?;
// Build a combined score map
let mut scores: HashMap<Uuid, (f32, f32, serde_json::Value)> = HashMap::new();
// Add semantic scores
for result in &semantic_results {
scores
.entry(result.id)
.or_insert((0.0, 0.0, result.metadata.clone()))
.0 = result.score;
}
// Compute keyword scores for all candidates
let query_lower = query.to_lowercase();
let query_terms: Vec<&str> = query_lower.split_whitespace().collect();
for result in &semantic_results {
let text = result
.metadata
.get("text")
.and_then(|v| v.as_str())
.unwrap_or("");
let text_lower = text.to_lowercase();
let keyword_score = compute_keyword_score(&query_terms, &text_lower);
if let Some(entry) = scores.get_mut(&result.id) {
entry.1 = keyword_score;
}
}
// Combine scores using weighted sum
let keyword_weight = 1.0 - self.semantic_weight;
let mut combined: Vec<SearchResult> = scores
.into_iter()
.map(|(id, (sem_score, kw_score, metadata))| {
let combined_score = self.semantic_weight * sem_score + keyword_weight * kw_score;
SearchResult {
id,
score: combined_score,
metadata,
}
})
.collect();
// Sort by combined score descending
combined.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
combined.truncate(k);
Ok(combined)
}
/// Return the configured semantic weight.
pub fn semantic_weight(&self) -> f32 {
self.semantic_weight
}
}
/// Compute a simple keyword match score based on term overlap.
///
/// Returns a value between 0.0 and 1.0 representing the fraction
/// of query terms found in the text.
fn compute_keyword_score(query_terms: &[&str], text_lower: &str) -> f32 {
if query_terms.is_empty() {
return 0.0;
}
let matches = query_terms
.iter()
.filter(|term| text_lower.contains(*term))
.count();
matches as f32 / query_terms.len() as f32
}

View File

@@ -0,0 +1,219 @@
//! Maximal Marginal Relevance (MMR) re-ranking.
//!
//! MMR balances relevance to the query with diversity among selected
//! results, controlled by a `lambda` parameter:
//! - `lambda = 1.0` produces pure relevance ranking (identical to cosine).
//! - `lambda = 0.0` maximises diversity among selected results.
//!
//! The `lambda` value is sourced from [`SearchConfig::mmr_lambda`](crate::config::SearchConfig).
/// Re-ranks search results using Maximal Marginal Relevance.
pub struct MmrReranker {
/// Trade-off between relevance and diversity.
/// 1.0 = pure relevance, 0.0 = pure diversity.
lambda: f32,
}
impl MmrReranker {
/// Create a new MMR reranker with the given lambda.
pub fn new(lambda: f32) -> Self {
Self { lambda }
}
/// Re-rank results using MMR to balance relevance and diversity.
///
/// # Arguments
///
/// * `query_embedding` - The query vector.
/// * `results` - Candidate results as `(id, score, embedding)` tuples.
/// * `k` - Maximum number of results to return.
///
/// # Returns
///
/// A `Vec` of `(id, mmr_score)` pairs in MMR-selected order,
/// truncated to at most `k` entries.
pub fn rerank(
&self,
query_embedding: &[f32],
results: &[(String, f32, Vec<f32>)],
k: usize,
) -> Vec<(String, f32)> {
if results.is_empty() {
return Vec::new();
}
let n = results.len().min(k);
// Precompute similarities between the query and each document.
let query_sims: Vec<f32> = results
.iter()
.map(|(_, _, emb)| cosine_sim(query_embedding, emb))
.collect();
let mut selected: Vec<usize> = Vec::with_capacity(n);
let mut selected_set = vec![false; results.len()];
let mut output: Vec<(String, f32)> = Vec::with_capacity(n);
for _ in 0..n {
let mut best_idx = None;
let mut best_mmr = f32::NEG_INFINITY;
for (i, _) in results.iter().enumerate() {
if selected_set[i] {
continue;
}
let relevance = query_sims[i];
// Max similarity to any already-selected document.
let max_sim_to_selected = if selected.is_empty() {
0.0
} else {
selected
.iter()
.map(|&j| cosine_sim(&results[i].2, &results[j].2))
.fold(f32::NEG_INFINITY, f32::max)
};
let mmr = self.lambda * relevance - (1.0 - self.lambda) * max_sim_to_selected;
if mmr > best_mmr {
best_mmr = mmr;
best_idx = Some(i);
}
}
if let Some(idx) = best_idx {
selected.push(idx);
selected_set[idx] = true;
output.push((results[idx].0.clone(), best_mmr));
} else {
break;
}
}
output
}
}
/// Cosine similarity between two vectors.
///
/// Returns 0.0 when either vector has zero magnitude.
fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
let mut dot: f32 = 0.0;
let mut mag_a: f32 = 0.0;
let mut mag_b: f32 = 0.0;
for i in 0..a.len().min(b.len()) {
dot += a[i] * b[i];
mag_a += a[i] * a[i];
mag_b += b[i] * b[i];
}
let denom = mag_a.sqrt() * mag_b.sqrt();
if denom == 0.0 {
0.0
} else {
dot / denom
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mmr_empty_results() {
let mmr = MmrReranker::new(0.5);
let result = mmr.rerank(&[1.0, 0.0], &[], 5);
assert!(result.is_empty());
}
#[test]
fn test_mmr_single_result() {
let mmr = MmrReranker::new(0.5);
let results = vec![("a".to_string(), 0.9, vec![1.0, 0.0])];
let ranked = mmr.rerank(&[1.0, 0.0], &results, 5);
assert_eq!(ranked.len(), 1);
assert_eq!(ranked[0].0, "a");
}
#[test]
fn test_mmr_pure_relevance() {
// lambda=1.0 should produce the same order as cosine similarity
let mmr = MmrReranker::new(1.0);
let query = vec![1.0, 0.0, 0.0];
let results = vec![
("best".to_string(), 0.9, vec![1.0, 0.0, 0.0]),
("mid".to_string(), 0.7, vec![0.7, 0.7, 0.0]),
("worst".to_string(), 0.3, vec![0.0, 0.0, 1.0]),
];
let ranked = mmr.rerank(&query, &results, 3);
assert_eq!(ranked.len(), 3);
assert_eq!(ranked[0].0, "best");
}
#[test]
fn test_mmr_promotes_diversity() {
// With lambda < 1.0, a diverse result should be promoted over a
// redundant one even if the redundant one has higher relevance.
let mmr = MmrReranker::new(0.3);
let query = vec![1.0, 0.0, 0.0, 0.0];
// Two results very similar to each other and the query,
// one result orthogonal but moderately relevant.
let results = vec![
("a".to_string(), 0.95, vec![1.0, 0.0, 0.0, 0.0]),
("a_clone".to_string(), 0.90, vec![0.99, 0.01, 0.0, 0.0]),
("diverse".to_string(), 0.60, vec![0.0, 1.0, 0.0, 0.0]),
];
let ranked = mmr.rerank(&query, &results, 3);
assert_eq!(ranked.len(), 3);
// "a" should be first (highest relevance)
assert_eq!(ranked[0].0, "a");
// "diverse" should be second because "a_clone" is too similar to "a"
assert_eq!(
ranked[1].0, "diverse",
"MMR should promote diverse result over near-duplicate"
);
}
#[test]
fn test_mmr_respects_top_k() {
let mmr = MmrReranker::new(0.5);
let query = vec![1.0, 0.0];
let results = vec![
("a".to_string(), 0.9, vec![1.0, 0.0]),
("b".to_string(), 0.8, vec![0.0, 1.0]),
("c".to_string(), 0.7, vec![0.5, 0.5]),
];
let ranked = mmr.rerank(&query, &results, 2);
assert_eq!(ranked.len(), 2);
}
#[test]
fn test_cosine_sim_identical() {
let v = vec![1.0, 2.0, 3.0];
let sim = cosine_sim(&v, &v);
assert!((sim - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_sim_orthogonal() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!(cosine_sim(&a, &b).abs() < 1e-6);
}
#[test]
fn test_cosine_sim_zero_vector() {
let a = vec![0.0, 0.0];
let b = vec![1.0, 2.0];
assert_eq!(cosine_sim(&a, &b), 0.0);
}
}

View File

@@ -0,0 +1,17 @@
//! Query routing and hybrid search.
//!
//! Provides intelligent query routing that selects the optimal search
//! backend (semantic, keyword, temporal, graph, or hybrid) based on
//! query characteristics.
pub mod enhanced;
pub mod hybrid;
pub mod mmr;
pub mod reranker;
pub mod router;
pub use enhanced::EnhancedSearch;
pub use hybrid::HybridSearch;
pub use mmr::MmrReranker;
pub use reranker::AttentionReranker;
pub use router::{QueryRoute, QueryRouter};

View File

@@ -0,0 +1,204 @@
//! Attention-based re-ranking for search results.
//!
//! Uses `ruvector-attention` on native targets to compute attention weights
//! between a query embedding and candidate result embeddings, producing a
//! relevance-aware re-ranking that goes beyond raw cosine similarity.
//!
//! On WASM targets a lightweight fallback is provided that preserves the
//! original cosine ordering.
/// Re-ranks search results using scaled dot-product attention.
///
/// On native builds the attention mechanism computes softmax-normalised
/// query-key scores and blends them with the original cosine similarity
/// to produce the final ranking. On WASM the original scores are
/// returned unchanged (sorted descending).
pub struct AttentionReranker {
dim: usize,
#[allow(dead_code)]
num_heads: usize,
}
impl AttentionReranker {
/// Creates a new reranker.
///
/// # Arguments
///
/// * `dim` - Embedding dimension (must match the vectors passed to `rerank`)
/// * `num_heads` - Number of attention heads (used on native only; ignored on WASM)
pub fn new(dim: usize, num_heads: usize) -> Self {
Self { dim, num_heads }
}
/// Re-ranks a set of search results using attention-derived scores.
///
/// # Arguments
///
/// * `query_embedding` - The query vector (`dim`-dimensional).
/// * `results` - Candidate results as `(id, original_cosine_score, embedding)` tuples.
/// * `top_k` - Maximum number of results to return.
///
/// # Returns
///
/// A `Vec` of `(id, final_score)` pairs sorted by descending `final_score`,
/// truncated to at most `top_k` entries.
pub fn rerank(
&self,
query_embedding: &[f32],
results: &[(String, f32, Vec<f32>)],
top_k: usize,
) -> Vec<(String, f32)> {
if results.is_empty() {
return Vec::new();
}
#[cfg(not(target_arch = "wasm32"))]
{
self.rerank_native(query_embedding, results, top_k)
}
#[cfg(target_arch = "wasm32")]
{
self.rerank_wasm(results, top_k)
}
}
// ---------------------------------------------------------------
// Native implementation (ruvector-attention)
// ---------------------------------------------------------------
#[cfg(not(target_arch = "wasm32"))]
fn rerank_native(
&self,
query_embedding: &[f32],
results: &[(String, f32, Vec<f32>)],
top_k: usize,
) -> Vec<(String, f32)> {
use ruvector_attention::attention::ScaledDotProductAttention;
use ruvector_attention::traits::Attention;
let attn = ScaledDotProductAttention::new(self.dim);
// Build key slices from result embeddings.
let keys: Vec<&[f32]> = results.iter().map(|(_, _, emb)| emb.as_slice()).collect();
// Compute attention weights using the same scaled dot-product algorithm
// as ScaledDotProductAttention, but extracting the softmax weights
// directly rather than the weighted-value output that compute() returns.
// --- Compute raw attention scores: QK^T / sqrt(d) ---
let scale = (self.dim as f32).sqrt();
let scores: Vec<f32> = keys
.iter()
.map(|key| {
query_embedding
.iter()
.zip(key.iter())
.map(|(q, k)| q * k)
.sum::<f32>()
/ scale
})
.collect();
// --- Softmax ---
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_scores: Vec<f32> = scores.iter().map(|s| (s - max_score).exp()).collect();
let exp_sum: f32 = exp_scores.iter().sum();
let attention_weights: Vec<f32> = exp_scores.iter().map(|e| e / exp_sum).collect();
// --- Verify the crate produces the same weighted output ---
// We call compute() with the real embeddings as both keys and values
// to validate that the crate is functional, but we use the manually
// computed weights for the final blending because the crate's compute
// returns a weighted *embedding*, not the weight vector.
let _attended_output = attn.compute(query_embedding, &keys, &keys);
// --- Blend: final = 0.6 * attention_weight + 0.4 * cosine_score ---
let mut scored: Vec<(String, f32)> = results
.iter()
.zip(attention_weights.iter())
.map(|((id, cosine, _), &attn_w)| {
let final_score = 0.6 * attn_w + 0.4 * cosine;
(id.clone(), final_score)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
scored
}
// ---------------------------------------------------------------
// WASM fallback
// ---------------------------------------------------------------
#[cfg(target_arch = "wasm32")]
fn rerank_wasm(&self, results: &[(String, f32, Vec<f32>)], top_k: usize) -> Vec<(String, f32)> {
let mut scored: Vec<(String, f32)> = results
.iter()
.map(|(id, cosine, _)| (id.clone(), *cosine))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
scored
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reranker_empty_results() {
let reranker = AttentionReranker::new(4, 1);
let result = reranker.rerank(&[1.0, 0.0, 0.0, 0.0], &[], 5);
assert!(result.is_empty());
}
#[test]
fn test_reranker_single_result() {
let reranker = AttentionReranker::new(4, 1);
let results = vec![("a".to_string(), 0.9, vec![1.0, 0.0, 0.0, 0.0])];
let ranked = reranker.rerank(&[1.0, 0.0, 0.0, 0.0], &results, 5);
assert_eq!(ranked.len(), 1);
assert_eq!(ranked[0].0, "a");
}
#[test]
fn test_reranker_respects_top_k() {
let reranker = AttentionReranker::new(4, 1);
let results = vec![
("a".to_string(), 0.9, vec![1.0, 0.0, 0.0, 0.0]),
("b".to_string(), 0.8, vec![0.0, 1.0, 0.0, 0.0]),
("c".to_string(), 0.7, vec![0.0, 0.0, 1.0, 0.0]),
];
let ranked = reranker.rerank(&[1.0, 0.0, 0.0, 0.0], &results, 2);
assert_eq!(ranked.len(), 2);
}
#[test]
fn test_reranker_can_reorder() {
// The attention mechanism should boost results whose embeddings
// are more aligned with the query, potentially changing the order
// compared to the original cosine scores.
let reranker = AttentionReranker::new(4, 1);
// Result "b" has a slightly lower cosine score but its embedding
// is perfectly aligned with the query while "a" is orthogonal.
// The 60/40 blending with a large attention weight difference
// should promote "b" above "a".
let results = vec![
("a".to_string(), 0.70, vec![0.0, 0.0, 1.0, 0.0]),
("b".to_string(), 0.55, vec![1.0, 0.0, 0.0, 0.0]),
];
let query = vec![1.0, 0.0, 0.0, 0.0];
let ranked = reranker.rerank(&query, &results, 2);
// With attention heavily favouring "b" (aligned with query) the
// blended score should push "b" above "a".
assert_eq!(ranked.len(), 2);
assert_eq!(
ranked[0].0, "b",
"Attention re-ranking should promote the more query-aligned result"
);
}
}

View File

@@ -0,0 +1,90 @@
//! Query routing to the optimal search backend.
/// The search backend to route a query to.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum QueryRoute {
/// Pure vector HNSW semantic search.
Semantic,
/// Full-text keyword search (FTS5-style).
Keyword,
/// Graph-based relationship query.
Graph,
/// Time-based delta replay query.
Temporal,
/// Combined semantic + keyword search.
Hybrid,
}
/// Routes incoming queries to the optimal search backend based on
/// query content heuristics.
pub struct QueryRouter;
impl QueryRouter {
/// Create a new query router.
pub fn new() -> Self {
Self
}
/// Determine the best search route for the given query string.
///
/// Routing heuristics:
/// - Temporal keywords ("yesterday", "last week", etc.) -> Temporal
/// - Graph keywords ("related to", "connected", etc.) -> Graph
/// - Short queries (1-2 words) -> Keyword
/// - Quoted exact phrases -> Keyword
/// - Everything else -> Hybrid
pub fn route(&self, query: &str) -> QueryRoute {
let lower = query.to_lowercase();
let word_count = lower.split_whitespace().count();
// Temporal patterns
let temporal_keywords = [
"yesterday",
"last week",
"last month",
"today",
"this morning",
"this afternoon",
"hours ago",
"minutes ago",
"days ago",
"between",
"before",
"after",
];
if temporal_keywords.iter().any(|kw| lower.contains(kw)) {
return QueryRoute::Temporal;
}
// Graph patterns
let graph_keywords = [
"related to",
"connected to",
"linked with",
"associated with",
"relationship between",
];
if graph_keywords.iter().any(|kw| lower.contains(kw)) {
return QueryRoute::Graph;
}
// Exact phrase (quoted)
if query.starts_with('"') && query.ends_with('"') {
return QueryRoute::Keyword;
}
// Very short queries are better served by keyword
if word_count <= 2 {
return QueryRoute::Keyword;
}
// Default: hybrid combines the best of both
QueryRoute::Hybrid
}
}
impl Default for QueryRouter {
fn default() -> Self {
Self::new()
}
}

View File

@@ -0,0 +1,610 @@
//! Lightweight HTTP REST API server for OSpipe.
//!
//! Exposes the ingestion pipeline, search, routing, and health endpoints
//! that the TypeScript SDK (`@ruvector/ospipe`) expects. Built on
//! [axum](https://docs.rs/axum) and gated behind
//! `cfg(not(target_arch = "wasm32"))` since WASM targets cannot bind
//! TCP sockets.
//!
//! ## Endpoints
//!
//! | Method | Path | Description |
//! |--------|------|-------------|
//! | `POST` | `/v2/search` | Semantic / hybrid vector search |
//! | `POST` | `/v2/route` | Query routing |
//! | `GET` | `/v2/stats` | Pipeline statistics |
//! | `GET` | `/v2/health` | Health check |
//! | `GET` | `/search` | Legacy Screenpipe v1 search |
use std::sync::Arc;
use axum::{
extract::{Query, State},
http::StatusCode,
response::IntoResponse,
routing::{get, post},
Json, Router,
};
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use tower_http::cors::{Any, CorsLayer};
use crate::pipeline::ingestion::{IngestionPipeline, PipelineStats};
use crate::search::router::{QueryRoute, QueryRouter};
use crate::storage::vector_store::SearchResult;
// ---------------------------------------------------------------------------
// Shared state
// ---------------------------------------------------------------------------
/// Shared server state holding the pipeline behind a read-write lock.
#[derive(Clone)]
pub struct ServerState {
/// The ingestion pipeline (search + store).
pub pipeline: Arc<RwLock<IngestionPipeline>>,
/// The query router.
pub router: Arc<QueryRouter>,
/// Server start instant for uptime calculation.
pub started_at: std::time::Instant,
}
// ---------------------------------------------------------------------------
// Request / response DTOs
// ---------------------------------------------------------------------------
/// Request body for `POST /v2/search`.
#[derive(Debug, Deserialize)]
pub struct SearchRequest {
/// Natural-language query string.
pub query: String,
/// Search mode hint (semantic, keyword, hybrid).
#[serde(default = "default_search_mode")]
pub mode: String,
/// Number of results to return.
#[serde(default = "default_k")]
pub k: usize,
/// Distance metric (cosine, euclidean, dot).
#[serde(default = "default_metric")]
pub metric: String,
/// Optional metadata filters.
pub filters: Option<SearchFilters>,
/// Whether to apply MMR reranking.
#[serde(default)]
pub rerank: bool,
}
fn default_search_mode() -> String {
"semantic".to_string()
}
fn default_k() -> usize {
10
}
fn default_metric() -> String {
"cosine".to_string()
}
/// Metadata filters mirroring the TypeScript SDK `SearchFilters` type.
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SearchFilters {
pub app: Option<String>,
pub window: Option<String>,
pub content_type: Option<String>,
pub time_range: Option<TimeRange>,
pub monitor: Option<u32>,
pub speaker: Option<String>,
pub language: Option<String>,
}
/// ISO-8601 time range.
#[derive(Debug, Deserialize)]
pub struct TimeRange {
pub start: String,
pub end: String,
}
/// Request body for `POST /v2/route`.
#[derive(Debug, Deserialize)]
pub struct RouteRequest {
pub query: String,
}
/// Response body for `POST /v2/route`.
#[derive(Debug, Serialize, Deserialize)]
pub struct RouteResponse {
pub route: String,
}
/// Response body for `GET /v2/stats`.
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct StatsResponse {
pub total_ingested: u64,
pub total_deduplicated: u64,
pub total_denied: u64,
pub total_redacted: u64,
pub storage_bytes: u64,
pub index_size: usize,
pub uptime: u64,
}
/// Response body for `GET /v2/health`.
#[derive(Debug, Serialize, Deserialize)]
pub struct HealthResponse {
pub status: String,
pub version: String,
pub backends: Vec<String>,
}
/// API-facing search result that matches the TypeScript SDK `SearchResult`.
#[derive(Debug, Serialize, Deserialize)]
pub struct ApiSearchResult {
pub id: String,
pub score: f32,
pub content: String,
pub source: String,
pub timestamp: String,
pub metadata: serde_json::Value,
}
/// Query parameters for `GET /search` (legacy v1).
#[derive(Debug, Deserialize)]
pub struct LegacySearchParams {
pub q: Option<String>,
pub content_type: Option<String>,
pub limit: Option<usize>,
}
/// Wrapper for JSON error responses.
#[derive(Serialize)]
struct ErrorBody {
error: String,
}
// ---------------------------------------------------------------------------
// Handlers
// ---------------------------------------------------------------------------
/// `POST /v2/search` - Semantic / hybrid search.
async fn search_handler(
State(state): State<ServerState>,
Json(req): Json<SearchRequest>,
) -> impl IntoResponse {
let pipeline = state.pipeline.read().await;
let embedding = pipeline.embedding_engine().embed(&req.query);
let k = if req.k == 0 { 10 } else { req.k };
let filter = build_search_filter(&req.filters);
let results = if filter_is_empty(&filter) {
pipeline.vector_store().search(&embedding, k)
} else {
pipeline
.vector_store()
.search_filtered(&embedding, k, &filter)
};
match results {
Ok(results) => {
let api_results: Vec<ApiSearchResult> =
results.into_iter().map(to_api_result).collect();
(StatusCode::OK, Json(api_results)).into_response()
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorBody {
error: e.to_string(),
}),
)
.into_response(),
}
}
/// `POST /v2/route` - Query routing.
async fn route_handler(
State(state): State<ServerState>,
Json(req): Json<RouteRequest>,
) -> impl IntoResponse {
let route = state.router.route(&req.query);
let route_str = match route {
QueryRoute::Semantic => "semantic",
QueryRoute::Keyword => "keyword",
QueryRoute::Graph => "graph",
QueryRoute::Temporal => "temporal",
QueryRoute::Hybrid => "hybrid",
};
Json(RouteResponse {
route: route_str.to_string(),
})
}
/// `GET /v2/stats` - Pipeline statistics.
async fn stats_handler(State(state): State<ServerState>) -> impl IntoResponse {
let pipeline = state.pipeline.read().await;
let stats: &PipelineStats = pipeline.stats();
let index_size = pipeline.vector_store().len();
let uptime = state.started_at.elapsed().as_secs();
Json(StatsResponse {
total_ingested: stats.total_ingested,
total_deduplicated: stats.total_deduplicated,
total_denied: stats.total_denied,
total_redacted: stats.total_redacted,
storage_bytes: 0, // not tracked in the in-memory store
index_size,
uptime,
})
}
/// `GET /v2/health` - Health check.
async fn health_handler() -> impl IntoResponse {
Json(HealthResponse {
status: "ok".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
backends: vec![
"hnsw".to_string(),
"keyword".to_string(),
"graph".to_string(),
],
})
}
/// `GET /search` - Legacy Screenpipe v1 search endpoint.
async fn legacy_search_handler(
State(state): State<ServerState>,
Query(params): Query<LegacySearchParams>,
) -> impl IntoResponse {
let q = match params.q {
Some(q) if !q.is_empty() => q,
_ => {
return (
StatusCode::BAD_REQUEST,
Json(ErrorBody {
error: "Missing required query parameter 'q'".to_string(),
}),
)
.into_response();
}
};
let k = params.limit.unwrap_or(10);
let pipeline = state.pipeline.read().await;
let embedding = pipeline.embedding_engine().embed(&q);
let filter = if let Some(ref ct) = params.content_type {
let mapped = match ct.as_str() {
"ocr" => "ocr",
"audio" => "transcription",
"ui" => "ui_event",
_ => "",
};
if mapped.is_empty() {
crate::storage::vector_store::SearchFilter::default()
} else {
crate::storage::vector_store::SearchFilter {
content_type: Some(mapped.to_string()),
..Default::default()
}
}
} else {
crate::storage::vector_store::SearchFilter::default()
};
let results = if filter_is_empty(&filter) {
pipeline.vector_store().search(&embedding, k)
} else {
pipeline
.vector_store()
.search_filtered(&embedding, k, &filter)
};
match results {
Ok(results) => {
let api_results: Vec<ApiSearchResult> =
results.into_iter().map(to_api_result).collect();
(StatusCode::OK, Json(api_results)).into_response()
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorBody {
error: e.to_string(),
}),
)
.into_response(),
}
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
/// Build a `SearchFilter` from optional API filters.
fn build_search_filter(
filters: &Option<SearchFilters>,
) -> crate::storage::vector_store::SearchFilter {
let Some(f) = filters else {
return crate::storage::vector_store::SearchFilter::default();
};
let content_type = f.content_type.as_deref().map(|ct| {
match ct {
"screen" => "ocr",
"audio" => "transcription",
"ui" => "ui_event",
other => other,
}
.to_string()
});
let (time_start, time_end) = if let Some(ref tr) = f.time_range {
(
chrono::DateTime::parse_from_rfc3339(&tr.start)
.ok()
.map(|dt| dt.with_timezone(&chrono::Utc)),
chrono::DateTime::parse_from_rfc3339(&tr.end)
.ok()
.map(|dt| dt.with_timezone(&chrono::Utc)),
)
} else {
(None, None)
};
crate::storage::vector_store::SearchFilter {
app: f.app.clone(),
time_start,
time_end,
content_type,
monitor: f.monitor,
}
}
/// Check whether a filter is effectively empty (no criteria set).
fn filter_is_empty(f: &crate::storage::vector_store::SearchFilter) -> bool {
f.app.is_none()
&& f.time_start.is_none()
&& f.time_end.is_none()
&& f.content_type.is_none()
&& f.monitor.is_none()
}
/// Convert an internal `SearchResult` to the API-facing DTO.
fn to_api_result(r: SearchResult) -> ApiSearchResult {
let content = r
.metadata
.get("text")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let source = r
.metadata
.get("content_type")
.and_then(|v| v.as_str())
.map(|ct| match ct {
"ocr" => "screen",
"transcription" => "audio",
"ui_event" => "ui",
other => other,
})
.unwrap_or("screen")
.to_string();
ApiSearchResult {
id: r.id.to_string(),
score: r.score,
content,
source,
timestamp: chrono::Utc::now().to_rfc3339(),
metadata: r.metadata,
}
}
// ---------------------------------------------------------------------------
// Router & startup
// ---------------------------------------------------------------------------
/// Build the axum [`Router`] with all OSpipe endpoints.
pub fn build_router(state: ServerState) -> Router {
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
Router::new()
// v2 API
.route("/v2/search", post(search_handler))
.route("/v2/route", post(route_handler))
.route("/v2/stats", get(stats_handler))
.route("/v2/health", get(health_handler))
// Legacy v1
.route("/search", get(legacy_search_handler))
.layer(cors)
.with_state(state)
}
/// Start the OSpipe HTTP server on the given port.
///
/// This function blocks until the server is shut down (e.g. via Ctrl-C).
///
/// # Errors
///
/// Returns an error if the TCP listener cannot bind to the requested port.
pub async fn start_server(state: ServerState, port: u16) -> crate::error::Result<()> {
let app = build_router(state);
let addr = format!("0.0.0.0:{}", port);
let listener = tokio::net::TcpListener::bind(&addr)
.await
.map_err(|e| OsPipeError::Pipeline(format!("Failed to bind to {}: {}", addr, e)))?;
tracing::info!("OSpipe server listening on {}", addr);
axum::serve(listener, app)
.await
.map_err(|e| OsPipeError::Pipeline(format!("Server error: {}", e)))?;
Ok(())
}
use crate::error::OsPipeError;
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use crate::config::OsPipeConfig;
use axum::body::Body;
use axum::http::Request;
use tower::ServiceExt; // for oneshot
fn test_state() -> ServerState {
let config = OsPipeConfig::default();
let pipeline = IngestionPipeline::new(config).unwrap();
ServerState {
pipeline: Arc::new(RwLock::new(pipeline)),
router: Arc::new(QueryRouter::new()),
started_at: std::time::Instant::now(),
}
}
#[tokio::test]
async fn test_health_endpoint() {
let state = test_state();
let app = build_router(state);
let req = Request::builder()
.uri("/v2/health")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), 1024 * 1024)
.await
.unwrap();
let health: HealthResponse = serde_json::from_slice(&body).unwrap();
assert_eq!(health.status, "ok");
assert_eq!(health.version, env!("CARGO_PKG_VERSION"));
assert!(!health.backends.is_empty());
}
#[tokio::test]
async fn test_stats_endpoint() {
let state = test_state();
let app = build_router(state);
let req = Request::builder()
.uri("/v2/stats")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), 1024 * 1024)
.await
.unwrap();
let stats: StatsResponse = serde_json::from_slice(&body).unwrap();
assert_eq!(stats.total_ingested, 0);
assert_eq!(stats.index_size, 0);
}
#[tokio::test]
async fn test_route_endpoint() {
let state = test_state();
let app = build_router(state);
let req = Request::builder()
.method("POST")
.uri("/v2/route")
.header("content-type", "application/json")
.body(Body::from(r#"{"query": "what happened yesterday"}"#))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), 1024 * 1024)
.await
.unwrap();
let route: RouteResponse = serde_json::from_slice(&body).unwrap();
assert_eq!(route.route, "temporal");
}
#[tokio::test]
async fn test_search_endpoint_empty_store() {
let state = test_state();
let app = build_router(state);
let req = Request::builder()
.method("POST")
.uri("/v2/search")
.header("content-type", "application/json")
.body(Body::from(
r#"{"query": "test", "mode": "semantic", "k": 5}"#,
))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), 1024 * 1024)
.await
.unwrap();
let results: Vec<ApiSearchResult> = serde_json::from_slice(&body).unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn test_legacy_search_missing_q() {
let state = test_state();
let app = build_router(state);
let req = Request::builder()
.uri("/search")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_search_with_ingested_data() {
let state = test_state();
// Ingest a frame so there is data to search
{
let mut pipeline = state.pipeline.write().await;
let frame = crate::capture::CapturedFrame::new_screen(
"VSCode",
"main.rs",
"fn main() { println!(\"hello\"); }",
0,
);
pipeline.ingest(frame).unwrap();
}
let app = build_router(state);
let req = Request::builder()
.method("POST")
.uri("/v2/search")
.header("content-type", "application/json")
.body(Body::from(r#"{"query": "fn main", "k": 5}"#))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), 1024 * 1024)
.await
.unwrap();
let results: Vec<ApiSearchResult> = serde_json::from_slice(&body).unwrap();
assert_eq!(results.len(), 1);
assert!(results[0].content.contains("fn main"));
assert_eq!(results[0].source, "screen");
}
}

View File

@@ -0,0 +1,163 @@
//! Embedding generation engine.
//!
//! This module provides a deterministic hash-based embedding engine for
//! development and testing. In production, this would be replaced with
//! a real model (ONNX, Candle, or an API-based provider via ruvector-core's
//! EmbeddingProvider trait).
//!
//! `EmbeddingEngine` also implements [`EmbeddingModel`]
//! so it can be used anywhere a trait-based embedding source is required.
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use super::traits::EmbeddingModel;
/// Engine that generates vector embeddings from text.
///
/// The current implementation uses a deterministic hash-based approach
/// that produces consistent embeddings for the same input text. This is
/// suitable for testing deduplication and search mechanics, but does NOT
/// provide semantic similarity. For semantic search, integrate a real
/// embedding model.
pub struct EmbeddingEngine {
dimension: usize,
}
impl EmbeddingEngine {
/// Create a new embedding engine with the given vector dimension.
pub fn new(dimension: usize) -> Self {
Self { dimension }
}
/// Generate an embedding vector for the given text.
///
/// The resulting vector is L2-normalized so that cosine similarity
/// can be computed as a simple dot product.
pub fn embed(&self, text: &str) -> Vec<f32> {
let mut vector = vec![0.0f32; self.dimension];
// Generate deterministic pseudo-random values from text hash
// We use multiple hash passes with different seeds to fill the vector.
for (i, val) in vector.iter_mut().enumerate() {
let mut hasher = DefaultHasher::new();
i.hash(&mut hasher);
text.hash(&mut hasher);
let h = hasher.finish();
// Map to [-1, 1] range
*val = ((h as f64 / u64::MAX as f64) * 2.0 - 1.0) as f32;
}
// L2-normalize the vector
normalize(&mut vector);
vector
}
/// Generate embeddings for a batch of texts.
pub fn batch_embed(&self, texts: &[&str]) -> Vec<Vec<f32>> {
texts.iter().map(|t| self.embed(t)).collect()
}
/// Return the dimensionality of embeddings produced by this engine.
pub fn dimension(&self) -> usize {
self.dimension
}
}
/// `EmbeddingEngine` satisfies [`EmbeddingModel`] so existing code can
/// pass an `&EmbeddingEngine` wherever a `&dyn EmbeddingModel` is needed.
impl EmbeddingModel for EmbeddingEngine {
fn embed(&self, text: &str) -> Vec<f32> {
EmbeddingEngine::embed(self, text)
}
fn batch_embed(&self, texts: &[&str]) -> Vec<Vec<f32>> {
EmbeddingEngine::batch_embed(self, texts)
}
fn dimension(&self) -> usize {
self.dimension
}
}
/// L2-normalize a vector in place. If the vector has zero magnitude,
/// it is left unchanged.
pub fn normalize(vector: &mut [f32]) {
let magnitude: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude > f32::EPSILON {
for val in vector.iter_mut() {
*val /= magnitude;
}
}
}
/// Compute cosine similarity between two L2-normalized vectors.
///
/// For normalized vectors, cosine similarity equals the dot product.
/// Returns a value in [-1.0, 1.0].
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Vectors must have equal dimensions");
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_determinism() {
let engine = EmbeddingEngine::new(384);
let v1 = engine.embed("hello world");
let v2 = engine.embed("hello world");
assert_eq!(v1, v2);
}
#[test]
fn test_embedding_dimension() {
let engine = EmbeddingEngine::new(128);
let v = engine.embed("test");
assert_eq!(v.len(), 128);
}
#[test]
fn test_embedding_normalized() {
let engine = EmbeddingEngine::new(384);
let v = engine.embed("test normalization");
let magnitude: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(magnitude - 1.0).abs() < 1e-5,
"Expected unit vector, got magnitude {}",
magnitude
);
}
#[test]
fn test_cosine_similarity_identical() {
let engine = EmbeddingEngine::new(384);
let v = engine.embed("same text");
let sim = cosine_similarity(&v, &v);
assert!((sim - 1.0).abs() < 1e-5);
}
#[test]
fn test_cosine_similarity_different() {
let engine = EmbeddingEngine::new(384);
let v1 = engine.embed("hello world");
let v2 = engine.embed("completely different text about cats");
let sim = cosine_similarity(&v1, &v2);
// Hash-based embeddings won't give semantic similarity,
// but different texts should generally not be identical.
assert!(sim < 1.0);
}
#[test]
fn test_batch_embed() {
let engine = EmbeddingEngine::new(64);
let texts = vec!["one", "two", "three"];
let embeddings = engine.batch_embed(&texts);
assert_eq!(embeddings.len(), 3);
for emb in &embeddings {
assert_eq!(emb.len(), 64);
}
}
}

View File

@@ -0,0 +1,18 @@
//! Vector storage, embedding engine, and trait abstractions.
//!
//! Provides HNSW-backed vector storage for captured frames with
//! cosine similarity search, metadata filtering, delete/update operations,
//! and a pluggable embedding model trait.
pub mod embedding;
pub mod traits;
pub mod vector_store;
pub use embedding::EmbeddingEngine;
pub use traits::{EmbeddingModel, HashEmbeddingModel};
pub use vector_store::{SearchFilter, SearchResult, StoredEmbedding, VectorStore};
#[cfg(not(target_arch = "wasm32"))]
pub use traits::RuvectorEmbeddingModel;
#[cfg(not(target_arch = "wasm32"))]
pub use vector_store::HnswVectorStore;

View File

@@ -0,0 +1,203 @@
//! Embedding model trait abstraction.
//!
//! Defines the [`EmbeddingModel`] trait that all embedding providers must
//! implement, enabling pluggable embedding backends. Two implementations are
//! provided out of the box:
//!
//! - [`HashEmbeddingModel`] - deterministic hash-based embeddings (no semantic
//! similarity, suitable for testing).
//! - [`RuvectorEmbeddingModel`] (native only) - wraps ruvector-core's
//! [`EmbeddingProvider`](ruvector_core::embeddings::EmbeddingProvider) for
//! real embedding backends (hash, candle, API-based).
/// Trait for generating vector embeddings from text.
///
/// Implementations must be `Send + Sync` so they can be shared across
/// threads.
pub trait EmbeddingModel: Send + Sync {
/// Generate an embedding vector for the given text.
fn embed(&self, text: &str) -> Vec<f32>;
/// Generate embeddings for a batch of texts.
///
/// The default implementation calls [`embed`](Self::embed) for each text
/// sequentially. Implementations may override this for batched inference.
fn batch_embed(&self, texts: &[&str]) -> Vec<Vec<f32>> {
texts.iter().map(|t| self.embed(t)).collect()
}
/// Return the dimensionality of embeddings produced by this model.
fn dimension(&self) -> usize;
}
// ---------------------------------------------------------------------------
// HashEmbeddingModel (cross-platform, always available)
// ---------------------------------------------------------------------------
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use super::embedding::normalize;
/// Hash-based embedding model for testing and development.
///
/// Produces deterministic, L2-normalized vectors from text using
/// `DefaultHasher`. The vectors have no semantic meaning -- identical
/// inputs produce identical outputs, but semantically similar inputs
/// are *not* guaranteed to be close in vector space.
pub struct HashEmbeddingModel {
dimension: usize,
}
impl HashEmbeddingModel {
/// Create a new hash-based embedding model with the given dimension.
pub fn new(dimension: usize) -> Self {
Self { dimension }
}
}
impl EmbeddingModel for HashEmbeddingModel {
fn embed(&self, text: &str) -> Vec<f32> {
let mut vector = vec![0.0f32; self.dimension];
for (i, val) in vector.iter_mut().enumerate() {
let mut hasher = DefaultHasher::new();
i.hash(&mut hasher);
text.hash(&mut hasher);
let h = hasher.finish();
*val = ((h as f64 / u64::MAX as f64) * 2.0 - 1.0) as f32;
}
normalize(&mut vector);
vector
}
fn dimension(&self) -> usize {
self.dimension
}
}
// ---------------------------------------------------------------------------
// RuvectorEmbeddingModel (native only -- wraps ruvector-core)
// ---------------------------------------------------------------------------
#[cfg(not(target_arch = "wasm32"))]
mod native {
use super::EmbeddingModel;
use crate::storage::embedding::normalize;
use ruvector_core::embeddings::EmbeddingProvider;
use std::sync::Arc;
/// Embedding model backed by a ruvector-core [`EmbeddingProvider`].
///
/// This wraps any `EmbeddingProvider` (e.g. `HashEmbedding`,
/// `CandleEmbedding`, `ApiEmbedding`) behind the OSpipe
/// [`EmbeddingModel`] trait, making the provider swappable at
/// construction time.
pub struct RuvectorEmbeddingModel {
provider: Arc<dyn EmbeddingProvider>,
}
impl RuvectorEmbeddingModel {
/// Create a new model wrapping the given provider.
pub fn new(provider: Arc<dyn EmbeddingProvider>) -> Self {
Self { provider }
}
/// Create a model using ruvector-core's `HashEmbedding` with the
/// given dimension. This is the simplest way to get started on
/// native targets.
pub fn hash(dimensions: usize) -> Self {
let provider = Arc::new(ruvector_core::embeddings::HashEmbedding::new(dimensions));
Self { provider }
}
}
impl EmbeddingModel for RuvectorEmbeddingModel {
fn embed(&self, text: &str) -> Vec<f32> {
match self.provider.embed(text) {
Ok(mut v) => {
normalize(&mut v);
v
}
Err(e) => {
tracing::warn!("Embedding provider failed, returning zero vector: {}", e);
vec![0.0f32; self.provider.dimensions()]
}
}
}
fn dimension(&self) -> usize {
self.provider.dimensions()
}
}
}
#[cfg(not(target_arch = "wasm32"))]
pub use native::RuvectorEmbeddingModel;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hash_embedding_model_determinism() {
let model = HashEmbeddingModel::new(128);
let v1 = model.embed("hello world");
let v2 = model.embed("hello world");
assert_eq!(v1, v2);
}
#[test]
fn test_hash_embedding_model_dimension() {
let model = HashEmbeddingModel::new(64);
assert_eq!(model.dimension(), 64);
let v = model.embed("test");
assert_eq!(v.len(), 64);
}
#[test]
fn test_hash_embedding_model_normalized() {
let model = HashEmbeddingModel::new(384);
let v = model.embed("normalization test");
let mag: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(mag - 1.0).abs() < 1e-5,
"Expected unit vector, got magnitude {}",
mag,
);
}
#[test]
fn test_batch_embed() {
let model = HashEmbeddingModel::new(64);
let texts: Vec<&str> = vec!["one", "two", "three"];
let embeddings = model.batch_embed(&texts);
assert_eq!(embeddings.len(), 3);
for emb in &embeddings {
assert_eq!(emb.len(), 64);
}
}
#[test]
fn test_trait_object_dispatch() {
let model: Box<dyn EmbeddingModel> = Box::new(HashEmbeddingModel::new(32));
let v = model.embed("dispatch test");
assert_eq!(v.len(), 32);
}
#[cfg(not(target_arch = "wasm32"))]
#[test]
fn test_ruvector_embedding_model() {
let model = RuvectorEmbeddingModel::hash(128);
let v = model.embed("ruvector test");
assert_eq!(v.len(), 128);
assert_eq!(model.dimension(), 128);
// Should be normalized
let mag: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(mag - 1.0).abs() < 1e-4,
"Expected unit vector, got magnitude {}",
mag,
);
}
}

View File

@@ -0,0 +1,541 @@
//! Vector storage with cosine similarity search.
//!
//! This module provides two implementations:
//!
//! - [`VectorStore`] -- brute-force O(n) linear scan (cross-platform,
//! works on WASM).
//! - [`HnswVectorStore`] (native only) -- wraps ruvector-core's HNSW
//! index for O(log n) approximate nearest-neighbor search.
//!
//! Both implementations support insert, search, filtered search, delete,
//! and metadata update.
use crate::capture::CapturedFrame;
use crate::config::StorageConfig;
use crate::error::{OsPipeError, Result};
use crate::storage::embedding::cosine_similarity;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// A vector embedding stored with its metadata.
#[derive(Debug, Clone)]
pub struct StoredEmbedding {
/// Unique identifier matching the source frame.
pub id: Uuid,
/// The embedding vector.
pub vector: Vec<f32>,
/// JSON metadata about the source frame.
pub metadata: serde_json::Value,
/// When the source frame was captured.
pub timestamp: DateTime<Utc>,
}
/// A search result returned from the vector store.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
/// ID of the matched embedding.
pub id: Uuid,
/// Cosine similarity score (higher is more similar).
pub score: f32,
/// Metadata of the matched embedding.
pub metadata: serde_json::Value,
}
/// Filter criteria for narrowing search results.
#[derive(Debug, Clone, Default)]
pub struct SearchFilter {
/// Filter by application name.
pub app: Option<String>,
/// Filter by start time (inclusive).
pub time_start: Option<DateTime<Utc>>,
/// Filter by end time (inclusive).
pub time_end: Option<DateTime<Utc>>,
/// Filter by content type (e.g., "ocr", "transcription", "ui_event").
pub content_type: Option<String>,
/// Filter by monitor index.
pub monitor: Option<u32>,
}
// ===========================================================================
// VectorStore -- brute-force fallback (cross-platform)
// ===========================================================================
/// In-memory vector store with brute-force cosine similarity search.
///
/// This is the cross-platform fallback that also works on WASM targets.
/// On native targets, prefer [`HnswVectorStore`] for large datasets.
pub struct VectorStore {
config: StorageConfig,
embeddings: Vec<StoredEmbedding>,
dimension: usize,
}
impl VectorStore {
/// Create a new vector store with the given configuration.
pub fn new(config: StorageConfig) -> Result<Self> {
let dimension = config.embedding_dim;
if dimension == 0 {
return Err(OsPipeError::Storage(
"embedding_dim must be greater than 0".to_string(),
));
}
Ok(Self {
config,
embeddings: Vec::new(),
dimension,
})
}
/// Insert a captured frame with its pre-computed embedding.
pub fn insert(&mut self, frame: &CapturedFrame, embedding: &[f32]) -> Result<()> {
if embedding.len() != self.dimension {
return Err(OsPipeError::Storage(format!(
"Expected embedding dimension {}, got {}",
self.dimension,
embedding.len()
)));
}
let metadata = serde_json::json!({
"text": frame.text_content(),
"content_type": frame.content_type(),
"app_name": frame.metadata.app_name,
"window_title": frame.metadata.window_title,
"monitor_id": frame.metadata.monitor_id,
"confidence": frame.metadata.confidence,
});
self.embeddings.push(StoredEmbedding {
id: frame.id,
vector: embedding.to_vec(),
metadata,
timestamp: frame.timestamp,
});
Ok(())
}
/// Search for the k most similar embeddings to the query vector.
pub fn search(&self, query_embedding: &[f32], k: usize) -> Result<Vec<SearchResult>> {
if query_embedding.len() != self.dimension {
return Err(OsPipeError::Search(format!(
"Expected query dimension {}, got {}",
self.dimension,
query_embedding.len()
)));
}
let mut scored: Vec<(usize, f32)> = self
.embeddings
.iter()
.enumerate()
.map(|(i, stored)| {
let score = cosine_similarity(query_embedding, &stored.vector);
(i, score)
})
.collect();
// Sort by score descending
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(k);
Ok(scored
.into_iter()
.map(|(i, score)| {
let stored = &self.embeddings[i];
SearchResult {
id: stored.id,
score,
metadata: stored.metadata.clone(),
}
})
.collect())
}
/// Search with metadata filtering applied before scoring.
pub fn search_filtered(
&self,
query: &[f32],
k: usize,
filter: &SearchFilter,
) -> Result<Vec<SearchResult>> {
if query.len() != self.dimension {
return Err(OsPipeError::Search(format!(
"Expected query dimension {}, got {}",
self.dimension,
query.len()
)));
}
let mut scored: Vec<(usize, f32)> = self
.embeddings
.iter()
.enumerate()
.filter(|(_, stored)| matches_filter(stored, filter))
.map(|(i, stored)| {
let score = cosine_similarity(query, &stored.vector);
(i, score)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(k);
Ok(scored
.into_iter()
.map(|(i, score)| {
let stored = &self.embeddings[i];
SearchResult {
id: stored.id,
score,
metadata: stored.metadata.clone(),
}
})
.collect())
}
/// Delete a stored embedding by its ID.
///
/// Returns `true` if the embedding was found and removed, `false`
/// if no embedding with the given ID existed.
pub fn delete(&mut self, id: &Uuid) -> Result<bool> {
let before = self.embeddings.len();
self.embeddings.retain(|e| e.id != *id);
Ok(self.embeddings.len() < before)
}
/// Update the metadata of a stored embedding.
///
/// The provided `metadata` value completely replaces the old metadata
/// for the entry identified by `id`. Returns an error if the ID is
/// not found.
pub fn update_metadata(&mut self, id: &Uuid, metadata: serde_json::Value) -> Result<()> {
match self.embeddings.iter_mut().find(|e| e.id == *id) {
Some(entry) => {
entry.metadata = metadata;
Ok(())
}
None => Err(OsPipeError::Storage(format!(
"No embedding found with id {}",
id
))),
}
}
/// Return the number of stored embeddings.
pub fn len(&self) -> usize {
self.embeddings.len()
}
/// Return true if the store contains no embeddings.
pub fn is_empty(&self) -> bool {
self.embeddings.is_empty()
}
/// Return the configured embedding dimension.
pub fn dimension(&self) -> usize {
self.dimension
}
/// Return a reference to the storage configuration.
pub fn config(&self) -> &StorageConfig {
&self.config
}
/// Get a stored embedding by its ID.
pub fn get(&self, id: &Uuid) -> Option<&StoredEmbedding> {
self.embeddings.iter().find(|e| e.id == *id)
}
}
// ===========================================================================
// HnswVectorStore -- native-only HNSW-backed store
// ===========================================================================
#[cfg(not(target_arch = "wasm32"))]
mod native {
use super::*;
use ruvector_core::index::hnsw::HnswIndex;
use ruvector_core::index::VectorIndex;
use ruvector_core::types::{DistanceMetric, HnswConfig};
use std::collections::HashMap;
/// HNSW-backed vector store using ruvector-core.
///
/// Uses approximate nearest-neighbor search for O(log n) query time.
/// Metadata and timestamps are stored in a side-car `HashMap`
/// alongside the HNSW index.
pub struct HnswVectorStore {
index: HnswIndex,
/// Side-car storage: id -> (metadata, timestamp, vector)
entries: HashMap<Uuid, StoredEmbedding>,
dimension: usize,
config: StorageConfig,
ef_search: usize,
}
impl HnswVectorStore {
/// Create a new HNSW-backed vector store.
pub fn new(config: StorageConfig) -> Result<Self> {
let dimension = config.embedding_dim;
if dimension == 0 {
return Err(OsPipeError::Storage(
"embedding_dim must be greater than 0".to_string(),
));
}
let hnsw_config = HnswConfig {
m: config.hnsw_m,
ef_construction: config.hnsw_ef_construction,
ef_search: config.hnsw_ef_search,
max_elements: 10_000_000,
};
let index = HnswIndex::new(dimension, DistanceMetric::Cosine, hnsw_config)
.map_err(|e| OsPipeError::Storage(format!("Failed to create HNSW index: {}", e)))?;
let ef_search = config.hnsw_ef_search;
Ok(Self {
index,
entries: HashMap::new(),
dimension,
config,
ef_search,
})
}
/// Insert a captured frame with its pre-computed embedding.
pub fn insert(&mut self, frame: &CapturedFrame, embedding: &[f32]) -> Result<()> {
if embedding.len() != self.dimension {
return Err(OsPipeError::Storage(format!(
"Expected embedding dimension {}, got {}",
self.dimension,
embedding.len()
)));
}
let metadata = serde_json::json!({
"text": frame.text_content(),
"content_type": frame.content_type(),
"app_name": frame.metadata.app_name,
"window_title": frame.metadata.window_title,
"monitor_id": frame.metadata.monitor_id,
"confidence": frame.metadata.confidence,
});
let id_str = frame.id.to_string();
// Insert into HNSW index
self.index
.add(id_str, embedding.to_vec())
.map_err(|e| OsPipeError::Storage(format!("HNSW insert failed: {}", e)))?;
// Store side-car data
self.entries.insert(
frame.id,
StoredEmbedding {
id: frame.id,
vector: embedding.to_vec(),
metadata,
timestamp: frame.timestamp,
},
);
Ok(())
}
/// Search for the k most similar embeddings using HNSW ANN search.
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
if query.len() != self.dimension {
return Err(OsPipeError::Search(format!(
"Expected query dimension {}, got {}",
self.dimension,
query.len()
)));
}
let hnsw_results = self
.index
.search_with_ef(query, k, self.ef_search)
.map_err(|e| OsPipeError::Search(format!("HNSW search failed: {}", e)))?;
let mut results = Vec::with_capacity(hnsw_results.len());
for hr in hnsw_results {
// hr.id is a String representation of the Uuid
if let Ok(uuid) = Uuid::parse_str(&hr.id) {
if let Some(stored) = self.entries.get(&uuid) {
// ruvector-core HNSW returns distance (lower = closer
// for cosine). Convert to similarity: 1.0 - distance.
let similarity = 1.0 - hr.score;
results.push(SearchResult {
id: uuid,
score: similarity,
metadata: stored.metadata.clone(),
});
}
}
}
// Sort descending by similarity score
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(results)
}
/// Search with post-filtering on metadata.
///
/// HNSW does not natively support metadata filters, so we
/// over-fetch and filter after the ANN search.
pub fn search_filtered(
&self,
query: &[f32],
k: usize,
filter: &SearchFilter,
) -> Result<Vec<SearchResult>> {
// Over-fetch to account for filtering
let over_k = (k * 4).max(k + 20);
let candidates = self.search(query, over_k)?;
let mut filtered: Vec<SearchResult> = candidates
.into_iter()
.filter(|r| {
if let Some(stored) = self.entries.get(&r.id) {
matches_filter(stored, filter)
} else {
false
}
})
.take(k)
.collect();
filtered.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(filtered)
}
/// Delete a stored embedding by its ID.
///
/// Returns `true` if the embedding was found and removed, `false`
/// otherwise. The HNSW graph link is removed via soft-delete (the
/// underlying `hnsw_rs` does not support hard deletion).
pub fn delete(&mut self, id: &Uuid) -> Result<bool> {
let id_str = id.to_string();
let removed_from_index = self
.index
.remove(&id_str)
.map_err(|e| OsPipeError::Storage(format!("HNSW delete failed: {}", e)))?;
let removed_from_entries = self.entries.remove(id).is_some();
Ok(removed_from_index || removed_from_entries)
}
/// Update the metadata of a stored embedding.
///
/// Returns an error if no embedding with the given ID exists.
pub fn update_metadata(&mut self, id: &Uuid, metadata: serde_json::Value) -> Result<()> {
match self.entries.get_mut(id) {
Some(entry) => {
entry.metadata = metadata;
Ok(())
}
None => Err(OsPipeError::Storage(format!(
"No embedding found with id {}",
id
))),
}
}
/// Return the number of stored embeddings.
pub fn len(&self) -> usize {
self.entries.len()
}
/// Return true if the store is empty.
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
/// Return the configured embedding dimension.
pub fn dimension(&self) -> usize {
self.dimension
}
/// Return a reference to the storage configuration.
pub fn config(&self) -> &StorageConfig {
&self.config
}
/// Get a stored embedding by its ID.
pub fn get(&self, id: &Uuid) -> Option<&StoredEmbedding> {
self.entries.get(id)
}
}
}
#[cfg(not(target_arch = "wasm32"))]
pub use native::HnswVectorStore;
// ===========================================================================
// Shared helpers
// ===========================================================================
/// Check whether a stored embedding matches the given filter.
fn matches_filter(stored: &StoredEmbedding, filter: &SearchFilter) -> bool {
if let Some(ref app) = filter.app {
let stored_app = stored
.metadata
.get("app_name")
.and_then(|v| v.as_str())
.unwrap_or("");
if stored_app != app {
return false;
}
}
if let Some(start) = filter.time_start {
if stored.timestamp < start {
return false;
}
}
if let Some(end) = filter.time_end {
if stored.timestamp > end {
return false;
}
}
if let Some(ref ct) = filter.content_type {
let stored_ct = stored
.metadata
.get("content_type")
.and_then(|v| v.as_str())
.unwrap_or("");
if stored_ct != ct {
return false;
}
}
if let Some(monitor) = filter.monitor {
let stored_monitor = stored
.metadata
.get("monitor_id")
.and_then(|v| v.as_u64())
.map(|v| v as u32);
if stored_monitor != Some(monitor) {
return false;
}
}
true
}

View File

@@ -0,0 +1,265 @@
//! WASM-bindgen exports for OSpipe browser usage.
//!
//! This module exposes a self-contained vector store that runs entirely in the
//! browser via WebAssembly. It supports embedding insertion, semantic search
//! with optional time-range filtering, deduplication checks, simple text
//! embedding (hash-based, suitable for demos), content safety checks, and
//! query routing heuristics.
use serde::{Deserialize, Serialize};
use wasm_bindgen::prelude::*;
use super::helpers;
/// Initialize WASM module: installs `console_error_panic_hook` so that Rust
/// panics produce readable error messages in the browser developer console
/// instead of the default `unreachable` with no context.
#[wasm_bindgen(start)]
pub fn init() {
#[cfg(feature = "console_error_panic_hook")]
console_error_panic_hook::set_once();
}
// ---------------------------------------------------------------------------
// Internal data structures
// ---------------------------------------------------------------------------
/// A single stored embedding with metadata.
struct WasmEmbedding {
id: String,
vector: Vec<f32>,
metadata: String, // JSON string
timestamp: f64, // Unix milliseconds
}
/// A search result returned to JavaScript.
#[derive(Serialize, Deserialize)]
struct SearchHit {
id: String,
score: f64,
metadata: String,
timestamp: f64,
}
// ---------------------------------------------------------------------------
// Public WASM API
// ---------------------------------------------------------------------------
/// OSpipe WASM -- browser-based personal AI memory search.
#[wasm_bindgen]
pub struct OsPipeWasm {
dimension: usize,
embeddings: Vec<WasmEmbedding>,
}
#[wasm_bindgen]
impl OsPipeWasm {
// -- lifecycle ---------------------------------------------------------
/// Create a new OsPipeWasm instance with the given embedding dimension.
#[wasm_bindgen(constructor)]
pub fn new(dimension: usize) -> Self {
Self {
dimension,
embeddings: Vec::new(),
}
}
// -- insertion ---------------------------------------------------------
/// Insert a frame embedding into the store.
///
/// * `id` - Unique identifier for this frame.
/// * `embedding` - Float32 vector whose length must match `dimension`.
/// * `metadata` - Arbitrary JSON string attached to this frame.
/// * `timestamp` - Unix timestamp in milliseconds.
pub fn insert(
&mut self,
id: &str,
embedding: &[f32],
metadata: &str,
timestamp: f64,
) -> Result<(), JsValue> {
if embedding.len() != self.dimension {
return Err(JsValue::from_str(&format!(
"Embedding dimension mismatch: expected {}, got {}",
self.dimension,
embedding.len()
)));
}
self.embeddings.push(WasmEmbedding {
id: id.to_string(),
vector: embedding.to_vec(),
metadata: metadata.to_string(),
timestamp,
});
Ok(())
}
// -- search ------------------------------------------------------------
/// Semantic search by embedding vector. Returns the top-k results as a
/// JSON-serialized `JsValue` array of `{ id, score, metadata, timestamp }`.
pub fn search(&self, query_embedding: &[f32], k: usize) -> Result<JsValue, JsValue> {
if query_embedding.len() != self.dimension {
return Err(JsValue::from_str(&format!(
"Query dimension mismatch: expected {}, got {}",
self.dimension,
query_embedding.len()
)));
}
let mut scored: Vec<(usize, f32)> = self
.embeddings
.iter()
.enumerate()
.map(|(i, e)| (i, helpers::cosine_similarity(query_embedding, &e.vector)))
.collect();
// Sort descending by similarity.
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let hits: Vec<SearchHit> = scored
.into_iter()
.take(k)
.map(|(i, score)| {
let e = &self.embeddings[i];
SearchHit {
id: e.id.clone(),
score: score as f64,
metadata: e.metadata.clone(),
timestamp: e.timestamp,
}
})
.collect();
serde_wasm_bindgen::to_value(&hits).map_err(|e| JsValue::from_str(&e.to_string()))
}
/// Search with a time-range filter. Only embeddings whose timestamp falls
/// within `[start_time, end_time]` (inclusive) are considered.
pub fn search_filtered(
&self,
query_embedding: &[f32],
k: usize,
start_time: f64,
end_time: f64,
) -> Result<JsValue, JsValue> {
if query_embedding.len() != self.dimension {
return Err(JsValue::from_str(&format!(
"Query dimension mismatch: expected {}, got {}",
self.dimension,
query_embedding.len()
)));
}
let mut scored: Vec<(usize, f32)> = self
.embeddings
.iter()
.enumerate()
.filter(|(_, e)| e.timestamp >= start_time && e.timestamp <= end_time)
.map(|(i, e)| (i, helpers::cosine_similarity(query_embedding, &e.vector)))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let hits: Vec<SearchHit> = scored
.into_iter()
.take(k)
.map(|(i, score)| {
let e = &self.embeddings[i];
SearchHit {
id: e.id.clone(),
score: score as f64,
metadata: e.metadata.clone(),
timestamp: e.timestamp,
}
})
.collect();
serde_wasm_bindgen::to_value(&hits).map_err(|e| JsValue::from_str(&e.to_string()))
}
// -- deduplication -----------------------------------------------------
/// Check whether `embedding` is a near-duplicate of any stored embedding.
///
/// Returns `true` when the cosine similarity to any existing embedding is
/// greater than or equal to `threshold`.
pub fn is_duplicate(&self, embedding: &[f32], threshold: f32) -> bool {
self.embeddings
.iter()
.any(|e| helpers::cosine_similarity(embedding, &e.vector) >= threshold)
}
// -- stats / accessors -------------------------------------------------
/// Number of stored embeddings.
pub fn len(&self) -> usize {
self.embeddings.len()
}
/// Returns true if no embeddings are stored.
pub fn is_empty(&self) -> bool {
self.embeddings.is_empty()
}
/// Return pipeline statistics as a JSON string.
pub fn stats(&self) -> String {
serde_json::json!({
"dimension": self.dimension,
"total_embeddings": self.embeddings.len(),
"memory_estimate_bytes": self.embeddings.len() * (self.dimension * 4 + 128),
})
.to_string()
}
// -- text embedding (demo / hash-based) --------------------------------
/// Generate a simple deterministic embedding from text.
///
/// This uses a hash-based approach and is **not** a real neural embedding.
/// Suitable for demos and testing only.
pub fn embed_text(&self, text: &str) -> Vec<f32> {
helpers::hash_embed(text, self.dimension)
}
/// Batch-embed multiple texts.
///
/// `texts` must be a JS `Array<string>`. Returns a JS `Array<Float32Array>`.
pub fn batch_embed(&self, texts: JsValue) -> Result<JsValue, JsValue> {
let text_list: Vec<String> = serde_wasm_bindgen::from_value(texts)
.map_err(|e| JsValue::from_str(&format!("Failed to deserialize texts: {e}")))?;
let results: Vec<Vec<f32>> = text_list
.iter()
.map(|t| helpers::hash_embed(t, self.dimension))
.collect();
serde_wasm_bindgen::to_value(&results).map_err(|e| JsValue::from_str(&e.to_string()))
}
// -- safety ------------------------------------------------------------
/// Run a lightweight safety check on `content`.
///
/// Returns one of:
/// - `"deny"` -- content contains patterns that should not be stored
/// (e.g. credit card numbers, SSNs).
/// - `"redact"` -- content contains potentially sensitive information
/// that could be redacted.
/// - `"allow"` -- content appears safe.
pub fn safety_check(&self, content: &str) -> String {
helpers::safety_classify(content).to_string()
}
// -- query routing -----------------------------------------------------
/// Route a query string to the optimal search backend based on simple
/// keyword heuristics.
///
/// Returns one of: `"Graph"`, `"Temporal"`, `"Keyword"`, `"Semantic"`.
pub fn route_query(&self, query: &str) -> String {
helpers::route_query(query).to_string()
}
}

View File

@@ -0,0 +1,461 @@
//! Pure helper functions used by the WASM bindings.
//!
//! These functions have no WASM dependencies and can be tested on any target.
/// Cosine similarity between two vectors.
///
/// Returns 0.0 when either vector has zero magnitude.
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "vectors must be same length");
let mut dot: f32 = 0.0;
let mut mag_a: f32 = 0.0;
let mut mag_b: f32 = 0.0;
for i in 0..a.len() {
dot += a[i] * b[i];
mag_a += a[i] * a[i];
mag_b += b[i] * b[i];
}
let denom = mag_a.sqrt() * mag_b.sqrt();
if denom == 0.0 {
0.0
} else {
dot / denom
}
}
/// Produce a deterministic pseudo-embedding from text using a simple hash.
///
/// The algorithm:
/// 1. Hash each character position into a seed.
/// 2. Use the seed to generate a float in [-1, 1].
/// 3. L2-normalise the resulting vector.
///
/// This is NOT a real embedding model -- it is only useful for demos and
/// testing that the WASM plumbing works end-to-end.
pub fn hash_embed(text: &str, dimension: usize) -> Vec<f32> {
let mut vec = vec![0.0f32; dimension];
let bytes = text.as_bytes();
for (i, slot) in vec.iter_mut().enumerate() {
// Mix byte values into the slot.
let mut h: u64 = 0xcbf29ce484222325; // FNV-1a offset basis
for (j, &b) in bytes.iter().enumerate() {
h ^= (b as u64)
.wrapping_add((i as u64).wrapping_mul(31))
.wrapping_add(j as u64);
h = h.wrapping_mul(0x100000001b3); // FNV-1a prime
}
// Map to [-1, 1].
*slot = ((h as i64) as f64 / i64::MAX as f64) as f32;
}
// L2 normalise.
let mag: f32 = vec.iter().map(|v| v * v).sum::<f32>().sqrt();
if mag > 0.0 {
for v in &mut vec {
*v /= mag;
}
}
vec
}
/// Check for credit-card-like patterns: 4 groups of 4 digits separated by
/// spaces or dashes (or no separator).
pub fn has_credit_card_pattern(content: &str) -> bool {
// Strategy: scan for sequences of 16 digits (possibly with separators).
let digits_only: String = content.chars().filter(|c| c.is_ascii_digit()).collect();
// Quick check: must have at least 16 digits somewhere.
if digits_only.len() < 16 {
return false;
}
// Look for the formatted pattern: DDDD[-/ ]DDDD[-/ ]DDDD[-/ ]DDDD
// We do a simple windowed scan on the original string.
let chars: Vec<char> = content.chars().collect();
let len = chars.len();
let mut i = 0;
while i < len {
if let Some(end) = try_parse_cc_at(&chars, i) {
// Verify the group doesn't continue with more digits (avoid
// matching longer numeric strings that aren't cards).
if end >= len || !chars[end].is_ascii_digit() {
// Also make sure it didn't start as part of a longer number.
if i == 0 || !chars[i - 1].is_ascii_digit() {
return true;
}
}
i = end;
} else {
i += 1;
}
}
false
}
/// Try to parse a credit-card-like pattern starting at position `start`.
/// Returns the index past the last consumed character on success.
fn try_parse_cc_at(chars: &[char], start: usize) -> Option<usize> {
let mut pos = start;
for group in 0..4 {
// Expect 4 digits.
for _ in 0..4 {
if pos >= chars.len() || !chars[pos].is_ascii_digit() {
return None;
}
pos += 1;
}
// After the first 3 groups, allow an optional separator.
if group < 3 && pos < chars.len() && (chars[pos] == '-' || chars[pos] == ' ') {
pos += 1;
}
}
Some(pos)
}
/// Check for SSN-like patterns: XXX-XX-XXXX
pub fn has_ssn_pattern(content: &str) -> bool {
let chars: Vec<char> = content.chars().collect();
let len = chars.len();
// Pattern length: 3 + 1 + 2 + 1 + 4 = 11
if len < 11 {
return false;
}
for i in 0..=len - 11 {
// Must not be preceded by a digit.
if i > 0 && chars[i - 1].is_ascii_digit() {
continue;
}
// Must not be followed by a digit.
if i + 11 < len && chars[i + 11].is_ascii_digit() {
continue;
}
if chars[i].is_ascii_digit()
&& chars[i + 1].is_ascii_digit()
&& chars[i + 2].is_ascii_digit()
&& chars[i + 3] == '-'
&& chars[i + 4].is_ascii_digit()
&& chars[i + 5].is_ascii_digit()
&& chars[i + 6] == '-'
&& chars[i + 7].is_ascii_digit()
&& chars[i + 8].is_ascii_digit()
&& chars[i + 9].is_ascii_digit()
&& chars[i + 10].is_ascii_digit()
{
return true;
}
}
false
}
/// Simple safety classification for content.
///
/// Returns `"deny"`, `"redact"`, or `"allow"`.
///
/// Classification matches native `SafetyGate::check`:
/// - Credit card patterns -> "redact"
/// - SSN patterns -> "redact"
/// - Email patterns -> "redact"
/// - Custom sensitive keywords -> "deny"
pub fn safety_classify(content: &str) -> &'static str {
// PII patterns are redacted (matching native SafetyGate behavior)
if has_credit_card_pattern(content) {
return "redact";
}
if has_ssn_pattern(content) {
return "redact";
}
if has_email_pattern(content) {
return "redact";
}
// Custom sensitive keywords are denied (matching native custom_patterns -> Deny)
let lower = content.to_lowercase();
let deny_keywords = [
"password",
"secret",
"api_key",
"api-key",
"apikey",
"token",
"private_key",
"private-key",
];
for kw in &deny_keywords {
if lower.contains(kw) {
return "deny";
}
}
"allow"
}
/// Check for email-like patterns: local@domain.tld
pub fn has_email_pattern(content: &str) -> bool {
let chars: Vec<char> = content.chars().collect();
let len = chars.len();
for i in 0..len {
if chars[i] == '@' {
// Must have at least one local-part char before '@'
if i == 0 || chars[i - 1].is_whitespace() {
continue;
}
// Must have at least one domain char and a dot after '@'
if i + 1 >= len || chars[i + 1].is_whitespace() {
continue;
}
// Scan backwards to find start of local part
let mut start = i;
while start > 0 && is_email_char(chars[start - 1]) {
start -= 1;
}
if start == i {
continue;
}
// Scan forwards to find end of domain
let mut end = i + 1;
let mut has_dot = false;
while end < len && is_domain_char(chars[end]) {
if chars[end] == '.' {
has_dot = true;
}
end += 1;
}
if has_dot && end > i + 3 {
return true;
}
}
}
false
}
fn is_email_char(c: char) -> bool {
c.is_ascii_alphanumeric() || c == '.' || c == '+' || c == '-' || c == '_'
}
fn is_domain_char(c: char) -> bool {
c.is_ascii_alphanumeric() || c == '.' || c == '-'
}
/// Route a query string to the optimal search backend.
///
/// Returns `"Temporal"`, `"Graph"`, `"Keyword"`, or `"Hybrid"`.
///
/// Routing heuristics (matching native `QueryRouter::route`):
/// - Temporal keywords ("yesterday", "last week", etc.) -> Temporal
/// - Graph keywords ("related to", "connected to", etc.) -> Graph
/// - Quoted exact phrases -> Keyword
/// - Short queries (1-2 words) -> Keyword
/// - Everything else -> Hybrid
pub fn route_query(query: &str) -> &'static str {
let lower = query.to_lowercase();
let word_count = lower.split_whitespace().count();
// Temporal patterns (checked first, matching native router order)
let temporal_keywords = [
"yesterday",
"last week",
"last month",
"today",
"this morning",
"this afternoon",
"hours ago",
"minutes ago",
"days ago",
"between",
"before",
"after",
];
for kw in &temporal_keywords {
if lower.contains(kw) {
return "Temporal";
}
}
// Graph patterns
let graph_keywords = [
"related to",
"connected to",
"linked with",
"associated with",
"relationship between",
];
for kw in &graph_keywords {
if lower.contains(kw) {
return "Graph";
}
}
// Exact phrase (quoted)
if query.starts_with('"') && query.ends_with('"') {
return "Keyword";
}
// Very short queries are better served by keyword
if word_count <= 2 {
return "Keyword";
}
// Default: hybrid combines the best of both
"Hybrid"
}
// ---------------------------------------------------------------------------
// Unit tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity_identical() {
let v = vec![1.0, 2.0, 3.0];
let sim = cosine_similarity(&v, &v);
assert!((sim - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_opposite() {
let a = vec![1.0, 0.0];
let b = vec![-1.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!((sim + 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_zero_vector() {
let a = vec![0.0, 0.0];
let b = vec![1.0, 2.0];
assert_eq!(cosine_similarity(&a, &b), 0.0);
}
#[test]
fn test_hash_embed_deterministic() {
let v1 = hash_embed("hello world", 128);
let v2 = hash_embed("hello world", 128);
assert_eq!(v1, v2);
}
#[test]
fn test_hash_embed_normalized() {
let v = hash_embed("test text", 64);
let mag: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(mag - 1.0).abs() < 1e-4,
"magnitude should be ~1.0, got {mag}"
);
}
#[test]
fn test_hash_embed_different_texts_differ() {
let v1 = hash_embed("hello", 64);
let v2 = hash_embed("world", 64);
assert_ne!(v1, v2);
}
#[test]
fn test_has_credit_card_pattern() {
assert!(has_credit_card_pattern("my card is 1234 5678 9012 3456"));
assert!(has_credit_card_pattern("cc: 1234-5678-9012-3456"));
assert!(has_credit_card_pattern("number 1234567890123456 here"));
assert!(!has_credit_card_pattern("short 123456"));
assert!(!has_credit_card_pattern("no cards here"));
}
#[test]
fn test_has_ssn_pattern() {
assert!(has_ssn_pattern("ssn is 123-45-6789"));
assert!(has_ssn_pattern("start 999-99-9999 end"));
assert!(!has_ssn_pattern("not a ssn 12-345-6789"));
assert!(!has_ssn_pattern("1234-56-7890")); // preceded by extra digit
assert!(!has_ssn_pattern("no ssn here"));
}
#[test]
fn test_safety_classify_redact_cc() {
assert_eq!(safety_classify("pay with 4111-1111-1111-1111"), "redact");
}
#[test]
fn test_safety_classify_redact_ssn() {
assert_eq!(safety_classify("my ssn 123-45-6789"), "redact");
}
#[test]
fn test_safety_classify_redact_email() {
assert_eq!(safety_classify("contact user@example.com"), "redact");
}
#[test]
fn test_safety_classify_deny_password() {
assert_eq!(safety_classify("my password is foo"), "deny");
}
#[test]
fn test_safety_classify_deny_api_key() {
assert_eq!(safety_classify("api_key: sk-abc123"), "deny");
}
#[test]
fn test_safety_classify_allow() {
assert_eq!(safety_classify("the weather is nice"), "allow");
}
#[test]
fn test_has_email_pattern() {
assert!(has_email_pattern("contact user@example.com please"));
assert!(has_email_pattern("email: alice@test.org"));
assert!(!has_email_pattern("not an email"));
assert!(!has_email_pattern("@ alone"));
assert!(!has_email_pattern("no@d"));
}
#[test]
fn test_route_query_temporal() {
assert_eq!(route_query("what happened yesterday"), "Temporal");
assert_eq!(route_query("show me events from last week"), "Temporal");
}
#[test]
fn test_route_query_graph() {
assert_eq!(route_query("documents related to authentication"), "Graph");
assert_eq!(route_query("things connected to the API module"), "Graph");
}
#[test]
fn test_route_query_keyword_quoted() {
assert_eq!(route_query("\"exact phrase search\""), "Keyword");
}
#[test]
fn test_route_query_keyword_short() {
assert_eq!(route_query("rust programming"), "Keyword");
assert_eq!(route_query("hello"), "Keyword");
}
#[test]
fn test_route_query_hybrid() {
assert_eq!(route_query("something about machine learning"), "Hybrid");
assert_eq!(route_query("explain how embeddings work"), "Hybrid");
}
}

View File

@@ -0,0 +1,15 @@
//! WASM bindings for OSpipe.
//!
//! Provides browser-based personal AI memory search using vector embeddings.
//!
//! - [`helpers`] - Pure helper functions (cosine similarity, hashing, safety
//! checks, query routing) that are available on all targets for testing.
//! - `bindings` - wasm-bindgen exports, gated behind `target_arch = "wasm32"`.
/// Pure helper functions with no WASM dependencies.
/// Always compiled so that unit tests can run on the host target.
pub mod helpers;
/// wasm-bindgen exports. Only compiled for the `wasm32` target.
#[cfg(target_arch = "wasm32")]
pub mod bindings;