Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
59
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/Cargo.toml
vendored
Normal file
59
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/Cargo.toml
vendored
Normal file
@@ -0,0 +1,59 @@
|
||||
[package]
|
||||
name = "sevensense-learning"
|
||||
description = "GNN-based learning and embedding refinement for 7sense bioacoustics platform"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
rust-version.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
authors.workspace = true
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
# Internal crates
|
||||
sevensense-core = { workspace = true, version = "0.1.0" }
|
||||
sevensense-vector = { workspace = true, version = "0.1.0" }
|
||||
|
||||
# Core dependencies
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
|
||||
# Async runtime
|
||||
tokio = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
|
||||
# Tracing
|
||||
tracing = { workspace = true }
|
||||
|
||||
# ML and numerical computing
|
||||
ndarray = { workspace = true }
|
||||
ndarray-rand = "0.14"
|
||||
|
||||
# Graph operations
|
||||
petgraph = "0.6"
|
||||
|
||||
# Parallel processing
|
||||
rayon = "1.10"
|
||||
|
||||
# Random number generation
|
||||
rand = "0.8"
|
||||
rand_distr = "0.4"
|
||||
|
||||
[dev-dependencies]
|
||||
proptest = { workspace = true }
|
||||
criterion = { workspace = true }
|
||||
test-case = { workspace = true }
|
||||
tokio = { workspace = true, features = ["test-util", "macros"] }
|
||||
approx = "0.5"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
||||
[[bench]]
|
||||
name = "gnn_benchmark"
|
||||
harness = false
|
||||
416
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/README.md
vendored
Normal file
416
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/README.md
vendored
Normal file
@@ -0,0 +1,416 @@
|
||||
# sevensense-learning
|
||||
|
||||
[](https://crates.io/crates/sevensense-learning)
|
||||
[](https://docs.rs/sevensense-learning)
|
||||
[](../../LICENSE)
|
||||
|
||||
> Graph Neural Network (GNN) learning for bioacoustic pattern discovery.
|
||||
|
||||
**sevensense-learning** implements online learning algorithms that discover patterns in bird vocalizations over time. Using Graph Neural Networks with Elastic Weight Consolidation (EWC), it learns species-specific call patterns, dialect variations, and behavioral signatures without forgetting previously learned knowledge.
|
||||
|
||||
## Features
|
||||
|
||||
- **GNN Architecture**: Graph-based learning on similarity networks
|
||||
- **EWC Regularization**: Prevents catastrophic forgetting in online learning
|
||||
- **Online Updates**: Continuous learning from streaming data
|
||||
- **Transition Graphs**: Model sequential call patterns
|
||||
- **Fisher Information**: Importance-weighted parameter updates
|
||||
- **Gradient Checkpointing**: Memory-efficient training
|
||||
|
||||
## Use Cases
|
||||
|
||||
| Use Case | Description | Key Functions |
|
||||
|----------|-------------|---------------|
|
||||
| Pattern Learning | Learn call patterns | `train()`, `learn_patterns()` |
|
||||
| Online Updates | Incremental learning | `online_update()` |
|
||||
| Transition Modeling | Sequential patterns | `TransitionGraph::learn()` |
|
||||
| EWC Training | Continual learning | `ewc_train()` |
|
||||
| Inference | Pattern prediction | `predict()`, `infer()` |
|
||||
|
||||
## Installation
|
||||
|
||||
Add to your `Cargo.toml`:
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
sevensense-learning = "0.1"
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```rust
|
||||
use sevensense_learning::{GnnModel, GnnConfig, TransitionGraph};
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Create GNN model
|
||||
let config = GnnConfig {
|
||||
hidden_dim: 256,
|
||||
num_layers: 3,
|
||||
dropout: 0.1,
|
||||
..Default::default()
|
||||
};
|
||||
let mut model = GnnModel::new(config);
|
||||
|
||||
// Build transition graph from embeddings
|
||||
let graph = TransitionGraph::from_embeddings(&embeddings, 0.8)?;
|
||||
|
||||
// Train the model
|
||||
model.train(&graph, &embeddings, 100)?; // 100 epochs
|
||||
|
||||
// Make predictions
|
||||
let prediction = model.predict(&query_embedding)?;
|
||||
println!("Predicted pattern: {:?}", prediction);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
<details>
|
||||
<summary><b>Tutorial: Building Transition Graphs</b></summary>
|
||||
|
||||
### Creating from Embeddings
|
||||
|
||||
```rust
|
||||
use sevensense_learning::{TransitionGraph, GraphConfig};
|
||||
|
||||
// Config for graph construction
|
||||
let config = GraphConfig {
|
||||
similarity_threshold: 0.8, // Edge threshold
|
||||
max_neighbors: 10, // Max edges per node
|
||||
temporal_decay: 0.95, // Time-based edge weighting
|
||||
};
|
||||
|
||||
// Build graph from embeddings with timestamps
|
||||
let graph = TransitionGraph::new(config);
|
||||
for (id, embedding, timestamp) in recordings.iter() {
|
||||
graph.add_node(*id, embedding, *timestamp)?;
|
||||
}
|
||||
|
||||
// Automatically computes edges based on similarity
|
||||
graph.build_edges()?;
|
||||
|
||||
println!("Graph has {} nodes and {} edges",
|
||||
graph.node_count(),
|
||||
graph.edge_count());
|
||||
```
|
||||
|
||||
### Analyzing Graph Structure
|
||||
|
||||
```rust
|
||||
use sevensense_learning::TransitionGraph;
|
||||
|
||||
// Get neighbors for a node
|
||||
let neighbors = graph.neighbors(node_id)?;
|
||||
for (neighbor_id, weight) in neighbors {
|
||||
println!("Neighbor {}: weight {:.3}", neighbor_id, weight);
|
||||
}
|
||||
|
||||
// Compute graph statistics
|
||||
let stats = graph.statistics();
|
||||
println!("Average degree: {:.2}", stats.avg_degree);
|
||||
println!("Clustering coefficient: {:.3}", stats.clustering_coeff);
|
||||
println!("Connected components: {}", stats.num_components);
|
||||
```
|
||||
|
||||
### Sequential Patterns
|
||||
|
||||
```rust
|
||||
// Analyze sequential call patterns
|
||||
let sequences = graph.find_sequences(min_length: 3)?;
|
||||
|
||||
for seq in sequences {
|
||||
println!("Sequence: {:?}", seq.node_ids);
|
||||
println!(" Frequency: {}", seq.count);
|
||||
println!(" Avg interval: {:.2}s", seq.avg_interval);
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Tutorial: GNN Training</b></summary>
|
||||
|
||||
### Basic Training
|
||||
|
||||
```rust
|
||||
use sevensense_learning::{GnnModel, GnnConfig, TrainingConfig};
|
||||
|
||||
let model_config = GnnConfig {
|
||||
input_dim: 1536, // Embedding dimension
|
||||
hidden_dim: 256, // Hidden layer size
|
||||
output_dim: 64, // Output embedding size
|
||||
num_layers: 3, // GNN layers
|
||||
dropout: 0.1,
|
||||
};
|
||||
|
||||
let mut model = GnnModel::new(model_config);
|
||||
|
||||
let train_config = TrainingConfig {
|
||||
epochs: 100,
|
||||
learning_rate: 0.001,
|
||||
batch_size: 32,
|
||||
early_stopping: Some(10), // Stop if no improvement for 10 epochs
|
||||
};
|
||||
|
||||
// Train on graph
|
||||
let history = model.train(&graph, &features, train_config)?;
|
||||
|
||||
println!("Final loss: {:.4}", history.final_loss);
|
||||
println!("Best epoch: {}", history.best_epoch);
|
||||
```
|
||||
|
||||
### Training with Validation
|
||||
|
||||
```rust
|
||||
let (train_graph, val_graph) = split_graph(&graph, 0.8)?;
|
||||
|
||||
let history = model.train_with_validation(
|
||||
&train_graph,
|
||||
&val_graph,
|
||||
&features,
|
||||
train_config,
|
||||
)?;
|
||||
|
||||
// Plot training curves
|
||||
for (epoch, train_loss, val_loss) in history.iter() {
|
||||
println!("Epoch {}: train={:.4}, val={:.4}", epoch, train_loss, val_loss);
|
||||
}
|
||||
```
|
||||
|
||||
### Custom Loss Functions
|
||||
|
||||
```rust
|
||||
use sevensense_learning::{GnnModel, LossFunction};
|
||||
|
||||
// Contrastive loss for similarity learning
|
||||
let loss_fn = LossFunction::Contrastive {
|
||||
margin: 0.5,
|
||||
positive_weight: 1.0,
|
||||
negative_weight: 0.5,
|
||||
};
|
||||
|
||||
model.set_loss_function(loss_fn);
|
||||
model.train(&graph, &features, config)?;
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Tutorial: Elastic Weight Consolidation (EWC)</b></summary>
|
||||
|
||||
### Why EWC?
|
||||
|
||||
Standard neural networks suffer from "catastrophic forgetting"—learning new patterns erases old ones. EWC prevents this by protecting important parameters.
|
||||
|
||||
### EWC Training
|
||||
|
||||
```rust
|
||||
use sevensense_learning::{GnnModel, EwcConfig};
|
||||
|
||||
let ewc_config = EwcConfig {
|
||||
lambda: 1000.0, // Regularization strength
|
||||
fisher_samples: 200, // Samples for Fisher estimation
|
||||
online: true, // Online EWC variant
|
||||
};
|
||||
|
||||
let mut model = GnnModel::new(model_config);
|
||||
|
||||
// Train on first dataset
|
||||
model.train(&graph1, &features1, train_config)?;
|
||||
|
||||
// Compute Fisher information (importance weights)
|
||||
model.compute_fisher(&graph1, &features1, ewc_config.fisher_samples)?;
|
||||
|
||||
// Train on second dataset with EWC
|
||||
model.ewc_train(&graph2, &features2, train_config, ewc_config)?;
|
||||
|
||||
// Model remembers patterns from both datasets!
|
||||
```
|
||||
|
||||
### Continual Learning Pipeline
|
||||
|
||||
```rust
|
||||
use sevensense_learning::{ContinualLearner, EwcConfig};
|
||||
|
||||
let mut learner = ContinualLearner::new(model, EwcConfig::default());
|
||||
|
||||
// Learn from streaming data batches
|
||||
for batch in data_stream {
|
||||
let graph = TransitionGraph::from_batch(&batch)?;
|
||||
learner.learn(&graph, &batch.features)?;
|
||||
|
||||
println!("Learned batch {}, total patterns: {}",
|
||||
batch.id, learner.pattern_count());
|
||||
}
|
||||
|
||||
// Test on all historical patterns
|
||||
let recall = learner.evaluate_recall(&all_test_data)?;
|
||||
println!("Recall on all patterns: {:.2}%", recall * 100.0);
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Tutorial: Online Learning</b></summary>
|
||||
|
||||
### Incremental Updates
|
||||
|
||||
```rust
|
||||
use sevensense_learning::{GnnModel, OnlineConfig};
|
||||
|
||||
let online_config = OnlineConfig {
|
||||
learning_rate: 0.0001, // Lower LR for stability
|
||||
momentum: 0.9,
|
||||
max_updates_per_sample: 5,
|
||||
replay_buffer_size: 1000,
|
||||
};
|
||||
|
||||
let mut model = GnnModel::new(model_config);
|
||||
model.enable_online_learning(online_config);
|
||||
|
||||
// Process streaming data
|
||||
for sample in stream {
|
||||
// Single-sample update
|
||||
model.online_update(&sample.embedding, &sample.label)?;
|
||||
|
||||
if model.updates_count() % 100 == 0 {
|
||||
println!("Processed {} samples", model.updates_count());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Experience Replay
|
||||
|
||||
```rust
|
||||
use sevensense_learning::{ReplayBuffer, GnnModel};
|
||||
|
||||
let mut buffer = ReplayBuffer::new(1000); // Store 1000 samples
|
||||
let mut model = GnnModel::new(config);
|
||||
|
||||
for sample in stream {
|
||||
// Add to replay buffer
|
||||
buffer.add(sample.clone());
|
||||
|
||||
// Train on current sample + replay
|
||||
let replay_batch = buffer.sample(32)?; // 32 random historical samples
|
||||
let batch = [vec![sample], replay_batch].concat();
|
||||
|
||||
model.train_batch(&batch)?;
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Tutorial: Pattern Prediction</b></summary>
|
||||
|
||||
### Predicting Similar Patterns
|
||||
|
||||
```rust
|
||||
use sevensense_learning::GnnModel;
|
||||
|
||||
let model = GnnModel::load("trained_model.bin")?;
|
||||
|
||||
// Get learned representation
|
||||
let embedding = model.encode(&query_features)?;
|
||||
|
||||
// Find similar learned patterns
|
||||
let similar = model.find_similar(&embedding, 10)?;
|
||||
|
||||
for (pattern_id, similarity) in similar {
|
||||
println!("Pattern {}: {:.3} similarity", pattern_id, similarity);
|
||||
}
|
||||
```
|
||||
|
||||
### Predicting Next Call
|
||||
|
||||
```rust
|
||||
// Given a sequence of calls, predict the next one
|
||||
let sequence = vec![embedding1, embedding2, embedding3];
|
||||
let prediction = model.predict_next(&sequence)?;
|
||||
|
||||
println!("Predicted next call embedding: {:?}", prediction.embedding);
|
||||
println!("Confidence: {:.3}", prediction.confidence);
|
||||
```
|
||||
|
||||
### Anomaly Detection
|
||||
|
||||
```rust
|
||||
use sevensense_learning::{GnnModel, AnomalyDetector};
|
||||
|
||||
let detector = AnomalyDetector::new(&model);
|
||||
|
||||
for embedding in embeddings {
|
||||
let score = detector.anomaly_score(&embedding)?;
|
||||
|
||||
if score > 0.95 {
|
||||
println!("Anomaly detected! Score: {:.3}", score);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
### GnnConfig Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `input_dim` | 1536 | Input embedding dimension |
|
||||
| `hidden_dim` | 256 | Hidden layer dimension |
|
||||
| `output_dim` | 64 | Output dimension |
|
||||
| `num_layers` | 3 | Number of GNN layers |
|
||||
| `dropout` | 0.1 | Dropout rate |
|
||||
|
||||
### EwcConfig Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `lambda` | 1000.0 | Regularization strength |
|
||||
| `fisher_samples` | 200 | Samples for Fisher estimation |
|
||||
| `online` | true | Use online EWC variant |
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Input Embeddings (1536-dim)
|
||||
│
|
||||
▼
|
||||
┌─────────┐
|
||||
│ GNN │ ◄── Graph structure (adjacency)
|
||||
│ Layer 1 │
|
||||
└────┬────┘
|
||||
│
|
||||
┌────▼────┐
|
||||
│ GNN │
|
||||
│ Layer 2 │
|
||||
└────┬────┘
|
||||
│
|
||||
┌────▼────┐
|
||||
│ GNN │
|
||||
│ Layer 3 │
|
||||
└────┬────┘
|
||||
│
|
||||
▼
|
||||
Output Embeddings (64-dim)
|
||||
```
|
||||
|
||||
## Links
|
||||
|
||||
- **Homepage**: [ruv.io](https://ruv.io)
|
||||
- **Repository**: [github.com/ruvnet/ruvector](https://github.com/ruvnet/ruvector)
|
||||
- **Crates.io**: [crates.io/crates/sevensense-learning](https://crates.io/crates/sevensense-learning)
|
||||
- **Documentation**: [docs.rs/sevensense-learning](https://docs.rs/sevensense-learning)
|
||||
|
||||
## License
|
||||
|
||||
MIT License - see [LICENSE](../../LICENSE) for details.
|
||||
|
||||
---
|
||||
|
||||
*Part of the [7sense Bioacoustic Intelligence Platform](https://ruv.io) by rUv*
|
||||
131
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/benches/gnn_benchmark.rs
vendored
Normal file
131
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/benches/gnn_benchmark.rs
vendored
Normal file
@@ -0,0 +1,131 @@
|
||||
//! Benchmarks for GNN operations.
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion, BenchmarkId};
|
||||
use ndarray::Array2;
|
||||
|
||||
// Note: These benchmarks require the crate to compile successfully.
|
||||
// They test the core GNN operations for performance regression.
|
||||
|
||||
fn create_test_features(n: usize, dim: usize) -> Array2<f32> {
|
||||
Array2::from_elem((n, dim), 0.5)
|
||||
}
|
||||
|
||||
fn create_test_adjacency(n: usize) -> Array2<f32> {
|
||||
let mut adj = Array2::<f32>::eye(n);
|
||||
// Add some random edges
|
||||
for i in 0..n.saturating_sub(1) {
|
||||
adj[[i, i + 1]] = 0.5;
|
||||
adj[[i + 1, i]] = 0.5;
|
||||
}
|
||||
adj
|
||||
}
|
||||
|
||||
fn benchmark_gcn_forward(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("gcn_forward");
|
||||
|
||||
for n in [10, 50, 100, 500].iter() {
|
||||
group.bench_with_input(BenchmarkId::from_parameter(n), n, |b, &n| {
|
||||
let features = create_test_features(n, 64);
|
||||
let adj = create_test_adjacency(n);
|
||||
|
||||
b.iter(|| {
|
||||
// Simple matrix multiplication to simulate GCN forward
|
||||
let aggregated = black_box(&adj).dot(black_box(&features));
|
||||
black_box(aggregated)
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn benchmark_attention_computation(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("attention");
|
||||
|
||||
for n in [10, 50, 100].iter() {
|
||||
group.bench_with_input(BenchmarkId::from_parameter(n), n, |b, &n| {
|
||||
let query = create_test_features(n, 64);
|
||||
let key = create_test_features(n, 64);
|
||||
|
||||
b.iter(|| {
|
||||
// Compute attention scores
|
||||
let scores = black_box(&query).dot(&black_box(&key).t());
|
||||
|
||||
// Softmax (simplified)
|
||||
let mut result = scores.clone();
|
||||
for mut row in result.rows_mut() {
|
||||
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let sum: f32 = row.iter().map(|x| (x - max).exp()).sum();
|
||||
for x in row.iter_mut() {
|
||||
*x = (*x - max).exp() / sum;
|
||||
}
|
||||
}
|
||||
black_box(result)
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn benchmark_cosine_similarity(c: &mut Criterion) {
|
||||
c.bench_function("cosine_similarity_256d", |bencher| {
|
||||
let vec_a: Vec<f32> = (0..256).map(|i| (i as f32).sin()).collect();
|
||||
let vec_b: Vec<f32> = (0..256).map(|i| (i as f32).cos()).collect();
|
||||
|
||||
bencher.iter(|| {
|
||||
let dot: f32 = black_box(&vec_a).iter().zip(black_box(&vec_b)).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = vec_a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = vec_b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
black_box(dot / (norm_a * norm_b))
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn benchmark_info_nce_loss(c: &mut Criterion) {
|
||||
c.bench_function("info_nce_10_negatives", |b| {
|
||||
let anchor: Vec<f32> = (0..128).map(|i| (i as f32 * 0.01).sin()).collect();
|
||||
let positive: Vec<f32> = (0..128).map(|i| (i as f32 * 0.01).sin() + 0.1).collect();
|
||||
let negatives: Vec<Vec<f32>> = (0..10)
|
||||
.map(|j| (0..128).map(|i| ((i + j * 10) as f32 * 0.01).cos()).collect())
|
||||
.collect();
|
||||
let neg_refs: Vec<&[f32]> = negatives.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
b.iter(|| {
|
||||
let temp = 0.07;
|
||||
|
||||
// Compute cosine similarities
|
||||
let pos_sim = {
|
||||
let dot: f32 = black_box(&anchor).iter().zip(black_box(&positive)).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = anchor.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = positive.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
dot / (norm_a * norm_b) / temp
|
||||
};
|
||||
|
||||
let neg_sims: Vec<f32> = neg_refs.iter().map(|neg| {
|
||||
let dot: f32 = anchor.iter().zip(*neg).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = anchor.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = neg.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
dot / (norm_a * norm_b) / temp
|
||||
}).collect();
|
||||
|
||||
// Log-sum-exp
|
||||
let max_sim = neg_sims.iter().chain(std::iter::once(&pos_sim))
|
||||
.cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let sum_exp: f32 = std::iter::once(pos_sim).chain(neg_sims)
|
||||
.map(|s| (s - max_sim).exp()).sum();
|
||||
|
||||
black_box(-pos_sim + max_sim + sum_exp.ln())
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
benchmark_gcn_forward,
|
||||
benchmark_attention_computation,
|
||||
benchmark_cosine_similarity,
|
||||
benchmark_info_nce_loss,
|
||||
);
|
||||
|
||||
criterion_main!(benches);
|
||||
6
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/src/application/mod.rs
vendored
Normal file
6
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/src/application/mod.rs
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
//! Application layer for the learning bounded context.
|
||||
//!
|
||||
//! Contains business logic, services, and use cases for
|
||||
//! GNN-based learning and embedding refinement.
|
||||
|
||||
pub mod services;
|
||||
829
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/src/application/services.rs
vendored
Normal file
829
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/src/application/services.rs
vendored
Normal file
@@ -0,0 +1,829 @@
|
||||
//! Learning service implementation.
|
||||
//!
|
||||
//! Provides the main application service for GNN-based learning,
|
||||
//! including training, embedding refinement, and edge prediction.
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use ndarray::Array2;
|
||||
use rayon::prelude::*;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, instrument, warn};
|
||||
|
||||
use crate::domain::entities::{
|
||||
EmbeddingId, GnnModelType, LearningConfig, LearningSession, RefinedEmbedding,
|
||||
TrainingMetrics, TrainingStatus, TransitionGraph,
|
||||
};
|
||||
use crate::domain::repository::LearningRepository;
|
||||
use crate::ewc::{EwcRegularizer, EwcState};
|
||||
use crate::infrastructure::gnn_model::{GnnError, GnnModel};
|
||||
use crate::loss;
|
||||
|
||||
/// Error type for learning service operations
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum LearningError {
|
||||
/// Invalid configuration
|
||||
#[error("Invalid configuration: {0}")]
|
||||
InvalidConfig(String),
|
||||
|
||||
/// Training error
|
||||
#[error("Training error: {0}")]
|
||||
TrainingError(String),
|
||||
|
||||
/// Model error
|
||||
#[error("Model error: {0}")]
|
||||
ModelError(String),
|
||||
|
||||
/// Data error
|
||||
#[error("Data error: {0}")]
|
||||
DataError(String),
|
||||
|
||||
/// Repository error
|
||||
#[error("Repository error: {0}")]
|
||||
RepositoryError(#[from] crate::domain::repository::RepositoryError),
|
||||
|
||||
/// GNN model error
|
||||
#[error("GNN error: {0}")]
|
||||
GnnError(#[from] GnnError),
|
||||
|
||||
/// Session not found
|
||||
#[error("Session not found: {0}")]
|
||||
SessionNotFound(String),
|
||||
|
||||
/// Session already running
|
||||
#[error("A training session is already running")]
|
||||
SessionAlreadyRunning,
|
||||
|
||||
/// Dimension mismatch
|
||||
#[error("Dimension mismatch: expected {expected}, got {actual}")]
|
||||
DimensionMismatch { expected: usize, actual: usize },
|
||||
|
||||
/// Empty graph
|
||||
#[error("Graph is empty or invalid")]
|
||||
EmptyGraph,
|
||||
|
||||
/// Internal error
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
/// Result type for learning operations
|
||||
pub type LearningResult<T> = Result<T, LearningError>;
|
||||
|
||||
/// Main learning service for GNN-based embedding refinement.
|
||||
///
|
||||
/// This service manages:
|
||||
/// - GNN model training on transition graphs
|
||||
/// - Embedding refinement through message passing
|
||||
/// - Edge prediction for relationship modeling
|
||||
/// - Continual learning with EWC regularization
|
||||
pub struct LearningService {
|
||||
/// The GNN model
|
||||
model: Arc<RwLock<GnnModel>>,
|
||||
/// Service configuration
|
||||
config: LearningConfig,
|
||||
/// EWC state for continual learning
|
||||
ewc_state: Arc<RwLock<Option<EwcState>>>,
|
||||
/// Optional repository for persistence
|
||||
repository: Option<Arc<dyn LearningRepository>>,
|
||||
/// Current active session
|
||||
current_session: Arc<RwLock<Option<LearningSession>>>,
|
||||
}
|
||||
|
||||
impl LearningService {
|
||||
/// Create a new learning service with the given configuration
|
||||
#[must_use]
|
||||
pub fn new(config: LearningConfig) -> Self {
|
||||
let model = GnnModel::new(
|
||||
config.model_type,
|
||||
config.input_dim,
|
||||
config.output_dim,
|
||||
config.hyperparameters.num_layers,
|
||||
config.hyperparameters.hidden_dim,
|
||||
config.hyperparameters.num_heads,
|
||||
config.hyperparameters.dropout,
|
||||
);
|
||||
|
||||
Self {
|
||||
model: Arc::new(RwLock::new(model)),
|
||||
config,
|
||||
ewc_state: Arc::new(RwLock::new(None)),
|
||||
repository: None,
|
||||
current_session: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a learning service with a repository
|
||||
#[must_use]
|
||||
pub fn with_repository(mut self, repository: Arc<dyn LearningRepository>) -> Self {
|
||||
self.repository = Some(repository);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the current configuration
|
||||
#[must_use]
|
||||
pub fn config(&self) -> &LearningConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Get the model type
|
||||
#[must_use]
|
||||
pub fn model_type(&self) -> GnnModelType {
|
||||
self.config.model_type
|
||||
}
|
||||
|
||||
/// Start a new training session
|
||||
#[instrument(skip(self), err)]
|
||||
pub async fn start_session(&self) -> LearningResult<String> {
|
||||
// Check if a session is already running
|
||||
{
|
||||
let session = self.current_session.read().await;
|
||||
if let Some(ref s) = *session {
|
||||
if s.status.is_active() {
|
||||
return Err(LearningError::SessionAlreadyRunning);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut session = LearningSession::new(self.config.clone());
|
||||
session.start();
|
||||
|
||||
let session_id = session.id.clone();
|
||||
|
||||
// Persist if repository available
|
||||
if let Some(ref repo) = self.repository {
|
||||
repo.save_session(&session).await?;
|
||||
}
|
||||
|
||||
*self.current_session.write().await = Some(session);
|
||||
|
||||
info!(session_id = %session_id, "Started new learning session");
|
||||
Ok(session_id)
|
||||
}
|
||||
|
||||
/// Train a single epoch on the transition graph
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `graph` - The transition graph to train on
|
||||
///
|
||||
/// # Returns
|
||||
/// Training metrics for the epoch
|
||||
#[instrument(skip(self, graph), fields(nodes = graph.num_nodes(), edges = graph.num_edges()), err)]
|
||||
pub async fn train_epoch(&self, graph: &TransitionGraph) -> LearningResult<TrainingMetrics> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
// Validate graph
|
||||
if graph.num_nodes() == 0 {
|
||||
return Err(LearningError::EmptyGraph);
|
||||
}
|
||||
|
||||
if let Some(dim) = graph.embedding_dim() {
|
||||
if dim != self.config.input_dim {
|
||||
return Err(LearningError::DimensionMismatch {
|
||||
expected: self.config.input_dim,
|
||||
actual: dim,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure we have an active session
|
||||
let mut session_guard = self.current_session.write().await;
|
||||
let session = session_guard
|
||||
.as_mut()
|
||||
.ok_or_else(|| LearningError::TrainingError("No active session".to_string()))?;
|
||||
|
||||
let current_epoch = session.metrics.epoch + 1;
|
||||
let lr = self.compute_learning_rate(current_epoch);
|
||||
|
||||
// Build adjacency matrix
|
||||
let adj_matrix = self.build_adjacency_matrix(graph);
|
||||
|
||||
// Build feature matrix from embeddings
|
||||
let features = self.build_feature_matrix(graph);
|
||||
|
||||
// Forward pass through GNN
|
||||
let mut model = self.model.write().await;
|
||||
let output = model.forward(&features, &adj_matrix)?;
|
||||
|
||||
// Compute loss using contrastive learning
|
||||
let (loss, accuracy) = self.compute_loss(graph, &output).await?;
|
||||
|
||||
// Compute gradients and update weights
|
||||
let gradients = self.compute_gradients(graph, &features, &output, &adj_matrix, &model)?;
|
||||
let grad_norm = self.compute_gradient_norm(&gradients);
|
||||
|
||||
// Apply gradient clipping if configured
|
||||
let clipped_gradients = if let Some(clip_value) = self.config.hyperparameters.gradient_clip {
|
||||
self.clip_gradients(gradients, clip_value)
|
||||
} else {
|
||||
gradients
|
||||
};
|
||||
|
||||
// Update model weights
|
||||
model.update_weights(&clipped_gradients, lr, self.config.hyperparameters.weight_decay);
|
||||
|
||||
// Apply EWC regularization if available
|
||||
if let Some(ref ewc_state) = *self.ewc_state.read().await {
|
||||
let ewc_reg = EwcRegularizer::new(self.config.hyperparameters.ewc_lambda);
|
||||
let ewc_loss = ewc_reg.compute_penalty(&model, ewc_state);
|
||||
debug!(ewc_loss = ewc_loss, "Applied EWC regularization");
|
||||
}
|
||||
|
||||
let epoch_time_ms = start_time.elapsed().as_millis() as u64;
|
||||
|
||||
let metrics = TrainingMetrics {
|
||||
loss,
|
||||
accuracy,
|
||||
epoch: current_epoch,
|
||||
learning_rate: lr,
|
||||
validation_loss: None,
|
||||
validation_accuracy: None,
|
||||
gradient_norm: Some(grad_norm),
|
||||
epoch_time_ms,
|
||||
custom_metrics: Default::default(),
|
||||
};
|
||||
|
||||
// Update session metrics
|
||||
session.update_metrics(metrics.clone());
|
||||
|
||||
// Persist session if repository available
|
||||
drop(model); // Release write lock before async operation
|
||||
if let Some(ref repo) = self.repository {
|
||||
repo.update_session(session).await?;
|
||||
}
|
||||
|
||||
info!(
|
||||
epoch = current_epoch,
|
||||
loss = loss,
|
||||
accuracy = accuracy,
|
||||
time_ms = epoch_time_ms,
|
||||
"Completed training epoch"
|
||||
);
|
||||
|
||||
Ok(metrics)
|
||||
}
|
||||
|
||||
/// Refine embeddings using the trained GNN model
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `embeddings` - Input embeddings to refine
|
||||
///
|
||||
/// # Returns
|
||||
/// Refined embeddings with quality scores
|
||||
#[instrument(skip(self, embeddings), fields(count = embeddings.len()), err)]
|
||||
pub async fn refine_embeddings(
|
||||
&self,
|
||||
embeddings: &[(EmbeddingId, Vec<f32>)],
|
||||
) -> LearningResult<Vec<RefinedEmbedding>> {
|
||||
if embeddings.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Validate dimensions
|
||||
if let Some((_, emb)) = embeddings.first() {
|
||||
if emb.len() != self.config.input_dim {
|
||||
return Err(LearningError::DimensionMismatch {
|
||||
expected: self.config.input_dim,
|
||||
actual: emb.len(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let model = self.model.read().await;
|
||||
|
||||
// Build a simple graph where each embedding is a node
|
||||
// Connected based on cosine similarity
|
||||
let n = embeddings.len();
|
||||
let similarity_threshold = 0.5;
|
||||
|
||||
// Build feature matrix
|
||||
let mut features = Array2::zeros((n, self.config.input_dim));
|
||||
for (i, (_, emb)) in embeddings.iter().enumerate() {
|
||||
for (j, &val) in emb.iter().enumerate() {
|
||||
features[[i, j]] = val;
|
||||
}
|
||||
}
|
||||
|
||||
// Build adjacency matrix based on similarity
|
||||
let mut adj_matrix = Array2::<f32>::eye(n);
|
||||
for i in 0..n {
|
||||
for j in (i + 1)..n {
|
||||
let sim = cosine_similarity(&embeddings[i].1, &embeddings[j].1);
|
||||
if sim > similarity_threshold {
|
||||
adj_matrix[[i, j]] = sim;
|
||||
adj_matrix[[j, i]] = sim;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize adjacency matrix
|
||||
let degrees: Vec<f32> = (0..n)
|
||||
.map(|i| adj_matrix.row(i).sum())
|
||||
.collect();
|
||||
for i in 0..n {
|
||||
for j in 0..n {
|
||||
if degrees[i] > 0.0 && degrees[j] > 0.0 {
|
||||
adj_matrix[[i, j]] /= (degrees[i] * degrees[j]).sqrt();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Forward pass
|
||||
let output = model.forward(&features, &adj_matrix)?;
|
||||
|
||||
// Create refined embeddings
|
||||
let session_id = self
|
||||
.current_session
|
||||
.read()
|
||||
.await
|
||||
.as_ref()
|
||||
.map(|s| s.id.clone());
|
||||
|
||||
let refined: Vec<RefinedEmbedding> = embeddings
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.map(|(i, (id, original))| {
|
||||
let refined_vec: Vec<f32> = output.row(i).to_vec();
|
||||
|
||||
// Compute refinement score based on change magnitude
|
||||
let delta = original
|
||||
.iter()
|
||||
.zip(&refined_vec)
|
||||
.map(|(a, b)| (a - b).powi(2))
|
||||
.sum::<f32>()
|
||||
.sqrt();
|
||||
|
||||
let score = 1.0 / (1.0 + delta); // Higher score for smaller changes
|
||||
|
||||
let mut refined = RefinedEmbedding::new(id.clone(), refined_vec, score);
|
||||
refined.session_id = session_id.clone();
|
||||
refined.delta_norm = Some(delta);
|
||||
refined.normalize();
|
||||
refined
|
||||
})
|
||||
.collect();
|
||||
|
||||
info!(count = refined.len(), "Refined embeddings");
|
||||
|
||||
// Persist if repository available
|
||||
if let Some(ref repo) = self.repository {
|
||||
repo.save_refined_embeddings(&refined).await?;
|
||||
}
|
||||
|
||||
Ok(refined)
|
||||
}
|
||||
|
||||
/// Predict edge weight between two embeddings
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `from` - Source embedding
|
||||
/// * `to` - Target embedding
|
||||
///
|
||||
/// # Returns
|
||||
/// Predicted edge weight (0.0 to 1.0)
|
||||
#[instrument(skip(self, from, to), err)]
|
||||
pub async fn predict_edge(&self, from: &[f32], to: &[f32]) -> LearningResult<f32> {
|
||||
// Validate dimensions
|
||||
if from.len() != self.config.input_dim {
|
||||
return Err(LearningError::DimensionMismatch {
|
||||
expected: self.config.input_dim,
|
||||
actual: from.len(),
|
||||
});
|
||||
}
|
||||
if to.len() != self.config.input_dim {
|
||||
return Err(LearningError::DimensionMismatch {
|
||||
expected: self.config.input_dim,
|
||||
actual: to.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let model = self.model.read().await;
|
||||
|
||||
// Create a mini-graph with two nodes
|
||||
let mut features = Array2::zeros((2, self.config.input_dim));
|
||||
for (j, &val) in from.iter().enumerate() {
|
||||
features[[0, j]] = val;
|
||||
}
|
||||
for (j, &val) in to.iter().enumerate() {
|
||||
features[[1, j]] = val;
|
||||
}
|
||||
|
||||
// Simple adjacency (self-loops only initially)
|
||||
let adj_matrix = Array2::<f32>::eye(2);
|
||||
|
||||
// Forward pass
|
||||
let output = model.forward(&features, &adj_matrix)?;
|
||||
|
||||
// Compute similarity of refined embeddings
|
||||
let from_refined: Vec<f32> = output.row(0).to_vec();
|
||||
let to_refined: Vec<f32> = output.row(1).to_vec();
|
||||
|
||||
let similarity = cosine_similarity(&from_refined, &to_refined);
|
||||
let weight = (similarity + 1.0) / 2.0; // Map from [-1, 1] to [0, 1]
|
||||
|
||||
Ok(weight)
|
||||
}
|
||||
|
||||
/// Complete the current training session
|
||||
#[instrument(skip(self), err)]
|
||||
pub async fn complete_session(&self) -> LearningResult<()> {
|
||||
let mut session_guard = self.current_session.write().await;
|
||||
|
||||
if let Some(ref mut session) = *session_guard {
|
||||
session.complete();
|
||||
|
||||
// Compute and store Fisher information for EWC
|
||||
// This would be done in a real implementation with the final model state
|
||||
|
||||
if let Some(ref repo) = self.repository {
|
||||
repo.update_session(session).await?;
|
||||
}
|
||||
|
||||
info!(session_id = %session.id, "Completed learning session");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Fail the current session with an error
|
||||
#[instrument(skip(self, error), err)]
|
||||
pub async fn fail_session(&self, error: impl Into<String>) -> LearningResult<()> {
|
||||
let error_msg = error.into();
|
||||
let mut session_guard = self.current_session.write().await;
|
||||
|
||||
if let Some(ref mut session) = *session_guard {
|
||||
session.fail(&error_msg);
|
||||
|
||||
if let Some(ref repo) = self.repository {
|
||||
repo.update_session(session).await?;
|
||||
}
|
||||
|
||||
warn!(session_id = %session.id, error = %error_msg, "Failed learning session");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the current session status
|
||||
pub async fn get_session(&self) -> Option<LearningSession> {
|
||||
self.current_session.read().await.clone()
|
||||
}
|
||||
|
||||
/// Save EWC state from current model for future regularization
|
||||
#[instrument(skip(self, graph), err)]
|
||||
pub async fn consolidate_ewc(&self, graph: &TransitionGraph) -> LearningResult<()> {
|
||||
let model = self.model.read().await;
|
||||
let fisher = self.compute_fisher_information(&model, graph)?;
|
||||
let state = EwcState::new(model.get_parameters(), fisher);
|
||||
|
||||
*self.ewc_state.write().await = Some(state);
|
||||
|
||||
info!("Consolidated EWC state");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// =========== Private Helper Methods ===========
|
||||
|
||||
fn build_adjacency_matrix(&self, graph: &TransitionGraph) -> Array2<f32> {
|
||||
let n = graph.num_nodes();
|
||||
let mut adj = Array2::zeros((n, n));
|
||||
|
||||
// Add self-loops
|
||||
for i in 0..n {
|
||||
adj[[i, i]] = 1.0;
|
||||
}
|
||||
|
||||
// Add edges
|
||||
for &(from, to, weight) in &graph.edges {
|
||||
adj[[from, to]] = weight;
|
||||
if !graph.directed {
|
||||
adj[[to, from]] = weight;
|
||||
}
|
||||
}
|
||||
|
||||
// Symmetric normalization: D^(-1/2) * A * D^(-1/2)
|
||||
let degrees: Vec<f32> = (0..n).map(|i| adj.row(i).sum()).collect();
|
||||
for i in 0..n {
|
||||
for j in 0..n {
|
||||
if degrees[i] > 0.0 && degrees[j] > 0.0 {
|
||||
adj[[i, j]] /= (degrees[i] * degrees[j]).sqrt();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
adj
|
||||
}
|
||||
|
||||
fn build_feature_matrix(&self, graph: &TransitionGraph) -> Array2<f32> {
|
||||
let n = graph.num_nodes();
|
||||
let dim = graph.embedding_dim().unwrap_or(self.config.input_dim);
|
||||
let mut features = Array2::zeros((n, dim));
|
||||
|
||||
for (i, emb) in graph.embeddings.iter().enumerate() {
|
||||
for (j, &val) in emb.iter().enumerate() {
|
||||
features[[i, j]] = val;
|
||||
}
|
||||
}
|
||||
|
||||
features
|
||||
}
|
||||
|
||||
async fn compute_loss(
|
||||
&self,
|
||||
graph: &TransitionGraph,
|
||||
output: &Array2<f32>,
|
||||
) -> LearningResult<(f32, f32)> {
|
||||
let n = graph.num_nodes();
|
||||
if n == 0 {
|
||||
return Ok((0.0, 0.0));
|
||||
}
|
||||
|
||||
let mut total_loss = 0.0;
|
||||
let mut correct = 0usize;
|
||||
let mut total = 0usize;
|
||||
|
||||
let hp = &self.config.hyperparameters;
|
||||
|
||||
// For each edge, compute contrastive loss
|
||||
for &(from, to, weight) in &graph.edges {
|
||||
let anchor: Vec<f32> = output.row(from).to_vec();
|
||||
let positive: Vec<f32> = output.row(to).to_vec();
|
||||
|
||||
// Sample negative nodes (nodes not connected to anchor)
|
||||
let negatives: Vec<Vec<f32>> = (0..n)
|
||||
.filter(|&i| i != from && i != to)
|
||||
.take(hp.negative_ratio)
|
||||
.map(|i| output.row(i).to_vec())
|
||||
.collect();
|
||||
|
||||
if !negatives.is_empty() {
|
||||
let neg_refs: Vec<&[f32]> = negatives.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
// InfoNCE loss
|
||||
let loss = loss::info_nce_loss(&anchor, &positive, &neg_refs, hp.temperature);
|
||||
total_loss += loss * weight;
|
||||
}
|
||||
|
||||
// Compute accuracy based on whether positive is closer than negatives
|
||||
let pos_sim = cosine_similarity(&anchor, &positive);
|
||||
let all_closer = (0..n)
|
||||
.filter(|&i| i != from && i != to)
|
||||
.all(|i| {
|
||||
let neg: Vec<f32> = output.row(i).to_vec();
|
||||
cosine_similarity(&anchor, &neg) < pos_sim
|
||||
});
|
||||
|
||||
if all_closer {
|
||||
correct += 1;
|
||||
}
|
||||
total += 1;
|
||||
}
|
||||
|
||||
let avg_loss = if graph.edges.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
total_loss / graph.edges.len() as f32
|
||||
};
|
||||
|
||||
let accuracy = if total == 0 {
|
||||
0.0
|
||||
} else {
|
||||
correct as f32 / total as f32
|
||||
};
|
||||
|
||||
Ok((avg_loss, accuracy))
|
||||
}
|
||||
|
||||
fn compute_gradients(
|
||||
&self,
|
||||
_graph: &TransitionGraph,
|
||||
features: &Array2<f32>,
|
||||
output: &Array2<f32>,
|
||||
_adj_matrix: &Array2<f32>,
|
||||
model: &GnnModel,
|
||||
) -> LearningResult<Vec<Array2<f32>>> {
|
||||
// Simplified gradient computation
|
||||
// In practice, this would use automatic differentiation
|
||||
let num_layers = model.num_layers();
|
||||
let mut gradients = Vec::with_capacity(num_layers);
|
||||
let batch_size = features.nrows() as f32;
|
||||
|
||||
for layer_idx in 0..num_layers {
|
||||
let (in_dim, out_dim) = model.layer_dims(layer_idx);
|
||||
|
||||
// Compute gradient approximation based on output variance
|
||||
// This is a simplified placeholder - real backprop would use chain rule
|
||||
let output_centered = &output.mapv(|x| x - output.mean().unwrap_or(0.0));
|
||||
|
||||
// Approximate gradient as outer product scaled by learning signal
|
||||
let grad = if layer_idx == 0 {
|
||||
// Input layer: gradient is features^T * output_signal / batch_size
|
||||
let output_slice = if output.ncols() >= out_dim {
|
||||
output_centered.slice(ndarray::s![.., ..out_dim]).to_owned()
|
||||
} else {
|
||||
Array2::zeros((output.nrows(), out_dim))
|
||||
};
|
||||
let feat_slice = if features.ncols() >= in_dim {
|
||||
features.slice(ndarray::s![.., ..in_dim]).to_owned()
|
||||
} else {
|
||||
Array2::zeros((features.nrows(), in_dim))
|
||||
};
|
||||
feat_slice.t().dot(&output_slice) / batch_size
|
||||
} else {
|
||||
// Hidden layers: use small random gradients scaled by output variance
|
||||
let variance = output.var(0.0);
|
||||
Array2::from_elem((in_dim, out_dim), 0.01 * variance.sqrt())
|
||||
};
|
||||
|
||||
// Reshape to (out_dim, in_dim)
|
||||
let scaled_grad = grad.t().to_owned();
|
||||
gradients.push(scaled_grad);
|
||||
}
|
||||
|
||||
Ok(gradients)
|
||||
}
|
||||
|
||||
fn compute_gradient_norm(&self, gradients: &[Array2<f32>]) -> f32 {
|
||||
gradients
|
||||
.iter()
|
||||
.map(|g| g.iter().map(|&x| x * x).sum::<f32>())
|
||||
.sum::<f32>()
|
||||
.sqrt()
|
||||
}
|
||||
|
||||
fn clip_gradients(&self, gradients: Vec<Array2<f32>>, max_norm: f32) -> Vec<Array2<f32>> {
|
||||
let current_norm = self.compute_gradient_norm(&gradients);
|
||||
if current_norm <= max_norm {
|
||||
return gradients;
|
||||
}
|
||||
|
||||
let scale = max_norm / current_norm;
|
||||
gradients.into_iter().map(|g| g * scale).collect()
|
||||
}
|
||||
|
||||
fn compute_learning_rate(&self, epoch: usize) -> f32 {
|
||||
let base_lr = self.config.hyperparameters.learning_rate;
|
||||
let total_epochs = self.config.hyperparameters.epochs;
|
||||
|
||||
// Cosine annealing schedule
|
||||
let progress = epoch as f32 / total_epochs as f32;
|
||||
let cosine_factor = (1.0 + (progress * std::f32::consts::PI).cos()) / 2.0;
|
||||
|
||||
base_lr * cosine_factor
|
||||
}
|
||||
|
||||
fn compute_fisher_information(
|
||||
&self,
|
||||
_model: &GnnModel,
|
||||
_graph: &TransitionGraph,
|
||||
) -> LearningResult<crate::ewc::FisherInformation> {
|
||||
// Simplified Fisher information computation
|
||||
// In practice, this would compute the diagonal of the Fisher matrix
|
||||
Ok(crate::ewc::FisherInformation::default())
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute cosine similarity between two vectors
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() || a.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a < 1e-10 || norm_b < 1e-10 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
dot / (norm_a * norm_b)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
|
||||
|
||||
let c = vec![0.0, 1.0, 0.0];
|
||||
assert!((cosine_similarity(&a, &c)).abs() < 1e-6);
|
||||
|
||||
let d = vec![-1.0, 0.0, 0.0];
|
||||
assert!((cosine_similarity(&a, &d) + 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_learning_service_creation() {
|
||||
let config = LearningConfig::default();
|
||||
let service = LearningService::new(config.clone());
|
||||
|
||||
assert_eq!(service.model_type(), GnnModelType::Gcn);
|
||||
assert_eq!(service.config().input_dim, 768);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_start_session() {
|
||||
let config = LearningConfig::default();
|
||||
let service = LearningService::new(config);
|
||||
|
||||
let session_id = service.start_session().await.unwrap();
|
||||
assert!(!session_id.is_empty());
|
||||
|
||||
let session = service.get_session().await.unwrap();
|
||||
assert_eq!(session.status, TrainingStatus::Running);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_train_epoch() {
|
||||
let mut config = LearningConfig::default();
|
||||
config.input_dim = 8;
|
||||
config.output_dim = 4;
|
||||
config.hyperparameters.hidden_dim = 8;
|
||||
|
||||
let service = LearningService::new(config);
|
||||
service.start_session().await.unwrap();
|
||||
|
||||
let mut graph = TransitionGraph::new();
|
||||
graph.add_node(EmbeddingId::new("n1"), vec![0.1; 8], None);
|
||||
graph.add_node(EmbeddingId::new("n2"), vec![0.2; 8], None);
|
||||
graph.add_node(EmbeddingId::new("n3"), vec![0.3; 8], None);
|
||||
graph.add_edge(0, 1, 0.8);
|
||||
graph.add_edge(1, 2, 0.7);
|
||||
|
||||
let metrics = service.train_epoch(&graph).await.unwrap();
|
||||
assert_eq!(metrics.epoch, 1);
|
||||
assert!(metrics.loss >= 0.0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_refine_embeddings() {
|
||||
let mut config = LearningConfig::default();
|
||||
config.input_dim = 8;
|
||||
config.output_dim = 4;
|
||||
config.hyperparameters.hidden_dim = 8;
|
||||
|
||||
let service = LearningService::new(config);
|
||||
service.start_session().await.unwrap();
|
||||
|
||||
let embeddings = vec![
|
||||
(EmbeddingId::new("e1"), vec![0.1; 8]),
|
||||
(EmbeddingId::new("e2"), vec![0.2; 8]),
|
||||
];
|
||||
|
||||
let refined = service.refine_embeddings(&embeddings).await.unwrap();
|
||||
assert_eq!(refined.len(), 2);
|
||||
assert_eq!(refined[0].dim(), 4); // Output dimension
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_predict_edge() {
|
||||
let mut config = LearningConfig::default();
|
||||
config.input_dim = 8;
|
||||
config.output_dim = 4;
|
||||
config.hyperparameters.hidden_dim = 8;
|
||||
|
||||
let service = LearningService::new(config);
|
||||
|
||||
let from = vec![0.1; 8];
|
||||
let to = vec![0.1; 8]; // Same embedding should have high weight
|
||||
|
||||
let weight = service.predict_edge(&from, &to).await.unwrap();
|
||||
assert!(weight >= 0.0 && weight <= 1.0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_empty_graph_error() {
|
||||
let config = LearningConfig::default();
|
||||
let service = LearningService::new(config);
|
||||
service.start_session().await.unwrap();
|
||||
|
||||
let graph = TransitionGraph::new();
|
||||
let result = service.train_epoch(&graph).await;
|
||||
|
||||
assert!(matches!(result, Err(LearningError::EmptyGraph)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dimension_mismatch() {
|
||||
let mut config = LearningConfig::default();
|
||||
config.input_dim = 768;
|
||||
|
||||
let service = LearningService::new(config);
|
||||
service.start_session().await.unwrap();
|
||||
|
||||
let mut graph = TransitionGraph::new();
|
||||
graph.add_node(EmbeddingId::new("n1"), vec![0.1; 128], None); // Wrong dimension
|
||||
|
||||
let result = service.train_epoch(&graph).await;
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(LearningError::DimensionMismatch { .. })
|
||||
));
|
||||
}
|
||||
}
|
||||
836
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/src/domain/entities.rs
vendored
Normal file
836
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/src/domain/entities.rs
vendored
Normal file
@@ -0,0 +1,836 @@
|
||||
//! Domain entities for the learning bounded context.
|
||||
//!
|
||||
//! This module defines the core domain entities including:
|
||||
//! - Learning sessions for tracking training state
|
||||
//! - GNN model types and training metrics
|
||||
//! - Transition graphs for embedding relationships
|
||||
//! - Refined embeddings as output of the learning process
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Unique identifier for an embedding
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct EmbeddingId(pub String);
|
||||
|
||||
impl EmbeddingId {
|
||||
/// Create a new embedding ID
|
||||
#[must_use]
|
||||
pub fn new(id: impl Into<String>) -> Self {
|
||||
Self(id.into())
|
||||
}
|
||||
|
||||
/// Generate a new random embedding ID
|
||||
#[must_use]
|
||||
pub fn generate() -> Self {
|
||||
Self(Uuid::new_v4().to_string())
|
||||
}
|
||||
|
||||
/// Get the inner string value
|
||||
#[must_use]
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for EmbeddingId {
|
||||
fn from(s: String) -> Self {
|
||||
Self(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for EmbeddingId {
|
||||
fn from(s: &str) -> Self {
|
||||
Self(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Timestamp type alias for consistency
|
||||
pub type Timestamp = DateTime<Utc>;
|
||||
|
||||
/// Types of GNN models supported by the learning system
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum GnnModelType {
|
||||
/// Graph Convolutional Network
|
||||
/// Uses spectral convolutions on graph-structured data
|
||||
Gcn,
|
||||
/// GraphSAGE (SAmple and aggreGatE)
|
||||
/// Learns node embeddings through neighborhood sampling and aggregation
|
||||
GraphSage,
|
||||
/// Graph Attention Network
|
||||
/// Uses attention mechanisms to weight neighbor contributions
|
||||
Gat,
|
||||
}
|
||||
|
||||
impl Default for GnnModelType {
|
||||
fn default() -> Self {
|
||||
Self::Gcn
|
||||
}
|
||||
}
|
||||
|
||||
impl GnnModelType {
|
||||
/// Get the number of learnable parameters per layer (approximate)
|
||||
#[must_use]
|
||||
pub fn params_per_layer(&self, input_dim: usize, output_dim: usize) -> usize {
|
||||
match self {
|
||||
Self::Gcn => input_dim * output_dim + output_dim,
|
||||
Self::GraphSage => 2 * input_dim * output_dim + output_dim,
|
||||
Self::Gat => input_dim * output_dim + 2 * output_dim,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get recommended number of attention heads (only relevant for GAT)
|
||||
#[must_use]
|
||||
pub fn recommended_heads(&self) -> usize {
|
||||
match self {
|
||||
Self::Gat => 8,
|
||||
_ => 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Status of a training session
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TrainingStatus {
|
||||
/// Session created but training not started
|
||||
Pending,
|
||||
/// Training is currently running
|
||||
Running,
|
||||
/// Training completed successfully
|
||||
Completed,
|
||||
/// Training failed with an error
|
||||
Failed,
|
||||
/// Training was paused
|
||||
Paused,
|
||||
/// Training was cancelled by user
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
impl TrainingStatus {
|
||||
/// Check if the status represents a terminal state
|
||||
#[must_use]
|
||||
pub fn is_terminal(&self) -> bool {
|
||||
matches!(self, Self::Completed | Self::Failed | Self::Cancelled)
|
||||
}
|
||||
|
||||
/// Check if training can be resumed from this status
|
||||
#[must_use]
|
||||
pub fn can_resume(&self) -> bool {
|
||||
matches!(self, Self::Paused)
|
||||
}
|
||||
|
||||
/// Check if training is active
|
||||
#[must_use]
|
||||
pub fn is_active(&self) -> bool {
|
||||
matches!(self, Self::Running)
|
||||
}
|
||||
}
|
||||
|
||||
/// Metrics collected during training
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TrainingMetrics {
|
||||
/// Current loss value
|
||||
pub loss: f32,
|
||||
/// Training accuracy (0.0 to 1.0)
|
||||
pub accuracy: f32,
|
||||
/// Current epoch number
|
||||
pub epoch: usize,
|
||||
/// Current learning rate
|
||||
pub learning_rate: f32,
|
||||
/// Validation loss (if validation set provided)
|
||||
pub validation_loss: Option<f32>,
|
||||
/// Validation accuracy
|
||||
pub validation_accuracy: Option<f32>,
|
||||
/// Gradient norm (for monitoring stability)
|
||||
pub gradient_norm: Option<f32>,
|
||||
/// Time taken for this epoch in milliseconds
|
||||
pub epoch_time_ms: u64,
|
||||
/// Additional custom metrics
|
||||
#[serde(default)]
|
||||
pub custom_metrics: HashMap<String, f32>,
|
||||
}
|
||||
|
||||
impl Default for TrainingMetrics {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
loss: f32::INFINITY,
|
||||
accuracy: 0.0,
|
||||
epoch: 0,
|
||||
learning_rate: 0.001,
|
||||
validation_loss: None,
|
||||
validation_accuracy: None,
|
||||
gradient_norm: None,
|
||||
epoch_time_ms: 0,
|
||||
custom_metrics: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TrainingMetrics {
|
||||
/// Create new metrics for an epoch
|
||||
#[must_use]
|
||||
pub fn new(epoch: usize, loss: f32, accuracy: f32, learning_rate: f32) -> Self {
|
||||
Self {
|
||||
loss,
|
||||
accuracy,
|
||||
epoch,
|
||||
learning_rate,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Set validation metrics
|
||||
#[must_use]
|
||||
pub fn with_validation(mut self, loss: f32, accuracy: f32) -> Self {
|
||||
self.validation_loss = Some(loss);
|
||||
self.validation_accuracy = Some(accuracy);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a custom metric
|
||||
pub fn add_custom_metric(&mut self, name: impl Into<String>, value: f32) {
|
||||
self.custom_metrics.insert(name.into(), value);
|
||||
}
|
||||
|
||||
/// Check if training is converging (loss is decreasing)
|
||||
#[must_use]
|
||||
pub fn is_improving(&self, previous: &Self) -> bool {
|
||||
self.loss < previous.loss
|
||||
}
|
||||
}
|
||||
|
||||
/// Hyperparameters for training
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HyperParameters {
|
||||
/// Initial learning rate
|
||||
pub learning_rate: f32,
|
||||
/// Weight decay (L2 regularization)
|
||||
pub weight_decay: f32,
|
||||
/// Dropout probability
|
||||
pub dropout: f32,
|
||||
/// Number of training epochs
|
||||
pub epochs: usize,
|
||||
/// Batch size for training
|
||||
pub batch_size: usize,
|
||||
/// Early stopping patience (epochs without improvement)
|
||||
pub early_stopping_patience: Option<usize>,
|
||||
/// Gradient clipping threshold
|
||||
pub gradient_clip: Option<f32>,
|
||||
/// Temperature for contrastive loss
|
||||
pub temperature: f32,
|
||||
/// Margin for triplet loss
|
||||
pub triplet_margin: f32,
|
||||
/// EWC lambda (importance of old task knowledge)
|
||||
pub ewc_lambda: f32,
|
||||
/// Number of GNN layers
|
||||
pub num_layers: usize,
|
||||
/// Hidden dimension size
|
||||
pub hidden_dim: usize,
|
||||
/// Number of attention heads (for GAT)
|
||||
pub num_heads: usize,
|
||||
/// Negative sample ratio for contrastive learning
|
||||
pub negative_ratio: usize,
|
||||
}
|
||||
|
||||
impl Default for HyperParameters {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
learning_rate: 0.001,
|
||||
weight_decay: 5e-4,
|
||||
dropout: 0.5,
|
||||
epochs: 200,
|
||||
batch_size: 32,
|
||||
early_stopping_patience: Some(20),
|
||||
gradient_clip: Some(1.0),
|
||||
temperature: 0.07,
|
||||
triplet_margin: 1.0,
|
||||
ewc_lambda: 5000.0,
|
||||
num_layers: 2,
|
||||
hidden_dim: 256,
|
||||
num_heads: 8,
|
||||
negative_ratio: 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for the learning service
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LearningConfig {
|
||||
/// Type of GNN model to use
|
||||
pub model_type: GnnModelType,
|
||||
/// Input embedding dimension
|
||||
pub input_dim: usize,
|
||||
/// Output embedding dimension
|
||||
pub output_dim: usize,
|
||||
/// Training hyperparameters
|
||||
pub hyperparameters: HyperParameters,
|
||||
/// Enable mixed precision training
|
||||
pub mixed_precision: bool,
|
||||
/// Device to use for training
|
||||
pub device: Device,
|
||||
/// Random seed for reproducibility
|
||||
pub seed: Option<u64>,
|
||||
/// Enable gradient checkpointing to save memory
|
||||
pub gradient_checkpointing: bool,
|
||||
}
|
||||
|
||||
impl Default for LearningConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model_type: GnnModelType::Gcn,
|
||||
input_dim: 768,
|
||||
output_dim: 256,
|
||||
hyperparameters: HyperParameters::default(),
|
||||
mixed_precision: false,
|
||||
device: Device::Cpu,
|
||||
seed: None,
|
||||
gradient_checkpointing: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Device for computation
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Device {
|
||||
/// CPU computation
|
||||
#[default]
|
||||
Cpu,
|
||||
/// CUDA GPU computation
|
||||
Cuda(usize),
|
||||
/// Metal GPU (Apple Silicon)
|
||||
Metal,
|
||||
}
|
||||
|
||||
/// A learning session tracking the state of a training run
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LearningSession {
|
||||
/// Unique session identifier
|
||||
pub id: String,
|
||||
/// Type of GNN model being trained
|
||||
pub model_type: GnnModelType,
|
||||
/// Current training status
|
||||
pub status: TrainingStatus,
|
||||
/// Current training metrics
|
||||
pub metrics: TrainingMetrics,
|
||||
/// When the session was started
|
||||
pub started_at: Timestamp,
|
||||
/// When the session was last updated
|
||||
pub updated_at: Timestamp,
|
||||
/// When the session completed (if applicable)
|
||||
pub completed_at: Option<Timestamp>,
|
||||
/// Configuration used for this session
|
||||
pub config: LearningConfig,
|
||||
/// History of metrics per epoch
|
||||
#[serde(default)]
|
||||
pub metrics_history: Vec<TrainingMetrics>,
|
||||
/// Best metrics achieved during training
|
||||
pub best_metrics: Option<TrainingMetrics>,
|
||||
/// Error message if training failed
|
||||
pub error_message: Option<String>,
|
||||
/// Number of checkpoints saved
|
||||
pub checkpoint_count: usize,
|
||||
}
|
||||
|
||||
impl LearningSession {
|
||||
/// Create a new learning session
|
||||
#[must_use]
|
||||
pub fn new(config: LearningConfig) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
model_type: config.model_type,
|
||||
status: TrainingStatus::Pending,
|
||||
metrics: TrainingMetrics::default(),
|
||||
started_at: now,
|
||||
updated_at: now,
|
||||
completed_at: None,
|
||||
config,
|
||||
metrics_history: Vec::new(),
|
||||
best_metrics: None,
|
||||
error_message: None,
|
||||
checkpoint_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the training session
|
||||
pub fn start(&mut self) {
|
||||
self.status = TrainingStatus::Running;
|
||||
self.updated_at = Utc::now();
|
||||
}
|
||||
|
||||
/// Update metrics for a completed epoch
|
||||
pub fn update_metrics(&mut self, metrics: TrainingMetrics) {
|
||||
// Update best metrics if this is an improvement
|
||||
if self.best_metrics.is_none()
|
||||
|| metrics.loss < self.best_metrics.as_ref().unwrap().loss
|
||||
{
|
||||
self.best_metrics = Some(metrics.clone());
|
||||
}
|
||||
|
||||
self.metrics = metrics.clone();
|
||||
self.metrics_history.push(metrics);
|
||||
self.updated_at = Utc::now();
|
||||
}
|
||||
|
||||
/// Mark the session as completed
|
||||
pub fn complete(&mut self) {
|
||||
self.status = TrainingStatus::Completed;
|
||||
self.completed_at = Some(Utc::now());
|
||||
self.updated_at = Utc::now();
|
||||
}
|
||||
|
||||
/// Mark the session as failed
|
||||
pub fn fail(&mut self, error: impl Into<String>) {
|
||||
self.status = TrainingStatus::Failed;
|
||||
self.error_message = Some(error.into());
|
||||
self.completed_at = Some(Utc::now());
|
||||
self.updated_at = Utc::now();
|
||||
}
|
||||
|
||||
/// Pause the training session
|
||||
pub fn pause(&mut self) {
|
||||
if self.status == TrainingStatus::Running {
|
||||
self.status = TrainingStatus::Paused;
|
||||
self.updated_at = Utc::now();
|
||||
}
|
||||
}
|
||||
|
||||
/// Resume a paused session
|
||||
pub fn resume(&mut self) {
|
||||
if self.status == TrainingStatus::Paused {
|
||||
self.status = TrainingStatus::Running;
|
||||
self.updated_at = Utc::now();
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the training duration
|
||||
#[must_use]
|
||||
pub fn duration(&self) -> chrono::Duration {
|
||||
let end = self.completed_at.unwrap_or_else(Utc::now);
|
||||
end - self.started_at
|
||||
}
|
||||
|
||||
/// Check if training should stop early
|
||||
#[must_use]
|
||||
pub fn should_early_stop(&self) -> bool {
|
||||
if let Some(patience) = self.config.hyperparameters.early_stopping_patience {
|
||||
if self.metrics_history.len() <= patience {
|
||||
return false;
|
||||
}
|
||||
|
||||
let best_epoch = self
|
||||
.metrics_history
|
||||
.iter()
|
||||
.enumerate()
|
||||
.min_by(|(_, a), (_, b)| a.loss.partial_cmp(&b.loss).unwrap())
|
||||
.map(|(i, _)| i)
|
||||
.unwrap_or(0);
|
||||
|
||||
self.metrics_history.len() - best_epoch > patience
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A node in the transition graph
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GraphNode {
|
||||
/// Embedding ID for this node
|
||||
pub id: EmbeddingId,
|
||||
/// The embedding vector
|
||||
pub embedding: Vec<f32>,
|
||||
/// Optional node features
|
||||
pub features: Option<Vec<f32>>,
|
||||
/// Node label (for supervised learning)
|
||||
pub label: Option<usize>,
|
||||
/// Metadata associated with this node
|
||||
#[serde(default)]
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl GraphNode {
|
||||
/// Create a new graph node
|
||||
#[must_use]
|
||||
pub fn new(id: EmbeddingId, embedding: Vec<f32>) -> Self {
|
||||
Self {
|
||||
id,
|
||||
embedding,
|
||||
features: None,
|
||||
label: None,
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the embedding dimension
|
||||
#[must_use]
|
||||
pub fn dim(&self) -> usize {
|
||||
self.embedding.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// An edge in the transition graph
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GraphEdge {
|
||||
/// Source node index
|
||||
pub from: usize,
|
||||
/// Target node index
|
||||
pub to: usize,
|
||||
/// Edge weight (e.g., similarity score)
|
||||
pub weight: f32,
|
||||
/// Edge type for heterogeneous graphs
|
||||
pub edge_type: Option<String>,
|
||||
}
|
||||
|
||||
impl GraphEdge {
|
||||
/// Create a new edge
|
||||
#[must_use]
|
||||
pub fn new(from: usize, to: usize, weight: f32) -> Self {
|
||||
Self {
|
||||
from,
|
||||
to,
|
||||
weight,
|
||||
edge_type: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a typed edge
|
||||
#[must_use]
|
||||
pub fn typed(from: usize, to: usize, weight: f32, edge_type: impl Into<String>) -> Self {
|
||||
Self {
|
||||
from,
|
||||
to,
|
||||
weight,
|
||||
edge_type: Some(edge_type.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A graph representing transitions between embeddings
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TransitionGraph {
|
||||
/// Nodes in the graph (embedding IDs)
|
||||
pub nodes: Vec<EmbeddingId>,
|
||||
/// Node embeddings (parallel to nodes)
|
||||
pub embeddings: Vec<Vec<f32>>,
|
||||
/// Edges as (from_index, to_index, weight) tuples
|
||||
pub edges: Vec<(usize, usize, f32)>,
|
||||
/// Optional node labels for supervised learning
|
||||
#[serde(default)]
|
||||
pub labels: Vec<Option<usize>>,
|
||||
/// Number of unique classes (if labeled)
|
||||
pub num_classes: Option<usize>,
|
||||
/// Whether the graph is directed
|
||||
pub directed: bool,
|
||||
}
|
||||
|
||||
impl Default for TransitionGraph {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl TransitionGraph {
|
||||
/// Create a new empty transition graph
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
nodes: Vec::new(),
|
||||
embeddings: Vec::new(),
|
||||
edges: Vec::new(),
|
||||
labels: Vec::new(),
|
||||
num_classes: None,
|
||||
directed: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an undirected graph
|
||||
#[must_use]
|
||||
pub fn undirected() -> Self {
|
||||
Self {
|
||||
directed: false,
|
||||
..Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a node to the graph
|
||||
pub fn add_node(&mut self, id: EmbeddingId, embedding: Vec<f32>, label: Option<usize>) {
|
||||
self.nodes.push(id);
|
||||
self.embeddings.push(embedding);
|
||||
self.labels.push(label);
|
||||
}
|
||||
|
||||
/// Add an edge to the graph
|
||||
pub fn add_edge(&mut self, from: usize, to: usize, weight: f32) {
|
||||
assert!(from < self.nodes.len(), "Invalid 'from' node index");
|
||||
assert!(to < self.nodes.len(), "Invalid 'to' node index");
|
||||
self.edges.push((from, to, weight));
|
||||
|
||||
// For undirected graphs, add reverse edge
|
||||
if !self.directed {
|
||||
self.edges.push((to, from, weight));
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the number of nodes
|
||||
#[must_use]
|
||||
pub fn num_nodes(&self) -> usize {
|
||||
self.nodes.len()
|
||||
}
|
||||
|
||||
/// Get the number of edges
|
||||
#[must_use]
|
||||
pub fn num_edges(&self) -> usize {
|
||||
self.edges.len()
|
||||
}
|
||||
|
||||
/// Get the embedding dimension (assumes all embeddings have same dimension)
|
||||
#[must_use]
|
||||
pub fn embedding_dim(&self) -> Option<usize> {
|
||||
self.embeddings.first().map(Vec::len)
|
||||
}
|
||||
|
||||
/// Get neighbors of a node
|
||||
#[must_use]
|
||||
pub fn neighbors(&self, node_idx: usize) -> Vec<(usize, f32)> {
|
||||
self.edges
|
||||
.iter()
|
||||
.filter(|(from, _, _)| *from == node_idx)
|
||||
.map(|(_, to, weight)| (*to, *weight))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get the adjacency list representation
|
||||
#[must_use]
|
||||
pub fn adjacency_list(&self) -> Vec<Vec<(usize, f32)>> {
|
||||
let mut adj = vec![Vec::new(); self.nodes.len()];
|
||||
for &(from, to, weight) in &self.edges {
|
||||
adj[from].push((to, weight));
|
||||
}
|
||||
adj
|
||||
}
|
||||
|
||||
/// Compute node degrees
|
||||
#[must_use]
|
||||
pub fn degrees(&self) -> Vec<usize> {
|
||||
let mut degrees = vec![0; self.nodes.len()];
|
||||
for &(from, to, _) in &self.edges {
|
||||
degrees[from] += 1;
|
||||
if !self.directed {
|
||||
degrees[to] += 1;
|
||||
}
|
||||
}
|
||||
degrees
|
||||
}
|
||||
|
||||
/// Validate the graph structure
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
if self.nodes.len() != self.embeddings.len() {
|
||||
return Err("Nodes and embeddings count mismatch".to_string());
|
||||
}
|
||||
if !self.labels.is_empty() && self.labels.len() != self.nodes.len() {
|
||||
return Err("Labels count mismatch".to_string());
|
||||
}
|
||||
for &(from, to, _) in &self.edges {
|
||||
if from >= self.nodes.len() || to >= self.nodes.len() {
|
||||
return Err(format!("Invalid edge: ({from}, {to})"));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A refined embedding produced by the learning process
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RefinedEmbedding {
|
||||
/// ID of the original embedding that was refined
|
||||
pub original_id: EmbeddingId,
|
||||
/// The refined embedding vector
|
||||
pub refined_vector: Vec<f32>,
|
||||
/// Score indicating quality of refinement (0.0 to 1.0)
|
||||
pub refinement_score: f32,
|
||||
/// The session that produced this refinement
|
||||
pub session_id: Option<String>,
|
||||
/// Timestamp of refinement
|
||||
pub refined_at: Timestamp,
|
||||
/// Delta from original (optional, for analysis)
|
||||
pub delta_norm: Option<f32>,
|
||||
/// Confidence in the refinement
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
impl RefinedEmbedding {
|
||||
/// Create a new refined embedding
|
||||
#[must_use]
|
||||
pub fn new(
|
||||
original_id: EmbeddingId,
|
||||
refined_vector: Vec<f32>,
|
||||
refinement_score: f32,
|
||||
) -> Self {
|
||||
Self {
|
||||
original_id,
|
||||
refined_vector,
|
||||
refinement_score,
|
||||
session_id: None,
|
||||
refined_at: Utc::now(),
|
||||
delta_norm: None,
|
||||
confidence: refinement_score,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the delta norm from original embedding
|
||||
pub fn compute_delta(&mut self, original: &[f32]) {
|
||||
if original.len() != self.refined_vector.len() {
|
||||
return;
|
||||
}
|
||||
let delta: f32 = original
|
||||
.iter()
|
||||
.zip(&self.refined_vector)
|
||||
.map(|(a, b)| (a - b).powi(2))
|
||||
.sum();
|
||||
self.delta_norm = Some(delta.sqrt());
|
||||
}
|
||||
|
||||
/// Get the embedding dimension
|
||||
#[must_use]
|
||||
pub fn dim(&self) -> usize {
|
||||
self.refined_vector.len()
|
||||
}
|
||||
|
||||
/// Normalize the refined vector to unit length
|
||||
pub fn normalize(&mut self) {
|
||||
let norm: f32 = self.refined_vector.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 1e-10 {
|
||||
for x in &mut self.refined_vector {
|
||||
*x /= norm;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_embedding_id() {
|
||||
let id = EmbeddingId::new("test-123");
|
||||
assert_eq!(id.as_str(), "test-123");
|
||||
|
||||
let generated = EmbeddingId::generate();
|
||||
assert!(!generated.as_str().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gnn_model_type() {
|
||||
assert_eq!(GnnModelType::default(), GnnModelType::Gcn);
|
||||
assert_eq!(GnnModelType::Gat.recommended_heads(), 8);
|
||||
assert_eq!(GnnModelType::Gcn.recommended_heads(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_training_status() {
|
||||
assert!(!TrainingStatus::Running.is_terminal());
|
||||
assert!(TrainingStatus::Completed.is_terminal());
|
||||
assert!(TrainingStatus::Failed.is_terminal());
|
||||
assert!(TrainingStatus::Paused.can_resume());
|
||||
assert!(!TrainingStatus::Completed.can_resume());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_training_metrics() {
|
||||
let metrics = TrainingMetrics::new(1, 0.5, 0.8, 0.001);
|
||||
assert_eq!(metrics.epoch, 1);
|
||||
assert_eq!(metrics.loss, 0.5);
|
||||
|
||||
let better = TrainingMetrics::new(2, 0.3, 0.9, 0.001);
|
||||
assert!(better.is_improving(&metrics));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_learning_session() {
|
||||
let config = LearningConfig::default();
|
||||
let mut session = LearningSession::new(config);
|
||||
|
||||
assert_eq!(session.status, TrainingStatus::Pending);
|
||||
|
||||
session.start();
|
||||
assert_eq!(session.status, TrainingStatus::Running);
|
||||
|
||||
let metrics = TrainingMetrics::new(1, 0.5, 0.8, 0.001);
|
||||
session.update_metrics(metrics);
|
||||
assert_eq!(session.metrics_history.len(), 1);
|
||||
|
||||
session.complete();
|
||||
assert_eq!(session.status, TrainingStatus::Completed);
|
||||
assert!(session.completed_at.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transition_graph() {
|
||||
let mut graph = TransitionGraph::new();
|
||||
|
||||
let emb1 = vec![0.1, 0.2, 0.3];
|
||||
let emb2 = vec![0.4, 0.5, 0.6];
|
||||
|
||||
graph.add_node(EmbeddingId::new("n1"), emb1, Some(0));
|
||||
graph.add_node(EmbeddingId::new("n2"), emb2, Some(1));
|
||||
graph.add_edge(0, 1, 0.8);
|
||||
|
||||
assert_eq!(graph.num_nodes(), 2);
|
||||
assert_eq!(graph.num_edges(), 1);
|
||||
assert_eq!(graph.embedding_dim(), Some(3));
|
||||
|
||||
let neighbors = graph.neighbors(0);
|
||||
assert_eq!(neighbors.len(), 1);
|
||||
assert_eq!(neighbors[0], (1, 0.8));
|
||||
|
||||
assert!(graph.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_refined_embedding() {
|
||||
let original = vec![1.0, 0.0, 0.0];
|
||||
let refined = vec![0.9, 0.1, 0.0];
|
||||
|
||||
let mut re = RefinedEmbedding::new(
|
||||
EmbeddingId::new("test"),
|
||||
refined,
|
||||
0.95,
|
||||
);
|
||||
|
||||
re.compute_delta(&original);
|
||||
assert!(re.delta_norm.is_some());
|
||||
|
||||
re.normalize();
|
||||
let norm: f32 = re.refined_vector.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
assert!((norm - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_early_stopping() {
|
||||
let mut config = LearningConfig::default();
|
||||
config.hyperparameters.early_stopping_patience = Some(3);
|
||||
|
||||
let mut session = LearningSession::new(config);
|
||||
session.start();
|
||||
|
||||
// Improving metrics
|
||||
for i in 0..5 {
|
||||
let loss = 1.0 - (i as f32 * 0.1);
|
||||
session.update_metrics(TrainingMetrics::new(i, loss, 0.8, 0.001));
|
||||
}
|
||||
assert!(!session.should_early_stop());
|
||||
|
||||
// Non-improving metrics
|
||||
for i in 5..10 {
|
||||
session.update_metrics(TrainingMetrics::new(i, 0.6, 0.8, 0.001));
|
||||
}
|
||||
assert!(session.should_early_stop());
|
||||
}
|
||||
}
|
||||
7
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/src/domain/mod.rs
vendored
Normal file
7
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/src/domain/mod.rs
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
//! Domain layer for the learning bounded context.
|
||||
//!
|
||||
//! Contains core entities, value objects, and repository traits
|
||||
//! that define the learning domain model.
|
||||
|
||||
pub mod entities;
|
||||
pub mod repository;
|
||||
462
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/src/domain/repository.rs
vendored
Normal file
462
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/src/domain/repository.rs
vendored
Normal file
@@ -0,0 +1,462 @@
|
||||
//! Repository traits for the learning domain.
|
||||
//!
|
||||
//! Defines the persistence abstraction for learning sessions,
|
||||
//! refined embeddings, and transition graphs.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::entities::{
|
||||
EmbeddingId, LearningSession, RefinedEmbedding, TrainingStatus, TransitionGraph,
|
||||
};
|
||||
|
||||
/// Error type for repository operations
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RepositoryError {
|
||||
/// Session not found
|
||||
#[error("Learning session not found: {0}")]
|
||||
SessionNotFound(String),
|
||||
|
||||
/// Embedding not found
|
||||
#[error("Embedding not found: {0}")]
|
||||
EmbeddingNotFound(String),
|
||||
|
||||
/// Graph not found or empty
|
||||
#[error("Transition graph not found")]
|
||||
GraphNotFound,
|
||||
|
||||
/// Serialization error
|
||||
#[error("Serialization error: {0}")]
|
||||
SerializationError(String),
|
||||
|
||||
/// Storage error
|
||||
#[error("Storage error: {0}")]
|
||||
StorageError(String),
|
||||
|
||||
/// Connection error
|
||||
#[error("Connection error: {0}")]
|
||||
ConnectionError(String),
|
||||
|
||||
/// Validation error
|
||||
#[error("Validation error: {0}")]
|
||||
ValidationError(String),
|
||||
|
||||
/// Concurrent modification error
|
||||
#[error("Concurrent modification detected for: {0}")]
|
||||
ConcurrentModification(String),
|
||||
|
||||
/// Internal error
|
||||
#[error("Internal repository error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for RepositoryError {
|
||||
fn from(e: serde_json::Error) -> Self {
|
||||
Self::SerializationError(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Result type for repository operations
|
||||
pub type RepositoryResult<T> = Result<T, RepositoryError>;
|
||||
|
||||
/// Repository trait for learning persistence operations.
|
||||
///
|
||||
/// Implementors should provide durable storage for:
|
||||
/// - Learning sessions and their state
|
||||
/// - Refined embeddings
|
||||
/// - Transition graphs
|
||||
#[async_trait]
|
||||
pub trait LearningRepository: Send + Sync {
|
||||
// =========== Session Operations ===========
|
||||
|
||||
/// Save a learning session
|
||||
async fn save_session(&self, session: &LearningSession) -> RepositoryResult<()>;
|
||||
|
||||
/// Get a learning session by ID
|
||||
async fn get_session(&self, id: &str) -> RepositoryResult<Option<LearningSession>>;
|
||||
|
||||
/// Update an existing session
|
||||
async fn update_session(&self, session: &LearningSession) -> RepositoryResult<()>;
|
||||
|
||||
/// Delete a session
|
||||
async fn delete_session(&self, id: &str) -> RepositoryResult<()>;
|
||||
|
||||
/// List sessions with optional status filter
|
||||
async fn list_sessions(
|
||||
&self,
|
||||
status: Option<TrainingStatus>,
|
||||
limit: Option<usize>,
|
||||
) -> RepositoryResult<Vec<LearningSession>>;
|
||||
|
||||
// =========== Embedding Operations ===========
|
||||
|
||||
/// Save refined embeddings (batch)
|
||||
async fn save_refined_embeddings(
|
||||
&self,
|
||||
embeddings: &[RefinedEmbedding],
|
||||
) -> RepositoryResult<()>;
|
||||
|
||||
/// Get a refined embedding by original ID
|
||||
async fn get_refined_embedding(
|
||||
&self,
|
||||
original_id: &EmbeddingId,
|
||||
) -> RepositoryResult<Option<RefinedEmbedding>>;
|
||||
|
||||
/// Get multiple refined embeddings
|
||||
async fn get_refined_embeddings(
|
||||
&self,
|
||||
ids: &[EmbeddingId],
|
||||
) -> RepositoryResult<Vec<RefinedEmbedding>>;
|
||||
|
||||
/// Delete refined embeddings for a session
|
||||
async fn delete_refined_embeddings(&self, session_id: &str) -> RepositoryResult<usize>;
|
||||
|
||||
// =========== Graph Operations ===========
|
||||
|
||||
/// Get the current transition graph
|
||||
async fn get_transition_graph(&self) -> RepositoryResult<TransitionGraph>;
|
||||
|
||||
/// Save a transition graph
|
||||
async fn save_transition_graph(&self, graph: &TransitionGraph) -> RepositoryResult<()>;
|
||||
|
||||
/// Update the transition graph (incremental)
|
||||
async fn update_transition_graph(&self, graph: &TransitionGraph) -> RepositoryResult<()>;
|
||||
|
||||
/// Clear the transition graph
|
||||
async fn clear_transition_graph(&self) -> RepositoryResult<()>;
|
||||
|
||||
// =========== Checkpoint Operations ===========
|
||||
|
||||
/// Save a model checkpoint
|
||||
async fn save_checkpoint(
|
||||
&self,
|
||||
session_id: &str,
|
||||
epoch: usize,
|
||||
data: &[u8],
|
||||
) -> RepositoryResult<String>;
|
||||
|
||||
/// Load a model checkpoint
|
||||
async fn load_checkpoint(
|
||||
&self,
|
||||
session_id: &str,
|
||||
epoch: Option<usize>,
|
||||
) -> RepositoryResult<Option<Vec<u8>>>;
|
||||
|
||||
/// List available checkpoints for a session
|
||||
async fn list_checkpoints(&self, session_id: &str) -> RepositoryResult<Vec<(usize, String)>>;
|
||||
|
||||
/// Delete checkpoints for a session
|
||||
async fn delete_checkpoints(&self, session_id: &str) -> RepositoryResult<usize>;
|
||||
}
|
||||
|
||||
/// Extension trait for repository operations
|
||||
#[async_trait]
|
||||
pub trait LearningRepositoryExt: LearningRepository {
|
||||
/// Get the latest session for a model type
|
||||
async fn get_latest_session(
|
||||
&self,
|
||||
model_type: crate::GnnModelType,
|
||||
) -> RepositoryResult<Option<LearningSession>> {
|
||||
let sessions = self.list_sessions(None, Some(100)).await?;
|
||||
Ok(sessions
|
||||
.into_iter()
|
||||
.filter(|s| s.model_type == model_type)
|
||||
.max_by_key(|s| s.started_at))
|
||||
}
|
||||
|
||||
/// Get all completed sessions
|
||||
async fn get_completed_sessions(&self) -> RepositoryResult<Vec<LearningSession>> {
|
||||
self.list_sessions(Some(TrainingStatus::Completed), None).await
|
||||
}
|
||||
|
||||
/// Check if any session is currently running
|
||||
async fn has_running_session(&self) -> RepositoryResult<bool> {
|
||||
let sessions = self.list_sessions(Some(TrainingStatus::Running), Some(1)).await?;
|
||||
Ok(!sessions.is_empty())
|
||||
}
|
||||
|
||||
/// Get embeddings refined in a specific session
|
||||
async fn get_session_embeddings(
|
||||
&self,
|
||||
session_id: &str,
|
||||
) -> RepositoryResult<Vec<RefinedEmbedding>> {
|
||||
// Default implementation - may be overridden for efficiency
|
||||
let session = self.get_session(session_id).await?;
|
||||
if session.is_none() {
|
||||
return Err(RepositoryError::SessionNotFound(session_id.to_string()));
|
||||
}
|
||||
|
||||
// This would need to be implemented properly in concrete implementations
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
// Blanket implementation
|
||||
impl<T: LearningRepository + ?Sized> LearningRepositoryExt for T {}
|
||||
|
||||
/// A thread-safe repository handle
|
||||
pub type DynLearningRepository = Arc<dyn LearningRepository>;
|
||||
|
||||
/// Unit of work pattern for transactional operations
|
||||
#[async_trait]
|
||||
pub trait UnitOfWork: Send + Sync {
|
||||
/// Begin a transaction
|
||||
async fn begin(&self) -> RepositoryResult<()>;
|
||||
|
||||
/// Commit the transaction
|
||||
async fn commit(&self) -> RepositoryResult<()>;
|
||||
|
||||
/// Rollback the transaction
|
||||
async fn rollback(&self) -> RepositoryResult<()>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// In-memory implementation for testing
|
||||
struct InMemoryRepository {
|
||||
sessions: RwLock<HashMap<String, LearningSession>>,
|
||||
embeddings: RwLock<HashMap<String, RefinedEmbedding>>,
|
||||
graph: RwLock<Option<TransitionGraph>>,
|
||||
checkpoints: RwLock<HashMap<String, Vec<(usize, Vec<u8>)>>>,
|
||||
}
|
||||
|
||||
impl InMemoryRepository {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
sessions: RwLock::new(HashMap::new()),
|
||||
embeddings: RwLock::new(HashMap::new()),
|
||||
graph: RwLock::new(None),
|
||||
checkpoints: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LearningRepository for InMemoryRepository {
|
||||
async fn save_session(&self, session: &LearningSession) -> RepositoryResult<()> {
|
||||
let mut sessions = self.sessions.write().await;
|
||||
sessions.insert(session.id.clone(), session.clone());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_session(&self, id: &str) -> RepositoryResult<Option<LearningSession>> {
|
||||
let sessions = self.sessions.read().await;
|
||||
Ok(sessions.get(id).cloned())
|
||||
}
|
||||
|
||||
async fn update_session(&self, session: &LearningSession) -> RepositoryResult<()> {
|
||||
self.save_session(session).await
|
||||
}
|
||||
|
||||
async fn delete_session(&self, id: &str) -> RepositoryResult<()> {
|
||||
let mut sessions = self.sessions.write().await;
|
||||
sessions.remove(id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list_sessions(
|
||||
&self,
|
||||
status: Option<TrainingStatus>,
|
||||
limit: Option<usize>,
|
||||
) -> RepositoryResult<Vec<LearningSession>> {
|
||||
let sessions = self.sessions.read().await;
|
||||
let mut result: Vec<_> = sessions
|
||||
.values()
|
||||
.filter(|s| status.map_or(true, |st| s.status == st))
|
||||
.cloned()
|
||||
.collect();
|
||||
result.sort_by(|a, b| b.started_at.cmp(&a.started_at));
|
||||
if let Some(limit) = limit {
|
||||
result.truncate(limit);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn save_refined_embeddings(
|
||||
&self,
|
||||
embeddings: &[RefinedEmbedding],
|
||||
) -> RepositoryResult<()> {
|
||||
let mut store = self.embeddings.write().await;
|
||||
for emb in embeddings {
|
||||
store.insert(emb.original_id.0.clone(), emb.clone());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_refined_embedding(
|
||||
&self,
|
||||
original_id: &EmbeddingId,
|
||||
) -> RepositoryResult<Option<RefinedEmbedding>> {
|
||||
let store = self.embeddings.read().await;
|
||||
Ok(store.get(&original_id.0).cloned())
|
||||
}
|
||||
|
||||
async fn get_refined_embeddings(
|
||||
&self,
|
||||
ids: &[EmbeddingId],
|
||||
) -> RepositoryResult<Vec<RefinedEmbedding>> {
|
||||
let store = self.embeddings.read().await;
|
||||
Ok(ids
|
||||
.iter()
|
||||
.filter_map(|id| store.get(&id.0).cloned())
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn delete_refined_embeddings(&self, _session_id: &str) -> RepositoryResult<usize> {
|
||||
let mut store = self.embeddings.write().await;
|
||||
let count = store.len();
|
||||
store.clear();
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
async fn get_transition_graph(&self) -> RepositoryResult<TransitionGraph> {
|
||||
let graph = self.graph.read().await;
|
||||
graph.clone().ok_or(RepositoryError::GraphNotFound)
|
||||
}
|
||||
|
||||
async fn save_transition_graph(&self, graph: &TransitionGraph) -> RepositoryResult<()> {
|
||||
let mut store = self.graph.write().await;
|
||||
*store = Some(graph.clone());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_transition_graph(&self, graph: &TransitionGraph) -> RepositoryResult<()> {
|
||||
self.save_transition_graph(graph).await
|
||||
}
|
||||
|
||||
async fn clear_transition_graph(&self) -> RepositoryResult<()> {
|
||||
let mut store = self.graph.write().await;
|
||||
*store = None;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn save_checkpoint(
|
||||
&self,
|
||||
session_id: &str,
|
||||
epoch: usize,
|
||||
data: &[u8],
|
||||
) -> RepositoryResult<String> {
|
||||
let mut store = self.checkpoints.write().await;
|
||||
let checkpoints = store.entry(session_id.to_string()).or_default();
|
||||
checkpoints.push((epoch, data.to_vec()));
|
||||
Ok(format!("{session_id}-{epoch}"))
|
||||
}
|
||||
|
||||
async fn load_checkpoint(
|
||||
&self,
|
||||
session_id: &str,
|
||||
epoch: Option<usize>,
|
||||
) -> RepositoryResult<Option<Vec<u8>>> {
|
||||
let store = self.checkpoints.read().await;
|
||||
if let Some(checkpoints) = store.get(session_id) {
|
||||
if let Some(epoch) = epoch {
|
||||
return Ok(checkpoints
|
||||
.iter()
|
||||
.find(|(e, _)| *e == epoch)
|
||||
.map(|(_, d)| d.clone()));
|
||||
}
|
||||
return Ok(checkpoints.last().map(|(_, d)| d.clone()));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list_checkpoints(
|
||||
&self,
|
||||
session_id: &str,
|
||||
) -> RepositoryResult<Vec<(usize, String)>> {
|
||||
let store = self.checkpoints.read().await;
|
||||
if let Some(checkpoints) = store.get(session_id) {
|
||||
return Ok(checkpoints
|
||||
.iter()
|
||||
.map(|(e, _)| (*e, format!("{session_id}-{e}")))
|
||||
.collect());
|
||||
}
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
async fn delete_checkpoints(&self, session_id: &str) -> RepositoryResult<usize> {
|
||||
let mut store = self.checkpoints.write().await;
|
||||
if let Some(checkpoints) = store.remove(session_id) {
|
||||
return Ok(checkpoints.len());
|
||||
}
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_in_memory_repository() {
|
||||
let repo = InMemoryRepository::new();
|
||||
let config = crate::LearningConfig::default();
|
||||
let session = crate::LearningSession::new(config);
|
||||
|
||||
// Save and retrieve session
|
||||
repo.save_session(&session).await.unwrap();
|
||||
let retrieved = repo.get_session(&session.id).await.unwrap();
|
||||
assert!(retrieved.is_some());
|
||||
assert_eq!(retrieved.unwrap().id, session.id);
|
||||
|
||||
// List sessions
|
||||
let sessions = repo.list_sessions(None, None).await.unwrap();
|
||||
assert_eq!(sessions.len(), 1);
|
||||
|
||||
// Delete session
|
||||
repo.delete_session(&session.id).await.unwrap();
|
||||
let retrieved = repo.get_session(&session.id).await.unwrap();
|
||||
assert!(retrieved.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transition_graph_operations() {
|
||||
let repo = InMemoryRepository::new();
|
||||
|
||||
// Graph should not exist initially
|
||||
assert!(repo.get_transition_graph().await.is_err());
|
||||
|
||||
// Save graph
|
||||
let mut graph = TransitionGraph::new();
|
||||
graph.add_node(
|
||||
crate::EmbeddingId::new("n1"),
|
||||
vec![0.1, 0.2, 0.3],
|
||||
None,
|
||||
);
|
||||
repo.save_transition_graph(&graph).await.unwrap();
|
||||
|
||||
// Retrieve graph
|
||||
let retrieved = repo.get_transition_graph().await.unwrap();
|
||||
assert_eq!(retrieved.num_nodes(), 1);
|
||||
|
||||
// Clear graph
|
||||
repo.clear_transition_graph().await.unwrap();
|
||||
assert!(repo.get_transition_graph().await.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_checkpoint_operations() {
|
||||
let repo = InMemoryRepository::new();
|
||||
let session_id = "test-session";
|
||||
|
||||
// Save checkpoints
|
||||
repo.save_checkpoint(session_id, 1, b"data1").await.unwrap();
|
||||
repo.save_checkpoint(session_id, 2, b"data2").await.unwrap();
|
||||
|
||||
// List checkpoints
|
||||
let checkpoints = repo.list_checkpoints(session_id).await.unwrap();
|
||||
assert_eq!(checkpoints.len(), 2);
|
||||
|
||||
// Load specific checkpoint
|
||||
let data = repo.load_checkpoint(session_id, Some(1)).await.unwrap();
|
||||
assert_eq!(data, Some(b"data1".to_vec()));
|
||||
|
||||
// Load latest checkpoint
|
||||
let data = repo.load_checkpoint(session_id, None).await.unwrap();
|
||||
assert_eq!(data, Some(b"data2".to_vec()));
|
||||
|
||||
// Delete checkpoints
|
||||
let count = repo.delete_checkpoints(session_id).await.unwrap();
|
||||
assert_eq!(count, 2);
|
||||
}
|
||||
}
|
||||
610
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/src/ewc.rs
vendored
Normal file
610
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/src/ewc.rs
vendored
Normal file
@@ -0,0 +1,610 @@
|
||||
//! Elastic Weight Consolidation (EWC) for continual learning.
|
||||
//!
|
||||
//! EWC prevents catastrophic forgetting by regularizing weight updates
|
||||
//! based on their importance for previously learned tasks.
|
||||
//!
|
||||
//! Reference: Kirkpatrick et al., "Overcoming catastrophic forgetting in neural networks", PNAS 2017
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::infrastructure::gnn_model::GnnModel;
|
||||
|
||||
/// Fisher Information matrix (diagonal approximation).
|
||||
///
|
||||
/// Stores the importance weights for each parameter in the model.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct FisherInformation {
|
||||
/// Diagonal elements of the Fisher matrix per layer
|
||||
pub diagonal: Vec<Vec<f32>>,
|
||||
/// Optional parameter names for debugging
|
||||
#[serde(default)]
|
||||
pub param_names: Vec<String>,
|
||||
/// Number of samples used to compute Fisher
|
||||
pub num_samples: usize,
|
||||
}
|
||||
|
||||
impl FisherInformation {
|
||||
/// Create a new Fisher information structure
|
||||
#[must_use]
|
||||
pub fn new(num_layers: usize) -> Self {
|
||||
Self {
|
||||
diagonal: vec![Vec::new(); num_layers],
|
||||
param_names: Vec::new(),
|
||||
num_samples: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create from parameter gradients (diagonal approximation)
|
||||
///
|
||||
/// F_ii = E[(∂L/∂θ_i)²]
|
||||
#[must_use]
|
||||
pub fn from_gradients(gradients: &[Vec<f32>]) -> Self {
|
||||
let diagonal: Vec<Vec<f32>> = gradients
|
||||
.iter()
|
||||
.map(|grads| grads.iter().map(|g| g * g).collect())
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
diagonal,
|
||||
param_names: Vec::new(),
|
||||
num_samples: 1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update Fisher with new gradient samples (online estimation)
|
||||
pub fn update(&mut self, gradients: &[Vec<f32>]) {
|
||||
let n = self.num_samples as f32;
|
||||
|
||||
for (layer_idx, grads) in gradients.iter().enumerate() {
|
||||
if layer_idx >= self.diagonal.len() {
|
||||
self.diagonal.push(grads.iter().map(|g| g * g).collect());
|
||||
} else {
|
||||
// Running average: F_new = (n * F_old + g²) / (n + 1)
|
||||
for (i, &g) in grads.iter().enumerate() {
|
||||
if i >= self.diagonal[layer_idx].len() {
|
||||
self.diagonal[layer_idx].push(g * g);
|
||||
} else {
|
||||
let old_val = self.diagonal[layer_idx][i];
|
||||
self.diagonal[layer_idx][i] = (n * old_val + g * g) / (n + 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.num_samples += 1;
|
||||
}
|
||||
|
||||
/// Get the importance of a parameter
|
||||
#[must_use]
|
||||
pub fn get_importance(&self, layer_idx: usize, param_idx: usize) -> f32 {
|
||||
self.diagonal
|
||||
.get(layer_idx)
|
||||
.and_then(|layer| layer.get(param_idx))
|
||||
.copied()
|
||||
.unwrap_or(0.0)
|
||||
}
|
||||
|
||||
/// Get total importance (sum of all diagonal elements)
|
||||
#[must_use]
|
||||
pub fn total_importance(&self) -> f32 {
|
||||
self.diagonal.iter().flat_map(|l| l.iter()).sum()
|
||||
}
|
||||
|
||||
/// Normalize Fisher information
|
||||
pub fn normalize(&mut self) {
|
||||
let total = self.total_importance();
|
||||
if total > 1e-10 {
|
||||
for layer in &mut self.diagonal {
|
||||
for val in layer.iter_mut() {
|
||||
*val /= total;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply decay to Fisher information (for progressive consolidation)
|
||||
pub fn decay(&mut self, factor: f32) {
|
||||
for layer in &mut self.diagonal {
|
||||
for val in layer.iter_mut() {
|
||||
*val *= factor;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// State snapshot for EWC regularization.
|
||||
///
|
||||
/// Stores the model parameters and their importance from previous tasks.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct EwcState {
|
||||
/// Optimal parameters for previous task(s)
|
||||
pub optimal_params: Vec<Vec<f32>>,
|
||||
/// Fisher information (importance weights)
|
||||
pub fisher: FisherInformation,
|
||||
/// Task identifier
|
||||
pub task_id: Option<String>,
|
||||
/// Number of tasks consolidated
|
||||
pub num_tasks: usize,
|
||||
}
|
||||
|
||||
impl EwcState {
|
||||
/// Create a new EWC state from model parameters and Fisher information
|
||||
#[must_use]
|
||||
pub fn new(params: Vec<f32>, fisher: FisherInformation) -> Self {
|
||||
Self {
|
||||
optimal_params: vec![params],
|
||||
fisher,
|
||||
task_id: None,
|
||||
num_tasks: 1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create from model with computed Fisher
|
||||
#[must_use]
|
||||
pub fn from_model(model: &GnnModel, gradients: &[Vec<f32>]) -> Self {
|
||||
let params = model.get_parameters();
|
||||
let fisher = FisherInformation::from_gradients(gradients);
|
||||
|
||||
Self::new(params, fisher)
|
||||
}
|
||||
|
||||
/// Merge with another EWC state (for multi-task learning)
|
||||
pub fn merge(&mut self, other: &EwcState) {
|
||||
// Add new optimal params
|
||||
self.optimal_params.extend(other.optimal_params.clone());
|
||||
self.num_tasks += other.num_tasks;
|
||||
|
||||
// Merge Fisher information (average)
|
||||
for (layer_idx, other_layer) in other.fisher.diagonal.iter().enumerate() {
|
||||
if layer_idx >= self.fisher.diagonal.len() {
|
||||
self.fisher.diagonal.push(other_layer.clone());
|
||||
} else {
|
||||
for (i, &val) in other_layer.iter().enumerate() {
|
||||
if i >= self.fisher.diagonal[layer_idx].len() {
|
||||
self.fisher.diagonal[layer_idx].push(val);
|
||||
} else {
|
||||
// Average Fisher values
|
||||
let old_n = (self.num_tasks - other.num_tasks) as f32;
|
||||
let new_n = other.num_tasks as f32;
|
||||
let total_n = self.num_tasks as f32;
|
||||
|
||||
self.fisher.diagonal[layer_idx][i] =
|
||||
(old_n * self.fisher.diagonal[layer_idx][i] + new_n * val) / total_n;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// EWC regularizer for computing the penalty term.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EwcRegularizer {
|
||||
/// Lambda coefficient (importance of old task knowledge)
|
||||
lambda: f32,
|
||||
/// Optional per-layer lambda scaling
|
||||
layer_lambdas: Option<Vec<f32>>,
|
||||
}
|
||||
|
||||
impl Default for EwcRegularizer {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
lambda: 5000.0,
|
||||
layer_lambdas: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EwcRegularizer {
|
||||
/// Create a new EWC regularizer with the given lambda
|
||||
#[must_use]
|
||||
pub fn new(lambda: f32) -> Self {
|
||||
Self {
|
||||
lambda,
|
||||
layer_lambdas: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with per-layer lambda scaling
|
||||
#[must_use]
|
||||
pub fn with_layer_lambdas(mut self, lambdas: Vec<f32>) -> Self {
|
||||
self.layer_lambdas = Some(lambdas);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get lambda for a specific layer
|
||||
fn get_layer_lambda(&self, layer_idx: usize) -> f32 {
|
||||
self.layer_lambdas
|
||||
.as_ref()
|
||||
.and_then(|l| l.get(layer_idx))
|
||||
.copied()
|
||||
.unwrap_or(self.lambda)
|
||||
}
|
||||
|
||||
/// Compute the EWC penalty term.
|
||||
///
|
||||
/// L_EWC = (λ/2) * Σ F_i * (θ_i - θ*_i)²
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `model` - Current model
|
||||
/// * `ewc_state` - Saved EWC state from previous task
|
||||
///
|
||||
/// # Returns
|
||||
/// The EWC penalty value to add to the loss
|
||||
#[must_use]
|
||||
pub fn compute_penalty(&self, model: &GnnModel, ewc_state: &EwcState) -> f32 {
|
||||
let current_params = model.get_parameters();
|
||||
|
||||
// Flatten optimal params (use the most recent one if multiple)
|
||||
let optimal_params = ewc_state
|
||||
.optimal_params
|
||||
.last()
|
||||
.map(|p| p.as_slice())
|
||||
.unwrap_or(&[]);
|
||||
|
||||
if current_params.len() != optimal_params.len() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Compute penalty: (λ/2) * Σ F_i * (θ_i - θ*_i)²
|
||||
let mut penalty = 0.0;
|
||||
let mut param_idx = 0;
|
||||
|
||||
for (layer_idx, layer_fisher) in ewc_state.fisher.diagonal.iter().enumerate() {
|
||||
let _layer_lambda = self.get_layer_lambda(layer_idx);
|
||||
|
||||
for &fisher_val in layer_fisher {
|
||||
if param_idx < current_params.len() && param_idx < optimal_params.len() {
|
||||
let diff = current_params[param_idx] - optimal_params[param_idx];
|
||||
penalty += fisher_val * diff * diff;
|
||||
}
|
||||
param_idx += 1;
|
||||
}
|
||||
}
|
||||
|
||||
(self.lambda / 2.0) * penalty
|
||||
}
|
||||
|
||||
/// Compute EWC gradient contribution.
|
||||
///
|
||||
/// ∂L_EWC/∂θ_i = λ * F_i * (θ_i - θ*_i)
|
||||
#[must_use]
|
||||
pub fn compute_gradient(&self, model: &GnnModel, ewc_state: &EwcState) -> Vec<f32> {
|
||||
let current_params = model.get_parameters();
|
||||
|
||||
let optimal_params = ewc_state
|
||||
.optimal_params
|
||||
.last()
|
||||
.map(|p| p.as_slice())
|
||||
.unwrap_or(&[]);
|
||||
|
||||
let mut gradient = vec![0.0; current_params.len()];
|
||||
|
||||
if current_params.len() != optimal_params.len() {
|
||||
return gradient;
|
||||
}
|
||||
|
||||
let mut param_idx = 0;
|
||||
|
||||
for (layer_idx, layer_fisher) in ewc_state.fisher.diagonal.iter().enumerate() {
|
||||
let layer_lambda = self.get_layer_lambda(layer_idx);
|
||||
|
||||
for &fisher_val in layer_fisher {
|
||||
if param_idx < current_params.len() {
|
||||
let diff = current_params[param_idx] - optimal_params[param_idx];
|
||||
gradient[param_idx] = layer_lambda * fisher_val * diff;
|
||||
}
|
||||
param_idx += 1;
|
||||
}
|
||||
}
|
||||
|
||||
gradient
|
||||
}
|
||||
}
|
||||
|
||||
/// Online EWC implementation (for streaming/continual learning).
|
||||
///
|
||||
/// Uses a running estimate of the Fisher information.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OnlineEwc {
|
||||
/// Current EWC state
|
||||
state: Option<EwcState>,
|
||||
/// EWC lambda
|
||||
lambda: f32,
|
||||
/// Gamma for Fisher decay (between tasks)
|
||||
gamma: f32,
|
||||
}
|
||||
|
||||
impl OnlineEwc {
|
||||
/// Create a new Online EWC instance
|
||||
#[must_use]
|
||||
pub fn new(lambda: f32, gamma: f32) -> Self {
|
||||
Self {
|
||||
state: None,
|
||||
lambda,
|
||||
gamma,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the EWC state after completing a task
|
||||
pub fn update(&mut self, model: &GnnModel, gradients: &[Vec<f32>]) {
|
||||
let new_fisher = FisherInformation::from_gradients(gradients);
|
||||
let params = model.get_parameters();
|
||||
|
||||
if let Some(ref mut state) = self.state {
|
||||
// Decay old Fisher
|
||||
state.fisher.decay(self.gamma);
|
||||
|
||||
// Add new Fisher (scaled by (1 - gamma))
|
||||
for (layer_idx, new_layer) in new_fisher.diagonal.iter().enumerate() {
|
||||
if layer_idx >= state.fisher.diagonal.len() {
|
||||
state.fisher.diagonal.push(new_layer.clone());
|
||||
} else {
|
||||
for (i, &val) in new_layer.iter().enumerate() {
|
||||
if i >= state.fisher.diagonal[layer_idx].len() {
|
||||
state.fisher.diagonal[layer_idx].push((1.0 - self.gamma) * val);
|
||||
} else {
|
||||
state.fisher.diagonal[layer_idx][i] += (1.0 - self.gamma) * val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update optimal params (keep only the latest)
|
||||
state.optimal_params = vec![params];
|
||||
state.num_tasks += 1;
|
||||
} else {
|
||||
self.state = Some(EwcState::new(params, new_fisher));
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the EWC penalty
|
||||
#[must_use]
|
||||
pub fn compute_penalty(&self, model: &GnnModel) -> f32 {
|
||||
if let Some(ref state) = self.state {
|
||||
let regularizer = EwcRegularizer::new(self.lambda);
|
||||
regularizer.compute_penalty(model, state)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the EWC gradient contribution
|
||||
#[must_use]
|
||||
pub fn compute_gradient(&self, model: &GnnModel) -> Vec<f32> {
|
||||
if let Some(ref state) = self.state {
|
||||
let regularizer = EwcRegularizer::new(self.lambda);
|
||||
regularizer.compute_gradient(model, state)
|
||||
} else {
|
||||
vec![0.0; model.get_parameters().len()]
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the current state
|
||||
#[must_use]
|
||||
pub fn state(&self) -> Option<&EwcState> {
|
||||
self.state.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
/// Progress & Compress (P&C) - improved EWC variant.
|
||||
///
|
||||
/// Alternates between progress (learning new task) and compress (consolidation).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProgressAndCompress {
|
||||
/// Knowledge base (compressed knowledge from all tasks)
|
||||
knowledge_base: Option<EwcState>,
|
||||
/// Active column (for current task)
|
||||
active_params: Option<Vec<f32>>,
|
||||
/// EWC lambda
|
||||
lambda: f32,
|
||||
}
|
||||
|
||||
impl ProgressAndCompress {
|
||||
/// Create a new P&C instance
|
||||
#[must_use]
|
||||
pub fn new(lambda: f32) -> Self {
|
||||
Self {
|
||||
knowledge_base: None,
|
||||
active_params: None,
|
||||
lambda,
|
||||
}
|
||||
}
|
||||
|
||||
/// Begin progress phase (start learning new task)
|
||||
pub fn begin_progress(&mut self, model: &GnnModel) {
|
||||
self.active_params = Some(model.get_parameters());
|
||||
}
|
||||
|
||||
/// End progress phase and begin compress
|
||||
pub fn compress(&mut self, model: &GnnModel, gradients: &[Vec<f32>]) {
|
||||
let fisher = FisherInformation::from_gradients(gradients);
|
||||
let params = model.get_parameters();
|
||||
let new_state = EwcState::new(params, fisher);
|
||||
|
||||
if let Some(ref mut kb) = self.knowledge_base {
|
||||
kb.merge(&new_state);
|
||||
} else {
|
||||
self.knowledge_base = Some(new_state);
|
||||
}
|
||||
|
||||
self.active_params = None;
|
||||
}
|
||||
|
||||
/// Compute penalty during progress phase
|
||||
#[must_use]
|
||||
pub fn compute_penalty(&self, model: &GnnModel) -> f32 {
|
||||
if let Some(ref kb) = self.knowledge_base {
|
||||
let regularizer = EwcRegularizer::new(self.lambda);
|
||||
regularizer.compute_penalty(model, kb)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{GnnModelType, LearningConfig};
|
||||
|
||||
#[test]
|
||||
fn test_fisher_information() {
|
||||
let mut fisher = FisherInformation::new(2);
|
||||
|
||||
// Update with gradients
|
||||
fisher.update(&[vec![1.0, 2.0], vec![3.0]]);
|
||||
fisher.update(&[vec![2.0, 1.0], vec![4.0]]);
|
||||
|
||||
assert_eq!(fisher.num_samples, 2);
|
||||
assert_eq!(fisher.diagonal.len(), 2);
|
||||
|
||||
// Check averaging
|
||||
assert!(fisher.get_importance(0, 0) > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fisher_from_gradients() {
|
||||
let gradients = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0]];
|
||||
let fisher = FisherInformation::from_gradients(&gradients);
|
||||
|
||||
assert_eq!(fisher.diagonal[0], vec![1.0, 4.0, 9.0]);
|
||||
assert_eq!(fisher.diagonal[1], vec![16.0, 25.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fisher_normalize() {
|
||||
let mut fisher = FisherInformation::from_gradients(&[vec![3.0, 4.0]]);
|
||||
fisher.normalize();
|
||||
|
||||
let total: f32 = fisher.diagonal.iter().flat_map(|l| l.iter()).sum();
|
||||
assert!((total - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ewc_state() {
|
||||
let params = vec![1.0, 2.0, 3.0];
|
||||
let fisher = FisherInformation::from_gradients(&[vec![0.1, 0.2, 0.3]]);
|
||||
|
||||
let state = EwcState::new(params.clone(), fisher);
|
||||
assert_eq!(state.optimal_params.len(), 1);
|
||||
assert_eq!(state.num_tasks, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ewc_state_merge() {
|
||||
let mut state1 = EwcState::new(
|
||||
vec![1.0, 2.0],
|
||||
FisherInformation::from_gradients(&[vec![0.1, 0.2]]),
|
||||
);
|
||||
|
||||
let state2 = EwcState::new(
|
||||
vec![1.5, 2.5],
|
||||
FisherInformation::from_gradients(&[vec![0.3, 0.4]]),
|
||||
);
|
||||
|
||||
state1.merge(&state2);
|
||||
assert_eq!(state1.num_tasks, 2);
|
||||
assert_eq!(state1.optimal_params.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ewc_regularizer() {
|
||||
let mut config = LearningConfig::default();
|
||||
config.input_dim = 4;
|
||||
config.output_dim = 2;
|
||||
config.hyperparameters.num_layers = 1;
|
||||
config.hyperparameters.hidden_dim = 4;
|
||||
|
||||
let model = crate::infrastructure::gnn_model::GnnModel::new(
|
||||
GnnModelType::Gcn,
|
||||
4, 2, 1, 4, 1, 0.0,
|
||||
);
|
||||
|
||||
let params = model.get_parameters();
|
||||
let fisher = FisherInformation::from_gradients(&[vec![0.1; params.len()]]);
|
||||
let ewc_state = EwcState::new(params.clone(), fisher);
|
||||
|
||||
let regularizer = EwcRegularizer::new(1000.0);
|
||||
let penalty = regularizer.compute_penalty(&model, &ewc_state);
|
||||
|
||||
// Penalty should be 0 when params haven't changed
|
||||
assert!(penalty.abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ewc_gradient() {
|
||||
let model = crate::infrastructure::gnn_model::GnnModel::new(
|
||||
GnnModelType::Gcn,
|
||||
4, 2, 1, 4, 1, 0.0,
|
||||
);
|
||||
|
||||
// Create state with slightly different params
|
||||
let mut optimal_params = model.get_parameters();
|
||||
for p in &mut optimal_params {
|
||||
*p += 0.1;
|
||||
}
|
||||
|
||||
let fisher = FisherInformation::from_gradients(&[vec![1.0; optimal_params.len()]]);
|
||||
let ewc_state = EwcState::new(optimal_params, fisher);
|
||||
|
||||
let regularizer = EwcRegularizer::new(1.0);
|
||||
let gradient = regularizer.compute_gradient(&model, &ewc_state);
|
||||
|
||||
// Gradient should push towards optimal params
|
||||
assert!(!gradient.is_empty());
|
||||
for &g in &gradient {
|
||||
// Gradient should be non-zero since params differ
|
||||
assert!(g.abs() > 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_online_ewc() {
|
||||
let model = crate::infrastructure::gnn_model::GnnModel::new(
|
||||
GnnModelType::Gcn,
|
||||
4, 2, 1, 4, 1, 0.0,
|
||||
);
|
||||
|
||||
let mut online = OnlineEwc::new(1000.0, 0.9);
|
||||
|
||||
// Initially no penalty
|
||||
assert_eq!(online.compute_penalty(&model), 0.0);
|
||||
|
||||
// Update after task 1
|
||||
let gradients = vec![vec![0.1; 20]];
|
||||
online.update(&model, &gradients);
|
||||
|
||||
assert!(online.state().is_some());
|
||||
assert_eq!(online.state().unwrap().num_tasks, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_progress_and_compress() {
|
||||
let model = crate::infrastructure::gnn_model::GnnModel::new(
|
||||
GnnModelType::Gcn,
|
||||
4, 2, 1, 4, 1, 0.0,
|
||||
);
|
||||
|
||||
let mut pc = ProgressAndCompress::new(1000.0);
|
||||
|
||||
// Begin progress
|
||||
pc.begin_progress(&model);
|
||||
assert!(pc.active_params.is_some());
|
||||
|
||||
// Compress
|
||||
let gradients = vec![vec![0.1; 20]];
|
||||
pc.compress(&model, &gradients);
|
||||
|
||||
assert!(pc.active_params.is_none());
|
||||
assert!(pc.knowledge_base.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fisher_decay() {
|
||||
let mut fisher = FisherInformation::from_gradients(&[vec![1.0, 2.0]]);
|
||||
fisher.decay(0.5);
|
||||
|
||||
assert_eq!(fisher.diagonal[0], vec![0.5, 2.0]);
|
||||
}
|
||||
}
|
||||
576
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/src/infrastructure/attention.rs
vendored
Normal file
576
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/src/infrastructure/attention.rs
vendored
Normal file
@@ -0,0 +1,576 @@
|
||||
//! Attention mechanisms for GNN models.
|
||||
//!
|
||||
//! This module provides attention layers and mechanisms including:
|
||||
//! - Single-head attention
|
||||
//! - Multi-head attention
|
||||
//! - Graph-level attention readout
|
||||
|
||||
use ndarray::{Array1, Array2, Axis};
|
||||
use rand::Rng;
|
||||
use rand_distr::{Distribution, Uniform};
|
||||
|
||||
/// Error type for attention operations
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum AttentionError {
|
||||
/// Dimension mismatch
|
||||
#[error("Dimension mismatch: expected {expected}, got {actual}")]
|
||||
DimensionMismatch { expected: usize, actual: usize },
|
||||
|
||||
/// Invalid configuration
|
||||
#[error("Invalid attention configuration: {0}")]
|
||||
InvalidConfig(String),
|
||||
|
||||
/// Computation error
|
||||
#[error("Attention computation error: {0}")]
|
||||
ComputationError(String),
|
||||
}
|
||||
|
||||
/// Result type for attention operations
|
||||
pub type AttentionResult<T> = Result<T, AttentionError>;
|
||||
|
||||
/// Single-head attention layer
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AttentionLayer {
|
||||
/// Query weight matrix
|
||||
query_weights: Array2<f32>,
|
||||
/// Key weight matrix
|
||||
key_weights: Array2<f32>,
|
||||
/// Value weight matrix
|
||||
value_weights: Array2<f32>,
|
||||
/// Attention dimension
|
||||
attention_dim: usize,
|
||||
/// Input dimension
|
||||
input_dim: usize,
|
||||
/// Output dimension
|
||||
output_dim: usize,
|
||||
/// Scaling factor
|
||||
scale: f32,
|
||||
}
|
||||
|
||||
impl AttentionLayer {
|
||||
/// Create a new attention layer
|
||||
#[must_use]
|
||||
pub fn new(input_dim: usize, attention_dim: usize) -> Self {
|
||||
let query_weights = xavier_init(input_dim, attention_dim);
|
||||
let key_weights = xavier_init(input_dim, attention_dim);
|
||||
let value_weights = xavier_init(input_dim, attention_dim);
|
||||
let scale = (attention_dim as f32).sqrt();
|
||||
|
||||
Self {
|
||||
query_weights,
|
||||
key_weights,
|
||||
value_weights,
|
||||
attention_dim,
|
||||
input_dim,
|
||||
output_dim: attention_dim,
|
||||
scale,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the output dimension
|
||||
#[must_use]
|
||||
pub fn output_dim(&self) -> usize {
|
||||
self.output_dim
|
||||
}
|
||||
|
||||
/// Compute attention scores
|
||||
pub fn compute_attention(
|
||||
&self,
|
||||
query: &Array2<f32>,
|
||||
key: &Array2<f32>,
|
||||
mask: Option<&Array2<f32>>,
|
||||
) -> AttentionResult<Array2<f32>> {
|
||||
// Q * K^T / sqrt(d_k)
|
||||
let scores = query.dot(&key.t()) / self.scale;
|
||||
|
||||
// Apply mask if provided (set masked positions to -inf)
|
||||
let scores = if let Some(mask) = mask {
|
||||
let mut masked = scores;
|
||||
for i in 0..masked.nrows() {
|
||||
for j in 0..masked.ncols() {
|
||||
if mask[[i, j]] == 0.0 {
|
||||
masked[[i, j]] = f32::NEG_INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
masked
|
||||
} else {
|
||||
scores
|
||||
};
|
||||
|
||||
// Softmax
|
||||
let attention_weights = softmax_2d(&scores);
|
||||
|
||||
Ok(attention_weights)
|
||||
}
|
||||
|
||||
/// Forward pass through the attention layer
|
||||
pub fn forward(&self, features: &Array2<f32>) -> AttentionResult<Array2<f32>> {
|
||||
if features.ncols() != self.input_dim {
|
||||
return Err(AttentionError::DimensionMismatch {
|
||||
expected: self.input_dim,
|
||||
actual: features.ncols(),
|
||||
});
|
||||
}
|
||||
|
||||
// Compute Q, K, V
|
||||
let q = features.dot(&self.query_weights);
|
||||
let k = features.dot(&self.key_weights);
|
||||
let v = features.dot(&self.value_weights);
|
||||
|
||||
// Compute attention weights
|
||||
let attention_weights = self.compute_attention(&q, &k, None)?;
|
||||
|
||||
// Apply attention to values
|
||||
let output = attention_weights.dot(&v);
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Forward pass with explicit Q, K, V
|
||||
pub fn forward_qkv(
|
||||
&self,
|
||||
query: &Array2<f32>,
|
||||
key: &Array2<f32>,
|
||||
value: &Array2<f32>,
|
||||
mask: Option<&Array2<f32>>,
|
||||
) -> AttentionResult<Array2<f32>> {
|
||||
// Transform Q, K, V
|
||||
let q = query.dot(&self.query_weights);
|
||||
let k = key.dot(&self.key_weights);
|
||||
let v = value.dot(&self.value_weights);
|
||||
|
||||
// Compute attention weights
|
||||
let attention_weights = self.compute_attention(&q, &k, mask)?;
|
||||
|
||||
// Apply attention to values
|
||||
let output = attention_weights.dot(&v);
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Graph-level readout using attention
|
||||
pub fn graph_readout(&self, node_features: &Array2<f32>) -> AttentionResult<Array1<f32>> {
|
||||
// Compute attention-weighted mean of node features
|
||||
let attended = self.forward(node_features)?;
|
||||
|
||||
// Mean over nodes
|
||||
let mean = attended.mean_axis(Axis(0)).unwrap();
|
||||
|
||||
Ok(mean)
|
||||
}
|
||||
|
||||
/// Update weights with gradient
|
||||
pub fn update_weights(&mut self, _lr: f32, weight_decay: f32) {
|
||||
self.query_weights -= &(&self.query_weights * weight_decay);
|
||||
self.key_weights -= &(&self.key_weights * weight_decay);
|
||||
self.value_weights -= &(&self.value_weights * weight_decay);
|
||||
}
|
||||
}
|
||||
|
||||
/// Multi-head attention layer
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MultiHeadAttention {
|
||||
/// Individual attention heads
|
||||
heads: Vec<AttentionLayer>,
|
||||
/// Output projection
|
||||
output_projection: Array2<f32>,
|
||||
/// Number of heads
|
||||
num_heads: usize,
|
||||
/// Dimension per head
|
||||
head_dim: usize,
|
||||
/// Total output dimension
|
||||
output_dim: usize,
|
||||
/// Dropout probability
|
||||
dropout: f32,
|
||||
}
|
||||
|
||||
impl MultiHeadAttention {
|
||||
/// Create a new multi-head attention layer
|
||||
#[must_use]
|
||||
pub fn new(input_dim: usize, num_heads: usize, head_dim: usize, dropout: f32) -> Self {
|
||||
let mut heads = Vec::with_capacity(num_heads);
|
||||
for _ in 0..num_heads {
|
||||
heads.push(AttentionLayer::new(input_dim, head_dim));
|
||||
}
|
||||
|
||||
let total_dim = num_heads * head_dim;
|
||||
let output_projection = xavier_init(total_dim, input_dim);
|
||||
|
||||
Self {
|
||||
heads,
|
||||
output_projection,
|
||||
num_heads,
|
||||
head_dim,
|
||||
output_dim: input_dim,
|
||||
dropout,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the number of heads
|
||||
#[must_use]
|
||||
pub fn num_heads(&self) -> usize {
|
||||
self.num_heads
|
||||
}
|
||||
|
||||
/// Get the output dimension
|
||||
#[must_use]
|
||||
pub fn output_dim(&self) -> usize {
|
||||
self.output_dim
|
||||
}
|
||||
|
||||
/// Forward pass through multi-head attention
|
||||
pub fn forward(&self, features: &Array2<f32>) -> AttentionResult<Array2<f32>> {
|
||||
let n = features.nrows();
|
||||
|
||||
// Compute attention for each head
|
||||
let mut head_outputs = Vec::with_capacity(self.num_heads);
|
||||
for head in &self.heads {
|
||||
let output = head.forward(features)?;
|
||||
head_outputs.push(output);
|
||||
}
|
||||
|
||||
// Concatenate head outputs
|
||||
let mut concat = Array2::zeros((n, self.num_heads * self.head_dim));
|
||||
for (h, output) in head_outputs.iter().enumerate() {
|
||||
let start = h * self.head_dim;
|
||||
for i in 0..n {
|
||||
for j in 0..self.head_dim {
|
||||
concat[[i, start + j]] = output[[i, j]];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply output projection
|
||||
let output = concat.dot(&self.output_projection);
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Forward pass with explicit Q, K, V
|
||||
pub fn forward_qkv(
|
||||
&self,
|
||||
query: &Array2<f32>,
|
||||
key: &Array2<f32>,
|
||||
value: &Array2<f32>,
|
||||
mask: Option<&Array2<f32>>,
|
||||
) -> AttentionResult<Array2<f32>> {
|
||||
let n = query.nrows();
|
||||
|
||||
// Compute attention for each head
|
||||
let mut head_outputs = Vec::with_capacity(self.num_heads);
|
||||
for head in &self.heads {
|
||||
let output = head.forward_qkv(query, key, value, mask)?;
|
||||
head_outputs.push(output);
|
||||
}
|
||||
|
||||
// Concatenate head outputs
|
||||
let mut concat = Array2::zeros((n, self.num_heads * self.head_dim));
|
||||
for (h, output) in head_outputs.iter().enumerate() {
|
||||
let start = h * self.head_dim;
|
||||
for i in 0..n {
|
||||
for j in 0..self.head_dim {
|
||||
concat[[i, start + j]] = output[[i, j]];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply output projection
|
||||
let output = concat.dot(&self.output_projection);
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Graph-level readout using multi-head attention
|
||||
pub fn graph_readout(&self, node_features: &Array2<f32>) -> AttentionResult<Array1<f32>> {
|
||||
let attended = self.forward(node_features)?;
|
||||
let mean = attended.mean_axis(Axis(0)).unwrap();
|
||||
Ok(mean)
|
||||
}
|
||||
}
|
||||
|
||||
/// Cross-attention between two sequences
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CrossAttention {
|
||||
/// Query projection for source
|
||||
query_proj: Array2<f32>,
|
||||
/// Key projection for target
|
||||
key_proj: Array2<f32>,
|
||||
/// Value projection for target
|
||||
value_proj: Array2<f32>,
|
||||
/// Attention dimension
|
||||
attention_dim: usize,
|
||||
/// Source dimension
|
||||
source_dim: usize,
|
||||
/// Target dimension
|
||||
target_dim: usize,
|
||||
}
|
||||
|
||||
impl CrossAttention {
|
||||
/// Create a new cross-attention layer
|
||||
#[must_use]
|
||||
pub fn new(source_dim: usize, target_dim: usize, attention_dim: usize) -> Self {
|
||||
Self {
|
||||
query_proj: xavier_init(source_dim, attention_dim),
|
||||
key_proj: xavier_init(target_dim, attention_dim),
|
||||
value_proj: xavier_init(target_dim, attention_dim),
|
||||
attention_dim,
|
||||
source_dim,
|
||||
target_dim,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute cross-attention between source and target
|
||||
pub fn forward(
|
||||
&self,
|
||||
source: &Array2<f32>,
|
||||
target: &Array2<f32>,
|
||||
) -> AttentionResult<Array2<f32>> {
|
||||
// Project source to query
|
||||
let query = source.dot(&self.query_proj);
|
||||
|
||||
// Project target to key and value
|
||||
let key = target.dot(&self.key_proj);
|
||||
let value = target.dot(&self.value_proj);
|
||||
|
||||
// Compute attention scores
|
||||
let scale = (self.attention_dim as f32).sqrt();
|
||||
let scores = query.dot(&key.t()) / scale;
|
||||
|
||||
// Softmax
|
||||
let attention_weights = softmax_2d(&scores);
|
||||
|
||||
// Apply attention to values
|
||||
let output = attention_weights.dot(&value);
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
/// Set attention for set-to-set operations
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SetAttention {
|
||||
/// Multi-head attention
|
||||
mha: MultiHeadAttention,
|
||||
/// Layer normalization parameters
|
||||
layer_norm_weight: Array1<f32>,
|
||||
layer_norm_bias: Array1<f32>,
|
||||
/// Feed-forward network
|
||||
ffn_w1: Array2<f32>,
|
||||
ffn_w2: Array2<f32>,
|
||||
/// Dimensions
|
||||
input_dim: usize,
|
||||
hidden_dim: usize,
|
||||
}
|
||||
|
||||
impl SetAttention {
|
||||
/// Create a new set attention layer (Set Transformer style)
|
||||
#[must_use]
|
||||
pub fn new(input_dim: usize, num_heads: usize, hidden_dim: usize) -> Self {
|
||||
let head_dim = input_dim / num_heads;
|
||||
let mha = MultiHeadAttention::new(input_dim, num_heads, head_dim, 0.0);
|
||||
|
||||
Self {
|
||||
mha,
|
||||
layer_norm_weight: Array1::ones(input_dim),
|
||||
layer_norm_bias: Array1::zeros(input_dim),
|
||||
ffn_w1: xavier_init(input_dim, hidden_dim),
|
||||
ffn_w2: xavier_init(hidden_dim, input_dim),
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass with self-attention and feed-forward
|
||||
pub fn forward(&self, features: &Array2<f32>) -> AttentionResult<Array2<f32>> {
|
||||
// Self-attention with residual
|
||||
let attended = self.mha.forward(features)?;
|
||||
let residual1 = features + &attended;
|
||||
let normed1 = layer_norm(&residual1, &self.layer_norm_weight, &self.layer_norm_bias);
|
||||
|
||||
// Feed-forward with residual
|
||||
let hidden = normed1.dot(&self.ffn_w1).mapv(|x| x.max(0.0)); // ReLU
|
||||
let ffn_out = hidden.dot(&self.ffn_w2);
|
||||
let residual2 = &normed1 + &ffn_out;
|
||||
let output = layer_norm(&residual2, &self.layer_norm_weight, &self.layer_norm_bias);
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Aggregate set elements using learned attention
|
||||
pub fn aggregate(&self, features: &Array2<f32>) -> AttentionResult<Array1<f32>> {
|
||||
let transformed = self.forward(features)?;
|
||||
Ok(transformed.mean_axis(Axis(0)).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
// =========== Helper Functions ===========
|
||||
|
||||
/// Xavier/Glorot initialization
|
||||
fn xavier_init(fan_in: usize, fan_out: usize) -> Array2<f32> {
|
||||
let limit = (6.0 / (fan_in + fan_out) as f32).sqrt();
|
||||
let uniform = Uniform::new(-limit, limit);
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
Array2::from_shape_fn((fan_in, fan_out), |_| uniform.sample(&mut rng))
|
||||
}
|
||||
|
||||
/// Softmax over 2D array (row-wise)
|
||||
fn softmax_2d(scores: &Array2<f32>) -> Array2<f32> {
|
||||
let mut result = scores.clone();
|
||||
|
||||
for mut row in result.rows_mut() {
|
||||
// Subtract max for numerical stability
|
||||
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
|
||||
let mut sum = 0.0;
|
||||
for val in row.iter_mut() {
|
||||
if val.is_finite() {
|
||||
*val = (*val - max).exp();
|
||||
sum += *val;
|
||||
} else {
|
||||
*val = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
if sum > 0.0 {
|
||||
row /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Layer normalization
|
||||
fn layer_norm(x: &Array2<f32>, weight: &Array1<f32>, bias: &Array1<f32>) -> Array2<f32> {
|
||||
let eps = 1e-5;
|
||||
let mut result = x.clone();
|
||||
|
||||
for mut row in result.rows_mut() {
|
||||
let mean = row.mean().unwrap_or(0.0);
|
||||
let variance = row.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / row.len() as f32;
|
||||
let std = (variance + eps).sqrt();
|
||||
|
||||
for (i, val) in row.iter_mut().enumerate() {
|
||||
*val = (*val - mean) / std * weight[i] + bias[i];
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_attention_layer() {
|
||||
let layer = AttentionLayer::new(8, 16);
|
||||
assert_eq!(layer.output_dim(), 16);
|
||||
|
||||
let features = Array2::from_elem((3, 8), 0.5);
|
||||
let output = layer.forward(&features).unwrap();
|
||||
|
||||
assert_eq!(output.shape(), &[3, 16]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multi_head_attention() {
|
||||
let mha = MultiHeadAttention::new(8, 4, 4, 0.0);
|
||||
assert_eq!(mha.num_heads(), 4);
|
||||
assert_eq!(mha.output_dim(), 8);
|
||||
|
||||
let features = Array2::from_elem((5, 8), 0.5);
|
||||
let output = mha.forward(&features).unwrap();
|
||||
|
||||
assert_eq!(output.shape(), &[5, 8]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_attention() {
|
||||
let cross = CrossAttention::new(8, 16, 32);
|
||||
|
||||
let source = Array2::from_elem((3, 8), 0.5);
|
||||
let target = Array2::from_elem((5, 16), 0.5);
|
||||
|
||||
let output = cross.forward(&source, &target).unwrap();
|
||||
assert_eq!(output.shape(), &[3, 32]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_set_attention() {
|
||||
let set_attn = SetAttention::new(16, 4, 64);
|
||||
|
||||
let features = Array2::from_elem((4, 16), 0.5);
|
||||
let output = set_attn.forward(&features).unwrap();
|
||||
|
||||
assert_eq!(output.shape(), &[4, 16]);
|
||||
|
||||
let aggregated = set_attn.aggregate(&features).unwrap();
|
||||
assert_eq!(aggregated.len(), 16);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax() {
|
||||
let scores = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
|
||||
|
||||
let probs = softmax_2d(&scores);
|
||||
|
||||
// Each row should sum to 1
|
||||
for row in probs.rows() {
|
||||
let sum: f32 = row.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-5);
|
||||
}
|
||||
|
||||
// Higher values should have higher probabilities
|
||||
assert!(probs[[0, 2]] > probs[[0, 1]]);
|
||||
assert!(probs[[0, 1]] > probs[[0, 0]]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_layer_norm() {
|
||||
let x = Array2::from_shape_vec((2, 4), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
|
||||
|
||||
let weight = Array1::ones(4);
|
||||
let bias = Array1::zeros(4);
|
||||
|
||||
let normed = layer_norm(&x, &weight, &bias);
|
||||
|
||||
// Each row should have mean ~0 and std ~1
|
||||
for row in normed.rows() {
|
||||
let mean: f32 = row.iter().sum::<f32>() / row.len() as f32;
|
||||
assert!(mean.abs() < 1e-5);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_graph_readout() {
|
||||
let layer = AttentionLayer::new(8, 4);
|
||||
let node_features = Array2::from_elem((5, 8), 0.5);
|
||||
|
||||
let readout = layer.graph_readout(&node_features).unwrap();
|
||||
assert_eq!(readout.len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_with_mask() {
|
||||
let layer = AttentionLayer::new(4, 4);
|
||||
|
||||
let features = Array2::from_elem((3, 4), 0.5);
|
||||
|
||||
// Create a mask that blocks second and third positions
|
||||
let mut mask = Array2::ones((3, 3));
|
||||
mask[[0, 1]] = 0.0;
|
||||
mask[[0, 2]] = 0.0;
|
||||
|
||||
let query = features.dot(&layer.query_weights);
|
||||
let key = features.dot(&layer.key_weights);
|
||||
|
||||
let attn_weights = layer.compute_attention(&query, &key, Some(&mask)).unwrap();
|
||||
|
||||
// First row should only attend to itself
|
||||
assert!(attn_weights[[0, 0]] > 0.99); // Almost all attention to self
|
||||
assert!(attn_weights[[0, 1]] < 0.01);
|
||||
assert!(attn_weights[[0, 2]] < 0.01);
|
||||
}
|
||||
}
|
||||
764
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/src/infrastructure/gnn_model.rs
vendored
Normal file
764
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/src/infrastructure/gnn_model.rs
vendored
Normal file
@@ -0,0 +1,764 @@
|
||||
//! GNN model implementation.
|
||||
//!
|
||||
//! This module provides Graph Neural Network implementations including:
|
||||
//! - GCN (Graph Convolutional Network)
|
||||
//! - GraphSAGE (Sample and Aggregate)
|
||||
//! - GAT (Graph Attention Network)
|
||||
|
||||
use ndarray::{Array1, Array2};
|
||||
use rand::Rng;
|
||||
use rand_distr::{Distribution, Normal, Uniform};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::domain::entities::GnnModelType;
|
||||
use crate::infrastructure::attention::AttentionLayer;
|
||||
|
||||
/// Error type for GNN operations
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum GnnError {
|
||||
/// Dimension mismatch
|
||||
#[error("Dimension mismatch: {0}")]
|
||||
DimensionMismatch(String),
|
||||
|
||||
/// Invalid layer configuration
|
||||
#[error("Invalid layer configuration: {0}")]
|
||||
InvalidConfig(String),
|
||||
|
||||
/// Computation error
|
||||
#[error("Computation error: {0}")]
|
||||
ComputationError(String),
|
||||
}
|
||||
|
||||
/// Result type for GNN operations
|
||||
pub type GnnResult<T> = Result<T, GnnError>;
|
||||
|
||||
/// Aggregation method for GraphSAGE
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum Aggregator {
|
||||
/// Mean aggregation
|
||||
Mean,
|
||||
/// Sum aggregation (GCN-style)
|
||||
Sum,
|
||||
/// Max pooling aggregation
|
||||
MaxPool,
|
||||
/// LSTM aggregation (sequence-aware)
|
||||
Lstm,
|
||||
}
|
||||
|
||||
impl Default for Aggregator {
|
||||
fn default() -> Self {
|
||||
Self::Mean
|
||||
}
|
||||
}
|
||||
|
||||
/// A single layer in a GNN
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum GnnLayer {
|
||||
/// Graph Convolutional Network layer
|
||||
/// H' = σ(D^(-1/2) A D^(-1/2) H W + b)
|
||||
Gcn {
|
||||
/// Weight matrix (input_dim x output_dim)
|
||||
weights: Array2<f32>,
|
||||
/// Bias vector (output_dim)
|
||||
bias: Array1<f32>,
|
||||
},
|
||||
/// GraphSAGE layer
|
||||
/// h_v' = σ(W * CONCAT(h_v, AGGREGATE({h_u : u ∈ N(v)})))
|
||||
GraphSage {
|
||||
/// Aggregation method
|
||||
aggregator: Aggregator,
|
||||
/// Self-weight matrix
|
||||
self_weights: Array2<f32>,
|
||||
/// Neighbor-weight matrix
|
||||
neighbor_weights: Array2<f32>,
|
||||
/// Bias vector
|
||||
bias: Array1<f32>,
|
||||
},
|
||||
/// Graph Attention Network layer
|
||||
/// Uses attention mechanism to weight neighbor contributions
|
||||
Gat {
|
||||
/// Weight matrix for linear transformation
|
||||
weights: Array2<f32>,
|
||||
/// Attention weights (2 * output_dim)
|
||||
attention_weights: Array1<f32>,
|
||||
/// Number of attention heads
|
||||
num_heads: usize,
|
||||
/// Bias vector
|
||||
bias: Array1<f32>,
|
||||
/// Leaky ReLU negative slope
|
||||
negative_slope: f32,
|
||||
},
|
||||
}
|
||||
|
||||
impl GnnLayer {
|
||||
/// Create a new GCN layer
|
||||
pub fn gcn(input_dim: usize, output_dim: usize) -> Self {
|
||||
let weights = xavier_init(input_dim, output_dim);
|
||||
let bias = Array1::zeros(output_dim);
|
||||
Self::Gcn { weights, bias }
|
||||
}
|
||||
|
||||
/// Create a new GraphSAGE layer
|
||||
pub fn graph_sage(input_dim: usize, output_dim: usize, aggregator: Aggregator) -> Self {
|
||||
let self_weights = xavier_init(input_dim, output_dim);
|
||||
let neighbor_weights = xavier_init(input_dim, output_dim);
|
||||
let bias = Array1::zeros(output_dim);
|
||||
Self::GraphSage {
|
||||
aggregator,
|
||||
self_weights,
|
||||
neighbor_weights,
|
||||
bias,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new GAT layer
|
||||
pub fn gat(input_dim: usize, output_dim: usize, num_heads: usize) -> Self {
|
||||
let weights = xavier_init(input_dim, output_dim * num_heads);
|
||||
// Initialize attention weights with small values
|
||||
let mut rng = rand::thread_rng();
|
||||
let uniform = Uniform::new(-0.1, 0.1);
|
||||
let attention_weights: Array1<f32> = Array1::from_iter(
|
||||
(0..2 * output_dim).map(|_| uniform.sample(&mut rng)),
|
||||
);
|
||||
let bias = Array1::zeros(output_dim * num_heads);
|
||||
Self::Gat {
|
||||
weights,
|
||||
attention_weights,
|
||||
num_heads,
|
||||
bias,
|
||||
negative_slope: 0.2,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the input dimension
|
||||
#[must_use]
|
||||
pub fn input_dim(&self) -> usize {
|
||||
match self {
|
||||
Self::Gcn { weights, .. } => weights.nrows(),
|
||||
Self::GraphSage { self_weights, .. } => self_weights.nrows(),
|
||||
Self::Gat { weights, .. } => weights.nrows(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the output dimension
|
||||
#[must_use]
|
||||
pub fn output_dim(&self) -> usize {
|
||||
match self {
|
||||
Self::Gcn { weights, .. } => weights.ncols(),
|
||||
Self::GraphSage { self_weights, .. } => self_weights.ncols(),
|
||||
Self::Gat { weights, num_heads, .. } => weights.ncols() / num_heads,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the layer
|
||||
pub fn forward(&self, features: &Array2<f32>, adj_matrix: &Array2<f32>) -> GnnResult<Array2<f32>> {
|
||||
match self {
|
||||
Self::Gcn { weights, bias } => self.gcn_forward(features, adj_matrix, weights, bias),
|
||||
Self::GraphSage {
|
||||
aggregator,
|
||||
self_weights,
|
||||
neighbor_weights,
|
||||
bias,
|
||||
} => self.sage_forward(features, adj_matrix, *aggregator, self_weights, neighbor_weights, bias),
|
||||
Self::Gat {
|
||||
weights,
|
||||
attention_weights,
|
||||
num_heads,
|
||||
bias,
|
||||
negative_slope,
|
||||
} => self.gat_forward(features, adj_matrix, weights, attention_weights, *num_heads, bias, *negative_slope),
|
||||
}
|
||||
}
|
||||
|
||||
fn gcn_forward(
|
||||
&self,
|
||||
features: &Array2<f32>,
|
||||
adj_matrix: &Array2<f32>,
|
||||
weights: &Array2<f32>,
|
||||
bias: &Array1<f32>,
|
||||
) -> GnnResult<Array2<f32>> {
|
||||
// H' = σ(A_norm * H * W + b)
|
||||
// A_norm is already normalized (symmetric normalization)
|
||||
|
||||
// Aggregate neighbor features: AH
|
||||
let aggregated = adj_matrix.dot(features);
|
||||
|
||||
// Transform: AH * W
|
||||
let transformed = aggregated.dot(weights);
|
||||
|
||||
// Add bias and apply activation
|
||||
let mut output = transformed;
|
||||
for mut row in output.rows_mut() {
|
||||
for (i, val) in row.iter_mut().enumerate() {
|
||||
*val = relu(*val + bias[i]);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn sage_forward(
|
||||
&self,
|
||||
features: &Array2<f32>,
|
||||
adj_matrix: &Array2<f32>,
|
||||
aggregator: Aggregator,
|
||||
self_weights: &Array2<f32>,
|
||||
neighbor_weights: &Array2<f32>,
|
||||
bias: &Array1<f32>,
|
||||
) -> GnnResult<Array2<f32>> {
|
||||
let n = features.nrows();
|
||||
let out_dim = self_weights.ncols();
|
||||
|
||||
// Aggregate neighbor features
|
||||
let neighbor_agg = match aggregator {
|
||||
Aggregator::Mean => {
|
||||
// Mean aggregation
|
||||
let mut agg = adj_matrix.dot(features);
|
||||
let degrees: Vec<f32> = (0..n).map(|i| adj_matrix.row(i).sum().max(1.0)).collect();
|
||||
for (i, mut row) in agg.rows_mut().into_iter().enumerate() {
|
||||
row /= degrees[i];
|
||||
}
|
||||
agg
|
||||
}
|
||||
Aggregator::Sum => {
|
||||
// Sum aggregation
|
||||
adj_matrix.dot(features)
|
||||
}
|
||||
Aggregator::MaxPool => {
|
||||
// Max pooling
|
||||
let mut agg = Array2::zeros((n, features.ncols()));
|
||||
for i in 0..n {
|
||||
for j in 0..features.ncols() {
|
||||
let mut max_val = f32::NEG_INFINITY;
|
||||
for k in 0..n {
|
||||
if adj_matrix[[i, k]] > 0.0 {
|
||||
max_val = max_val.max(features[[k, j]]);
|
||||
}
|
||||
}
|
||||
agg[[i, j]] = if max_val.is_finite() { max_val } else { 0.0 };
|
||||
}
|
||||
}
|
||||
agg
|
||||
}
|
||||
Aggregator::Lstm => {
|
||||
// Simplified LSTM: just use mean for now
|
||||
// Full implementation would use actual LSTM
|
||||
adj_matrix.dot(features)
|
||||
}
|
||||
};
|
||||
|
||||
// Transform self and neighbor features
|
||||
let self_transformed = features.dot(self_weights);
|
||||
let neighbor_transformed = neighbor_agg.dot(neighbor_weights);
|
||||
|
||||
// Combine: concat and add bias
|
||||
let mut output = Array2::zeros((n, out_dim));
|
||||
for i in 0..n {
|
||||
for j in 0..out_dim {
|
||||
let val = self_transformed[[i, j]] + neighbor_transformed[[i, j]] + bias[j];
|
||||
output[[i, j]] = relu(val);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn gat_forward(
|
||||
&self,
|
||||
features: &Array2<f32>,
|
||||
adj_matrix: &Array2<f32>,
|
||||
weights: &Array2<f32>,
|
||||
attention_weights: &Array1<f32>,
|
||||
num_heads: usize,
|
||||
bias: &Array1<f32>,
|
||||
negative_slope: f32,
|
||||
) -> GnnResult<Array2<f32>> {
|
||||
let n = features.nrows();
|
||||
let total_out_dim = weights.ncols();
|
||||
let head_dim = total_out_dim / num_heads;
|
||||
|
||||
// Transform features: H * W
|
||||
let transformed = features.dot(weights);
|
||||
|
||||
// Multi-head attention
|
||||
let mut outputs = Vec::with_capacity(num_heads);
|
||||
|
||||
for head in 0..num_heads {
|
||||
let start = head * head_dim;
|
||||
let end = start + head_dim;
|
||||
|
||||
// Extract features for this head
|
||||
let h = transformed.slice(ndarray::s![.., start..end]).to_owned();
|
||||
|
||||
// Compute attention coefficients
|
||||
let mut attention = Array2::zeros((n, n));
|
||||
let attention_dim = attention_weights.len() / 2;
|
||||
let a_src = attention_weights.slice(ndarray::s![..attention_dim]);
|
||||
let a_dst = attention_weights.slice(ndarray::s![attention_dim..]);
|
||||
|
||||
for i in 0..n {
|
||||
for j in 0..n {
|
||||
if adj_matrix[[i, j]] > 0.0 {
|
||||
// Compute attention: a^T [Wh_i || Wh_j]
|
||||
let mut e = 0.0;
|
||||
for k in 0..head_dim.min(attention_dim) {
|
||||
e += a_src[k] * h[[i, k]] + a_dst[k] * h[[j, k]];
|
||||
}
|
||||
// Leaky ReLU
|
||||
attention[[i, j]] = leaky_relu(e, negative_slope);
|
||||
} else {
|
||||
attention[[i, j]] = f32::NEG_INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Softmax over neighbors
|
||||
for mut row in attention.rows_mut() {
|
||||
let max_val = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let mut sum = 0.0;
|
||||
for val in row.iter_mut() {
|
||||
if val.is_finite() {
|
||||
*val = (*val - max_val).exp();
|
||||
sum += *val;
|
||||
} else {
|
||||
*val = 0.0;
|
||||
}
|
||||
}
|
||||
if sum > 0.0 {
|
||||
row /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
// Apply attention
|
||||
let head_output = attention.dot(&h);
|
||||
outputs.push(head_output);
|
||||
}
|
||||
|
||||
// Concatenate head outputs
|
||||
let mut output = Array2::zeros((n, total_out_dim));
|
||||
for (head, head_out) in outputs.iter().enumerate() {
|
||||
let start = head * head_dim;
|
||||
for i in 0..n {
|
||||
for j in 0..head_dim {
|
||||
output[[i, start + j]] = head_out[[i, j]];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add bias and apply activation
|
||||
for mut row in output.rows_mut() {
|
||||
for (i, val) in row.iter_mut().enumerate() {
|
||||
if i < bias.len() {
|
||||
*val = elu(*val + bias[i], 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Update weights with gradients
|
||||
pub fn update_weights(&mut self, gradient: &Array2<f32>, lr: f32, weight_decay: f32) {
|
||||
match self {
|
||||
Self::Gcn { weights, bias: _ } => {
|
||||
// Apply weight decay
|
||||
*weights -= &(weights.clone() * weight_decay);
|
||||
// Apply gradient update
|
||||
if gradient.shape() == weights.shape() {
|
||||
*weights -= &(gradient * lr);
|
||||
}
|
||||
}
|
||||
Self::GraphSage {
|
||||
self_weights,
|
||||
neighbor_weights,
|
||||
bias: _,
|
||||
..
|
||||
} => {
|
||||
*self_weights -= &(self_weights.clone() * weight_decay);
|
||||
*neighbor_weights -= &(neighbor_weights.clone() * weight_decay);
|
||||
if gradient.shape() == self_weights.shape() {
|
||||
*self_weights -= &(gradient * lr);
|
||||
*neighbor_weights -= &(gradient * lr);
|
||||
}
|
||||
}
|
||||
Self::Gat { weights, bias: _, .. } => {
|
||||
*weights -= &(weights.clone() * weight_decay);
|
||||
if gradient.shape() == weights.shape() {
|
||||
*weights -= &(gradient * lr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get layer parameters as flattened vector
|
||||
#[must_use]
|
||||
pub fn get_parameters(&self) -> Vec<f32> {
|
||||
match self {
|
||||
Self::Gcn { weights, bias } => {
|
||||
let mut params: Vec<f32> = weights.iter().cloned().collect();
|
||||
params.extend(bias.iter().cloned());
|
||||
params
|
||||
}
|
||||
Self::GraphSage {
|
||||
self_weights,
|
||||
neighbor_weights,
|
||||
bias,
|
||||
..
|
||||
} => {
|
||||
let mut params: Vec<f32> = self_weights.iter().cloned().collect();
|
||||
params.extend(neighbor_weights.iter().cloned());
|
||||
params.extend(bias.iter().cloned());
|
||||
params
|
||||
}
|
||||
Self::Gat {
|
||||
weights,
|
||||
attention_weights,
|
||||
bias,
|
||||
..
|
||||
} => {
|
||||
let mut params: Vec<f32> = weights.iter().cloned().collect();
|
||||
params.extend(attention_weights.iter().cloned());
|
||||
params.extend(bias.iter().cloned());
|
||||
params
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Complete GNN model with multiple layers
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GnnModel {
|
||||
/// Model type
|
||||
model_type: GnnModelType,
|
||||
/// Stacked GNN layers
|
||||
layers: Vec<GnnLayer>,
|
||||
/// Optional attention layer for final aggregation
|
||||
attention: Option<AttentionLayer>,
|
||||
/// Dropout probability (applied during training)
|
||||
dropout: f32,
|
||||
/// Whether the model is in training mode
|
||||
training: bool,
|
||||
}
|
||||
|
||||
impl GnnModel {
|
||||
/// Create a new GNN model
|
||||
#[must_use]
|
||||
pub fn new(
|
||||
model_type: GnnModelType,
|
||||
input_dim: usize,
|
||||
output_dim: usize,
|
||||
num_layers: usize,
|
||||
hidden_dim: usize,
|
||||
num_heads: usize,
|
||||
dropout: f32,
|
||||
) -> Self {
|
||||
let mut layers = Vec::with_capacity(num_layers);
|
||||
|
||||
for i in 0..num_layers {
|
||||
let in_dim = if i == 0 { input_dim } else { hidden_dim };
|
||||
let out_dim = if i == num_layers - 1 {
|
||||
output_dim
|
||||
} else {
|
||||
hidden_dim
|
||||
};
|
||||
|
||||
let layer = match model_type {
|
||||
GnnModelType::Gcn => GnnLayer::gcn(in_dim, out_dim),
|
||||
GnnModelType::GraphSage => GnnLayer::graph_sage(in_dim, out_dim, Aggregator::Mean),
|
||||
GnnModelType::Gat => GnnLayer::gat(in_dim, out_dim, num_heads),
|
||||
};
|
||||
|
||||
layers.push(layer);
|
||||
}
|
||||
|
||||
// Add attention layer for graph-level readout
|
||||
let attention = if num_layers > 0 {
|
||||
Some(AttentionLayer::new(output_dim, 64))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Self {
|
||||
model_type,
|
||||
layers,
|
||||
attention,
|
||||
dropout,
|
||||
training: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set training mode
|
||||
pub fn train(&mut self) {
|
||||
self.training = true;
|
||||
}
|
||||
|
||||
/// Set evaluation mode
|
||||
pub fn eval(&mut self) {
|
||||
self.training = false;
|
||||
}
|
||||
|
||||
/// Get the model type
|
||||
#[must_use]
|
||||
pub fn model_type(&self) -> GnnModelType {
|
||||
self.model_type
|
||||
}
|
||||
|
||||
/// Get the number of layers
|
||||
#[must_use]
|
||||
pub fn num_layers(&self) -> usize {
|
||||
self.layers.len()
|
||||
}
|
||||
|
||||
/// Get layer dimensions
|
||||
#[must_use]
|
||||
pub fn layer_dims(&self, layer_idx: usize) -> (usize, usize) {
|
||||
if layer_idx < self.layers.len() {
|
||||
(
|
||||
self.layers[layer_idx].input_dim(),
|
||||
self.layers[layer_idx].output_dim(),
|
||||
)
|
||||
} else {
|
||||
(0, 0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
pub fn forward(
|
||||
&self,
|
||||
features: &Array2<f32>,
|
||||
adj_matrix: &Array2<f32>,
|
||||
) -> GnnResult<Array2<f32>> {
|
||||
let mut h = features.clone();
|
||||
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
h = layer.forward(&h, adj_matrix)?;
|
||||
|
||||
// Apply dropout (except on last layer)
|
||||
if self.training && i < self.layers.len() - 1 {
|
||||
h = self.apply_dropout(&h);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(h)
|
||||
}
|
||||
|
||||
/// Update model weights with gradients
|
||||
pub fn update_weights(
|
||||
&mut self,
|
||||
gradients: &[Array2<f32>],
|
||||
lr: f32,
|
||||
weight_decay: f32,
|
||||
) {
|
||||
for (layer, grad) in self.layers.iter_mut().zip(gradients.iter()) {
|
||||
layer.update_weights(grad, lr, weight_decay);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all model parameters as flattened vector
|
||||
#[must_use]
|
||||
pub fn get_parameters(&self) -> Vec<f32> {
|
||||
let mut params = Vec::new();
|
||||
for layer in &self.layers {
|
||||
params.extend(layer.get_parameters());
|
||||
}
|
||||
params
|
||||
}
|
||||
|
||||
/// Get total number of parameters
|
||||
#[must_use]
|
||||
pub fn num_parameters(&self) -> usize {
|
||||
self.get_parameters().len()
|
||||
}
|
||||
|
||||
fn apply_dropout(&self, features: &Array2<f32>) -> Array2<f32> {
|
||||
if self.dropout <= 0.0 || !self.training {
|
||||
return features.clone();
|
||||
}
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let scale = 1.0 / (1.0 - self.dropout);
|
||||
|
||||
let mut dropped = features.clone();
|
||||
for val in dropped.iter_mut() {
|
||||
if rng.gen::<f32>() < self.dropout {
|
||||
*val = 0.0;
|
||||
} else {
|
||||
*val *= scale;
|
||||
}
|
||||
}
|
||||
|
||||
dropped
|
||||
}
|
||||
}
|
||||
|
||||
// =========== Activation Functions ===========
|
||||
|
||||
/// ReLU activation
|
||||
fn relu(x: f32) -> f32 {
|
||||
x.max(0.0)
|
||||
}
|
||||
|
||||
/// Leaky ReLU activation
|
||||
fn leaky_relu(x: f32, negative_slope: f32) -> f32 {
|
||||
if x >= 0.0 {
|
||||
x
|
||||
} else {
|
||||
negative_slope * x
|
||||
}
|
||||
}
|
||||
|
||||
/// ELU activation
|
||||
fn elu(x: f32, alpha: f32) -> f32 {
|
||||
if x >= 0.0 {
|
||||
x
|
||||
} else {
|
||||
alpha * (x.exp() - 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// GELU activation (Gaussian Error Linear Unit)
|
||||
#[allow(dead_code)]
|
||||
fn gelu(x: f32) -> f32 {
|
||||
0.5 * x * (1.0 + (x * 0.7978845608 * (1.0 + 0.044715 * x * x)).tanh())
|
||||
}
|
||||
|
||||
// =========== Initialization ===========
|
||||
|
||||
/// Xavier/Glorot uniform initialization
|
||||
fn xavier_init(fan_in: usize, fan_out: usize) -> Array2<f32> {
|
||||
let limit = (6.0 / (fan_in + fan_out) as f32).sqrt();
|
||||
let uniform = Uniform::new(-limit, limit);
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
Array2::from_shape_fn((fan_in, fan_out), |_| uniform.sample(&mut rng))
|
||||
}
|
||||
|
||||
/// Kaiming/He initialization (for ReLU)
|
||||
#[allow(dead_code)]
|
||||
fn kaiming_init(fan_in: usize, fan_out: usize) -> Array2<f32> {
|
||||
let std = (2.0 / fan_in as f32).sqrt();
|
||||
let normal = Normal::new(0.0, std).unwrap();
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
Array2::from_shape_fn((fan_in, fan_out), |_| normal.sample(&mut rng))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_gcn_layer() {
|
||||
let layer = GnnLayer::gcn(8, 4);
|
||||
assert_eq!(layer.input_dim(), 8);
|
||||
assert_eq!(layer.output_dim(), 4);
|
||||
|
||||
let features = Array2::from_elem((3, 8), 0.5);
|
||||
let adj = Array2::<f32>::eye(3);
|
||||
|
||||
let output = layer.forward(&features, &adj).unwrap();
|
||||
assert_eq!(output.shape(), &[3, 4]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_graphsage_layer() {
|
||||
let layer = GnnLayer::graph_sage(8, 4, Aggregator::Mean);
|
||||
assert_eq!(layer.input_dim(), 8);
|
||||
assert_eq!(layer.output_dim(), 4);
|
||||
|
||||
let features = Array2::from_elem((3, 8), 0.5);
|
||||
let adj = Array2::<f32>::eye(3);
|
||||
|
||||
let output = layer.forward(&features, &adj).unwrap();
|
||||
assert_eq!(output.shape(), &[3, 4]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gat_layer() {
|
||||
let layer = GnnLayer::gat(8, 4, 2);
|
||||
assert_eq!(layer.input_dim(), 8);
|
||||
|
||||
let features = Array2::from_elem((3, 8), 0.5);
|
||||
let adj = Array2::<f32>::eye(3);
|
||||
|
||||
let output = layer.forward(&features, &adj).unwrap();
|
||||
assert_eq!(output.shape(), &[3, 8]); // 4 * 2 heads
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gnn_model() {
|
||||
let model = GnnModel::new(GnnModelType::Gcn, 16, 8, 2, 32, 4, 0.5);
|
||||
|
||||
assert_eq!(model.model_type(), GnnModelType::Gcn);
|
||||
assert_eq!(model.num_layers(), 2);
|
||||
|
||||
let features = Array2::from_elem((5, 16), 0.5);
|
||||
let adj = Array2::<f32>::eye(5);
|
||||
|
||||
let output = model.forward(&features, &adj).unwrap();
|
||||
assert_eq!(output.shape(), &[5, 8]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_activation_functions() {
|
||||
assert_eq!(relu(-1.0), 0.0);
|
||||
assert_eq!(relu(1.0), 1.0);
|
||||
|
||||
assert!(leaky_relu(-1.0, 0.2) < 0.0);
|
||||
assert_eq!(leaky_relu(1.0, 0.2), 1.0);
|
||||
|
||||
assert!(elu(-1.0, 1.0) < 0.0);
|
||||
assert_eq!(elu(1.0, 1.0), 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_xavier_init() {
|
||||
let weights = xavier_init(100, 100);
|
||||
assert_eq!(weights.shape(), &[100, 100]);
|
||||
|
||||
// Check values are in reasonable range
|
||||
let max = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let min = weights.iter().cloned().fold(f32::INFINITY, f32::min);
|
||||
|
||||
assert!(max < 1.0);
|
||||
assert!(min > -1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_parameters() {
|
||||
let model = GnnModel::new(GnnModelType::Gcn, 8, 4, 2, 16, 1, 0.0);
|
||||
|
||||
let params = model.get_parameters();
|
||||
assert!(params.len() > 0);
|
||||
|
||||
// Layer 1: 8*16 + 16 = 144
|
||||
// Layer 2: 16*4 + 4 = 68
|
||||
// Total: 212
|
||||
assert_eq!(model.num_parameters(), 8 * 16 + 16 + 16 * 4 + 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregators() {
|
||||
let features = Array2::from_shape_vec((3, 4), vec![
|
||||
1.0, 2.0, 3.0, 4.0,
|
||||
5.0, 6.0, 7.0, 8.0,
|
||||
9.0, 10.0, 11.0, 12.0,
|
||||
]).unwrap();
|
||||
|
||||
let mut adj = Array2::zeros((3, 3));
|
||||
adj[[0, 1]] = 1.0;
|
||||
adj[[0, 2]] = 1.0;
|
||||
adj[[1, 0]] = 1.0;
|
||||
adj[[2, 0]] = 1.0;
|
||||
|
||||
// Test mean aggregation
|
||||
let layer = GnnLayer::graph_sage(4, 2, Aggregator::Mean);
|
||||
let output = layer.forward(&features, &adj).unwrap();
|
||||
assert_eq!(output.shape(), &[3, 2]);
|
||||
|
||||
// Test max aggregation
|
||||
let layer = GnnLayer::graph_sage(4, 2, Aggregator::MaxPool);
|
||||
let output = layer.forward(&features, &adj).unwrap();
|
||||
assert_eq!(output.shape(), &[3, 2]);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
//! Infrastructure layer for the learning bounded context.
|
||||
//!
|
||||
//! Contains GNN model implementations, attention mechanisms,
|
||||
//! and other technical components.
|
||||
|
||||
pub mod gnn_model;
|
||||
pub mod attention;
|
||||
82
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/src/lib.rs
vendored
Normal file
82
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/src/lib.rs
vendored
Normal file
@@ -0,0 +1,82 @@
|
||||
//! # sevensense-learning
|
||||
//!
|
||||
//! Graph Neural Network (GNN) based learning and embedding refinement for 7sense.
|
||||
//!
|
||||
//! This crate provides:
|
||||
//! - GNN models (GCN, GraphSAGE, GAT) for graph-based learning
|
||||
//! - Embedding refinement through message passing
|
||||
//! - Contrastive learning with InfoNCE and triplet loss
|
||||
//! - Elastic Weight Consolidation (EWC) for continual learning
|
||||
//! - Graph attention mechanisms for relationship modeling
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! The crate follows Domain-Driven Design principles:
|
||||
//! - `domain`: Core entities and repository traits
|
||||
//! - `application`: Business logic and services
|
||||
//! - `infrastructure`: GNN implementations and attention mechanisms
|
||||
//!
|
||||
//! ## Example
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use sevensense_learning::{LearningService, LearningConfig, GnnModelType};
|
||||
//!
|
||||
//! let config = LearningConfig::default();
|
||||
//! let service = LearningService::new(config);
|
||||
//!
|
||||
//! // Train on transition graph
|
||||
//! let metrics = service.train_epoch(&graph).await?;
|
||||
//!
|
||||
//! // Refine embeddings
|
||||
//! let refined = service.refine_embeddings(&embeddings).await?;
|
||||
//! ```
|
||||
|
||||
#![warn(missing_docs)]
|
||||
#![warn(clippy::all)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
|
||||
pub mod domain;
|
||||
pub mod application;
|
||||
pub mod infrastructure;
|
||||
pub mod loss;
|
||||
pub mod ewc;
|
||||
|
||||
// Re-exports for convenience
|
||||
pub use domain::entities::{
|
||||
LearningSession, GnnModelType, TrainingStatus, TrainingMetrics,
|
||||
TransitionGraph, RefinedEmbedding, EmbeddingId, Timestamp,
|
||||
GraphNode, GraphEdge, HyperParameters, LearningConfig,
|
||||
};
|
||||
pub use domain::repository::LearningRepository;
|
||||
pub use application::services::LearningService;
|
||||
pub use infrastructure::gnn_model::{GnnModel, GnnLayer, Aggregator};
|
||||
pub use infrastructure::attention::{AttentionLayer, MultiHeadAttention};
|
||||
pub use loss::{info_nce_loss, triplet_loss, margin_ranking_loss, contrastive_loss};
|
||||
pub use ewc::{EwcState, FisherInformation, EwcRegularizer};
|
||||
|
||||
/// Crate version information
|
||||
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
/// Prelude module for convenient imports
|
||||
pub mod prelude {
|
||||
pub use crate::domain::entities::*;
|
||||
pub use crate::domain::repository::*;
|
||||
pub use crate::application::services::*;
|
||||
pub use crate::infrastructure::gnn_model::*;
|
||||
pub use crate::infrastructure::attention::*;
|
||||
pub use crate::loss::*;
|
||||
pub use crate::ewc::*;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_crate_exports() {
|
||||
// Verify all public types are accessible
|
||||
let _: GnnModelType = GnnModelType::Gcn;
|
||||
let _: TrainingStatus = TrainingStatus::Pending;
|
||||
}
|
||||
}
|
||||
612
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/src/loss.rs
vendored
Normal file
612
vendor/ruvector/examples/vibecast-7sense/crates/sevensense-learning/src/loss.rs
vendored
Normal file
@@ -0,0 +1,612 @@
|
||||
//! Loss functions for contrastive learning.
|
||||
//!
|
||||
//! This module provides various loss functions for training GNN models
|
||||
//! on graph-structured data with contrastive learning objectives.
|
||||
|
||||
|
||||
/// Compute InfoNCE (Noise Contrastive Estimation) loss.
|
||||
///
|
||||
/// InfoNCE loss encourages the model to distinguish positive samples
|
||||
/// from negative samples in the embedding space.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `anchor` - The anchor embedding
|
||||
/// * `positive` - The positive (similar) embedding
|
||||
/// * `negatives` - Slice of negative (dissimilar) embeddings
|
||||
/// * `temperature` - Temperature parameter for softmax scaling (typical: 0.07-0.5)
|
||||
///
|
||||
/// # Returns
|
||||
/// The InfoNCE loss value (lower is better)
|
||||
///
|
||||
/// # Formula
|
||||
/// L = -log(exp(sim(a,p)/τ) / Σ exp(sim(a,n_i)/τ))
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use sevensense_learning::info_nce_loss;
|
||||
///
|
||||
/// let anchor = vec![1.0, 0.0, 0.0];
|
||||
/// let positive = vec![0.9, 0.1, 0.0];
|
||||
/// let negative = vec![0.0, 1.0, 0.0];
|
||||
///
|
||||
/// let loss = info_nce_loss(&anchor, &positive, &[&negative], 0.07);
|
||||
/// assert!(loss >= 0.0);
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub fn info_nce_loss(
|
||||
anchor: &[f32],
|
||||
positive: &[f32],
|
||||
negatives: &[&[f32]],
|
||||
temperature: f32,
|
||||
) -> f32 {
|
||||
if anchor.is_empty() || positive.is_empty() || negatives.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let temp = temperature.max(1e-6); // Prevent division by zero
|
||||
|
||||
// Compute similarity with positive
|
||||
let pos_sim = cosine_similarity(anchor, positive) / temp;
|
||||
|
||||
// Compute similarities with all negatives
|
||||
let neg_sims: Vec<f32> = negatives
|
||||
.iter()
|
||||
.map(|neg| cosine_similarity(anchor, neg) / temp)
|
||||
.collect();
|
||||
|
||||
// Log-sum-exp for numerical stability
|
||||
// L = -pos_sim + log(exp(pos_sim) + Σ exp(neg_sim_i))
|
||||
let max_sim = neg_sims
|
||||
.iter()
|
||||
.chain(std::iter::once(&pos_sim))
|
||||
.cloned()
|
||||
.fold(f32::NEG_INFINITY, f32::max);
|
||||
|
||||
let sum_exp: f32 = std::iter::once(pos_sim)
|
||||
.chain(neg_sims)
|
||||
.map(|s| (s - max_sim).exp())
|
||||
.sum();
|
||||
|
||||
let log_sum_exp = max_sim + sum_exp.ln();
|
||||
|
||||
-pos_sim + log_sum_exp
|
||||
}
|
||||
|
||||
/// Compute triplet loss with margin.
|
||||
///
|
||||
/// Triplet loss ensures that the anchor is closer to the positive
|
||||
/// than to the negative by at least a margin.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `anchor` - The anchor embedding
|
||||
/// * `positive` - The positive (similar) embedding
|
||||
/// * `negative` - The negative (dissimilar) embedding
|
||||
/// * `margin` - The margin to enforce between positive and negative distances
|
||||
///
|
||||
/// # Returns
|
||||
/// The triplet loss value (lower is better)
|
||||
///
|
||||
/// # Formula
|
||||
/// L = max(0, d(a,p) - d(a,n) + margin)
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use sevensense_learning::triplet_loss;
|
||||
///
|
||||
/// let anchor = vec![1.0, 0.0, 0.0];
|
||||
/// let positive = vec![0.9, 0.1, 0.0];
|
||||
/// let negative = vec![0.0, 1.0, 0.0];
|
||||
///
|
||||
/// let loss = triplet_loss(&anchor, &positive, &negative, 1.0);
|
||||
/// assert!(loss >= 0.0);
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub fn triplet_loss(anchor: &[f32], positive: &[f32], negative: &[f32], margin: f32) -> f32 {
|
||||
if anchor.is_empty() || positive.is_empty() || negative.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let d_pos = euclidean_distance(anchor, positive);
|
||||
let d_neg = euclidean_distance(anchor, negative);
|
||||
|
||||
(d_pos - d_neg + margin).max(0.0)
|
||||
}
|
||||
|
||||
/// Compute margin ranking loss.
|
||||
///
|
||||
/// Similar to triplet loss but uses a ranking formulation.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `anchor` - The anchor embedding
|
||||
/// * `positive` - The positive embedding
|
||||
/// * `negative` - The negative embedding
|
||||
/// * `margin` - The margin for ranking
|
||||
///
|
||||
/// # Returns
|
||||
/// The margin ranking loss value
|
||||
///
|
||||
/// # Formula
|
||||
/// L = max(0, margin - (sim(a,p) - sim(a,n)))
|
||||
#[must_use]
|
||||
pub fn margin_ranking_loss(
|
||||
anchor: &[f32],
|
||||
positive: &[f32],
|
||||
negative: &[f32],
|
||||
margin: f32,
|
||||
) -> f32 {
|
||||
if anchor.is_empty() || positive.is_empty() || negative.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let sim_pos = cosine_similarity(anchor, positive);
|
||||
let sim_neg = cosine_similarity(anchor, negative);
|
||||
|
||||
(margin - (sim_pos - sim_neg)).max(0.0)
|
||||
}
|
||||
|
||||
/// Compute contrastive loss (SimCLR style).
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `z_i` - First view embedding
|
||||
/// * `z_j` - Second view embedding (augmented view of same sample)
|
||||
/// * `other_samples` - Embeddings of other samples in the batch
|
||||
/// * `temperature` - Temperature parameter
|
||||
///
|
||||
/// # Returns
|
||||
/// The contrastive loss value
|
||||
#[must_use]
|
||||
pub fn contrastive_loss(
|
||||
z_i: &[f32],
|
||||
z_j: &[f32],
|
||||
other_samples: &[&[f32]],
|
||||
temperature: f32,
|
||||
) -> f32 {
|
||||
if z_i.is_empty() || z_j.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let temp = temperature.max(1e-6);
|
||||
|
||||
// Similarity between positive pair
|
||||
let pos_sim = cosine_similarity(z_i, z_j) / temp;
|
||||
|
||||
// Similarities with all other samples
|
||||
let mut all_sims: Vec<f32> = vec![pos_sim];
|
||||
for sample in other_samples {
|
||||
all_sims.push(cosine_similarity(z_i, sample) / temp);
|
||||
}
|
||||
|
||||
// Log-sum-exp trick
|
||||
let max_sim = all_sims.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let sum_exp: f32 = all_sims.iter().map(|s| (s - max_sim).exp()).sum();
|
||||
let log_sum_exp = max_sim + sum_exp.ln();
|
||||
|
||||
-pos_sim + log_sum_exp
|
||||
}
|
||||
|
||||
/// Compute NT-Xent loss (Normalized Temperature-scaled Cross Entropy).
|
||||
///
|
||||
/// This is the loss function used in SimCLR.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `embeddings` - All embeddings in the batch (2N for N samples with 2 views each)
|
||||
/// * `temperature` - Temperature parameter
|
||||
///
|
||||
/// # Returns
|
||||
/// The average NT-Xent loss across all positive pairs
|
||||
#[must_use]
|
||||
pub fn nt_xent_loss(embeddings: &[Vec<f32>], temperature: f32) -> f32 {
|
||||
let n = embeddings.len();
|
||||
if n < 2 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let temp = temperature.max(1e-6);
|
||||
|
||||
// Assume embeddings are organized as [z_1_a, z_1_b, z_2_a, z_2_b, ...]
|
||||
// where z_i_a and z_i_b are two views of sample i
|
||||
|
||||
// Compute all pairwise similarities
|
||||
let mut sim_matrix = vec![vec![0.0f32; n]; n];
|
||||
for i in 0..n {
|
||||
for j in 0..n {
|
||||
sim_matrix[i][j] = cosine_similarity(&embeddings[i], &embeddings[j]) / temp;
|
||||
}
|
||||
}
|
||||
|
||||
let mut total_loss = 0.0;
|
||||
let mut count = 0;
|
||||
|
||||
// For each sample, compute loss against its positive pair
|
||||
for i in 0..n {
|
||||
// Find positive pair (assumes alternating views)
|
||||
let j = if i % 2 == 0 { i + 1 } else { i - 1 };
|
||||
if j >= n {
|
||||
continue;
|
||||
}
|
||||
|
||||
let pos_sim = sim_matrix[i][j];
|
||||
|
||||
// Sum of all negative similarities (excluding self and positive)
|
||||
let max_sim = sim_matrix[i]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(k, _)| *k != i)
|
||||
.map(|(_, &s)| s)
|
||||
.fold(f32::NEG_INFINITY, f32::max);
|
||||
|
||||
let sum_exp: f32 = sim_matrix[i]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(k, _)| *k != i)
|
||||
.map(|(_, &s)| (s - max_sim).exp())
|
||||
.sum();
|
||||
|
||||
let log_sum_exp = max_sim + sum_exp.ln();
|
||||
|
||||
total_loss += -pos_sim + log_sum_exp;
|
||||
count += 1;
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
total_loss / count as f32
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute supervised contrastive loss (SupCon).
|
||||
///
|
||||
/// Extends contrastive loss to use label information.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `embeddings` - All embeddings
|
||||
/// * `labels` - Label for each embedding
|
||||
/// * `temperature` - Temperature parameter
|
||||
///
|
||||
/// # Returns
|
||||
/// The supervised contrastive loss
|
||||
#[must_use]
|
||||
pub fn supervised_contrastive_loss(
|
||||
embeddings: &[Vec<f32>],
|
||||
labels: &[usize],
|
||||
temperature: f32,
|
||||
) -> f32 {
|
||||
let n = embeddings.len();
|
||||
if n < 2 || n != labels.len() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let temp = temperature.max(1e-6);
|
||||
|
||||
// Compute all pairwise similarities
|
||||
let mut sim_matrix = vec![vec![0.0f32; n]; n];
|
||||
for i in 0..n {
|
||||
for j in 0..n {
|
||||
sim_matrix[i][j] = cosine_similarity(&embeddings[i], &embeddings[j]) / temp;
|
||||
}
|
||||
}
|
||||
|
||||
let mut total_loss = 0.0;
|
||||
|
||||
for i in 0..n {
|
||||
// Find all positive pairs (same label, excluding self)
|
||||
let positives: Vec<usize> = (0..n)
|
||||
.filter(|&j| j != i && labels[j] == labels[i])
|
||||
.collect();
|
||||
|
||||
if positives.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compute denominator (all except self)
|
||||
let max_sim = sim_matrix[i]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(j, _)| *j != i)
|
||||
.map(|(_, &s)| s)
|
||||
.fold(f32::NEG_INFINITY, f32::max);
|
||||
|
||||
let denom_exp: f32 = sim_matrix[i]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(j, _)| *j != i)
|
||||
.map(|(_, &s)| (s - max_sim).exp())
|
||||
.sum();
|
||||
|
||||
let log_denom = max_sim + denom_exp.ln();
|
||||
|
||||
// Average over all positive pairs
|
||||
let pos_loss: f32 = positives
|
||||
.iter()
|
||||
.map(|&j| -sim_matrix[i][j] + log_denom)
|
||||
.sum();
|
||||
|
||||
total_loss += pos_loss / positives.len() as f32;
|
||||
}
|
||||
|
||||
total_loss / n as f32
|
||||
}
|
||||
|
||||
/// Compute center loss.
|
||||
///
|
||||
/// Encourages embeddings to be close to their class centers.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `embeddings` - Embeddings
|
||||
/// * `labels` - Class labels
|
||||
/// * `centers` - Class center embeddings
|
||||
///
|
||||
/// # Returns
|
||||
/// The center loss value
|
||||
#[must_use]
|
||||
pub fn center_loss(
|
||||
embeddings: &[Vec<f32>],
|
||||
labels: &[usize],
|
||||
centers: &[Vec<f32>],
|
||||
) -> f32 {
|
||||
if embeddings.is_empty() || embeddings.len() != labels.len() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mut total_loss = 0.0;
|
||||
|
||||
for (emb, &label) in embeddings.iter().zip(labels.iter()) {
|
||||
if label < centers.len() {
|
||||
let dist = euclidean_distance(emb, ¢ers[label]);
|
||||
total_loss += dist * dist;
|
||||
}
|
||||
}
|
||||
|
||||
total_loss / (2.0 * embeddings.len() as f32)
|
||||
}
|
||||
|
||||
/// Compute focal loss (for imbalanced classification).
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `predictions` - Predicted probabilities
|
||||
/// * `targets` - Target labels
|
||||
/// * `gamma` - Focusing parameter (typical: 2.0)
|
||||
/// * `alpha` - Class weighting parameter
|
||||
///
|
||||
/// # Returns
|
||||
/// The focal loss value
|
||||
#[must_use]
|
||||
pub fn focal_loss(
|
||||
predictions: &[f32],
|
||||
targets: &[usize],
|
||||
gamma: f32,
|
||||
alpha: f32,
|
||||
) -> f32 {
|
||||
if predictions.is_empty() || predictions.len() != targets.len() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let eps = 1e-7;
|
||||
let mut total_loss = 0.0;
|
||||
|
||||
for (&pred, &target) in predictions.iter().zip(targets.iter()) {
|
||||
let p = pred.clamp(eps, 1.0 - eps);
|
||||
let pt = if target == 1 { p } else { 1.0 - p };
|
||||
let at = if target == 1 { alpha } else { 1.0 - alpha };
|
||||
|
||||
let loss = -at * (1.0 - pt).powf(gamma) * pt.ln();
|
||||
total_loss += loss;
|
||||
}
|
||||
|
||||
total_loss / predictions.len() as f32
|
||||
}
|
||||
|
||||
// =========== Helper Functions ===========
|
||||
|
||||
/// Compute cosine similarity between two vectors
|
||||
#[inline]
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() || a.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a < 1e-10 || norm_b < 1e-10 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
dot / (norm_a * norm_b)
|
||||
}
|
||||
|
||||
/// Compute Euclidean distance between two vectors
|
||||
#[inline]
|
||||
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() {
|
||||
return f32::INFINITY;
|
||||
}
|
||||
|
||||
a.iter()
|
||||
.zip(b)
|
||||
.map(|(x, y)| (x - y).powi(2))
|
||||
.sum::<f32>()
|
||||
.sqrt()
|
||||
}
|
||||
|
||||
/// Compute squared Euclidean distance (faster, no sqrt)
|
||||
#[inline]
|
||||
#[allow(dead_code)]
|
||||
fn squared_euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() {
|
||||
return f32::INFINITY;
|
||||
}
|
||||
|
||||
a.iter().zip(b).map(|(x, y)| (x - y).powi(2)).sum()
|
||||
}
|
||||
|
||||
/// Compute dot product
|
||||
#[inline]
|
||||
#[allow(dead_code)]
|
||||
fn dot_product(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b).map(|(x, y)| x * y).sum()
|
||||
}
|
||||
|
||||
/// L2 normalize a vector
|
||||
#[must_use]
|
||||
pub fn l2_normalize(v: &[f32]) -> Vec<f32> {
|
||||
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm < 1e-10 {
|
||||
return v.to_vec();
|
||||
}
|
||||
v.iter().map(|x| x / norm).collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_info_nce_loss() {
|
||||
let anchor = vec![1.0, 0.0, 0.0];
|
||||
let positive = vec![0.95, 0.05, 0.0];
|
||||
let negative1 = vec![0.0, 1.0, 0.0];
|
||||
let negative2 = vec![0.0, 0.0, 1.0];
|
||||
|
||||
let loss = info_nce_loss(&anchor, &positive, &[&negative1, &negative2], 0.1);
|
||||
assert!(loss >= 0.0);
|
||||
|
||||
// Similar positive should have lower loss
|
||||
let similar_positive = vec![0.99, 0.01, 0.0];
|
||||
let lower_loss = info_nce_loss(&anchor, &similar_positive, &[&negative1, &negative2], 0.1);
|
||||
assert!(lower_loss < loss);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_triplet_loss() {
|
||||
let anchor = vec![1.0, 0.0, 0.0];
|
||||
let positive = vec![0.9, 0.1, 0.0];
|
||||
let negative = vec![0.0, 1.0, 0.0];
|
||||
|
||||
let loss = triplet_loss(&anchor, &positive, &negative, 1.0);
|
||||
assert!(loss >= 0.0);
|
||||
|
||||
// When positive is very similar, loss should be lower
|
||||
let close_positive = vec![0.99, 0.01, 0.0];
|
||||
let lower_loss = triplet_loss(&anchor, &close_positive, &negative, 1.0);
|
||||
assert!(lower_loss <= loss);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_margin_ranking_loss() {
|
||||
let anchor = vec![1.0, 0.0, 0.0];
|
||||
let positive = vec![0.9, 0.1, 0.0];
|
||||
let negative = vec![0.0, 1.0, 0.0];
|
||||
|
||||
let loss = margin_ranking_loss(&anchor, &positive, &negative, 0.5);
|
||||
assert!(loss >= 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_contrastive_loss() {
|
||||
let z_i = vec![1.0, 0.0, 0.0];
|
||||
let z_j = vec![0.95, 0.05, 0.0];
|
||||
let other1 = vec![0.0, 1.0, 0.0];
|
||||
let other2 = vec![0.0, 0.0, 1.0];
|
||||
|
||||
let loss = contrastive_loss(&z_i, &z_j, &[&other1, &other2], 0.1);
|
||||
assert!(loss >= 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nt_xent_loss() {
|
||||
let embeddings = vec![
|
||||
vec![1.0, 0.0],
|
||||
vec![0.95, 0.05], // positive pair with first
|
||||
vec![0.0, 1.0],
|
||||
vec![0.05, 0.95], // positive pair with third
|
||||
];
|
||||
|
||||
let loss = nt_xent_loss(&embeddings, 0.5);
|
||||
assert!(loss >= 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_supervised_contrastive_loss() {
|
||||
let embeddings = vec![
|
||||
vec![1.0, 0.0],
|
||||
vec![0.9, 0.1], // same class as first
|
||||
vec![0.0, 1.0],
|
||||
vec![0.1, 0.9], // same class as third
|
||||
];
|
||||
let labels = vec![0, 0, 1, 1];
|
||||
|
||||
let loss = supervised_contrastive_loss(&embeddings, &labels, 0.1);
|
||||
assert!(loss >= 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_center_loss() {
|
||||
let embeddings = vec![
|
||||
vec![1.0, 0.0],
|
||||
vec![0.9, 0.1],
|
||||
vec![0.0, 1.0],
|
||||
];
|
||||
let labels = vec![0, 0, 1];
|
||||
let centers = vec![
|
||||
vec![0.95, 0.05], // center for class 0
|
||||
vec![0.05, 0.95], // center for class 1
|
||||
];
|
||||
|
||||
let loss = center_loss(&embeddings, &labels, ¢ers);
|
||||
assert!(loss >= 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_focal_loss() {
|
||||
let predictions = vec![0.9, 0.1, 0.8, 0.2];
|
||||
let targets = vec![1, 0, 1, 0];
|
||||
|
||||
let loss = focal_loss(&predictions, &targets, 2.0, 0.25);
|
||||
assert!(loss >= 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
|
||||
|
||||
let c = vec![0.0, 1.0, 0.0];
|
||||
assert!(cosine_similarity(&a, &c).abs() < 1e-6);
|
||||
|
||||
let d = vec![-1.0, 0.0, 0.0];
|
||||
assert!((cosine_similarity(&a, &d) + 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_distance() {
|
||||
let a = vec![0.0, 0.0];
|
||||
let b = vec![3.0, 4.0];
|
||||
assert!((euclidean_distance(&a, &b) - 5.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_l2_normalize() {
|
||||
let v = vec![3.0, 4.0];
|
||||
let normalized = l2_normalize(&v);
|
||||
|
||||
let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
assert!((norm - 1.0).abs() < 1e-6);
|
||||
|
||||
assert!((normalized[0] - 0.6).abs() < 1e-6);
|
||||
assert!((normalized[1] - 0.8).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_inputs() {
|
||||
let empty: Vec<f32> = vec![];
|
||||
let valid = vec![1.0, 0.0, 0.0];
|
||||
|
||||
assert_eq!(info_nce_loss(&empty, &valid, &[&valid], 0.1), 0.0);
|
||||
assert_eq!(triplet_loss(&empty, &valid, &valid, 1.0), 0.0);
|
||||
assert_eq!(contrastive_loss(&empty, &valid, &[], 0.1), 0.0);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user