Squashed 'vendor/ruvector/' content from commit b64c2172

git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
commit d803bfe2b1
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,831 @@
//! AgentDB adapter for pattern-aware tiering.
//!
//! Provides a bridge between the TieredStore and an external HNSW
//! vector index. When connected, tiering decisions can be influenced
//! by semantic similarity to frequently-accessed patterns.
//!
//! # Overview
//!
//! Block metadata is converted into a compact 4-dimensional embedding
//! via [`pattern_from_meta`], then stored in a [`PatternIndex`]. The
//! [`AdaptiveTiering`] struct combines the index with a
//! [`TierConfig`](crate::tiering::TierConfig) to produce tier
//! suggestions based on weighted neighbor voting.
//!
//! The default [`InMemoryPatternIndex`] uses brute-force linear scan
//! with cosine similarity, suitable for up to ~10K blocks. A real
//! deployment would swap in an HNSW-backed implementation.
use crate::store::{BlockKey, BlockMeta, Tier};
use crate::tiering::TierConfig;
use std::collections::HashMap;
// ---------------------------------------------------------------------------
// PatternVector
// ---------------------------------------------------------------------------
/// A block's access-pattern embedding for similarity search.
#[derive(Clone, Debug)]
pub struct PatternVector {
/// The block this vector represents.
pub key: BlockKey,
/// Access-pattern embedding (typically 4 dimensions).
pub embedding: Vec<f32>,
/// Tiering score at the time of insertion.
pub score: f32,
}
// ---------------------------------------------------------------------------
// PatternIndex trait
// ---------------------------------------------------------------------------
/// Trait for a vector index over access-pattern embeddings.
///
/// Implementations range from a simple brute-force scan
/// ([`InMemoryPatternIndex`]) to an HNSW-backed production index.
pub trait PatternIndex {
/// Insert (or replace) a pattern vector.
fn insert(&mut self, vec: &PatternVector);
/// Return the `k` nearest neighbors to `query`, sorted by
/// descending cosine similarity. Each result is `(key, similarity)`.
fn search_nearest(&self, query: &[f32], k: usize) -> Vec<(BlockKey, f32)>;
/// Remove the pattern for `key`, if present.
fn remove(&mut self, key: BlockKey);
/// Number of pattern vectors currently stored.
fn len(&self) -> usize;
/// Returns `true` if the index contains no vectors.
fn is_empty(&self) -> bool {
self.len() == 0
}
}
// ---------------------------------------------------------------------------
// Cosine similarity
// ---------------------------------------------------------------------------
/// Compute the cosine similarity between two vectors.
///
/// Returns 0.0 if either vector has zero magnitude or they differ in length.
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let mut dot = 0.0f32;
let mut norm_a_sq = 0.0f32;
let mut norm_b_sq = 0.0f32;
for (&x, &y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a_sq += x * x;
norm_b_sq += y * y;
}
let denom = norm_a_sq.sqrt() * norm_b_sq.sqrt();
if denom == 0.0 {
0.0
} else {
dot / denom
}
}
// ---------------------------------------------------------------------------
// InMemoryPatternIndex
// ---------------------------------------------------------------------------
/// Brute-force in-memory implementation of [`PatternIndex`].
///
/// Uses a `Vec<PatternVector>` with linear-scan cosine similarity.
/// Adequate for small collections (<10K blocks); a real AgentDB
/// deployment would use HNSW for sub-linear search.
pub struct InMemoryPatternIndex {
vectors: Vec<PatternVector>,
}
impl InMemoryPatternIndex {
/// Create a new empty index.
pub fn new() -> Self {
Self {
vectors: Vec::new(),
}
}
}
impl Default for InMemoryPatternIndex {
fn default() -> Self {
Self::new()
}
}
impl PatternIndex for InMemoryPatternIndex {
fn insert(&mut self, vec: &PatternVector) {
// Remove any existing entry for the same key, then append.
self.vectors.retain(|v| v.key != vec.key);
self.vectors.push(vec.clone());
}
fn search_nearest(&self, query: &[f32], k: usize) -> Vec<(BlockKey, f32)> {
if k == 0 || self.vectors.is_empty() {
return Vec::new();
}
let mut scored: Vec<(BlockKey, f32)> = self
.vectors
.iter()
.map(|v| (v.key, cosine_similarity(query, &v.embedding)))
.collect();
// Sort by descending similarity.
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal));
scored.truncate(k);
scored
}
fn remove(&mut self, key: BlockKey) {
self.vectors.retain(|v| v.key != key);
}
fn len(&self) -> usize {
self.vectors.len()
}
}
// ---------------------------------------------------------------------------
// pattern_from_meta
// ---------------------------------------------------------------------------
/// Convert block metadata into a 4-dimensional pattern vector.
///
/// The dimensions encode access-pattern features that are useful for
/// clustering blocks with similar tiering behaviour:
///
/// | Index | Feature | Range | Description |
/// |-------|------------------|---------|------------------------------------------|
/// | 0 | `ema_rate` | [0, 1] | Exponential moving average of access rate|
/// | 1 | `popcount/64` | [0, 1] | Fraction of recent ticks with access |
/// | 2 | `recency_decay` | (0, 1] | `1 / (1 + tier_age)` -- inverse staleness|
/// | 3 | `access_count_log` | [0, 1] | `log2(1 + count) / 32` -- normalized log |
pub fn pattern_from_meta(meta: &BlockMeta) -> Vec<f32> {
let ema = meta.ema_rate.clamp(0.0, 1.0);
let pop = meta.window.count_ones() as f32 / 64.0;
let recency = 1.0 / (1.0 + meta.tier_age as f32);
let count_log = ((1.0 + meta.access_count as f32).log2() / 32.0).clamp(0.0, 1.0);
vec![ema, pop, recency, count_log]
}
// ---------------------------------------------------------------------------
// AdaptiveTiering
// ---------------------------------------------------------------------------
/// Pattern-aware tiering advisor.
///
/// Combines a [`PatternIndex`] with a [`TierConfig`] to suggest tier
/// assignments based on the tiers of semantically similar blocks.
///
/// # Algorithm
///
/// Given a block's metadata and a set of nearest neighbors (from the
/// pattern index), each neighbor's known tier contributes a weighted
/// vote proportional to its cosine similarity. The tier with the
/// highest cumulative vote is suggested, unless it matches the block's
/// current tier (in which case `None` is returned).
pub struct AdaptiveTiering<I: PatternIndex> {
/// The underlying pattern vector index.
pub index: I,
/// Tiering configuration (thresholds, hysteresis, etc.).
pub config: TierConfig,
/// Known tier for each block, updated via [`register_block`].
block_tiers: HashMap<BlockKey, Tier>,
}
impl<I: PatternIndex> AdaptiveTiering<I> {
/// Create a new `AdaptiveTiering` with the given index and config.
pub fn new(index: I, config: TierConfig) -> Self {
Self {
index,
config,
block_tiers: HashMap::new(),
}
}
/// Register (or update) the known tier for a block.
///
/// This must be called whenever a block changes tier so that
/// [`suggest_tier`](Self::suggest_tier) can use accurate neighbor
/// tier information for voting.
pub fn register_block(&mut self, key: BlockKey, tier: Tier) {
self.block_tiers.insert(key, tier);
}
/// Remove a block from the tier registry and the pattern index.
pub fn remove_block(&mut self, key: BlockKey) {
self.block_tiers.remove(&key);
self.index.remove(key);
}
/// Number of blocks registered in the tier map.
pub fn registered_count(&self) -> usize {
self.block_tiers.len()
}
/// Suggest a tier for `meta` based on its nearest neighbors.
///
/// `neighbors` should be the output of
/// [`PatternIndex::search_nearest`]: a list of `(BlockKey, similarity)`
/// pairs. Each neighbor whose tier is known contributes a weighted
/// vote. The tier with the highest total vote is returned, unless it
/// matches the block's current tier.
///
/// Returns `None` if:
/// - `neighbors` is empty,
/// - no neighbors have known tiers, or
/// - the consensus tier matches the block's current tier.
pub fn suggest_tier(&self, meta: &BlockMeta, neighbors: &[(BlockKey, f32)]) -> Option<Tier> {
if neighbors.is_empty() {
return None;
}
// Accumulate weighted votes per tier.
// Index 0 = Tier0, 1 = Tier1, 2 = Tier2, 3 = Tier3.
let mut votes = [0.0f32; 4];
let mut total_weight = 0.0f32;
for &(key, similarity) in neighbors {
if let Some(&tier) = self.block_tiers.get(&key) {
let weight = similarity.max(0.0);
votes[tier as u8 as usize] += weight;
total_weight += weight;
}
}
if total_weight == 0.0 {
return None;
}
// Find the tier with the highest vote. On ties, prefer the
// hotter tier (lower index) since it was found first.
let mut best_idx = 0usize;
let mut best_vote = votes[0];
for i in 1..4 {
if votes[i] > best_vote {
best_vote = votes[i];
best_idx = i;
}
}
let suggested = match best_idx {
0 => Tier::Tier0,
1 => Tier::Tier1,
2 => Tier::Tier2,
3 => Tier::Tier3,
_ => unreachable!(),
};
if suggested == meta.tier {
None
} else {
Some(suggested)
}
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use crate::store::{DType, ReconstructPolicy};
fn make_key(tid: u128, idx: u32) -> BlockKey {
BlockKey {
tensor_id: tid,
block_index: idx,
}
}
fn make_store_meta(
key: BlockKey,
tier: Tier,
ema_rate: f32,
window: u64,
access_count: u32,
tier_age: u32,
) -> BlockMeta {
BlockMeta {
key,
dtype: DType::F32,
tier,
bits: 8,
scale: 1.0,
zero_point: 0,
created_at: 0,
last_access_at: 100,
access_count,
ema_rate,
window,
checksum: 0,
reconstruct: ReconstructPolicy::None,
tier_age,
lineage_parent: None,
block_bytes: 1024,
}
}
// -- cosine_similarity -------------------------------------------------
#[test]
fn cosine_identical_vectors() {
let v = vec![1.0, 2.0, 3.0, 4.0];
let sim = cosine_similarity(&v, &v);
assert!((sim - 1.0).abs() < 1e-6, "sim={sim}");
}
#[test]
fn cosine_orthogonal_vectors() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-6, "sim={sim}");
}
#[test]
fn cosine_opposite_vectors() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![-1.0, 0.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!((sim - (-1.0)).abs() < 1e-6, "sim={sim}");
}
#[test]
fn cosine_zero_vector() {
let a = vec![1.0, 2.0];
let b = vec![0.0, 0.0];
assert_eq!(cosine_similarity(&a, &b), 0.0);
}
#[test]
fn cosine_different_lengths() {
let a = vec![1.0, 2.0];
let b = vec![1.0, 2.0, 3.0];
assert_eq!(cosine_similarity(&a, &b), 0.0);
}
#[test]
fn cosine_empty() {
let a: Vec<f32> = vec![];
let b: Vec<f32> = vec![];
assert_eq!(cosine_similarity(&a, &b), 0.0);
}
#[test]
fn cosine_known_value() {
// cos([1,1], [1,0]) = 1/sqrt(2) ~ 0.7071
let a = vec![1.0, 1.0];
let b = vec![1.0, 0.0];
let sim = cosine_similarity(&a, &b);
let expected = 1.0 / 2.0f32.sqrt();
assert!(
(sim - expected).abs() < 1e-6,
"sim={sim}, expected={expected}"
);
}
// -- InMemoryPatternIndex ----------------------------------------------
#[test]
fn index_insert_and_len() {
let mut idx = InMemoryPatternIndex::new();
assert!(idx.is_empty());
idx.insert(&PatternVector {
key: make_key(1, 0),
embedding: vec![1.0, 0.0, 0.0, 0.0],
score: 0.5,
});
assert_eq!(idx.len(), 1);
assert!(!idx.is_empty());
}
#[test]
fn index_insert_replaces_duplicate_key() {
let mut idx = InMemoryPatternIndex::new();
let key = make_key(1, 0);
idx.insert(&PatternVector {
key,
embedding: vec![1.0, 0.0, 0.0, 0.0],
score: 0.5,
});
idx.insert(&PatternVector {
key,
embedding: vec![0.0, 1.0, 0.0, 0.0],
score: 0.8,
});
assert_eq!(idx.len(), 1);
// The search should find the updated embedding.
let results = idx.search_nearest(&[0.0, 1.0, 0.0, 0.0], 1);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, key);
// Similarity should be ~1.0 since embeddings match.
assert!((results[0].1 - 1.0).abs() < 1e-6);
}
#[test]
fn index_remove() {
let mut idx = InMemoryPatternIndex::new();
let key = make_key(1, 0);
idx.insert(&PatternVector {
key,
embedding: vec![1.0, 0.0, 0.0, 0.0],
score: 0.5,
});
assert_eq!(idx.len(), 1);
idx.remove(key);
assert_eq!(idx.len(), 0);
}
#[test]
fn index_remove_nonexistent() {
let mut idx = InMemoryPatternIndex::new();
idx.remove(make_key(99, 0)); // should not panic
assert_eq!(idx.len(), 0);
}
#[test]
fn index_search_nearest_ordering() {
let mut idx = InMemoryPatternIndex::new();
// Insert three vectors with known geometry.
idx.insert(&PatternVector {
key: make_key(1, 0),
embedding: vec![1.0, 0.0, 0.0, 0.0],
score: 0.0,
});
idx.insert(&PatternVector {
key: make_key(2, 0),
embedding: vec![0.7, 0.7, 0.0, 0.0],
score: 0.0,
});
idx.insert(&PatternVector {
key: make_key(3, 0),
embedding: vec![0.0, 1.0, 0.0, 0.0],
score: 0.0,
});
// Query close to [1, 0, 0, 0].
let results = idx.search_nearest(&[1.0, 0.1, 0.0, 0.0], 3);
assert_eq!(results.len(), 3);
// Closest should be key 1 (nearly identical direction).
assert_eq!(results[0].0, make_key(1, 0));
// Second should be key 2 (partial overlap).
assert_eq!(results[1].0, make_key(2, 0));
// Third should be key 3 (mostly orthogonal).
assert_eq!(results[2].0, make_key(3, 0));
// Similarities should be descending.
assert!(results[0].1 >= results[1].1);
assert!(results[1].1 >= results[2].1);
}
#[test]
fn index_search_nearest_k_larger_than_size() {
let mut idx = InMemoryPatternIndex::new();
idx.insert(&PatternVector {
key: make_key(1, 0),
embedding: vec![1.0, 0.0],
score: 0.0,
});
let results = idx.search_nearest(&[1.0, 0.0], 10);
assert_eq!(results.len(), 1);
}
#[test]
fn index_search_nearest_k_zero() {
let mut idx = InMemoryPatternIndex::new();
idx.insert(&PatternVector {
key: make_key(1, 0),
embedding: vec![1.0],
score: 0.0,
});
let results = idx.search_nearest(&[1.0], 0);
assert!(results.is_empty());
}
#[test]
fn index_search_nearest_empty() {
let idx = InMemoryPatternIndex::new();
let results = idx.search_nearest(&[1.0, 0.0], 5);
assert!(results.is_empty());
}
// -- pattern_from_meta -------------------------------------------------
#[test]
fn pattern_from_meta_dimensions() {
let meta = make_store_meta(make_key(1, 0), Tier::Tier1, 0.5, 0xFFFF, 100, 10);
let pat = pattern_from_meta(&meta);
assert_eq!(pat.len(), 4);
}
#[test]
fn pattern_from_meta_ema_component() {
let meta = make_store_meta(make_key(1, 0), Tier::Tier1, 0.8, 0, 0, 0);
let pat = pattern_from_meta(&meta);
assert!((pat[0] - 0.8).abs() < 1e-6, "ema={}", pat[0]);
}
#[test]
fn pattern_from_meta_popcount_component() {
// All 64 bits set.
let meta = make_store_meta(make_key(1, 0), Tier::Tier1, 0.0, u64::MAX, 0, 0);
let pat = pattern_from_meta(&meta);
assert!((pat[1] - 1.0).abs() < 1e-6, "pop={}", pat[1]);
// No bits set.
let meta2 = make_store_meta(make_key(1, 0), Tier::Tier1, 0.0, 0, 0, 0);
let pat2 = pattern_from_meta(&meta2);
assert!((pat2[1]).abs() < 1e-6, "pop={}", pat2[1]);
// 32 bits set.
let meta3 = make_store_meta(make_key(1, 0), Tier::Tier1, 0.0, 0xFFFF_FFFF, 0, 0);
let pat3 = pattern_from_meta(&meta3);
assert!((pat3[1] - 0.5).abs() < 1e-6, "pop={}", pat3[1]);
}
#[test]
fn pattern_from_meta_recency_component() {
// tier_age = 0 => recency = 1.0 / (1.0 + 0) = 1.0
let meta = make_store_meta(make_key(1, 0), Tier::Tier1, 0.0, 0, 0, 0);
let pat = pattern_from_meta(&meta);
assert!((pat[2] - 1.0).abs() < 1e-6, "recency={}", pat[2]);
// tier_age = 9 => recency = 1.0 / 10.0 = 0.1
let meta2 = make_store_meta(make_key(1, 0), Tier::Tier1, 0.0, 0, 0, 9);
let pat2 = pattern_from_meta(&meta2);
assert!((pat2[2] - 0.1).abs() < 1e-6, "recency={}", pat2[2]);
}
#[test]
fn pattern_from_meta_access_count_log_component() {
// access_count = 0 => log2(1) / 32 = 0
let meta = make_store_meta(make_key(1, 0), Tier::Tier1, 0.0, 0, 0, 0);
let pat = pattern_from_meta(&meta);
assert!(pat[3].abs() < 1e-6, "count_log={}", pat[3]);
// access_count = 1 => log2(2) / 32 = 1/32 ~ 0.03125
let meta2 = make_store_meta(make_key(1, 0), Tier::Tier1, 0.0, 0, 1, 0);
let pat2 = pattern_from_meta(&meta2);
assert!((pat2[3] - 1.0 / 32.0).abs() < 1e-4, "count_log={}", pat2[3]);
}
#[test]
fn pattern_from_meta_values_in_unit_range() {
// Use extreme values to verify clamping.
let meta = make_store_meta(
make_key(1, 0),
Tier::Tier1,
2.0, // ema > 1, should be clamped
u64::MAX, // all bits set
u32::MAX, // max access count
u32::MAX, // max tier age
);
let pat = pattern_from_meta(&meta);
for (i, &v) in pat.iter().enumerate() {
assert!(v >= 0.0 && v <= 1.0, "dim {i} out of [0,1]: {v}");
}
}
// -- AdaptiveTiering ---------------------------------------------------
#[test]
fn adaptive_new_and_register() {
let idx = InMemoryPatternIndex::new();
let config = TierConfig::default();
let mut at = AdaptiveTiering::new(idx, config);
assert_eq!(at.registered_count(), 0);
at.register_block(make_key(1, 0), Tier::Tier1);
assert_eq!(at.registered_count(), 1);
at.register_block(make_key(1, 0), Tier::Tier2);
assert_eq!(at.registered_count(), 1); // same key, updated
}
#[test]
fn adaptive_remove_block() {
let mut idx = InMemoryPatternIndex::new();
let key = make_key(1, 0);
idx.insert(&PatternVector {
key,
embedding: vec![1.0, 0.0, 0.0, 0.0],
score: 0.5,
});
let config = TierConfig::default();
let mut at = AdaptiveTiering::new(idx, config);
at.register_block(key, Tier::Tier1);
assert_eq!(at.registered_count(), 1);
assert_eq!(at.index.len(), 1);
at.remove_block(key);
assert_eq!(at.registered_count(), 0);
assert_eq!(at.index.len(), 0);
}
#[test]
fn suggest_tier_empty_neighbors() {
let idx = InMemoryPatternIndex::new();
let config = TierConfig::default();
let at = AdaptiveTiering::new(idx, config);
let meta = make_store_meta(make_key(1, 0), Tier::Tier1, 0.5, 0, 10, 5);
let result = at.suggest_tier(&meta, &[]);
assert_eq!(result, None);
}
#[test]
fn suggest_tier_no_known_neighbors() {
let idx = InMemoryPatternIndex::new();
let config = TierConfig::default();
let at = AdaptiveTiering::new(idx, config);
let meta = make_store_meta(make_key(1, 0), Tier::Tier1, 0.5, 0, 10, 5);
// Neighbors exist but their tiers are not registered.
let neighbors = vec![(make_key(2, 0), 0.9), (make_key(3, 0), 0.8)];
let result = at.suggest_tier(&meta, &neighbors);
assert_eq!(result, None);
}
#[test]
fn suggest_tier_unanimous_vote() {
let idx = InMemoryPatternIndex::new();
let config = TierConfig::default();
let mut at = AdaptiveTiering::new(idx, config);
// Register three neighbors all in Tier3.
at.register_block(make_key(2, 0), Tier::Tier3);
at.register_block(make_key(3, 0), Tier::Tier3);
at.register_block(make_key(4, 0), Tier::Tier3);
let meta = make_store_meta(make_key(1, 0), Tier::Tier1, 0.5, 0, 10, 5);
let neighbors = vec![
(make_key(2, 0), 0.9),
(make_key(3, 0), 0.8),
(make_key(4, 0), 0.7),
];
let result = at.suggest_tier(&meta, &neighbors);
assert_eq!(result, Some(Tier::Tier3));
}
#[test]
fn suggest_tier_same_as_current_returns_none() {
let idx = InMemoryPatternIndex::new();
let config = TierConfig::default();
let mut at = AdaptiveTiering::new(idx, config);
// Neighbors all in Tier1, same as the block.
at.register_block(make_key(2, 0), Tier::Tier1);
at.register_block(make_key(3, 0), Tier::Tier1);
let meta = make_store_meta(make_key(1, 0), Tier::Tier1, 0.5, 0, 10, 5);
let neighbors = vec![(make_key(2, 0), 0.9), (make_key(3, 0), 0.8)];
let result = at.suggest_tier(&meta, &neighbors);
assert_eq!(result, None);
}
#[test]
fn suggest_tier_weighted_majority() {
let idx = InMemoryPatternIndex::new();
let config = TierConfig::default();
let mut at = AdaptiveTiering::new(idx, config);
// Two neighbors in Tier1 with moderate similarity.
at.register_block(make_key(2, 0), Tier::Tier1);
at.register_block(make_key(3, 0), Tier::Tier1);
// One neighbor in Tier3 with very high similarity.
at.register_block(make_key(4, 0), Tier::Tier3);
let meta = make_store_meta(make_key(1, 0), Tier::Tier2, 0.5, 0, 10, 5);
let neighbors = vec![
(make_key(2, 0), 0.3), // votes Tier1 with weight 0.3
(make_key(3, 0), 0.3), // votes Tier1 with weight 0.3
(make_key(4, 0), 0.9), // votes Tier3 with weight 0.9
];
// Tier1 total = 0.6, Tier3 total = 0.9. Tier3 wins.
let result = at.suggest_tier(&meta, &neighbors);
assert_eq!(result, Some(Tier::Tier3));
}
#[test]
fn suggest_tier_negative_similarity_ignored() {
let idx = InMemoryPatternIndex::new();
let config = TierConfig::default();
let mut at = AdaptiveTiering::new(idx, config);
at.register_block(make_key(2, 0), Tier::Tier3);
at.register_block(make_key(3, 0), Tier::Tier1);
let meta = make_store_meta(make_key(1, 0), Tier::Tier2, 0.5, 0, 10, 5);
let neighbors = vec![
(make_key(2, 0), -0.5), // negative similarity, weight clamped to 0
(make_key(3, 0), 0.5), // positive similarity, votes Tier1
];
// Tier3 gets 0 weight (clamped), Tier1 gets 0.5. Tier1 wins.
let result = at.suggest_tier(&meta, &neighbors);
assert_eq!(result, Some(Tier::Tier1));
}
#[test]
fn suggest_tier_zero_similarity_all() {
let idx = InMemoryPatternIndex::new();
let config = TierConfig::default();
let mut at = AdaptiveTiering::new(idx, config);
at.register_block(make_key(2, 0), Tier::Tier3);
let meta = make_store_meta(make_key(1, 0), Tier::Tier1, 0.5, 0, 10, 5);
let neighbors = vec![(make_key(2, 0), 0.0)];
// Zero similarity means zero weight => total_weight == 0 => None.
let result = at.suggest_tier(&meta, &neighbors);
assert_eq!(result, None);
}
// -- Integration: pattern_from_meta + index + adaptive -----------------
#[test]
fn integration_end_to_end() {
let mut idx = InMemoryPatternIndex::new();
let config = TierConfig::default();
// Create several blocks with different access patterns.
let hot_key = make_key(1, 0);
let warm_key = make_key(2, 0);
let cold_key = make_key(3, 0);
let hot_meta = make_store_meta(hot_key, Tier::Tier1, 0.9, u64::MAX, 1000, 2);
let warm_meta = make_store_meta(warm_key, Tier::Tier2, 0.5, 0xFFFF_FFFF, 100, 10);
let cold_meta = make_store_meta(cold_key, Tier::Tier3, 0.05, 0x0F, 5, 100);
// Build embeddings and insert into index.
let hot_emb = pattern_from_meta(&hot_meta);
let warm_emb = pattern_from_meta(&warm_meta);
let cold_emb = pattern_from_meta(&cold_meta);
idx.insert(&PatternVector {
key: hot_key,
embedding: hot_emb.clone(),
score: 0.9,
});
idx.insert(&PatternVector {
key: warm_key,
embedding: warm_emb.clone(),
score: 0.5,
});
idx.insert(&PatternVector {
key: cold_key,
embedding: cold_emb.clone(),
score: 0.1,
});
let mut at = AdaptiveTiering::new(idx, config);
at.register_block(hot_key, Tier::Tier1);
at.register_block(warm_key, Tier::Tier2);
at.register_block(cold_key, Tier::Tier3);
// Query: a new block with a hot-like pattern.
let new_key = make_key(4, 0);
let new_meta = make_store_meta(new_key, Tier::Tier3, 0.85, u64::MAX, 800, 3);
let new_emb = pattern_from_meta(&new_meta);
let neighbors = at.index.search_nearest(&new_emb, 3);
assert!(!neighbors.is_empty());
let suggestion = at.suggest_tier(&new_meta, &neighbors);
// The new block's pattern is closest to the hot block, so
// the suggestion should be to promote it (away from Tier3).
assert!(
suggestion.is_some(),
"expected a tier suggestion for a hot-like pattern in Tier3"
);
let suggested = suggestion.unwrap();
assert_ne!(suggested, Tier::Tier3, "should not stay cold");
}
}

