Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,59 @@
[package]
name = "sevensense-learning"
description = "GNN-based learning and embedding refinement for 7sense bioacoustics platform"
version.workspace = true
edition.workspace = true
rust-version.workspace = true
license.workspace = true
repository.workspace = true
authors.workspace = true
readme = "README.md"
[dependencies]
# Internal crates
sevensense-core = { workspace = true, version = "0.1.0" }
sevensense-vector = { workspace = true, version = "0.1.0" }
# Core dependencies
uuid = { workspace = true }
chrono = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
# Async runtime
tokio = { workspace = true }
async-trait = { workspace = true }
futures = { workspace = true }
# Tracing
tracing = { workspace = true }
# ML and numerical computing
ndarray = { workspace = true }
ndarray-rand = "0.14"
# Graph operations
petgraph = "0.6"
# Parallel processing
rayon = "1.10"
# Random number generation
rand = "0.8"
rand_distr = "0.4"
[dev-dependencies]
proptest = { workspace = true }
criterion = { workspace = true }
test-case = { workspace = true }
tokio = { workspace = true, features = ["test-util", "macros"] }
approx = "0.5"
[features]
default = []
[[bench]]
name = "gnn_benchmark"
harness = false

View File

@@ -0,0 +1,416 @@
# sevensense-learning
[![Crate](https://img.shields.io/badge/crates.io-sevensense--learning-orange.svg)](https://crates.io/crates/sevensense-learning)
[![Docs](https://img.shields.io/badge/docs-sevensense--learning-blue.svg)](https://docs.rs/sevensense-learning)
[![License](https://img.shields.io/badge/license-MIT-blue.svg)](../../LICENSE)
> Graph Neural Network (GNN) learning for bioacoustic pattern discovery.
**sevensense-learning** implements online learning algorithms that discover patterns in bird vocalizations over time. Using Graph Neural Networks with Elastic Weight Consolidation (EWC), it learns species-specific call patterns, dialect variations, and behavioral signatures without forgetting previously learned knowledge.
## Features
- **GNN Architecture**: Graph-based learning on similarity networks
- **EWC Regularization**: Prevents catastrophic forgetting in online learning
- **Online Updates**: Continuous learning from streaming data
- **Transition Graphs**: Model sequential call patterns
- **Fisher Information**: Importance-weighted parameter updates
- **Gradient Checkpointing**: Memory-efficient training
## Use Cases
| Use Case | Description | Key Functions |
|----------|-------------|---------------|
| Pattern Learning | Learn call patterns | `train()`, `learn_patterns()` |
| Online Updates | Incremental learning | `online_update()` |
| Transition Modeling | Sequential patterns | `TransitionGraph::learn()` |
| EWC Training | Continual learning | `ewc_train()` |
| Inference | Pattern prediction | `predict()`, `infer()` |
## Installation
Add to your `Cargo.toml`:
```toml
[dependencies]
sevensense-learning = "0.1"
```
## Quick Start
```rust
use sevensense_learning::{GnnModel, GnnConfig, TransitionGraph};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Create GNN model
let config = GnnConfig {
hidden_dim: 256,
num_layers: 3,
dropout: 0.1,
..Default::default()
};
let mut model = GnnModel::new(config);
// Build transition graph from embeddings
let graph = TransitionGraph::from_embeddings(&embeddings, 0.8)?;
// Train the model
model.train(&graph, &embeddings, 100)?; // 100 epochs
// Make predictions
let prediction = model.predict(&query_embedding)?;
println!("Predicted pattern: {:?}", prediction);
Ok(())
}
```
---
<details>
<summary><b>Tutorial: Building Transition Graphs</b></summary>
### Creating from Embeddings
```rust
use sevensense_learning::{TransitionGraph, GraphConfig};
// Config for graph construction
let config = GraphConfig {
similarity_threshold: 0.8, // Edge threshold
max_neighbors: 10, // Max edges per node
temporal_decay: 0.95, // Time-based edge weighting
};
// Build graph from embeddings with timestamps
let graph = TransitionGraph::new(config);
for (id, embedding, timestamp) in recordings.iter() {
graph.add_node(*id, embedding, *timestamp)?;
}
// Automatically computes edges based on similarity
graph.build_edges()?;
println!("Graph has {} nodes and {} edges",
graph.node_count(),
graph.edge_count());
```
### Analyzing Graph Structure
```rust
use sevensense_learning::TransitionGraph;
// Get neighbors for a node
let neighbors = graph.neighbors(node_id)?;
for (neighbor_id, weight) in neighbors {
println!("Neighbor {}: weight {:.3}", neighbor_id, weight);
}
// Compute graph statistics
let stats = graph.statistics();
println!("Average degree: {:.2}", stats.avg_degree);
println!("Clustering coefficient: {:.3}", stats.clustering_coeff);
println!("Connected components: {}", stats.num_components);
```
### Sequential Patterns
```rust
// Analyze sequential call patterns
let sequences = graph.find_sequences(min_length: 3)?;
for seq in sequences {
println!("Sequence: {:?}", seq.node_ids);
println!(" Frequency: {}", seq.count);
println!(" Avg interval: {:.2}s", seq.avg_interval);
}
```
</details>
<details>
<summary><b>Tutorial: GNN Training</b></summary>
### Basic Training
```rust
use sevensense_learning::{GnnModel, GnnConfig, TrainingConfig};
let model_config = GnnConfig {
input_dim: 1536, // Embedding dimension
hidden_dim: 256, // Hidden layer size
output_dim: 64, // Output embedding size
num_layers: 3, // GNN layers
dropout: 0.1,
};
let mut model = GnnModel::new(model_config);
let train_config = TrainingConfig {
epochs: 100,
learning_rate: 0.001,
batch_size: 32,
early_stopping: Some(10), // Stop if no improvement for 10 epochs
};
// Train on graph
let history = model.train(&graph, &features, train_config)?;
println!("Final loss: {:.4}", history.final_loss);
println!("Best epoch: {}", history.best_epoch);
```
### Training with Validation
```rust
let (train_graph, val_graph) = split_graph(&graph, 0.8)?;
let history = model.train_with_validation(
&train_graph,
&val_graph,
&features,
train_config,
)?;
// Plot training curves
for (epoch, train_loss, val_loss) in history.iter() {
println!("Epoch {}: train={:.4}, val={:.4}", epoch, train_loss, val_loss);
}
```
### Custom Loss Functions
```rust
use sevensense_learning::{GnnModel, LossFunction};
// Contrastive loss for similarity learning
let loss_fn = LossFunction::Contrastive {
margin: 0.5,
positive_weight: 1.0,
negative_weight: 0.5,
};
model.set_loss_function(loss_fn);
model.train(&graph, &features, config)?;
```
</details>
<details>
<summary><b>Tutorial: Elastic Weight Consolidation (EWC)</b></summary>
### Why EWC?
Standard neural networks suffer from "catastrophic forgetting"—learning new patterns erases old ones. EWC prevents this by protecting important parameters.
### EWC Training
```rust
use sevensense_learning::{GnnModel, EwcConfig};
let ewc_config = EwcConfig {
lambda: 1000.0, // Regularization strength
fisher_samples: 200, // Samples for Fisher estimation
online: true, // Online EWC variant
};
let mut model = GnnModel::new(model_config);
// Train on first dataset
model.train(&graph1, &features1, train_config)?;
// Compute Fisher information (importance weights)
model.compute_fisher(&graph1, &features1, ewc_config.fisher_samples)?;
// Train on second dataset with EWC
model.ewc_train(&graph2, &features2, train_config, ewc_config)?;
// Model remembers patterns from both datasets!
```
### Continual Learning Pipeline
```rust
use sevensense_learning::{ContinualLearner, EwcConfig};
let mut learner = ContinualLearner::new(model, EwcConfig::default());
// Learn from streaming data batches
for batch in data_stream {
let graph = TransitionGraph::from_batch(&batch)?;
learner.learn(&graph, &batch.features)?;
println!("Learned batch {}, total patterns: {}",
batch.id, learner.pattern_count());
}
// Test on all historical patterns
let recall = learner.evaluate_recall(&all_test_data)?;
println!("Recall on all patterns: {:.2}%", recall * 100.0);
```
</details>
<details>
<summary><b>Tutorial: Online Learning</b></summary>
### Incremental Updates
```rust
use sevensense_learning::{GnnModel, OnlineConfig};
let online_config = OnlineConfig {
learning_rate: 0.0001, // Lower LR for stability
momentum: 0.9,
max_updates_per_sample: 5,
replay_buffer_size: 1000,
};
let mut model = GnnModel::new(model_config);
model.enable_online_learning(online_config);
// Process streaming data
for sample in stream {
// Single-sample update
model.online_update(&sample.embedding, &sample.label)?;
if model.updates_count() % 100 == 0 {
println!("Processed {} samples", model.updates_count());
}
}
```
### Experience Replay
```rust
use sevensense_learning::{ReplayBuffer, GnnModel};
let mut buffer = ReplayBuffer::new(1000); // Store 1000 samples
let mut model = GnnModel::new(config);
for sample in stream {
// Add to replay buffer
buffer.add(sample.clone());
// Train on current sample + replay
let replay_batch = buffer.sample(32)?; // 32 random historical samples
let batch = [vec![sample], replay_batch].concat();
model.train_batch(&batch)?;
}
```
</details>
<details>
<summary><b>Tutorial: Pattern Prediction</b></summary>
### Predicting Similar Patterns
```rust
use sevensense_learning::GnnModel;
let model = GnnModel::load("trained_model.bin")?;
// Get learned representation
let embedding = model.encode(&query_features)?;
// Find similar learned patterns
let similar = model.find_similar(&embedding, 10)?;
for (pattern_id, similarity) in similar {
println!("Pattern {}: {:.3} similarity", pattern_id, similarity);
}
```
### Predicting Next Call
```rust
// Given a sequence of calls, predict the next one
let sequence = vec![embedding1, embedding2, embedding3];
let prediction = model.predict_next(&sequence)?;
println!("Predicted next call embedding: {:?}", prediction.embedding);
println!("Confidence: {:.3}", prediction.confidence);
```
### Anomaly Detection
```rust
use sevensense_learning::{GnnModel, AnomalyDetector};
let detector = AnomalyDetector::new(&model);
for embedding in embeddings {
let score = detector.anomaly_score(&embedding)?;
if score > 0.95 {
println!("Anomaly detected! Score: {:.3}", score);
}
}
```
</details>
---
## Configuration
### GnnConfig Parameters
| Parameter | Default | Description |
|-----------|---------|-------------|
| `input_dim` | 1536 | Input embedding dimension |
| `hidden_dim` | 256 | Hidden layer dimension |
| `output_dim` | 64 | Output dimension |
| `num_layers` | 3 | Number of GNN layers |
| `dropout` | 0.1 | Dropout rate |
### EwcConfig Parameters
| Parameter | Default | Description |
|-----------|---------|-------------|
| `lambda` | 1000.0 | Regularization strength |
| `fisher_samples` | 200 | Samples for Fisher estimation |
| `online` | true | Use online EWC variant |
## Architecture
```
Input Embeddings (1536-dim)
┌─────────┐
│ GNN │ ◄── Graph structure (adjacency)
│ Layer 1 │
└────┬────┘
┌────▼────┐
│ GNN │
│ Layer 2 │
└────┬────┘
┌────▼────┐
│ GNN │
│ Layer 3 │
└────┬────┘
Output Embeddings (64-dim)
```
## Links
- **Homepage**: [ruv.io](https://ruv.io)
- **Repository**: [github.com/ruvnet/ruvector](https://github.com/ruvnet/ruvector)
- **Crates.io**: [crates.io/crates/sevensense-learning](https://crates.io/crates/sevensense-learning)
- **Documentation**: [docs.rs/sevensense-learning](https://docs.rs/sevensense-learning)
## License
MIT License - see [LICENSE](../../LICENSE) for details.
---
*Part of the [7sense Bioacoustic Intelligence Platform](https://ruv.io) by rUv*

View File

@@ -0,0 +1,131 @@
//! Benchmarks for GNN operations.
use criterion::{black_box, criterion_group, criterion_main, Criterion, BenchmarkId};
use ndarray::Array2;
// Note: These benchmarks require the crate to compile successfully.
// They test the core GNN operations for performance regression.
fn create_test_features(n: usize, dim: usize) -> Array2<f32> {
Array2::from_elem((n, dim), 0.5)
}
fn create_test_adjacency(n: usize) -> Array2<f32> {
let mut adj = Array2::<f32>::eye(n);
// Add some random edges
for i in 0..n.saturating_sub(1) {
adj[[i, i + 1]] = 0.5;
adj[[i + 1, i]] = 0.5;
}
adj
}
fn benchmark_gcn_forward(c: &mut Criterion) {
let mut group = c.benchmark_group("gcn_forward");
for n in [10, 50, 100, 500].iter() {
group.bench_with_input(BenchmarkId::from_parameter(n), n, |b, &n| {
let features = create_test_features(n, 64);
let adj = create_test_adjacency(n);
b.iter(|| {
// Simple matrix multiplication to simulate GCN forward
let aggregated = black_box(&adj).dot(black_box(&features));
black_box(aggregated)
});
});
}
group.finish();
}
fn benchmark_attention_computation(c: &mut Criterion) {
let mut group = c.benchmark_group("attention");
for n in [10, 50, 100].iter() {
group.bench_with_input(BenchmarkId::from_parameter(n), n, |b, &n| {
let query = create_test_features(n, 64);
let key = create_test_features(n, 64);
b.iter(|| {
// Compute attention scores
let scores = black_box(&query).dot(&black_box(&key).t());
// Softmax (simplified)
let mut result = scores.clone();
for mut row in result.rows_mut() {
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let sum: f32 = row.iter().map(|x| (x - max).exp()).sum();
for x in row.iter_mut() {
*x = (*x - max).exp() / sum;
}
}
black_box(result)
});
});
}
group.finish();
}
fn benchmark_cosine_similarity(c: &mut Criterion) {
c.bench_function("cosine_similarity_256d", |bencher| {
let vec_a: Vec<f32> = (0..256).map(|i| (i as f32).sin()).collect();
let vec_b: Vec<f32> = (0..256).map(|i| (i as f32).cos()).collect();
bencher.iter(|| {
let dot: f32 = black_box(&vec_a).iter().zip(black_box(&vec_b)).map(|(x, y)| x * y).sum();
let norm_a: f32 = vec_a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = vec_b.iter().map(|x| x * x).sum::<f32>().sqrt();
black_box(dot / (norm_a * norm_b))
});
});
}
fn benchmark_info_nce_loss(c: &mut Criterion) {
c.bench_function("info_nce_10_negatives", |b| {
let anchor: Vec<f32> = (0..128).map(|i| (i as f32 * 0.01).sin()).collect();
let positive: Vec<f32> = (0..128).map(|i| (i as f32 * 0.01).sin() + 0.1).collect();
let negatives: Vec<Vec<f32>> = (0..10)
.map(|j| (0..128).map(|i| ((i + j * 10) as f32 * 0.01).cos()).collect())
.collect();
let neg_refs: Vec<&[f32]> = negatives.iter().map(|v| v.as_slice()).collect();
b.iter(|| {
let temp = 0.07;
// Compute cosine similarities
let pos_sim = {
let dot: f32 = black_box(&anchor).iter().zip(black_box(&positive)).map(|(x, y)| x * y).sum();
let norm_a: f32 = anchor.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = positive.iter().map(|x| x * x).sum::<f32>().sqrt();
dot / (norm_a * norm_b) / temp
};
let neg_sims: Vec<f32> = neg_refs.iter().map(|neg| {
let dot: f32 = anchor.iter().zip(*neg).map(|(x, y)| x * y).sum();
let norm_a: f32 = anchor.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = neg.iter().map(|x| x * x).sum::<f32>().sqrt();
dot / (norm_a * norm_b) / temp
}).collect();
// Log-sum-exp
let max_sim = neg_sims.iter().chain(std::iter::once(&pos_sim))
.cloned().fold(f32::NEG_INFINITY, f32::max);
let sum_exp: f32 = std::iter::once(pos_sim).chain(neg_sims)
.map(|s| (s - max_sim).exp()).sum();
black_box(-pos_sim + max_sim + sum_exp.ln())
});
});
}
criterion_group!(
benches,
benchmark_gcn_forward,
benchmark_attention_computation,
benchmark_cosine_similarity,
benchmark_info_nce_loss,
);
criterion_main!(benches);

View File

@@ -0,0 +1,6 @@
//! Application layer for the learning bounded context.
//!
//! Contains business logic, services, and use cases for
//! GNN-based learning and embedding refinement.
pub mod services;

View File

@@ -0,0 +1,829 @@
//! Learning service implementation.
//!
//! Provides the main application service for GNN-based learning,
//! including training, embedding refinement, and edge prediction.
use std::sync::Arc;
use std::time::Instant;
use ndarray::Array2;
use rayon::prelude::*;
use tokio::sync::RwLock;
use tracing::{debug, info, instrument, warn};
use crate::domain::entities::{
EmbeddingId, GnnModelType, LearningConfig, LearningSession, RefinedEmbedding,
TrainingMetrics, TrainingStatus, TransitionGraph,
};
use crate::domain::repository::LearningRepository;
use crate::ewc::{EwcRegularizer, EwcState};
use crate::infrastructure::gnn_model::{GnnError, GnnModel};
use crate::loss;
/// Error type for learning service operations
#[derive(Debug, thiserror::Error)]
pub enum LearningError {
/// Invalid configuration
#[error("Invalid configuration: {0}")]
InvalidConfig(String),
/// Training error
#[error("Training error: {0}")]
TrainingError(String),
/// Model error
#[error("Model error: {0}")]
ModelError(String),
/// Data error
#[error("Data error: {0}")]
DataError(String),
/// Repository error
#[error("Repository error: {0}")]
RepositoryError(#[from] crate::domain::repository::RepositoryError),
/// GNN model error
#[error("GNN error: {0}")]
GnnError(#[from] GnnError),
/// Session not found
#[error("Session not found: {0}")]
SessionNotFound(String),
/// Session already running
#[error("A training session is already running")]
SessionAlreadyRunning,
/// Dimension mismatch
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch { expected: usize, actual: usize },
/// Empty graph
#[error("Graph is empty or invalid")]
EmptyGraph,
/// Internal error
#[error("Internal error: {0}")]
Internal(String),
}
/// Result type for learning operations
pub type LearningResult<T> = Result<T, LearningError>;
/// Main learning service for GNN-based embedding refinement.
///
/// This service manages:
/// - GNN model training on transition graphs
/// - Embedding refinement through message passing
/// - Edge prediction for relationship modeling
/// - Continual learning with EWC regularization
pub struct LearningService {
/// The GNN model
model: Arc<RwLock<GnnModel>>,
/// Service configuration
config: LearningConfig,
/// EWC state for continual learning
ewc_state: Arc<RwLock<Option<EwcState>>>,
/// Optional repository for persistence
repository: Option<Arc<dyn LearningRepository>>,
/// Current active session
current_session: Arc<RwLock<Option<LearningSession>>>,
}
impl LearningService {
/// Create a new learning service with the given configuration
#[must_use]
pub fn new(config: LearningConfig) -> Self {
let model = GnnModel::new(
config.model_type,
config.input_dim,
config.output_dim,
config.hyperparameters.num_layers,
config.hyperparameters.hidden_dim,
config.hyperparameters.num_heads,
config.hyperparameters.dropout,
);
Self {
model: Arc::new(RwLock::new(model)),
config,
ewc_state: Arc::new(RwLock::new(None)),
repository: None,
current_session: Arc::new(RwLock::new(None)),
}
}
/// Create a learning service with a repository
#[must_use]
pub fn with_repository(mut self, repository: Arc<dyn LearningRepository>) -> Self {
self.repository = Some(repository);
self
}
/// Get the current configuration
#[must_use]
pub fn config(&self) -> &LearningConfig {
&self.config
}
/// Get the model type
#[must_use]
pub fn model_type(&self) -> GnnModelType {
self.config.model_type
}
/// Start a new training session
#[instrument(skip(self), err)]
pub async fn start_session(&self) -> LearningResult<String> {
// Check if a session is already running
{
let session = self.current_session.read().await;
if let Some(ref s) = *session {
if s.status.is_active() {
return Err(LearningError::SessionAlreadyRunning);
}
}
}
let mut session = LearningSession::new(self.config.clone());
session.start();
let session_id = session.id.clone();
// Persist if repository available
if let Some(ref repo) = self.repository {
repo.save_session(&session).await?;
}
*self.current_session.write().await = Some(session);
info!(session_id = %session_id, "Started new learning session");
Ok(session_id)
}
/// Train a single epoch on the transition graph
///
/// # Arguments
/// * `graph` - The transition graph to train on
///
/// # Returns
/// Training metrics for the epoch
#[instrument(skip(self, graph), fields(nodes = graph.num_nodes(), edges = graph.num_edges()), err)]
pub async fn train_epoch(&self, graph: &TransitionGraph) -> LearningResult<TrainingMetrics> {
let start_time = Instant::now();
// Validate graph
if graph.num_nodes() == 0 {
return Err(LearningError::EmptyGraph);
}
if let Some(dim) = graph.embedding_dim() {
if dim != self.config.input_dim {
return Err(LearningError::DimensionMismatch {
expected: self.config.input_dim,
actual: dim,
});
}
}
// Ensure we have an active session
let mut session_guard = self.current_session.write().await;
let session = session_guard
.as_mut()
.ok_or_else(|| LearningError::TrainingError("No active session".to_string()))?;
let current_epoch = session.metrics.epoch + 1;
let lr = self.compute_learning_rate(current_epoch);
// Build adjacency matrix
let adj_matrix = self.build_adjacency_matrix(graph);
// Build feature matrix from embeddings
let features = self.build_feature_matrix(graph);
// Forward pass through GNN
let mut model = self.model.write().await;
let output = model.forward(&features, &adj_matrix)?;
// Compute loss using contrastive learning
let (loss, accuracy) = self.compute_loss(graph, &output).await?;
// Compute gradients and update weights
let gradients = self.compute_gradients(graph, &features, &output, &adj_matrix, &model)?;
let grad_norm = self.compute_gradient_norm(&gradients);
// Apply gradient clipping if configured
let clipped_gradients = if let Some(clip_value) = self.config.hyperparameters.gradient_clip {
self.clip_gradients(gradients, clip_value)
} else {
gradients
};
// Update model weights
model.update_weights(&clipped_gradients, lr, self.config.hyperparameters.weight_decay);
// Apply EWC regularization if available
if let Some(ref ewc_state) = *self.ewc_state.read().await {
let ewc_reg = EwcRegularizer::new(self.config.hyperparameters.ewc_lambda);
let ewc_loss = ewc_reg.compute_penalty(&model, ewc_state);
debug!(ewc_loss = ewc_loss, "Applied EWC regularization");
}
let epoch_time_ms = start_time.elapsed().as_millis() as u64;
let metrics = TrainingMetrics {
loss,
accuracy,
epoch: current_epoch,
learning_rate: lr,
validation_loss: None,
validation_accuracy: None,
gradient_norm: Some(grad_norm),
epoch_time_ms,
custom_metrics: Default::default(),
};
// Update session metrics
session.update_metrics(metrics.clone());
// Persist session if repository available
drop(model); // Release write lock before async operation
if let Some(ref repo) = self.repository {
repo.update_session(session).await?;
}
info!(
epoch = current_epoch,
loss = loss,
accuracy = accuracy,
time_ms = epoch_time_ms,
"Completed training epoch"
);
Ok(metrics)
}
/// Refine embeddings using the trained GNN model
///
/// # Arguments
/// * `embeddings` - Input embeddings to refine
///
/// # Returns
/// Refined embeddings with quality scores
#[instrument(skip(self, embeddings), fields(count = embeddings.len()), err)]
pub async fn refine_embeddings(
&self,
embeddings: &[(EmbeddingId, Vec<f32>)],
) -> LearningResult<Vec<RefinedEmbedding>> {
if embeddings.is_empty() {
return Ok(Vec::new());
}
// Validate dimensions
if let Some((_, emb)) = embeddings.first() {
if emb.len() != self.config.input_dim {
return Err(LearningError::DimensionMismatch {
expected: self.config.input_dim,
actual: emb.len(),
});
}
}
let model = self.model.read().await;
// Build a simple graph where each embedding is a node
// Connected based on cosine similarity
let n = embeddings.len();
let similarity_threshold = 0.5;
// Build feature matrix
let mut features = Array2::zeros((n, self.config.input_dim));
for (i, (_, emb)) in embeddings.iter().enumerate() {
for (j, &val) in emb.iter().enumerate() {
features[[i, j]] = val;
}
}
// Build adjacency matrix based on similarity
let mut adj_matrix = Array2::<f32>::eye(n);
for i in 0..n {
for j in (i + 1)..n {
let sim = cosine_similarity(&embeddings[i].1, &embeddings[j].1);
if sim > similarity_threshold {
adj_matrix[[i, j]] = sim;
adj_matrix[[j, i]] = sim;
}
}
}
// Normalize adjacency matrix
let degrees: Vec<f32> = (0..n)
.map(|i| adj_matrix.row(i).sum())
.collect();
for i in 0..n {
for j in 0..n {
if degrees[i] > 0.0 && degrees[j] > 0.0 {
adj_matrix[[i, j]] /= (degrees[i] * degrees[j]).sqrt();
}
}
}
// Forward pass
let output = model.forward(&features, &adj_matrix)?;
// Create refined embeddings
let session_id = self
.current_session
.read()
.await
.as_ref()
.map(|s| s.id.clone());
let refined: Vec<RefinedEmbedding> = embeddings
.par_iter()
.enumerate()
.map(|(i, (id, original))| {
let refined_vec: Vec<f32> = output.row(i).to_vec();
// Compute refinement score based on change magnitude
let delta = original
.iter()
.zip(&refined_vec)
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
.sqrt();
let score = 1.0 / (1.0 + delta); // Higher score for smaller changes
let mut refined = RefinedEmbedding::new(id.clone(), refined_vec, score);
refined.session_id = session_id.clone();
refined.delta_norm = Some(delta);
refined.normalize();
refined
})
.collect();
info!(count = refined.len(), "Refined embeddings");
// Persist if repository available
if let Some(ref repo) = self.repository {
repo.save_refined_embeddings(&refined).await?;
}
Ok(refined)
}
/// Predict edge weight between two embeddings
///
/// # Arguments
/// * `from` - Source embedding
/// * `to` - Target embedding
///
/// # Returns
/// Predicted edge weight (0.0 to 1.0)
#[instrument(skip(self, from, to), err)]
pub async fn predict_edge(&self, from: &[f32], to: &[f32]) -> LearningResult<f32> {
// Validate dimensions
if from.len() != self.config.input_dim {
return Err(LearningError::DimensionMismatch {
expected: self.config.input_dim,
actual: from.len(),
});
}
if to.len() != self.config.input_dim {
return Err(LearningError::DimensionMismatch {
expected: self.config.input_dim,
actual: to.len(),
});
}
let model = self.model.read().await;
// Create a mini-graph with two nodes
let mut features = Array2::zeros((2, self.config.input_dim));
for (j, &val) in from.iter().enumerate() {
features[[0, j]] = val;
}
for (j, &val) in to.iter().enumerate() {
features[[1, j]] = val;
}
// Simple adjacency (self-loops only initially)
let adj_matrix = Array2::<f32>::eye(2);
// Forward pass
let output = model.forward(&features, &adj_matrix)?;
// Compute similarity of refined embeddings
let from_refined: Vec<f32> = output.row(0).to_vec();
let to_refined: Vec<f32> = output.row(1).to_vec();
let similarity = cosine_similarity(&from_refined, &to_refined);
let weight = (similarity + 1.0) / 2.0; // Map from [-1, 1] to [0, 1]
Ok(weight)
}
/// Complete the current training session
#[instrument(skip(self), err)]
pub async fn complete_session(&self) -> LearningResult<()> {
let mut session_guard = self.current_session.write().await;
if let Some(ref mut session) = *session_guard {
session.complete();
// Compute and store Fisher information for EWC
// This would be done in a real implementation with the final model state
if let Some(ref repo) = self.repository {
repo.update_session(session).await?;
}
info!(session_id = %session.id, "Completed learning session");
}
Ok(())
}
/// Fail the current session with an error
#[instrument(skip(self, error), err)]
pub async fn fail_session(&self, error: impl Into<String>) -> LearningResult<()> {
let error_msg = error.into();
let mut session_guard = self.current_session.write().await;
if let Some(ref mut session) = *session_guard {
session.fail(&error_msg);
if let Some(ref repo) = self.repository {
repo.update_session(session).await?;
}
warn!(session_id = %session.id, error = %error_msg, "Failed learning session");
}
Ok(())
}
/// Get the current session status
pub async fn get_session(&self) -> Option<LearningSession> {
self.current_session.read().await.clone()
}
/// Save EWC state from current model for future regularization
#[instrument(skip(self, graph), err)]
pub async fn consolidate_ewc(&self, graph: &TransitionGraph) -> LearningResult<()> {
let model = self.model.read().await;
let fisher = self.compute_fisher_information(&model, graph)?;
let state = EwcState::new(model.get_parameters(), fisher);
*self.ewc_state.write().await = Some(state);
info!("Consolidated EWC state");
Ok(())
}
// =========== Private Helper Methods ===========
fn build_adjacency_matrix(&self, graph: &TransitionGraph) -> Array2<f32> {
let n = graph.num_nodes();
let mut adj = Array2::zeros((n, n));
// Add self-loops
for i in 0..n {
adj[[i, i]] = 1.0;
}
// Add edges
for &(from, to, weight) in &graph.edges {
adj[[from, to]] = weight;
if !graph.directed {
adj[[to, from]] = weight;
}
}
// Symmetric normalization: D^(-1/2) * A * D^(-1/2)
let degrees: Vec<f32> = (0..n).map(|i| adj.row(i).sum()).collect();
for i in 0..n {
for j in 0..n {
if degrees[i] > 0.0 && degrees[j] > 0.0 {
adj[[i, j]] /= (degrees[i] * degrees[j]).sqrt();
}
}
}
adj
}
fn build_feature_matrix(&self, graph: &TransitionGraph) -> Array2<f32> {
let n = graph.num_nodes();
let dim = graph.embedding_dim().unwrap_or(self.config.input_dim);
let mut features = Array2::zeros((n, dim));
for (i, emb) in graph.embeddings.iter().enumerate() {
for (j, &val) in emb.iter().enumerate() {
features[[i, j]] = val;
}
}
features
}
async fn compute_loss(
&self,
graph: &TransitionGraph,
output: &Array2<f32>,
) -> LearningResult<(f32, f32)> {
let n = graph.num_nodes();
if n == 0 {
return Ok((0.0, 0.0));
}
let mut total_loss = 0.0;
let mut correct = 0usize;
let mut total = 0usize;
let hp = &self.config.hyperparameters;
// For each edge, compute contrastive loss
for &(from, to, weight) in &graph.edges {
let anchor: Vec<f32> = output.row(from).to_vec();
let positive: Vec<f32> = output.row(to).to_vec();
// Sample negative nodes (nodes not connected to anchor)
let negatives: Vec<Vec<f32>> = (0..n)
.filter(|&i| i != from && i != to)
.take(hp.negative_ratio)
.map(|i| output.row(i).to_vec())
.collect();
if !negatives.is_empty() {
let neg_refs: Vec<&[f32]> = negatives.iter().map(|v| v.as_slice()).collect();
// InfoNCE loss
let loss = loss::info_nce_loss(&anchor, &positive, &neg_refs, hp.temperature);
total_loss += loss * weight;
}
// Compute accuracy based on whether positive is closer than negatives
let pos_sim = cosine_similarity(&anchor, &positive);
let all_closer = (0..n)
.filter(|&i| i != from && i != to)
.all(|i| {
let neg: Vec<f32> = output.row(i).to_vec();
cosine_similarity(&anchor, &neg) < pos_sim
});
if all_closer {
correct += 1;
}
total += 1;
}
let avg_loss = if graph.edges.is_empty() {
0.0
} else {
total_loss / graph.edges.len() as f32
};
let accuracy = if total == 0 {
0.0
} else {
correct as f32 / total as f32
};
Ok((avg_loss, accuracy))
}
fn compute_gradients(
&self,
_graph: &TransitionGraph,
features: &Array2<f32>,
output: &Array2<f32>,
_adj_matrix: &Array2<f32>,
model: &GnnModel,
) -> LearningResult<Vec<Array2<f32>>> {
// Simplified gradient computation
// In practice, this would use automatic differentiation
let num_layers = model.num_layers();
let mut gradients = Vec::with_capacity(num_layers);
let batch_size = features.nrows() as f32;
for layer_idx in 0..num_layers {
let (in_dim, out_dim) = model.layer_dims(layer_idx);
// Compute gradient approximation based on output variance
// This is a simplified placeholder - real backprop would use chain rule
let output_centered = &output.mapv(|x| x - output.mean().unwrap_or(0.0));
// Approximate gradient as outer product scaled by learning signal
let grad = if layer_idx == 0 {
// Input layer: gradient is features^T * output_signal / batch_size
let output_slice = if output.ncols() >= out_dim {
output_centered.slice(ndarray::s![.., ..out_dim]).to_owned()
} else {
Array2::zeros((output.nrows(), out_dim))
};
let feat_slice = if features.ncols() >= in_dim {
features.slice(ndarray::s![.., ..in_dim]).to_owned()
} else {
Array2::zeros((features.nrows(), in_dim))
};
feat_slice.t().dot(&output_slice) / batch_size
} else {
// Hidden layers: use small random gradients scaled by output variance
let variance = output.var(0.0);
Array2::from_elem((in_dim, out_dim), 0.01 * variance.sqrt())
};
// Reshape to (out_dim, in_dim)
let scaled_grad = grad.t().to_owned();
gradients.push(scaled_grad);
}
Ok(gradients)
}
fn compute_gradient_norm(&self, gradients: &[Array2<f32>]) -> f32 {
gradients
.iter()
.map(|g| g.iter().map(|&x| x * x).sum::<f32>())
.sum::<f32>()
.sqrt()
}
fn clip_gradients(&self, gradients: Vec<Array2<f32>>, max_norm: f32) -> Vec<Array2<f32>> {
let current_norm = self.compute_gradient_norm(&gradients);
if current_norm <= max_norm {
return gradients;
}
let scale = max_norm / current_norm;
gradients.into_iter().map(|g| g * scale).collect()
}
fn compute_learning_rate(&self, epoch: usize) -> f32 {
let base_lr = self.config.hyperparameters.learning_rate;
let total_epochs = self.config.hyperparameters.epochs;
// Cosine annealing schedule
let progress = epoch as f32 / total_epochs as f32;
let cosine_factor = (1.0 + (progress * std::f32::consts::PI).cos()) / 2.0;
base_lr * cosine_factor
}
fn compute_fisher_information(
&self,
_model: &GnnModel,
_graph: &TransitionGraph,
) -> LearningResult<crate::ewc::FisherInformation> {
// Simplified Fisher information computation
// In practice, this would compute the diagonal of the Fisher matrix
Ok(crate::ewc::FisherInformation::default())
}
}
/// Compute cosine similarity between two vectors
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a < 1e-10 || norm_b < 1e-10 {
return 0.0;
}
dot / (norm_a * norm_b)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
let c = vec![0.0, 1.0, 0.0];
assert!((cosine_similarity(&a, &c)).abs() < 1e-6);
let d = vec![-1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &d) + 1.0).abs() < 1e-6);
}
#[tokio::test]
async fn test_learning_service_creation() {
let config = LearningConfig::default();
let service = LearningService::new(config.clone());
assert_eq!(service.model_type(), GnnModelType::Gcn);
assert_eq!(service.config().input_dim, 768);
}
#[tokio::test]
async fn test_start_session() {
let config = LearningConfig::default();
let service = LearningService::new(config);
let session_id = service.start_session().await.unwrap();
assert!(!session_id.is_empty());
let session = service.get_session().await.unwrap();
assert_eq!(session.status, TrainingStatus::Running);
}
#[tokio::test]
async fn test_train_epoch() {
let mut config = LearningConfig::default();
config.input_dim = 8;
config.output_dim = 4;
config.hyperparameters.hidden_dim = 8;
let service = LearningService::new(config);
service.start_session().await.unwrap();
let mut graph = TransitionGraph::new();
graph.add_node(EmbeddingId::new("n1"), vec![0.1; 8], None);
graph.add_node(EmbeddingId::new("n2"), vec![0.2; 8], None);
graph.add_node(EmbeddingId::new("n3"), vec![0.3; 8], None);
graph.add_edge(0, 1, 0.8);
graph.add_edge(1, 2, 0.7);
let metrics = service.train_epoch(&graph).await.unwrap();
assert_eq!(metrics.epoch, 1);
assert!(metrics.loss >= 0.0);
}
#[tokio::test]
async fn test_refine_embeddings() {
let mut config = LearningConfig::default();
config.input_dim = 8;
config.output_dim = 4;
config.hyperparameters.hidden_dim = 8;
let service = LearningService::new(config);
service.start_session().await.unwrap();
let embeddings = vec![
(EmbeddingId::new("e1"), vec![0.1; 8]),
(EmbeddingId::new("e2"), vec![0.2; 8]),
];
let refined = service.refine_embeddings(&embeddings).await.unwrap();
assert_eq!(refined.len(), 2);
assert_eq!(refined[0].dim(), 4); // Output dimension
}
#[tokio::test]
async fn test_predict_edge() {
let mut config = LearningConfig::default();
config.input_dim = 8;
config.output_dim = 4;
config.hyperparameters.hidden_dim = 8;
let service = LearningService::new(config);
let from = vec![0.1; 8];
let to = vec![0.1; 8]; // Same embedding should have high weight
let weight = service.predict_edge(&from, &to).await.unwrap();
assert!(weight >= 0.0 && weight <= 1.0);
}
#[tokio::test]
async fn test_empty_graph_error() {
let config = LearningConfig::default();
let service = LearningService::new(config);
service.start_session().await.unwrap();
let graph = TransitionGraph::new();
let result = service.train_epoch(&graph).await;
assert!(matches!(result, Err(LearningError::EmptyGraph)));
}
#[tokio::test]
async fn test_dimension_mismatch() {
let mut config = LearningConfig::default();
config.input_dim = 768;
let service = LearningService::new(config);
service.start_session().await.unwrap();
let mut graph = TransitionGraph::new();
graph.add_node(EmbeddingId::new("n1"), vec![0.1; 128], None); // Wrong dimension
let result = service.train_epoch(&graph).await;
assert!(matches!(
result,
Err(LearningError::DimensionMismatch { .. })
));
}
}

View File

@@ -0,0 +1,836 @@
//! Domain entities for the learning bounded context.
//!
//! This module defines the core domain entities including:
//! - Learning sessions for tracking training state
//! - GNN model types and training metrics
//! - Transition graphs for embedding relationships
//! - Refined embeddings as output of the learning process
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
/// Unique identifier for an embedding
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct EmbeddingId(pub String);
impl EmbeddingId {
/// Create a new embedding ID
#[must_use]
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
/// Generate a new random embedding ID
#[must_use]
pub fn generate() -> Self {
Self(Uuid::new_v4().to_string())
}
/// Get the inner string value
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
}
impl From<String> for EmbeddingId {
fn from(s: String) -> Self {
Self(s)
}
}
impl From<&str> for EmbeddingId {
fn from(s: &str) -> Self {
Self(s.to_string())
}
}
/// Timestamp type alias for consistency
pub type Timestamp = DateTime<Utc>;
/// Types of GNN models supported by the learning system
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum GnnModelType {
/// Graph Convolutional Network
/// Uses spectral convolutions on graph-structured data
Gcn,
/// GraphSAGE (SAmple and aggreGatE)
/// Learns node embeddings through neighborhood sampling and aggregation
GraphSage,
/// Graph Attention Network
/// Uses attention mechanisms to weight neighbor contributions
Gat,
}
impl Default for GnnModelType {
fn default() -> Self {
Self::Gcn
}
}
impl GnnModelType {
/// Get the number of learnable parameters per layer (approximate)
#[must_use]
pub fn params_per_layer(&self, input_dim: usize, output_dim: usize) -> usize {
match self {
Self::Gcn => input_dim * output_dim + output_dim,
Self::GraphSage => 2 * input_dim * output_dim + output_dim,
Self::Gat => input_dim * output_dim + 2 * output_dim,
}
}
/// Get recommended number of attention heads (only relevant for GAT)
#[must_use]
pub fn recommended_heads(&self) -> usize {
match self {
Self::Gat => 8,
_ => 1,
}
}
}
/// Status of a training session
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TrainingStatus {
/// Session created but training not started
Pending,
/// Training is currently running
Running,
/// Training completed successfully
Completed,
/// Training failed with an error
Failed,
/// Training was paused
Paused,
/// Training was cancelled by user
Cancelled,
}
impl TrainingStatus {
/// Check if the status represents a terminal state
#[must_use]
pub fn is_terminal(&self) -> bool {
matches!(self, Self::Completed | Self::Failed | Self::Cancelled)
}
/// Check if training can be resumed from this status
#[must_use]
pub fn can_resume(&self) -> bool {
matches!(self, Self::Paused)
}
/// Check if training is active
#[must_use]
pub fn is_active(&self) -> bool {
matches!(self, Self::Running)
}
}
/// Metrics collected during training
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingMetrics {
/// Current loss value
pub loss: f32,
/// Training accuracy (0.0 to 1.0)
pub accuracy: f32,
/// Current epoch number
pub epoch: usize,
/// Current learning rate
pub learning_rate: f32,
/// Validation loss (if validation set provided)
pub validation_loss: Option<f32>,
/// Validation accuracy
pub validation_accuracy: Option<f32>,
/// Gradient norm (for monitoring stability)
pub gradient_norm: Option<f32>,
/// Time taken for this epoch in milliseconds
pub epoch_time_ms: u64,
/// Additional custom metrics
#[serde(default)]
pub custom_metrics: HashMap<String, f32>,
}
impl Default for TrainingMetrics {
fn default() -> Self {
Self {
loss: f32::INFINITY,
accuracy: 0.0,
epoch: 0,
learning_rate: 0.001,
validation_loss: None,
validation_accuracy: None,
gradient_norm: None,
epoch_time_ms: 0,
custom_metrics: HashMap::new(),
}
}
}
impl TrainingMetrics {
/// Create new metrics for an epoch
#[must_use]
pub fn new(epoch: usize, loss: f32, accuracy: f32, learning_rate: f32) -> Self {
Self {
loss,
accuracy,
epoch,
learning_rate,
..Default::default()
}
}
/// Set validation metrics
#[must_use]
pub fn with_validation(mut self, loss: f32, accuracy: f32) -> Self {
self.validation_loss = Some(loss);
self.validation_accuracy = Some(accuracy);
self
}
/// Add a custom metric
pub fn add_custom_metric(&mut self, name: impl Into<String>, value: f32) {
self.custom_metrics.insert(name.into(), value);
}
/// Check if training is converging (loss is decreasing)
#[must_use]
pub fn is_improving(&self, previous: &Self) -> bool {
self.loss < previous.loss
}
}
/// Hyperparameters for training
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HyperParameters {
/// Initial learning rate
pub learning_rate: f32,
/// Weight decay (L2 regularization)
pub weight_decay: f32,
/// Dropout probability
pub dropout: f32,
/// Number of training epochs
pub epochs: usize,
/// Batch size for training
pub batch_size: usize,
/// Early stopping patience (epochs without improvement)
pub early_stopping_patience: Option<usize>,
/// Gradient clipping threshold
pub gradient_clip: Option<f32>,
/// Temperature for contrastive loss
pub temperature: f32,
/// Margin for triplet loss
pub triplet_margin: f32,
/// EWC lambda (importance of old task knowledge)
pub ewc_lambda: f32,
/// Number of GNN layers
pub num_layers: usize,
/// Hidden dimension size
pub hidden_dim: usize,
/// Number of attention heads (for GAT)
pub num_heads: usize,
/// Negative sample ratio for contrastive learning
pub negative_ratio: usize,
}
impl Default for HyperParameters {
fn default() -> Self {
Self {
learning_rate: 0.001,
weight_decay: 5e-4,
dropout: 0.5,
epochs: 200,
batch_size: 32,
early_stopping_patience: Some(20),
gradient_clip: Some(1.0),
temperature: 0.07,
triplet_margin: 1.0,
ewc_lambda: 5000.0,
num_layers: 2,
hidden_dim: 256,
num_heads: 8,
negative_ratio: 5,
}
}
}
/// Configuration for the learning service
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LearningConfig {
/// Type of GNN model to use
pub model_type: GnnModelType,
/// Input embedding dimension
pub input_dim: usize,
/// Output embedding dimension
pub output_dim: usize,
/// Training hyperparameters
pub hyperparameters: HyperParameters,
/// Enable mixed precision training
pub mixed_precision: bool,
/// Device to use for training
pub device: Device,
/// Random seed for reproducibility
pub seed: Option<u64>,
/// Enable gradient checkpointing to save memory
pub gradient_checkpointing: bool,
}
impl Default for LearningConfig {
fn default() -> Self {
Self {
model_type: GnnModelType::Gcn,
input_dim: 768,
output_dim: 256,
hyperparameters: HyperParameters::default(),
mixed_precision: false,
device: Device::Cpu,
seed: None,
gradient_checkpointing: false,
}
}
}
/// Device for computation
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum Device {
/// CPU computation
#[default]
Cpu,
/// CUDA GPU computation
Cuda(usize),
/// Metal GPU (Apple Silicon)
Metal,
}
/// A learning session tracking the state of a training run
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LearningSession {
/// Unique session identifier
pub id: String,
/// Type of GNN model being trained
pub model_type: GnnModelType,
/// Current training status
pub status: TrainingStatus,
/// Current training metrics
pub metrics: TrainingMetrics,
/// When the session was started
pub started_at: Timestamp,
/// When the session was last updated
pub updated_at: Timestamp,
/// When the session completed (if applicable)
pub completed_at: Option<Timestamp>,
/// Configuration used for this session
pub config: LearningConfig,
/// History of metrics per epoch
#[serde(default)]
pub metrics_history: Vec<TrainingMetrics>,
/// Best metrics achieved during training
pub best_metrics: Option<TrainingMetrics>,
/// Error message if training failed
pub error_message: Option<String>,
/// Number of checkpoints saved
pub checkpoint_count: usize,
}
impl LearningSession {
/// Create a new learning session
#[must_use]
pub fn new(config: LearningConfig) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
model_type: config.model_type,
status: TrainingStatus::Pending,
metrics: TrainingMetrics::default(),
started_at: now,
updated_at: now,
completed_at: None,
config,
metrics_history: Vec::new(),
best_metrics: None,
error_message: None,
checkpoint_count: 0,
}
}
/// Start the training session
pub fn start(&mut self) {
self.status = TrainingStatus::Running;
self.updated_at = Utc::now();
}
/// Update metrics for a completed epoch
pub fn update_metrics(&mut self, metrics: TrainingMetrics) {
// Update best metrics if this is an improvement
if self.best_metrics.is_none()
|| metrics.loss < self.best_metrics.as_ref().unwrap().loss
{
self.best_metrics = Some(metrics.clone());
}
self.metrics = metrics.clone();
self.metrics_history.push(metrics);
self.updated_at = Utc::now();
}
/// Mark the session as completed
pub fn complete(&mut self) {
self.status = TrainingStatus::Completed;
self.completed_at = Some(Utc::now());
self.updated_at = Utc::now();
}
/// Mark the session as failed
pub fn fail(&mut self, error: impl Into<String>) {
self.status = TrainingStatus::Failed;
self.error_message = Some(error.into());
self.completed_at = Some(Utc::now());
self.updated_at = Utc::now();
}
/// Pause the training session
pub fn pause(&mut self) {
if self.status == TrainingStatus::Running {
self.status = TrainingStatus::Paused;
self.updated_at = Utc::now();
}
}
/// Resume a paused session
pub fn resume(&mut self) {
if self.status == TrainingStatus::Paused {
self.status = TrainingStatus::Running;
self.updated_at = Utc::now();
}
}
/// Get the training duration
#[must_use]
pub fn duration(&self) -> chrono::Duration {
let end = self.completed_at.unwrap_or_else(Utc::now);
end - self.started_at
}
/// Check if training should stop early
#[must_use]
pub fn should_early_stop(&self) -> bool {
if let Some(patience) = self.config.hyperparameters.early_stopping_patience {
if self.metrics_history.len() <= patience {
return false;
}
let best_epoch = self
.metrics_history
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.loss.partial_cmp(&b.loss).unwrap())
.map(|(i, _)| i)
.unwrap_or(0);
self.metrics_history.len() - best_epoch > patience
} else {
false
}
}
}
/// A node in the transition graph
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphNode {
/// Embedding ID for this node
pub id: EmbeddingId,
/// The embedding vector
pub embedding: Vec<f32>,
/// Optional node features
pub features: Option<Vec<f32>>,
/// Node label (for supervised learning)
pub label: Option<usize>,
/// Metadata associated with this node
#[serde(default)]
pub metadata: HashMap<String, String>,
}
impl GraphNode {
/// Create a new graph node
#[must_use]
pub fn new(id: EmbeddingId, embedding: Vec<f32>) -> Self {
Self {
id,
embedding,
features: None,
label: None,
metadata: HashMap::new(),
}
}
/// Get the embedding dimension
#[must_use]
pub fn dim(&self) -> usize {
self.embedding.len()
}
}
/// An edge in the transition graph
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphEdge {
/// Source node index
pub from: usize,
/// Target node index
pub to: usize,
/// Edge weight (e.g., similarity score)
pub weight: f32,
/// Edge type for heterogeneous graphs
pub edge_type: Option<String>,
}
impl GraphEdge {
/// Create a new edge
#[must_use]
pub fn new(from: usize, to: usize, weight: f32) -> Self {
Self {
from,
to,
weight,
edge_type: None,
}
}
/// Create a typed edge
#[must_use]
pub fn typed(from: usize, to: usize, weight: f32, edge_type: impl Into<String>) -> Self {
Self {
from,
to,
weight,
edge_type: Some(edge_type.into()),
}
}
}
/// A graph representing transitions between embeddings
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransitionGraph {
/// Nodes in the graph (embedding IDs)
pub nodes: Vec<EmbeddingId>,
/// Node embeddings (parallel to nodes)
pub embeddings: Vec<Vec<f32>>,
/// Edges as (from_index, to_index, weight) tuples
pub edges: Vec<(usize, usize, f32)>,
/// Optional node labels for supervised learning
#[serde(default)]
pub labels: Vec<Option<usize>>,
/// Number of unique classes (if labeled)
pub num_classes: Option<usize>,
/// Whether the graph is directed
pub directed: bool,
}
impl Default for TransitionGraph {
fn default() -> Self {
Self::new()
}
}
impl TransitionGraph {
/// Create a new empty transition graph
#[must_use]
pub fn new() -> Self {
Self {
nodes: Vec::new(),
embeddings: Vec::new(),
edges: Vec::new(),
labels: Vec::new(),
num_classes: None,
directed: true,
}
}
/// Create an undirected graph
#[must_use]
pub fn undirected() -> Self {
Self {
directed: false,
..Self::new()
}
}
/// Add a node to the graph
pub fn add_node(&mut self, id: EmbeddingId, embedding: Vec<f32>, label: Option<usize>) {
self.nodes.push(id);
self.embeddings.push(embedding);
self.labels.push(label);
}
/// Add an edge to the graph
pub fn add_edge(&mut self, from: usize, to: usize, weight: f32) {
assert!(from < self.nodes.len(), "Invalid 'from' node index");
assert!(to < self.nodes.len(), "Invalid 'to' node index");
self.edges.push((from, to, weight));
// For undirected graphs, add reverse edge
if !self.directed {
self.edges.push((to, from, weight));
}
}
/// Get the number of nodes
#[must_use]
pub fn num_nodes(&self) -> usize {
self.nodes.len()
}
/// Get the number of edges
#[must_use]
pub fn num_edges(&self) -> usize {
self.edges.len()
}
/// Get the embedding dimension (assumes all embeddings have same dimension)
#[must_use]
pub fn embedding_dim(&self) -> Option<usize> {
self.embeddings.first().map(Vec::len)
}
/// Get neighbors of a node
#[must_use]
pub fn neighbors(&self, node_idx: usize) -> Vec<(usize, f32)> {
self.edges
.iter()
.filter(|(from, _, _)| *from == node_idx)
.map(|(_, to, weight)| (*to, *weight))
.collect()
}
/// Get the adjacency list representation
#[must_use]
pub fn adjacency_list(&self) -> Vec<Vec<(usize, f32)>> {
let mut adj = vec![Vec::new(); self.nodes.len()];
for &(from, to, weight) in &self.edges {
adj[from].push((to, weight));
}
adj
}
/// Compute node degrees
#[must_use]
pub fn degrees(&self) -> Vec<usize> {
let mut degrees = vec![0; self.nodes.len()];
for &(from, to, _) in &self.edges {
degrees[from] += 1;
if !self.directed {
degrees[to] += 1;
}
}
degrees
}
/// Validate the graph structure
pub fn validate(&self) -> Result<(), String> {
if self.nodes.len() != self.embeddings.len() {
return Err("Nodes and embeddings count mismatch".to_string());
}
if !self.labels.is_empty() && self.labels.len() != self.nodes.len() {
return Err("Labels count mismatch".to_string());
}
for &(from, to, _) in &self.edges {
if from >= self.nodes.len() || to >= self.nodes.len() {
return Err(format!("Invalid edge: ({from}, {to})"));
}
}
Ok(())
}
}
/// A refined embedding produced by the learning process
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RefinedEmbedding {
/// ID of the original embedding that was refined
pub original_id: EmbeddingId,
/// The refined embedding vector
pub refined_vector: Vec<f32>,
/// Score indicating quality of refinement (0.0 to 1.0)
pub refinement_score: f32,
/// The session that produced this refinement
pub session_id: Option<String>,
/// Timestamp of refinement
pub refined_at: Timestamp,
/// Delta from original (optional, for analysis)
pub delta_norm: Option<f32>,
/// Confidence in the refinement
pub confidence: f32,
}
impl RefinedEmbedding {
/// Create a new refined embedding
#[must_use]
pub fn new(
original_id: EmbeddingId,
refined_vector: Vec<f32>,
refinement_score: f32,
) -> Self {
Self {
original_id,
refined_vector,
refinement_score,
session_id: None,
refined_at: Utc::now(),
delta_norm: None,
confidence: refinement_score,
}
}
/// Compute the delta norm from original embedding
pub fn compute_delta(&mut self, original: &[f32]) {
if original.len() != self.refined_vector.len() {
return;
}
let delta: f32 = original
.iter()
.zip(&self.refined_vector)
.map(|(a, b)| (a - b).powi(2))
.sum();
self.delta_norm = Some(delta.sqrt());
}
/// Get the embedding dimension
#[must_use]
pub fn dim(&self) -> usize {
self.refined_vector.len()
}
/// Normalize the refined vector to unit length
pub fn normalize(&mut self) {
let norm: f32 = self.refined_vector.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for x in &mut self.refined_vector {
*x /= norm;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_id() {
let id = EmbeddingId::new("test-123");
assert_eq!(id.as_str(), "test-123");
let generated = EmbeddingId::generate();
assert!(!generated.as_str().is_empty());
}
#[test]
fn test_gnn_model_type() {
assert_eq!(GnnModelType::default(), GnnModelType::Gcn);
assert_eq!(GnnModelType::Gat.recommended_heads(), 8);
assert_eq!(GnnModelType::Gcn.recommended_heads(), 1);
}
#[test]
fn test_training_status() {
assert!(!TrainingStatus::Running.is_terminal());
assert!(TrainingStatus::Completed.is_terminal());
assert!(TrainingStatus::Failed.is_terminal());
assert!(TrainingStatus::Paused.can_resume());
assert!(!TrainingStatus::Completed.can_resume());
}
#[test]
fn test_training_metrics() {
let metrics = TrainingMetrics::new(1, 0.5, 0.8, 0.001);
assert_eq!(metrics.epoch, 1);
assert_eq!(metrics.loss, 0.5);
let better = TrainingMetrics::new(2, 0.3, 0.9, 0.001);
assert!(better.is_improving(&metrics));
}
#[test]
fn test_learning_session() {
let config = LearningConfig::default();
let mut session = LearningSession::new(config);
assert_eq!(session.status, TrainingStatus::Pending);
session.start();
assert_eq!(session.status, TrainingStatus::Running);
let metrics = TrainingMetrics::new(1, 0.5, 0.8, 0.001);
session.update_metrics(metrics);
assert_eq!(session.metrics_history.len(), 1);
session.complete();
assert_eq!(session.status, TrainingStatus::Completed);
assert!(session.completed_at.is_some());
}
#[test]
fn test_transition_graph() {
let mut graph = TransitionGraph::new();
let emb1 = vec![0.1, 0.2, 0.3];
let emb2 = vec![0.4, 0.5, 0.6];
graph.add_node(EmbeddingId::new("n1"), emb1, Some(0));
graph.add_node(EmbeddingId::new("n2"), emb2, Some(1));
graph.add_edge(0, 1, 0.8);
assert_eq!(graph.num_nodes(), 2);
assert_eq!(graph.num_edges(), 1);
assert_eq!(graph.embedding_dim(), Some(3));
let neighbors = graph.neighbors(0);
assert_eq!(neighbors.len(), 1);
assert_eq!(neighbors[0], (1, 0.8));
assert!(graph.validate().is_ok());
}
#[test]
fn test_refined_embedding() {
let original = vec![1.0, 0.0, 0.0];
let refined = vec![0.9, 0.1, 0.0];
let mut re = RefinedEmbedding::new(
EmbeddingId::new("test"),
refined,
0.95,
);
re.compute_delta(&original);
assert!(re.delta_norm.is_some());
re.normalize();
let norm: f32 = re.refined_vector.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6);
}
#[test]
fn test_early_stopping() {
let mut config = LearningConfig::default();
config.hyperparameters.early_stopping_patience = Some(3);
let mut session = LearningSession::new(config);
session.start();
// Improving metrics
for i in 0..5 {
let loss = 1.0 - (i as f32 * 0.1);
session.update_metrics(TrainingMetrics::new(i, loss, 0.8, 0.001));
}
assert!(!session.should_early_stop());
// Non-improving metrics
for i in 5..10 {
session.update_metrics(TrainingMetrics::new(i, 0.6, 0.8, 0.001));
}
assert!(session.should_early_stop());
}
}

View File

@@ -0,0 +1,7 @@
//! Domain layer for the learning bounded context.
//!
//! Contains core entities, value objects, and repository traits
//! that define the learning domain model.
pub mod entities;
pub mod repository;

View File

@@ -0,0 +1,462 @@
//! Repository traits for the learning domain.
//!
//! Defines the persistence abstraction for learning sessions,
//! refined embeddings, and transition graphs.
use async_trait::async_trait;
use std::sync::Arc;
use super::entities::{
EmbeddingId, LearningSession, RefinedEmbedding, TrainingStatus, TransitionGraph,
};
/// Error type for repository operations
#[derive(Debug, thiserror::Error)]
pub enum RepositoryError {
/// Session not found
#[error("Learning session not found: {0}")]
SessionNotFound(String),
/// Embedding not found
#[error("Embedding not found: {0}")]
EmbeddingNotFound(String),
/// Graph not found or empty
#[error("Transition graph not found")]
GraphNotFound,
/// Serialization error
#[error("Serialization error: {0}")]
SerializationError(String),
/// Storage error
#[error("Storage error: {0}")]
StorageError(String),
/// Connection error
#[error("Connection error: {0}")]
ConnectionError(String),
/// Validation error
#[error("Validation error: {0}")]
ValidationError(String),
/// Concurrent modification error
#[error("Concurrent modification detected for: {0}")]
ConcurrentModification(String),
/// Internal error
#[error("Internal repository error: {0}")]
Internal(String),
}
impl From<serde_json::Error> for RepositoryError {
fn from(e: serde_json::Error) -> Self {
Self::SerializationError(e.to_string())
}
}
/// Result type for repository operations
pub type RepositoryResult<T> = Result<T, RepositoryError>;
/// Repository trait for learning persistence operations.
///
/// Implementors should provide durable storage for:
/// - Learning sessions and their state
/// - Refined embeddings
/// - Transition graphs
#[async_trait]
pub trait LearningRepository: Send + Sync {
// =========== Session Operations ===========
/// Save a learning session
async fn save_session(&self, session: &LearningSession) -> RepositoryResult<()>;
/// Get a learning session by ID
async fn get_session(&self, id: &str) -> RepositoryResult<Option<LearningSession>>;
/// Update an existing session
async fn update_session(&self, session: &LearningSession) -> RepositoryResult<()>;
/// Delete a session
async fn delete_session(&self, id: &str) -> RepositoryResult<()>;
/// List sessions with optional status filter
async fn list_sessions(
&self,
status: Option<TrainingStatus>,
limit: Option<usize>,
) -> RepositoryResult<Vec<LearningSession>>;
// =========== Embedding Operations ===========
/// Save refined embeddings (batch)
async fn save_refined_embeddings(
&self,
embeddings: &[RefinedEmbedding],
) -> RepositoryResult<()>;
/// Get a refined embedding by original ID
async fn get_refined_embedding(
&self,
original_id: &EmbeddingId,
) -> RepositoryResult<Option<RefinedEmbedding>>;
/// Get multiple refined embeddings
async fn get_refined_embeddings(
&self,
ids: &[EmbeddingId],
) -> RepositoryResult<Vec<RefinedEmbedding>>;
/// Delete refined embeddings for a session
async fn delete_refined_embeddings(&self, session_id: &str) -> RepositoryResult<usize>;
// =========== Graph Operations ===========
/// Get the current transition graph
async fn get_transition_graph(&self) -> RepositoryResult<TransitionGraph>;
/// Save a transition graph
async fn save_transition_graph(&self, graph: &TransitionGraph) -> RepositoryResult<()>;
/// Update the transition graph (incremental)
async fn update_transition_graph(&self, graph: &TransitionGraph) -> RepositoryResult<()>;
/// Clear the transition graph
async fn clear_transition_graph(&self) -> RepositoryResult<()>;
// =========== Checkpoint Operations ===========
/// Save a model checkpoint
async fn save_checkpoint(
&self,
session_id: &str,
epoch: usize,
data: &[u8],
) -> RepositoryResult<String>;
/// Load a model checkpoint
async fn load_checkpoint(
&self,
session_id: &str,
epoch: Option<usize>,
) -> RepositoryResult<Option<Vec<u8>>>;
/// List available checkpoints for a session
async fn list_checkpoints(&self, session_id: &str) -> RepositoryResult<Vec<(usize, String)>>;
/// Delete checkpoints for a session
async fn delete_checkpoints(&self, session_id: &str) -> RepositoryResult<usize>;
}
/// Extension trait for repository operations
#[async_trait]
pub trait LearningRepositoryExt: LearningRepository {
/// Get the latest session for a model type
async fn get_latest_session(
&self,
model_type: crate::GnnModelType,
) -> RepositoryResult<Option<LearningSession>> {
let sessions = self.list_sessions(None, Some(100)).await?;
Ok(sessions
.into_iter()
.filter(|s| s.model_type == model_type)
.max_by_key(|s| s.started_at))
}
/// Get all completed sessions
async fn get_completed_sessions(&self) -> RepositoryResult<Vec<LearningSession>> {
self.list_sessions(Some(TrainingStatus::Completed), None).await
}
/// Check if any session is currently running
async fn has_running_session(&self) -> RepositoryResult<bool> {
let sessions = self.list_sessions(Some(TrainingStatus::Running), Some(1)).await?;
Ok(!sessions.is_empty())
}
/// Get embeddings refined in a specific session
async fn get_session_embeddings(
&self,
session_id: &str,
) -> RepositoryResult<Vec<RefinedEmbedding>> {
// Default implementation - may be overridden for efficiency
let session = self.get_session(session_id).await?;
if session.is_none() {
return Err(RepositoryError::SessionNotFound(session_id.to_string()));
}
// This would need to be implemented properly in concrete implementations
Ok(Vec::new())
}
}
// Blanket implementation
impl<T: LearningRepository + ?Sized> LearningRepositoryExt for T {}
/// A thread-safe repository handle
pub type DynLearningRepository = Arc<dyn LearningRepository>;
/// Unit of work pattern for transactional operations
#[async_trait]
pub trait UnitOfWork: Send + Sync {
/// Begin a transaction
async fn begin(&self) -> RepositoryResult<()>;
/// Commit the transaction
async fn commit(&self) -> RepositoryResult<()>;
/// Rollback the transaction
async fn rollback(&self) -> RepositoryResult<()>;
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use tokio::sync::RwLock;
/// In-memory implementation for testing
struct InMemoryRepository {
sessions: RwLock<HashMap<String, LearningSession>>,
embeddings: RwLock<HashMap<String, RefinedEmbedding>>,
graph: RwLock<Option<TransitionGraph>>,
checkpoints: RwLock<HashMap<String, Vec<(usize, Vec<u8>)>>>,
}
impl InMemoryRepository {
fn new() -> Self {
Self {
sessions: RwLock::new(HashMap::new()),
embeddings: RwLock::new(HashMap::new()),
graph: RwLock::new(None),
checkpoints: RwLock::new(HashMap::new()),
}
}
}
#[async_trait]
impl LearningRepository for InMemoryRepository {
async fn save_session(&self, session: &LearningSession) -> RepositoryResult<()> {
let mut sessions = self.sessions.write().await;
sessions.insert(session.id.clone(), session.clone());
Ok(())
}
async fn get_session(&self, id: &str) -> RepositoryResult<Option<LearningSession>> {
let sessions = self.sessions.read().await;
Ok(sessions.get(id).cloned())
}
async fn update_session(&self, session: &LearningSession) -> RepositoryResult<()> {
self.save_session(session).await
}
async fn delete_session(&self, id: &str) -> RepositoryResult<()> {
let mut sessions = self.sessions.write().await;
sessions.remove(id);
Ok(())
}
async fn list_sessions(
&self,
status: Option<TrainingStatus>,
limit: Option<usize>,
) -> RepositoryResult<Vec<LearningSession>> {
let sessions = self.sessions.read().await;
let mut result: Vec<_> = sessions
.values()
.filter(|s| status.map_or(true, |st| s.status == st))
.cloned()
.collect();
result.sort_by(|a, b| b.started_at.cmp(&a.started_at));
if let Some(limit) = limit {
result.truncate(limit);
}
Ok(result)
}
async fn save_refined_embeddings(
&self,
embeddings: &[RefinedEmbedding],
) -> RepositoryResult<()> {
let mut store = self.embeddings.write().await;
for emb in embeddings {
store.insert(emb.original_id.0.clone(), emb.clone());
}
Ok(())
}
async fn get_refined_embedding(
&self,
original_id: &EmbeddingId,
) -> RepositoryResult<Option<RefinedEmbedding>> {
let store = self.embeddings.read().await;
Ok(store.get(&original_id.0).cloned())
}
async fn get_refined_embeddings(
&self,
ids: &[EmbeddingId],
) -> RepositoryResult<Vec<RefinedEmbedding>> {
let store = self.embeddings.read().await;
Ok(ids
.iter()
.filter_map(|id| store.get(&id.0).cloned())
.collect())
}
async fn delete_refined_embeddings(&self, _session_id: &str) -> RepositoryResult<usize> {
let mut store = self.embeddings.write().await;
let count = store.len();
store.clear();
Ok(count)
}
async fn get_transition_graph(&self) -> RepositoryResult<TransitionGraph> {
let graph = self.graph.read().await;
graph.clone().ok_or(RepositoryError::GraphNotFound)
}
async fn save_transition_graph(&self, graph: &TransitionGraph) -> RepositoryResult<()> {
let mut store = self.graph.write().await;
*store = Some(graph.clone());
Ok(())
}
async fn update_transition_graph(&self, graph: &TransitionGraph) -> RepositoryResult<()> {
self.save_transition_graph(graph).await
}
async fn clear_transition_graph(&self) -> RepositoryResult<()> {
let mut store = self.graph.write().await;
*store = None;
Ok(())
}
async fn save_checkpoint(
&self,
session_id: &str,
epoch: usize,
data: &[u8],
) -> RepositoryResult<String> {
let mut store = self.checkpoints.write().await;
let checkpoints = store.entry(session_id.to_string()).or_default();
checkpoints.push((epoch, data.to_vec()));
Ok(format!("{session_id}-{epoch}"))
}
async fn load_checkpoint(
&self,
session_id: &str,
epoch: Option<usize>,
) -> RepositoryResult<Option<Vec<u8>>> {
let store = self.checkpoints.read().await;
if let Some(checkpoints) = store.get(session_id) {
if let Some(epoch) = epoch {
return Ok(checkpoints
.iter()
.find(|(e, _)| *e == epoch)
.map(|(_, d)| d.clone()));
}
return Ok(checkpoints.last().map(|(_, d)| d.clone()));
}
Ok(None)
}
async fn list_checkpoints(
&self,
session_id: &str,
) -> RepositoryResult<Vec<(usize, String)>> {
let store = self.checkpoints.read().await;
if let Some(checkpoints) = store.get(session_id) {
return Ok(checkpoints
.iter()
.map(|(e, _)| (*e, format!("{session_id}-{e}")))
.collect());
}
Ok(Vec::new())
}
async fn delete_checkpoints(&self, session_id: &str) -> RepositoryResult<usize> {
let mut store = self.checkpoints.write().await;
if let Some(checkpoints) = store.remove(session_id) {
return Ok(checkpoints.len());
}
Ok(0)
}
}
#[tokio::test]
async fn test_in_memory_repository() {
let repo = InMemoryRepository::new();
let config = crate::LearningConfig::default();
let session = crate::LearningSession::new(config);
// Save and retrieve session
repo.save_session(&session).await.unwrap();
let retrieved = repo.get_session(&session.id).await.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().id, session.id);
// List sessions
let sessions = repo.list_sessions(None, None).await.unwrap();
assert_eq!(sessions.len(), 1);
// Delete session
repo.delete_session(&session.id).await.unwrap();
let retrieved = repo.get_session(&session.id).await.unwrap();
assert!(retrieved.is_none());
}
#[tokio::test]
async fn test_transition_graph_operations() {
let repo = InMemoryRepository::new();
// Graph should not exist initially
assert!(repo.get_transition_graph().await.is_err());
// Save graph
let mut graph = TransitionGraph::new();
graph.add_node(
crate::EmbeddingId::new("n1"),
vec![0.1, 0.2, 0.3],
None,
);
repo.save_transition_graph(&graph).await.unwrap();
// Retrieve graph
let retrieved = repo.get_transition_graph().await.unwrap();
assert_eq!(retrieved.num_nodes(), 1);
// Clear graph
repo.clear_transition_graph().await.unwrap();
assert!(repo.get_transition_graph().await.is_err());
}
#[tokio::test]
async fn test_checkpoint_operations() {
let repo = InMemoryRepository::new();
let session_id = "test-session";
// Save checkpoints
repo.save_checkpoint(session_id, 1, b"data1").await.unwrap();
repo.save_checkpoint(session_id, 2, b"data2").await.unwrap();
// List checkpoints
let checkpoints = repo.list_checkpoints(session_id).await.unwrap();
assert_eq!(checkpoints.len(), 2);
// Load specific checkpoint
let data = repo.load_checkpoint(session_id, Some(1)).await.unwrap();
assert_eq!(data, Some(b"data1".to_vec()));
// Load latest checkpoint
let data = repo.load_checkpoint(session_id, None).await.unwrap();
assert_eq!(data, Some(b"data2".to_vec()));
// Delete checkpoints
let count = repo.delete_checkpoints(session_id).await.unwrap();
assert_eq!(count, 2);
}
}

View File

@@ -0,0 +1,610 @@
//! Elastic Weight Consolidation (EWC) for continual learning.
//!
//! EWC prevents catastrophic forgetting by regularizing weight updates
//! based on their importance for previously learned tasks.
//!
//! Reference: Kirkpatrick et al., "Overcoming catastrophic forgetting in neural networks", PNAS 2017
use serde::{Deserialize, Serialize};
use crate::infrastructure::gnn_model::GnnModel;
/// Fisher Information matrix (diagonal approximation).
///
/// Stores the importance weights for each parameter in the model.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct FisherInformation {
/// Diagonal elements of the Fisher matrix per layer
pub diagonal: Vec<Vec<f32>>,
/// Optional parameter names for debugging
#[serde(default)]
pub param_names: Vec<String>,
/// Number of samples used to compute Fisher
pub num_samples: usize,
}
impl FisherInformation {
/// Create a new Fisher information structure
#[must_use]
pub fn new(num_layers: usize) -> Self {
Self {
diagonal: vec![Vec::new(); num_layers],
param_names: Vec::new(),
num_samples: 0,
}
}
/// Create from parameter gradients (diagonal approximation)
///
/// F_ii = E[(∂L/∂θ_i)²]
#[must_use]
pub fn from_gradients(gradients: &[Vec<f32>]) -> Self {
let diagonal: Vec<Vec<f32>> = gradients
.iter()
.map(|grads| grads.iter().map(|g| g * g).collect())
.collect();
Self {
diagonal,
param_names: Vec::new(),
num_samples: 1,
}
}
/// Update Fisher with new gradient samples (online estimation)
pub fn update(&mut self, gradients: &[Vec<f32>]) {
let n = self.num_samples as f32;
for (layer_idx, grads) in gradients.iter().enumerate() {
if layer_idx >= self.diagonal.len() {
self.diagonal.push(grads.iter().map(|g| g * g).collect());
} else {
// Running average: F_new = (n * F_old + g²) / (n + 1)
for (i, &g) in grads.iter().enumerate() {
if i >= self.diagonal[layer_idx].len() {
self.diagonal[layer_idx].push(g * g);
} else {
let old_val = self.diagonal[layer_idx][i];
self.diagonal[layer_idx][i] = (n * old_val + g * g) / (n + 1.0);
}
}
}
}
self.num_samples += 1;
}
/// Get the importance of a parameter
#[must_use]
pub fn get_importance(&self, layer_idx: usize, param_idx: usize) -> f32 {
self.diagonal
.get(layer_idx)
.and_then(|layer| layer.get(param_idx))
.copied()
.unwrap_or(0.0)
}
/// Get total importance (sum of all diagonal elements)
#[must_use]
pub fn total_importance(&self) -> f32 {
self.diagonal.iter().flat_map(|l| l.iter()).sum()
}
/// Normalize Fisher information
pub fn normalize(&mut self) {
let total = self.total_importance();
if total > 1e-10 {
for layer in &mut self.diagonal {
for val in layer.iter_mut() {
*val /= total;
}
}
}
}
/// Apply decay to Fisher information (for progressive consolidation)
pub fn decay(&mut self, factor: f32) {
for layer in &mut self.diagonal {
for val in layer.iter_mut() {
*val *= factor;
}
}
}
}
/// State snapshot for EWC regularization.
///
/// Stores the model parameters and their importance from previous tasks.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EwcState {
/// Optimal parameters for previous task(s)
pub optimal_params: Vec<Vec<f32>>,
/// Fisher information (importance weights)
pub fisher: FisherInformation,
/// Task identifier
pub task_id: Option<String>,
/// Number of tasks consolidated
pub num_tasks: usize,
}
impl EwcState {
/// Create a new EWC state from model parameters and Fisher information
#[must_use]
pub fn new(params: Vec<f32>, fisher: FisherInformation) -> Self {
Self {
optimal_params: vec![params],
fisher,
task_id: None,
num_tasks: 1,
}
}
/// Create from model with computed Fisher
#[must_use]
pub fn from_model(model: &GnnModel, gradients: &[Vec<f32>]) -> Self {
let params = model.get_parameters();
let fisher = FisherInformation::from_gradients(gradients);
Self::new(params, fisher)
}
/// Merge with another EWC state (for multi-task learning)
pub fn merge(&mut self, other: &EwcState) {
// Add new optimal params
self.optimal_params.extend(other.optimal_params.clone());
self.num_tasks += other.num_tasks;
// Merge Fisher information (average)
for (layer_idx, other_layer) in other.fisher.diagonal.iter().enumerate() {
if layer_idx >= self.fisher.diagonal.len() {
self.fisher.diagonal.push(other_layer.clone());
} else {
for (i, &val) in other_layer.iter().enumerate() {
if i >= self.fisher.diagonal[layer_idx].len() {
self.fisher.diagonal[layer_idx].push(val);
} else {
// Average Fisher values
let old_n = (self.num_tasks - other.num_tasks) as f32;
let new_n = other.num_tasks as f32;
let total_n = self.num_tasks as f32;
self.fisher.diagonal[layer_idx][i] =
(old_n * self.fisher.diagonal[layer_idx][i] + new_n * val) / total_n;
}
}
}
}
}
}
/// EWC regularizer for computing the penalty term.
#[derive(Debug, Clone)]
pub struct EwcRegularizer {
/// Lambda coefficient (importance of old task knowledge)
lambda: f32,
/// Optional per-layer lambda scaling
layer_lambdas: Option<Vec<f32>>,
}
impl Default for EwcRegularizer {
fn default() -> Self {
Self {
lambda: 5000.0,
layer_lambdas: None,
}
}
}
impl EwcRegularizer {
/// Create a new EWC regularizer with the given lambda
#[must_use]
pub fn new(lambda: f32) -> Self {
Self {
lambda,
layer_lambdas: None,
}
}
/// Create with per-layer lambda scaling
#[must_use]
pub fn with_layer_lambdas(mut self, lambdas: Vec<f32>) -> Self {
self.layer_lambdas = Some(lambdas);
self
}
/// Get lambda for a specific layer
fn get_layer_lambda(&self, layer_idx: usize) -> f32 {
self.layer_lambdas
.as_ref()
.and_then(|l| l.get(layer_idx))
.copied()
.unwrap_or(self.lambda)
}
/// Compute the EWC penalty term.
///
/// L_EWC = (λ/2) * Σ F_i * (θ_i - θ*_i)²
///
/// # Arguments
/// * `model` - Current model
/// * `ewc_state` - Saved EWC state from previous task
///
/// # Returns
/// The EWC penalty value to add to the loss
#[must_use]
pub fn compute_penalty(&self, model: &GnnModel, ewc_state: &EwcState) -> f32 {
let current_params = model.get_parameters();
// Flatten optimal params (use the most recent one if multiple)
let optimal_params = ewc_state
.optimal_params
.last()
.map(|p| p.as_slice())
.unwrap_or(&[]);
if current_params.len() != optimal_params.len() {
return 0.0;
}
// Compute penalty: (λ/2) * Σ F_i * (θ_i - θ*_i)²
let mut penalty = 0.0;
let mut param_idx = 0;
for (layer_idx, layer_fisher) in ewc_state.fisher.diagonal.iter().enumerate() {
let _layer_lambda = self.get_layer_lambda(layer_idx);
for &fisher_val in layer_fisher {
if param_idx < current_params.len() && param_idx < optimal_params.len() {
let diff = current_params[param_idx] - optimal_params[param_idx];
penalty += fisher_val * diff * diff;
}
param_idx += 1;
}
}
(self.lambda / 2.0) * penalty
}
/// Compute EWC gradient contribution.
///
/// ∂L_EWC/∂θ_i = λ * F_i * (θ_i - θ*_i)
#[must_use]
pub fn compute_gradient(&self, model: &GnnModel, ewc_state: &EwcState) -> Vec<f32> {
let current_params = model.get_parameters();
let optimal_params = ewc_state
.optimal_params
.last()
.map(|p| p.as_slice())
.unwrap_or(&[]);
let mut gradient = vec![0.0; current_params.len()];
if current_params.len() != optimal_params.len() {
return gradient;
}
let mut param_idx = 0;
for (layer_idx, layer_fisher) in ewc_state.fisher.diagonal.iter().enumerate() {
let layer_lambda = self.get_layer_lambda(layer_idx);
for &fisher_val in layer_fisher {
if param_idx < current_params.len() {
let diff = current_params[param_idx] - optimal_params[param_idx];
gradient[param_idx] = layer_lambda * fisher_val * diff;
}
param_idx += 1;
}
}
gradient
}
}
/// Online EWC implementation (for streaming/continual learning).
///
/// Uses a running estimate of the Fisher information.
#[derive(Debug, Clone)]
pub struct OnlineEwc {
/// Current EWC state
state: Option<EwcState>,
/// EWC lambda
lambda: f32,
/// Gamma for Fisher decay (between tasks)
gamma: f32,
}
impl OnlineEwc {
/// Create a new Online EWC instance
#[must_use]
pub fn new(lambda: f32, gamma: f32) -> Self {
Self {
state: None,
lambda,
gamma,
}
}
/// Update the EWC state after completing a task
pub fn update(&mut self, model: &GnnModel, gradients: &[Vec<f32>]) {
let new_fisher = FisherInformation::from_gradients(gradients);
let params = model.get_parameters();
if let Some(ref mut state) = self.state {
// Decay old Fisher
state.fisher.decay(self.gamma);
// Add new Fisher (scaled by (1 - gamma))
for (layer_idx, new_layer) in new_fisher.diagonal.iter().enumerate() {
if layer_idx >= state.fisher.diagonal.len() {
state.fisher.diagonal.push(new_layer.clone());
} else {
for (i, &val) in new_layer.iter().enumerate() {
if i >= state.fisher.diagonal[layer_idx].len() {
state.fisher.diagonal[layer_idx].push((1.0 - self.gamma) * val);
} else {
state.fisher.diagonal[layer_idx][i] += (1.0 - self.gamma) * val;
}
}
}
}
// Update optimal params (keep only the latest)
state.optimal_params = vec![params];
state.num_tasks += 1;
} else {
self.state = Some(EwcState::new(params, new_fisher));
}
}
/// Compute the EWC penalty
#[must_use]
pub fn compute_penalty(&self, model: &GnnModel) -> f32 {
if let Some(ref state) = self.state {
let regularizer = EwcRegularizer::new(self.lambda);
regularizer.compute_penalty(model, state)
} else {
0.0
}
}
/// Compute the EWC gradient contribution
#[must_use]
pub fn compute_gradient(&self, model: &GnnModel) -> Vec<f32> {
if let Some(ref state) = self.state {
let regularizer = EwcRegularizer::new(self.lambda);
regularizer.compute_gradient(model, state)
} else {
vec![0.0; model.get_parameters().len()]
}
}
/// Get the current state
#[must_use]
pub fn state(&self) -> Option<&EwcState> {
self.state.as_ref()
}
}
/// Progress & Compress (P&C) - improved EWC variant.
///
/// Alternates between progress (learning new task) and compress (consolidation).
#[derive(Debug, Clone)]
pub struct ProgressAndCompress {
/// Knowledge base (compressed knowledge from all tasks)
knowledge_base: Option<EwcState>,
/// Active column (for current task)
active_params: Option<Vec<f32>>,
/// EWC lambda
lambda: f32,
}
impl ProgressAndCompress {
/// Create a new P&C instance
#[must_use]
pub fn new(lambda: f32) -> Self {
Self {
knowledge_base: None,
active_params: None,
lambda,
}
}
/// Begin progress phase (start learning new task)
pub fn begin_progress(&mut self, model: &GnnModel) {
self.active_params = Some(model.get_parameters());
}
/// End progress phase and begin compress
pub fn compress(&mut self, model: &GnnModel, gradients: &[Vec<f32>]) {
let fisher = FisherInformation::from_gradients(gradients);
let params = model.get_parameters();
let new_state = EwcState::new(params, fisher);
if let Some(ref mut kb) = self.knowledge_base {
kb.merge(&new_state);
} else {
self.knowledge_base = Some(new_state);
}
self.active_params = None;
}
/// Compute penalty during progress phase
#[must_use]
pub fn compute_penalty(&self, model: &GnnModel) -> f32 {
if let Some(ref kb) = self.knowledge_base {
let regularizer = EwcRegularizer::new(self.lambda);
regularizer.compute_penalty(model, kb)
} else {
0.0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{GnnModelType, LearningConfig};
#[test]
fn test_fisher_information() {
let mut fisher = FisherInformation::new(2);
// Update with gradients
fisher.update(&[vec![1.0, 2.0], vec![3.0]]);
fisher.update(&[vec![2.0, 1.0], vec![4.0]]);
assert_eq!(fisher.num_samples, 2);
assert_eq!(fisher.diagonal.len(), 2);
// Check averaging
assert!(fisher.get_importance(0, 0) > 0.0);
}
#[test]
fn test_fisher_from_gradients() {
let gradients = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0]];
let fisher = FisherInformation::from_gradients(&gradients);
assert_eq!(fisher.diagonal[0], vec![1.0, 4.0, 9.0]);
assert_eq!(fisher.diagonal[1], vec![16.0, 25.0]);
}
#[test]
fn test_fisher_normalize() {
let mut fisher = FisherInformation::from_gradients(&[vec![3.0, 4.0]]);
fisher.normalize();
let total: f32 = fisher.diagonal.iter().flat_map(|l| l.iter()).sum();
assert!((total - 1.0).abs() < 1e-6);
}
#[test]
fn test_ewc_state() {
let params = vec![1.0, 2.0, 3.0];
let fisher = FisherInformation::from_gradients(&[vec![0.1, 0.2, 0.3]]);
let state = EwcState::new(params.clone(), fisher);
assert_eq!(state.optimal_params.len(), 1);
assert_eq!(state.num_tasks, 1);
}
#[test]
fn test_ewc_state_merge() {
let mut state1 = EwcState::new(
vec![1.0, 2.0],
FisherInformation::from_gradients(&[vec![0.1, 0.2]]),
);
let state2 = EwcState::new(
vec![1.5, 2.5],
FisherInformation::from_gradients(&[vec![0.3, 0.4]]),
);
state1.merge(&state2);
assert_eq!(state1.num_tasks, 2);
assert_eq!(state1.optimal_params.len(), 2);
}
#[test]
fn test_ewc_regularizer() {
let mut config = LearningConfig::default();
config.input_dim = 4;
config.output_dim = 2;
config.hyperparameters.num_layers = 1;
config.hyperparameters.hidden_dim = 4;
let model = crate::infrastructure::gnn_model::GnnModel::new(
GnnModelType::Gcn,
4, 2, 1, 4, 1, 0.0,
);
let params = model.get_parameters();
let fisher = FisherInformation::from_gradients(&[vec![0.1; params.len()]]);
let ewc_state = EwcState::new(params.clone(), fisher);
let regularizer = EwcRegularizer::new(1000.0);
let penalty = regularizer.compute_penalty(&model, &ewc_state);
// Penalty should be 0 when params haven't changed
assert!(penalty.abs() < 1e-6);
}
#[test]
fn test_ewc_gradient() {
let model = crate::infrastructure::gnn_model::GnnModel::new(
GnnModelType::Gcn,
4, 2, 1, 4, 1, 0.0,
);
// Create state with slightly different params
let mut optimal_params = model.get_parameters();
for p in &mut optimal_params {
*p += 0.1;
}
let fisher = FisherInformation::from_gradients(&[vec![1.0; optimal_params.len()]]);
let ewc_state = EwcState::new(optimal_params, fisher);
let regularizer = EwcRegularizer::new(1.0);
let gradient = regularizer.compute_gradient(&model, &ewc_state);
// Gradient should push towards optimal params
assert!(!gradient.is_empty());
for &g in &gradient {
// Gradient should be non-zero since params differ
assert!(g.abs() > 0.0);
}
}
#[test]
fn test_online_ewc() {
let model = crate::infrastructure::gnn_model::GnnModel::new(
GnnModelType::Gcn,
4, 2, 1, 4, 1, 0.0,
);
let mut online = OnlineEwc::new(1000.0, 0.9);
// Initially no penalty
assert_eq!(online.compute_penalty(&model), 0.0);
// Update after task 1
let gradients = vec![vec![0.1; 20]];
online.update(&model, &gradients);
assert!(online.state().is_some());
assert_eq!(online.state().unwrap().num_tasks, 1);
}
#[test]
fn test_progress_and_compress() {
let model = crate::infrastructure::gnn_model::GnnModel::new(
GnnModelType::Gcn,
4, 2, 1, 4, 1, 0.0,
);
let mut pc = ProgressAndCompress::new(1000.0);
// Begin progress
pc.begin_progress(&model);
assert!(pc.active_params.is_some());
// Compress
let gradients = vec![vec![0.1; 20]];
pc.compress(&model, &gradients);
assert!(pc.active_params.is_none());
assert!(pc.knowledge_base.is_some());
}
#[test]
fn test_fisher_decay() {
let mut fisher = FisherInformation::from_gradients(&[vec![1.0, 2.0]]);
fisher.decay(0.5);
assert_eq!(fisher.diagonal[0], vec![0.5, 2.0]);
}
}

View File

@@ -0,0 +1,576 @@
//! Attention mechanisms for GNN models.
//!
//! This module provides attention layers and mechanisms including:
//! - Single-head attention
//! - Multi-head attention
//! - Graph-level attention readout
use ndarray::{Array1, Array2, Axis};
use rand::Rng;
use rand_distr::{Distribution, Uniform};
/// Error type for attention operations
#[derive(Debug, thiserror::Error)]
pub enum AttentionError {
/// Dimension mismatch
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch { expected: usize, actual: usize },
/// Invalid configuration
#[error("Invalid attention configuration: {0}")]
InvalidConfig(String),
/// Computation error
#[error("Attention computation error: {0}")]
ComputationError(String),
}
/// Result type for attention operations
pub type AttentionResult<T> = Result<T, AttentionError>;
/// Single-head attention layer
#[derive(Debug, Clone)]
pub struct AttentionLayer {
/// Query weight matrix
query_weights: Array2<f32>,
/// Key weight matrix
key_weights: Array2<f32>,
/// Value weight matrix
value_weights: Array2<f32>,
/// Attention dimension
attention_dim: usize,
/// Input dimension
input_dim: usize,
/// Output dimension
output_dim: usize,
/// Scaling factor
scale: f32,
}
impl AttentionLayer {
/// Create a new attention layer
#[must_use]
pub fn new(input_dim: usize, attention_dim: usize) -> Self {
let query_weights = xavier_init(input_dim, attention_dim);
let key_weights = xavier_init(input_dim, attention_dim);
let value_weights = xavier_init(input_dim, attention_dim);
let scale = (attention_dim as f32).sqrt();
Self {
query_weights,
key_weights,
value_weights,
attention_dim,
input_dim,
output_dim: attention_dim,
scale,
}
}
/// Get the output dimension
#[must_use]
pub fn output_dim(&self) -> usize {
self.output_dim
}
/// Compute attention scores
pub fn compute_attention(
&self,
query: &Array2<f32>,
key: &Array2<f32>,
mask: Option<&Array2<f32>>,
) -> AttentionResult<Array2<f32>> {
// Q * K^T / sqrt(d_k)
let scores = query.dot(&key.t()) / self.scale;
// Apply mask if provided (set masked positions to -inf)
let scores = if let Some(mask) = mask {
let mut masked = scores;
for i in 0..masked.nrows() {
for j in 0..masked.ncols() {
if mask[[i, j]] == 0.0 {
masked[[i, j]] = f32::NEG_INFINITY;
}
}
}
masked
} else {
scores
};
// Softmax
let attention_weights = softmax_2d(&scores);
Ok(attention_weights)
}
/// Forward pass through the attention layer
pub fn forward(&self, features: &Array2<f32>) -> AttentionResult<Array2<f32>> {
if features.ncols() != self.input_dim {
return Err(AttentionError::DimensionMismatch {
expected: self.input_dim,
actual: features.ncols(),
});
}
// Compute Q, K, V
let q = features.dot(&self.query_weights);
let k = features.dot(&self.key_weights);
let v = features.dot(&self.value_weights);
// Compute attention weights
let attention_weights = self.compute_attention(&q, &k, None)?;
// Apply attention to values
let output = attention_weights.dot(&v);
Ok(output)
}
/// Forward pass with explicit Q, K, V
pub fn forward_qkv(
&self,
query: &Array2<f32>,
key: &Array2<f32>,
value: &Array2<f32>,
mask: Option<&Array2<f32>>,
) -> AttentionResult<Array2<f32>> {
// Transform Q, K, V
let q = query.dot(&self.query_weights);
let k = key.dot(&self.key_weights);
let v = value.dot(&self.value_weights);
// Compute attention weights
let attention_weights = self.compute_attention(&q, &k, mask)?;
// Apply attention to values
let output = attention_weights.dot(&v);
Ok(output)
}
/// Graph-level readout using attention
pub fn graph_readout(&self, node_features: &Array2<f32>) -> AttentionResult<Array1<f32>> {
// Compute attention-weighted mean of node features
let attended = self.forward(node_features)?;
// Mean over nodes
let mean = attended.mean_axis(Axis(0)).unwrap();
Ok(mean)
}
/// Update weights with gradient
pub fn update_weights(&mut self, _lr: f32, weight_decay: f32) {
self.query_weights -= &(&self.query_weights * weight_decay);
self.key_weights -= &(&self.key_weights * weight_decay);
self.value_weights -= &(&self.value_weights * weight_decay);
}
}
/// Multi-head attention layer
#[derive(Debug, Clone)]
pub struct MultiHeadAttention {
/// Individual attention heads
heads: Vec<AttentionLayer>,
/// Output projection
output_projection: Array2<f32>,
/// Number of heads
num_heads: usize,
/// Dimension per head
head_dim: usize,
/// Total output dimension
output_dim: usize,
/// Dropout probability
dropout: f32,
}
impl MultiHeadAttention {
/// Create a new multi-head attention layer
#[must_use]
pub fn new(input_dim: usize, num_heads: usize, head_dim: usize, dropout: f32) -> Self {
let mut heads = Vec::with_capacity(num_heads);
for _ in 0..num_heads {
heads.push(AttentionLayer::new(input_dim, head_dim));
}
let total_dim = num_heads * head_dim;
let output_projection = xavier_init(total_dim, input_dim);
Self {
heads,
output_projection,
num_heads,
head_dim,
output_dim: input_dim,
dropout,
}
}
/// Get the number of heads
#[must_use]
pub fn num_heads(&self) -> usize {
self.num_heads
}
/// Get the output dimension
#[must_use]
pub fn output_dim(&self) -> usize {
self.output_dim
}
/// Forward pass through multi-head attention
pub fn forward(&self, features: &Array2<f32>) -> AttentionResult<Array2<f32>> {
let n = features.nrows();
// Compute attention for each head
let mut head_outputs = Vec::with_capacity(self.num_heads);
for head in &self.heads {
let output = head.forward(features)?;
head_outputs.push(output);
}
// Concatenate head outputs
let mut concat = Array2::zeros((n, self.num_heads * self.head_dim));
for (h, output) in head_outputs.iter().enumerate() {
let start = h * self.head_dim;
for i in 0..n {
for j in 0..self.head_dim {
concat[[i, start + j]] = output[[i, j]];
}
}
}
// Apply output projection
let output = concat.dot(&self.output_projection);
Ok(output)
}
/// Forward pass with explicit Q, K, V
pub fn forward_qkv(
&self,
query: &Array2<f32>,
key: &Array2<f32>,
value: &Array2<f32>,
mask: Option<&Array2<f32>>,
) -> AttentionResult<Array2<f32>> {
let n = query.nrows();
// Compute attention for each head
let mut head_outputs = Vec::with_capacity(self.num_heads);
for head in &self.heads {
let output = head.forward_qkv(query, key, value, mask)?;
head_outputs.push(output);
}
// Concatenate head outputs
let mut concat = Array2::zeros((n, self.num_heads * self.head_dim));
for (h, output) in head_outputs.iter().enumerate() {
let start = h * self.head_dim;
for i in 0..n {
for j in 0..self.head_dim {
concat[[i, start + j]] = output[[i, j]];
}
}
}
// Apply output projection
let output = concat.dot(&self.output_projection);
Ok(output)
}
/// Graph-level readout using multi-head attention
pub fn graph_readout(&self, node_features: &Array2<f32>) -> AttentionResult<Array1<f32>> {
let attended = self.forward(node_features)?;
let mean = attended.mean_axis(Axis(0)).unwrap();
Ok(mean)
}
}
/// Cross-attention between two sequences
#[derive(Debug, Clone)]
pub struct CrossAttention {
/// Query projection for source
query_proj: Array2<f32>,
/// Key projection for target
key_proj: Array2<f32>,
/// Value projection for target
value_proj: Array2<f32>,
/// Attention dimension
attention_dim: usize,
/// Source dimension
source_dim: usize,
/// Target dimension
target_dim: usize,
}
impl CrossAttention {
/// Create a new cross-attention layer
#[must_use]
pub fn new(source_dim: usize, target_dim: usize, attention_dim: usize) -> Self {
Self {
query_proj: xavier_init(source_dim, attention_dim),
key_proj: xavier_init(target_dim, attention_dim),
value_proj: xavier_init(target_dim, attention_dim),
attention_dim,
source_dim,
target_dim,
}
}
/// Compute cross-attention between source and target
pub fn forward(
&self,
source: &Array2<f32>,
target: &Array2<f32>,
) -> AttentionResult<Array2<f32>> {
// Project source to query
let query = source.dot(&self.query_proj);
// Project target to key and value
let key = target.dot(&self.key_proj);
let value = target.dot(&self.value_proj);
// Compute attention scores
let scale = (self.attention_dim as f32).sqrt();
let scores = query.dot(&key.t()) / scale;
// Softmax
let attention_weights = softmax_2d(&scores);
// Apply attention to values
let output = attention_weights.dot(&value);
Ok(output)
}
}
/// Set attention for set-to-set operations
#[derive(Debug, Clone)]
pub struct SetAttention {
/// Multi-head attention
mha: MultiHeadAttention,
/// Layer normalization parameters
layer_norm_weight: Array1<f32>,
layer_norm_bias: Array1<f32>,
/// Feed-forward network
ffn_w1: Array2<f32>,
ffn_w2: Array2<f32>,
/// Dimensions
input_dim: usize,
hidden_dim: usize,
}
impl SetAttention {
/// Create a new set attention layer (Set Transformer style)
#[must_use]
pub fn new(input_dim: usize, num_heads: usize, hidden_dim: usize) -> Self {
let head_dim = input_dim / num_heads;
let mha = MultiHeadAttention::new(input_dim, num_heads, head_dim, 0.0);
Self {
mha,
layer_norm_weight: Array1::ones(input_dim),
layer_norm_bias: Array1::zeros(input_dim),
ffn_w1: xavier_init(input_dim, hidden_dim),
ffn_w2: xavier_init(hidden_dim, input_dim),
input_dim,
hidden_dim,
}
}
/// Forward pass with self-attention and feed-forward
pub fn forward(&self, features: &Array2<f32>) -> AttentionResult<Array2<f32>> {
// Self-attention with residual
let attended = self.mha.forward(features)?;
let residual1 = features + &attended;
let normed1 = layer_norm(&residual1, &self.layer_norm_weight, &self.layer_norm_bias);
// Feed-forward with residual
let hidden = normed1.dot(&self.ffn_w1).mapv(|x| x.max(0.0)); // ReLU
let ffn_out = hidden.dot(&self.ffn_w2);
let residual2 = &normed1 + &ffn_out;
let output = layer_norm(&residual2, &self.layer_norm_weight, &self.layer_norm_bias);
Ok(output)
}
/// Aggregate set elements using learned attention
pub fn aggregate(&self, features: &Array2<f32>) -> AttentionResult<Array1<f32>> {
let transformed = self.forward(features)?;
Ok(transformed.mean_axis(Axis(0)).unwrap())
}
}
// =========== Helper Functions ===========
/// Xavier/Glorot initialization
fn xavier_init(fan_in: usize, fan_out: usize) -> Array2<f32> {
let limit = (6.0 / (fan_in + fan_out) as f32).sqrt();
let uniform = Uniform::new(-limit, limit);
let mut rng = rand::thread_rng();
Array2::from_shape_fn((fan_in, fan_out), |_| uniform.sample(&mut rng))
}
/// Softmax over 2D array (row-wise)
fn softmax_2d(scores: &Array2<f32>) -> Array2<f32> {
let mut result = scores.clone();
for mut row in result.rows_mut() {
// Subtract max for numerical stability
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0;
for val in row.iter_mut() {
if val.is_finite() {
*val = (*val - max).exp();
sum += *val;
} else {
*val = 0.0;
}
}
if sum > 0.0 {
row /= sum;
}
}
result
}
/// Layer normalization
fn layer_norm(x: &Array2<f32>, weight: &Array1<f32>, bias: &Array1<f32>) -> Array2<f32> {
let eps = 1e-5;
let mut result = x.clone();
for mut row in result.rows_mut() {
let mean = row.mean().unwrap_or(0.0);
let variance = row.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / row.len() as f32;
let std = (variance + eps).sqrt();
for (i, val) in row.iter_mut().enumerate() {
*val = (*val - mean) / std * weight[i] + bias[i];
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_attention_layer() {
let layer = AttentionLayer::new(8, 16);
assert_eq!(layer.output_dim(), 16);
let features = Array2::from_elem((3, 8), 0.5);
let output = layer.forward(&features).unwrap();
assert_eq!(output.shape(), &[3, 16]);
}
#[test]
fn test_multi_head_attention() {
let mha = MultiHeadAttention::new(8, 4, 4, 0.0);
assert_eq!(mha.num_heads(), 4);
assert_eq!(mha.output_dim(), 8);
let features = Array2::from_elem((5, 8), 0.5);
let output = mha.forward(&features).unwrap();
assert_eq!(output.shape(), &[5, 8]);
}
#[test]
fn test_cross_attention() {
let cross = CrossAttention::new(8, 16, 32);
let source = Array2::from_elem((3, 8), 0.5);
let target = Array2::from_elem((5, 16), 0.5);
let output = cross.forward(&source, &target).unwrap();
assert_eq!(output.shape(), &[3, 32]);
}
#[test]
fn test_set_attention() {
let set_attn = SetAttention::new(16, 4, 64);
let features = Array2::from_elem((4, 16), 0.5);
let output = set_attn.forward(&features).unwrap();
assert_eq!(output.shape(), &[4, 16]);
let aggregated = set_attn.aggregate(&features).unwrap();
assert_eq!(aggregated.len(), 16);
}
#[test]
fn test_softmax() {
let scores = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let probs = softmax_2d(&scores);
// Each row should sum to 1
for row in probs.rows() {
let sum: f32 = row.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
}
// Higher values should have higher probabilities
assert!(probs[[0, 2]] > probs[[0, 1]]);
assert!(probs[[0, 1]] > probs[[0, 0]]);
}
#[test]
fn test_layer_norm() {
let x = Array2::from_shape_vec((2, 4), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
let weight = Array1::ones(4);
let bias = Array1::zeros(4);
let normed = layer_norm(&x, &weight, &bias);
// Each row should have mean ~0 and std ~1
for row in normed.rows() {
let mean: f32 = row.iter().sum::<f32>() / row.len() as f32;
assert!(mean.abs() < 1e-5);
}
}
#[test]
fn test_graph_readout() {
let layer = AttentionLayer::new(8, 4);
let node_features = Array2::from_elem((5, 8), 0.5);
let readout = layer.graph_readout(&node_features).unwrap();
assert_eq!(readout.len(), 4);
}
#[test]
fn test_attention_with_mask() {
let layer = AttentionLayer::new(4, 4);
let features = Array2::from_elem((3, 4), 0.5);
// Create a mask that blocks second and third positions
let mut mask = Array2::ones((3, 3));
mask[[0, 1]] = 0.0;
mask[[0, 2]] = 0.0;
let query = features.dot(&layer.query_weights);
let key = features.dot(&layer.key_weights);
let attn_weights = layer.compute_attention(&query, &key, Some(&mask)).unwrap();
// First row should only attend to itself
assert!(attn_weights[[0, 0]] > 0.99); // Almost all attention to self
assert!(attn_weights[[0, 1]] < 0.01);
assert!(attn_weights[[0, 2]] < 0.01);
}
}

View File

@@ -0,0 +1,764 @@
//! GNN model implementation.
//!
//! This module provides Graph Neural Network implementations including:
//! - GCN (Graph Convolutional Network)
//! - GraphSAGE (Sample and Aggregate)
//! - GAT (Graph Attention Network)
use ndarray::{Array1, Array2};
use rand::Rng;
use rand_distr::{Distribution, Normal, Uniform};
use serde::{Deserialize, Serialize};
use crate::domain::entities::GnnModelType;
use crate::infrastructure::attention::AttentionLayer;
/// Error type for GNN operations
#[derive(Debug, thiserror::Error)]
pub enum GnnError {
/// Dimension mismatch
#[error("Dimension mismatch: {0}")]
DimensionMismatch(String),
/// Invalid layer configuration
#[error("Invalid layer configuration: {0}")]
InvalidConfig(String),
/// Computation error
#[error("Computation error: {0}")]
ComputationError(String),
}
/// Result type for GNN operations
pub type GnnResult<T> = Result<T, GnnError>;
/// Aggregation method for GraphSAGE
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Aggregator {
/// Mean aggregation
Mean,
/// Sum aggregation (GCN-style)
Sum,
/// Max pooling aggregation
MaxPool,
/// LSTM aggregation (sequence-aware)
Lstm,
}
impl Default for Aggregator {
fn default() -> Self {
Self::Mean
}
}
/// A single layer in a GNN
#[derive(Debug, Clone)]
pub enum GnnLayer {
/// Graph Convolutional Network layer
/// H' = σ(D^(-1/2) A D^(-1/2) H W + b)
Gcn {
/// Weight matrix (input_dim x output_dim)
weights: Array2<f32>,
/// Bias vector (output_dim)
bias: Array1<f32>,
},
/// GraphSAGE layer
/// h_v' = σ(W * CONCAT(h_v, AGGREGATE({h_u : u ∈ N(v)})))
GraphSage {
/// Aggregation method
aggregator: Aggregator,
/// Self-weight matrix
self_weights: Array2<f32>,
/// Neighbor-weight matrix
neighbor_weights: Array2<f32>,
/// Bias vector
bias: Array1<f32>,
},
/// Graph Attention Network layer
/// Uses attention mechanism to weight neighbor contributions
Gat {
/// Weight matrix for linear transformation
weights: Array2<f32>,
/// Attention weights (2 * output_dim)
attention_weights: Array1<f32>,
/// Number of attention heads
num_heads: usize,
/// Bias vector
bias: Array1<f32>,
/// Leaky ReLU negative slope
negative_slope: f32,
},
}
impl GnnLayer {
/// Create a new GCN layer
pub fn gcn(input_dim: usize, output_dim: usize) -> Self {
let weights = xavier_init(input_dim, output_dim);
let bias = Array1::zeros(output_dim);
Self::Gcn { weights, bias }
}
/// Create a new GraphSAGE layer
pub fn graph_sage(input_dim: usize, output_dim: usize, aggregator: Aggregator) -> Self {
let self_weights = xavier_init(input_dim, output_dim);
let neighbor_weights = xavier_init(input_dim, output_dim);
let bias = Array1::zeros(output_dim);
Self::GraphSage {
aggregator,
self_weights,
neighbor_weights,
bias,
}
}
/// Create a new GAT layer
pub fn gat(input_dim: usize, output_dim: usize, num_heads: usize) -> Self {
let weights = xavier_init(input_dim, output_dim * num_heads);
// Initialize attention weights with small values
let mut rng = rand::thread_rng();
let uniform = Uniform::new(-0.1, 0.1);
let attention_weights: Array1<f32> = Array1::from_iter(
(0..2 * output_dim).map(|_| uniform.sample(&mut rng)),
);
let bias = Array1::zeros(output_dim * num_heads);
Self::Gat {
weights,
attention_weights,
num_heads,
bias,
negative_slope: 0.2,
}
}
/// Get the input dimension
#[must_use]
pub fn input_dim(&self) -> usize {
match self {
Self::Gcn { weights, .. } => weights.nrows(),
Self::GraphSage { self_weights, .. } => self_weights.nrows(),
Self::Gat { weights, .. } => weights.nrows(),
}
}
/// Get the output dimension
#[must_use]
pub fn output_dim(&self) -> usize {
match self {
Self::Gcn { weights, .. } => weights.ncols(),
Self::GraphSage { self_weights, .. } => self_weights.ncols(),
Self::Gat { weights, num_heads, .. } => weights.ncols() / num_heads,
}
}
/// Forward pass through the layer
pub fn forward(&self, features: &Array2<f32>, adj_matrix: &Array2<f32>) -> GnnResult<Array2<f32>> {
match self {
Self::Gcn { weights, bias } => self.gcn_forward(features, adj_matrix, weights, bias),
Self::GraphSage {
aggregator,
self_weights,
neighbor_weights,
bias,
} => self.sage_forward(features, adj_matrix, *aggregator, self_weights, neighbor_weights, bias),
Self::Gat {
weights,
attention_weights,
num_heads,
bias,
negative_slope,
} => self.gat_forward(features, adj_matrix, weights, attention_weights, *num_heads, bias, *negative_slope),
}
}
fn gcn_forward(
&self,
features: &Array2<f32>,
adj_matrix: &Array2<f32>,
weights: &Array2<f32>,
bias: &Array1<f32>,
) -> GnnResult<Array2<f32>> {
// H' = σ(A_norm * H * W + b)
// A_norm is already normalized (symmetric normalization)
// Aggregate neighbor features: AH
let aggregated = adj_matrix.dot(features);
// Transform: AH * W
let transformed = aggregated.dot(weights);
// Add bias and apply activation
let mut output = transformed;
for mut row in output.rows_mut() {
for (i, val) in row.iter_mut().enumerate() {
*val = relu(*val + bias[i]);
}
}
Ok(output)
}
fn sage_forward(
&self,
features: &Array2<f32>,
adj_matrix: &Array2<f32>,
aggregator: Aggregator,
self_weights: &Array2<f32>,
neighbor_weights: &Array2<f32>,
bias: &Array1<f32>,
) -> GnnResult<Array2<f32>> {
let n = features.nrows();
let out_dim = self_weights.ncols();
// Aggregate neighbor features
let neighbor_agg = match aggregator {
Aggregator::Mean => {
// Mean aggregation
let mut agg = adj_matrix.dot(features);
let degrees: Vec<f32> = (0..n).map(|i| adj_matrix.row(i).sum().max(1.0)).collect();
for (i, mut row) in agg.rows_mut().into_iter().enumerate() {
row /= degrees[i];
}
agg
}
Aggregator::Sum => {
// Sum aggregation
adj_matrix.dot(features)
}
Aggregator::MaxPool => {
// Max pooling
let mut agg = Array2::zeros((n, features.ncols()));
for i in 0..n {
for j in 0..features.ncols() {
let mut max_val = f32::NEG_INFINITY;
for k in 0..n {
if adj_matrix[[i, k]] > 0.0 {
max_val = max_val.max(features[[k, j]]);
}
}
agg[[i, j]] = if max_val.is_finite() { max_val } else { 0.0 };
}
}
agg
}
Aggregator::Lstm => {
// Simplified LSTM: just use mean for now
// Full implementation would use actual LSTM
adj_matrix.dot(features)
}
};
// Transform self and neighbor features
let self_transformed = features.dot(self_weights);
let neighbor_transformed = neighbor_agg.dot(neighbor_weights);
// Combine: concat and add bias
let mut output = Array2::zeros((n, out_dim));
for i in 0..n {
for j in 0..out_dim {
let val = self_transformed[[i, j]] + neighbor_transformed[[i, j]] + bias[j];
output[[i, j]] = relu(val);
}
}
Ok(output)
}
fn gat_forward(
&self,
features: &Array2<f32>,
adj_matrix: &Array2<f32>,
weights: &Array2<f32>,
attention_weights: &Array1<f32>,
num_heads: usize,
bias: &Array1<f32>,
negative_slope: f32,
) -> GnnResult<Array2<f32>> {
let n = features.nrows();
let total_out_dim = weights.ncols();
let head_dim = total_out_dim / num_heads;
// Transform features: H * W
let transformed = features.dot(weights);
// Multi-head attention
let mut outputs = Vec::with_capacity(num_heads);
for head in 0..num_heads {
let start = head * head_dim;
let end = start + head_dim;
// Extract features for this head
let h = transformed.slice(ndarray::s![.., start..end]).to_owned();
// Compute attention coefficients
let mut attention = Array2::zeros((n, n));
let attention_dim = attention_weights.len() / 2;
let a_src = attention_weights.slice(ndarray::s![..attention_dim]);
let a_dst = attention_weights.slice(ndarray::s![attention_dim..]);
for i in 0..n {
for j in 0..n {
if adj_matrix[[i, j]] > 0.0 {
// Compute attention: a^T [Wh_i || Wh_j]
let mut e = 0.0;
for k in 0..head_dim.min(attention_dim) {
e += a_src[k] * h[[i, k]] + a_dst[k] * h[[j, k]];
}
// Leaky ReLU
attention[[i, j]] = leaky_relu(e, negative_slope);
} else {
attention[[i, j]] = f32::NEG_INFINITY;
}
}
}
// Softmax over neighbors
for mut row in attention.rows_mut() {
let max_val = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0;
for val in row.iter_mut() {
if val.is_finite() {
*val = (*val - max_val).exp();
sum += *val;
} else {
*val = 0.0;
}
}
if sum > 0.0 {
row /= sum;
}
}
// Apply attention
let head_output = attention.dot(&h);
outputs.push(head_output);
}
// Concatenate head outputs
let mut output = Array2::zeros((n, total_out_dim));
for (head, head_out) in outputs.iter().enumerate() {
let start = head * head_dim;
for i in 0..n {
for j in 0..head_dim {
output[[i, start + j]] = head_out[[i, j]];
}
}
}
// Add bias and apply activation
for mut row in output.rows_mut() {
for (i, val) in row.iter_mut().enumerate() {
if i < bias.len() {
*val = elu(*val + bias[i], 1.0);
}
}
}
Ok(output)
}
/// Update weights with gradients
pub fn update_weights(&mut self, gradient: &Array2<f32>, lr: f32, weight_decay: f32) {
match self {
Self::Gcn { weights, bias: _ } => {
// Apply weight decay
*weights -= &(weights.clone() * weight_decay);
// Apply gradient update
if gradient.shape() == weights.shape() {
*weights -= &(gradient * lr);
}
}
Self::GraphSage {
self_weights,
neighbor_weights,
bias: _,
..
} => {
*self_weights -= &(self_weights.clone() * weight_decay);
*neighbor_weights -= &(neighbor_weights.clone() * weight_decay);
if gradient.shape() == self_weights.shape() {
*self_weights -= &(gradient * lr);
*neighbor_weights -= &(gradient * lr);
}
}
Self::Gat { weights, bias: _, .. } => {
*weights -= &(weights.clone() * weight_decay);
if gradient.shape() == weights.shape() {
*weights -= &(gradient * lr);
}
}
}
}
/// Get layer parameters as flattened vector
#[must_use]
pub fn get_parameters(&self) -> Vec<f32> {
match self {
Self::Gcn { weights, bias } => {
let mut params: Vec<f32> = weights.iter().cloned().collect();
params.extend(bias.iter().cloned());
params
}
Self::GraphSage {
self_weights,
neighbor_weights,
bias,
..
} => {
let mut params: Vec<f32> = self_weights.iter().cloned().collect();
params.extend(neighbor_weights.iter().cloned());
params.extend(bias.iter().cloned());
params
}
Self::Gat {
weights,
attention_weights,
bias,
..
} => {
let mut params: Vec<f32> = weights.iter().cloned().collect();
params.extend(attention_weights.iter().cloned());
params.extend(bias.iter().cloned());
params
}
}
}
}
/// Complete GNN model with multiple layers
#[derive(Debug, Clone)]
pub struct GnnModel {
/// Model type
model_type: GnnModelType,
/// Stacked GNN layers
layers: Vec<GnnLayer>,
/// Optional attention layer for final aggregation
attention: Option<AttentionLayer>,
/// Dropout probability (applied during training)
dropout: f32,
/// Whether the model is in training mode
training: bool,
}
impl GnnModel {
/// Create a new GNN model
#[must_use]
pub fn new(
model_type: GnnModelType,
input_dim: usize,
output_dim: usize,
num_layers: usize,
hidden_dim: usize,
num_heads: usize,
dropout: f32,
) -> Self {
let mut layers = Vec::with_capacity(num_layers);
for i in 0..num_layers {
let in_dim = if i == 0 { input_dim } else { hidden_dim };
let out_dim = if i == num_layers - 1 {
output_dim
} else {
hidden_dim
};
let layer = match model_type {
GnnModelType::Gcn => GnnLayer::gcn(in_dim, out_dim),
GnnModelType::GraphSage => GnnLayer::graph_sage(in_dim, out_dim, Aggregator::Mean),
GnnModelType::Gat => GnnLayer::gat(in_dim, out_dim, num_heads),
};
layers.push(layer);
}
// Add attention layer for graph-level readout
let attention = if num_layers > 0 {
Some(AttentionLayer::new(output_dim, 64))
} else {
None
};
Self {
model_type,
layers,
attention,
dropout,
training: true,
}
}
/// Set training mode
pub fn train(&mut self) {
self.training = true;
}
/// Set evaluation mode
pub fn eval(&mut self) {
self.training = false;
}
/// Get the model type
#[must_use]
pub fn model_type(&self) -> GnnModelType {
self.model_type
}
/// Get the number of layers
#[must_use]
pub fn num_layers(&self) -> usize {
self.layers.len()
}
/// Get layer dimensions
#[must_use]
pub fn layer_dims(&self, layer_idx: usize) -> (usize, usize) {
if layer_idx < self.layers.len() {
(
self.layers[layer_idx].input_dim(),
self.layers[layer_idx].output_dim(),
)
} else {
(0, 0)
}
}
/// Forward pass through the model
pub fn forward(
&self,
features: &Array2<f32>,
adj_matrix: &Array2<f32>,
) -> GnnResult<Array2<f32>> {
let mut h = features.clone();
for (i, layer) in self.layers.iter().enumerate() {
h = layer.forward(&h, adj_matrix)?;
// Apply dropout (except on last layer)
if self.training && i < self.layers.len() - 1 {
h = self.apply_dropout(&h);
}
}
Ok(h)
}
/// Update model weights with gradients
pub fn update_weights(
&mut self,
gradients: &[Array2<f32>],
lr: f32,
weight_decay: f32,
) {
for (layer, grad) in self.layers.iter_mut().zip(gradients.iter()) {
layer.update_weights(grad, lr, weight_decay);
}
}
/// Get all model parameters as flattened vector
#[must_use]
pub fn get_parameters(&self) -> Vec<f32> {
let mut params = Vec::new();
for layer in &self.layers {
params.extend(layer.get_parameters());
}
params
}
/// Get total number of parameters
#[must_use]
pub fn num_parameters(&self) -> usize {
self.get_parameters().len()
}
fn apply_dropout(&self, features: &Array2<f32>) -> Array2<f32> {
if self.dropout <= 0.0 || !self.training {
return features.clone();
}
let mut rng = rand::thread_rng();
let scale = 1.0 / (1.0 - self.dropout);
let mut dropped = features.clone();
for val in dropped.iter_mut() {
if rng.gen::<f32>() < self.dropout {
*val = 0.0;
} else {
*val *= scale;
}
}
dropped
}
}
// =========== Activation Functions ===========
/// ReLU activation
fn relu(x: f32) -> f32 {
x.max(0.0)
}
/// Leaky ReLU activation
fn leaky_relu(x: f32, negative_slope: f32) -> f32 {
if x >= 0.0 {
x
} else {
negative_slope * x
}
}
/// ELU activation
fn elu(x: f32, alpha: f32) -> f32 {
if x >= 0.0 {
x
} else {
alpha * (x.exp() - 1.0)
}
}
/// GELU activation (Gaussian Error Linear Unit)
#[allow(dead_code)]
fn gelu(x: f32) -> f32 {
0.5 * x * (1.0 + (x * 0.7978845608 * (1.0 + 0.044715 * x * x)).tanh())
}
// =========== Initialization ===========
/// Xavier/Glorot uniform initialization
fn xavier_init(fan_in: usize, fan_out: usize) -> Array2<f32> {
let limit = (6.0 / (fan_in + fan_out) as f32).sqrt();
let uniform = Uniform::new(-limit, limit);
let mut rng = rand::thread_rng();
Array2::from_shape_fn((fan_in, fan_out), |_| uniform.sample(&mut rng))
}
/// Kaiming/He initialization (for ReLU)
#[allow(dead_code)]
fn kaiming_init(fan_in: usize, fan_out: usize) -> Array2<f32> {
let std = (2.0 / fan_in as f32).sqrt();
let normal = Normal::new(0.0, std).unwrap();
let mut rng = rand::thread_rng();
Array2::from_shape_fn((fan_in, fan_out), |_| normal.sample(&mut rng))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gcn_layer() {
let layer = GnnLayer::gcn(8, 4);
assert_eq!(layer.input_dim(), 8);
assert_eq!(layer.output_dim(), 4);
let features = Array2::from_elem((3, 8), 0.5);
let adj = Array2::<f32>::eye(3);
let output = layer.forward(&features, &adj).unwrap();
assert_eq!(output.shape(), &[3, 4]);
}
#[test]
fn test_graphsage_layer() {
let layer = GnnLayer::graph_sage(8, 4, Aggregator::Mean);
assert_eq!(layer.input_dim(), 8);
assert_eq!(layer.output_dim(), 4);
let features = Array2::from_elem((3, 8), 0.5);
let adj = Array2::<f32>::eye(3);
let output = layer.forward(&features, &adj).unwrap();
assert_eq!(output.shape(), &[3, 4]);
}
#[test]
fn test_gat_layer() {
let layer = GnnLayer::gat(8, 4, 2);
assert_eq!(layer.input_dim(), 8);
let features = Array2::from_elem((3, 8), 0.5);
let adj = Array2::<f32>::eye(3);
let output = layer.forward(&features, &adj).unwrap();
assert_eq!(output.shape(), &[3, 8]); // 4 * 2 heads
}
#[test]
fn test_gnn_model() {
let model = GnnModel::new(GnnModelType::Gcn, 16, 8, 2, 32, 4, 0.5);
assert_eq!(model.model_type(), GnnModelType::Gcn);
assert_eq!(model.num_layers(), 2);
let features = Array2::from_elem((5, 16), 0.5);
let adj = Array2::<f32>::eye(5);
let output = model.forward(&features, &adj).unwrap();
assert_eq!(output.shape(), &[5, 8]);
}
#[test]
fn test_activation_functions() {
assert_eq!(relu(-1.0), 0.0);
assert_eq!(relu(1.0), 1.0);
assert!(leaky_relu(-1.0, 0.2) < 0.0);
assert_eq!(leaky_relu(1.0, 0.2), 1.0);
assert!(elu(-1.0, 1.0) < 0.0);
assert_eq!(elu(1.0, 1.0), 1.0);
}
#[test]
fn test_xavier_init() {
let weights = xavier_init(100, 100);
assert_eq!(weights.shape(), &[100, 100]);
// Check values are in reasonable range
let max = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let min = weights.iter().cloned().fold(f32::INFINITY, f32::min);
assert!(max < 1.0);
assert!(min > -1.0);
}
#[test]
fn test_model_parameters() {
let model = GnnModel::new(GnnModelType::Gcn, 8, 4, 2, 16, 1, 0.0);
let params = model.get_parameters();
assert!(params.len() > 0);
// Layer 1: 8*16 + 16 = 144
// Layer 2: 16*4 + 4 = 68
// Total: 212
assert_eq!(model.num_parameters(), 8 * 16 + 16 + 16 * 4 + 4);
}
#[test]
fn test_aggregators() {
let features = Array2::from_shape_vec((3, 4), vec![
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
]).unwrap();
let mut adj = Array2::zeros((3, 3));
adj[[0, 1]] = 1.0;
adj[[0, 2]] = 1.0;
adj[[1, 0]] = 1.0;
adj[[2, 0]] = 1.0;
// Test mean aggregation
let layer = GnnLayer::graph_sage(4, 2, Aggregator::Mean);
let output = layer.forward(&features, &adj).unwrap();
assert_eq!(output.shape(), &[3, 2]);
// Test max aggregation
let layer = GnnLayer::graph_sage(4, 2, Aggregator::MaxPool);
let output = layer.forward(&features, &adj).unwrap();
assert_eq!(output.shape(), &[3, 2]);
}
}

View File

@@ -0,0 +1,7 @@
//! Infrastructure layer for the learning bounded context.
//!
//! Contains GNN model implementations, attention mechanisms,
//! and other technical components.
pub mod gnn_model;
pub mod attention;

View File

@@ -0,0 +1,82 @@
//! # sevensense-learning
//!
//! Graph Neural Network (GNN) based learning and embedding refinement for 7sense.
//!
//! This crate provides:
//! - GNN models (GCN, GraphSAGE, GAT) for graph-based learning
//! - Embedding refinement through message passing
//! - Contrastive learning with InfoNCE and triplet loss
//! - Elastic Weight Consolidation (EWC) for continual learning
//! - Graph attention mechanisms for relationship modeling
//!
//! ## Architecture
//!
//! The crate follows Domain-Driven Design principles:
//! - `domain`: Core entities and repository traits
//! - `application`: Business logic and services
//! - `infrastructure`: GNN implementations and attention mechanisms
//!
//! ## Example
//!
//! ```rust,ignore
//! use sevensense_learning::{LearningService, LearningConfig, GnnModelType};
//!
//! let config = LearningConfig::default();
//! let service = LearningService::new(config);
//!
//! // Train on transition graph
//! let metrics = service.train_epoch(&graph).await?;
//!
//! // Refine embeddings
//! let refined = service.refine_embeddings(&embeddings).await?;
//! ```
#![warn(missing_docs)]
#![warn(clippy::all)]
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)]
pub mod domain;
pub mod application;
pub mod infrastructure;
pub mod loss;
pub mod ewc;
// Re-exports for convenience
pub use domain::entities::{
LearningSession, GnnModelType, TrainingStatus, TrainingMetrics,
TransitionGraph, RefinedEmbedding, EmbeddingId, Timestamp,
GraphNode, GraphEdge, HyperParameters, LearningConfig,
};
pub use domain::repository::LearningRepository;
pub use application::services::LearningService;
pub use infrastructure::gnn_model::{GnnModel, GnnLayer, Aggregator};
pub use infrastructure::attention::{AttentionLayer, MultiHeadAttention};
pub use loss::{info_nce_loss, triplet_loss, margin_ranking_loss, contrastive_loss};
pub use ewc::{EwcState, FisherInformation, EwcRegularizer};
/// Crate version information
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
/// Prelude module for convenient imports
pub mod prelude {
pub use crate::domain::entities::*;
pub use crate::domain::repository::*;
pub use crate::application::services::*;
pub use crate::infrastructure::gnn_model::*;
pub use crate::infrastructure::attention::*;
pub use crate::loss::*;
pub use crate::ewc::*;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_crate_exports() {
// Verify all public types are accessible
let _: GnnModelType = GnnModelType::Gcn;
let _: TrainingStatus = TrainingStatus::Pending;
}
}

View File

@@ -0,0 +1,612 @@
//! Loss functions for contrastive learning.
//!
//! This module provides various loss functions for training GNN models
//! on graph-structured data with contrastive learning objectives.
/// Compute InfoNCE (Noise Contrastive Estimation) loss.
///
/// InfoNCE loss encourages the model to distinguish positive samples
/// from negative samples in the embedding space.
///
/// # Arguments
/// * `anchor` - The anchor embedding
/// * `positive` - The positive (similar) embedding
/// * `negatives` - Slice of negative (dissimilar) embeddings
/// * `temperature` - Temperature parameter for softmax scaling (typical: 0.07-0.5)
///
/// # Returns
/// The InfoNCE loss value (lower is better)
///
/// # Formula
/// L = -log(exp(sim(a,p)/τ) / Σ exp(sim(a,n_i)/τ))
///
/// # Example
/// ```
/// use sevensense_learning::info_nce_loss;
///
/// let anchor = vec![1.0, 0.0, 0.0];
/// let positive = vec![0.9, 0.1, 0.0];
/// let negative = vec![0.0, 1.0, 0.0];
///
/// let loss = info_nce_loss(&anchor, &positive, &[&negative], 0.07);
/// assert!(loss >= 0.0);
/// ```
#[must_use]
pub fn info_nce_loss(
anchor: &[f32],
positive: &[f32],
negatives: &[&[f32]],
temperature: f32,
) -> f32 {
if anchor.is_empty() || positive.is_empty() || negatives.is_empty() {
return 0.0;
}
let temp = temperature.max(1e-6); // Prevent division by zero
// Compute similarity with positive
let pos_sim = cosine_similarity(anchor, positive) / temp;
// Compute similarities with all negatives
let neg_sims: Vec<f32> = negatives
.iter()
.map(|neg| cosine_similarity(anchor, neg) / temp)
.collect();
// Log-sum-exp for numerical stability
// L = -pos_sim + log(exp(pos_sim) + Σ exp(neg_sim_i))
let max_sim = neg_sims
.iter()
.chain(std::iter::once(&pos_sim))
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
let sum_exp: f32 = std::iter::once(pos_sim)
.chain(neg_sims)
.map(|s| (s - max_sim).exp())
.sum();
let log_sum_exp = max_sim + sum_exp.ln();
-pos_sim + log_sum_exp
}
/// Compute triplet loss with margin.
///
/// Triplet loss ensures that the anchor is closer to the positive
/// than to the negative by at least a margin.
///
/// # Arguments
/// * `anchor` - The anchor embedding
/// * `positive` - The positive (similar) embedding
/// * `negative` - The negative (dissimilar) embedding
/// * `margin` - The margin to enforce between positive and negative distances
///
/// # Returns
/// The triplet loss value (lower is better)
///
/// # Formula
/// L = max(0, d(a,p) - d(a,n) + margin)
///
/// # Example
/// ```
/// use sevensense_learning::triplet_loss;
///
/// let anchor = vec![1.0, 0.0, 0.0];
/// let positive = vec![0.9, 0.1, 0.0];
/// let negative = vec![0.0, 1.0, 0.0];
///
/// let loss = triplet_loss(&anchor, &positive, &negative, 1.0);
/// assert!(loss >= 0.0);
/// ```
#[must_use]
pub fn triplet_loss(anchor: &[f32], positive: &[f32], negative: &[f32], margin: f32) -> f32 {
if anchor.is_empty() || positive.is_empty() || negative.is_empty() {
return 0.0;
}
let d_pos = euclidean_distance(anchor, positive);
let d_neg = euclidean_distance(anchor, negative);
(d_pos - d_neg + margin).max(0.0)
}
/// Compute margin ranking loss.
///
/// Similar to triplet loss but uses a ranking formulation.
///
/// # Arguments
/// * `anchor` - The anchor embedding
/// * `positive` - The positive embedding
/// * `negative` - The negative embedding
/// * `margin` - The margin for ranking
///
/// # Returns
/// The margin ranking loss value
///
/// # Formula
/// L = max(0, margin - (sim(a,p) - sim(a,n)))
#[must_use]
pub fn margin_ranking_loss(
anchor: &[f32],
positive: &[f32],
negative: &[f32],
margin: f32,
) -> f32 {
if anchor.is_empty() || positive.is_empty() || negative.is_empty() {
return 0.0;
}
let sim_pos = cosine_similarity(anchor, positive);
let sim_neg = cosine_similarity(anchor, negative);
(margin - (sim_pos - sim_neg)).max(0.0)
}
/// Compute contrastive loss (SimCLR style).
///
/// # Arguments
/// * `z_i` - First view embedding
/// * `z_j` - Second view embedding (augmented view of same sample)
/// * `other_samples` - Embeddings of other samples in the batch
/// * `temperature` - Temperature parameter
///
/// # Returns
/// The contrastive loss value
#[must_use]
pub fn contrastive_loss(
z_i: &[f32],
z_j: &[f32],
other_samples: &[&[f32]],
temperature: f32,
) -> f32 {
if z_i.is_empty() || z_j.is_empty() {
return 0.0;
}
let temp = temperature.max(1e-6);
// Similarity between positive pair
let pos_sim = cosine_similarity(z_i, z_j) / temp;
// Similarities with all other samples
let mut all_sims: Vec<f32> = vec![pos_sim];
for sample in other_samples {
all_sims.push(cosine_similarity(z_i, sample) / temp);
}
// Log-sum-exp trick
let max_sim = all_sims.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let sum_exp: f32 = all_sims.iter().map(|s| (s - max_sim).exp()).sum();
let log_sum_exp = max_sim + sum_exp.ln();
-pos_sim + log_sum_exp
}
/// Compute NT-Xent loss (Normalized Temperature-scaled Cross Entropy).
///
/// This is the loss function used in SimCLR.
///
/// # Arguments
/// * `embeddings` - All embeddings in the batch (2N for N samples with 2 views each)
/// * `temperature` - Temperature parameter
///
/// # Returns
/// The average NT-Xent loss across all positive pairs
#[must_use]
pub fn nt_xent_loss(embeddings: &[Vec<f32>], temperature: f32) -> f32 {
let n = embeddings.len();
if n < 2 {
return 0.0;
}
let temp = temperature.max(1e-6);
// Assume embeddings are organized as [z_1_a, z_1_b, z_2_a, z_2_b, ...]
// where z_i_a and z_i_b are two views of sample i
// Compute all pairwise similarities
let mut sim_matrix = vec![vec![0.0f32; n]; n];
for i in 0..n {
for j in 0..n {
sim_matrix[i][j] = cosine_similarity(&embeddings[i], &embeddings[j]) / temp;
}
}
let mut total_loss = 0.0;
let mut count = 0;
// For each sample, compute loss against its positive pair
for i in 0..n {
// Find positive pair (assumes alternating views)
let j = if i % 2 == 0 { i + 1 } else { i - 1 };
if j >= n {
continue;
}
let pos_sim = sim_matrix[i][j];
// Sum of all negative similarities (excluding self and positive)
let max_sim = sim_matrix[i]
.iter()
.enumerate()
.filter(|(k, _)| *k != i)
.map(|(_, &s)| s)
.fold(f32::NEG_INFINITY, f32::max);
let sum_exp: f32 = sim_matrix[i]
.iter()
.enumerate()
.filter(|(k, _)| *k != i)
.map(|(_, &s)| (s - max_sim).exp())
.sum();
let log_sum_exp = max_sim + sum_exp.ln();
total_loss += -pos_sim + log_sum_exp;
count += 1;
}
if count > 0 {
total_loss / count as f32
} else {
0.0
}
}
/// Compute supervised contrastive loss (SupCon).
///
/// Extends contrastive loss to use label information.
///
/// # Arguments
/// * `embeddings` - All embeddings
/// * `labels` - Label for each embedding
/// * `temperature` - Temperature parameter
///
/// # Returns
/// The supervised contrastive loss
#[must_use]
pub fn supervised_contrastive_loss(
embeddings: &[Vec<f32>],
labels: &[usize],
temperature: f32,
) -> f32 {
let n = embeddings.len();
if n < 2 || n != labels.len() {
return 0.0;
}
let temp = temperature.max(1e-6);
// Compute all pairwise similarities
let mut sim_matrix = vec![vec![0.0f32; n]; n];
for i in 0..n {
for j in 0..n {
sim_matrix[i][j] = cosine_similarity(&embeddings[i], &embeddings[j]) / temp;
}
}
let mut total_loss = 0.0;
for i in 0..n {
// Find all positive pairs (same label, excluding self)
let positives: Vec<usize> = (0..n)
.filter(|&j| j != i && labels[j] == labels[i])
.collect();
if positives.is_empty() {
continue;
}
// Compute denominator (all except self)
let max_sim = sim_matrix[i]
.iter()
.enumerate()
.filter(|(j, _)| *j != i)
.map(|(_, &s)| s)
.fold(f32::NEG_INFINITY, f32::max);
let denom_exp: f32 = sim_matrix[i]
.iter()
.enumerate()
.filter(|(j, _)| *j != i)
.map(|(_, &s)| (s - max_sim).exp())
.sum();
let log_denom = max_sim + denom_exp.ln();
// Average over all positive pairs
let pos_loss: f32 = positives
.iter()
.map(|&j| -sim_matrix[i][j] + log_denom)
.sum();
total_loss += pos_loss / positives.len() as f32;
}
total_loss / n as f32
}
/// Compute center loss.
///
/// Encourages embeddings to be close to their class centers.
///
/// # Arguments
/// * `embeddings` - Embeddings
/// * `labels` - Class labels
/// * `centers` - Class center embeddings
///
/// # Returns
/// The center loss value
#[must_use]
pub fn center_loss(
embeddings: &[Vec<f32>],
labels: &[usize],
centers: &[Vec<f32>],
) -> f32 {
if embeddings.is_empty() || embeddings.len() != labels.len() {
return 0.0;
}
let mut total_loss = 0.0;
for (emb, &label) in embeddings.iter().zip(labels.iter()) {
if label < centers.len() {
let dist = euclidean_distance(emb, &centers[label]);
total_loss += dist * dist;
}
}
total_loss / (2.0 * embeddings.len() as f32)
}
/// Compute focal loss (for imbalanced classification).
///
/// # Arguments
/// * `predictions` - Predicted probabilities
/// * `targets` - Target labels
/// * `gamma` - Focusing parameter (typical: 2.0)
/// * `alpha` - Class weighting parameter
///
/// # Returns
/// The focal loss value
#[must_use]
pub fn focal_loss(
predictions: &[f32],
targets: &[usize],
gamma: f32,
alpha: f32,
) -> f32 {
if predictions.is_empty() || predictions.len() != targets.len() {
return 0.0;
}
let eps = 1e-7;
let mut total_loss = 0.0;
for (&pred, &target) in predictions.iter().zip(targets.iter()) {
let p = pred.clamp(eps, 1.0 - eps);
let pt = if target == 1 { p } else { 1.0 - p };
let at = if target == 1 { alpha } else { 1.0 - alpha };
let loss = -at * (1.0 - pt).powf(gamma) * pt.ln();
total_loss += loss;
}
total_loss / predictions.len() as f32
}
// =========== Helper Functions ===========
/// Compute cosine similarity between two vectors
#[inline]
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a < 1e-10 || norm_b < 1e-10 {
return 0.0;
}
dot / (norm_a * norm_b)
}
/// Compute Euclidean distance between two vectors
#[inline]
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return f32::INFINITY;
}
a.iter()
.zip(b)
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
/// Compute squared Euclidean distance (faster, no sqrt)
#[inline]
#[allow(dead_code)]
fn squared_euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return f32::INFINITY;
}
a.iter().zip(b).map(|(x, y)| (x - y).powi(2)).sum()
}
/// Compute dot product
#[inline]
#[allow(dead_code)]
fn dot_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(x, y)| x * y).sum()
}
/// L2 normalize a vector
#[must_use]
pub fn l2_normalize(v: &[f32]) -> Vec<f32> {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm < 1e-10 {
return v.to_vec();
}
v.iter().map(|x| x / norm).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_info_nce_loss() {
let anchor = vec![1.0, 0.0, 0.0];
let positive = vec![0.95, 0.05, 0.0];
let negative1 = vec![0.0, 1.0, 0.0];
let negative2 = vec![0.0, 0.0, 1.0];
let loss = info_nce_loss(&anchor, &positive, &[&negative1, &negative2], 0.1);
assert!(loss >= 0.0);
// Similar positive should have lower loss
let similar_positive = vec![0.99, 0.01, 0.0];
let lower_loss = info_nce_loss(&anchor, &similar_positive, &[&negative1, &negative2], 0.1);
assert!(lower_loss < loss);
}
#[test]
fn test_triplet_loss() {
let anchor = vec![1.0, 0.0, 0.0];
let positive = vec![0.9, 0.1, 0.0];
let negative = vec![0.0, 1.0, 0.0];
let loss = triplet_loss(&anchor, &positive, &negative, 1.0);
assert!(loss >= 0.0);
// When positive is very similar, loss should be lower
let close_positive = vec![0.99, 0.01, 0.0];
let lower_loss = triplet_loss(&anchor, &close_positive, &negative, 1.0);
assert!(lower_loss <= loss);
}
#[test]
fn test_margin_ranking_loss() {
let anchor = vec![1.0, 0.0, 0.0];
let positive = vec![0.9, 0.1, 0.0];
let negative = vec![0.0, 1.0, 0.0];
let loss = margin_ranking_loss(&anchor, &positive, &negative, 0.5);
assert!(loss >= 0.0);
}
#[test]
fn test_contrastive_loss() {
let z_i = vec![1.0, 0.0, 0.0];
let z_j = vec![0.95, 0.05, 0.0];
let other1 = vec![0.0, 1.0, 0.0];
let other2 = vec![0.0, 0.0, 1.0];
let loss = contrastive_loss(&z_i, &z_j, &[&other1, &other2], 0.1);
assert!(loss >= 0.0);
}
#[test]
fn test_nt_xent_loss() {
let embeddings = vec![
vec![1.0, 0.0],
vec![0.95, 0.05], // positive pair with first
vec![0.0, 1.0],
vec![0.05, 0.95], // positive pair with third
];
let loss = nt_xent_loss(&embeddings, 0.5);
assert!(loss >= 0.0);
}
#[test]
fn test_supervised_contrastive_loss() {
let embeddings = vec![
vec![1.0, 0.0],
vec![0.9, 0.1], // same class as first
vec![0.0, 1.0],
vec![0.1, 0.9], // same class as third
];
let labels = vec![0, 0, 1, 1];
let loss = supervised_contrastive_loss(&embeddings, &labels, 0.1);
assert!(loss >= 0.0);
}
#[test]
fn test_center_loss() {
let embeddings = vec![
vec![1.0, 0.0],
vec![0.9, 0.1],
vec![0.0, 1.0],
];
let labels = vec![0, 0, 1];
let centers = vec![
vec![0.95, 0.05], // center for class 0
vec![0.05, 0.95], // center for class 1
];
let loss = center_loss(&embeddings, &labels, &centers);
assert!(loss >= 0.0);
}
#[test]
fn test_focal_loss() {
let predictions = vec![0.9, 0.1, 0.8, 0.2];
let targets = vec![1, 0, 1, 0];
let loss = focal_loss(&predictions, &targets, 2.0, 0.25);
assert!(loss >= 0.0);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
let c = vec![0.0, 1.0, 0.0];
assert!(cosine_similarity(&a, &c).abs() < 1e-6);
let d = vec![-1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &d) + 1.0).abs() < 1e-6);
}
#[test]
fn test_euclidean_distance() {
let a = vec![0.0, 0.0];
let b = vec![3.0, 4.0];
assert!((euclidean_distance(&a, &b) - 5.0).abs() < 1e-6);
}
#[test]
fn test_l2_normalize() {
let v = vec![3.0, 4.0];
let normalized = l2_normalize(&v);
let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6);
assert!((normalized[0] - 0.6).abs() < 1e-6);
assert!((normalized[1] - 0.8).abs() < 1e-6);
}
#[test]
fn test_empty_inputs() {
let empty: Vec<f32> = vec![];
let valid = vec![1.0, 0.0, 0.0];
assert_eq!(info_nce_loss(&empty, &valid, &[&valid], 0.1), 0.0);
assert_eq!(triplet_loss(&empty, &valid, &valid, 1.0), 0.0);
assert_eq!(contrastive_loss(&empty, &valid, &[], 0.1), 0.0);
}
}