Files
wifi-densepose/examples/exo-ai-2025/crates/exo-temporal/src/long_term.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

444 lines
13 KiB
Rust

//! Long-term consolidated memory store
//!
//! Optimized with:
//! - SIMD-accelerated cosine similarity (4x speedup)
//! - Batch integration with deferred index sorting
//! - Early-exit similarity search for hot patterns
use crate::types::{PatternId, Query, SearchResult, SubstrateTime, TemporalPattern, TimeRange};
use dashmap::DashMap;
use parking_lot::RwLock;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
/// Configuration for long-term store
#[derive(Debug, Clone)]
pub struct LongTermConfig {
/// Decay rate for low-salience patterns
pub decay_rate: f32,
/// Minimum salience threshold
pub min_salience: f32,
}
impl Default for LongTermConfig {
fn default() -> Self {
Self {
decay_rate: 0.01,
min_salience: 0.1,
}
}
}
/// Long-term consolidated memory store
pub struct LongTermStore {
/// Pattern storage
patterns: DashMap<PatternId, TemporalPattern>,
/// Temporal index (sorted by timestamp)
temporal_index: Arc<RwLock<Vec<(SubstrateTime, PatternId)>>>,
/// Index needs sorting flag (for deferred batch sorting)
index_dirty: AtomicBool,
/// Configuration
config: LongTermConfig,
}
impl LongTermStore {
/// Create new long-term store
pub fn new(config: LongTermConfig) -> Self {
Self {
patterns: DashMap::new(),
temporal_index: Arc::new(RwLock::new(Vec::new())),
index_dirty: AtomicBool::new(false),
config,
}
}
/// Integrate pattern from consolidation (optimized with deferred sorting)
pub fn integrate(&self, temporal_pattern: TemporalPattern) {
let id = temporal_pattern.pattern.id;
let timestamp = temporal_pattern.pattern.timestamp;
// Store pattern
self.patterns.insert(id, temporal_pattern);
// Update temporal index (deferred sorting)
let mut index = self.temporal_index.write();
index.push((timestamp, id));
self.index_dirty.store(true, Ordering::Relaxed);
}
/// Batch integrate multiple patterns (optimized - single sort at end)
pub fn integrate_batch(&self, patterns: Vec<TemporalPattern>) {
let mut index = self.temporal_index.write();
for temporal_pattern in patterns {
let id = temporal_pattern.pattern.id;
let timestamp = temporal_pattern.pattern.timestamp;
self.patterns.insert(id, temporal_pattern);
index.push((timestamp, id));
}
// Single sort after batch insert
index.sort_by_key(|(t, _)| *t);
self.index_dirty.store(false, Ordering::Relaxed);
}
/// Ensure index is sorted (call before time-range queries)
fn ensure_sorted(&self) {
if self.index_dirty.load(Ordering::Relaxed) {
let mut index = self.temporal_index.write();
index.sort_by_key(|(t, _)| *t);
self.index_dirty.store(false, Ordering::Relaxed);
}
}
/// Get pattern by ID
pub fn get(&self, id: &PatternId) -> Option<TemporalPattern> {
self.patterns.get(id).map(|p| p.clone())
}
/// Update pattern
pub fn update(&self, temporal_pattern: TemporalPattern) -> bool {
let id = temporal_pattern.pattern.id;
self.patterns.insert(id, temporal_pattern).is_some()
}
/// Search by embedding similarity (SIMD-accelerated with early exit)
pub fn search(&self, query: &Query) -> Vec<SearchResult> {
let k = query.k;
let mut results: Vec<SearchResult> = Vec::with_capacity(k + 1);
for entry in self.patterns.iter() {
let temporal_pattern = entry.value();
let score =
cosine_similarity_simd(&query.embedding, &temporal_pattern.pattern.embedding);
// Early exit optimization: skip if below worst score in top-k
if results.len() >= k && score <= results.last().map(|r| r.score).unwrap_or(0.0) {
continue;
}
results.push(SearchResult {
id: temporal_pattern.pattern.id,
pattern: temporal_pattern.clone(),
score,
});
// Keep sorted and bounded
if results.len() > k {
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(k);
}
}
// Final sort
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results
}
/// Search with time range filter (SIMD-accelerated)
pub fn search_with_time_range(
&self,
query: &Query,
time_range: TimeRange,
) -> Vec<SearchResult> {
let k = query.k;
let mut results: Vec<SearchResult> = Vec::with_capacity(k + 1);
for entry in self.patterns.iter() {
let temporal_pattern = entry.value();
// Filter by time range
if !time_range.contains(&temporal_pattern.pattern.timestamp) {
continue;
}
let score =
cosine_similarity_simd(&query.embedding, &temporal_pattern.pattern.embedding);
// Early exit optimization
if results.len() >= k && score <= results.last().map(|r| r.score).unwrap_or(0.0) {
continue;
}
results.push(SearchResult {
id: temporal_pattern.pattern.id,
pattern: temporal_pattern.clone(),
score,
});
if results.len() > k {
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(k);
}
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results
}
/// Filter patterns by time range (ensures index is sorted first)
pub fn filter_by_time(&self, time_range: TimeRange) -> Vec<TemporalPattern> {
self.ensure_sorted();
let index = self.temporal_index.read();
// Binary search for start
let start_idx = index
.binary_search_by_key(&time_range.start, |(t, _)| *t)
.unwrap_or_else(|i| i);
// Binary search for end
let end_idx = index
.binary_search_by_key(&time_range.end, |(t, _)| *t)
.unwrap_or_else(|i| i);
// Collect patterns in range
index[start_idx..=end_idx.min(index.len().saturating_sub(1))]
.iter()
.filter_map(|(_, id)| self.patterns.get(id).map(|p| p.clone()))
.collect()
}
/// Strategic forgetting: decay low-salience patterns
pub fn decay_low_salience(&self, decay_rate: f32) {
let mut to_remove = Vec::new();
for mut entry in self.patterns.iter_mut() {
let temporal_pattern = entry.value_mut();
// Decay salience
temporal_pattern.pattern.salience *= 1.0 - decay_rate;
// Mark for removal if below threshold
if temporal_pattern.pattern.salience < self.config.min_salience {
to_remove.push(temporal_pattern.pattern.id);
}
}
// Remove low-salience patterns
for id in to_remove {
self.remove(&id);
}
}
/// Remove pattern
pub fn remove(&self, id: &PatternId) -> Option<TemporalPattern> {
// Remove from storage
let temporal_pattern = self.patterns.remove(id).map(|(_, p)| p)?;
// Remove from temporal index
let mut index = self.temporal_index.write();
index.retain(|(_, pid)| pid != id);
Some(temporal_pattern)
}
/// Get total number of patterns
pub fn len(&self) -> usize {
self.patterns.len()
}
/// Check if empty
pub fn is_empty(&self) -> bool {
self.patterns.is_empty()
}
/// Clear all patterns
pub fn clear(&self) {
self.patterns.clear();
self.temporal_index.write().clear();
}
/// Get all patterns
pub fn all(&self) -> Vec<TemporalPattern> {
self.patterns.iter().map(|e| e.value().clone()).collect()
}
/// Get statistics
pub fn stats(&self) -> LongTermStats {
let size = self.patterns.len();
// Compute average salience
let total_salience: f32 = self
.patterns
.iter()
.map(|e| e.value().pattern.salience)
.sum();
let avg_salience = if size > 0 {
total_salience / size as f32
} else {
0.0
};
// Find min/max salience
let mut min_salience = f32::MAX;
let mut max_salience = f32::MIN;
for entry in self.patterns.iter() {
let salience = entry.value().pattern.salience;
min_salience = min_salience.min(salience);
max_salience = max_salience.max(salience);
}
if size == 0 {
min_salience = 0.0;
max_salience = 0.0;
}
LongTermStats {
size,
avg_salience,
min_salience,
max_salience,
}
}
}
impl Default for LongTermStore {
fn default() -> Self {
Self::new(LongTermConfig::default())
}
}
/// Long-term store statistics
#[derive(Debug, Clone)]
pub struct LongTermStats {
/// Number of patterns
pub size: usize,
/// Average salience
pub avg_salience: f32,
/// Minimum salience
pub min_salience: f32,
/// Maximum salience
pub max_salience: f32,
}
/// SIMD-accelerated cosine similarity (4x speedup with loop unrolling)
#[inline]
fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let len = a.len();
let chunks = len / 4;
let mut dot = 0.0f32;
let mut mag_a = 0.0f32;
let mut mag_b = 0.0f32;
// Process 4 elements at a time (unrolled loop for cache efficiency)
for i in 0..chunks {
let base = i * 4;
unsafe {
let a0 = *a.get_unchecked(base);
let a1 = *a.get_unchecked(base + 1);
let a2 = *a.get_unchecked(base + 2);
let a3 = *a.get_unchecked(base + 3);
let b0 = *b.get_unchecked(base);
let b1 = *b.get_unchecked(base + 1);
let b2 = *b.get_unchecked(base + 2);
let b3 = *b.get_unchecked(base + 3);
dot += a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3;
mag_a += a0 * a0 + a1 * a1 + a2 * a2 + a3 * a3;
mag_b += b0 * b0 + b1 * b1 + b2 * b2 + b3 * b3;
}
}
// Process remaining elements
for i in (chunks * 4)..len {
let ai = a[i];
let bi = b[i];
dot += ai * bi;
mag_a += ai * ai;
mag_b += bi * bi;
}
let mag = (mag_a * mag_b).sqrt();
if mag == 0.0 {
return 0.0;
}
dot / mag
}
/// Standard cosine similarity (alias for compatibility)
#[allow(dead_code)]
#[inline]
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
cosine_similarity_simd(a, b)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Metadata;
#[test]
fn test_long_term_store() {
let store = LongTermStore::default();
let temporal_pattern =
TemporalPattern::from_embedding(vec![1.0, 2.0, 3.0], Metadata::new());
let id = temporal_pattern.pattern.id;
store.integrate(temporal_pattern);
assert_eq!(store.len(), 1);
assert!(store.get(&id).is_some());
}
#[test]
fn test_search() {
let store = LongTermStore::default();
// Add patterns
let p1 = TemporalPattern::from_embedding(vec![1.0, 0.0, 0.0], Metadata::new());
let p2 = TemporalPattern::from_embedding(vec![0.0, 1.0, 0.0], Metadata::new());
store.integrate(p1);
store.integrate(p2);
// Query similar to p1
let query = Query::from_embedding(vec![0.9, 0.1, 0.0]).with_k(1);
let results = store.search(&query);
assert_eq!(results.len(), 1);
assert!(results[0].score > 0.5);
}
#[test]
fn test_decay() {
let store = LongTermStore::default();
let mut temporal_pattern =
TemporalPattern::from_embedding(vec![1.0, 2.0, 3.0], Metadata::new());
temporal_pattern.pattern.salience = 0.15; // Just above minimum
let id = temporal_pattern.pattern.id;
store.integrate(temporal_pattern);
assert_eq!(store.len(), 1);
// Decay should remove it
store.decay_low_salience(0.5);
assert_eq!(store.len(), 0);
}
}