View File

@@ -0,0 +1,158 @@
//! Bitstream packer/unpacker for arbitrary bit widths (1-8).
//!
//! Uses a 64-bit accumulator for sub-byte codes with no alignment padding.
/// Pack unsigned codes of `bits` width into a byte stream.
///
/// Each code occupies exactly `bits` bits in the output with no alignment
/// padding between codes. A trailing partial byte is emitted if needed.
///
/// For 8-bit codes, writes bytes directly without bit accumulation.
#[inline]
pub fn pack(codes: &[u32], bits: u32, out: &mut Vec<u8>) {
// Fast path: 8-bit codes map 1:1 to bytes.
if bits == 8 {
out.extend(codes.iter().map(|&c| c as u8));
return;
}
let mut acc: u64 = 0;
let mut acc_bits: u32 = 0;
for &code in codes {
acc |= (code as u64) << acc_bits;
acc_bits += bits;
while acc_bits >= 8 {
out.push((acc & 0xFF) as u8);
acc >>= 8;
acc_bits -= 8;
}
}
if acc_bits > 0 {
out.push((acc & 0xFF) as u8);
}
}
/// Unpack `count` unsigned codes of `bits` width from a byte stream.
///
/// Stops early if the data is exhausted before `count` codes are extracted.
///
/// For 8-bit codes, reads bytes directly without bit accumulation.
#[inline]
pub fn unpack(data: &[u8], bits: u32, count: usize, out: &mut Vec<u32>) {
// Fast path: 8-bit codes map 1:1 from bytes.
if bits == 8 {
let n = count.min(data.len());
out.extend(data[..n].iter().map(|&b| b as u32));
return;
}
let mask = (1u64 << bits) - 1;
let mut acc: u64 = 0;
let mut acc_bits: u32 = 0;
let mut byte_idx = 0usize;
let mut decoded = 0usize;
while decoded < count {
while acc_bits < bits && byte_idx < data.len() {
acc |= (data[byte_idx] as u64) << acc_bits;
acc_bits += 8;
byte_idx += 1;
}
if acc_bits < bits {
break;
}
out.push((acc & mask) as u32);
acc >>= bits;
acc_bits -= bits;
decoded += 1;
}
}
/// Compute qmax for a given bit width: `2^(bits-1) - 1`.
///
/// Returns 0 for invalid bit widths (0 or >8).
///
/// | bits | qmax |
/// |------|------|
/// | 8 | 127 |
/// | 7 | 63 |
/// | 5 | 15 |
/// | 3 | 3 |
#[inline]
pub fn qmax_from_bits(bits: u8) -> i32 {
if bits == 0 || bits > 8 {
return 0;
}
(1i32 << (bits - 1)) - 1
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_roundtrip_8bit() {
let codes: Vec<u32> = (0..256).collect();
let mut packed = Vec::new();
pack(&codes, 8, &mut packed);
assert_eq!(packed.len(), 256);
let mut unpacked = Vec::new();
unpack(&packed, 8, 256, &mut unpacked);
assert_eq!(codes, unpacked);
}
#[test]
fn test_roundtrip_3bit() {
let codes: Vec<u32> = (0..7).collect();
let mut packed = Vec::new();
pack(&codes, 3, &mut packed);
let mut unpacked = Vec::new();
unpack(&packed, 3, 7, &mut unpacked);
assert_eq!(codes, unpacked);
}
#[test]
fn test_roundtrip_5bit() {
let codes: Vec<u32> = (0..31).collect();
let mut packed = Vec::new();
pack(&codes, 5, &mut packed);
let mut unpacked = Vec::new();
unpack(&packed, 5, 31, &mut unpacked);
assert_eq!(codes, unpacked);
}
#[test]
fn test_roundtrip_7bit() {
let codes: Vec<u32> = (0..127).collect();
let mut packed = Vec::new();
pack(&codes, 7, &mut packed);
let mut unpacked = Vec::new();
unpack(&packed, 7, 127, &mut unpacked);
assert_eq!(codes, unpacked);
}
#[test]
fn test_packing_density() {
let codes = vec![5u32; 100];
let mut packed = Vec::new();
pack(&codes, 3, &mut packed);
assert_eq!(packed.len(), 38); // ceil(300/8) = 38
}
#[test]
fn test_qmax() {
assert_eq!(qmax_from_bits(8), 127);
assert_eq!(qmax_from_bits(7), 63);
assert_eq!(qmax_from_bits(5), 15);
assert_eq!(qmax_from_bits(3), 3);
assert_eq!(qmax_from_bits(1), 0);
assert_eq!(qmax_from_bits(0), 0);
}
}

View File

@@ -0,0 +1,552 @@
//! Coherence gate: read-after-write validation for the temporal tensor store.
//!
//! Ensures data integrity by verifying that a `get()` immediately after `put()`
//! returns data within the expected quantization error bounds for the tier.
//!
//! # Overview
//!
//! Quantization is lossy -- the error introduced depends on the tier's bit
//! width (8-bit for Tier1, 7-bit for Tier2, 3-bit for Tier3). The coherence
//! gate validates that the round-trip error stays within configurable
//! per-tier bounds, catching silent corruption or encoding bugs.
//!
//! # Epoch Tracking
//!
//! [`EpochTracker`] provides a lightweight write-epoch mechanism so that
//! readers can detect stale data (i.e. data that was overwritten between
//! the time it was read and the time it was consumed).
use std::collections::HashMap;
use crate::store::{BlockKey, StoreError, Tier, TieredStore};
// ---------------------------------------------------------------------------
// CoherenceResult
// ---------------------------------------------------------------------------
/// Outcome of a coherence check.
#[derive(Clone, Debug, PartialEq)]
pub struct CoherenceResult {
/// Maximum relative error observed across all elements.
pub max_error: f32,
/// The tier at which the block is stored.
pub tier: Tier,
/// Whether the observed error is within the configured bound for this tier.
pub passed: bool,
}
// ---------------------------------------------------------------------------
// CoherenceCheck
// ---------------------------------------------------------------------------
/// Per-tier maximum relative error bounds for read-after-write validation.
///
/// After a `put()`, the block is immediately read back and the maximum
/// relative error (per-element `|orig - decoded| / |orig|`) is compared
/// against the bound for the block's current tier.
#[derive(Clone, Debug)]
pub struct CoherenceCheck {
/// Maximum acceptable relative error for each tier, indexed by
/// `Tier as usize`: `[Tier0, Tier1, Tier2, Tier3]`.
///
/// Tier0 (evicted) has no payload, so any read will fail before the
/// error comparison is reached. The bound is set to `f32::MAX` as a
/// sentinel.
pub max_relative_errors: [f32; 4],
}
impl Default for CoherenceCheck {
fn default() -> Self {
Self {
// Tier0: evicted, reads always fail (sentinel value).
// Tier1: 8-bit, very tight bound.
// Tier2: 7-bit, slightly looser.
// Tier3: 3-bit, aggressive quantization allows up to 35% error.
max_relative_errors: [f32::MAX, 0.01, 0.02, 0.35],
}
}
}
impl CoherenceCheck {
/// Create a `CoherenceCheck` with custom per-tier error bounds.
pub fn new(max_relative_errors: [f32; 4]) -> Self {
Self {
max_relative_errors,
}
}
/// Validate read-after-write coherence for a block that was just written.
///
/// Reads the block back from `store`, computes the maximum relative
/// error against `original_data`, and checks whether it falls within
/// the configured bound for the block's tier.
///
/// # Errors
///
/// Returns [`StoreError::BlockNotFound`] if the key does not exist,
/// [`StoreError::TensorEvicted`] if the block is in Tier0, or any
/// other `StoreError` from the underlying read.
pub fn check_coherence(
&self,
store: &mut TieredStore,
key: BlockKey,
original_data: &[f32],
now: u64,
) -> Result<CoherenceResult, StoreError> {
// Look up the tier before reading (needed for the error bound).
let tier = store.meta(key).ok_or(StoreError::BlockNotFound)?.tier;
// Read back the block.
let mut buf = vec![0.0f32; original_data.len()];
let n = store.get(key, &mut buf, now)?;
// Compute the maximum relative error.
let max_error = compute_max_relative_error(original_data, &buf[..n]);
let tier_idx = tier as usize;
let bound = if tier_idx < self.max_relative_errors.len() {
self.max_relative_errors[tier_idx]
} else {
f32::MAX
};
Ok(CoherenceResult {
max_error,
tier,
passed: max_error <= bound,
})
}
/// Convenience: `put` followed by `check_coherence` in one call.
///
/// Stores the data at the given tier, then immediately reads it back
/// and validates the round-trip error. Returns the coherence result
/// so the caller can decide whether to retry at a higher-fidelity tier.
///
/// # Errors
///
/// Propagates errors from both `put` and the subsequent `get`.
pub fn verify_put(
&self,
store: &mut TieredStore,
key: BlockKey,
data: &[f32],
tier: Tier,
now: u64,
) -> Result<CoherenceResult, StoreError> {
store.put(key, data, tier, now)?;
self.check_coherence(store, key, data, now)
}
}
// ---------------------------------------------------------------------------
// Helper: relative error computation
// ---------------------------------------------------------------------------
/// Compute the maximum element-wise relative error between `original` and
/// `decoded`.
///
/// For elements where `|original| < epsilon` (near-zero), the absolute
/// error is used directly to avoid division-by-zero amplification.
fn compute_max_relative_error(original: &[f32], decoded: &[f32]) -> f32 {
const EPSILON: f32 = 1e-6;
let len = original.len().min(decoded.len());
let mut max_err: f32 = 0.0;
for i in 0..len {
let orig = original[i];
let dec = decoded[i];
let abs_err = (orig - dec).abs();
let rel_err = if orig.abs() > EPSILON {
abs_err / orig.abs()
} else {
abs_err
};
if rel_err > max_err {
max_err = rel_err;
}
}
max_err
}
// ---------------------------------------------------------------------------
// EpochTracker
// ---------------------------------------------------------------------------
/// Monotonic write-epoch tracker keyed by [`BlockKey`].
///
/// Each call to [`record_write`](EpochTracker::record_write) increments a
/// global counter and associates the new epoch with the given key. Readers
/// can later check whether their snapshot is stale via
/// [`is_stale`](EpochTracker::is_stale).
#[derive(Clone, Debug)]
pub struct EpochTracker {
/// Global monotonically increasing write counter.
next_epoch: u64,
/// Per-key latest write epoch.
epochs: HashMap<BlockKey, u64>,
}
impl EpochTracker {
/// Create a new tracker with epoch starting at 1.
pub fn new() -> Self {
Self {
next_epoch: 1,
epochs: HashMap::new(),
}
}
/// Record a write for `key`, returning the new epoch number.
///
/// The epoch is strictly monotonically increasing across all keys.
pub fn record_write(&mut self, key: BlockKey) -> u64 {
let epoch = self.next_epoch;
self.next_epoch += 1;
self.epochs.insert(key, epoch);
epoch
}
/// Return the latest write epoch for `key`, if any write has been recorded.
pub fn check_epoch(&self, key: BlockKey) -> Option<u64> {
self.epochs.get(&key).copied()
}
/// Returns `true` if the block identified by `key` has been written
/// after `read_epoch`, meaning the reader's snapshot is stale.
///
/// Returns `false` if no write has been recorded for `key` (the key
/// does not exist in the tracker).
pub fn is_stale(&self, key: BlockKey, read_epoch: u64) -> bool {
match self.epochs.get(&key) {
Some(&write_epoch) => write_epoch > read_epoch,
None => false,
}
}
}
impl Default for EpochTracker {
fn default() -> Self {
Self::new()
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use crate::store::{BlockKey, Tier, TieredStore};
fn make_key(tid: u128, idx: u32) -> BlockKey {
BlockKey {
tensor_id: tid,
block_index: idx,
}
}
// -- CoherenceCheck -----------------------------------------------------
#[test]
fn test_coherence_check_default_bounds() {
let cc = CoherenceCheck::default();
assert_eq!(cc.max_relative_errors[0], f32::MAX);
assert!((cc.max_relative_errors[1] - 0.01).abs() < 1e-9);
assert!((cc.max_relative_errors[2] - 0.02).abs() < 1e-9);
assert!((cc.max_relative_errors[3] - 0.35).abs() < 1e-9);
}
#[test]
fn test_coherence_check_custom_bounds() {
let bounds = [0.0, 0.05, 0.10, 0.50];
let cc = CoherenceCheck::new(bounds);
assert_eq!(cc.max_relative_errors, bounds);
}
#[test]
fn test_check_coherence_tier1_passes() {
let mut store = TieredStore::new(4096);
let key = make_key(1, 0);
let data: Vec<f32> = (0..64).map(|i| (i as f32 + 1.0) * 0.25).collect();
store.put(key, &data, Tier::Tier1, 0).unwrap();
let cc = CoherenceCheck::default();
let result = cc.check_coherence(&mut store, key, &data, 1).unwrap();
assert_eq!(result.tier, Tier::Tier1);
assert!(
result.passed,
"Tier1 coherence should pass; max_error={}, bound={}",
result.max_error, cc.max_relative_errors[1],
);
assert!(
result.max_error < cc.max_relative_errors[1],
"max_error {} should be < bound {}",
result.max_error,
cc.max_relative_errors[1],
);
}
#[test]
fn test_check_coherence_tier3_passes() {
let mut store = TieredStore::new(4096);
let key = make_key(2, 0);
// Use values with large magnitude to keep relative error low under
// 3-bit quantization (only 7 levels). Avoid near-zero values where
// even small absolute error produces large relative error.
let data: Vec<f32> = (0..32).map(|i| 10.0 + (i as f32) * 0.1).collect();
store.put(key, &data, Tier::Tier3, 0).unwrap();
let cc = CoherenceCheck::default();
let result = cc.check_coherence(&mut store, key, &data, 1).unwrap();
assert_eq!(result.tier, Tier::Tier3);
assert!(
result.passed,
"Tier3 coherence should pass with default 0.35 bound; max_error={}",
result.max_error,
);
}
#[test]
fn test_check_coherence_missing_block() {
let mut store = TieredStore::new(4096);
let key = make_key(99, 0);
let data = vec![1.0f32; 8];
let cc = CoherenceCheck::default();
let err = cc.check_coherence(&mut store, key, &data, 0);
assert_eq!(err, Err(StoreError::BlockNotFound));
}
#[test]
fn test_check_coherence_evicted_block() {
use crate::store::ReconstructPolicy;
let mut store = TieredStore::new(4096);
let key = make_key(3, 0);
let data = vec![1.0f32; 16];
store.put(key, &data, Tier::Tier1, 0).unwrap();
store.evict(key, ReconstructPolicy::None).unwrap();
let cc = CoherenceCheck::default();
let err = cc.check_coherence(&mut store, key, &data, 1);
assert_eq!(err, Err(StoreError::TensorEvicted));
}
#[test]
fn test_check_coherence_tight_bound_fails() {
let mut store = TieredStore::new(4096);
let key = make_key(4, 0);
// Data with large dynamic range to maximize quantization error.
let data: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 10.0).collect();
// Store at Tier3 (3-bit) for maximum quantization error.
store.put(key, &data, Tier::Tier3, 0).unwrap();
// Use an extremely tight bound that 3-bit quantization cannot meet.
let cc = CoherenceCheck::new([f32::MAX, 0.001, 0.001, 0.001]);
let result = cc.check_coherence(&mut store, key, &data, 1).unwrap();
assert_eq!(result.tier, Tier::Tier3);
assert!(
!result.passed,
"Tier3 with 0.001 bound should fail; max_error={}",
result.max_error,
);
}
// -- verify_put ---------------------------------------------------------
#[test]
fn test_verify_put_tier1() {
let mut store = TieredStore::new(4096);
let key = make_key(10, 0);
let data: Vec<f32> = (0..64).map(|i| (i as f32 + 1.0) * 0.1).collect();
let cc = CoherenceCheck::default();
let result = cc
.verify_put(&mut store, key, &data, Tier::Tier1, 0)
.unwrap();
assert_eq!(result.tier, Tier::Tier1);
assert!(result.passed, "verify_put Tier1 should pass");
assert_eq!(store.block_count(), 1);
}
#[test]
fn test_verify_put_tier0_rejected() {
let mut store = TieredStore::new(4096);
let key = make_key(11, 0);
let data = vec![1.0f32; 16];
let cc = CoherenceCheck::default();
let err = cc.verify_put(&mut store, key, &data, Tier::Tier0, 0);
assert_eq!(err, Err(StoreError::InvalidBlock));
}
#[test]
fn test_verify_put_tier2() {
let mut store = TieredStore::new(4096);
let key = make_key(12, 0);
let data: Vec<f32> = (0..64).map(|i| (i as f32 + 1.0) * 0.3).collect();
let cc = CoherenceCheck::default();
let result = cc
.verify_put(&mut store, key, &data, Tier::Tier2, 0)
.unwrap();
assert_eq!(result.tier, Tier::Tier2);
assert!(
result.passed,
"verify_put Tier2 should pass; max_error={}",
result.max_error
);
}
// -- compute_max_relative_error -----------------------------------------
#[test]
fn test_relative_error_identical() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0, 3.0];
assert_eq!(compute_max_relative_error(&a, &b), 0.0);
}
#[test]
fn test_relative_error_known() {
let original = vec![10.0, 20.0, 50.0];
let decoded = vec![10.5, 20.0, 48.0];
let err = compute_max_relative_error(&original, &decoded);
// Element 0: |0.5| / 10.0 = 0.05
// Element 1: 0.0
// Element 2: |2.0| / 50.0 = 0.04
assert!((err - 0.05).abs() < 1e-6, "expected 0.05, got {err}");
}
#[test]
fn test_relative_error_near_zero() {
// Near-zero original values should use absolute error.
let original = vec![0.0, 1e-8, 1.0];
let decoded = vec![0.001, 0.0, 1.0];
let err = compute_max_relative_error(&original, &decoded);
// Element 0: |0.001| (absolute, since orig < epsilon)
// Element 1: |1e-8| (absolute, since orig < epsilon)
// Element 2: 0.0
assert!((err - 0.001).abs() < 1e-6, "expected ~0.001, got {err}");
}
#[test]
fn test_relative_error_empty() {
assert_eq!(compute_max_relative_error(&[], &[]), 0.0);
}
#[test]
fn test_relative_error_mismatched_lengths() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0];
// Should only compare up to min(len(a), len(b)) = 2 elements.
let err = compute_max_relative_error(&a, &b);
assert_eq!(err, 0.0);
}
// -- EpochTracker -------------------------------------------------------
#[test]
fn test_epoch_tracker_new() {
let tracker = EpochTracker::new();
let key = make_key(1, 0);
assert_eq!(tracker.check_epoch(key), None);
assert!(!tracker.is_stale(key, 0));
}
#[test]
fn test_epoch_tracker_record_write() {
let mut tracker = EpochTracker::new();
let key = make_key(1, 0);
let e1 = tracker.record_write(key);
assert_eq!(e1, 1);
assert_eq!(tracker.check_epoch(key), Some(1));
let e2 = tracker.record_write(key);
assert_eq!(e2, 2);
assert_eq!(tracker.check_epoch(key), Some(2));
}
#[test]
fn test_epoch_tracker_monotonic_across_keys() {
let mut tracker = EpochTracker::new();
let key_a = make_key(1, 0);
let key_b = make_key(2, 0);
let e1 = tracker.record_write(key_a);
let e2 = tracker.record_write(key_b);
let e3 = tracker.record_write(key_a);
assert_eq!(e1, 1);
assert_eq!(e2, 2);
assert_eq!(e3, 3);
assert_eq!(tracker.check_epoch(key_a), Some(3));
assert_eq!(tracker.check_epoch(key_b), Some(2));
}
#[test]
fn test_epoch_tracker_is_stale() {
let mut tracker = EpochTracker::new();
let key = make_key(1, 0);
let epoch = tracker.record_write(key);
assert!(
!tracker.is_stale(key, epoch),
"same epoch should not be stale"
);
assert!(
!tracker.is_stale(key, epoch + 1),
"future epoch should not be stale"
);
// Write again -> epoch advances.
let _e2 = tracker.record_write(key);
assert!(
tracker.is_stale(key, epoch),
"old epoch should now be stale after a new write"
);
}
#[test]
fn test_epoch_tracker_unknown_key_not_stale() {
let tracker = EpochTracker::new();
let key = make_key(99, 0);
assert!(!tracker.is_stale(key, 0));
assert!(!tracker.is_stale(key, u64::MAX));
}
#[test]
fn test_epoch_tracker_multiple_keys_independent() {
let mut tracker = EpochTracker::new();
let key_a = make_key(1, 0);
let key_b = make_key(2, 0);
let ea = tracker.record_write(key_a);
let _eb = tracker.record_write(key_b);
// Writing key_b should not make key_a stale at its own epoch.
assert!(!tracker.is_stale(key_a, ea));
}
#[test]
fn test_epoch_tracker_default_trait() {
let tracker = EpochTracker::default();
assert_eq!(tracker.check_epoch(make_key(1, 0)), None);
}
}

View File

