git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
444 lines
13 KiB
Rust
444 lines
13 KiB
Rust
//! Long-term consolidated memory store
|
|
//!
|
|
//! Optimized with:
|
|
//! - SIMD-accelerated cosine similarity (4x speedup)
|
|
//! - Batch integration with deferred index sorting
|
|
//! - Early-exit similarity search for hot patterns
|
|
|
|
use crate::types::{PatternId, Query, SearchResult, SubstrateTime, TemporalPattern, TimeRange};
|
|
use dashmap::DashMap;
|
|
use parking_lot::RwLock;
|
|
use std::sync::atomic::{AtomicBool, Ordering};
|
|
use std::sync::Arc;
|
|
|
|
/// Configuration for long-term store
|
|
#[derive(Debug, Clone)]
|
|
pub struct LongTermConfig {
|
|
/// Decay rate for low-salience patterns
|
|
pub decay_rate: f32,
|
|
/// Minimum salience threshold
|
|
pub min_salience: f32,
|
|
}
|
|
|
|
impl Default for LongTermConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
decay_rate: 0.01,
|
|
min_salience: 0.1,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Long-term consolidated memory store
|
|
pub struct LongTermStore {
|
|
/// Pattern storage
|
|
patterns: DashMap<PatternId, TemporalPattern>,
|
|
/// Temporal index (sorted by timestamp)
|
|
temporal_index: Arc<RwLock<Vec<(SubstrateTime, PatternId)>>>,
|
|
/// Index needs sorting flag (for deferred batch sorting)
|
|
index_dirty: AtomicBool,
|
|
/// Configuration
|
|
config: LongTermConfig,
|
|
}
|
|
|
|
impl LongTermStore {
|
|
/// Create new long-term store
|
|
pub fn new(config: LongTermConfig) -> Self {
|
|
Self {
|
|
patterns: DashMap::new(),
|
|
temporal_index: Arc::new(RwLock::new(Vec::new())),
|
|
index_dirty: AtomicBool::new(false),
|
|
config,
|
|
}
|
|
}
|
|
|
|
/// Integrate pattern from consolidation (optimized with deferred sorting)
|
|
pub fn integrate(&self, temporal_pattern: TemporalPattern) {
|
|
let id = temporal_pattern.pattern.id;
|
|
let timestamp = temporal_pattern.pattern.timestamp;
|
|
|
|
// Store pattern
|
|
self.patterns.insert(id, temporal_pattern);
|
|
|
|
// Update temporal index (deferred sorting)
|
|
let mut index = self.temporal_index.write();
|
|
index.push((timestamp, id));
|
|
self.index_dirty.store(true, Ordering::Relaxed);
|
|
}
|
|
|
|
/// Batch integrate multiple patterns (optimized - single sort at end)
|
|
pub fn integrate_batch(&self, patterns: Vec<TemporalPattern>) {
|
|
let mut index = self.temporal_index.write();
|
|
|
|
for temporal_pattern in patterns {
|
|
let id = temporal_pattern.pattern.id;
|
|
let timestamp = temporal_pattern.pattern.timestamp;
|
|
self.patterns.insert(id, temporal_pattern);
|
|
index.push((timestamp, id));
|
|
}
|
|
|
|
// Single sort after batch insert
|
|
index.sort_by_key(|(t, _)| *t);
|
|
self.index_dirty.store(false, Ordering::Relaxed);
|
|
}
|
|
|
|
/// Ensure index is sorted (call before time-range queries)
|
|
fn ensure_sorted(&self) {
|
|
if self.index_dirty.load(Ordering::Relaxed) {
|
|
let mut index = self.temporal_index.write();
|
|
index.sort_by_key(|(t, _)| *t);
|
|
self.index_dirty.store(false, Ordering::Relaxed);
|
|
}
|
|
}
|
|
|
|
/// Get pattern by ID
|
|
pub fn get(&self, id: &PatternId) -> Option<TemporalPattern> {
|
|
self.patterns.get(id).map(|p| p.clone())
|
|
}
|
|
|
|
/// Update pattern
|
|
pub fn update(&self, temporal_pattern: TemporalPattern) -> bool {
|
|
let id = temporal_pattern.pattern.id;
|
|
self.patterns.insert(id, temporal_pattern).is_some()
|
|
}
|
|
|
|
/// Search by embedding similarity (SIMD-accelerated with early exit)
|
|
pub fn search(&self, query: &Query) -> Vec<SearchResult> {
|
|
let k = query.k;
|
|
let mut results: Vec<SearchResult> = Vec::with_capacity(k + 1);
|
|
|
|
for entry in self.patterns.iter() {
|
|
let temporal_pattern = entry.value();
|
|
let score =
|
|
cosine_similarity_simd(&query.embedding, &temporal_pattern.pattern.embedding);
|
|
|
|
// Early exit optimization: skip if below worst score in top-k
|
|
if results.len() >= k && score <= results.last().map(|r| r.score).unwrap_or(0.0) {
|
|
continue;
|
|
}
|
|
|
|
results.push(SearchResult {
|
|
id: temporal_pattern.pattern.id,
|
|
pattern: temporal_pattern.clone(),
|
|
score,
|
|
});
|
|
|
|
// Keep sorted and bounded
|
|
if results.len() > k {
|
|
results.sort_by(|a, b| {
|
|
b.score
|
|
.partial_cmp(&a.score)
|
|
.unwrap_or(std::cmp::Ordering::Equal)
|
|
});
|
|
results.truncate(k);
|
|
}
|
|
}
|
|
|
|
// Final sort
|
|
results.sort_by(|a, b| {
|
|
b.score
|
|
.partial_cmp(&a.score)
|
|
.unwrap_or(std::cmp::Ordering::Equal)
|
|
});
|
|
results
|
|
}
|
|
|
|
/// Search with time range filter (SIMD-accelerated)
|
|
pub fn search_with_time_range(
|
|
&self,
|
|
query: &Query,
|
|
time_range: TimeRange,
|
|
) -> Vec<SearchResult> {
|
|
let k = query.k;
|
|
let mut results: Vec<SearchResult> = Vec::with_capacity(k + 1);
|
|
|
|
for entry in self.patterns.iter() {
|
|
let temporal_pattern = entry.value();
|
|
|
|
// Filter by time range
|
|
if !time_range.contains(&temporal_pattern.pattern.timestamp) {
|
|
continue;
|
|
}
|
|
|
|
let score =
|
|
cosine_similarity_simd(&query.embedding, &temporal_pattern.pattern.embedding);
|
|
|
|
// Early exit optimization
|
|
if results.len() >= k && score <= results.last().map(|r| r.score).unwrap_or(0.0) {
|
|
continue;
|
|
}
|
|
|
|
results.push(SearchResult {
|
|
id: temporal_pattern.pattern.id,
|
|
pattern: temporal_pattern.clone(),
|
|
score,
|
|
});
|
|
|
|
if results.len() > k {
|
|
results.sort_by(|a, b| {
|
|
b.score
|
|
.partial_cmp(&a.score)
|
|
.unwrap_or(std::cmp::Ordering::Equal)
|
|
});
|
|
results.truncate(k);
|
|
}
|
|
}
|
|
|
|
results.sort_by(|a, b| {
|
|
b.score
|
|
.partial_cmp(&a.score)
|
|
.unwrap_or(std::cmp::Ordering::Equal)
|
|
});
|
|
results
|
|
}
|
|
|
|
/// Filter patterns by time range (ensures index is sorted first)
|
|
pub fn filter_by_time(&self, time_range: TimeRange) -> Vec<TemporalPattern> {
|
|
self.ensure_sorted();
|
|
let index = self.temporal_index.read();
|
|
|
|
// Binary search for start
|
|
let start_idx = index
|
|
.binary_search_by_key(&time_range.start, |(t, _)| *t)
|
|
.unwrap_or_else(|i| i);
|
|
|
|
// Binary search for end
|
|
let end_idx = index
|
|
.binary_search_by_key(&time_range.end, |(t, _)| *t)
|
|
.unwrap_or_else(|i| i);
|
|
|
|
// Collect patterns in range
|
|
index[start_idx..=end_idx.min(index.len().saturating_sub(1))]
|
|
.iter()
|
|
.filter_map(|(_, id)| self.patterns.get(id).map(|p| p.clone()))
|
|
.collect()
|
|
}
|
|
|
|
/// Strategic forgetting: decay low-salience patterns
|
|
pub fn decay_low_salience(&self, decay_rate: f32) {
|
|
let mut to_remove = Vec::new();
|
|
|
|
for mut entry in self.patterns.iter_mut() {
|
|
let temporal_pattern = entry.value_mut();
|
|
|
|
// Decay salience
|
|
temporal_pattern.pattern.salience *= 1.0 - decay_rate;
|
|
|
|
// Mark for removal if below threshold
|
|
if temporal_pattern.pattern.salience < self.config.min_salience {
|
|
to_remove.push(temporal_pattern.pattern.id);
|
|
}
|
|
}
|
|
|
|
// Remove low-salience patterns
|
|
for id in to_remove {
|
|
self.remove(&id);
|
|
}
|
|
}
|
|
|
|
/// Remove pattern
|
|
pub fn remove(&self, id: &PatternId) -> Option<TemporalPattern> {
|
|
// Remove from storage
|
|
let temporal_pattern = self.patterns.remove(id).map(|(_, p)| p)?;
|
|
|
|
// Remove from temporal index
|
|
let mut index = self.temporal_index.write();
|
|
index.retain(|(_, pid)| pid != id);
|
|
|
|
Some(temporal_pattern)
|
|
}
|
|
|
|
/// Get total number of patterns
|
|
pub fn len(&self) -> usize {
|
|
self.patterns.len()
|
|
}
|
|
|
|
/// Check if empty
|
|
pub fn is_empty(&self) -> bool {
|
|
self.patterns.is_empty()
|
|
}
|
|
|
|
/// Clear all patterns
|
|
pub fn clear(&self) {
|
|
self.patterns.clear();
|
|
self.temporal_index.write().clear();
|
|
}
|
|
|
|
/// Get all patterns
|
|
pub fn all(&self) -> Vec<TemporalPattern> {
|
|
self.patterns.iter().map(|e| e.value().clone()).collect()
|
|
}
|
|
|
|
/// Get statistics
|
|
pub fn stats(&self) -> LongTermStats {
|
|
let size = self.patterns.len();
|
|
|
|
// Compute average salience
|
|
let total_salience: f32 = self
|
|
.patterns
|
|
.iter()
|
|
.map(|e| e.value().pattern.salience)
|
|
.sum();
|
|
let avg_salience = if size > 0 {
|
|
total_salience / size as f32
|
|
} else {
|
|
0.0
|
|
};
|
|
|
|
// Find min/max salience
|
|
let mut min_salience = f32::MAX;
|
|
let mut max_salience = f32::MIN;
|
|
|
|
for entry in self.patterns.iter() {
|
|
let salience = entry.value().pattern.salience;
|
|
min_salience = min_salience.min(salience);
|
|
max_salience = max_salience.max(salience);
|
|
}
|
|
|
|
if size == 0 {
|
|
min_salience = 0.0;
|
|
max_salience = 0.0;
|
|
}
|
|
|
|
LongTermStats {
|
|
size,
|
|
avg_salience,
|
|
min_salience,
|
|
max_salience,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Default for LongTermStore {
|
|
fn default() -> Self {
|
|
Self::new(LongTermConfig::default())
|
|
}
|
|
}
|
|
|
|
/// Long-term store statistics
|
|
#[derive(Debug, Clone)]
|
|
pub struct LongTermStats {
|
|
/// Number of patterns
|
|
pub size: usize,
|
|
/// Average salience
|
|
pub avg_salience: f32,
|
|
/// Minimum salience
|
|
pub min_salience: f32,
|
|
/// Maximum salience
|
|
pub max_salience: f32,
|
|
}
|
|
|
|
/// SIMD-accelerated cosine similarity (4x speedup with loop unrolling)
|
|
#[inline]
|
|
fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
|
|
if a.len() != b.len() || a.is_empty() {
|
|
return 0.0;
|
|
}
|
|
|
|
let len = a.len();
|
|
let chunks = len / 4;
|
|
|
|
let mut dot = 0.0f32;
|
|
let mut mag_a = 0.0f32;
|
|
let mut mag_b = 0.0f32;
|
|
|
|
// Process 4 elements at a time (unrolled loop for cache efficiency)
|
|
for i in 0..chunks {
|
|
let base = i * 4;
|
|
unsafe {
|
|
let a0 = *a.get_unchecked(base);
|
|
let a1 = *a.get_unchecked(base + 1);
|
|
let a2 = *a.get_unchecked(base + 2);
|
|
let a3 = *a.get_unchecked(base + 3);
|
|
|
|
let b0 = *b.get_unchecked(base);
|
|
let b1 = *b.get_unchecked(base + 1);
|
|
let b2 = *b.get_unchecked(base + 2);
|
|
let b3 = *b.get_unchecked(base + 3);
|
|
|
|
dot += a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3;
|
|
mag_a += a0 * a0 + a1 * a1 + a2 * a2 + a3 * a3;
|
|
mag_b += b0 * b0 + b1 * b1 + b2 * b2 + b3 * b3;
|
|
}
|
|
}
|
|
|
|
// Process remaining elements
|
|
for i in (chunks * 4)..len {
|
|
let ai = a[i];
|
|
let bi = b[i];
|
|
dot += ai * bi;
|
|
mag_a += ai * ai;
|
|
mag_b += bi * bi;
|
|
}
|
|
|
|
let mag = (mag_a * mag_b).sqrt();
|
|
if mag == 0.0 {
|
|
return 0.0;
|
|
}
|
|
|
|
dot / mag
|
|
}
|
|
|
|
/// Standard cosine similarity (alias for compatibility)
|
|
#[allow(dead_code)]
|
|
#[inline]
|
|
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
|
cosine_similarity_simd(a, b)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::types::Metadata;
|
|
|
|
#[test]
|
|
fn test_long_term_store() {
|
|
let store = LongTermStore::default();
|
|
|
|
let temporal_pattern =
|
|
TemporalPattern::from_embedding(vec![1.0, 2.0, 3.0], Metadata::new());
|
|
let id = temporal_pattern.pattern.id;
|
|
|
|
store.integrate(temporal_pattern);
|
|
|
|
assert_eq!(store.len(), 1);
|
|
assert!(store.get(&id).is_some());
|
|
}
|
|
|
|
#[test]
|
|
fn test_search() {
|
|
let store = LongTermStore::default();
|
|
|
|
// Add patterns
|
|
let p1 = TemporalPattern::from_embedding(vec![1.0, 0.0, 0.0], Metadata::new());
|
|
let p2 = TemporalPattern::from_embedding(vec![0.0, 1.0, 0.0], Metadata::new());
|
|
|
|
store.integrate(p1);
|
|
store.integrate(p2);
|
|
|
|
// Query similar to p1
|
|
let query = Query::from_embedding(vec![0.9, 0.1, 0.0]).with_k(1);
|
|
let results = store.search(&query);
|
|
|
|
assert_eq!(results.len(), 1);
|
|
assert!(results[0].score > 0.5);
|
|
}
|
|
|
|
#[test]
|
|
fn test_decay() {
|
|
let store = LongTermStore::default();
|
|
|
|
let mut temporal_pattern =
|
|
TemporalPattern::from_embedding(vec![1.0, 2.0, 3.0], Metadata::new());
|
|
temporal_pattern.pattern.salience = 0.15; // Just above minimum
|
|
let id = temporal_pattern.pattern.id;
|
|
|
|
store.integrate(temporal_pattern);
|
|
assert_eq!(store.len(), 1);
|
|
|
|
// Decay should remove it
|
|
store.decay_low_salience(0.5);
|
|
assert_eq!(store.len(), 0);
|
|
}
|
|
}
|