Files
wifi-densepose/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/sona.rs
ruv fc409dfd6a feat: ADR-023 full DensePose training pipeline (Phases 1-8)
Implement complete WiFi CSI-to-DensePose neural network pipeline:

Phase 1 - Dataset loaders: .npy/.mat v5 parsers, MM-Fi + Wi-Pose
  loaders, subcarrier resampling (114->56, 30->56), DataPipeline
Phase 2 - Graph transformer: COCO BodyGraph (17 kp, 16 edges),
  AntennaGraph, multi-head CrossAttention, GCN message passing,
  CsiToPoseTransformer full pipeline
Phase 4 - Training loop: 6-term composite loss (MSE, cross-entropy,
  UV regression, temporal consistency, bone length, symmetry),
  SGD+momentum, cosine+warmup scheduler, PCK/OKS metrics, checkpoints
Phase 5 - SONA adaptation: LoRA (rank-4, A*B delta), EWC++ Fisher
  regularization, EnvironmentDetector (3-sigma drift), temporal
  consistency loss
Phase 6 - Sparse inference: NeuronProfiler hot/cold partitioning,
  SparseLinear (skip cold rows), INT8/FP16 quantization with <0.01
  MSE, SparseModel engine, BenchmarkRunner
Phase 7 - RVF pipeline: 6 new segment types (Index, Overlay, Crypto,
  WASM, Dashboard, AggregateWeights), HNSW index, OverlayGraph,
  RvfModelBuilder, ProgressiveLoader (3-layer: A=instant, B=hot, C=full)
Phase 8 - Server integration: --model, --progressive CLI flags,
  4 new REST endpoints, WebSocket pose_keypoints + model_status

229 tests passing (147 unit + 48 bin + 34 integration)
Benchmark: 9,520 frames/sec (105μs/frame), 476x real-time at 20 Hz
7,832 lines of pure Rust, zero external ML dependencies

Co-Authored-By: claude-flow <ruv@ruv.net>
2026-02-28 23:22:15 -05:00

640 lines
24 KiB
Rust