@@ -0,0 +1,342 @@
//! TemporalTensorCompressor: the main entry point.
//!
//! Manages temporal segments, drift detection, and tier transitions.
//! Caches f32-converted scales to avoid repeated f16 conversion in hot paths.
use crate::quantizer;
use crate::segment;
use crate::tier_policy::TierPolicy;
pub struct TemporalTensorCompressor {
policy: TierPolicy,
len: u32,
access_count: u32,
last_access_ts: u32,
active_bits: u8,
active_group_len: usize,
active_scales_f16: Vec<u16>,
active_scales_f32: Vec<f32>, // Cached f32 conversion of scales
active_frames: u32,
active_data: Vec<u8>,
}
impl TemporalTensorCompressor {
/// Create a new compressor for tensors of the given length.
pub fn new(policy: TierPolicy, len: u32, now_ts: u32) -> Self {
let bits = policy.select_bits(0, now_ts, now_ts);
Self {
policy,
len,
access_count: 0,
last_access_ts: now_ts,
active_bits: bits,
active_group_len: policy.group_len.max(1) as usize,
active_scales_f16: Vec::new(),
active_scales_f32: Vec::new(),
active_frames: 0,
active_data: Vec::new(),
}
}
/// Record an access (increments count, updates timestamp).
pub fn touch(&mut self, now_ts: u32) {
self.access_count = self.access_count.wrapping_add(1);
self.last_access_ts = now_ts;
}
/// Set access stats directly (for restoring state).
pub fn set_access(&mut self, access_count: u32, last_access_ts: u32) {
self.access_count = access_count;
self.last_access_ts = last_access_ts;
}
/// Current tier bits.
pub fn active_bits(&self) -> u8 {
self.active_bits
}
/// Number of frames in the current segment.
pub fn active_frame_count(&self) -> u32 {
self.active_frames
}
/// Current policy.
pub fn policy(&self) -> &TierPolicy {
&self.policy
}
/// Tensor length.
pub fn len(&self) -> u32 {
self.len
}
/// Returns `true` if the tensor length is zero.
pub fn is_empty(&self) -> bool {
self.len == 0
}
/// Bytes currently buffered in the active segment data.
pub fn active_data_bytes(&self) -> usize {
self.active_data.len()
}
/// Push a frame. If a segment boundary is crossed, the completed segment
/// bytes are written to `out_segment`. Otherwise `out_segment` is cleared.
pub fn push_frame(&mut self, frame: &[f32], now_ts: u32, out_segment: &mut Vec<u8>) {
out_segment.clear();
if frame.len() != self.len as usize {
return;
}
let desired_bits = self
.policy
.select_bits(self.access_count, self.last_access_ts, now_ts);
let drift_factor = self.policy.drift_factor();
// Use cached f32 scales for drift check (avoids f16 conversion per group)
let need_new_segment = self.active_frames == 0
|| desired_bits != self.active_bits
|| !quantizer::frame_fits_scales_f32(
frame,
&self.active_scales_f32,
self.active_group_len,
self.active_bits,
drift_factor,
);
if need_new_segment {
self.flush(out_segment);
self.active_bits = desired_bits;
self.active_group_len = self.policy.group_len.max(1) as usize;
self.active_scales_f16 =
quantizer::compute_scales(frame, self.active_group_len, self.active_bits);
self.active_scales_f32 = quantizer::scales_to_f32(&self.active_scales_f16);
}
// Use cached f32 scales for quantization (avoids f16 conversion per group)
quantizer::quantize_and_pack_f32(
frame,
&self.active_scales_f32,
self.active_group_len,
self.active_bits,
&mut self.active_data,
);
self.active_frames = self.active_frames.wrapping_add(1);
}
/// Flush the current segment. Writes segment bytes to `out_segment`.
/// Resets internal state for the next segment.
pub fn flush(&mut self, out_segment: &mut Vec<u8>) {
if self.active_frames == 0 {
return;
}
segment::encode(
self.active_bits,
self.active_group_len as u32,
self.len,
self.active_frames,
&self.active_scales_f16,
&self.active_data,
out_segment,
);
self.active_frames = 0;
self.active_scales_f16.clear();
self.active_scales_f32.clear();
self.active_data.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_policy() -> TierPolicy {
TierPolicy::default()
}
#[test]
fn test_create_and_push() {
let mut comp = TemporalTensorCompressor::new(default_policy(), 64, 0);
let frame = vec![1.0f32; 64];
let mut seg = Vec::new();
comp.push_frame(&frame, 0, &mut seg);
assert!(seg.is_empty()); // First frame, no completed segment
assert_eq!(comp.active_frame_count(), 1);
}
#[test]
fn test_flush_produces_segment() {
let mut comp = TemporalTensorCompressor::new(default_policy(), 64, 0);
let frame = vec![1.0f32; 64];
let mut seg = Vec::new();
comp.push_frame(&frame, 0, &mut seg);
comp.flush(&mut seg);
assert!(!seg.is_empty());
let mut decoded = Vec::new();
segment::decode(&seg, &mut decoded);
assert_eq!(decoded.len(), 64);
}
#[test]
fn test_tier_transition_flushes() {
let policy = TierPolicy {
hot_min_score: 512,
warm_min_score: 64,
warm_bits: 7,
drift_pct_q8: 26,
group_len: 64,
};
let mut comp = TemporalTensorCompressor::new(policy, 64, 0);
comp.set_access(100, 0); // Hot
let frame = vec![1.0f32; 64];
let mut seg = Vec::new();
comp.push_frame(&frame, 1, &mut seg);
assert_eq!(comp.active_bits(), 8);
// Make it cold
comp.set_access(1, 0);
comp.push_frame(&frame, 10000, &mut seg);
assert!(!seg.is_empty());
assert_eq!(comp.active_bits(), 3);
}
#[test]
fn test_drift_triggers_new_segment() {
let mut comp = TemporalTensorCompressor::new(default_policy(), 64, 0);
let mut seg = Vec::new();
let frame1 = vec![1.0f32; 64];
comp.push_frame(&frame1, 0, &mut seg);
let frame2 = vec![5.0f32; 64];
comp.push_frame(&frame2, 0, &mut seg);
assert!(!seg.is_empty());
}
#[test]
fn test_multi_frame_same_segment() {
let mut comp = TemporalTensorCompressor::new(default_policy(), 64, 0);
let mut seg = Vec::new();
let frame = vec![1.0f32; 64];
comp.push_frame(&frame, 0, &mut seg);
assert!(seg.is_empty());
let frame2 = vec![1.05f32; 64];
comp.push_frame(&frame2, 0, &mut seg);
assert!(seg.is_empty());
assert_eq!(comp.active_frame_count(), 2);
}
#[test]
fn test_full_roundtrip_hot() {
let mut comp = TemporalTensorCompressor::new(default_policy(), 128, 0);
comp.set_access(100, 0);
let frame: Vec<f32> = (0..128).map(|i| (i as f32 - 64.0) * 0.01).collect();
let mut seg = Vec::new();
for _ in 0..10 {
comp.push_frame(&frame, 1, &mut seg);
}
comp.flush(&mut seg);
let mut decoded = Vec::new();
segment::decode(&seg, &mut decoded);
assert_eq!(decoded.len(), 128 * 10);
let max_abs = frame.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
for i in 0..128 {
let err = (decoded[i] - frame[i]).abs();
assert!(
err < max_abs * 0.02,
"i={i} orig={} dec={} err={err}",
frame[i],
decoded[i]
);
}
}
#[test]
fn test_full_roundtrip_cold() {
let mut comp = TemporalTensorCompressor::new(default_policy(), 64, 0);
// Default: access_count=0, cold -> 3-bit
let frame: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
let mut seg = Vec::new();
comp.push_frame(&frame, 0, &mut seg);
comp.flush(&mut seg);
let header = segment::parse_header(&seg).unwrap();
assert_eq!(header.bits, 3);
let mut decoded = Vec::new();
segment::decode(&seg, &mut decoded);
assert_eq!(decoded.len(), 64);
let max_abs = frame.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
for (i, (&orig, &dec)) in frame.iter().zip(decoded.iter()).enumerate() {
let err = (orig - dec).abs();
// 3-bit: qmax=3, max relative error ~33%
assert!(err < max_abs * 0.4, "i={i} orig={orig} dec={dec} err={err}");
}
}
#[test]
fn test_wrong_length_frame_rejected() {
let mut comp = TemporalTensorCompressor::new(default_policy(), 64, 0);
let frame = vec![1.0f32; 32];
let mut seg = Vec::new();
comp.push_frame(&frame, 0, &mut seg);
assert_eq!(comp.active_frame_count(), 0);
}
#[test]
fn test_accessor_methods() {
let policy = TierPolicy::default();
let comp = TemporalTensorCompressor::new(policy, 256, 42);
assert_eq!(comp.len(), 256);
assert_eq!(comp.active_frame_count(), 0);
assert_eq!(comp.active_data_bytes(), 0);
assert_eq!(comp.policy().group_len, 64);
}
#[test]
fn test_large_tensor_multi_group() {
let mut comp = TemporalTensorCompressor::new(default_policy(), 512, 0);
comp.set_access(100, 0); // hot -> 8-bit
let frame: Vec<f32> = (0..512).map(|i| ((i as f32) * 0.731).sin()).collect();
let mut seg = Vec::new();
for _ in 0..50 {
comp.push_frame(&frame, 1, &mut seg);
}
comp.flush(&mut seg);
let header = segment::parse_header(&seg).unwrap();
assert_eq!(header.bits, 8);
assert_eq!(header.tensor_len, 512);
assert_eq!(header.frame_count, 50);
assert_eq!(header.scale_count, 8); // 512/64 = 8 groups
let mut decoded = Vec::new();
segment::decode(&seg, &mut decoded);
assert_eq!(decoded.len(), 512 * 50);
// Verify compression ratio
let raw = 512 * 4 * 50;
let compressed = seg.len();
let ratio = raw as f32 / compressed as f32;
assert!(ratio > 3.5, "ratio={ratio:.2}x, expected >3.5x");
}
}

View File

@@ -0,0 +1,531 @@
//! Abstract trait interface for tensor block storage.
//!
//! Defines [`TensorStore`] so that other crates can depend on a thin
//! abstraction rather than the concrete [`crate::store::TieredStore`].
//! An extension trait [`TensorStoreExt`] provides convenience helpers
//! via a blanket implementation for all `TensorStore` implementors.
#![allow(dead_code)]
use crate::store::{BlockKey, BlockMeta, ReconstructPolicy, StoreError, Tier, TieredStore};
// ---------------------------------------------------------------------------
// TensorStore trait
// ---------------------------------------------------------------------------
/// Abstract interface for a tiered tensor block store.
///
/// All methods mirror the public API of [`TieredStore`] so that higher-level
/// crates can interact with the store without depending on the concrete type.
pub trait TensorStore {
/// Quantize `data` at the bit width for `tier` and store the block.
///
/// Replaces any existing block with the same `key`.
fn put(&mut self, key: BlockKey, data: &[f32], tier: Tier, now: u64) -> Result<(), StoreError>;
/// Dequantize the block identified by `key` into `out`.
///
/// Returns the number of f32 elements written.
fn get(&mut self, key: BlockKey, out: &mut [f32], now: u64) -> Result<usize, StoreError>;
/// Update access statistics for `key` at tick `now`.
fn touch(&mut self, key: BlockKey, now: u64);
/// Evict a block to Tier0, preserving metadata with the given policy.
fn evict(&mut self, key: BlockKey, policy: ReconstructPolicy) -> Result<(), StoreError>;
/// Return a reference to the metadata for `key`, if it exists.
fn meta(&self, key: BlockKey) -> Option<&BlockMeta>;
/// Total number of blocks tracked (including Tier0 evicted blocks).
fn block_count(&self) -> usize;
/// Number of blocks currently in the given tier.
fn tier_count(&self, tier: Tier) -> usize;
/// Total bytes of quantized data stored across all active tiers.
fn total_bytes(&self) -> usize;
/// Whether a block with the given key exists in the store.
fn contains(&self, key: BlockKey) -> bool;
/// Capture a read-only snapshot of the store's current state.
fn snapshot(&self) -> TensorStoreSnapshot;
}
// ---------------------------------------------------------------------------
// TensorStore impl for TieredStore
// ---------------------------------------------------------------------------
impl TensorStore for TieredStore {
fn put(&mut self, key: BlockKey, data: &[f32], tier: Tier, now: u64) -> Result<(), StoreError> {
TieredStore::put(self, key, data, tier, now)
}
fn get(&mut self, key: BlockKey, out: &mut [f32], now: u64) -> Result<usize, StoreError> {
TieredStore::get(self, key, out, now)
}
fn touch(&mut self, key: BlockKey, now: u64) {
TieredStore::touch(self, key, now);
}
fn evict(&mut self, key: BlockKey, policy: ReconstructPolicy) -> Result<(), StoreError> {
TieredStore::evict(self, key, policy)
}
fn meta(&self, key: BlockKey) -> Option<&BlockMeta> {
TieredStore::meta(self, key)
}
fn block_count(&self) -> usize {
TieredStore::block_count(self)
}
fn tier_count(&self, tier: Tier) -> usize {
TieredStore::tier_count(self, tier)
}
fn total_bytes(&self) -> usize {
TieredStore::total_bytes(self)
}
fn contains(&self, key: BlockKey) -> bool {
TieredStore::meta(self, key).is_some()
}
fn snapshot(&self) -> TensorStoreSnapshot {
let tier_counts = [
TieredStore::tier_count(self, Tier::Tier0),
TieredStore::tier_count(self, Tier::Tier1),
TieredStore::tier_count(self, Tier::Tier2),
TieredStore::tier_count(self, Tier::Tier3),
];
// Compute per-tier byte totals from the store metrics.
let metrics = TieredStore::metrics(self);
let tier_bytes = [
0, // Tier0 holds no payload data
metrics.tier1_bytes as usize,
metrics.tier2_bytes as usize,
metrics.tier3_bytes as usize,
];
TensorStoreSnapshot {
block_count: TieredStore::block_count(self),
tier_counts,
total_bytes: TieredStore::total_bytes(self),
tier_bytes,
}
}
}
// ---------------------------------------------------------------------------
// TensorStoreSnapshot
// ---------------------------------------------------------------------------
/// Read-only snapshot of the store's current state.
///
/// Captures block counts, byte totals, and per-tier breakdowns at a single
/// point in time. Useful for monitoring, dashboards, and tiering decisions
/// that need a consistent view without holding a borrow on the store.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct TensorStoreSnapshot {
/// Total number of blocks tracked (including evicted Tier0 blocks).
pub block_count: usize,
/// Number of blocks in each tier, indexed as `[Tier0, Tier1, Tier2, Tier3]`.
pub tier_counts: [usize; 4],
/// Total bytes of quantized data across all active tiers.
pub total_bytes: usize,
/// Bytes of quantized data per tier, indexed as `[Tier0, Tier1, Tier2, Tier3]`.
pub tier_bytes: [usize; 4],
}
impl TensorStoreSnapshot {
/// Fraction of total blocks that reside in the given tier.
///
/// Returns 0.0 if the store is empty.
pub fn tier_fraction(&self, tier: Tier) -> f64 {
if self.block_count == 0 {
return 0.0;
}
self.tier_counts[tier as usize] as f64 / self.block_count as f64
}
/// Fraction of total bytes stored in the given tier.
///
/// Returns 0.0 if the store holds no data.
pub fn byte_fraction(&self, tier: Tier) -> f64 {
if self.total_bytes == 0 {
return 0.0;
}
self.tier_bytes[tier as usize] as f64 / self.total_bytes as f64
}
}
// ---------------------------------------------------------------------------
// TensorStoreExt extension trait
// ---------------------------------------------------------------------------
/// Convenience methods available on every [`TensorStore`] implementor.
pub trait TensorStoreExt: TensorStore {
/// Allocate a `Vec<f32>` of length `len` and read the block into it.
///
/// This is a convenience wrapper around [`TensorStore::get`] for callers
/// that do not want to manage the output buffer themselves.
fn get_vec(&mut self, key: BlockKey, len: usize, now: u64) -> Result<Vec<f32>, StoreError>;
/// Store a block in Tier1 (hot, 8-bit quantization).
///
/// Shorthand for `put(key, data, Tier::Tier1, now)`.
fn put_tier1(&mut self, key: BlockKey, data: &[f32], now: u64) -> Result<(), StoreError>;
/// Check whether a block has been evicted to Tier0.
///
/// Returns `false` if the block does not exist.
fn is_evicted(&self, key: BlockKey) -> bool;
}
/// Blanket implementation of [`TensorStoreExt`] for all `TensorStore` types.
impl<T: TensorStore> TensorStoreExt for T {
fn get_vec(&mut self, key: BlockKey, len: usize, now: u64) -> Result<Vec<f32>, StoreError> {
let mut buf = vec![0.0f32; len];
let n = self.get(key, &mut buf, now)?;
buf.truncate(n);
Ok(buf)
}
fn put_tier1(&mut self, key: BlockKey, data: &[f32], now: u64) -> Result<(), StoreError> {
self.put(key, data, Tier::Tier1, now)
}
fn is_evicted(&self, key: BlockKey) -> bool {
self.meta(key)
.map(|m| m.tier == Tier::Tier0)
.unwrap_or(false)
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use crate::store::{BlockKey, Tier, TieredStore};
fn make_key(tid: u128, idx: u32) -> BlockKey {
BlockKey {
tensor_id: tid,
block_index: idx,
}
}
// -- TensorStore trait delegation ----------------------------------------
#[test]
fn test_trait_put_get_roundtrip() {
let mut store = TieredStore::new(4096);
let key = make_key(1, 0);
let data: Vec<f32> = (0..64).map(|i| i as f32 * 0.25).collect();
// Use trait method
TensorStore::put(&mut store, key, &data, Tier::Tier1, 0).unwrap();
assert_eq!(TensorStore::block_count(&store), 1);
assert!(TensorStore::contains(&store, key));
let mut out = vec![0.0f32; 64];
let n = TensorStore::get(&mut store, key, &mut out, 1).unwrap();
assert_eq!(n, 64);
for (i, (&orig, &dec)) in data.iter().zip(out.iter()).enumerate() {
let err = (orig - dec).abs();
let tol = if orig.abs() > 0.01 {
orig.abs() * 0.02
} else {
0.15
};
assert!(err < tol, "i={i} orig={orig} dec={dec} err={err}");
}
}
#[test]
fn test_trait_touch_updates_access() {
let mut store = TieredStore::new(4096);
let key = make_key(1, 0);
TensorStore::put(&mut store, key, &[1.0; 16], Tier::Tier1, 0).unwrap();
let meta = TensorStore::meta(&store, key).unwrap();
assert_eq!(meta.access_count, 1);
TensorStore::touch(&mut store, key, 10);
let meta = TensorStore::meta(&store, key).unwrap();
assert_eq!(meta.access_count, 2);
assert_eq!(meta.last_access_at, 10);
}
#[test]
fn test_trait_evict() {
let mut store = TieredStore::new(4096);
let key = make_key(1, 0);
TensorStore::put(&mut store, key, &[1.0; 32], Tier::Tier1, 0).unwrap();
assert_eq!(TensorStore::tier_count(&store, Tier::Tier1), 1);
TensorStore::evict(&mut store, key, ReconstructPolicy::Delta).unwrap();
let meta = TensorStore::meta(&store, key).unwrap();
assert_eq!(meta.tier, Tier::Tier0);
assert_eq!(meta.reconstruct, ReconstructPolicy::Delta);
assert_eq!(TensorStore::tier_count(&store, Tier::Tier0), 1);
assert_eq!(TensorStore::tier_count(&store, Tier::Tier1), 0);
}
#[test]
fn test_trait_contains_false_for_missing() {
let store = TieredStore::new(4096);
assert!(!TensorStore::contains(&store, make_key(99, 0)));
}
#[test]
fn test_trait_total_bytes() {
let mut store = TieredStore::new(4096);
assert_eq!(TensorStore::total_bytes(&store), 0);
TensorStore::put(&mut store, make_key(1, 0), &[1.0; 64], Tier::Tier1, 0).unwrap();
assert!(TensorStore::total_bytes(&store) > 0);
}
// -- TensorStoreSnapshot -------------------------------------------------
#[test]
fn test_snapshot_empty_store() {
let store = TieredStore::new(4096);
let snap = TensorStore::snapshot(&store);
assert_eq!(snap.block_count, 0);
assert_eq!(snap.tier_counts, [0, 0, 0, 0]);
assert_eq!(snap.total_bytes, 0);
assert_eq!(snap.tier_bytes, [0, 0, 0, 0]);
}
#[test]
fn test_snapshot_populated_store() {
let mut store = TieredStore::new(4096);
let data = vec![1.0f32; 32];
TensorStore::put(&mut store, make_key(1, 0), &data, Tier::Tier1, 0).unwrap();
TensorStore::put(&mut store, make_key(2, 0), &data, Tier::Tier1, 0).unwrap();
TensorStore::put(&mut store, make_key(3, 0), &data, Tier::Tier2, 0).unwrap();
TensorStore::put(&mut store, make_key(4, 0), &data, Tier::Tier3, 0).unwrap();
let snap = TensorStore::snapshot(&store);
assert_eq!(snap.block_count, 4);
assert_eq!(snap.tier_counts[0], 0); // Tier0
assert_eq!(snap.tier_counts[1], 2); // Tier1
assert_eq!(snap.tier_counts[2], 1); // Tier2
assert_eq!(snap.tier_counts[3], 1); // Tier3
assert!(snap.total_bytes > 0);
assert!(snap.tier_bytes[1] > 0); // Tier1 bytes
assert!(snap.tier_bytes[2] > 0); // Tier2 bytes
assert!(snap.tier_bytes[3] > 0); // Tier3 bytes
assert_eq!(snap.tier_bytes[0], 0); // Tier0 holds no data
}
#[test]
fn test_snapshot_tier_fraction() {
let mut store = TieredStore::new(4096);
let data = vec![1.0f32; 16];
TensorStore::put(&mut store, make_key(1, 0), &data, Tier::Tier1, 0).unwrap();
TensorStore::put(&mut store, make_key(2, 0), &data, Tier::Tier1, 0).unwrap();
TensorStore::put(&mut store, make_key(3, 0), &data, Tier::Tier2, 0).unwrap();
TensorStore::put(&mut store, make_key(4, 0), &data, Tier::Tier3, 0).unwrap();
let snap = TensorStore::snapshot(&store);
assert!((snap.tier_fraction(Tier::Tier1) - 0.5).abs() < 1e-10);
assert!((snap.tier_fraction(Tier::Tier2) - 0.25).abs() < 1e-10);
assert!((snap.tier_fraction(Tier::Tier3) - 0.25).abs() < 1e-10);
assert!((snap.tier_fraction(Tier::Tier0) - 0.0).abs() < 1e-10);
}
#[test]
fn test_snapshot_tier_fraction_empty() {
let snap = TensorStoreSnapshot {
block_count: 0,
tier_counts: [0; 4],
total_bytes: 0,
tier_bytes: [0; 4],
};
assert_eq!(snap.tier_fraction(Tier::Tier1), 0.0);
}
#[test]
fn test_snapshot_byte_fraction_empty() {
let snap = TensorStoreSnapshot {
block_count: 0,
tier_counts: [0; 4],
total_bytes: 0,
tier_bytes: [0; 4],
};
assert_eq!(snap.byte_fraction(Tier::Tier1), 0.0);
}
#[test]
fn test_snapshot_after_eviction() {
let mut store = TieredStore::new(4096);
let data = vec![1.0f32; 32];
TensorStore::put(&mut store, make_key(1, 0), &data, Tier::Tier1, 0).unwrap();
TensorStore::put(&mut store, make_key(2, 0), &data, Tier::Tier2, 0).unwrap();
TensorStore::evict(&mut store, make_key(1, 0), ReconstructPolicy::None).unwrap();
let snap = TensorStore::snapshot(&store);
assert_eq!(snap.block_count, 2); // metadata preserved
assert_eq!(snap.tier_counts[0], 1); // one evicted
assert_eq!(snap.tier_counts[1], 0); // tier1 now empty
assert_eq!(snap.tier_counts[2], 1); // tier2 still has one
assert_eq!(snap.tier_bytes[0], 0); // evicted holds no data
assert_eq!(snap.tier_bytes[1], 0); // tier1 bytes gone
assert!(snap.tier_bytes[2] > 0); // tier2 bytes remain
}
// -- TensorStoreExt convenience methods ----------------------------------
#[test]
fn test_ext_get_vec() {
let mut store = TieredStore::new(4096);
let key = make_key(1, 0);
let data: Vec<f32> = (0..32).map(|i| i as f32 * 0.5).collect();
TensorStore::put(&mut store, key, &data, Tier::Tier1, 0).unwrap();
let result = TensorStoreExt::get_vec(&mut store, key, 32, 1).unwrap();
assert_eq!(result.len(), 32);
for (i, (&orig, &dec)) in data.iter().zip(result.iter()).enumerate() {
let err = (orig - dec).abs();
let tol = if orig.abs() > 0.01 {
orig.abs() * 0.05
} else {
0.15
};
assert!(err < tol, "i={i} orig={orig} dec={dec} err={err}");
}
}
#[test]
fn test_ext_get_vec_truncates_to_actual() {
let mut store = TieredStore::new(4096);
let key = make_key(1, 0);
TensorStore::put(&mut store, key, &[1.0; 16], Tier::Tier1, 0).unwrap();
// Request a larger buffer than the block contains; vec should be truncated.
let result = TensorStoreExt::get_vec(&mut store, key, 64, 1).unwrap();
assert_eq!(result.len(), 16);
}
#[test]
fn test_ext_get_vec_not_found() {
let mut store = TieredStore::new(4096);
let result = TensorStoreExt::get_vec(&mut store, make_key(99, 0), 16, 0);
assert_eq!(result, Err(StoreError::BlockNotFound));
}
#[test]
fn test_ext_put_tier1() {
let mut store = TieredStore::new(4096);
let key = make_key(1, 0);
let data = vec![2.0f32; 16];
TensorStoreExt::put_tier1(&mut store, key, &data, 0).unwrap();
let meta = TensorStore::meta(&store, key).unwrap();
assert_eq!(meta.tier, Tier::Tier1);
assert_eq!(meta.bits, 8);
}
#[test]
fn test_ext_is_evicted_false_when_active() {
let mut store = TieredStore::new(4096);
let key = make_key(1, 0);
TensorStore::put(&mut store, key, &[1.0; 8], Tier::Tier1, 0).unwrap();
assert!(!TensorStoreExt::is_evicted(&store, key));
}
#[test]
fn test_ext_is_evicted_true_after_evict() {
let mut store = TieredStore::new(4096);
let key = make_key(1, 0);
TensorStore::put(&mut store, key, &[1.0; 8], Tier::Tier1, 0).unwrap();
TensorStore::evict(&mut store, key, ReconstructPolicy::None).unwrap();
assert!(TensorStoreExt::is_evicted(&store, key));
}
#[test]
fn test_ext_is_evicted_false_when_missing() {
let store = TieredStore::new(4096);
assert!(!TensorStoreExt::is_evicted(&store, make_key(99, 0)));
}
// -- Trait object safety check -------------------------------------------
#[test]
fn test_trait_object_usable() {
let mut store = TieredStore::new(4096);
let key = make_key(1, 0);
// Ensure TensorStore can be used as a trait object for the subset
// of methods that are object-safe. Since &BlockMeta borrows prevent
// full dyn dispatch for meta(), we verify the non-borrowing methods.
fn use_store(s: &mut dyn TensorStore) -> usize {
s.block_count()
}
TensorStore::put(&mut store, key, &[1.0; 8], Tier::Tier1, 0).unwrap();
assert_eq!(use_store(&mut store), 1);
}
// -- Integration: mixed trait + ext usage --------------------------------
#[test]
fn test_integration_mixed_usage() {
let mut store = TieredStore::new(4096);
let k1 = make_key(1, 0);
let k2 = make_key(2, 0);
let k3 = make_key(3, 0);
// Insert via ext shorthand and trait method.
TensorStoreExt::put_tier1(&mut store, k1, &[1.0; 32], 0).unwrap();
TensorStore::put(&mut store, k2, &[2.0; 32], Tier::Tier2, 0).unwrap();
TensorStore::put(&mut store, k3, &[3.0; 32], Tier::Tier3, 0).unwrap();
assert_eq!(TensorStore::block_count(&store), 3);
assert!(TensorStore::contains(&store, k1));
assert!(TensorStore::contains(&store, k2));
assert!(TensorStore::contains(&store, k3));
// Evict k3 and verify via ext method.
TensorStore::evict(&mut store, k3, ReconstructPolicy::Delta).unwrap();
assert!(TensorStoreExt::is_evicted(&store, k3));
assert!(!TensorStoreExt::is_evicted(&store, k1));
// Read back via ext.
let v1 = TensorStoreExt::get_vec(&mut store, k1, 32, 10).unwrap();
assert_eq!(v1.len(), 32);
// Snapshot should reflect the current state.
let snap = TensorStore::snapshot(&store);
assert_eq!(snap.block_count, 3);
assert_eq!(snap.tier_counts[0], 1); // k3 evicted
assert_eq!(snap.tier_counts[1], 1); // k1
assert_eq!(snap.tier_counts[2], 1); // k2
assert_eq!(snap.tier_counts[3], 0); // k3 was here but evicted
}
}

View File

@@ -0,0 +1,824 @@
//! Delta compression, delta chains, and reconstruction policies (ADR-021).
//!
//! Sparse delta encoding for incremental tensor updates, bounded-depth delta
//! chain management with automatic compaction, and SVD-based low-rank factor
//! reconstruction. All structures are WASM-safe (no `f64` in hot paths).
use crate::store::StoreError;
#[allow(unused_imports)]
use crate::store::{BlockKey, ReconstructPolicy};
/// Size of the fixed portion of a serialized delta (header + scale).
const DELTA_HEADER_BYTES: usize = 34;
/// Size of a single serialized sparse entry (index: u16 + value: i16).
const DELTA_ENTRY_BYTES: usize = 4;
/// Maximum power-iteration steps per singular component.
const POWER_ITER_MAX: usize = 30;
/// Convergence threshold for power iteration.
const POWER_ITER_EPS: f32 = 1e-10;
/// Header for a delta record.
#[derive(Clone, Debug)]
pub struct DeltaHeader {
pub tensor_id: u128,
pub block_index: u32,
pub base_epoch: u64,
pub nnz: u16,
}
/// A single sparse delta entry: index + quantized value.
#[derive(Clone, Copy, Debug)]
pub struct SparseEntry {
pub index: u16,
pub value: i16,
}
/// Complete delta record: header + sparse entries + scale.
///
/// Actual diff = `entry.value as f32 * delta_scale`.
#[derive(Clone, Debug)]
pub struct DeltaRecord {
pub header: DeltaHeader,
pub delta_scale: f32,
pub entries: Vec<SparseEntry>,
}
/// Compute a sparse delta between `old` and `new` data.
///
/// Keeps entries whose absolute change exceeds `threshold`. Returns `None`
/// if the changed fraction meets or exceeds `max_change_fraction`.
///
/// # Panics
///
/// Panics if `old.len() != new.len()`.
pub fn compute_delta(
old: &[f32],
new: &[f32],
tensor_id: u128,
block_index: u32,
base_epoch: u64,
threshold: f32,
max_change_fraction: f32,
) -> Option<DeltaRecord> {
assert_eq!(old.len(), new.len(), "old and new must have equal length");
let n = old.len();
if n == 0 {
return Some(DeltaRecord {
header: DeltaHeader {
tensor_id,
block_index,
base_epoch,
nnz: 0,
},
delta_scale: 0.0,
entries: Vec::new(),
});
}
let mut changed: Vec<(u16, f32)> = Vec::new();
let mut max_abs = 0.0f32;
for i in 0..n {
let diff = new[i] - old[i];
if diff.abs() >= threshold {
changed.push((i as u16, diff));
if diff.abs() > max_abs {
max_abs = diff.abs();
}
}
}
if changed.len() as f32 / n as f32 >= max_change_fraction {
return None;
}
let delta_scale = if max_abs == 0.0 {
1.0
} else {
max_abs / i16::MAX as f32
};
let inv_scale = 1.0 / delta_scale;
let entries: Vec<SparseEntry> = changed
.iter()
.map(|&(idx, diff)| {
let q = (diff * inv_scale).round() as i32;
SparseEntry {
index: idx,
value: q.clamp(i16::MIN as i32, i16::MAX as i32) as i16,
}
})
.collect();
Some(DeltaRecord {
header: DeltaHeader {
tensor_id,
block_index,
base_epoch,
nnz: entries.len() as u16,
},
delta_scale,
entries,
})
}
/// Apply a delta to a base data vector in-place.
///
/// Entries whose indices exceed the base length are silently skipped.
pub fn apply_delta(base: &mut [f32], delta: &DeltaRecord) {
let scale = delta.delta_scale;
for entry in &delta.entries {
let idx = entry.index as usize;
if idx < base.len() {
base[idx] += entry.value as f32 * scale;
}
}
}
/// A chain of deltas applied to a base block.
/// Invariant: `deltas.len() <= max_chain_len`.
#[derive(Clone, Debug)]
pub struct DeltaChain {
base_data: Vec<f32>,
deltas: Vec<DeltaRecord>,
max_chain_len: u8,
}
impl DeltaChain {
/// Create a new chain with a base block.
pub fn new(base_data: Vec<f32>, max_chain_len: u8) -> Self {
Self {
base_data,
deltas: Vec::new(),
max_chain_len,
}
}
/// Append a delta. Returns `Err(StoreError::DeltaChainTooLong)` at max length.
pub fn append(&mut self, delta: DeltaRecord) -> Result<(), StoreError> {
if self.deltas.len() >= self.max_chain_len as usize {
return Err(StoreError::DeltaChainTooLong);
}
self.deltas.push(delta);
Ok(())
}
/// Reconstruct the current state by applying all deltas to the base.
pub fn reconstruct(&self) -> Vec<f32> {
let mut result = self.base_data.clone();
for delta in &self.deltas {
apply_delta(&mut result, delta);
}
result
}
/// Compact the chain: apply all deltas to base, clear delta list.
pub fn compact(&mut self) {
if self.deltas.is_empty() {
return;
}
for delta in &self.deltas {
apply_delta(&mut self.base_data, delta);
}
self.deltas.clear();
}
/// Number of deltas in the chain.
#[inline]
pub fn chain_len(&self) -> usize {
self.deltas.len()
}
/// Whether the chain needs compaction (at max length).
#[inline]
pub fn needs_compaction(&self) -> bool {
self.deltas.len() >= self.max_chain_len as usize
}
/// Total storage bytes: base + serialized size of all deltas.
pub fn total_bytes(&self) -> usize {
let base_bytes = self.base_data.len() * 4;
let delta_bytes: usize = self
.deltas
.iter()
.map(|d| DELTA_HEADER_BYTES + d.entries.len() * DELTA_ENTRY_BYTES)
.sum();
base_bytes + delta_bytes
}
}
/// Low-rank factor representation for reconstruction.
///
/// Stores U (m x k), S (k), V (k x n) such that data ~ U * diag(S) * V.
/// All matrices are row-major.
#[derive(Clone, Debug)]
pub struct FactorSet {
pub m: usize,
pub n: usize,
pub k: usize,
pub u_data: Vec<f32>, // m * k elements
pub s_data: Vec<f32>, // k elements
pub v_data: Vec<f32>, // k * n elements
}
impl FactorSet {
/// Reconstruct the full data from factors: U * diag(S) * V.
pub fn reconstruct(&self) -> Vec<f32> {
let mut out = vec![0.0f32; self.m * self.n];
for r in 0..self.k {
let s_r = self.s_data[r];
for i in 0..self.m {
let u_s = self.u_data[i * self.k + r] * s_r;
let row = i * self.n;
let v_off = r * self.n;
for j in 0..self.n {
out[row + j] += u_s * self.v_data[v_off + j];
}
}
}
out
}
/// Compute storage size in bytes: (m*k + k + k*n) * 4.
pub fn storage_bytes(&self) -> usize {
(self.m * self.k + self.k + self.k * self.n) * 4
}
/// Create from a flat data vector using truncated SVD via power iteration.
///
/// Simplified implementation suitable for moderate-sized matrices.
/// Extracts top-`rank` singular triplets with successive deflation.
///
/// # Panics
///
/// Panics if `data.len() != rows * cols`.
pub fn from_data(data: &[f32], rows: usize, cols: usize, rank: usize) -> Self {
assert_eq!(
data.len(),
rows * cols,
"data length must equal rows * cols"
);
let (m, n) = (rows, cols);
let k = rank.min(m).min(n);
let mut work = data.to_vec();
let mut u_data = vec![0.0f32; m * k];
let mut s_data = vec![0.0f32; k];
let mut v_data = vec![0.0f32; k * n];
for r in 0..k {
// Deterministic initial vector: Fibonacci-hash sign pattern.
let inv_sqrt_n = 1.0 / (n as f32).sqrt();
let mut v = vec![0.0f32; n];
for j in 0..n {
let seed = (j as u32)
.wrapping_mul(2_654_435_761)
.wrapping_add((r as u32).wrapping_mul(0x9E37_79B9));
v[j] = if seed & 1 == 0 {
inv_sqrt_n
} else {
-inv_sqrt_n
};
}
let mut u = vec![0.0f32; m];
let mut sigma = 0.0f32;
for _ in 0..POWER_ITER_MAX {
// u = work * v
for i in 0..m {
let mut acc = 0.0f32;
let row = i * n;
for j in 0..n {
acc += work[row + j] * v[j];
}
u[i] = acc;
}
let su: f32 = u.iter().map(|x| x * x).sum::<f32>().sqrt();
if su < POWER_ITER_EPS {
sigma = 0.0;
break;
}
let inv = 1.0 / su;
for x in u.iter_mut() {
*x *= inv;
}
// v = work^T * u
for j in 0..n {
let mut acc = 0.0f32;
for i in 0..m {
acc += work[i * n + j] * u[i];
}
v[j] = acc;
}
let sv: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if sv < POWER_ITER_EPS {
sigma = su;
break;
}
sigma = sv;
let inv = 1.0 / sv;
for x in v.iter_mut() {
*x *= inv;
}
}
s_data[r] = sigma;
for i in 0..m {
u_data[i * k + r] = u[i];
}
for j in 0..n {
v_data[r * n + j] = v[j];
}
// Deflate: work -= sigma * u * v^T
if sigma > POWER_ITER_EPS {
for i in 0..m {
let us = u[i] * sigma;
let row = i * n;
for j in 0..n {
work[row + j] -= us * v[j];
}
}
}
}
Self {
m,
n,
k,
u_data,
s_data,
v_data,
}
}
/// Compute the relative reconstruction error (Frobenius norm).
///
/// Returns `||original - reconstructed|| / ||original||`.
/// Returns 0.0 if the original has zero norm.
pub fn reconstruction_error(&self, original: &[f32]) -> f32 {
let reconstructed = self.reconstruct();
let mut diff_sq = 0.0f32;
let mut orig_sq = 0.0f32;
for (i, &o) in original.iter().enumerate() {
let r = if i < reconstructed.len() {
reconstructed[i]
} else {
0.0
};
diff_sq += (o - r) * (o - r);
orig_sq += o * o;
}
if orig_sq < 1e-30 {
return 0.0;
}
(diff_sq / orig_sq).sqrt()
}
/// Estimate the fraction of total energy (Frobenius norm) captured by factors.
///
/// Uses `sum(s_i^2)` as captured energy. Requires the original data to compute
/// total energy as `||data||_F^2`. Returns 1.0 if total energy is near zero.
pub fn energy_captured(&self, original: &[f32]) -> f32 {
let total_energy: f32 = original.iter().map(|x| x * x).sum();
if total_energy < 1e-30 {
return 1.0;
}
let captured: f32 = self.s_data.iter().map(|s| s * s).sum();
(captured / total_energy).min(1.0)
}
/// Compression ratio: original_elements * 4 bytes / storage_bytes.
///
/// Returns 0.0 if storage_bytes is zero.
pub fn compression_ratio(&self, original_elements: usize) -> f32 {
let raw = original_elements * 4;
let stored = self.storage_bytes();
if stored == 0 {
return 0.0;
}
raw as f32 / stored as f32
}
/// Create factors with adaptive rank selection.
///
/// Starts with rank 1 and increases until either `max_rank` is reached or
/// the reconstruction error falls below `target_error`.
pub fn from_data_adaptive(
data: &[f32],
rows: usize,
cols: usize,
max_rank: usize,
target_error: f32,
) -> Self {
let max_k = max_rank.min(rows).min(cols);
let mut best = Self::from_data(data, rows, cols, 1);
for rank in 2..=max_k {
let err = best.reconstruction_error(data);
if err <= target_error {
break;
}
best = Self::from_data(data, rows, cols, rank);
}
best
}
}
/// Encode a [`DeltaRecord`] to bytes (little-endian, ADR-021 section 4.1).
pub fn encode_delta(delta: &DeltaRecord) -> Vec<u8> {
let mut buf = Vec::with_capacity(DELTA_HEADER_BYTES + delta.entries.len() * DELTA_ENTRY_BYTES);
buf.extend_from_slice(&delta.header.tensor_id.to_le_bytes());
buf.extend_from_slice(&delta.header.block_index.to_le_bytes());
buf.extend_from_slice(&delta.header.base_epoch.to_le_bytes());
buf.extend_from_slice(&delta.header.nnz.to_le_bytes());
buf.extend_from_slice(&delta.delta_scale.to_le_bytes());
for entry in &delta.entries {
buf.extend_from_slice(&entry.index.to_le_bytes());
buf.extend_from_slice(&entry.value.to_le_bytes());
}
buf
}
/// Decode a [`DeltaRecord`] from bytes.
///
/// Returns `Err(StoreError::InvalidBlock)` on truncated or malformed input.
pub fn decode_delta(data: &[u8]) -> Result<DeltaRecord, StoreError> {
if data.len() < DELTA_HEADER_BYTES {
return Err(StoreError::InvalidBlock);
}
let tensor_id = u128::from_le_bytes(
data[0..16]
.try_into()
.map_err(|_| StoreError::InvalidBlock)?,
);
let block_index = u32::from_le_bytes(
data[16..20]
.try_into()
.map_err(|_| StoreError::InvalidBlock)?,
);
let base_epoch = u64::from_le_bytes(
data[20..28]
.try_into()
.map_err(|_| StoreError::InvalidBlock)?,
);
let nnz = u16::from_le_bytes(
data[28..30]
.try_into()
.map_err(|_| StoreError::InvalidBlock)?,
);
let delta_scale = f32::from_le_bytes(
data[30..34]
.try_into()
.map_err(|_| StoreError::InvalidBlock)?,
);
if data.len() < DELTA_HEADER_BYTES + (nnz as usize) * DELTA_ENTRY_BYTES {
return Err(StoreError::InvalidBlock);
}
let mut entries = Vec::with_capacity(nnz as usize);
let mut off = DELTA_HEADER_BYTES;
for _ in 0..nnz {
let index = u16::from_le_bytes(
data[off..off + 2]
.try_into()
.map_err(|_| StoreError::InvalidBlock)?,
);
let value = i16::from_le_bytes(
data[off + 2..off + 4]
.try_into()
.map_err(|_| StoreError::InvalidBlock)?,
);
entries.push(SparseEntry { index, value });
off += DELTA_ENTRY_BYTES;
}
Ok(DeltaRecord {
header: DeltaHeader {
tensor_id,
block_index,
base_epoch,
nnz,
},
delta_scale,
entries,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn make_delta(entries: Vec<(u16, i16)>, scale: f32) -> DeltaRecord {
let sparse: Vec<SparseEntry> = entries
.iter()
.map(|&(i, v)| SparseEntry { index: i, value: v })
.collect();
DeltaRecord {
header: DeltaHeader {
tensor_id: 42,
block_index: 0,
base_epoch: 1,
nnz: sparse.len() as u16,
},
delta_scale: scale,
entries: sparse,
}
}
#[test]
fn test_compute_delta_small_change() {
let old = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mut new = old.clone();
new[2] = 3.5;
let d = compute_delta(&old, &new, 1, 0, 0, 0.01, 0.5).unwrap();
assert_eq!(d.entries.len(), 1);
assert_eq!(d.entries[0].index, 2);
assert!(d.delta_scale > 0.0);
}
#[test]
fn test_compute_delta_large_change_returns_none() {
let old = vec![1.0; 10];
let new = vec![5.0; 10];
assert!(compute_delta(&old, &new, 1, 0, 0, 0.01, 0.5).is_none());
}
#[test]
fn test_apply_delta_modifies_base() {
let mut base = vec![1.0, 2.0, 3.0, 4.0];
apply_delta(&mut base, &make_delta(vec![(1, 100), (3, -50)], 0.01));
assert!((base[0] - 1.0).abs() < 1e-6);
assert!((base[1] - 3.0).abs() < 1e-6); // 2.0 + 100*0.01
assert!((base[2] - 3.0).abs() < 1e-6);
assert!((base[3] - 3.5).abs() < 1e-6); // 4.0 - 50*0.01
}
#[test]
fn test_chain_append_and_reconstruct() {
let mut chain = DeltaChain::new(vec![1.0, 2.0, 3.0, 4.0], 4);
chain.append(make_delta(vec![(0, 1000)], 0.001)).unwrap(); // +1.0
assert_eq!(chain.chain_len(), 1);
let r = chain.reconstruct();
assert!((r[0] - 2.0).abs() < 1e-3);
assert!((r[1] - 2.0).abs() < 1e-6);
}
#[test]
fn test_chain_compact_preserves_state() {
let mut chain = DeltaChain::new(vec![0.0; 4], 8);
chain.append(make_delta(vec![(0, 100)], 0.1)).unwrap(); // +10.0
chain.append(make_delta(vec![(1, 200)], 0.1)).unwrap(); // +20.0
let before = chain.reconstruct();
chain.compact();
assert_eq!(chain.chain_len(), 0);
let after = chain.reconstruct();
for (a, b) in before.iter().zip(after.iter()) {
assert!((a - b).abs() < 1e-6);
}
}
#[test]
fn test_chain_max_length_enforcement() {
let mut chain = DeltaChain::new(vec![1.0; 4], 2);
assert!(chain.append(make_delta(vec![(0, 1)], 0.1)).is_ok());
assert!(chain.append(make_delta(vec![(1, 1)], 0.1)).is_ok());
assert!(chain.append(make_delta(vec![(2, 1)], 0.1)).is_err());
}
#[test]
fn test_chain_needs_compaction() {
let mut chain = DeltaChain::new(vec![1.0; 4], 2);
assert!(!chain.needs_compaction());
chain.append(make_delta(vec![(0, 1)], 0.1)).unwrap();
assert!(!chain.needs_compaction());
chain.append(make_delta(vec![(1, 1)], 0.1)).unwrap();
assert!(chain.needs_compaction());
}
#[test]
fn test_factor_reconstruct() {
let (u, v, s) = (vec![1.0, 2.0, 3.0], vec![4.0, 5.0], 2.0);
let f = FactorSet {
m: 3,
n: 2,
k: 1,
u_data: u.clone(),
s_data: vec![s],
v_data: v.clone(),
};
let r = f.reconstruct();
assert_eq!(r.len(), 6);
for i in 0..3 {
for j in 0..2 {
assert!((r[i * 2 + j] - u[i] * s * v[j]).abs() < 1e-6);
}
}
}
#[test]
fn test_factor_from_data_approximation() {
let (m, n) = (8, 6);
let data: Vec<f32> = (0..m * n)
.map(|idx| {
let (i, j) = (idx / n, idx % n);
(i as f32 + 1.0) * (j as f32 + 1.0)
})
.collect();
let reconstructed = FactorSet::from_data(&data, m, n, 1).reconstruct();
let max_err = data
.iter()
.zip(reconstructed.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
max_err < 0.5,
"max error {max_err} too large for rank-1 input"
);
}
#[test]
fn test_encode_decode_roundtrip() {
let orig = DeltaRecord {
header: DeltaHeader {
tensor_id: 0xDEADBEEFCAFEBABE,
block_index: 42,
base_epoch: 100,
nnz: 3,
},
delta_scale: 0.001,
entries: vec![
SparseEntry {
index: 10,
value: 500,
},
SparseEntry {
index: 20,
value: -300,
},
SparseEntry {
index: 30,
value: 1,
},
],
};
let bytes = encode_delta(&orig);
assert_eq!(bytes.len(), DELTA_HEADER_BYTES + 3 * DELTA_ENTRY_BYTES);
let dec = decode_delta(&bytes).unwrap();
assert_eq!(dec.header.tensor_id, orig.header.tensor_id);
assert_eq!(dec.header.block_index, orig.header.block_index);
assert_eq!(dec.header.nnz, orig.header.nnz);
assert!((dec.delta_scale - orig.delta_scale).abs() < 1e-10);
for (a, b) in dec.entries.iter().zip(orig.entries.iter()) {
assert_eq!(a.index, b.index);
assert_eq!(a.value, b.value);
}
}
#[test]
fn test_decode_truncated_header() {
assert!(decode_delta(&vec![0u8; 20]).is_err());
}
#[test]
fn test_decode_truncated_entries() {
let mut bytes = encode_delta(&make_delta(vec![(0, 1), (1, 2)], 1.0));
bytes[28] = 5;
bytes[29] = 0; // claim 5 entries, only 2 present
assert!(decode_delta(&bytes).is_err());
}
#[test]
fn test_empty_delta_roundtrip() {
let d = DeltaRecord {
header: DeltaHeader {
tensor_id: 99,
block_index: 7,
base_epoch: 50,
nnz: 0,
},
delta_scale: 0.0,
entries: Vec::new(),
};
let dec = decode_delta(&encode_delta(&d)).unwrap();
assert_eq!(dec.entries.len(), 0);
}
#[test]
fn test_single_entry_delta() {
let old = vec![1.0; 100];
let mut new = old.clone();
new[50] = 2.0;
let d = compute_delta(&old, &new, 1, 0, 0, 0.01, 0.5).unwrap();
assert_eq!(d.entries.len(), 1);
assert_eq!(d.entries[0].index, 50);
let mut base = old.clone();
apply_delta(&mut base, &d);
assert!((base[50] - 2.0).abs() < 0.01);
}
#[test]
fn test_full_density_delta() {
let old = vec![0.0; 4];
let new = vec![0.1, 0.2, 0.3, 0.4];
let d = compute_delta(&old, &new, 1, 0, 0, 0.001, 1.1).unwrap();
assert_eq!(d.entries.len(), 4);
let mut base = old.clone();
apply_delta(&mut base, &d);
for i in 0..4 {
assert!((base[i] - new[i]).abs() < 0.01, "index {i}");
}
}
#[test]
fn test_compute_apply_roundtrip_64() {
let old: Vec<f32> = (0..64).map(|i| i as f32 * 0.1).collect();
let mut new = old.clone();
new[5] += 0.5;
new[10] -= 0.3;
new[60] += 1.0;
let d = compute_delta(&old, &new, 1, 0, 0, 0.01, 0.5).unwrap();
let mut recon = old.clone();
apply_delta(&mut recon, &d);
for i in 0..64 {
assert!((recon[i] - new[i]).abs() < 0.01, "index {i}");
}
}
#[test]
fn test_reconstruction_error_zero_for_exact() {
// Rank-1 data should be exactly reconstructed with rank-1 factors
let (m, n) = (4, 3);
let data: Vec<f32> = (0..m * n)
.map(|idx| {
let (i, j) = (idx / n, idx % n);
(i as f32 + 1.0) * (j as f32 + 1.0)
})
.collect();
let factors = FactorSet::from_data(&data, m, n, 1);
let err = factors.reconstruction_error(&data);
assert!(err < 0.01, "err={err} too large for rank-1 data");
}
#[test]
fn test_reconstruction_error_decreases_with_rank() {
let (m, n) = (8, 6);
let data: Vec<f32> = (0..m * n).map(|i| (i as f32 * 0.7).sin()).collect();
let err1 = FactorSet::from_data(&data, m, n, 1).reconstruction_error(&data);
let err3 = FactorSet::from_data(&data, m, n, 3).reconstruction_error(&data);
assert!(err3 <= err1 + 1e-6, "err3={err3} > err1={err1}");
}
#[test]
fn test_energy_captured_rank1_data() {
let (m, n) = (4, 3);
let data: Vec<f32> = (0..m * n)
.map(|idx| {
let (i, j) = (idx / n, idx % n);
(i as f32 + 1.0) * (j as f32 + 1.0)
})
.collect();
let factors = FactorSet::from_data(&data, m, n, 1);
let energy = factors.energy_captured(&data);
assert!(energy > 0.95, "energy={energy} too low for rank-1 data");
}
#[test]
fn test_compression_ratio_meaningful() {
let (m, n) = (16, 16);
let data: Vec<f32> = (0..m * n).map(|i| i as f32).collect();
let factors = FactorSet::from_data(&data, m, n, 2);
let ratio = factors.compression_ratio(m * n);
// rank-2 storage: (16*2 + 2 + 2*16) * 4 = 264 bytes vs 16*16*4 = 1024 bytes
assert!(ratio > 1.0, "ratio={ratio} should be > 1");
}
#[test]
fn test_from_data_adaptive_stops_early() {
let (m, n) = (4, 3);
// Rank-1 data: adaptive should stop at rank 1
let data: Vec<f32> = (0..m * n)
.map(|idx| {
let (i, j) = (idx / n, idx % n);
(i as f32 + 1.0) * (j as f32 + 1.0)
})
.collect();
let factors = FactorSet::from_data_adaptive(&data, m, n, 5, 0.05);
// Should use rank 1 since data is rank 1
assert!(
factors.k <= 2,
"k={} should be small for rank-1 data",
factors.k
);
}
#[test]
fn test_from_data_adaptive_increases_rank() {
let (m, n) = (8, 6);
// Multi-rank data
let data: Vec<f32> = (0..m * n)
.map(|i| (i as f32 * 0.3).sin() + (i as f32 * 0.7).cos())
.collect();
let factors = FactorSet::from_data_adaptive(&data, m, n, 6, 0.01);
let err = factors.reconstruction_error(&data);
// Should achieve close to target error or use max rank
assert!(err < 0.1 || factors.k == 6, "err={err}, k={}", factors.k);
}
}

View File

@@ -0,0 +1,150 @@
//! Software IEEE 754 half-precision (f16) conversion.
//!
//! No external crate dependencies. Handles normals, denormals, infinity, and NaN.
//! Round-to-nearest with ties-to-even for normal values.
/// Convert f32 to f16 bit representation.
///
/// Handles all IEEE 754 special cases: infinity, NaN, denormals, and zero (both signs).
/// Values outside f16 range saturate to infinity. Values too small for f16 denormals
/// flush to zero.
#[inline]
pub fn f32_to_f16_bits(x: f32) -> u16 {
let b = x.to_bits();
let sign = ((b >> 16) & 0x8000) as u16;
let exp = ((b >> 23) & 0xFF) as i32;
let mant = b & 0x7F_FFFF;
// Infinity or NaN
if exp == 255 {
if mant == 0 {
return sign | 0x7C00;
}
let nan_m = (mant >> 13) as u16;
return sign | 0x7C00 | nan_m | 1;
}
let exp16 = exp - 127 + 15;
// Overflow -> Infinity
if exp16 >= 31 {
return sign | 0x7C00;
}
// Underflow -> denormal or zero
if exp16 <= 0 {
if exp16 < -10 {
return sign;
}
let shift = (14 - exp16) as u32;
let mut mant32 = mant | 0x80_0000;
let round_bit = 1u32.wrapping_shl(shift.wrapping_sub(1));
mant32 = mant32.wrapping_add(round_bit);
let sub = (mant32 >> shift) as u16;
return sign | sub;
}
// Normal case
let mant16 = (mant >> 13) as u16;
let round = (mant >> 12) & 1;
let mut res = sign | ((exp16 as u16) << 10) | mant16;
if round != 0 {
res = res.wrapping_add(1);
}
res
}
/// Convert f16 bit representation to f32.
///
/// Exactly reconstructs the f32 value represented by the f16 bit pattern.
/// Handles denormals by normalizing the mantissa before constructing the f32 bits.
#[inline]
pub fn f16_bits_to_f32(h: u16) -> f32 {
let sign = ((h & 0x8000) as u32) << 16;
let exp = ((h >> 10) & 0x1F) as i32;
let mant = (h & 0x03FF) as u32;
// Zero or denormal
if exp == 0 {
if mant == 0 {
return f32::from_bits(sign);
}
let mut e = 1i32;
let mut m = mant;
while (m & 0x0400) == 0 {
m <<= 1;
e += 1;
}
m &= 0x03FF;
let exp32 = 127 - 15 - e + 1;
let mant32 = m << 13;
return f32::from_bits(sign | ((exp32 as u32) << 23) | mant32);
}
// Infinity or NaN
if exp == 31 {
return f32::from_bits(sign | 0x7F80_0000 | (mant << 13));
}
// Normal
let exp32 = exp - 15 + 127;
let mant32 = mant << 13;
f32::from_bits(sign | ((exp32 as u32) << 23) | mant32)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_roundtrip_normal() {
for &v in &[0.0f32, 1.0, -1.0, 0.5, 65504.0, -65504.0, 0.0001] {
let h = f32_to_f16_bits(v);
let back = f16_bits_to_f32(h);
if v == 0.0 {
assert_eq!(back, 0.0);
} else {
let rel_err = ((back - v) / v).abs();
assert!(rel_err < 0.01, "v={v}, back={back}, rel_err={rel_err}");
}
}
}
#[test]
fn test_infinity() {
let h = f32_to_f16_bits(f32::INFINITY);
assert_eq!(h, 0x7C00);
assert!(f16_bits_to_f32(h).is_infinite());
}
#[test]
fn test_neg_infinity() {
let h = f32_to_f16_bits(f32::NEG_INFINITY);
assert_eq!(h, 0xFC00);
let back = f16_bits_to_f32(h);
assert!(back.is_infinite() && back < 0.0);
}
#[test]
fn test_nan() {
let h = f32_to_f16_bits(f32::NAN);
assert!(f16_bits_to_f32(h).is_nan());
}
#[test]
fn test_zero_signs() {
assert_eq!(f32_to_f16_bits(0.0f32), 0x0000);
assert_eq!(f32_to_f16_bits(-0.0f32), 0x8000);
}
#[test]
fn test_scale_range_accuracy() {
for exp in -4..=4i32 {
let v = 10.0f32.powi(exp);
let h = f32_to_f16_bits(v);
let back = f16_bits_to_f32(h);
let rel_err = ((back - v) / v).abs();
assert!(rel_err < 0.002, "v={v}, back={back}, rel_err={rel_err}");
}
}
}

View File

@@ -0,0 +1,250 @@
//! WASM/C FFI interface with handle-based resource management.
//!
//! Exports `extern "C"` functions for:
//! - Compressor lifecycle (`ttc_create`, `ttc_free`, `ttc_touch`, `ttc_set_access`)
//! - Frame compression (`ttc_push_frame`, `ttc_flush`)
//! - Segment decoding (`ttc_decode_segment`)
//! - Memory management (`ttc_alloc`, `ttc_dealloc`)
use crate::compressor::TemporalTensorCompressor;
use crate::segment;
use crate::tier_policy::TierPolicy;
static mut STORE: Option<Vec<Option<TemporalTensorCompressor>>> = None;
fn store_init() {
unsafe {
if STORE.is_none() {
STORE = Some(Vec::new());
}
}
}
fn with_store<F, R>(f: F) -> R
where
F: FnOnce(&mut Vec<Option<TemporalTensorCompressor>>) -> R,
{
store_init();
unsafe { f(STORE.as_mut().unwrap()) }
}
fn with_compressor<F>(handle: u32, f: F)
where
F: FnOnce(&mut TemporalTensorCompressor),
{
with_store(|store| {
let idx = handle as usize;
if idx < store.len() {
if let Some(comp) = store[idx].as_mut() {
f(comp);
}
}
});
}
/// Create a new compressor. Returns handle via out_handle.
#[no_mangle]
pub extern "C" fn ttc_create(len: u32, now_ts: u32, out_handle: *mut u32) {
let policy = TierPolicy::default();
let comp = TemporalTensorCompressor::new(policy, len, now_ts);
with_store(|store| {
// Find a free slot
for (i, slot) in store.iter_mut().enumerate() {
if slot.is_none() {
*slot = Some(comp);
if !out_handle.is_null() {
unsafe { *out_handle = i as u32 };
}
return;
}
}
// No free slot, push
let idx = store.len();
store.push(Some(comp));
if !out_handle.is_null() {
unsafe { *out_handle = idx as u32 };
}
});
}
/// Create a compressor with custom policy parameters.
#[no_mangle]
pub extern "C" fn ttc_create_with_policy(
len: u32,
now_ts: u32,
hot_min_score: u32,
warm_min_score: u32,
warm_bits: u8,
drift_pct_q8: u32,
group_len: u32,
out_handle: *mut u32,
) {
let policy = TierPolicy {
hot_min_score,
warm_min_score,
warm_bits,
drift_pct_q8,
group_len,
};
let comp = TemporalTensorCompressor::new(policy, len, now_ts);
with_store(|store| {
for (i, slot) in store.iter_mut().enumerate() {
if slot.is_none() {
*slot = Some(comp);
if !out_handle.is_null() {
unsafe { *out_handle = i as u32 };
}
return;
}
}
let idx = store.len();
store.push(Some(comp));
if !out_handle.is_null() {
unsafe { *out_handle = idx as u32 };
}
});
}
/// Free a compressor.
#[no_mangle]
pub extern "C" fn ttc_free(handle: u32) {
with_store(|store| {
let idx = handle as usize;
if idx < store.len() {
store[idx] = None;
}
});
}
/// Record an access event.
#[no_mangle]
pub extern "C" fn ttc_touch(handle: u32, now_ts: u32) {
with_compressor(handle, |comp| comp.touch(now_ts));
}
/// Set access stats directly.
#[no_mangle]
pub extern "C" fn ttc_set_access(handle: u32, access_count: u32, last_access_ts: u32) {
with_compressor(handle, |comp| comp.set_access(access_count, last_access_ts));
}
/// Push a frame. If a segment boundary is crossed, the completed segment
/// is written to out_ptr/out_cap, and out_written is set to the byte count.
#[no_mangle]
pub extern "C" fn ttc_push_frame(
handle: u32,
now_ts: u32,
in_ptr: *const f32,
len: u32,
out_ptr: *mut u8,
out_cap: u32,
out_written: *mut u32,
) {
if out_written.is_null() {
return;
}
unsafe { *out_written = 0 };
if in_ptr.is_null() || out_ptr.is_null() {
return;
}
let frame = unsafe { std::slice::from_raw_parts(in_ptr, len as usize) };
let mut seg = Vec::new();
with_compressor(handle, |comp| {
comp.push_frame(frame, now_ts, &mut seg);
});
if seg.is_empty() || (seg.len() as u32) > out_cap {
return;
}
unsafe {
let out = std::slice::from_raw_parts_mut(out_ptr, out_cap as usize);
out[..seg.len()].copy_from_slice(&seg);
*out_written = seg.len() as u32;
}
}
/// Flush the current segment.
#[no_mangle]
pub extern "C" fn ttc_flush(handle: u32, out_ptr: *mut u8, out_cap: u32, out_written: *mut u32) {
if out_written.is_null() {
return;
}
unsafe { *out_written = 0 };
let mut seg = Vec::new();
with_compressor(handle, |comp| {
comp.flush(&mut seg);
});
if seg.is_empty() || out_ptr.is_null() || (seg.len() as u32) > out_cap {
return;
}
unsafe {
let out = std::slice::from_raw_parts_mut(out_ptr, out_cap as usize);
out[..seg.len()].copy_from_slice(&seg);
*out_written = seg.len() as u32;
}
}
/// Decode a segment into f32 values.
#[no_mangle]
pub extern "C" fn ttc_decode_segment(
seg_ptr: *const u8,
seg_len: u32,
out_ptr: *mut f32,
out_cap_f32: u32,
out_written_f32: *mut u32,
) {
if out_written_f32.is_null() {
return;
}
unsafe { *out_written_f32 = 0 };
if seg_ptr.is_null() || out_ptr.is_null() {
return;
}
let seg = unsafe { std::slice::from_raw_parts(seg_ptr, seg_len as usize) };
let mut values = Vec::new();
segment::decode(seg, &mut values);
if values.is_empty() || (values.len() as u32) > out_cap_f32 {
return;
}
unsafe {
let out = std::slice::from_raw_parts_mut(out_ptr, out_cap_f32 as usize);
out[..values.len()].copy_from_slice(&values);
*out_written_f32 = values.len() as u32;
}
}
/// Allocate a buffer in WASM linear memory.
#[no_mangle]
pub extern "C" fn ttc_alloc(size: u32, out_ptr: *mut u32) {
if out_ptr.is_null() {
return;
}
let mut v: Vec<u8> = Vec::with_capacity(size as usize);
let p = v.as_mut_ptr();
std::mem::forget(v);
unsafe {
*out_ptr = p as u32;
}
}
/// Free a buffer previously allocated with ttc_alloc.
#[no_mangle]
pub extern "C" fn ttc_dealloc(ptr: u32, cap: u32) {
if ptr == 0 || cap == 0 {
return;
}
unsafe {
let _ = Vec::<u8>::from_raw_parts(ptr as *mut u8, 0, cap as usize);
}
}

View File

@@ -0,0 +1,99 @@
//! Temporal Tensor Compression with Tiered Quantization
//!
//! Implements ADR-017: groupwise symmetric quantization with temporal segment
//! reuse and access-pattern-driven tier selection (8/7/5/3 bit).
//!
//! # Architecture
//!
//! ```text
//! f32 frame → tier_policy → quantizer → bitpack → segment
//! segment → bitpack → quantizer → f32 output
//! ```
//!
//! # Compression Ratios
//!
//! | Tier | Bits | Ratio vs f32 | Use Case |
//! |------|------|-------------|----------|
//! | Hot | 8 | ~4.0x | Frequently accessed tensors |
//! | Warm | 7 | ~4.57x | Moderately accessed |
//! | Warm | 5 | ~6.4x | Aggressively compressed warm |
//! | Cold | 3 | ~10.67x | Rarely accessed |
//!
//! # Zero Dependencies
//!
//! This crate has no external dependencies, making it fully WASM-compatible.
//!
//! # Quick Start
//!
//! ```rust
//! use ruvector_temporal_tensor::{TemporalTensorCompressor, TierPolicy};
//!
//! // Create a compressor for 128-element tensors
//! let mut comp = TemporalTensorCompressor::new(TierPolicy::default(), 128, 0);
//! comp.set_access(100, 0); // hot tensor -> 8-bit quantization
//!
//! let frame = vec![1.0f32; 128];
//! let mut segment = Vec::new();
//!
//! // Push frames; segment is populated when a boundary is crossed
//! comp.push_frame(&frame, 1, &mut segment);
//! comp.flush(&mut segment); // force-emit the current segment
//!
//! // Decode the segment back to f32
//! let mut decoded = Vec::new();
//! ruvector_temporal_tensor::segment::decode(&segment, &mut decoded);
//! assert_eq!(decoded.len(), 128);
//! ```
//!
//! # Random-Access Decode
//!
//! ```rust
//! # use ruvector_temporal_tensor::{TemporalTensorCompressor, TierPolicy};
//! # let mut comp = TemporalTensorCompressor::new(TierPolicy::default(), 64, 0);
//! # let frame = vec![1.0f32; 64];
//! # let mut seg = Vec::new();
//! # comp.push_frame(&frame, 0, &mut seg);
//! # comp.flush(&mut seg);
//! // Decode only frame 0 without decoding the entire segment
//! let single = ruvector_temporal_tensor::segment::decode_single_frame(&seg, 0);
//! assert!(single.is_some());
//! ```
//!
//! # Compression Ratio Inspection
//!
//! ```rust
//! # use ruvector_temporal_tensor::{TemporalTensorCompressor, TierPolicy};
//! # let mut comp = TemporalTensorCompressor::new(TierPolicy::default(), 64, 0);
//! # let frame = vec![1.0f32; 64];
//! # let mut seg = Vec::new();
//! # comp.push_frame(&frame, 0, &mut seg);
//! # comp.flush(&mut seg);
//! let ratio = ruvector_temporal_tensor::segment::compression_ratio(&seg);
//! assert!(ratio > 1.0);
//! ```
pub mod bitpack;
pub mod compressor;
pub mod delta;
pub mod f16;
pub mod metrics;
pub mod quantizer;
pub mod segment;
pub mod store;
pub mod tier_policy;
pub mod tiering;
pub mod agentdb;
pub mod coherence;
pub mod core_trait;
#[cfg(feature = "persistence")]
pub mod persistence;
#[cfg(feature = "ffi")]
pub mod ffi;
#[cfg(feature = "ffi")]
pub mod store_ffi;
pub use compressor::TemporalTensorCompressor;
pub use tier_policy::TierPolicy;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,859 @@
//! Disk-backed BlockIO and MetaLog implementations.
//!
//! Gated behind the `persistence` feature flag. Uses raw file I/O
//! with a simple binary format. No external dependencies.
#![cfg(feature = "persistence")]
use crate::store::{
BlockIO, BlockKey, BlockMeta, DType, MetaLog, ReconstructPolicy, StoreError, Tier,
};
use std::collections::HashMap;
use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
/// Fixed size of a single encoded [`BlockMeta`] record in bytes.
///
/// Layout (all little-endian):
///
/// | Offset | Size | Field |
/// |--------|------|-----------------|
/// | 0 | 16 | tensor_id |
/// | 16 | 4 | block_index |
/// | 20 | 1 | dtype |
/// | 21 | 1 | tier |
/// | 22 | 1 | bits |
/// | 23 | 4 | scale |
/// | 27 | 2 | zero_point |
/// | 29 | 8 | created_at |
/// | 37 | 8 | last_access_at |
/// | 45 | 4 | access_count |
/// | 49 | 4 | ema_rate |
/// | 53 | 8 | window |
/// | 61 | 4 | checksum |
/// | 65 | 1 | reconstruct |
/// | 66 | 4 | tier_age |
/// | 70 | 1 | has_lineage |
/// | 71 | 16 | lineage_parent |
/// | 87 | 4 | block_bytes |
const RECORD_SIZE: usize = 91;
// ---------------------------------------------------------------------------
// Serialization helpers
// ---------------------------------------------------------------------------
/// Serialize a [`BlockMeta`] into a fixed-size byte vector.
///
/// The encoding uses little-endian byte order for all multi-byte fields
/// and occupies exactly [`RECORD_SIZE`] bytes.
pub fn encode_meta(meta: &BlockMeta) -> Vec<u8> {
let mut buf = Vec::with_capacity(RECORD_SIZE);
// key
buf.extend_from_slice(&meta.key.tensor_id.to_le_bytes());
buf.extend_from_slice(&meta.key.block_index.to_le_bytes());
// scalar metadata
buf.push(meta.dtype as u8);
buf.push(meta.tier as u8);
buf.push(meta.bits);
buf.extend_from_slice(&meta.scale.to_le_bytes());
buf.extend_from_slice(&meta.zero_point.to_le_bytes());
// timestamps and counters
buf.extend_from_slice(&meta.created_at.to_le_bytes());
buf.extend_from_slice(&meta.last_access_at.to_le_bytes());
buf.extend_from_slice(&meta.access_count.to_le_bytes());
buf.extend_from_slice(&meta.ema_rate.to_le_bytes());
buf.extend_from_slice(&meta.window.to_le_bytes());
buf.extend_from_slice(&meta.checksum.to_le_bytes());
// policy and age
buf.push(meta.reconstruct as u8);
buf.extend_from_slice(&meta.tier_age.to_le_bytes());
// optional lineage parent
match meta.lineage_parent {
Some(parent) => {
buf.push(1);
buf.extend_from_slice(&parent.to_le_bytes());
}
None => {
buf.push(0);
buf.extend_from_slice(&0u128.to_le_bytes());
}
}
// payload size
buf.extend_from_slice(&meta.block_bytes.to_le_bytes());
debug_assert_eq!(buf.len(), RECORD_SIZE);
buf
}
/// Deserialize a [`BlockMeta`] from a byte slice of at least [`RECORD_SIZE`] bytes.
///
/// Returns [`StoreError::InvalidData`] if the slice is too short or
/// contains invalid enum discriminants.
pub fn decode_meta(bytes: &[u8]) -> Result<BlockMeta, StoreError> {
if bytes.len() < RECORD_SIZE {
return Err(StoreError::InvalidData);
}
let tensor_id = u128::from_le_bytes(
bytes[0..16]
.try_into()
.map_err(|_| StoreError::InvalidData)?,
);
let block_index = u32::from_le_bytes(
bytes[16..20]
.try_into()
.map_err(|_| StoreError::InvalidData)?,
);
let dtype = match bytes[20] {
0 => DType::F32,
1 => DType::F16,
2 => DType::BF16,
_ => return Err(StoreError::InvalidData),
};
let tier = match bytes[21] {
0 => Tier::Tier0,
1 => Tier::Tier1,
2 => Tier::Tier2,
3 => Tier::Tier3,
_ => return Err(StoreError::InvalidData),
};
let bits = bytes[22];
let scale = f32::from_le_bytes(
bytes[23..27]
.try_into()
.map_err(|_| StoreError::InvalidData)?,
);
let zero_point = i16::from_le_bytes(
bytes[27..29]
.try_into()
.map_err(|_| StoreError::InvalidData)?,
);
let created_at = u64::from_le_bytes(
bytes[29..37]
.try_into()
.map_err(|_| StoreError::InvalidData)?,
);
let last_access_at = u64::from_le_bytes(
bytes[37..45]
.try_into()
.map_err(|_| StoreError::InvalidData)?,
);
let access_count = u32::from_le_bytes(
bytes[45..49]
.try_into()
.map_err(|_| StoreError::InvalidData)?,
);
let ema_rate = f32::from_le_bytes(
bytes[49..53]
.try_into()
.map_err(|_| StoreError::InvalidData)?,
);
let window = u64::from_le_bytes(
bytes[53..61]
.try_into()
.map_err(|_| StoreError::InvalidData)?,
);
let checksum = u32::from_le_bytes(
bytes[61..65]
.try_into()
.map_err(|_| StoreError::InvalidData)?,
);
let reconstruct = match bytes[65] {
0 => ReconstructPolicy::None,
1 => ReconstructPolicy::Delta,
2 => ReconstructPolicy::Factor,
_ => return Err(StoreError::InvalidData),
};
let tier_age = u32::from_le_bytes(
bytes[66..70]
.try_into()
.map_err(|_| StoreError::InvalidData)?,
);
let has_lineage = bytes[70];
let lineage_value = u128::from_le_bytes(
bytes[71..87]
.try_into()
.map_err(|_| StoreError::InvalidData)?,
);
let lineage_parent = if has_lineage != 0 {
Some(lineage_value)
} else {
None
};
let block_bytes = u32::from_le_bytes(
bytes[87..91]
.try_into()
.map_err(|_| StoreError::InvalidData)?,
);
Ok(BlockMeta {
key: BlockKey {
tensor_id,
block_index,
},
dtype,
tier,
bits,
scale,
zero_point,
created_at,
last_access_at,
access_count,
ema_rate,
window,
checksum,
reconstruct,
tier_age,
lineage_parent,
block_bytes,
})
}
// ---------------------------------------------------------------------------
// FileBlockIO
// ---------------------------------------------------------------------------
/// Disk-backed [`BlockIO`] that stores each block as a separate file.
///
/// Directory layout:
/// ```text
/// {base_dir}/
/// tier0/
/// tier1/
/// tier2/
/// tier3/
/// ```
///
/// Each block file is named `{tensor_id_hex}_{block_index}.bin`.
pub struct FileBlockIO {
base_dir: PathBuf,
}
impl FileBlockIO {
/// Create a new `FileBlockIO` rooted at `base_dir`.
///
/// Creates the tier subdirectories if they do not already exist.
pub fn new(base_dir: impl Into<PathBuf>) -> Result<Self, StoreError> {
let base_dir = base_dir.into();
for tier_num in 0..=3u8 {
let tier_dir = base_dir.join(format!("tier{}", tier_num));
fs::create_dir_all(&tier_dir).map_err(|_| StoreError::IOError)?;
}
Ok(Self { base_dir })
}
/// Return the filesystem path for a given block.
fn block_path(&self, tier: Tier, key: BlockKey) -> PathBuf {
self.base_dir
.join(format!("tier{}", tier as u8))
.join(format!("{:032x}_{}.bin", key.tensor_id, key.block_index))
}
/// Return the base directory.
pub fn base_dir(&self) -> &Path {
&self.base_dir
}
}
impl BlockIO for FileBlockIO {
fn read_block(&self, tier: Tier, key: BlockKey, dst: &mut [u8]) -> Result<usize, StoreError> {
let path = self.block_path(tier, key);
let data = fs::read(&path).map_err(|_| StoreError::BlockNotFound)?;
let n = data.len().min(dst.len());
dst[..n].copy_from_slice(&data[..n]);
Ok(n)
}
fn write_block(&mut self, tier: Tier, key: BlockKey, src: &[u8]) -> Result<(), StoreError> {
if tier == Tier::Tier0 {
return Err(StoreError::InvalidBlock);
}
let path = self.block_path(tier, key);
fs::write(&path, src).map_err(|_| StoreError::IOError)
}
fn delete_block(&mut self, tier: Tier, key: BlockKey) -> Result<(), StoreError> {
let path = self.block_path(tier, key);
fs::remove_file(&path).map_err(|_| StoreError::BlockNotFound)
}
}
// ---------------------------------------------------------------------------
// FileMetaLog
// ---------------------------------------------------------------------------
/// Append-only file-backed [`MetaLog`].
///
/// Each [`append`](MetaLog::append) call writes a fixed-size binary record
/// to `{base_dir}/meta.log`. On construction the log is replayed into an
/// in-memory [`HashMap`] so that [`get`](MetaLog::get) is a simple lookup.
///
/// Because the log is append-only, multiple records for the same key may
/// exist on disk. The last record wins when the log is replayed.
pub struct FileMetaLog {
log_path: PathBuf,
index: HashMap<BlockKey, BlockMeta>,
}
impl FileMetaLog {
/// Open (or create) a `FileMetaLog` rooted at `base_dir`.
///
/// If `{base_dir}/meta.log` already exists it is replayed to populate
/// the in-memory index.
pub fn new(base_dir: impl Into<PathBuf>) -> Result<Self, StoreError> {
let base_dir = base_dir.into();
fs::create_dir_all(&base_dir).map_err(|_| StoreError::IOError)?;
let log_path = base_dir.join("meta.log");
let mut index = HashMap::new();
if log_path.exists() {
let data = fs::read(&log_path).map_err(|_| StoreError::IOError)?;
let mut offset = 0;
while offset + RECORD_SIZE <= data.len() {
if let Ok(meta) = decode_meta(&data[offset..offset + RECORD_SIZE]) {
index.insert(meta.key, meta);
}
offset += RECORD_SIZE;
}
}
Ok(Self { log_path, index })
}
/// Return the path to the underlying log file.
pub fn log_path(&self) -> &Path {
&self.log_path
}
/// Number of unique blocks tracked in the in-memory index.
pub fn len(&self) -> usize {
self.index.len()
}
/// Returns `true` if no metadata records are tracked.
pub fn is_empty(&self) -> bool {
self.index.is_empty()
}
}
impl MetaLog for FileMetaLog {
fn append(&mut self, rec: &BlockMeta) -> Result<(), StoreError> {
let encoded = encode_meta(rec);
let mut file = fs::OpenOptions::new()
.create(true)
.append(true)
.open(&self.log_path)
.map_err(|_| StoreError::IOError)?;
file.write_all(&encoded).map_err(|_| StoreError::IOError)?;
file.flush().map_err(|_| StoreError::IOError)?;
self.index.insert(rec.key, rec.clone());
Ok(())
}
fn get(&self, key: BlockKey) -> Option<&BlockMeta> {
self.index.get(&key)
}
fn iter(&self) -> Box<dyn Iterator<Item = &BlockMeta> + '_> {
Box::new(self.index.values())
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
/// Monotonic counter for unique test directory names.
static TEST_ID: AtomicU32 = AtomicU32::new(0);
/// Create a unique temporary directory for a test.
fn test_dir(prefix: &str) -> PathBuf {
let id = TEST_ID.fetch_add(1, Ordering::SeqCst);
let pid = std::process::id();
let dir =
std::env::temp_dir().join(format!("ruvector_persistence_{}_{}_{}", prefix, pid, id));
let _ = fs::remove_dir_all(&dir);
fs::create_dir_all(&dir).unwrap();
dir
}
/// Clean up a test directory (best-effort).
fn cleanup(dir: &Path) {
let _ = fs::remove_dir_all(dir);
}
fn make_key(tid: u128, idx: u32) -> BlockKey {
BlockKey {
tensor_id: tid,
block_index: idx,
}
}
fn sample_meta(key: BlockKey) -> BlockMeta {
BlockMeta {
key,
dtype: DType::F32,
tier: Tier::Tier1,
bits: 8,
scale: 0.03125,
zero_point: 0,
created_at: 1000,
last_access_at: 2000,
access_count: 42,
ema_rate: 0.75,
window: 0xAAAA_BBBB_CCCC_DDDD,
checksum: 0xDEAD_BEEF,
reconstruct: ReconstructPolicy::None,
tier_age: 15,
lineage_parent: None,
block_bytes: 512,
}
}
// -- encode / decode roundtrip -----------------------------------------
#[test]
fn encode_decode_roundtrip_basic() {
let key = make_key(0x0123_4567_89AB_CDEF_FEDC_BA98_7654_3210, 7);
let meta = sample_meta(key);
let encoded = encode_meta(&meta);
assert_eq!(encoded.len(), RECORD_SIZE);
let decoded = decode_meta(&encoded).unwrap();
assert_eq!(decoded.key, meta.key);
assert_eq!(decoded.dtype, meta.dtype);
assert_eq!(decoded.tier, meta.tier);
assert_eq!(decoded.bits, meta.bits);
assert!((decoded.scale - meta.scale).abs() < 1e-10);
assert_eq!(decoded.zero_point, meta.zero_point);
assert_eq!(decoded.created_at, meta.created_at);
assert_eq!(decoded.last_access_at, meta.last_access_at);
assert_eq!(decoded.access_count, meta.access_count);
assert!((decoded.ema_rate - meta.ema_rate).abs() < 1e-6);
assert_eq!(decoded.window, meta.window);
assert_eq!(decoded.checksum, meta.checksum);
assert_eq!(decoded.reconstruct, meta.reconstruct);
assert_eq!(decoded.tier_age, meta.tier_age);
assert_eq!(decoded.lineage_parent, meta.lineage_parent);
assert_eq!(decoded.block_bytes, meta.block_bytes);
}
#[test]
fn encode_decode_with_lineage() {
let key = make_key(1, 0);
let mut meta = sample_meta(key);
meta.lineage_parent = Some(0xFFFF_FFFF_FFFF_FFFF_0000_0000_0000_0001);
let encoded = encode_meta(&meta);
let decoded = decode_meta(&encoded).unwrap();
assert_eq!(
decoded.lineage_parent,
Some(0xFFFF_FFFF_FFFF_FFFF_0000_0000_0000_0001)
);
}
#[test]
fn encode_decode_all_dtypes() {
for (dtype_val, expected) in [(0u8, DType::F32), (1, DType::F16), (2, DType::BF16)] {
let key = make_key(dtype_val as u128, 0);
let mut meta = sample_meta(key);
meta.dtype = expected;
let decoded = decode_meta(&encode_meta(&meta)).unwrap();
assert_eq!(decoded.dtype, expected);
}
}
#[test]
fn encode_decode_all_tiers() {
for (tier_val, expected) in [
(0u8, Tier::Tier0),
(1, Tier::Tier1),
(2, Tier::Tier2),
(3, Tier::Tier3),
] {
let key = make_key(tier_val as u128, 0);
let mut meta = sample_meta(key);
meta.tier = expected;
let decoded = decode_meta(&encode_meta(&meta)).unwrap();
assert_eq!(decoded.tier, expected);
}
}
#[test]
fn encode_decode_all_reconstruct_policies() {
for (_, expected) in [
(0u8, ReconstructPolicy::None),
(1, ReconstructPolicy::Delta),
(2, ReconstructPolicy::Factor),
] {
let key = make_key(1, 0);
let mut meta = sample_meta(key);
meta.reconstruct = expected;
let decoded = decode_meta(&encode_meta(&meta)).unwrap();
assert_eq!(decoded.reconstruct, expected);
}
}
#[test]
fn decode_too_short() {
let result = decode_meta(&[0u8; RECORD_SIZE - 1]);
assert!(
matches!(result, Err(StoreError::InvalidData)),
"expected InvalidData, got {:?}",
result.err()
);
}
#[test]
fn decode_invalid_dtype() {
let key = make_key(1, 0);
let mut encoded = encode_meta(&sample_meta(key));
encoded[20] = 255; // invalid dtype
assert!(
matches!(decode_meta(&encoded), Err(StoreError::InvalidData)),
"expected InvalidData for bad dtype"
);
}
#[test]
fn decode_invalid_tier() {
let key = make_key(1, 0);
let mut encoded = encode_meta(&sample_meta(key));
encoded[21] = 99; // invalid tier
assert!(
matches!(decode_meta(&encoded), Err(StoreError::InvalidData)),
"expected InvalidData for bad tier"
);
}
#[test]
fn decode_invalid_reconstruct() {
let key = make_key(1, 0);
let mut encoded = encode_meta(&sample_meta(key));
encoded[65] = 77; // invalid reconstruct policy
assert!(
matches!(decode_meta(&encoded), Err(StoreError::InvalidData)),
"expected InvalidData for bad reconstruct"
);
}
// -- FileBlockIO -------------------------------------------------------
#[test]
fn file_block_io_write_read() {
let dir = test_dir("bio_wr");
let mut io = FileBlockIO::new(&dir).unwrap();
let key = make_key(0xABCD, 3);
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
io.write_block(Tier::Tier1, key, &data).unwrap();
let mut dst = vec![0u8; 16];
let n = io.read_block(Tier::Tier1, key, &mut dst).unwrap();
assert_eq!(n, 8);
assert_eq!(&dst[..8], &data);
cleanup(&dir);
}
#[test]
fn file_block_io_write_tier0_rejected() {
let dir = test_dir("bio_t0");
let mut io = FileBlockIO::new(&dir).unwrap();
let key = make_key(1, 0);
assert_eq!(
io.write_block(Tier::Tier0, key, &[1]),
Err(StoreError::InvalidBlock)
);
cleanup(&dir);
}
#[test]
fn file_block_io_read_not_found() {
let dir = test_dir("bio_nf");
let io = FileBlockIO::new(&dir).unwrap();
let key = make_key(99, 99);
let mut dst = vec![0u8; 4];
assert_eq!(
io.read_block(Tier::Tier2, key, &mut dst),
Err(StoreError::BlockNotFound)
);
cleanup(&dir);
}
#[test]
fn file_block_io_delete() {
let dir = test_dir("bio_del");
let mut io = FileBlockIO::new(&dir).unwrap();
let key = make_key(5, 0);
io.write_block(Tier::Tier2, key, &[10, 20, 30]).unwrap();
io.delete_block(Tier::Tier2, key).unwrap();
let mut dst = vec![0u8; 4];
assert_eq!(
io.read_block(Tier::Tier2, key, &mut dst),
Err(StoreError::BlockNotFound)
);
cleanup(&dir);
}
#[test]
fn file_block_io_delete_not_found() {
let dir = test_dir("bio_del_nf");
let mut io = FileBlockIO::new(&dir).unwrap();
let key = make_key(1, 0);
assert_eq!(
io.delete_block(Tier::Tier1, key),
Err(StoreError::BlockNotFound)
);
cleanup(&dir);
}
#[test]
fn file_block_io_overwrite() {
let dir = test_dir("bio_ow");
let mut io = FileBlockIO::new(&dir).unwrap();
let key = make_key(1, 0);
io.write_block(Tier::Tier1, key, &[1, 2, 3]).unwrap();
io.write_block(Tier::Tier1, key, &[4, 5, 6, 7]).unwrap();
let mut dst = vec![0u8; 8];
let n = io.read_block(Tier::Tier1, key, &mut dst).unwrap();
assert_eq!(n, 4);
assert_eq!(&dst[..4], &[4, 5, 6, 7]);
cleanup(&dir);
}
#[test]
fn file_block_io_multiple_tiers() {
let dir = test_dir("bio_mt");
let mut io = FileBlockIO::new(&dir).unwrap();
let key = make_key(1, 0);
io.write_block(Tier::Tier1, key, &[1]).unwrap();
io.write_block(Tier::Tier2, key, &[2]).unwrap();
io.write_block(Tier::Tier3, key, &[3]).unwrap();
let mut dst = [0u8; 1];
let n = io.read_block(Tier::Tier1, key, &mut dst).unwrap();
assert_eq!(n, 1);
assert_eq!(dst[0], 1);
let n = io.read_block(Tier::Tier2, key, &mut dst).unwrap();
assert_eq!(n, 1);
assert_eq!(dst[0], 2);
let n = io.read_block(Tier::Tier3, key, &mut dst).unwrap();
assert_eq!(n, 1);
assert_eq!(dst[0], 3);
cleanup(&dir);
}
#[test]
fn file_block_io_path_format() {
let dir = test_dir("bio_path");
let io = FileBlockIO::new(&dir).unwrap();
let key = make_key(0xFF, 42);
let path = io.block_path(Tier::Tier1, key);
let expected = dir
.join("tier1")
.join("000000000000000000000000000000ff_42.bin");
assert_eq!(path, expected);
cleanup(&dir);
}
// -- FileMetaLog -------------------------------------------------------
#[test]
fn file_meta_log_append_get() {
let dir = test_dir("ml_ag");
let mut log = FileMetaLog::new(&dir).unwrap();
let key = make_key(1, 0);
let meta = sample_meta(key);
log.append(&meta).unwrap();
let retrieved = log.get(key).unwrap();
assert_eq!(retrieved.key, key);
assert_eq!(retrieved.created_at, 1000);
assert_eq!(log.len(), 1);
cleanup(&dir);
}
#[test]
fn file_meta_log_get_missing() {
let dir = test_dir("ml_miss");
let log = FileMetaLog::new(&dir).unwrap();
assert!(log.get(make_key(99, 0)).is_none());
cleanup(&dir);
}
#[test]
fn file_meta_log_upsert() {
let dir = test_dir("ml_ups");
let mut log = FileMetaLog::new(&dir).unwrap();
let key = make_key(1, 0);
let mut meta = sample_meta(key);
meta.access_count = 10;
log.append(&meta).unwrap();
meta.access_count = 20;
log.append(&meta).unwrap();
// In-memory should reflect the latest write.
let retrieved = log.get(key).unwrap();
assert_eq!(retrieved.access_count, 20);
assert_eq!(log.len(), 1);
cleanup(&dir);
}
#[test]
fn file_meta_log_iter() {
let dir = test_dir("ml_iter");
let mut log = FileMetaLog::new(&dir).unwrap();
for i in 0..5u32 {
let key = make_key(i as u128, 0);
log.append(&sample_meta(key)).unwrap();
}
let entries: Vec<_> = log.iter().collect();
assert_eq!(entries.len(), 5);
cleanup(&dir);
}
#[test]
fn file_meta_log_persistence_across_opens() {
let dir = test_dir("ml_persist");
let key1 = make_key(1, 0);
let key2 = make_key(2, 5);
// First open: write two records.
{
let mut log = FileMetaLog::new(&dir).unwrap();
log.append(&sample_meta(key1)).unwrap();
let mut meta2 = sample_meta(key2);
meta2.tier = Tier::Tier3;
meta2.bits = 3;
meta2.lineage_parent = Some(0x42);
log.append(&meta2).unwrap();
assert_eq!(log.len(), 2);
}
// Second open: records should be recovered from disk.
{
let log = FileMetaLog::new(&dir).unwrap();
assert_eq!(log.len(), 2);
let r1 = log.get(key1).unwrap();
assert_eq!(r1.tier, Tier::Tier1);
let r2 = log.get(key2).unwrap();
assert_eq!(r2.tier, Tier::Tier3);
assert_eq!(r2.lineage_parent, Some(0x42));
}
cleanup(&dir);
}
#[test]
fn file_meta_log_replay_last_wins() {
let dir = test_dir("ml_lw");
let key = make_key(1, 0);
// Write two versions of the same key.
{
let mut log = FileMetaLog::new(&dir).unwrap();
let mut meta = sample_meta(key);
meta.access_count = 100;
log.append(&meta).unwrap();
meta.access_count = 200;
log.append(&meta).unwrap();
}
// Reopen: last record should win during replay.
{
let log = FileMetaLog::new(&dir).unwrap();
assert_eq!(log.len(), 1);
let retrieved = log.get(key).unwrap();
assert_eq!(retrieved.access_count, 200);
}
cleanup(&dir);
}
#[test]
fn file_meta_log_empty_on_fresh_dir() {
let dir = test_dir("ml_empty");
let log = FileMetaLog::new(&dir).unwrap();
assert!(log.is_empty());
assert_eq!(log.len(), 0);
assert_eq!(log.iter().count(), 0);
cleanup(&dir);
}
// -- Integration: FileBlockIO + FileMetaLog ----------------------------
#[test]
fn integration_block_io_and_meta_log() {
let dir = test_dir("integ");
let mut io = FileBlockIO::new(&dir).unwrap();
let mut log = FileMetaLog::new(&dir).unwrap();
let key = make_key(0x1234, 0);
let block_data = vec![0xFFu8; 256];
// Write block and metadata.
io.write_block(Tier::Tier1, key, &block_data).unwrap();
let mut meta = sample_meta(key);
meta.block_bytes = 256;
log.append(&meta).unwrap();
// Read back and verify.
let mut dst = vec![0u8; 512];
let n = io.read_block(Tier::Tier1, key, &mut dst).unwrap();
assert_eq!(n, 256);
assert!(dst[..256].iter().all(|&b| b == 0xFF));
let retrieved = log.get(key).unwrap();
assert_eq!(retrieved.block_bytes, 256);
cleanup(&dir);
}
#[test]
fn record_size_constant_matches() {
// Verify that RECORD_SIZE matches the actual encoded size.
let meta = sample_meta(make_key(0, 0));
let encoded = encode_meta(&meta);
assert_eq!(encoded.len(), RECORD_SIZE);
}
}

View File

@@ -0,0 +1,834 @@
//! Groupwise symmetric quantization with f16 scales.
//!
//! For each group of `group_len` values:
//! - `scale = max(|v_i|) / qmax`
//! - `q_i = round(v_i / scale)`, clamped to `[-qmax, +qmax]`
//! - `u_i = q_i + qmax` (bias to unsigned for packing)
use crate::bitpack::qmax_from_bits;
use crate::f16;
/// Compute f16 group scales for a frame.
///
/// Returns one f16-encoded scale per group of `group_len` elements.
/// Each scale is `max(|v|) / qmax` for that group, stored as IEEE 754 half-precision.
#[inline]
pub fn compute_scales(frame: &[f32], group_len: usize, bits: u8) -> Vec<u16> {
let qmax = qmax_from_bits(bits);
if qmax == 0 {
return Vec::new();
}
let qmax_f = qmax as f32;
let num_groups = frame.len().div_ceil(group_len);
let mut scales = Vec::with_capacity(num_groups);
for chunk in frame.chunks(group_len) {
let mut max_abs = 0.0f32;
for &v in chunk {
if v.is_finite() {
let a = v.abs();
if a > max_abs {
max_abs = a;
}
}
}
let scale = if max_abs == 0.0 {
0.0
} else {
max_abs / qmax_f
};
scales.push(f16::f32_to_f16_bits(scale));
}
scales
}
/// Pre-convert f16 scales to f32 for hot-path use.
#[inline]
pub fn scales_to_f32(scales_f16: &[u16]) -> Vec<f32> {
scales_f16
.iter()
.map(|&s| f16::f16_bits_to_f32(s))
.collect()
}
/// Check if a frame fits within existing scales (within drift tolerance).
///
/// Uses pre-converted f32 scales to avoid repeated f16 conversion.
/// Returns `false` if any group's max absolute value exceeds
/// `scale * qmax * drift_factor`.
pub fn frame_fits_scales_f32(
frame: &[f32],
scales_f32: &[f32],
group_len: usize,
bits: u8,
drift_factor: f32,
) -> bool {
let qmax = qmax_from_bits(bits);
if qmax == 0 || scales_f32.is_empty() {
return false;
}
let qmax_f = qmax as f32;
for (group_idx, chunk) in frame.chunks(group_len).enumerate() {
if group_idx >= scales_f32.len() {
return false;
}
let allowed = scales_f32[group_idx] * qmax_f * drift_factor;
for &v in chunk {
if v.is_finite() && v.abs() > allowed {
return false;
}
}
}
true
}
/// Quantize a frame using pre-computed f32 scales and pack into bitstream.
///
/// Appends packed bytes to `out`. Pre-reserves the expected output size
/// to avoid reallocations.
///
/// For 8-bit quantization, writes bytes directly without bit accumulation
/// since each quantized value maps 1:1 to a u8.
#[inline]
pub fn quantize_and_pack_f32(
frame: &[f32],
scales_f32: &[f32],
group_len: usize,
bits: u8,
out: &mut Vec<u8>,
) {
let qmax = qmax_from_bits(bits);
if qmax == 0 {
return;
}
// Fast path: 8-bit quantization writes bytes directly, no bit accumulator.
if bits == 8 {
out.reserve(frame.len());
for (group_idx, chunk) in frame.chunks(group_len).enumerate() {
let scale = if group_idx < scales_f32.len() {
scales_f32[group_idx]
} else {
0.0
};
let inv_scale = if scale == 0.0 { 0.0 } else { 1.0 / scale };
for &v in chunk {
let mut q: i32 = 0;
if v.is_finite() {
let scaled = v * inv_scale;
q = if scaled >= 0.0 {
(scaled + 0.5) as i32
} else {
(scaled - 0.5) as i32
};
q = q.clamp(-127, 127);
}
out.push((q + 127) as u8);
}
}
return;
}
// Fast path: 5-bit quantization packs 8 values into 5 bytes.
// 8 values * 5 bits = 40 bits = 5 bytes exactly, avoiding the bit accumulator.
// LSB-first packing layout for 8 values in 5 bytes:
// byte0 = v0 | (v1 << 5)
// byte1 = (v1 >> 3) | (v2 << 2) | (v3 << 7)
// byte2 = (v3 >> 1) | (v4 << 4)
// byte3 = (v4 >> 4) | (v5 << 1) | (v6 << 6)
// byte4 = (v6 >> 2) | (v7 << 3)
#[inline]
fn pack_5bit_group(chunk: &[f32], inv_scale: f32, out: &mut Vec<u8>) {
let quantize = |v: f32| -> u32 {
let mut q: i32 = 0;
if v.is_finite() {
let scaled = v * inv_scale;
q = if scaled >= 0.0 {
(scaled + 0.5) as i32
} else {
(scaled - 0.5) as i32
};
q = q.clamp(-15, 15);
}
(q + 15) as u32
};
let v0 = quantize(chunk[0]);
let v1 = quantize(chunk[1]);
let v2 = quantize(chunk[2]);
let v3 = quantize(chunk[3]);
let v4 = quantize(chunk[4]);
let v5 = quantize(chunk[5]);
let v6 = quantize(chunk[6]);
let v7 = quantize(chunk[7]);
out.push((v0 | (v1 << 5)) as u8);
out.push(((v1 >> 3) | (v2 << 2) | (v3 << 7)) as u8);
out.push(((v3 >> 1) | (v4 << 4)) as u8);
out.push(((v4 >> 4) | (v5 << 1) | (v6 << 6)) as u8);
out.push(((v6 >> 2) | (v7 << 3)) as u8);
}
if bits == 5 {
let needed_bytes = (frame.len() * 5).div_ceil(8);
out.reserve(needed_bytes);
let mut acc: u64 = 0;
let mut acc_bits: u32 = 0;
for (group_idx, chunk) in frame.chunks(group_len).enumerate() {
let scale = if group_idx < scales_f32.len() {
scales_f32[group_idx]
} else {
0.0
};
let inv_scale = if scale == 0.0 { 0.0 } else { 1.0 / scale };
let mut i = 0;
// Process 8 values at a time into 5 bytes when byte-aligned
while acc_bits == 0 && i + 8 <= chunk.len() {
pack_5bit_group(&chunk[i..i + 8], inv_scale, out);
i += 8;
}
// Remainder (or misaligned) with bit accumulator
while i < chunk.len() {
let mut q: i32 = 0;
if chunk[i].is_finite() {
let scaled = chunk[i] * inv_scale;
q = if scaled >= 0.0 {
(scaled + 0.5) as i32
} else {
(scaled - 0.5) as i32
};
q = q.clamp(-15, 15);
}
let u = (q + 15) as u32;
acc |= (u as u64) << acc_bits;
acc_bits += 5;
while acc_bits >= 8 {
out.push((acc & 0xFF) as u8);
acc >>= 8;
acc_bits -= 8;
}
i += 1;
}
}
if acc_bits > 0 {
out.push((acc & 0xFF) as u8);
}
return;
}
// Generic path for sub-byte bit widths.
let qmax_i = qmax;
let bias = qmax;
let bits_u32 = bits as u32;
let needed_bytes = (frame.len() * bits as usize).div_ceil(8);
out.reserve(needed_bytes);
let mut acc: u64 = 0;
let mut acc_bits: u32 = 0;
for (group_idx, chunk) in frame.chunks(group_len).enumerate() {
let scale = if group_idx < scales_f32.len() {
scales_f32[group_idx]
} else {
0.0
};
let inv_scale = if scale == 0.0 { 0.0 } else { 1.0 / scale };
for &v in chunk {
let mut q: i32 = 0;
if v.is_finite() {
let scaled = v * inv_scale;
q = if scaled >= 0.0 {
(scaled + 0.5) as i32
} else {
(scaled - 0.5) as i32
};
q = q.clamp(-qmax_i, qmax_i);
}
let u = (q + bias) as u32;
acc |= (u as u64) << acc_bits;
acc_bits += bits_u32;
while acc_bits >= 8 {
out.push((acc & 0xFF) as u8);
acc >>= 8;
acc_bits -= 8;
}
}
}
if acc_bits > 0 {
out.push((acc & 0xFF) as u8);
}
}
/// Dequantize packed codes using f32 scales, writing f32 values.
///
/// Iterates by frame then by group to avoid per-value modulo/division
/// and caches the f32 scale per group.
///
/// For 8-bit data, reads bytes directly without bit accumulation.
#[inline]
pub fn dequantize_f32(
data: &[u8],
scales_f32: &[f32],
group_len: usize,
bits: u8,
tensor_len: usize,
frame_count: usize,
out: &mut Vec<f32>,
) {
let qmax = qmax_from_bits(bits);
if qmax == 0 {
return;
}
let total = tensor_len * frame_count;
out.resize(total, 0.0);
// Fast path: 8-bit dequantization reads bytes directly, no bit accumulator.
if bits == 8 {
let mut out_idx = 0usize;
let mut byte_idx = 0usize;
for _frame in 0..frame_count {
let mut pos = 0usize;
let mut group_idx = 0usize;
while pos < tensor_len {
let group_end = (pos + group_len).min(tensor_len);
let scale = if group_idx < scales_f32.len() {
scales_f32[group_idx]
} else {
0.0
};
while pos < group_end && byte_idx < data.len() {
let u = data[byte_idx] as i32;
let q = u - 127;
out[out_idx] = (q as f32) * scale;
out_idx += 1;
byte_idx += 1;
pos += 1;
}
group_idx += 1;
}
}
return;
}
// Fast path: 3-bit dequantization processes 8 values from 3 bytes.
// 8 values * 3 bits = 24 bits = 3 bytes exactly, avoiding the bit accumulator.
// LSB-first packing layout for 8 values in 3 bytes:
// byte0 = v0 | (v1 << 3) | ((v2 & 0x3) << 6)
// byte1 = (v2 >> 2) | (v3 << 1) | (v4 << 4) | ((v5 & 0x1) << 7)
// byte2 = (v5 >> 1) | (v6 << 2) | (v7 << 5)
if bits == 3 {
let bias = 3i32; // qmax for 3-bit
let mut out_idx = 0usize;
let mut byte_idx = 0usize;
for _frame in 0..frame_count {
let mut pos = 0usize;
let mut group_idx = 0usize;
while pos < tensor_len {
let group_end = (pos + group_len).min(tensor_len);
let scale = if group_idx < scales_f32.len() {
scales_f32[group_idx]
} else {
0.0
};
// Process 8 values at a time from 3 bytes
while pos + 8 <= group_end && byte_idx + 3 <= data.len() {
let b0 = data[byte_idx] as u32;
let b1 = data[byte_idx + 1] as u32;
let b2 = data[byte_idx + 2] as u32;
byte_idx += 3;
out[out_idx] = ((b0 & 0x7) as i32 - bias) as f32 * scale;
out[out_idx + 1] = (((b0 >> 3) & 0x7) as i32 - bias) as f32 * scale;
out[out_idx + 2] =
((((b0 >> 6) | (b1 << 2)) & 0x7) as i32 - bias) as f32 * scale;
out[out_idx + 3] = (((b1 >> 1) & 0x7) as i32 - bias) as f32 * scale;
out[out_idx + 4] = (((b1 >> 4) & 0x7) as i32 - bias) as f32 * scale;
out[out_idx + 5] =
((((b1 >> 7) | (b2 << 1)) & 0x7) as i32 - bias) as f32 * scale;
out[out_idx + 6] = (((b2 >> 2) & 0x7) as i32 - bias) as f32 * scale;
out[out_idx + 7] = (((b2 >> 5) & 0x7) as i32 - bias) as f32 * scale;
out_idx += 8;
pos += 8;
}
// Handle remaining values (< 8) with a local bit accumulator
if pos < group_end {
let remaining = group_end - pos;
let mut acc: u64 = 0;
let mut acc_bits: u32 = 0;
while acc_bits < (remaining as u32) * 3 && byte_idx < data.len() {
acc |= (data[byte_idx] as u64) << acc_bits;
acc_bits += 8;
byte_idx += 1;
}
for _ in 0..remaining {
if acc_bits < 3 {
break;
}
let u = (acc & 0x7) as i32;
acc >>= 3;
acc_bits -= 3;
out[out_idx] = (u - bias) as f32 * scale;
out_idx += 1;
pos += 1;
}
}
group_idx += 1;
}
}
return;
}
// Fast path: 7-bit dequantization processes 8 values from 7 bytes.
// 8 values * 7 bits = 56 bits = 7 bytes exactly, avoiding the bit accumulator.
// LSB-first packing layout for 8 values in 7 bytes:
// v0 = b0 & 0x7F
// v1 = ((b0 >> 7) | (b1 << 1)) & 0x7F
// v2 = ((b1 >> 6) | (b2 << 2)) & 0x7F
// v3 = ((b2 >> 5) | (b3 << 3)) & 0x7F
// v4 = ((b3 >> 4) | (b4 << 4)) & 0x7F
// v5 = ((b4 >> 3) | (b5 << 5)) & 0x7F
// v6 = ((b5 >> 2) | (b6 << 6)) & 0x7F
// v7 = (b6 >> 1) & 0x7F
if bits == 7 {
let bias = 63i32; // qmax for 7-bit
let mut out_idx = 0usize;
let mut byte_idx = 0usize;
for _frame in 0..frame_count {
let mut pos = 0usize;
let mut group_idx = 0usize;
while pos < tensor_len {
let group_end = (pos + group_len).min(tensor_len);
let scale = if group_idx < scales_f32.len() {
scales_f32[group_idx]
} else {
0.0
};
// Process 8 values at a time from 7 bytes
#[inline]
fn unpack_7bit(
out: &mut [f32],
out_idx: usize,
data: &[u8],
byte_idx: usize,
bias: i32,
scale: f32,
) {
let b0 = data[byte_idx] as u32;
let b1 = data[byte_idx + 1] as u32;
let b2 = data[byte_idx + 2] as u32;
let b3 = data[byte_idx + 3] as u32;
let b4 = data[byte_idx + 4] as u32;
let b5 = data[byte_idx + 5] as u32;
let b6 = data[byte_idx + 6] as u32;
out[out_idx] = ((b0 & 0x7F) as i32 - bias) as f32 * scale;
out[out_idx + 1] =
((((b0 >> 7) | (b1 << 1)) & 0x7F) as i32 - bias) as f32 * scale;
out[out_idx + 2] =
((((b1 >> 6) | (b2 << 2)) & 0x7F) as i32 - bias) as f32 * scale;
out[out_idx + 3] =
((((b2 >> 5) | (b3 << 3)) & 0x7F) as i32 - bias) as f32 * scale;
out[out_idx + 4] =
((((b3 >> 4) | (b4 << 4)) & 0x7F) as i32 - bias) as f32 * scale;
out[out_idx + 5] =
((((b4 >> 3) | (b5 << 5)) & 0x7F) as i32 - bias) as f32 * scale;
out[out_idx + 6] =
((((b5 >> 2) | (b6 << 6)) & 0x7F) as i32 - bias) as f32 * scale;
out[out_idx + 7] = (((b6 >> 1) & 0x7F) as i32 - bias) as f32 * scale;
}
while pos + 8 <= group_end && byte_idx + 7 <= data.len() {
unpack_7bit(out, out_idx, data, byte_idx, bias, scale);
byte_idx += 7;
out_idx += 8;
pos += 8;
}
// Handle remaining values (< 8) with a local bit accumulator
if pos < group_end {
let remaining = group_end - pos;
let mut acc: u64 = 0;
let mut acc_bits: u32 = 0;
while acc_bits < (remaining as u32) * 7 && byte_idx < data.len() {
acc |= (data[byte_idx] as u64) << acc_bits;
acc_bits += 8;
byte_idx += 1;
}
for _ in 0..remaining {
if acc_bits < 7 {
break;
}
let u = (acc & 0x7F) as i32;
acc >>= 7;
acc_bits -= 7;
out[out_idx] = (u - bias) as f32 * scale;
out_idx += 1;
pos += 1;
}
}
group_idx += 1;
}
}
return;
}
// Fast path: 5-bit dequantization processes 8 values from 5 bytes.
// 8 values * 5 bits = 40 bits = 5 bytes exactly, avoiding the bit accumulator.
// LSB-first packing layout for 8 values in 5 bytes:
// v0 = b0 & 0x1F
// v1 = ((b0 >> 5) | (b1 << 3)) & 0x1F
// v2 = (b1 >> 2) & 0x1F
// v3 = ((b1 >> 7) | (b2 << 1)) & 0x1F
// v4 = ((b2 >> 4) | (b3 << 4)) & 0x1F
// v5 = (b3 >> 1) & 0x1F
// v6 = ((b3 >> 6) | (b4 << 2)) & 0x1F
// v7 = (b4 >> 3) & 0x1F
if bits == 5 {
let bias = 15i32; // qmax for 5-bit
let mut out_idx = 0usize;
let mut byte_idx = 0usize;
for _frame in 0..frame_count {
let mut pos = 0usize;
let mut group_idx = 0usize;
while pos < tensor_len {
let group_end = (pos + group_len).min(tensor_len);
let scale = if group_idx < scales_f32.len() {
scales_f32[group_idx]
} else {
0.0
};
// Process 8 values at a time from 5 bytes
#[inline]
fn unpack_5bit(
out: &mut [f32],
out_idx: usize,
data: &[u8],
byte_idx: usize,
bias: i32,
scale: f32,
) {
let b0 = data[byte_idx] as u32;
let b1 = data[byte_idx + 1] as u32;
let b2 = data[byte_idx + 2] as u32;
let b3 = data[byte_idx + 3] as u32;
let b4 = data[byte_idx + 4] as u32;
out[out_idx] = ((b0 & 0x1F) as i32 - bias) as f32 * scale;
out[out_idx + 1] =
((((b0 >> 5) | (b1 << 3)) & 0x1F) as i32 - bias) as f32 * scale;
out[out_idx + 2] = (((b1 >> 2) & 0x1F) as i32 - bias) as f32 * scale;
out[out_idx + 3] =
((((b1 >> 7) | (b2 << 1)) & 0x1F) as i32 - bias) as f32 * scale;
out[out_idx + 4] =
((((b2 >> 4) | (b3 << 4)) & 0x1F) as i32 - bias) as f32 * scale;
out[out_idx + 5] = (((b3 >> 1) & 0x1F) as i32 - bias) as f32 * scale;
out[out_idx + 6] =
((((b3 >> 6) | (b4 << 2)) & 0x1F) as i32 - bias) as f32 * scale;
out[out_idx + 7] = (((b4 >> 3) & 0x1F) as i32 - bias) as f32 * scale;
}
while pos + 8 <= group_end && byte_idx + 5 <= data.len() {
unpack_5bit(out, out_idx, data, byte_idx, bias, scale);
byte_idx += 5;
out_idx += 8;
pos += 8;
}
// Handle remaining values (< 8) with a local bit accumulator
if pos < group_end {
let remaining = group_end - pos;
let mut acc: u64 = 0;
let mut acc_bits: u32 = 0;
while acc_bits < (remaining as u32) * 5 && byte_idx < data.len() {
acc |= (data[byte_idx] as u64) << acc_bits;
acc_bits += 8;
byte_idx += 1;
}
for _ in 0..remaining {
if acc_bits < 5 {
break;
}
let u = (acc & 0x1F) as i32;
acc >>= 5;
acc_bits -= 5;
out[out_idx] = (u - bias) as f32 * scale;
out_idx += 1;
pos += 1;
}
}
group_idx += 1;
}
}
return;
}
// Generic path for sub-byte bit widths.
let bias = qmax;
let bits_u32 = bits as u32;
let mask = (1u64 << bits_u32) - 1;
let mut acc: u64 = 0;
let mut acc_bits: u32 = 0;
let mut byte_idx = 0usize;
let mut out_idx = 0usize;
for _frame in 0..frame_count {
let mut pos = 0usize;
let mut group_idx = 0usize;
while pos < tensor_len {
let group_end = (pos + group_len).min(tensor_len);
let scale = if group_idx < scales_f32.len() {
scales_f32[group_idx]
} else {
0.0
};
while pos < group_end {
while acc_bits < bits_u32 && byte_idx < data.len() {
acc |= (data[byte_idx] as u64) << acc_bits;
acc_bits += 8;
byte_idx += 1;
}
if acc_bits < bits_u32 {
return;
}
let u = (acc & mask) as u32;
acc >>= bits_u32;
acc_bits -= bits_u32;
let q = (u as i32) - bias;
out[out_idx] = (q as f32) * scale;
out_idx += 1;
pos += 1;
}
group_idx += 1;
}
}
}
// --- Legacy API (delegates to f32 variants) ---
/// Check if a frame fits within existing f16 scales (within drift tolerance).
pub fn frame_fits_scales(
frame: &[f32],
scales: &[u16],
group_len: usize,
bits: u8,
drift_factor: f32,
) -> bool {
let scales_f32 = scales_to_f32(scales);
frame_fits_scales_f32(frame, &scales_f32, group_len, bits, drift_factor)
}
/// Quantize a frame using pre-computed f16 scales and pack into bitstream.
pub fn quantize_and_pack(
frame: &[f32],
scales: &[u16],
group_len: usize,
bits: u8,
out: &mut Vec<u8>,
) {
let scales_f32 = scales_to_f32(scales);
quantize_and_pack_f32(frame, &scales_f32, group_len, bits, out)
}
/// Dequantize packed codes using f16 scales, writing f32 values.
pub fn dequantize(
data: &[u8],
scales: &[u16],
group_len: usize,
bits: u8,
tensor_len: usize,
frame_count: usize,
out: &mut Vec<f32>,
) {
let scales_f32 = scales_to_f32(scales);
dequantize_f32(
data,
&scales_f32,
group_len,
bits,
tensor_len,
frame_count,
out,
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantize_roundtrip_8bit() {
let frame: Vec<f32> = (0..128).map(|i| (i as f32 - 64.0) * 0.1).collect();
let scales = compute_scales(&frame, 64, 8);
let mut packed = Vec::new();
quantize_and_pack(&frame, &scales, 64, 8, &mut packed);
let mut decoded = Vec::new();
dequantize(&packed, &scales, 64, 8, frame.len(), 1, &mut decoded);
assert_eq!(decoded.len(), frame.len());
for (i, (&orig, &dec)) in frame.iter().zip(decoded.iter()).enumerate() {
let err = (orig - dec).abs();
let max_err = if orig.abs() > 0.01 {
orig.abs() * 0.02
} else {
0.1
};
assert!(err < max_err, "i={i}, orig={orig}, dec={dec}, err={err}");
}
}
#[test]
fn test_quantize_roundtrip_3bit() {
let frame: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.5).collect();
let scales = compute_scales(&frame, 64, 3);
let mut packed = Vec::new();
quantize_and_pack(&frame, &scales, 64, 3, &mut packed);
let mut decoded = Vec::new();
dequantize(&packed, &scales, 64, 3, frame.len(), 1, &mut decoded);
let max_val = frame.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
for (&orig, &dec) in frame.iter().zip(decoded.iter()) {
let err = (orig - dec).abs();
assert!(err < max_val * 0.35, "orig={orig}, dec={dec}, err={err}");
}
}
#[test]
fn test_quantize_roundtrip_5bit() {
let frame: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) * 0.05).collect();
let scales = compute_scales(&frame, 64, 5);
let mut packed = Vec::new();
quantize_and_pack(&frame, &scales, 64, 5, &mut packed);
let mut decoded = Vec::new();
dequantize(&packed, &scales, 64, 5, frame.len(), 1, &mut decoded);
let max_val = frame.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
for (&orig, &dec) in frame.iter().zip(decoded.iter()) {
let err = (orig - dec).abs();
assert!(err < max_val * 0.08, "orig={orig}, dec={dec}, err={err}");
}
}
#[test]
fn test_quantize_roundtrip_7bit() {
let frame: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) * 0.05).collect();
let scales = compute_scales(&frame, 64, 7);
let mut packed = Vec::new();
quantize_and_pack(&frame, &scales, 64, 7, &mut packed);
let mut decoded = Vec::new();
dequantize(&packed, &scales, 64, 7, frame.len(), 1, &mut decoded);
for (i, (&orig, &dec)) in frame.iter().zip(decoded.iter()).enumerate() {
let err = (orig - dec).abs();
let max_err = if orig.abs() > 0.01 {
orig.abs() * 0.02
} else {
0.1
};
assert!(err < max_err, "i={i}, orig={orig}, dec={dec}, err={err}");
}
}
#[test]
fn test_drift_detection() {
let frame1: Vec<f32> = vec![1.0; 64];
let frame2: Vec<f32> = vec![1.05; 64];
let frame3: Vec<f32> = vec![2.0; 64];
let scales = compute_scales(&frame1, 64, 8);
let drift_factor = 1.0 + 26.0 / 256.0;
assert!(frame_fits_scales(&frame2, &scales, 64, 8, drift_factor));
assert!(!frame_fits_scales(&frame3, &scales, 64, 8, drift_factor));
}
#[test]
fn test_zero_frame() {
let frame = vec![0.0f32; 128];
let scales = compute_scales(&frame, 64, 8);
let mut packed = Vec::new();
quantize_and_pack(&frame, &scales, 64, 8, &mut packed);
let mut decoded = Vec::new();
dequantize(&packed, &scales, 64, 8, 128, 1, &mut decoded);
for &v in &decoded {
assert_eq!(v, 0.0);
}
}
#[test]
fn test_non_finite_values() {
let mut frame = vec![1.0f32; 64];
frame[10] = f32::NAN;
frame[20] = f32::INFINITY;
frame[30] = f32::NEG_INFINITY;
let scales = compute_scales(&frame, 64, 8);
let mut packed = Vec::new();
quantize_and_pack(&frame, &scales, 64, 8, &mut packed);
let mut decoded = Vec::new();
dequantize(&packed, &scales, 64, 8, 64, 1, &mut decoded);
assert_eq!(decoded[10], 0.0);
assert_eq!(decoded[20], 0.0);
assert_eq!(decoded[30], 0.0);
assert!((decoded[0] - 1.0).abs() < 0.02);
}
#[test]
fn test_single_element_group() {
let frame = vec![3.14f32; 16];
let scales = compute_scales(&frame, 1, 8);
assert_eq!(scales.len(), 16);
let mut packed = Vec::new();
quantize_and_pack(&frame, &scales, 1, 8, &mut packed);
let mut decoded = Vec::new();
dequantize(&packed, &scales, 1, 8, 16, 1, &mut decoded);
for (i, &v) in decoded.iter().enumerate() {
let err = (v - 3.14).abs();
assert!(err < 0.03, "i={i} v={v} err={err}");
}
}
#[test]
fn test_compression_ratio() {
let frame = vec![1.0f32; 512];
for &(bits, min_ratio) in &[(8u8, 3.5f32), (7, 4.0), (5, 5.5), (3, 8.5)] {
let scales = compute_scales(&frame, 64, bits);
let mut packed = Vec::new();
quantize_and_pack(&frame, &scales, 64, bits, &mut packed);
let raw_bytes = frame.len() * 4;
let compressed = packed.len() + scales.len() * 2;
let ratio = raw_bytes as f32 / compressed as f32;
assert!(
ratio >= min_ratio,
"bits={bits}: ratio {ratio:.2}x < expected {min_ratio}x"
);
}
}
}

