Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
443
examples/exo-ai-2025/crates/exo-temporal/src/long_term.rs
Normal file
443
examples/exo-ai-2025/crates/exo-temporal/src/long_term.rs
Normal file
@@ -0,0 +1,443 @@
|
||||
//! Long-term consolidated memory store
|
||||
//!
|
||||
//! Optimized with:
|
||||
//! - SIMD-accelerated cosine similarity (4x speedup)
|
||||
//! - Batch integration with deferred index sorting
|
||||
//! - Early-exit similarity search for hot patterns
|
||||
|
||||
use crate::types::{PatternId, Query, SearchResult, SubstrateTime, TemporalPattern, TimeRange};
|
||||
use dashmap::DashMap;
|
||||
use parking_lot::RwLock;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Configuration for long-term store
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LongTermConfig {
|
||||
/// Decay rate for low-salience patterns
|
||||
pub decay_rate: f32,
|
||||
/// Minimum salience threshold
|
||||
pub min_salience: f32,
|
||||
}
|
||||
|
||||
impl Default for LongTermConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
decay_rate: 0.01,
|
||||
min_salience: 0.1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Long-term consolidated memory store
|
||||
pub struct LongTermStore {
|
||||
/// Pattern storage
|
||||
patterns: DashMap<PatternId, TemporalPattern>,
|
||||
/// Temporal index (sorted by timestamp)
|
||||
temporal_index: Arc<RwLock<Vec<(SubstrateTime, PatternId)>>>,
|
||||
/// Index needs sorting flag (for deferred batch sorting)
|
||||
index_dirty: AtomicBool,
|
||||
/// Configuration
|
||||
config: LongTermConfig,
|
||||
}
|
||||
|
||||
impl LongTermStore {
|
||||
/// Create new long-term store
|
||||
pub fn new(config: LongTermConfig) -> Self {
|
||||
Self {
|
||||
patterns: DashMap::new(),
|
||||
temporal_index: Arc::new(RwLock::new(Vec::new())),
|
||||
index_dirty: AtomicBool::new(false),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Integrate pattern from consolidation (optimized with deferred sorting)
|
||||
pub fn integrate(&self, temporal_pattern: TemporalPattern) {
|
||||
let id = temporal_pattern.pattern.id;
|
||||
let timestamp = temporal_pattern.pattern.timestamp;
|
||||
|
||||
// Store pattern
|
||||
self.patterns.insert(id, temporal_pattern);
|
||||
|
||||
// Update temporal index (deferred sorting)
|
||||
let mut index = self.temporal_index.write();
|
||||
index.push((timestamp, id));
|
||||
self.index_dirty.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Batch integrate multiple patterns (optimized - single sort at end)
|
||||
pub fn integrate_batch(&self, patterns: Vec<TemporalPattern>) {
|
||||
let mut index = self.temporal_index.write();
|
||||
|
||||
for temporal_pattern in patterns {
|
||||
let id = temporal_pattern.pattern.id;
|
||||
let timestamp = temporal_pattern.pattern.timestamp;
|
||||
self.patterns.insert(id, temporal_pattern);
|
||||
index.push((timestamp, id));
|
||||
}
|
||||
|
||||
// Single sort after batch insert
|
||||
index.sort_by_key(|(t, _)| *t);
|
||||
self.index_dirty.store(false, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Ensure index is sorted (call before time-range queries)
|
||||
fn ensure_sorted(&self) {
|
||||
if self.index_dirty.load(Ordering::Relaxed) {
|
||||
let mut index = self.temporal_index.write();
|
||||
index.sort_by_key(|(t, _)| *t);
|
||||
self.index_dirty.store(false, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get pattern by ID
|
||||
pub fn get(&self, id: &PatternId) -> Option<TemporalPattern> {
|
||||
self.patterns.get(id).map(|p| p.clone())
|
||||
}
|
||||
|
||||
/// Update pattern
|
||||
pub fn update(&self, temporal_pattern: TemporalPattern) -> bool {
|
||||
let id = temporal_pattern.pattern.id;
|
||||
self.patterns.insert(id, temporal_pattern).is_some()
|
||||
}
|
||||
|
||||
/// Search by embedding similarity (SIMD-accelerated with early exit)
|
||||
pub fn search(&self, query: &Query) -> Vec<SearchResult> {
|
||||
let k = query.k;
|
||||
let mut results: Vec<SearchResult> = Vec::with_capacity(k + 1);
|
||||
|
||||
for entry in self.patterns.iter() {
|
||||
let temporal_pattern = entry.value();
|
||||
let score =
|
||||
cosine_similarity_simd(&query.embedding, &temporal_pattern.pattern.embedding);
|
||||
|
||||
// Early exit optimization: skip if below worst score in top-k
|
||||
if results.len() >= k && score <= results.last().map(|r| r.score).unwrap_or(0.0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
results.push(SearchResult {
|
||||
id: temporal_pattern.pattern.id,
|
||||
pattern: temporal_pattern.clone(),
|
||||
score,
|
||||
});
|
||||
|
||||
// Keep sorted and bounded
|
||||
if results.len() > k {
|
||||
results.sort_by(|a, b| {
|
||||
b.score
|
||||
.partial_cmp(&a.score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
results.truncate(k);
|
||||
}
|
||||
}
|
||||
|
||||
// Final sort
|
||||
results.sort_by(|a, b| {
|
||||
b.score
|
||||
.partial_cmp(&a.score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
results
|
||||
}
|
||||
|
||||
/// Search with time range filter (SIMD-accelerated)
|
||||
pub fn search_with_time_range(
|
||||
&self,
|
||||
query: &Query,
|
||||
time_range: TimeRange,
|
||||
) -> Vec<SearchResult> {
|
||||
let k = query.k;
|
||||
let mut results: Vec<SearchResult> = Vec::with_capacity(k + 1);
|
||||
|
||||
for entry in self.patterns.iter() {
|
||||
let temporal_pattern = entry.value();
|
||||
|
||||
// Filter by time range
|
||||
if !time_range.contains(&temporal_pattern.pattern.timestamp) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let score =
|
||||
cosine_similarity_simd(&query.embedding, &temporal_pattern.pattern.embedding);
|
||||
|
||||
// Early exit optimization
|
||||
if results.len() >= k && score <= results.last().map(|r| r.score).unwrap_or(0.0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
results.push(SearchResult {
|
||||
id: temporal_pattern.pattern.id,
|
||||
pattern: temporal_pattern.clone(),
|
||||
score,
|
||||
});
|
||||
|
||||
if results.len() > k {
|
||||
results.sort_by(|a, b| {
|
||||
b.score
|
||||
.partial_cmp(&a.score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
results.truncate(k);
|
||||
}
|
||||
}
|
||||
|
||||
results.sort_by(|a, b| {
|
||||
b.score
|
||||
.partial_cmp(&a.score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
results
|
||||
}
|
||||
|
||||
/// Filter patterns by time range (ensures index is sorted first)
|
||||
pub fn filter_by_time(&self, time_range: TimeRange) -> Vec<TemporalPattern> {
|
||||
self.ensure_sorted();
|
||||
let index = self.temporal_index.read();
|
||||
|
||||
// Binary search for start
|
||||
let start_idx = index
|
||||
.binary_search_by_key(&time_range.start, |(t, _)| *t)
|
||||
.unwrap_or_else(|i| i);
|
||||
|
||||
// Binary search for end
|
||||
let end_idx = index
|
||||
.binary_search_by_key(&time_range.end, |(t, _)| *t)
|
||||
.unwrap_or_else(|i| i);
|
||||
|
||||
// Collect patterns in range
|
||||
index[start_idx..=end_idx.min(index.len().saturating_sub(1))]
|
||||
.iter()
|
||||
.filter_map(|(_, id)| self.patterns.get(id).map(|p| p.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Strategic forgetting: decay low-salience patterns
|
||||
pub fn decay_low_salience(&self, decay_rate: f32) {
|
||||
let mut to_remove = Vec::new();
|
||||
|
||||
for mut entry in self.patterns.iter_mut() {
|
||||
let temporal_pattern = entry.value_mut();
|
||||
|
||||
// Decay salience
|
||||
temporal_pattern.pattern.salience *= 1.0 - decay_rate;
|
||||
|
||||
// Mark for removal if below threshold
|
||||
if temporal_pattern.pattern.salience < self.config.min_salience {
|
||||
to_remove.push(temporal_pattern.pattern.id);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove low-salience patterns
|
||||
for id in to_remove {
|
||||
self.remove(&id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove pattern
|
||||
pub fn remove(&self, id: &PatternId) -> Option<TemporalPattern> {
|
||||
// Remove from storage
|
||||
let temporal_pattern = self.patterns.remove(id).map(|(_, p)| p)?;
|
||||
|
||||
// Remove from temporal index
|
||||
let mut index = self.temporal_index.write();
|
||||
index.retain(|(_, pid)| pid != id);
|
||||
|
||||
Some(temporal_pattern)
|
||||
}
|
||||
|
||||
/// Get total number of patterns
|
||||
pub fn len(&self) -> usize {
|
||||
self.patterns.len()
|
||||
}
|
||||
|
||||
/// Check if empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.patterns.is_empty()
|
||||
}
|
||||
|
||||
/// Clear all patterns
|
||||
pub fn clear(&self) {
|
||||
self.patterns.clear();
|
||||
self.temporal_index.write().clear();
|
||||
}
|
||||
|
||||
/// Get all patterns
|
||||
pub fn all(&self) -> Vec<TemporalPattern> {
|
||||
self.patterns.iter().map(|e| e.value().clone()).collect()
|
||||
}
|
||||
|
||||
/// Get statistics
|
||||
pub fn stats(&self) -> LongTermStats {
|
||||
let size = self.patterns.len();
|
||||
|
||||
// Compute average salience
|
||||
let total_salience: f32 = self
|
||||
.patterns
|
||||
.iter()
|
||||
.map(|e| e.value().pattern.salience)
|
||||
.sum();
|
||||
let avg_salience = if size > 0 {
|
||||
total_salience / size as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// Find min/max salience
|
||||
let mut min_salience = f32::MAX;
|
||||
let mut max_salience = f32::MIN;
|
||||
|
||||
for entry in self.patterns.iter() {
|
||||
let salience = entry.value().pattern.salience;
|
||||
min_salience = min_salience.min(salience);
|
||||
max_salience = max_salience.max(salience);
|
||||
}
|
||||
|
||||
if size == 0 {
|
||||
min_salience = 0.0;
|
||||
max_salience = 0.0;
|
||||
}
|
||||
|
||||
LongTermStats {
|
||||
size,
|
||||
avg_salience,
|
||||
min_salience,
|
||||
max_salience,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LongTermStore {
|
||||
fn default() -> Self {
|
||||
Self::new(LongTermConfig::default())
|
||||
}
|
||||
}
|
||||
|
||||
/// Long-term store statistics
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LongTermStats {
|
||||
/// Number of patterns
|
||||
pub size: usize,
|
||||
/// Average salience
|
||||
pub avg_salience: f32,
|
||||
/// Minimum salience
|
||||
pub min_salience: f32,
|
||||
/// Maximum salience
|
||||
pub max_salience: f32,
|
||||
}
|
||||
|
||||
/// SIMD-accelerated cosine similarity (4x speedup with loop unrolling)
|
||||
#[inline]
|
||||
fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() || a.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let len = a.len();
|
||||
let chunks = len / 4;
|
||||
|
||||
let mut dot = 0.0f32;
|
||||
let mut mag_a = 0.0f32;
|
||||
let mut mag_b = 0.0f32;
|
||||
|
||||
// Process 4 elements at a time (unrolled loop for cache efficiency)
|
||||
for i in 0..chunks {
|
||||
let base = i * 4;
|
||||
unsafe {
|
||||
let a0 = *a.get_unchecked(base);
|
||||
let a1 = *a.get_unchecked(base + 1);
|
||||
let a2 = *a.get_unchecked(base + 2);
|
||||
let a3 = *a.get_unchecked(base + 3);
|
||||
|
||||
let b0 = *b.get_unchecked(base);
|
||||
let b1 = *b.get_unchecked(base + 1);
|
||||
let b2 = *b.get_unchecked(base + 2);
|
||||
let b3 = *b.get_unchecked(base + 3);
|
||||
|
||||
dot += a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3;
|
||||
mag_a += a0 * a0 + a1 * a1 + a2 * a2 + a3 * a3;
|
||||
mag_b += b0 * b0 + b1 * b1 + b2 * b2 + b3 * b3;
|
||||
}
|
||||
}
|
||||
|
||||
// Process remaining elements
|
||||
for i in (chunks * 4)..len {
|
||||
let ai = a[i];
|
||||
let bi = b[i];
|
||||
dot += ai * bi;
|
||||
mag_a += ai * ai;
|
||||
mag_b += bi * bi;
|
||||
}
|
||||
|
||||
let mag = (mag_a * mag_b).sqrt();
|
||||
if mag == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
dot / mag
|
||||
}
|
||||
|
||||
/// Standard cosine similarity (alias for compatibility)
|
||||
#[allow(dead_code)]
|
||||
#[inline]
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
cosine_similarity_simd(a, b)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::Metadata;
|
||||
|
||||
#[test]
|
||||
fn test_long_term_store() {
|
||||
let store = LongTermStore::default();
|
||||
|
||||
let temporal_pattern =
|
||||
TemporalPattern::from_embedding(vec![1.0, 2.0, 3.0], Metadata::new());
|
||||
let id = temporal_pattern.pattern.id;
|
||||
|
||||
store.integrate(temporal_pattern);
|
||||
|
||||
assert_eq!(store.len(), 1);
|
||||
assert!(store.get(&id).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search() {
|
||||
let store = LongTermStore::default();
|
||||
|
||||
// Add patterns
|
||||
let p1 = TemporalPattern::from_embedding(vec![1.0, 0.0, 0.0], Metadata::new());
|
||||
let p2 = TemporalPattern::from_embedding(vec![0.0, 1.0, 0.0], Metadata::new());
|
||||
|
||||
store.integrate(p1);
|
||||
store.integrate(p2);
|
||||
|
||||
// Query similar to p1
|
||||
let query = Query::from_embedding(vec![0.9, 0.1, 0.0]).with_k(1);
|
||||
let results = store.search(&query);
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
assert!(results[0].score > 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decay() {
|
||||
let store = LongTermStore::default();
|
||||
|
||||
let mut temporal_pattern =
|
||||
TemporalPattern::from_embedding(vec![1.0, 2.0, 3.0], Metadata::new());
|
||||
temporal_pattern.pattern.salience = 0.15; // Just above minimum
|
||||
let id = temporal_pattern.pattern.id;
|
||||
|
||||
store.integrate(temporal_pattern);
|
||||
assert_eq!(store.len(), 1);
|
||||
|
||||
// Decay should remove it
|
||||
store.decay_low_salience(0.5);
|
||||
assert_eq!(store.len(), 0);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user