//! SONA online adaptation: LoRA + EWC++ for WiFi-DensePose (ADR-023 Phase 5).
//!
//! Enables rapid low-parameter adaptation to changing WiFi environments without
//! catastrophic forgetting. All arithmetic uses `f32`, no external dependencies.
use std::collections::VecDeque;
// ── LoRA Adapter ────────────────────────────────────────────────────────────
/// Low-Rank Adaptation layer storing factorised delta `scale * A * B`.
#[derive(Debug, Clone)]
pub struct LoraAdapter {
pub a: Vec<Vec<f32>>, // (in_features, rank)
pub b: Vec<Vec<f32>>, // (rank, out_features)
pub scale: f32, // alpha / rank
pub in_features: usize,
pub out_features: usize,
pub rank: usize,
}
impl LoraAdapter {
pub fn new(in_features: usize, out_features: usize, rank: usize, alpha: f32) -> Self {
Self {
a: vec![vec![0.0f32; rank]; in_features],
b: vec![vec![0.0f32; out_features]; rank],
scale: alpha / rank.max(1) as f32,
in_features, out_features, rank,
}
}
/// Compute `scale * input * A * B`, returning a vector of length `out_features`.
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
assert_eq!(input.len(), self.in_features);
let mut hidden = vec![0.0f32; self.rank];
for (i, &x) in input.iter().enumerate() {
for r in 0..self.rank { hidden[r] += x * self.a[i][r]; }
}
let mut output = vec![0.0f32; self.out_features];
for r in 0..self.rank {
for j in 0..self.out_features { output[j] += hidden[r] * self.b[r][j]; }
}
for v in output.iter_mut() { *v *= self.scale; }
output
}
/// Full delta weight matrix `scale * A * B`, shape (in_features, out_features).
pub fn delta_weights(&self) -> Vec<Vec<f32>> {
let mut delta = vec![vec![0.0f32; self.out_features]; self.in_features];
for i in 0..self.in_features {
for r in 0..self.rank {
let a_val = self.a[i][r];
for j in 0..self.out_features { delta[i][j] += a_val * self.b[r][j]; }
}
}
for row in delta.iter_mut() { for v in row.iter_mut() { *v *= self.scale; } }
delta
}
/// Add LoRA delta to base weights in place.
pub fn merge_into(&self, base_weights: &mut [Vec<f32>]) {
let delta = self.delta_weights();
for (rb, rd) in base_weights.iter_mut().zip(delta.iter()) {
for (w, &d) in rb.iter_mut().zip(rd.iter()) { *w += d; }
}
}
/// Subtract LoRA delta from base weights in place.
pub fn unmerge_from(&self, base_weights: &mut [Vec<f32>]) {
let delta = self.delta_weights();
for (rb, rd) in base_weights.iter_mut().zip(delta.iter()) {
for (w, &d) in rb.iter_mut().zip(rd.iter()) { *w -= d; }
}
}
/// Trainable parameter count: `rank * (in_features + out_features)`.
pub fn n_params(&self) -> usize { self.rank * (self.in_features + self.out_features) }
/// Reset A and B to zero.
pub fn reset(&mut self) {
for row in self.a.iter_mut() { for v in row.iter_mut() { *v = 0.0; } }
for row in self.b.iter_mut() { for v in row.iter_mut() { *v = 0.0; } }
}
}
// ── EWC++ Regularizer ───────────────────────────────────────────────────────
/// Elastic Weight Consolidation++ regularizer with running Fisher average.
#[derive(Debug, Clone)]
pub struct EwcRegularizer {
pub lambda: f32,
pub decay: f32,
pub fisher_diag: Vec<f32>,
pub reference_params: Vec<f32>,
}
impl EwcRegularizer {
pub fn new(lambda: f32, decay: f32) -> Self {
Self { lambda, decay, fisher_diag: Vec::new(), reference_params: Vec::new() }
}
/// Diagonal Fisher via numerical central differences: F_i = grad_i^2.
pub fn compute_fisher(params: &[f32], loss_fn: impl Fn(&[f32]) -> f32, n_samples: usize) -> Vec<f32> {
let eps = 1e-4f32;
let n = params.len();
let mut fisher = vec![0.0f32; n];
let samples = n_samples.max(1);
for _ in 0..samples {
let mut p = params.to_vec();
for i in 0..n {
let orig = p[i];
p[i] = orig + eps;
let lp = loss_fn(&p);
p[i] = orig - eps;
let lm = loss_fn(&p);
p[i] = orig;
let g = (lp - lm) / (2.0 * eps);
fisher[i] += g * g;
}
}
for f in fisher.iter_mut() { *f /= samples as f32; }
fisher
}
/// Online update: `F = decay * F_old + (1-decay) * F_new`.
pub fn update_fisher(&mut self, new_fisher: &[f32]) {
if self.fisher_diag.is_empty() {
self.fisher_diag = new_fisher.to_vec();
return;
}
assert_eq!(self.fisher_diag.len(), new_fisher.len());
for (old, &nv) in self.fisher_diag.iter_mut().zip(new_fisher.iter()) {
*old = self.decay * *old + (1.0 - self.decay) * nv;
}
}
/// Penalty: `0.5 * lambda * sum(F_i * (theta_i - theta_i*)^2)`.
pub fn penalty(&self, current_params: &[f32]) -> f32 {
if self.reference_params.is_empty() || self.fisher_diag.is_empty() { return 0.0; }
let n = current_params.len().min(self.reference_params.len()).min(self.fisher_diag.len());
let mut sum = 0.0f32;
for i in 0..n {
let d = current_params[i] - self.reference_params[i];
sum += self.fisher_diag[i] * d * d;
}
0.5 * self.lambda * sum
}
/// Gradient of penalty: `lambda * F_i * (theta_i - theta_i*)`.
pub fn penalty_gradient(&self, current_params: &[f32]) -> Vec<f32> {
if self.reference_params.is_empty() || self.fisher_diag.is_empty() {
return vec![0.0f32; current_params.len()];
}
let n = current_params.len().min(self.reference_params.len()).min(self.fisher_diag.len());
let mut grad = vec![0.0f32; current_params.len()];
for i in 0..n {
grad[i] = self.lambda * self.fisher_diag[i] * (current_params[i] - self.reference_params[i]);
}
grad
}
/// Save current params as the new reference point.
pub fn consolidate(&mut self, params: &[f32]) { self.reference_params = params.to_vec(); }
}
// ── Configuration & Types ───────────────────────────────────────────────────
/// SONA adaptation configuration.
#[derive(Debug, Clone)]
pub struct SonaConfig {
pub lora_rank: usize,
pub lora_alpha: f32,
pub ewc_lambda: f32,
pub ewc_decay: f32,
pub adaptation_lr: f32,
pub max_steps: usize,
pub convergence_threshold: f32,
pub temporal_consistency_weight: f32,
}
impl Default for SonaConfig {
fn default() -> Self {
Self {
lora_rank: 4, lora_alpha: 8.0, ewc_lambda: 5000.0, ewc_decay: 0.99,
adaptation_lr: 0.001, max_steps: 50, convergence_threshold: 1e-4,
temporal_consistency_weight: 0.1,
}
}
}
/// Single training sample for online adaptation.
#[derive(Debug, Clone)]
pub struct AdaptationSample {
pub csi_features: Vec<f32>,
pub target: Vec<f32>,
}
/// Result of a SONA adaptation run.
#[derive(Debug, Clone)]
pub struct AdaptationResult {
pub adapted_params: Vec<f32>,
pub steps_taken: usize,
pub final_loss: f32,
pub converged: bool,
pub ewc_penalty: f32,
}
/// Saved environment-specific adaptation profile.
#[derive(Debug, Clone)]
pub struct SonaProfile {
pub name: String,
pub lora_a: Vec<Vec<f32>>,
pub lora_b: Vec<Vec<f32>>,
pub fisher_diag: Vec<f32>,
pub reference_params: Vec<f32>,
pub adaptation_count: usize,
}
// ── SONA Adapter ────────────────────────────────────────────────────────────
/// Full SONA system: LoRA adapter + EWC++ regularizer for online adaptation.
#[derive(Debug, Clone)]
pub struct SonaAdapter {
pub config: SonaConfig,
pub lora: LoraAdapter,
pub ewc: EwcRegularizer,
pub param_count: usize,
pub adaptation_count: usize,
}
impl SonaAdapter {
pub fn new(config: SonaConfig, param_count: usize) -> Self {
let lora = LoraAdapter::new(param_count, 1, config.lora_rank, config.lora_alpha);
let ewc = EwcRegularizer::new(config.ewc_lambda, config.ewc_decay);
Self { config, lora, ewc, param_count, adaptation_count: 0 }
}
/// Run gradient descent with LoRA + EWC on the given samples.
pub fn adapt(&mut self, base_params: &[f32], samples: &[AdaptationSample]) -> AdaptationResult {
assert_eq!(base_params.len(), self.param_count);
if samples.is_empty() {
return AdaptationResult {
adapted_params: base_params.to_vec(), steps_taken: 0,
final_loss: 0.0, converged: true, ewc_penalty: self.ewc.penalty(base_params),
};
}
let lr = self.config.adaptation_lr;
let (mut prev_loss, mut steps, mut converged) = (f32::MAX, 0usize, false);
let out_dim = samples[0].target.len();
let in_dim = samples[0].csi_features.len();
for step in 0..self.config.max_steps {
steps = step + 1;
let df = self.lora_delta_flat();
let eff: Vec<f32> = base_params.iter().zip(df.iter()).map(|(&b, &d)| b + d).collect();
let (dl, dg) = Self::mse_loss_grad(&eff, samples, in_dim, out_dim);
let ep = self.ewc.penalty(&eff);
let eg = self.ewc.penalty_gradient(&eff);
let total = dl + ep;
if (prev_loss - total).abs() < self.config.convergence_threshold {
converged = true; prev_loss = total; break;
}
prev_loss = total;
let gl = df.len().min(dg.len()).min(eg.len());
let mut tg = vec![0.0f32; gl];
for i in 0..gl { tg[i] = dg[i] + eg[i]; }
self.update_lora(&tg, lr);
}
let df = self.lora_delta_flat();
let adapted: Vec<f32> = base_params.iter().zip(df.iter()).map(|(&b, &d)| b + d).collect();
let ewc_penalty = self.ewc.penalty(&adapted);
self.adaptation_count += 1;
AdaptationResult { adapted_params: adapted, steps_taken: steps, final_loss: prev_loss, converged, ewc_penalty }
}
pub fn save_profile(&self, name: &str) -> SonaProfile {
SonaProfile {
name: name.to_string(), lora_a: self.lora.a.clone(), lora_b: self.lora.b.clone(),
fisher_diag: self.ewc.fisher_diag.clone(), reference_params: self.ewc.reference_params.clone(),
adaptation_count: self.adaptation_count,
}
}
pub fn load_profile(&mut self, profile: &SonaProfile) {
self.lora.a = profile.lora_a.clone();
self.lora.b = profile.lora_b.clone();
self.ewc.fisher_diag = profile.fisher_diag.clone();
self.ewc.reference_params = profile.reference_params.clone();
self.adaptation_count = profile.adaptation_count;
}
fn lora_delta_flat(&self) -> Vec<f32> {
self.lora.delta_weights().into_iter().map(|r| r[0]).collect()
}
fn mse_loss_grad(params: &[f32], samples: &[AdaptationSample], in_dim: usize, out_dim: usize) -> (f32, Vec<f32>) {
let n = samples.len() as f32;
let ws = in_dim * out_dim;
let mut grad = vec![0.0f32; params.len()];
let mut loss = 0.0f32;
for s in samples {
let (inp, tgt) = (&s.csi_features, &s.target);
let mut pred = vec![0.0f32; out_dim];
for j in 0..out_dim {
for i in 0..in_dim.min(inp.len()) {
let idx = j * in_dim + i;
if idx < ws && idx < params.len() { pred[j] += params[idx] * inp[i]; }
}
}
for j in 0..out_dim.min(tgt.len()) {
let e = pred[j] - tgt[j];
loss += e * e;
for i in 0..in_dim.min(inp.len()) {
let idx = j * in_dim + i;
if idx < ws && idx < grad.len() { grad[idx] += 2.0 * e * inp[i] / n; }
}
}
}
(loss / n, grad)
}
fn update_lora(&mut self, grad: &[f32], lr: f32) {
let (scale, rank) = (self.lora.scale, self.lora.rank);
if self.lora.b.iter().all(|r| r.iter().all(|&v| v == 0.0)) && rank > 0 {
self.lora.b[0][0] = 1.0;
}
for i in 0..self.lora.in_features.min(grad.len()) {
for r in 0..rank {
self.lora.a[i][r] -= lr * grad[i] * scale * self.lora.b[r][0];
}
}
for r in 0..rank {
let mut g = 0.0f32;
for i in 0..self.lora.in_features.min(grad.len()) {
g += grad[i] * scale * self.lora.a[i][r];
}
self.lora.b[r][0] -= lr * g;
}
}
}
// ── Environment Detector ────────────────────────────────────────────────────
/// CSI baseline drift information.
#[derive(Debug, Clone)]
pub struct DriftInfo {
pub magnitude: f32,
pub duration_frames: usize,
pub baseline_mean: f32,
pub current_mean: f32,
}
/// Detects environmental drift in CSI statistics (>3 sigma from baseline).
#[derive(Debug, Clone)]
pub struct EnvironmentDetector {
window_size: usize,
means: VecDeque<f32>,
variances: VecDeque<f32>,
baseline_mean: f32,
baseline_var: f32,
baseline_std: f32,
baseline_set: bool,
drift_frames: usize,
}
impl EnvironmentDetector {
pub fn new(window_size: usize) -> Self {
Self {
window_size: window_size.max(2),
means: VecDeque::with_capacity(window_size),
variances: VecDeque::with_capacity(window_size),
baseline_mean: 0.0, baseline_var: 0.0, baseline_std: 0.0,
baseline_set: false, drift_frames: 0,
}
}
pub fn update(&mut self, csi_mean: f32, csi_var: f32) {
self.means.push_back(csi_mean);
self.variances.push_back(csi_var);
while self.means.len() > self.window_size { self.means.pop_front(); }
while self.variances.len() > self.window_size { self.variances.pop_front(); }
if !self.baseline_set && self.means.len() >= self.window_size { self.reset_baseline(); }
if self.drift_detected() { self.drift_frames += 1; } else { self.drift_frames = 0; }
}
pub fn drift_detected(&self) -> bool {
if !self.baseline_set || self.means.is_empty() { return false; }
let dev = (self.current_mean() - self.baseline_mean).abs();
let thr = if self.baseline_std > f32::EPSILON { 3.0 * self.baseline_std }
else { f32::EPSILON * 100.0 };
dev > thr
}
pub fn reset_baseline(&mut self) {
if self.means.is_empty() { return; }
let n = self.means.len() as f32;
self.baseline_mean = self.means.iter().sum::<f32>() / n;
let var = self.means.iter().map(|&m| (m - self.baseline_mean).powi(2)).sum::<f32>() / n;
self.baseline_var = var;
self.baseline_std = var.sqrt();
self.baseline_set = true;
self.drift_frames = 0;
}
pub fn drift_info(&self) -> DriftInfo {
let cm = self.current_mean();
let abs_dev = (cm - self.baseline_mean).abs();
let magnitude = if self.baseline_std > f32::EPSILON { abs_dev / self.baseline_std }
else if abs_dev > f32::EPSILON { abs_dev / f32::EPSILON }
else { 0.0 };
DriftInfo { magnitude, duration_frames: self.drift_frames, baseline_mean: self.baseline_mean, current_mean: cm }
}
fn current_mean(&self) -> f32 {
if self.means.is_empty() { 0.0 }
else { self.means.iter().sum::<f32>() / self.means.len() as f32 }
}
}
// ── Temporal Consistency Loss ───────────────────────────────────────────────
/// Penalises large velocity between consecutive outputs: `sum((c-p)^2) / dt`.
pub struct TemporalConsistencyLoss;
impl TemporalConsistencyLoss {
pub fn compute(prev_output: &[f32], curr_output: &[f32], dt: f32) -> f32 {
if dt <= 0.0 { return 0.0; }
let n = prev_output.len().min(curr_output.len());
let mut sq = 0.0f32;
for i in 0..n { let d = curr_output[i] - prev_output[i]; sq += d * d; }
sq / dt
}
}
// ── Tests ───────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lora_adapter_param_count() {
let lora = LoraAdapter::new(64, 32, 4, 8.0);
assert_eq!(lora.n_params(), 4 * (64 + 32));
}
#[test]
fn lora_adapter_forward_shape() {
let lora = LoraAdapter::new(8, 4, 2, 4.0);
assert_eq!(lora.forward(&vec![1.0f32; 8]).len(), 4);
}
#[test]
fn lora_adapter_zero_init_produces_zero_delta() {
let delta = LoraAdapter::new(8, 4, 2, 4.0).delta_weights();
assert_eq!(delta.len(), 8);
for row in &delta { assert_eq!(row.len(), 4); for &v in row { assert_eq!(v, 0.0); } }
}
#[test]
fn lora_adapter_merge_unmerge_roundtrip() {
let mut lora = LoraAdapter::new(3, 2, 1, 2.0);
lora.a[0][0] = 1.0; lora.a[1][0] = 2.0; lora.a[2][0] = 3.0;
lora.b[0][0] = 0.5; lora.b[0][1] = -0.5;
let mut base = vec![vec![10.0, 20.0], vec![30.0, 40.0], vec![50.0, 60.0]];
let orig = base.clone();
lora.merge_into(&mut base);
assert_ne!(base, orig);
lora.unmerge_from(&mut base);
for (rb, ro) in base.iter().zip(orig.iter()) {
for (&b, &o) in rb.iter().zip(ro.iter()) {
assert!((b - o).abs() < 1e-5, "roundtrip failed: {b} vs {o}");
}
}
}
#[test]
fn lora_adapter_rank_1_outer_product() {
let mut lora = LoraAdapter::new(3, 2, 1, 1.0); // scale=1
lora.a[0][0] = 1.0; lora.a[1][0] = 2.0; lora.a[2][0] = 3.0;
lora.b[0][0] = 4.0; lora.b[0][1] = 5.0;
let d = lora.delta_weights();
let expected = [[4.0, 5.0], [8.0, 10.0], [12.0, 15.0]];
for (i, row) in expected.iter().enumerate() {
for (j, &v) in row.iter().enumerate() { assert!((d[i][j] - v).abs() < 1e-6); }
}
}
#[test]
fn lora_scale_factor() {
assert!((LoraAdapter::new(8, 4, 4, 16.0).scale - 4.0).abs() < 1e-6);
assert!((LoraAdapter::new(8, 4, 2, 8.0).scale - 4.0).abs() < 1e-6);
}
#[test]
fn ewc_fisher_positive() {
let fisher = EwcRegularizer::compute_fisher(
&[1.0f32, -2.0, 0.5],
|p: &[f32]| p.iter().map(|&x| x * x).sum::<f32>(), 1,
);
assert_eq!(fisher.len(), 3);
for &f in &fisher { assert!(f >= 0.0, "Fisher must be >= 0, got {f}"); }
}
#[test]
fn ewc_penalty_zero_at_reference() {
let mut ewc = EwcRegularizer::new(5000.0, 0.99);
let p = vec![1.0, 2.0, 3.0];
ewc.fisher_diag = vec![1.0; 3]; ewc.consolidate(&p);
assert!(ewc.penalty(&p).abs() < 1e-10);
}
#[test]
fn ewc_penalty_positive_away_from_reference() {
let mut ewc = EwcRegularizer::new(5000.0, 0.99);
ewc.fisher_diag = vec![1.0; 3]; ewc.consolidate(&[1.0, 2.0, 3.0]);
let pen = ewc.penalty(&[2.0, 3.0, 4.0]);
assert!(pen > 0.0); // 0.5 * 5000 * 3 = 7500
assert!((pen - 7500.0).abs() < 1e-3, "expected ~7500, got {pen}");
}
#[test]
fn ewc_penalty_gradient_direction() {
let mut ewc = EwcRegularizer::new(100.0, 0.99);
let r = vec![1.0, 2.0, 3.0];
ewc.fisher_diag = vec![1.0; 3]; ewc.consolidate(&r);
let c = vec![2.0, 4.0, 5.0];
let grad = ewc.penalty_gradient(&c);
for (i, &g) in grad.iter().enumerate() {
assert!(g * (c[i] - r[i]) > 0.0, "gradient[{i}] wrong sign");
}
}
#[test]
fn ewc_online_update_decays() {
let mut ewc = EwcRegularizer::new(1.0, 0.5);
ewc.update_fisher(&[10.0, 20.0]);
assert!((ewc.fisher_diag[0] - 10.0).abs() < 1e-6);
ewc.update_fisher(&[0.0, 0.0]);
assert!((ewc.fisher_diag[0] - 5.0).abs() < 1e-6); // 0.5*10 + 0.5*0
assert!((ewc.fisher_diag[1] - 10.0).abs() < 1e-6); // 0.5*20 + 0.5*0
}
#[test]
fn ewc_consolidate_updates_reference() {
let mut ewc = EwcRegularizer::new(1.0, 0.99);
ewc.consolidate(&[1.0, 2.0]);
assert_eq!(ewc.reference_params, vec![1.0, 2.0]);
ewc.consolidate(&[3.0, 4.0]);
assert_eq!(ewc.reference_params, vec![3.0, 4.0]);
}
#[test]
fn sona_config_defaults() {
let c = SonaConfig::default();
assert_eq!(c.lora_rank, 4);
assert!((c.lora_alpha - 8.0).abs() < 1e-6);
assert!((c.ewc_lambda - 5000.0).abs() < 1e-3);
assert!((c.ewc_decay - 0.99).abs() < 1e-6);
assert!((c.adaptation_lr - 0.001).abs() < 1e-6);
assert_eq!(c.max_steps, 50);
assert!((c.convergence_threshold - 1e-4).abs() < 1e-8);
assert!((c.temporal_consistency_weight - 0.1).abs() < 1e-6);
}
#[test]
fn sona_adapter_converges_on_simple_task() {
let cfg = SonaConfig {
lora_rank: 1, lora_alpha: 1.0, ewc_lambda: 0.0, ewc_decay: 0.99,
adaptation_lr: 0.01, max_steps: 200, convergence_threshold: 1e-6,
temporal_consistency_weight: 0.0,
};
let mut adapter = SonaAdapter::new(cfg, 1);
let samples: Vec<_> = (1..=5).map(|i| {
let x = i as f32;
AdaptationSample { csi_features: vec![x], target: vec![2.0 * x] }
}).collect();
let r = adapter.adapt(&[0.0f32], &samples);
assert!(r.final_loss < 1.0, "loss should decrease, got {}", r.final_loss);
assert!(r.steps_taken > 0);
}
#[test]
fn sona_adapter_respects_max_steps() {
let cfg = SonaConfig { max_steps: 5, convergence_threshold: 0.0, ..SonaConfig::default() };
let mut a = SonaAdapter::new(cfg, 4);
let s = vec![AdaptationSample { csi_features: vec![1.0, 0.0, 0.0, 0.0], target: vec![1.0] }];
assert_eq!(a.adapt(&[0.0; 4], &s).steps_taken, 5);
}
#[test]
fn sona_profile_save_load_roundtrip() {
let mut a = SonaAdapter::new(SonaConfig::default(), 8);
a.lora.a[0][0] = 1.5; a.lora.b[0][0] = -0.3;
a.ewc.fisher_diag = vec![1.0, 2.0, 3.0];
a.ewc.reference_params = vec![0.1, 0.2, 0.3];
a.adaptation_count = 42;
let p = a.save_profile("test-env");
assert_eq!(p.name, "test-env");
assert_eq!(p.adaptation_count, 42);
let mut a2 = SonaAdapter::new(SonaConfig::default(), 8);
a2.load_profile(&p);
assert!((a2.lora.a[0][0] - 1.5).abs() < 1e-6);
assert!((a2.lora.b[0][0] - (-0.3)).abs() < 1e-6);
assert_eq!(a2.ewc.fisher_diag.len(), 3);
assert!((a2.ewc.fisher_diag[2] - 3.0).abs() < 1e-6);
assert_eq!(a2.adaptation_count, 42);
}
#[test]
fn environment_detector_no_drift_initially() {
assert!(!EnvironmentDetector::new(10).drift_detected());
}
#[test]
fn environment_detector_detects_large_shift() {
let mut d = EnvironmentDetector::new(10);
for _ in 0..10 { d.update(10.0, 0.1); }
assert!(!d.drift_detected());
for _ in 0..10 { d.update(50.0, 0.1); }
assert!(d.drift_detected());
assert!(d.drift_info().magnitude > 3.0, "magnitude = {}", d.drift_info().magnitude);
}
#[test]
fn environment_detector_reset_baseline() {
let mut d = EnvironmentDetector::new(10);
for _ in 0..10 { d.update(10.0, 0.1); }
for _ in 0..10 { d.update(50.0, 0.1); }
assert!(d.drift_detected());
d.reset_baseline();
assert!(!d.drift_detected());
}
#[test]
fn temporal_consistency_zero_for_static() {
let o = vec![1.0, 2.0, 3.0];
assert!(TemporalConsistencyLoss::compute(&o, &o, 0.033).abs() < 1e-10);
}
}