View File

@@ -0,0 +1,335 @@
//! Segment binary format: encode and decode.
//!
//! Format (little-endian):
//!
//! ```text
//! [magic:4][version:1][bits:1][group_len:4][tensor_len:4][frames:4]
//! [scale_count:4][scales:2*S][data_len:4][data:D]
//! ```
//!
//! Magic: `0x43545154` ("TQTC" in LE). Header is 26 bytes before scales.
use crate::quantizer;
/// Segment magic number: `"TQTC"` in little-endian.
pub const MAGIC: u32 = 0x4354_5154;
/// Current segment format version.
pub const VERSION: u8 = 1;
/// Minimum valid segment size in bytes (header fields + data_len, no scales/data).
pub const HEADER_SIZE: usize = 26;
/// Encode a segment from metadata, scales, and packed data.
pub fn encode(
bits: u8,
group_len: u32,
tensor_len: u32,
frame_count: u32,
scales: &[u16],
data: &[u8],
out: &mut Vec<u8>,
) {
out.clear();
let estimated = HEADER_SIZE + scales.len() * 2 + data.len();
out.reserve(estimated);
// Header
out.extend_from_slice(&MAGIC.to_le_bytes());
out.push(VERSION);
out.push(bits);
out.extend_from_slice(&group_len.to_le_bytes());
out.extend_from_slice(&tensor_len.to_le_bytes());
out.extend_from_slice(&frame_count.to_le_bytes());
// Scales
let scale_count = scales.len() as u32;
out.extend_from_slice(&scale_count.to_le_bytes());
for &s in scales {
out.extend_from_slice(&s.to_le_bytes());
}
// Data
let data_len = data.len() as u32;
out.extend_from_slice(&data_len.to_le_bytes());
out.extend_from_slice(data);
}
/// Decoded segment header.
#[derive(Debug, Clone)]
pub struct SegmentHeader {
pub bits: u8,
pub group_len: u32,
pub tensor_len: u32,
pub frame_count: u32,
pub scale_count: u32,
}
/// Decode a segment, returning all frames as f32 values.
pub fn decode(segment: &[u8], out: &mut Vec<f32>) {
out.clear();
if segment.len() < HEADER_SIZE {
return;
}
let mut off = 0;
let magic = read_u32_le(segment, &mut off);
if magic != MAGIC {
return;
}
let version = segment[off];
off += 1;
if version != VERSION {
return;
}
let bits = segment[off];
off += 1;
let group_len = read_u32_le(segment, &mut off);
let tensor_len = read_u32_le(segment, &mut off);
let frame_count = read_u32_le(segment, &mut off);
let scale_count = read_u32_le(segment, &mut off);
// Read scales
let scales_end = off + (scale_count as usize) * 2;
if scales_end > segment.len() {
return;
}
let mut scales = Vec::with_capacity(scale_count as usize);
for _ in 0..scale_count {
scales.push(read_u16_le(segment, &mut off));
}
// Read data
if off + 4 > segment.len() {
return;
}
let data_len = read_u32_le(segment, &mut off) as usize;
if off + data_len > segment.len() {
return;
}
let data = &segment[off..off + data_len];
// Convert scales to f32 once, then dequantize via the optimized path
let scales_f32 = quantizer::scales_to_f32(&scales);
quantizer::dequantize_f32(
data,
&scales_f32,
group_len as usize,
bits,
tensor_len as usize,
frame_count as usize,
out,
);
}
/// Parse only the segment header (no data decoding).
pub fn parse_header(segment: &[u8]) -> Option<SegmentHeader> {
if segment.len() < HEADER_SIZE {
return None;
}
let mut off = 0;
let magic = read_u32_le(segment, &mut off);
if magic != MAGIC {
return None;
}
let version = segment[off];
off += 1;
if version != VERSION {
return None;
}
let bits = segment[off];
off += 1;
let group_len = read_u32_le(segment, &mut off);
let tensor_len = read_u32_le(segment, &mut off);
let frame_count = read_u32_le(segment, &mut off);
let scale_count = read_u32_le(segment, &mut off);
Some(SegmentHeader {
bits,
group_len,
tensor_len,
frame_count,
scale_count,
})
}
/// Compute the compression ratio for a segment: raw f32 bytes / segment bytes.
///
/// Returns `0.0` if the segment is empty or has no frames.
pub fn compression_ratio(segment: &[u8]) -> f32 {
match parse_header(segment) {
Some(h) if h.frame_count > 0 => {
let raw = h.tensor_len as usize * h.frame_count as usize * 4;
raw as f32 / segment.len() as f32
}
_ => 0.0,
}
}
/// Decode a single frame by index from a segment.
///
/// Returns `None` if the segment is invalid or `frame_idx` is out of range.
pub fn decode_single_frame(segment: &[u8], frame_idx: usize) -> Option<Vec<f32>> {
let header = parse_header(segment)?;
if frame_idx >= header.frame_count as usize {
return None;
}
// Skip past the fixed header fields (magic + version + bits + group_len +
// tensor_len + frame_count + scale_count = 4+1+1+4+4+4+4 = 22 bytes).
let mut off = 22usize;
let scale_count = header.scale_count as usize;
// Read scales
let scales_end = off + scale_count * 2;
if scales_end > segment.len() {
return None;
}
let mut scales_f16 = Vec::with_capacity(scale_count);
for _ in 0..scale_count {
scales_f16.push(read_u16_le(segment, &mut off));
}
let scales_f32 = quantizer::scales_to_f32(&scales_f16);
// Read data section
if off + 4 > segment.len() {
return None;
}
let data_len = read_u32_le(segment, &mut off) as usize;
if off + data_len > segment.len() {
return None;
}
let data = &segment[off..off + data_len];
// Compute byte offset for the requested frame
let tensor_len = header.tensor_len as usize;
let bits = header.bits;
let bits_per_frame = tensor_len * bits as usize;
let bytes_per_frame = bits_per_frame.div_ceil(8);
let frame_start = frame_idx * bytes_per_frame;
if frame_start + bytes_per_frame > data.len() {
return None;
}
let frame_data = &data[frame_start..frame_start + bytes_per_frame];
let mut out = Vec::new();
quantizer::dequantize_f32(
frame_data,
&scales_f32,
header.group_len as usize,
bits,
tensor_len,
1,
&mut out,
);
Some(out)
}
#[inline]
fn read_u32_le(bytes: &[u8], offset: &mut usize) -> u32 {
let o = *offset;
let arr = [bytes[o], bytes[o + 1], bytes[o + 2], bytes[o + 3]];
*offset = o + 4;
u32::from_le_bytes(arr)
}
fn read_u16_le(bytes: &[u8], offset: &mut usize) -> u16 {
let o = *offset;
let arr = [bytes[o], bytes[o + 1]];
*offset = o + 2;
u16::from_le_bytes(arr)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::quantizer;
#[test]
fn test_encode_decode_roundtrip() {
let frame: Vec<f32> = (0..128).map(|i| (i as f32 - 64.0) * 0.1).collect();
let group_len = 64usize;
let bits = 8u8;
let scales = quantizer::compute_scales(&frame, group_len, bits);
let mut packed = Vec::new();
quantizer::quantize_and_pack(&frame, &scales, group_len, bits, &mut packed);
let mut seg = Vec::new();
encode(
bits,
group_len as u32,
frame.len() as u32,
1,
&scales,
&packed,
&mut seg,
);
let mut decoded = Vec::new();
decode(&seg, &mut decoded);
assert_eq!(decoded.len(), frame.len());
for (i, (&orig, &dec)) in frame.iter().zip(decoded.iter()).enumerate() {
let err = (orig - dec).abs();
assert!(err < 0.1, "i={i} orig={orig} dec={dec} err={err}");
}
}
#[test]
fn test_magic_validation() {
let mut decoded = Vec::new();
decode(&[0, 0, 0, 0], &mut decoded);
assert!(decoded.is_empty()); // Wrong magic
}
#[test]
fn test_parse_header() {
let frame = vec![1.0f32; 64];
let scales = quantizer::compute_scales(&frame, 64, 7);
let mut packed = Vec::new();
quantizer::quantize_and_pack(&frame, &scales, 64, 7, &mut packed);
let mut seg = Vec::new();
encode(7, 64, 64, 1, &scales, &packed, &mut seg);
let header = parse_header(&seg).unwrap();
assert_eq!(header.bits, 7);
assert_eq!(header.group_len, 64);
assert_eq!(header.tensor_len, 64);
assert_eq!(header.frame_count, 1);
}
#[test]
fn test_multi_frame_roundtrip() {
let group_len = 32usize;
let bits = 5u8;
let tensor_len = 64;
let frame1: Vec<f32> = (0..tensor_len).map(|i| (i as f32) * 0.1).collect();
let frame2: Vec<f32> = (0..tensor_len).map(|i| (i as f32) * 0.09).collect();
let scales = quantizer::compute_scales(&frame1, group_len, bits);
let mut packed = Vec::new();
quantizer::quantize_and_pack(&frame1, &scales, group_len, bits, &mut packed);
quantizer::quantize_and_pack(&frame2, &scales, group_len, bits, &mut packed);
let mut seg = Vec::new();
encode(
bits,
group_len as u32,
tensor_len as u32,
2,
&scales,
&packed,
&mut seg,
);
let mut decoded = Vec::new();
decode(&seg, &mut decoded);
assert_eq!(decoded.len(), tensor_len * 2);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,877 @@
//! WASM/C FFI for the block-based temporal tensor store (ADR-022).
//!
//! Exports `extern "C"` functions prefixed with `tts_` for:
//! - Store lifecycle (`tts_init`)
//! - Block ingest and read (`tts_put`, `tts_get`)
//! - Access tracking (`tts_touch`)
//! - Maintenance (`tts_tick`, `tts_evict`)
//! - Statistics (`tts_stats`, `tts_block_count`, `tts_tier_count`)
//!
//! Coexists with `ffi.rs` which exports `ttc_*` functions for the
//! frame-based compressor.
use std::collections::HashMap;
use crate::quantizer;
use crate::segment;
// ── Error codes ──────────────────────────────────────────────────────
#[allow(dead_code)]
const ERR_NOT_INITIALIZED: i32 = -1;
const ERR_NULL_POINTER: i32 = -2;
const ERR_INVALID_CONFIG: i32 = -3;
const ERR_BLOCK_NOT_FOUND: i32 = -4;
const ERR_BUFFER_TOO_SMALL: i32 = -5;
const ERR_EMPTY_DATA: i32 = -6;
// ── Types ────────────────────────────────────────────────────────────
// These mirror the types defined in store.rs and tiering.rs which are
// being written in parallel. Once those modules land, these can be
// replaced with `use crate::store::*` / `use crate::tiering::*`.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
struct BlockKey {
tensor_id: u128,
block_index: u32,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
enum Tier {
Hot = 0,
Warm = 1,
Cool = 2,
Cold = 3,
}
impl Tier {
fn from_u8(v: u8) -> Option<Self> {
match v {
0 => Some(Tier::Hot),
1 => Some(Tier::Warm),
2 => Some(Tier::Cool),
3 => Some(Tier::Cold),
_ => None,
}
}
/// Quantization bit-width for this tier.
fn bits(self) -> u8 {
match self {
Tier::Hot => 8,
Tier::Warm => 7,
Tier::Cool => 5,
Tier::Cold => 3,
}
}
}
#[derive(Clone, Debug)]
struct BlockMeta {
tier: Tier,
access_count: u32,
last_access_ts: u64,
ema_score: f32,
/// Original f32 count; used when re-tiering to size the decode buffer.
#[allow(dead_code)]
element_count: usize,
}
/// Binary config layout (little-endian, 45 bytes):
/// ```text
/// [block_bytes:u32][alpha:f32][tau:f32][w_ema:f32][w_pop:f32][w_rec:f32]
/// [t1:f32][t2:f32][t3:f32][hysteresis:f32][min_residency:u32][max_delta_chain:u8]
/// ```
#[derive(Clone, Debug)]
struct TierConfig {
block_bytes: u32,
alpha: f32,
tau: f32,
w_ema: f32,
w_pop: f32,
w_rec: f32,
t1: f32,
t2: f32,
t3: f32,
hysteresis: f32,
min_residency: u32,
max_delta_chain: u8,
}
const CONFIG_BINARY_LEN: usize = 45;
impl Default for TierConfig {
fn default() -> Self {
Self {
block_bytes: 4096,
alpha: 0.3,
tau: 100.0,
w_ema: 0.5,
w_pop: 0.3,
w_rec: 0.2,
t1: 0.8,
t2: 0.5,
t3: 0.2,
hysteresis: 0.05,
min_residency: 10,
max_delta_chain: 4,
}
}
}
impl TierConfig {
fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() < CONFIG_BINARY_LEN {
return None;
}
let mut off = 0usize;
let block_bytes = read_u32_le(bytes, &mut off);
let alpha = read_f32_le(bytes, &mut off);
let tau = read_f32_le(bytes, &mut off);
let w_ema = read_f32_le(bytes, &mut off);
let w_pop = read_f32_le(bytes, &mut off);
let w_rec = read_f32_le(bytes, &mut off);
let t1 = read_f32_le(bytes, &mut off);
let t2 = read_f32_le(bytes, &mut off);
let t3 = read_f32_le(bytes, &mut off);
let hysteresis = read_f32_le(bytes, &mut off);
let min_residency = read_u32_le(bytes, &mut off);
let max_delta_chain = bytes[off];
if ![alpha, tau, w_ema, w_pop, w_rec, t1, t2, t3, hysteresis]
.iter()
.all(|v| v.is_finite())
{
return None;
}
Some(Self {
block_bytes,
alpha,
tau,
w_ema,
w_pop,
w_rec,
t1,
t2,
t3,
hysteresis,
min_residency,
max_delta_chain,
})
}
}
// ── Store ────────────────────────────────────────────────────────────
struct TieredStore {
blocks: HashMap<BlockKey, (BlockMeta, Vec<u8>)>,
}
impl TieredStore {
fn new() -> Self {
Self {
blocks: HashMap::new(),
}
}
fn block_count(&self) -> usize {
self.blocks.len()
}
fn tier_count(&self, tier: Tier) -> usize {
self.blocks.values().filter(|(m, _)| m.tier == tier).count()
}
fn total_bytes(&self) -> usize {
self.blocks.values().map(|(_, d)| d.len()).sum()
}
}
// ── Global state ─────────────────────────────────────────────────────
struct StoreState {
store: TieredStore,
config: TierConfig,
tick_count: u64,
}
static mut STORE_STATE: Option<StoreState> = None;
// ── Helpers ──────────────────────────────────────────────────────────
/// Combine hi/lo u64 into u128 tensor_id.
#[inline]
fn make_tensor_id(hi: u64, lo: u64) -> u128 {
((hi as u128) << 64) | (lo as u128)
}
/// Access the global store state, initializing with defaults if needed.
fn with_state<F, R>(f: F) -> R
where
F: FnOnce(&mut StoreState) -> R,
{
unsafe {
if STORE_STATE.is_none() {
STORE_STATE = Some(StoreState {
store: TieredStore::new(),
config: TierConfig::default(),
tick_count: 0,
});
}
f(STORE_STATE.as_mut().unwrap())
}
}
const DEFAULT_GROUP_LEN: usize = 64;
/// Composite access score used for tier selection.
fn compute_score(config: &TierConfig, meta: &BlockMeta, tick: u64) -> f32 {
let recency = if tick > meta.last_access_ts {
(-((tick - meta.last_access_ts) as f32) / config.tau).exp()
} else {
1.0
};
let popularity = (meta.access_count as f32).ln_1p();
config.w_ema * meta.ema_score + config.w_pop * popularity + config.w_rec * recency
}
/// Map a score to a tier using the config thresholds.
fn choose_tier(config: &TierConfig, score: f32) -> Tier {
if score >= config.t1 {
Tier::Hot
} else if score >= config.t2 {
Tier::Warm
} else if score >= config.t3 {
Tier::Cool
} else {
Tier::Cold
}
}
/// Quantize f32 data and encode into a compressed segment.
fn encode_block(data: &[f32], tier: Tier) -> Vec<u8> {
let bits = tier.bits();
let group_len = DEFAULT_GROUP_LEN;
let scales = quantizer::compute_scales(data, group_len, bits);
let mut packed = Vec::new();
quantizer::quantize_and_pack(data, &scales, group_len, bits, &mut packed);
let mut seg = Vec::new();
segment::encode(
bits,
group_len as u32,
data.len() as u32,
1,
&scales,
&packed,
&mut seg,
);
seg
}
/// Decode a compressed segment back to f32.
fn decode_block(seg: &[u8]) -> Vec<f32> {
let mut out = Vec::new();
segment::decode(seg, &mut out);
out
}
#[inline]
fn read_u32_le(bytes: &[u8], off: &mut usize) -> u32 {
let o = *off;
let arr = [bytes[o], bytes[o + 1], bytes[o + 2], bytes[o + 3]];
*off = o + 4;
u32::from_le_bytes(arr)
}
#[inline]
fn read_f32_le(bytes: &[u8], off: &mut usize) -> f32 {
f32::from_bits(read_u32_le(bytes, off))
}
#[inline]
fn write_u32_le(buf: &mut [u8], off: &mut usize, v: u32) {
buf[*off..*off + 4].copy_from_slice(&v.to_le_bytes());
*off += 4;
}
#[inline]
fn write_u64_le(buf: &mut [u8], off: &mut usize, v: u64) {
buf[*off..*off + 8].copy_from_slice(&v.to_le_bytes());
*off += 8;
}
/// Stats binary layout (36 bytes, little-endian):
/// ```text
/// [block_count:u32][hot:u32][warm:u32][cool:u32][cold:u32]
/// [total_bytes:u64][tick_count:u64]
/// ```
const STATS_SIZE: usize = 5 * 4 + 2 * 8;
// ── FFI exports ──────────────────────────────────────────────────────
/// Initialize the temporal tensor store with a serialized config.
/// If `policy_ptr` is null or `policy_len` is 0, uses `TierConfig::default()`.
/// Returns 0 on success, negative on error.
#[no_mangle]
pub extern "C" fn tts_init(policy_ptr: *const u8, policy_len: usize) -> i32 {
let config = if policy_ptr.is_null() || policy_len == 0 {
TierConfig::default()
} else {
let bytes = unsafe { std::slice::from_raw_parts(policy_ptr, policy_len) };
match TierConfig::from_bytes(bytes) {
Some(c) => c,
None => return ERR_INVALID_CONFIG,
}
};
unsafe {
STORE_STATE = Some(StoreState {
store: TieredStore::new(),
config,
tick_count: 0,
});
}
0
}
/// Store a tensor block. Quantizes according to the block's current tier
/// (or Hot for new blocks). `tensor_id` is split into hi/lo because WASM
/// does not support u128.
/// Returns 0 on success, negative on error.
#[no_mangle]
pub extern "C" fn tts_put(
tensor_id_hi: u64,
tensor_id_lo: u64,
block_index: u32,
data_ptr: *const f32,
data_len: usize,
) -> i32 {
if data_ptr.is_null() {
return ERR_NULL_POINTER;
}
if data_len == 0 {
return ERR_EMPTY_DATA;
}
let data = unsafe { std::slice::from_raw_parts(data_ptr, data_len) };
let key = BlockKey {
tensor_id: make_tensor_id(tensor_id_hi, tensor_id_lo),
block_index,
};
with_state(|state| {
let tier = state
.store
.blocks
.get(&key)
.map(|(m, _)| m.tier)
.unwrap_or(Tier::Hot);
let seg = encode_block(data, tier);
let meta = BlockMeta {
tier,
access_count: 1,
last_access_ts: state.tick_count,
ema_score: 1.0,
element_count: data_len,
};
state.store.blocks.insert(key, (meta, seg));
0
})
}
/// Read a tensor block, dequantized to f32.
/// Returns the number of f32 elements written, or negative on error.
#[no_mangle]
pub extern "C" fn tts_get(
tensor_id_hi: u64,
tensor_id_lo: u64,
block_index: u32,
out_ptr: *mut f32,
out_len: usize,
) -> i32 {
if out_ptr.is_null() {
return ERR_NULL_POINTER;
}
let key = BlockKey {
tensor_id: make_tensor_id(tensor_id_hi, tensor_id_lo),
block_index,
};
with_state(|state| match state.store.blocks.get(&key) {
None => ERR_BLOCK_NOT_FOUND,
Some((_meta, seg)) => {
let decoded = decode_block(seg);
if decoded.len() > out_len {
return ERR_BUFFER_TOO_SMALL;
}
let out = unsafe { std::slice::from_raw_parts_mut(out_ptr, out_len) };
out[..decoded.len()].copy_from_slice(&decoded);
decoded.len() as i32
}
})
}
/// Run a maintenance tick with byte and operation budgets.
/// Re-scores every block and migrates those whose tier has changed,
/// subject to hysteresis.
/// Returns number of migration operations performed, or negative on error.
#[no_mangle]
pub extern "C" fn tts_tick(budget_bytes: u32, budget_ops: u32) -> i32 {
with_state(|state| {
state.tick_count += 1;
let tick = state.tick_count;
// Snapshot keys and scores so we can mutate blocks afterwards.
let entries: Vec<(BlockKey, f32)> = state
.store
.blocks
.iter()
.map(|(k, (m, _))| (*k, compute_score(&state.config, m, tick)))
.collect();
let mut ops = 0u32;
let mut bytes_used = 0u32;
for (key, score) in entries {
if ops >= budget_ops || bytes_used >= budget_bytes {
break;
}
if let Some((meta, seg)) = state.store.blocks.get_mut(&key) {
let new_tier = choose_tier(&state.config, score);
let current_threshold = match meta.tier {
Tier::Hot => state.config.t1,
Tier::Warm => state.config.t2,
Tier::Cool => state.config.t3,
Tier::Cold => 0.0,
};
let needs_change = new_tier != meta.tier
&& (score - current_threshold).abs() > state.config.hysteresis;
if needs_change {
let decoded = decode_block(seg);
if !decoded.is_empty() {
let new_seg = encode_block(&decoded, new_tier);
bytes_used = bytes_used.saturating_add(new_seg.len() as u32);
*seg = new_seg;
meta.tier = new_tier;
ops += 1;
}
}
// Update EMA for every block regardless of migration.
meta.ema_score =
state.config.alpha * score + (1.0 - state.config.alpha) * meta.ema_score;
}
}
ops as i32
})
}
/// Write a statistics snapshot to `out_ptr`.
/// Returns number of bytes written, or negative on error.
#[no_mangle]
pub extern "C" fn tts_stats(out_ptr: *mut u8, out_len: usize) -> i32 {
if out_ptr.is_null() {
return ERR_NULL_POINTER;
}
if out_len < STATS_SIZE {
return ERR_BUFFER_TOO_SMALL;
}
with_state(|state| {
let out = unsafe { std::slice::from_raw_parts_mut(out_ptr, out_len) };
let mut off = 0usize;
write_u32_le(out, &mut off, state.store.block_count() as u32);
write_u32_le(out, &mut off, state.store.tier_count(Tier::Hot) as u32);
write_u32_le(out, &mut off, state.store.tier_count(Tier::Warm) as u32);
write_u32_le(out, &mut off, state.store.tier_count(Tier::Cool) as u32);
write_u32_le(out, &mut off, state.store.tier_count(Tier::Cold) as u32);
write_u64_le(out, &mut off, state.store.total_bytes() as u64);
write_u64_le(out, &mut off, state.tick_count);
STATS_SIZE as i32
})
}
/// Record an access event for a block (increments count, updates timestamp).
/// Returns 0 on success, negative on error.
#[no_mangle]
pub extern "C" fn tts_touch(tensor_id_hi: u64, tensor_id_lo: u64, block_index: u32) -> i32 {
let key = BlockKey {
tensor_id: make_tensor_id(tensor_id_hi, tensor_id_lo),
block_index,
};
with_state(|state| match state.store.blocks.get_mut(&key) {
None => ERR_BLOCK_NOT_FOUND,
Some((meta, _)) => {
meta.access_count = meta.access_count.saturating_add(1);
meta.last_access_ts = state.tick_count;
0
}
})
}
/// Evict a block, removing it from the store entirely.
/// Returns 0 on success, negative on error.
#[no_mangle]
pub extern "C" fn tts_evict(tensor_id_hi: u64, tensor_id_lo: u64, block_index: u32) -> i32 {
let key = BlockKey {
tensor_id: make_tensor_id(tensor_id_hi, tensor_id_lo),
block_index,
};
with_state(|state| match state.store.blocks.remove(&key) {
None => ERR_BLOCK_NOT_FOUND,
Some(_) => 0,
})
}
/// Get total number of blocks in the store.
#[no_mangle]
pub extern "C" fn tts_block_count() -> i32 {
with_state(|state| state.store.block_count() as i32)
}
/// Get number of blocks in a specific tier (0=Hot, 1=Warm, 2=Cool, 3=Cold).
#[no_mangle]
pub extern "C" fn tts_tier_count(tier: u8) -> i32 {
match Tier::from_u8(tier) {
Some(t) => with_state(|state| state.store.tier_count(t) as i32),
None => ERR_INVALID_CONFIG,
}
}
// ── Tests ────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
/// Reset global state before each test.
fn reset() {
unsafe {
STORE_STATE = None;
}
}
/// Build a binary config buffer from the default TierConfig.
fn default_config_bytes() -> Vec<u8> {
let c = TierConfig::default();
let mut buf = Vec::with_capacity(CONFIG_BINARY_LEN);
buf.extend_from_slice(&c.block_bytes.to_le_bytes());
buf.extend_from_slice(&c.alpha.to_bits().to_le_bytes());
buf.extend_from_slice(&c.tau.to_bits().to_le_bytes());
buf.extend_from_slice(&c.w_ema.to_bits().to_le_bytes());
buf.extend_from_slice(&c.w_pop.to_bits().to_le_bytes());
buf.extend_from_slice(&c.w_rec.to_bits().to_le_bytes());
buf.extend_from_slice(&c.t1.to_bits().to_le_bytes());
buf.extend_from_slice(&c.t2.to_bits().to_le_bytes());
buf.extend_from_slice(&c.t3.to_bits().to_le_bytes());
buf.extend_from_slice(&c.hysteresis.to_bits().to_le_bytes());
buf.extend_from_slice(&c.min_residency.to_le_bytes());
buf.push(c.max_delta_chain);
buf
}
#[test]
fn test_init_default() {
reset();
let rc = tts_init(std::ptr::null(), 0);
assert_eq!(rc, 0);
assert_eq!(tts_block_count(), 0);
}
#[test]
fn test_init_with_config() {
reset();
let cfg = default_config_bytes();
let rc = tts_init(cfg.as_ptr(), cfg.len());
assert_eq!(rc, 0);
assert_eq!(tts_block_count(), 0);
}
#[test]
fn test_init_invalid_config_too_short() {
reset();
let buf = [0u8; 10];
let rc = tts_init(buf.as_ptr(), buf.len());
assert_eq!(rc, ERR_INVALID_CONFIG);
}
#[test]
fn test_put_get_roundtrip() {
reset();
tts_init(std::ptr::null(), 0);
let data: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
let rc = tts_put(0, 1, 0, data.as_ptr(), data.len());
assert_eq!(rc, 0);
let mut out = vec![0.0f32; 64];
let n = tts_get(0, 1, 0, out.as_mut_ptr(), out.len());
assert_eq!(n, 64);
// 8-bit quantization: expect low error.
let max_abs = data.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
for (i, (&orig, &dec)) in data.iter().zip(out.iter()).enumerate() {
let err = (orig - dec).abs();
assert!(
err < max_abs * 0.05,
"i={i} orig={orig} dec={dec} err={err}"
);
}
}
#[test]
fn test_put_null_pointer() {
reset();
tts_init(std::ptr::null(), 0);
let rc = tts_put(0, 1, 0, std::ptr::null(), 64);
assert_eq!(rc, ERR_NULL_POINTER);
}
#[test]
fn test_put_empty_data() {
reset();
tts_init(std::ptr::null(), 0);
let data = [1.0f32; 1];
let rc = tts_put(0, 1, 0, data.as_ptr(), 0);
assert_eq!(rc, ERR_EMPTY_DATA);
}
#[test]
fn test_get_not_found() {
reset();
tts_init(std::ptr::null(), 0);
let mut out = vec![0.0f32; 64];
let rc = tts_get(0, 99, 0, out.as_mut_ptr(), out.len());
assert_eq!(rc, ERR_BLOCK_NOT_FOUND);
}
#[test]
fn test_get_null_pointer() {
reset();
tts_init(std::ptr::null(), 0);
let rc = tts_get(0, 1, 0, std::ptr::null_mut(), 64);
assert_eq!(rc, ERR_NULL_POINTER);
}
#[test]
fn test_get_buffer_too_small() {
reset();
tts_init(std::ptr::null(), 0);
let data = vec![1.0f32; 64];
tts_put(0, 1, 0, data.as_ptr(), data.len());
let mut out = vec![0.0f32; 2]; // too small
let rc = tts_get(0, 1, 0, out.as_mut_ptr(), out.len());
assert_eq!(rc, ERR_BUFFER_TOO_SMALL);
}
#[test]
fn test_block_count_after_puts() {
reset();
tts_init(std::ptr::null(), 0);
let data = vec![1.0f32; 64];
tts_put(0, 1, 0, data.as_ptr(), data.len());
tts_put(0, 1, 1, data.as_ptr(), data.len());
tts_put(0, 2, 0, data.as_ptr(), data.len());
assert_eq!(tts_block_count(), 3);
}
#[test]
fn test_tier_count_initial() {
reset();
tts_init(std::ptr::null(), 0);
let data = vec![1.0f32; 64];
tts_put(0, 1, 0, data.as_ptr(), data.len());
tts_put(0, 1, 1, data.as_ptr(), data.len());
// New blocks default to Hot.
assert_eq!(tts_tier_count(0), 2); // Hot
assert_eq!(tts_tier_count(1), 0); // Warm
assert_eq!(tts_tier_count(2), 0); // Cool
assert_eq!(tts_tier_count(3), 0); // Cold
}
#[test]
fn test_tier_count_invalid_tier() {
reset();
tts_init(std::ptr::null(), 0);
assert_eq!(tts_tier_count(99), ERR_INVALID_CONFIG);
}
#[test]
fn test_touch() {
reset();
tts_init(std::ptr::null(), 0);
let data = vec![1.0f32; 64];
tts_put(0, 1, 0, data.as_ptr(), data.len());
let rc = tts_touch(0, 1, 0);
assert_eq!(rc, 0);
// Touch a non-existent block.
let rc = tts_touch(0, 99, 0);
assert_eq!(rc, ERR_BLOCK_NOT_FOUND);
}
#[test]
fn test_evict() {
reset();
tts_init(std::ptr::null(), 0);
let data = vec![1.0f32; 64];
tts_put(0, 1, 0, data.as_ptr(), data.len());
assert_eq!(tts_block_count(), 1);
let rc = tts_evict(0, 1, 0);
assert_eq!(rc, 0);
assert_eq!(tts_block_count(), 0);
// Evict again should fail.
let rc = tts_evict(0, 1, 0);
assert_eq!(rc, ERR_BLOCK_NOT_FOUND);
}
#[test]
fn test_tick_does_not_crash() {
reset();
tts_init(std::ptr::null(), 0);
let data = vec![1.0f32; 64];
tts_put(0, 1, 0, data.as_ptr(), data.len());
tts_put(0, 1, 1, data.as_ptr(), data.len());
// Run several ticks with generous budgets.
for _ in 0..10 {
let ops = tts_tick(1_000_000, 1000);
assert!(ops >= 0);
}
// Blocks should still be readable.
let mut out = vec![0.0f32; 64];
let n = tts_get(0, 1, 0, out.as_mut_ptr(), out.len());
assert!(n > 0);
}
#[test]
fn test_tick_with_zero_budget() {
reset();
tts_init(std::ptr::null(), 0);
let data = vec![1.0f32; 64];
tts_put(0, 1, 0, data.as_ptr(), data.len());
let ops = tts_tick(0, 0);
assert_eq!(ops, 0);
}
#[test]
fn test_stats_returns_valid_data() {
reset();
tts_init(std::ptr::null(), 0);
let data = vec![1.0f32; 64];
tts_put(0, 1, 0, data.as_ptr(), data.len());
tts_put(0, 1, 1, data.as_ptr(), data.len());
let mut buf = vec![0u8; STATS_SIZE];
let written = tts_stats(buf.as_mut_ptr(), buf.len());
assert_eq!(written, STATS_SIZE as i32);
// Parse the stats back.
let mut off = 0usize;
let block_count = read_u32_le(&buf, &mut off);
let hot = read_u32_le(&buf, &mut off);
let warm = read_u32_le(&buf, &mut off);
let cool = read_u32_le(&buf, &mut off);
let cold = read_u32_le(&buf, &mut off);
assert_eq!(block_count, 2);
assert_eq!(hot, 2);
assert_eq!(warm, 0);
assert_eq!(cool, 0);
assert_eq!(cold, 0);
}
#[test]
fn test_stats_null_pointer() {
reset();
tts_init(std::ptr::null(), 0);
let rc = tts_stats(std::ptr::null_mut(), 64);
assert_eq!(rc, ERR_NULL_POINTER);
}
#[test]
fn test_stats_buffer_too_small() {
reset();
tts_init(std::ptr::null(), 0);
let mut buf = vec![0u8; 4]; // too small
let rc = tts_stats(buf.as_mut_ptr(), buf.len());
assert_eq!(rc, ERR_BUFFER_TOO_SMALL);
}
#[test]
fn test_make_tensor_id() {
assert_eq!(make_tensor_id(0, 0), 0u128);
assert_eq!(make_tensor_id(0, 1), 1u128);
assert_eq!(make_tensor_id(1, 0), 1u128 << 64);
assert_eq!(make_tensor_id(u64::MAX, u64::MAX), u128::MAX,);
}
#[test]
fn test_multiple_tensor_ids() {
reset();
tts_init(std::ptr::null(), 0);
let data = vec![1.0f32; 64];
tts_put(0, 1, 0, data.as_ptr(), data.len());
tts_put(0, 2, 0, data.as_ptr(), data.len());
tts_put(1, 0, 0, data.as_ptr(), data.len());
assert_eq!(tts_block_count(), 3);
// Each should be independently readable.
let mut out = vec![0.0f32; 64];
assert!(tts_get(0, 1, 0, out.as_mut_ptr(), out.len()) > 0);
assert!(tts_get(0, 2, 0, out.as_mut_ptr(), out.len()) > 0);
assert!(tts_get(1, 0, 0, out.as_mut_ptr(), out.len()) > 0);
}
#[test]
fn test_overwrite_block() {
reset();
tts_init(std::ptr::null(), 0);
let data1 = vec![1.0f32; 64];
tts_put(0, 1, 0, data1.as_ptr(), data1.len());
let data2 = vec![2.0f32; 64];
tts_put(0, 1, 0, data2.as_ptr(), data2.len());
assert_eq!(tts_block_count(), 1);
// Should read back the second write.
let mut out = vec![0.0f32; 64];
let n = tts_get(0, 1, 0, out.as_mut_ptr(), out.len());
assert_eq!(n, 64);
for &v in &out {
assert!((v - 2.0).abs() < 0.1);
}
}
}

View File

@@ -0,0 +1,104 @@
//! Tier policy for access-pattern-driven bit-width selection.
//!
//! Score = `access_count * 1024 / (now_ts - last_access_ts + 1)`
//!
//! | Tier | Condition | Bits |
//! |------|-----------|------|
//! | Hot | score >= hot_min_score | 8 |
//! | Warm | score >= warm_min_score | warm_bits (7 or 5) |
//! | Cold | otherwise | 3 |
#[derive(Clone, Copy, Debug)]
pub struct TierPolicy {
pub hot_min_score: u32,
pub warm_min_score: u32,
pub warm_bits: u8,
/// Drift tolerance as Q8 fixed-point. 26 means ~10.2% (26/256).
pub drift_pct_q8: u32,
pub group_len: u32,
}
impl Default for TierPolicy {
fn default() -> Self {
Self {
hot_min_score: 512,
warm_min_score: 64,
warm_bits: 7,
drift_pct_q8: 26,
group_len: 64,
}
}
}
impl TierPolicy {
/// Select bit width based on access pattern.
pub fn select_bits(&self, access_count: u32, last_access_ts: u32, now_ts: u32) -> u8 {
let age = now_ts.wrapping_sub(last_access_ts).wrapping_add(1);
let score = access_count.saturating_mul(1024).wrapping_div(age);
if score >= self.hot_min_score {
8
} else if score >= self.warm_min_score {
self.warm_bits
} else {
3
}
}
/// Compute the drift factor as 1.0 + drift_pct_q8/256.
pub fn drift_factor(&self) -> f32 {
1.0 + (self.drift_pct_q8 as f32) / 256.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_policy() {
let p = TierPolicy::default();
assert_eq!(p.hot_min_score, 512);
assert_eq!(p.warm_min_score, 64);
assert_eq!(p.warm_bits, 7);
assert_eq!(p.drift_pct_q8, 26);
assert_eq!(p.group_len, 64);
}
#[test]
fn test_tier_selection_hot() {
let p = TierPolicy::default();
// 100 accesses, age=10 -> score = 100*1024/10 = 10240 >= 512
assert_eq!(p.select_bits(100, 0, 9), 8);
}
#[test]
fn test_tier_selection_warm() {
let p = TierPolicy::default();
// 10 accesses, age=100 -> score = 10*1024/100 = 102 >= 64, < 512
assert_eq!(p.select_bits(10, 0, 99), 7);
}
#[test]
fn test_tier_selection_cold() {
let p = TierPolicy::default();
// 1 access, age=1000 -> score = 1024/1000 = 1 < 64
assert_eq!(p.select_bits(1, 0, 999), 3);
}
#[test]
fn test_drift_factor() {
let p = TierPolicy::default();
let df = p.drift_factor();
assert!((df - 1.1015625).abs() < 1e-6);
}
#[test]
fn test_warm_bits_5() {
let p = TierPolicy {
warm_bits: 5,
..Default::default()
};
assert_eq!(p.select_bits(10, 0, 99), 5);
}
}

File diff suppressed because it is too large Load Diff