git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
1471 lines
44 KiB
Markdown
1471 lines
44 KiB
Markdown
# Feature 16: Predictive Prefetch Attention (PPA)
|
|
|
|
## Overview
|
|
|
|
### Problem Statement
|
|
Traditional attention mechanisms compute attention scores reactively after receiving a query, leading to inherent latency bottlenecks. In production systems with sequential or temporal query patterns, this reactive approach wastes opportunities for proactive computation. Users often issue semantically related queries in sequences, but current systems treat each query independently.
|
|
|
|
### Proposed Solution
|
|
Predictive Prefetch Attention (PPA) uses a learned query predictor to anticipate future queries and pre-compute attention scores before they're needed. The system maintains a cache of pre-computed attention results and continuously learns from observed query sequences to improve prediction accuracy. The predictor trains online, becoming more accurate with usage.
|
|
|
|
### Expected Benefits
|
|
- **Latency Reduction**: 60-80% reduction in p95 query latency for predictable patterns
|
|
- **Throughput Improvement**: 3-5x increase in queries per second
|
|
- **Self-Improvement**: Prediction accuracy improves from ~30% to 70-85% with usage
|
|
- **Cache Hit Rate**: 65-75% for typical workloads after warm-up period
|
|
- **Resource Efficiency**: Utilize idle CPU/GPU cycles for prefetch computation
|
|
|
|
### Novelty Claim
|
|
**Unique Contribution**: First GNN system with learned query prediction and adaptive prefetching for attention mechanisms. Unlike traditional caching (which stores past results) or static prefetching (which uses fixed patterns), PPA learns temporal and semantic query patterns dynamically and adapts its prefetching strategy based on prediction confidence and system load.
|
|
|
|
**Differentiators**:
|
|
1. Online learning of query patterns (vs. static caching)
|
|
2. Confidence-based prefetch scheduling (vs. always-prefetch)
|
|
3. Multi-scale temporal modeling (short-term, session-level, long-term)
|
|
4. Adaptive cache management with reinforcement learning
|
|
5. Integration of query prediction with attention computation
|
|
|
|
## Technical Design
|
|
|
|
### Architecture Diagram
|
|
|
|
```
|
|
┌─────────────────────────────────────────────────────────────────┐
|
|
│ Query Stream │
|
|
└────────────┬────────────────────────────────────────────────────┘
|
|
│
|
|
▼
|
|
┌─────────────────────────────────────────────────────────────────┐
|
|
│ Query Predictor │
|
|
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
|
|
│ │ Short-term │ │ Session-level│ │ Long-term │ │
|
|
│ │ LSTM │ │ Transformer │ │ Pattern │ │
|
|
│ │ (last 5-10) │ │ (session) │ │ Embedding │ │
|
|
│ └──────┬───────┘ └──────┬───────┘ └──────┬───────┘ │
|
|
│ │ │ │ │
|
|
│ └──────────────────┴──────────────────┘ │
|
|
│ │ │
|
|
│ Ensemble Prediction │
|
|
│ │ │
|
|
│ ┌─────────▼─────────┐ │
|
|
│ │ Top-K Predictions │ │
|
|
│ │ + Confidence │ │
|
|
│ └─────────┬─────────┘ │
|
|
└────────────────────────────┼──────────────────────────────────┘
|
|
│
|
|
▼
|
|
┌─────────────────────────────────────────────────────────────────┐
|
|
│ Prefetch Scheduler │
|
|
│ ┌──────────────────────────────────────────────────────────┐ │
|
|
│ │ Priority = f(confidence, cache_space, system_load) │ │
|
|
│ └────────────────────┬─────────────────────────────────────┘ │
|
|
│ │ │
|
|
│ ┌─────────────┼─────────────┐ │
|
|
│ ▼ ▼ ▼ │
|
|
│ High Priority Med Priority Low Priority │
|
|
│ (conf > 0.8) (0.5-0.8) (0.3-0.5) │
|
|
│ │ │ │ │
|
|
└─────────┼─────────────┼──────────────┼──────────────────────────┘
|
|
│ │ │
|
|
▼ ▼ ▼
|
|
┌─────────────────────────────────────────────────────────────────┐
|
|
│ Attention Computation Pool │
|
|
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
|
|
│ │ Worker 1 │ │ Worker 2 │ │ Worker 3 │ │ Worker 4 │ │
|
|
│ │ Prefetch │ │ Prefetch │ │ Real-time│ │ Real-time│ │
|
|
│ └──────┬───┘ └──────┬───┘ └──────┬───┘ └──────┬───┘ │
|
|
│ │ │ │ │ │
|
|
└─────────┼─────────────┼─────────────┼─────────────┼─────────────┘
|
|
│ │ │ │
|
|
▼ ▼ ▼ ▼
|
|
┌─────────────────────────────────────────────────────────────────┐
|
|
│ Attention Cache │
|
|
│ ┌──────────────────────────────────────────────────────────┐ │
|
|
│ │ Key: Query Hash | Value: (Attention Scores, Timestamp) │ │
|
|
│ │ Eviction: LRU + Prediction-Aware │ │
|
|
│ │ Size: Adaptive based on hit rate and memory │ │
|
|
│ └──────────────────────────────────────────────────────────┘ │
|
|
│ │
|
|
│ Cache Hit? ──Yes──> Return Cached Results (< 0.1ms) │
|
|
│ │ │
|
|
│ No │
|
|
│ │ │
|
|
│ ▼ │
|
|
│ Compute Attention (blocking, 2-5ms) │
|
|
│ │ │
|
|
│ ▼ │
|
|
│ Store in Cache │
|
|
└───────────────────────────────────────────────────────────────────┘
|
|
│
|
|
▼
|
|
┌─────────────────────────────────────────────────────────────────┐
|
|
│ Feedback Loop (Online Learning) │
|
|
│ ┌──────────────────────────────────────────────────────────┐ │
|
|
│ │ Actual Query → Compare with Prediction → Update Weights │ │
|
|
│ │ Hit/Miss → Adjust Cache Policy │ │
|
|
│ │ Latency → Tune Prefetch Aggressiveness │ │
|
|
│ └──────────────────────────────────────────────────────────┘ │
|
|
└─────────────────────────────────────────────────────────────────┘
|
|
|
|
|
|
Query Predictor Detail:
|
|
┌───────────────────────────────────────┐
|
|
│ Short-term LSTM (last 5-10) │
|
|
│ │
|
|
│ q[t-5] → q[t-4] → ... → q[t-1] │
|
|
│ │ │ │ │
|
|
│ ▼ ▼ ▼ │
|
|
│ [LSTM Cell] → [LSTM Cell] → ... │
|
|
│ │ │
|
|
│ ▼ │
|
|
│ Prediction q[t] │
|
|
└───────────────────────────────────────┘
|
|
|
|
┌───────────────────────────────────────┐
|
|
│ Session-level Transformer │
|
|
│ │
|
|
│ [Session Start] ... [Recent Queries] │
|
|
│ │ │
|
|
│ ▼ │
|
|
│ Self-Attention │
|
|
│ │ │
|
|
│ ▼ │
|
|
│ Position Encoding │
|
|
│ │ │
|
|
│ ▼ │
|
|
│ Prediction q[t] │
|
|
└───────────────────────────────────────┘
|
|
```
|
|
|
|
### Core Data Structures
|
|
|
|
```rust
|
|
/// Configuration for Predictive Prefetch Attention
|
|
#[derive(Debug, Clone)]
|
|
pub struct PPAConfig {
|
|
/// Number of recent queries to track
|
|
pub history_size: usize,
|
|
|
|
/// Number of queries to prefetch
|
|
pub prefetch_k: usize,
|
|
|
|
/// Minimum confidence for prefetching
|
|
pub min_confidence: f32,
|
|
|
|
/// Maximum cache size (number of entries)
|
|
pub max_cache_size: usize,
|
|
|
|
/// Number of prefetch worker threads
|
|
pub num_workers: usize,
|
|
|
|
/// Enable online learning
|
|
pub online_learning: bool,
|
|
|
|
/// Learning rate for predictor updates
|
|
pub learning_rate: f32,
|
|
|
|
/// Predictor architecture
|
|
pub predictor_type: PredictorType,
|
|
|
|
/// Cache eviction policy
|
|
pub eviction_policy: EvictionPolicy,
|
|
}
|
|
|
|
/// Query history and pattern tracking
|
|
#[derive(Debug, Clone)]
|
|
pub struct QueryHistory {
|
|
/// Recent queries (circular buffer)
|
|
queries: VecDeque<QueryRecord>,
|
|
|
|
/// Maximum history size
|
|
max_size: usize,
|
|
|
|
/// Session ID for grouping related queries
|
|
session_id: Option<String>,
|
|
|
|
/// Session start time
|
|
session_start: std::time::Instant,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct QueryRecord {
|
|
/// Query embedding
|
|
pub embedding: Vec<f32>,
|
|
|
|
/// Timestamp
|
|
pub timestamp: std::time::Instant,
|
|
|
|
/// Query hash for cache lookup
|
|
pub hash: u64,
|
|
|
|
/// Session ID
|
|
pub session_id: Option<String>,
|
|
|
|
/// Metadata (user ID, query type, etc.)
|
|
pub metadata: HashMap<String, String>,
|
|
}
|
|
|
|
/// Query prediction result
|
|
#[derive(Debug, Clone)]
|
|
pub struct QueryPrediction {
|
|
/// Predicted query embedding
|
|
pub predicted_query: Vec<f32>,
|
|
|
|
/// Prediction confidence (0.0 - 1.0)
|
|
pub confidence: f32,
|
|
|
|
/// Predictor that made this prediction
|
|
pub predictor_id: PredictorId,
|
|
|
|
/// When this prediction was made
|
|
pub timestamp: std::time::Instant,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
|
pub enum PredictorId {
|
|
ShortTermLSTM,
|
|
SessionTransformer,
|
|
LongTermPattern,
|
|
Ensemble,
|
|
}
|
|
|
|
/// Query predictor trait
|
|
pub trait QueryPredictor: Send + Sync {
|
|
/// Predict next k queries given history
|
|
fn predict(
|
|
&self,
|
|
history: &QueryHistory,
|
|
k: usize
|
|
) -> Vec<QueryPrediction>;
|
|
|
|
/// Update predictor with observed query (online learning)
|
|
fn update(&mut self, history: &QueryHistory, actual_query: &[f32]);
|
|
|
|
/// Get predictor metrics
|
|
fn get_metrics(&self) -> PredictorMetrics;
|
|
}
|
|
|
|
/// Short-term LSTM predictor
|
|
#[derive(Debug)]
|
|
pub struct ShortTermLSTM {
|
|
/// LSTM parameters
|
|
lstm_weights: LSTMWeights,
|
|
|
|
/// Embedding dimension
|
|
embed_dim: usize,
|
|
|
|
/// Hidden state dimension
|
|
hidden_dim: usize,
|
|
|
|
/// Current hidden state
|
|
hidden_state: Option<Vec<f32>>,
|
|
|
|
/// Current cell state
|
|
cell_state: Option<Vec<f32>>,
|
|
|
|
/// Optimizer state
|
|
optimizer: AdamOptimizer,
|
|
|
|
/// Metrics
|
|
metrics: PredictorMetrics,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct LSTMWeights {
|
|
pub w_f: Array2<f32>, // Forget gate
|
|
pub w_i: Array2<f32>, // Input gate
|
|
pub w_c: Array2<f32>, // Cell gate
|
|
pub w_o: Array2<f32>, // Output gate
|
|
pub b_f: Array1<f32>,
|
|
pub b_i: Array1<f32>,
|
|
pub b_c: Array1<f32>,
|
|
pub b_o: Array1<f32>,
|
|
}
|
|
|
|
/// Session-level transformer predictor
|
|
#[derive(Debug)]
|
|
pub struct SessionTransformer {
|
|
/// Transformer parameters
|
|
transformer_weights: TransformerWeights,
|
|
|
|
/// Embedding dimension
|
|
embed_dim: usize,
|
|
|
|
/// Number of attention heads
|
|
num_heads: usize,
|
|
|
|
/// Number of layers
|
|
num_layers: usize,
|
|
|
|
/// Maximum sequence length
|
|
max_seq_len: usize,
|
|
|
|
/// Position encoding
|
|
position_encoding: Array2<f32>,
|
|
|
|
/// Optimizer
|
|
optimizer: AdamOptimizer,
|
|
|
|
/// Metrics
|
|
metrics: PredictorMetrics,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct TransformerWeights {
|
|
pub layers: Vec<TransformerLayer>,
|
|
pub output_proj: Array2<f32>,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct TransformerLayer {
|
|
pub self_attn: MultiHeadAttention,
|
|
pub feed_forward: FeedForward,
|
|
pub norm1: LayerNorm,
|
|
pub norm2: LayerNorm,
|
|
}
|
|
|
|
/// Long-term pattern predictor
|
|
#[derive(Debug)]
|
|
pub struct LongTermPattern {
|
|
/// Frequent pattern index
|
|
pattern_index: HashMap<u64, PatternFrequency>,
|
|
|
|
/// Temporal pattern index (hour of day, day of week)
|
|
temporal_index: HashMap<TemporalKey, Vec<Vec<f32>>>,
|
|
|
|
/// User-specific patterns
|
|
user_patterns: HashMap<String, Vec<Vec<f32>>>,
|
|
|
|
/// Embedding dimension
|
|
embed_dim: usize,
|
|
|
|
/// Metrics
|
|
metrics: PredictorMetrics,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct PatternFrequency {
|
|
/// Pattern (sequence of query hashes)
|
|
pub pattern: Vec<u64>,
|
|
|
|
/// Frequency count
|
|
pub count: usize,
|
|
|
|
/// Next query distribution
|
|
pub next_queries: HashMap<u64, usize>,
|
|
|
|
/// Last seen timestamp
|
|
pub last_seen: std::time::Instant,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
|
|
pub struct TemporalKey {
|
|
pub hour: u8, // 0-23
|
|
pub day_of_week: u8, // 0-6
|
|
}
|
|
|
|
/// Ensemble predictor combining multiple predictors
|
|
#[derive(Debug)]
|
|
pub struct EnsemblePredictor {
|
|
/// Component predictors
|
|
predictors: Vec<Box<dyn QueryPredictor>>,
|
|
|
|
/// Predictor weights (learned online)
|
|
weights: Vec<f32>,
|
|
|
|
/// Ensemble strategy
|
|
strategy: EnsembleStrategy,
|
|
|
|
/// Metrics
|
|
metrics: PredictorMetrics,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub enum EnsembleStrategy {
|
|
/// Weighted average by confidence
|
|
WeightedAverage,
|
|
|
|
/// Take prediction from most confident predictor
|
|
MaxConfidence,
|
|
|
|
/// Majority voting on predicted query hash
|
|
MajorityVoting,
|
|
|
|
/// Learned weighted combination
|
|
LearnedWeights,
|
|
}
|
|
|
|
/// Predictor performance metrics
|
|
#[derive(Debug, Clone, Default)]
|
|
pub struct PredictorMetrics {
|
|
/// Total predictions made
|
|
pub total_predictions: usize,
|
|
|
|
/// Correct predictions (within threshold)
|
|
pub correct_predictions: usize,
|
|
|
|
/// Average prediction confidence
|
|
pub avg_confidence: f32,
|
|
|
|
/// Prediction latency
|
|
pub avg_latency_ms: f32,
|
|
|
|
/// Confidence calibration (predicted vs actual accuracy)
|
|
pub calibration_error: f32,
|
|
}
|
|
|
|
/// Attention cache with prefetched results
|
|
#[derive(Debug)]
|
|
pub struct AttentionCache {
|
|
/// Cache storage: query_hash -> CacheEntry
|
|
cache: HashMap<u64, CacheEntry>,
|
|
|
|
/// Cache metadata for eviction
|
|
metadata: CacheMetadata,
|
|
|
|
/// Maximum cache size
|
|
max_size: usize,
|
|
|
|
/// Eviction policy
|
|
eviction_policy: EvictionPolicy,
|
|
|
|
/// Cache metrics
|
|
metrics: CacheMetrics,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct CacheEntry {
|
|
/// Attention scores
|
|
pub scores: Vec<f32>,
|
|
|
|
/// Top-k indices
|
|
pub top_k_indices: Vec<usize>,
|
|
|
|
/// When this was computed
|
|
pub timestamp: std::time::Instant,
|
|
|
|
/// How this entry was created
|
|
pub source: EntrySource,
|
|
|
|
/// Number of times this entry was hit
|
|
pub hit_count: usize,
|
|
|
|
/// Priority for eviction
|
|
pub priority: f32,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub enum EntrySource {
|
|
/// Computed on-demand (cache miss)
|
|
OnDemand,
|
|
|
|
/// Prefetched based on prediction
|
|
Prefetched,
|
|
|
|
/// Manually inserted
|
|
Manual,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct CacheMetadata {
|
|
/// LRU tracking
|
|
lru_order: VecDeque<u64>,
|
|
|
|
/// Access frequency tracking
|
|
access_counts: HashMap<u64, usize>,
|
|
|
|
/// Last access times
|
|
last_access: HashMap<u64, std::time::Instant>,
|
|
|
|
/// Predicted future access (from predictor)
|
|
predicted_access: HashMap<u64, f32>,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub enum EvictionPolicy {
|
|
/// Least Recently Used
|
|
LRU,
|
|
|
|
/// Least Frequently Used
|
|
LFU,
|
|
|
|
/// Prediction-aware (least likely to be accessed)
|
|
PredictionAware,
|
|
|
|
/// Adaptive based on hit rate
|
|
Adaptive,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Default)]
|
|
pub struct CacheMetrics {
|
|
/// Total cache hits
|
|
pub hits: usize,
|
|
|
|
/// Total cache misses
|
|
pub misses: usize,
|
|
|
|
/// Prefetch hits (predicted query was actually requested)
|
|
pub prefetch_hits: usize,
|
|
|
|
/// Prefetch misses (prefetched but never requested)
|
|
pub prefetch_misses: usize,
|
|
|
|
/// Average cache lookup latency
|
|
pub avg_lookup_latency_ms: f32,
|
|
|
|
/// Current cache size
|
|
pub current_size: usize,
|
|
|
|
/// Total evictions
|
|
pub evictions: usize,
|
|
}
|
|
|
|
/// Prefetch scheduler
|
|
#[derive(Debug)]
|
|
pub struct PrefetchScheduler {
|
|
/// Work queue sorted by priority
|
|
work_queue: BinaryHeap<PrefetchTask>,
|
|
|
|
/// Currently executing tasks
|
|
active_tasks: HashMap<u64, TaskHandle>,
|
|
|
|
/// Worker thread pool
|
|
worker_pool: ThreadPool,
|
|
|
|
/// Scheduler configuration
|
|
config: SchedulerConfig,
|
|
|
|
/// Metrics
|
|
metrics: SchedulerMetrics,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct PrefetchTask {
|
|
/// Predicted query
|
|
pub query: Vec<f32>,
|
|
|
|
/// Query hash
|
|
pub query_hash: u64,
|
|
|
|
/// Priority (higher = more urgent)
|
|
pub priority: f32,
|
|
|
|
/// Prediction confidence
|
|
pub confidence: f32,
|
|
|
|
/// When this task was created
|
|
pub created_at: std::time::Instant,
|
|
}
|
|
|
|
impl Ord for PrefetchTask {
|
|
fn cmp(&self, other: &Self) -> Ordering {
|
|
self.priority.partial_cmp(&other.priority).unwrap_or(Ordering::Equal)
|
|
}
|
|
}
|
|
|
|
impl PartialOrd for PrefetchTask {
|
|
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
|
Some(self.cmp(other))
|
|
}
|
|
}
|
|
|
|
impl PartialEq for PrefetchTask {
|
|
fn eq(&self, other: &Self) -> bool {
|
|
self.query_hash == other.query_hash
|
|
}
|
|
}
|
|
|
|
impl Eq for PrefetchTask {}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct SchedulerConfig {
|
|
/// Maximum concurrent prefetch tasks
|
|
pub max_concurrent: usize,
|
|
|
|
/// Minimum confidence to schedule prefetch
|
|
pub min_confidence: f32,
|
|
|
|
/// System load threshold (0.0-1.0)
|
|
/// Don't prefetch if load > threshold
|
|
pub max_system_load: f32,
|
|
|
|
/// Priority function parameters
|
|
pub priority_weights: PriorityWeights,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct PriorityWeights {
|
|
pub confidence_weight: f32,
|
|
pub cache_space_weight: f32,
|
|
pub system_load_weight: f32,
|
|
pub temporal_weight: f32,
|
|
}
|
|
|
|
#[derive(Debug, Default)]
|
|
pub struct SchedulerMetrics {
|
|
pub tasks_scheduled: usize,
|
|
pub tasks_completed: usize,
|
|
pub tasks_cancelled: usize,
|
|
pub avg_task_latency_ms: f32,
|
|
}
|
|
|
|
/// Complete Predictive Prefetch Attention system
|
|
pub struct PredictivePrefetchAttention {
|
|
/// Configuration
|
|
config: PPAConfig,
|
|
|
|
/// Query history tracker
|
|
history: Arc<RwLock<QueryHistory>>,
|
|
|
|
/// Query predictor
|
|
predictor: Arc<RwLock<EnsemblePredictor>>,
|
|
|
|
/// Attention cache
|
|
cache: Arc<RwLock<AttentionCache>>,
|
|
|
|
/// Prefetch scheduler
|
|
scheduler: Arc<RwLock<PrefetchScheduler>>,
|
|
|
|
/// Underlying attention mechanism
|
|
attention: Box<dyn AttentionLayer>,
|
|
|
|
/// Candidate set (for prefetch computation)
|
|
candidates: Arc<RwLock<Array2<f32>>>,
|
|
|
|
/// Global metrics
|
|
metrics: Arc<RwLock<PPAMetrics>>,
|
|
}
|
|
|
|
#[derive(Debug, Default)]
|
|
pub struct PPAMetrics {
|
|
/// Total queries processed
|
|
pub total_queries: usize,
|
|
|
|
/// Cache hit rate
|
|
pub cache_hit_rate: f32,
|
|
|
|
/// Prefetch hit rate
|
|
pub prefetch_hit_rate: f32,
|
|
|
|
/// Average latency (cache hit)
|
|
pub avg_latency_hit_ms: f32,
|
|
|
|
/// Average latency (cache miss)
|
|
pub avg_latency_miss_ms: f32,
|
|
|
|
/// Predictor accuracy over time
|
|
pub predictor_accuracy_history: VecDeque<f32>,
|
|
|
|
/// System throughput (queries/second)
|
|
pub throughput: f32,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub enum PredictorType {
|
|
ShortTermLSTM,
|
|
SessionTransformer,
|
|
LongTermPattern,
|
|
Ensemble,
|
|
}
|
|
```
|
|
|
|
### Key Algorithms
|
|
|
|
#### 1. Main Query Processing with Prefetch
|
|
|
|
```rust
|
|
/// Process query with predictive prefetching
|
|
async fn query_with_prefetch(
|
|
&mut self,
|
|
query: &[f32],
|
|
k: usize
|
|
) -> Result<(Vec<usize>, Vec<f32>), PPAError> {
|
|
|
|
let start_time = Instant::now();
|
|
let query_hash = hash_query(query);
|
|
|
|
// Step 1: Check cache
|
|
{
|
|
let cache = self.cache.read().await;
|
|
if let Some(entry) = cache.get(query_hash) {
|
|
// Cache hit!
|
|
self.update_metrics_hit();
|
|
return Ok((entry.top_k_indices.clone(), entry.scores.clone()));
|
|
}
|
|
}
|
|
|
|
// Step 2: Cache miss - compute attention
|
|
let (indices, scores) = self.attention.forward(query, k)?;
|
|
|
|
// Step 3: Store in cache
|
|
{
|
|
let mut cache = self.cache.write().await;
|
|
cache.insert(query_hash, CacheEntry {
|
|
scores: scores.clone(),
|
|
top_k_indices: indices.clone(),
|
|
timestamp: Instant::now(),
|
|
source: EntrySource::OnDemand,
|
|
hit_count: 1,
|
|
priority: 1.0,
|
|
});
|
|
}
|
|
|
|
// Step 4: Update query history
|
|
{
|
|
let mut history = self.history.write().await;
|
|
history.add_query(QueryRecord {
|
|
embedding: query.to_vec(),
|
|
timestamp: Instant::now(),
|
|
hash: query_hash,
|
|
session_id: history.session_id.clone(),
|
|
metadata: HashMap::new(),
|
|
});
|
|
}
|
|
|
|
// Step 5: Predict next queries and schedule prefetch (async)
|
|
tokio::spawn({
|
|
let predictor = Arc::clone(&self.predictor);
|
|
let history = Arc::clone(&self.history);
|
|
let scheduler = Arc::clone(&self.scheduler);
|
|
let config = self.config.clone();
|
|
|
|
async move {
|
|
// Get predictions
|
|
let predictions = {
|
|
let predictor = predictor.read().await;
|
|
let history = history.read().await;
|
|
predictor.predict(&history, config.prefetch_k)
|
|
};
|
|
|
|
// Schedule prefetch tasks
|
|
let mut scheduler = scheduler.write().await;
|
|
for prediction in predictions {
|
|
if prediction.confidence >= config.min_confidence {
|
|
let priority = compute_priority(
|
|
prediction.confidence,
|
|
&config.scheduler.priority_weights
|
|
);
|
|
|
|
scheduler.schedule(PrefetchTask {
|
|
query: prediction.predicted_query,
|
|
query_hash: hash_query(&prediction.predicted_query),
|
|
priority,
|
|
confidence: prediction.confidence,
|
|
created_at: Instant::now(),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
// Step 6: Online learning update (async)
|
|
if self.config.online_learning {
|
|
tokio::spawn({
|
|
let predictor = Arc::clone(&self.predictor);
|
|
let history = Arc::clone(&self.history);
|
|
let query = query.to_vec();
|
|
|
|
async move {
|
|
let mut predictor = predictor.write().await;
|
|
let history = history.read().await;
|
|
predictor.update(&history, &query);
|
|
}
|
|
});
|
|
}
|
|
|
|
let latency = start_time.elapsed();
|
|
self.update_metrics_miss(latency);
|
|
|
|
Ok((indices, scores))
|
|
}
|
|
|
|
/// Compute priority for prefetch task
|
|
fn compute_priority(
|
|
confidence: f32,
|
|
weights: &PriorityWeights
|
|
) -> f32 {
|
|
let cache_space_available = get_cache_space_ratio();
|
|
let system_load = get_system_load();
|
|
|
|
let priority =
|
|
confidence * weights.confidence_weight +
|
|
cache_space_available * weights.cache_space_weight -
|
|
system_load * weights.system_load_weight;
|
|
|
|
priority.max(0.0).min(1.0)
|
|
}
|
|
```
|
|
|
|
#### 2. LSTM Query Prediction
|
|
|
|
```rust
|
|
/// LSTM forward pass for query prediction
|
|
fn lstm_predict(
|
|
&self,
|
|
history: &QueryHistory,
|
|
k: usize
|
|
) -> Vec<QueryPrediction> {
|
|
|
|
if history.queries.len() < 2 {
|
|
return Vec::new();
|
|
}
|
|
|
|
// Initialize hidden and cell states
|
|
let mut h = self.hidden_state.clone()
|
|
.unwrap_or_else(|| vec![0.0; self.hidden_dim]);
|
|
let mut c = self.cell_state.clone()
|
|
.unwrap_or_else(|| vec![0.0; self.hidden_dim]);
|
|
|
|
// Process query sequence
|
|
for query in history.queries.iter() {
|
|
let x = &query.embedding;
|
|
|
|
// LSTM cell computation
|
|
let (h_new, c_new) = lstm_cell_forward(
|
|
x,
|
|
&h,
|
|
&c,
|
|
&self.lstm_weights
|
|
);
|
|
|
|
h = h_new;
|
|
c = c_new;
|
|
}
|
|
|
|
// Predict next k queries
|
|
let mut predictions = Vec::new();
|
|
let mut h_pred = h.clone();
|
|
let mut c_pred = c.clone();
|
|
|
|
for i in 0..k {
|
|
// Generate prediction from hidden state
|
|
let predicted_query = self.output_projection(&h_pred);
|
|
|
|
// Compute confidence based on hidden state entropy
|
|
let confidence = compute_prediction_confidence(&h_pred, &c_pred);
|
|
|
|
predictions.push(QueryPrediction {
|
|
predicted_query: predicted_query.clone(),
|
|
confidence: confidence * (0.9_f32.powi(i as i32)), // Decay confidence
|
|
predictor_id: PredictorId::ShortTermLSTM,
|
|
timestamp: Instant::now(),
|
|
});
|
|
|
|
// Continue LSTM for next prediction
|
|
let (h_new, c_new) = lstm_cell_forward(
|
|
&predicted_query,
|
|
&h_pred,
|
|
&c_pred,
|
|
&self.lstm_weights
|
|
);
|
|
h_pred = h_new;
|
|
c_pred = c_new;
|
|
}
|
|
|
|
predictions
|
|
}
|
|
|
|
/// LSTM cell forward pass
|
|
fn lstm_cell_forward(
|
|
x: &[f32],
|
|
h: &[f32],
|
|
c: &[f32],
|
|
weights: &LSTMWeights
|
|
) -> (Vec<f32>, Vec<f32>) {
|
|
|
|
// Concatenate input and hidden state
|
|
let mut xh = x.to_vec();
|
|
xh.extend_from_slice(h);
|
|
let xh = Array1::from(xh);
|
|
|
|
// Compute gates
|
|
let f = sigmoid(&(weights.w_f.dot(&xh) + &weights.b_f)); // Forget gate
|
|
let i = sigmoid(&(weights.w_i.dot(&xh) + &weights.b_i)); // Input gate
|
|
let g = tanh(&(weights.w_c.dot(&xh) + &weights.b_c)); // Cell gate
|
|
let o = sigmoid(&(weights.w_o.dot(&xh) + &weights.b_o)); // Output gate
|
|
|
|
// Update cell state
|
|
let c_new = &f * &Array1::from(c.to_vec()) + &i * &g;
|
|
|
|
// Compute new hidden state
|
|
let h_new = &o * &tanh(&c_new);
|
|
|
|
(h_new.to_vec(), c_new.to_vec())
|
|
}
|
|
|
|
/// Compute prediction confidence from LSTM hidden state
|
|
fn compute_prediction_confidence(h: &[f32], c: &[f32]) -> f32 {
|
|
// Higher confidence when hidden state has low entropy
|
|
let h_entropy = -h.iter()
|
|
.map(|&x| {
|
|
let p = sigmoid_scalar(x);
|
|
if p > 0.0 && p < 1.0 {
|
|
p * p.ln() + (1.0 - p) * (1.0 - p).ln()
|
|
} else {
|
|
0.0
|
|
}
|
|
})
|
|
.sum::<f32>();
|
|
|
|
// Normalize entropy to confidence score
|
|
let max_entropy = h.len() as f32 * (0.5_f32.ln() * 2.0);
|
|
let confidence = 1.0 - (h_entropy / max_entropy).min(1.0);
|
|
|
|
confidence.max(0.0).min(1.0)
|
|
}
|
|
```
|
|
|
|
#### 3. Transformer Session Prediction
|
|
|
|
```rust
|
|
/// Transformer-based session prediction
|
|
fn transformer_predict(
|
|
&self,
|
|
history: &QueryHistory,
|
|
k: usize
|
|
) -> Vec<QueryPrediction> {
|
|
|
|
let seq_len = history.queries.len();
|
|
if seq_len == 0 {
|
|
return Vec::new();
|
|
}
|
|
|
|
// Prepare input sequence
|
|
let mut input_seq = Array2::zeros((seq_len, self.embed_dim));
|
|
for (i, query) in history.queries.iter().enumerate() {
|
|
for (j, &val) in query.embedding.iter().enumerate() {
|
|
input_seq[[i, j]] = val;
|
|
}
|
|
}
|
|
|
|
// Add position encoding
|
|
let pos_encoded = &input_seq + &self.position_encoding.slice(s![..seq_len, ..]);
|
|
|
|
// Forward through transformer layers
|
|
let mut hidden = pos_encoded;
|
|
for layer in &self.transformer_weights.layers {
|
|
hidden = transformer_layer_forward(hidden, layer);
|
|
}
|
|
|
|
// Use last hidden state for prediction
|
|
let last_hidden = hidden.row(seq_len - 1);
|
|
|
|
// Project to next query prediction
|
|
let predicted_query = self.transformer_weights.output_proj.dot(&last_hidden);
|
|
|
|
// Compute confidence from attention weights
|
|
let confidence = compute_transformer_confidence(&hidden);
|
|
|
|
vec![QueryPrediction {
|
|
predicted_query: predicted_query.to_vec(),
|
|
confidence,
|
|
predictor_id: PredictorId::SessionTransformer,
|
|
timestamp: Instant::now(),
|
|
}]
|
|
}
|
|
|
|
/// Forward through transformer layer
|
|
fn transformer_layer_forward(
|
|
input: Array2<f32>,
|
|
layer: &TransformerLayer
|
|
) -> Array2<f32> {
|
|
|
|
// Self-attention
|
|
let attn_output = multi_head_attention_forward(
|
|
&input,
|
|
&input,
|
|
&input,
|
|
&layer.self_attn
|
|
);
|
|
|
|
// Add & norm
|
|
let normed1 = layer_norm(&(input + attn_output), &layer.norm1);
|
|
|
|
// Feed-forward
|
|
let ff_output = feed_forward(&normed1, &layer.feed_forward);
|
|
|
|
// Add & norm
|
|
layer_norm(&(normed1 + ff_output), &layer.norm2)
|
|
}
|
|
```
|
|
|
|
#### 4. Cache Management
|
|
|
|
```rust
|
|
/// Insert entry into cache with eviction if necessary
|
|
fn cache_insert(&mut self, query_hash: u64, entry: CacheEntry) {
|
|
|
|
// Check if cache is full
|
|
if self.cache.len() >= self.max_size {
|
|
// Evict entry based on policy
|
|
let victim_hash = match self.eviction_policy {
|
|
EvictionPolicy::LRU => {
|
|
self.metadata.lru_order.pop_front().unwrap()
|
|
},
|
|
EvictionPolicy::LFU => {
|
|
self.find_lfu_victim()
|
|
},
|
|
EvictionPolicy::PredictionAware => {
|
|
self.find_prediction_aware_victim()
|
|
},
|
|
EvictionPolicy::Adaptive => {
|
|
self.find_adaptive_victim()
|
|
}
|
|
};
|
|
|
|
self.cache.remove(&victim_hash);
|
|
self.metrics.evictions += 1;
|
|
}
|
|
|
|
// Insert new entry
|
|
self.cache.insert(query_hash, entry);
|
|
self.metadata.lru_order.push_back(query_hash);
|
|
self.metadata.last_access.insert(query_hash, Instant::now());
|
|
self.metrics.current_size = self.cache.len();
|
|
}
|
|
|
|
/// Find victim for prediction-aware eviction
|
|
fn find_prediction_aware_victim(&self) -> u64 {
|
|
// Evict entry with lowest predicted future access probability
|
|
let mut min_score = f32::MAX;
|
|
let mut victim = 0;
|
|
|
|
for (&hash, entry) in &self.cache {
|
|
// Score = predicted_access_prob * recency * frequency
|
|
let predicted_access = self.metadata.predicted_access
|
|
.get(&hash)
|
|
.copied()
|
|
.unwrap_or(0.0);
|
|
|
|
let recency = self.metadata.last_access
|
|
.get(&hash)
|
|
.map(|t| t.elapsed().as_secs_f32())
|
|
.unwrap_or(f32::MAX);
|
|
|
|
let frequency = self.metadata.access_counts
|
|
.get(&hash)
|
|
.copied()
|
|
.unwrap_or(0) as f32;
|
|
|
|
let score = predicted_access * (1.0 / (1.0 + recency)) * frequency;
|
|
|
|
if score < min_score {
|
|
min_score = score;
|
|
victim = hash;
|
|
}
|
|
}
|
|
|
|
victim
|
|
}
|
|
```
|
|
|
|
#### 5. Online Learning Update
|
|
|
|
```rust
|
|
/// Update predictor based on observed query
|
|
async fn update_predictor(
|
|
&mut self,
|
|
history: &QueryHistory,
|
|
actual_query: &[f32]
|
|
) {
|
|
|
|
// Get what we predicted last time
|
|
let last_predictions = self.last_predictions.clone();
|
|
|
|
// Compute loss (MSE between prediction and actual)
|
|
for prediction in last_predictions {
|
|
let mse = mean_squared_error(&prediction.predicted_query, actual_query);
|
|
|
|
// Update predictor weights based on loss
|
|
match prediction.predictor_id {
|
|
PredictorId::ShortTermLSTM => {
|
|
self.update_lstm(history, actual_query, mse);
|
|
},
|
|
PredictorId::SessionTransformer => {
|
|
self.update_transformer(history, actual_query, mse);
|
|
},
|
|
PredictorId::LongTermPattern => {
|
|
self.update_pattern_index(history, actual_query);
|
|
},
|
|
_ => {}
|
|
}
|
|
|
|
// Update ensemble weights
|
|
self.update_ensemble_weights(prediction.predictor_id, mse);
|
|
}
|
|
|
|
// Update metrics
|
|
self.update_prediction_metrics(last_predictions, actual_query);
|
|
}
|
|
|
|
/// Update LSTM weights via backpropagation
|
|
fn update_lstm(
|
|
&mut self,
|
|
history: &QueryHistory,
|
|
actual_query: &[f32],
|
|
loss: f32
|
|
) {
|
|
|
|
// Compute gradients via BPTT
|
|
let gradients = compute_lstm_gradients(
|
|
&self.lstm_weights,
|
|
history,
|
|
actual_query
|
|
);
|
|
|
|
// Update weights with Adam optimizer
|
|
self.optimizer.step(&mut self.lstm_weights, gradients);
|
|
|
|
// Update metrics
|
|
self.metrics.avg_loss = 0.9 * self.metrics.avg_loss + 0.1 * loss;
|
|
}
|
|
```
|
|
|
|
### API Design
|
|
|
|
```rust
|
|
/// Public API for Predictive Prefetch Attention
|
|
pub trait PPALayer {
|
|
/// Create new PPA layer
|
|
fn new(
|
|
config: PPAConfig,
|
|
attention: Box<dyn AttentionLayer>
|
|
) -> Self;
|
|
|
|
/// Process query with prefetching
|
|
async fn query(
|
|
&mut self,
|
|
query: &[f32],
|
|
k: usize
|
|
) -> Result<(Vec<usize>, Vec<f32>), PPAError>;
|
|
|
|
/// Update candidate set for prefetch
|
|
fn update_candidates(&mut self, candidates: Vec<Vec<f32>>);
|
|
|
|
/// Start prefetch worker pool
|
|
async fn start_prefetch_workers(&mut self) -> Result<(), PPAError>;
|
|
|
|
/// Stop prefetch workers
|
|
async fn stop_prefetch_workers(&mut self) -> Result<(), PPAError>;
|
|
|
|
/// Get current metrics
|
|
fn get_metrics(&self) -> PPAMetrics;
|
|
|
|
/// Reset metrics
|
|
fn reset_metrics(&mut self);
|
|
|
|
/// Start new session
|
|
fn start_session(&mut self, session_id: String);
|
|
|
|
/// End current session
|
|
fn end_session(&mut self);
|
|
|
|
/// Save predictor state
|
|
async fn save_state(&self, path: &str) -> Result<(), PPAError>;
|
|
|
|
/// Load predictor state
|
|
async fn load_state(&mut self, path: &str) -> Result<(), PPAError>;
|
|
}
|
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
pub enum PPAError {
|
|
#[error("Attention error: {0}")]
|
|
AttentionError(String),
|
|
|
|
#[error("Prediction error: {0}")]
|
|
PredictionError(String),
|
|
|
|
#[error("Cache error: {0}")]
|
|
CacheError(String),
|
|
|
|
#[error("IO error: {0}")]
|
|
IoError(#[from] std::io::Error),
|
|
}
|
|
|
|
/// Builder for PPA configuration
|
|
pub struct PPAConfigBuilder {
|
|
history_size: usize,
|
|
prefetch_k: usize,
|
|
min_confidence: f32,
|
|
max_cache_size: usize,
|
|
num_workers: usize,
|
|
online_learning: bool,
|
|
learning_rate: f32,
|
|
predictor_type: PredictorType,
|
|
eviction_policy: EvictionPolicy,
|
|
}
|
|
|
|
impl PPAConfigBuilder {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
history_size: 100,
|
|
prefetch_k: 5,
|
|
min_confidence: 0.5,
|
|
max_cache_size: 10000,
|
|
num_workers: 4,
|
|
online_learning: true,
|
|
learning_rate: 0.001,
|
|
predictor_type: PredictorType::Ensemble,
|
|
eviction_policy: EvictionPolicy::PredictionAware,
|
|
}
|
|
}
|
|
|
|
pub fn history_size(mut self, size: usize) -> Self {
|
|
self.history_size = size;
|
|
self
|
|
}
|
|
|
|
pub fn prefetch_k(mut self, k: usize) -> Self {
|
|
self.prefetch_k = k;
|
|
self
|
|
}
|
|
|
|
pub fn min_confidence(mut self, conf: f32) -> Self {
|
|
self.min_confidence = conf;
|
|
self
|
|
}
|
|
|
|
pub fn build(self) -> PPAConfig {
|
|
PPAConfig {
|
|
history_size: self.history_size,
|
|
prefetch_k: self.prefetch_k,
|
|
min_confidence: self.min_confidence,
|
|
max_cache_size: self.max_cache_size,
|
|
num_workers: self.num_workers,
|
|
online_learning: self.online_learning,
|
|
learning_rate: self.learning_rate,
|
|
predictor_type: self.predictor_type,
|
|
eviction_policy: self.eviction_policy,
|
|
}
|
|
}
|
|
}
|
|
```
|
|
|
|
## Integration Points
|
|
|
|
### Affected Crates/Modules
|
|
|
|
1. **`ruvector-gnn-core/`**
|
|
- `src/attention/mod.rs` - Add PPA as wrapper around existing attention
|
|
- `src/cache/mod.rs` - New cache subsystem
|
|
|
|
2. **`ruvector-gnn-node/`**
|
|
- `src/lib.rs` - Expose async PPA API to Node.js
|
|
- Add support for session management in bindings
|
|
|
|
3. **`ruvector-core/`**
|
|
- May benefit from PPA for index queries
|
|
|
|
### New Modules to Create
|
|
|
|
1. **`ruvector-gnn-core/src/attention/ppa/`**
|
|
```
|
|
ppa/
|
|
├── mod.rs
|
|
├── config.rs
|
|
├── history.rs # Query history tracking
|
|
├── predictor/
|
|
│ ├── mod.rs
|
|
│ ├── lstm.rs # LSTM predictor
|
|
│ ├── transformer.rs # Transformer predictor
|
|
│ ├── pattern.rs # Pattern-based predictor
|
|
│ └── ensemble.rs # Ensemble predictor
|
|
├── cache/
|
|
│ ├── mod.rs
|
|
│ ├── entry.rs
|
|
│ ├── eviction.rs
|
|
│ └── metrics.rs
|
|
├── scheduler.rs # Prefetch scheduler
|
|
├── worker.rs # Prefetch worker pool
|
|
└── metrics.rs # Global metrics
|
|
```
|
|
|
|
2. **`ruvector-gnn-core/tests/ppa/`**
|
|
```
|
|
tests/ppa/
|
|
├── basic.rs
|
|
├── prediction.rs
|
|
├── cache.rs
|
|
├── scheduler.rs
|
|
├── online_learning.rs
|
|
├── integration.rs
|
|
└── benchmarks.rs
|
|
```
|
|
|
|
### Dependencies on Other Features
|
|
|
|
- **All attention features**: PPA wraps any attention mechanism
|
|
- **Feature 15 (ESA)**: Can prefetch ESA attention computations
|
|
- **Feature 19 (Consensus Attention)**: Can prefetch consensus computations
|
|
|
|
### External Dependencies
|
|
|
|
```toml
|
|
[dependencies]
|
|
tokio = { version = "1.35", features = ["full"] }
|
|
rayon = "1.7"
|
|
ndarray = "0.15"
|
|
serde = { version = "1.0", features = ["derive"] }
|
|
bincode = "1.3"
|
|
thiserror = "1.0"
|
|
dashmap = "5.5" # Concurrent HashMap
|
|
crossbeam = "0.8" # Lock-free data structures
|
|
```
|
|
|
|
## Regression Prevention
|
|
|
|
### What Existing Functionality Could Break
|
|
|
|
1. **Synchronous API**
|
|
- Risk: PPA is async, existing code expects sync
|
|
- Mitigation: Provide both sync and async APIs
|
|
|
|
2. **Determinism**
|
|
- Risk: Prefetching may introduce non-determinism
|
|
- Mitigation: Cache can be disabled for testing
|
|
|
|
3. **Memory Usage**
|
|
- Risk: Cache and predictor increase memory significantly
|
|
- Mitigation: Configurable limits, memory monitoring
|
|
|
|
4. **Thread Safety**
|
|
- Risk: Concurrent prefetch could cause races
|
|
- Mitigation: Extensive use of Arc<RwLock<>> and DashMap
|
|
|
|
### Test Cases
|
|
|
|
```rust
|
|
#[tokio::test]
|
|
async fn test_cache_hit_performance() {
|
|
let ppa = setup_ppa().await;
|
|
|
|
let query = vec![1.0; 128];
|
|
|
|
// First query (cache miss)
|
|
let start = Instant::now();
|
|
let _ = ppa.query(&query, 10).await;
|
|
let miss_latency = start.elapsed();
|
|
|
|
// Second query (cache hit)
|
|
let start = Instant::now();
|
|
let _ = ppa.query(&query, 10).await;
|
|
let hit_latency = start.elapsed();
|
|
|
|
// Cache hit should be 10x faster
|
|
assert!(hit_latency < miss_latency / 10);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_prefetch_accuracy() {
|
|
let mut ppa = setup_ppa().await;
|
|
|
|
// Create predictable query sequence
|
|
let sequence = generate_predictable_sequence(100);
|
|
|
|
// Process sequence and measure prefetch hit rate
|
|
let mut prefetch_hits = 0;
|
|
for query in sequence {
|
|
if ppa.is_in_cache(&query) {
|
|
prefetch_hits += 1;
|
|
}
|
|
let _ = ppa.query(&query, 10).await;
|
|
}
|
|
|
|
let hit_rate = prefetch_hits as f32 / 100.0;
|
|
|
|
// After warm-up, hit rate should be > 60%
|
|
assert!(hit_rate > 0.6);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_online_learning_improvement() {
|
|
let mut ppa = setup_ppa().await;
|
|
|
|
// Measure accuracy before learning
|
|
let initial_accuracy = measure_prediction_accuracy(&ppa).await;
|
|
|
|
// Process many queries to trigger learning
|
|
for _ in 0..1000 {
|
|
let query = generate_random_query();
|
|
let _ = ppa.query(&query, 10).await;
|
|
}
|
|
|
|
// Measure accuracy after learning
|
|
let final_accuracy = measure_prediction_accuracy(&ppa).await;
|
|
|
|
// Accuracy should improve
|
|
assert!(final_accuracy > initial_accuracy + 0.1);
|
|
}
|
|
```
|
|
|
|
## Implementation Phases
|
|
|
|
### Phase 1: Research Validation (3 weeks)
|
|
- Prototype LSTM and transformer predictors in Python
|
|
- Collect real query logs for analysis
|
|
- Benchmark prediction accuracy
|
|
- Analyze cache hit rates with different policies
|
|
|
|
### Phase 2: Core Implementation (4 weeks)
|
|
- Implement query history tracking
|
|
- Implement LSTM predictor
|
|
- Implement cache with LRU eviction
|
|
- Basic prefetch scheduler
|
|
- Unit tests
|
|
|
|
### Phase 3: Advanced Predictors (3 weeks)
|
|
- Implement transformer predictor
|
|
- Implement pattern-based predictor
|
|
- Implement ensemble predictor
|
|
- Online learning updates
|
|
- Advanced eviction policies
|
|
|
|
### Phase 4: Integration & Optimization (2 weeks)
|
|
- Integrate with GNN attention layers
|
|
- Async/await optimization
|
|
- Memory optimization
|
|
- Performance benchmarking
|
|
- Production testing
|
|
|
|
## Success Metrics
|
|
|
|
### Performance Benchmarks
|
|
|
|
| Metric | Target | Measurement |
|
|
|--------|--------|-------------|
|
|
| P95 Latency (cache hit) | <0.1ms | 1M queries |
|
|
| P95 Latency (cache miss) | <5ms | 1M queries |
|
|
| Cache Hit Rate | 65-75% | After 1000 query warm-up |
|
|
| Prefetch Hit Rate | 60-70% | Predicted queries actually requested |
|
|
| Throughput | 3-5x baseline | Queries/second |
|
|
| Prediction Accuracy | 70-85% | Top-1 prediction within cosine<0.1 |
|
|
|
|
### Accuracy Metrics
|
|
|
|
- **Cold Start**: 30-40% accuracy (no history)
|
|
- **After 100 Queries**: 50-60% accuracy
|
|
- **After 1000 Queries**: 70-85% accuracy
|
|
- **Confidence Calibration**: <0.1 error
|
|
|
|
## Risks and Mitigations
|
|
|
|
### Technical Risks
|
|
|
|
1. **Risk: Low Prediction Accuracy**
|
|
- Mitigation: Ensemble of multiple predictors, start with conservative confidence thresholds
|
|
|
|
2. **Risk: Memory Overhead**
|
|
- Mitigation: Adaptive cache sizing, configurable limits
|
|
|
|
3. **Risk: Stale Cache Entries**
|
|
- Mitigation: TTL on cache entries, prediction-aware eviction
|
|
|
|
4. **Risk: Wasted Computation on Wrong Predictions**
|
|
- Mitigation: Only prefetch high-confidence predictions, monitor prefetch miss rate
|
|
|
|
5. **Risk: Thread Contention**
|
|
- Mitigation: Lock-free data structures, careful use of RwLock
|
|
|
|
6. **Risk: Cold Start Problem**
|
|
- Mitigation: Fall back to pattern-based prediction, use temporal patterns
|