Compare commits
21 Commits
adr-028-es
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7c2e7e2b27 | ||
|
|
381b51a382 | ||
|
|
e99a41434d | ||
|
|
0aab555821 | ||
|
|
0c01157e36 | ||
|
|
60e0e6d3c4 | ||
|
|
97f2a490eb | ||
|
|
c520204e12 | ||
|
|
1288fd9375 | ||
|
|
95c68139bc | ||
|
|
ba9c88ee30 | ||
|
|
5541926e6a | ||
|
|
37b54d649b | ||
|
|
303871275b | ||
|
|
b4f1e55546 | ||
|
|
d4dc5cb0bc | ||
|
|
374b0fdcef | ||
|
|
c707b636bd | ||
|
|
25b005a0d6 | ||
|
|
08a6d5a7f1 | ||
|
|
322eddbcc3 |
167
CLAUDE.md
167
CLAUDE.md
@@ -4,13 +4,49 @@
|
||||
|
||||
WiFi-based human pose estimation using Channel State Information (CSI).
|
||||
Dual codebase: Python v1 (`v1/`) and Rust port (`rust-port/wifi-densepose-rs/`).
|
||||
|
||||
### Key Rust Crates
|
||||
- `wifi-densepose-signal` — SOTA signal processing (conjugate mult, Hampel, Fresnel, BVP, spectrogram)
|
||||
- `wifi-densepose-train` — Training pipeline with ruvector integration (ADR-016)
|
||||
- `wifi-densepose-mat` — Disaster detection module (MAT, multi-AP, triage)
|
||||
- `wifi-densepose-nn` — Neural network inference (DensePose head, RCNN)
|
||||
- `wifi-densepose-hardware` — ESP32 aggregator, hardware interfaces
|
||||
| Crate | Description |
|
||||
|-------|-------------|
|
||||
| `wifi-densepose-core` | Core types, traits, error types, CSI frame primitives |
|
||||
| `wifi-densepose-signal` | SOTA signal processing + RuvSense multistatic sensing (14 modules) |
|
||||
| `wifi-densepose-nn` | Neural network inference (ONNX, PyTorch, Candle backends) |
|
||||
| `wifi-densepose-train` | Training pipeline with ruvector integration + ruview_metrics |
|
||||
| `wifi-densepose-mat` | Mass Casualty Assessment Tool — disaster survivor detection |
|
||||
| `wifi-densepose-hardware` | ESP32 aggregator, TDM protocol, channel hopping firmware |
|
||||
| `wifi-densepose-ruvector` | RuVector v2.0.4 integration + cross-viewpoint fusion (5 modules) |
|
||||
| `wifi-densepose-api` | REST API (Axum) |
|
||||
| `wifi-densepose-db` | Database layer (Postgres, SQLite, Redis) |
|
||||
| `wifi-densepose-config` | Configuration management |
|
||||
| `wifi-densepose-wasm` | WebAssembly bindings for browser deployment |
|
||||
| `wifi-densepose-cli` | CLI tool (`wifi-densepose` binary) |
|
||||
| `wifi-densepose-sensing-server` | Lightweight Axum server for WiFi sensing UI |
|
||||
| `wifi-densepose-wifiscan` | Multi-BSSID WiFi scanning (ADR-022) |
|
||||
| `wifi-densepose-vitals` | ESP32 CSI-grade vital sign extraction (ADR-021) |
|
||||
|
||||
### RuvSense Modules (`signal/src/ruvsense/`)
|
||||
| Module | Purpose |
|
||||
|--------|---------|
|
||||
| `multiband.rs` | Multi-band CSI frame fusion, cross-channel coherence |
|
||||
| `phase_align.rs` | Iterative LO phase offset estimation, circular mean |
|
||||
| `multistatic.rs` | Attention-weighted fusion, geometric diversity |
|
||||
| `coherence.rs` | Z-score coherence scoring, DriftProfile |
|
||||
| `coherence_gate.rs` | Accept/PredictOnly/Reject/Recalibrate gate decisions |
|
||||
| `pose_tracker.rs` | 17-keypoint Kalman tracker with AETHER re-ID embeddings |
|
||||
| `field_model.rs` | SVD room eigenstructure, perturbation extraction |
|
||||
| `tomography.rs` | RF tomography, ISTA L1 solver, voxel grid |
|
||||
| `longitudinal.rs` | Welford stats, biomechanics drift detection |
|
||||
| `intention.rs` | Pre-movement lead signals (200-500ms) |
|
||||
| `cross_room.rs` | Environment fingerprinting, transition graph |
|
||||
| `gesture.rs` | DTW template matching gesture classifier |
|
||||
| `adversarial.rs` | Physically impossible signal detection, multi-link consistency |
|
||||
|
||||
### Cross-Viewpoint Fusion (`ruvector/src/viewpoint/`)
|
||||
| Module | Purpose |
|
||||
|--------|---------|
|
||||
| `attention.rs` | CrossViewpointAttention, GeometricBias, softmax with G_bias |
|
||||
| `geometry.rs` | GeometricDiversityIndex, Cramer-Rao bounds, Fisher Information |
|
||||
| `coherence.rs` | Phase phasor coherence, hysteresis gate |
|
||||
| `fusion.rs` | MultistaticArray aggregate root, domain events |
|
||||
|
||||
### RuVector v2.0.4 Integration (ADR-016 complete, ADR-017 proposed)
|
||||
All 5 ruvector crates integrated in workspace:
|
||||
@@ -21,33 +57,105 @@ All 5 ruvector crates integrated in workspace:
|
||||
- `ruvector-attention` → `model.rs` (apply_spatial_attention) + `bvp.rs`
|
||||
|
||||
### Architecture Decisions
|
||||
All ADRs in `docs/adr/` (ADR-001 through ADR-017). Key ones:
|
||||
32 ADRs in `docs/adr/` (ADR-001 through ADR-032). Key ones:
|
||||
- ADR-014: SOTA signal processing (Accepted)
|
||||
- ADR-015: MM-Fi + Wi-Pose training datasets (Accepted)
|
||||
- ADR-016: RuVector training pipeline integration (Accepted — complete)
|
||||
- ADR-017: RuVector signal + MAT integration (Proposed — next target)
|
||||
- ADR-024: Contrastive CSI embedding / AETHER (Accepted)
|
||||
- ADR-027: Cross-environment domain generalization / MERIDIAN (Accepted)
|
||||
- ADR-028: ESP32 capability audit + witness verification (Accepted)
|
||||
- ADR-029: RuvSense multistatic sensing mode (Proposed)
|
||||
- ADR-030: RuvSense persistent field model (Proposed)
|
||||
- ADR-031: RuView sensing-first RF mode (Proposed)
|
||||
- ADR-032: Multistatic mesh security hardening (Proposed)
|
||||
|
||||
### Build & Test Commands (this repo)
|
||||
```bash
|
||||
# Rust — check training crate (no GPU needed)
|
||||
# Rust — full workspace tests (1,031+ tests, ~2 min)
|
||||
cd rust-port/wifi-densepose-rs
|
||||
cargo test --workspace --no-default-features
|
||||
|
||||
# Rust — single crate check (no GPU needed)
|
||||
cargo check -p wifi-densepose-train --no-default-features
|
||||
|
||||
# Rust — run all tests
|
||||
cargo test -p wifi-densepose-train --no-default-features
|
||||
# Rust — publish crates (dependency order)
|
||||
cargo publish -p wifi-densepose-core --no-default-features
|
||||
cargo publish -p wifi-densepose-signal --no-default-features
|
||||
# ... see crate publishing order below
|
||||
|
||||
# Rust — full workspace check
|
||||
cargo check --workspace --no-default-features
|
||||
|
||||
# Python — proof verification
|
||||
# Python — deterministic proof verification (SHA-256)
|
||||
python v1/data/proof/verify.py
|
||||
|
||||
# Python — test suite
|
||||
cd v1 && python -m pytest tests/ -x -q
|
||||
```
|
||||
|
||||
### Crate Publishing Order
|
||||
Crates must be published in dependency order:
|
||||
1. `wifi-densepose-core` (no internal deps)
|
||||
2. `wifi-densepose-vitals` (no internal deps)
|
||||
3. `wifi-densepose-wifiscan` (no internal deps)
|
||||
4. `wifi-densepose-hardware` (no internal deps)
|
||||
5. `wifi-densepose-config` (no internal deps)
|
||||
6. `wifi-densepose-db` (no internal deps)
|
||||
7. `wifi-densepose-signal` (depends on core)
|
||||
8. `wifi-densepose-nn` (no internal deps, workspace only)
|
||||
9. `wifi-densepose-ruvector` (no internal deps, workspace only)
|
||||
10. `wifi-densepose-train` (depends on signal, nn)
|
||||
11. `wifi-densepose-mat` (depends on core, signal, nn)
|
||||
12. `wifi-densepose-api` (no internal deps)
|
||||
13. `wifi-densepose-wasm` (depends on mat)
|
||||
14. `wifi-densepose-sensing-server` (depends on wifiscan)
|
||||
15. `wifi-densepose-cli` (depends on mat)
|
||||
|
||||
### Validation & Witness Verification (ADR-028)
|
||||
|
||||
**After any significant code change, run the full validation:**
|
||||
|
||||
```bash
|
||||
# 1. Rust tests — must be 1,031+ passed, 0 failed
|
||||
cd rust-port/wifi-densepose-rs
|
||||
cargo test --workspace --no-default-features
|
||||
|
||||
# 2. Python proof — must print VERDICT: PASS
|
||||
cd ../..
|
||||
python v1/data/proof/verify.py
|
||||
|
||||
# 3. Generate witness bundle (includes both above + firmware hashes)
|
||||
bash scripts/generate-witness-bundle.sh
|
||||
|
||||
# 4. Self-verify the bundle — must be 7/7 PASS
|
||||
cd dist/witness-bundle-ADR028-*/
|
||||
bash VERIFY.sh
|
||||
```
|
||||
|
||||
**If the Python proof hash changes** (e.g., numpy/scipy version update):
|
||||
```bash
|
||||
# Regenerate the expected hash, then verify it passes
|
||||
python v1/data/proof/verify.py --generate-hash
|
||||
python v1/data/proof/verify.py
|
||||
```
|
||||
|
||||
**Witness bundle contents** (`dist/witness-bundle-ADR028-<sha>.tar.gz`):
|
||||
- `WITNESS-LOG-028.md` — 33-row attestation matrix with evidence per capability
|
||||
- `ADR-028-esp32-capability-audit.md` — Full audit findings
|
||||
- `proof/verify.py` + `expected_features.sha256` — Deterministic pipeline proof
|
||||
- `test-results/rust-workspace-tests.log` — Full cargo test output
|
||||
- `firmware-manifest/source-hashes.txt` — SHA-256 of all 7 ESP32 firmware files
|
||||
- `crate-manifest/versions.txt` — All 15 crates with versions
|
||||
- `VERIFY.sh` — One-command self-verification for recipients
|
||||
|
||||
**Key proof artifacts:**
|
||||
- `v1/data/proof/verify.py` — Trust Kill Switch: feeds reference signal through production pipeline, hashes output
|
||||
- `v1/data/proof/expected_features.sha256` — Published expected hash
|
||||
- `v1/data/proof/sample_csi_data.json` — 1,000 synthetic CSI frames (seed=42)
|
||||
- `docs/WITNESS-LOG-028.md` — 11-step reproducible verification procedure
|
||||
- `docs/adr/ADR-028-esp32-capability-audit.md` — Complete audit record
|
||||
|
||||
### Branch
|
||||
All development on: `claude/validate-code-quality-WNrNw`
|
||||
Default branch: `main`
|
||||
Active feature branch: `ruvsense-full-implementation` (PR #77)
|
||||
|
||||
---
|
||||
|
||||
@@ -65,8 +173,13 @@ All development on: `claude/validate-code-quality-WNrNw`
|
||||
## File Organization
|
||||
|
||||
- NEVER save to root folder — use the directories below
|
||||
- `docs/adr/` — Architecture Decision Records
|
||||
- `rust-port/wifi-densepose-rs/crates/` — Rust workspace crates (signal, train, mat, nn, hardware)
|
||||
- `docs/adr/` — Architecture Decision Records (32 ADRs)
|
||||
- `docs/ddd/` — Domain-Driven Design models
|
||||
- `rust-port/wifi-densepose-rs/crates/` — Rust workspace crates (15 crates)
|
||||
- `rust-port/wifi-densepose-rs/crates/wifi-densepose-signal/src/ruvsense/` — RuvSense multistatic modules (14 files)
|
||||
- `rust-port/wifi-densepose-rs/crates/wifi-densepose-ruvector/src/viewpoint/` — Cross-viewpoint fusion (5 files)
|
||||
- `rust-port/wifi-densepose-rs/crates/wifi-densepose-hardware/src/esp32/` — ESP32 TDM protocol
|
||||
- `firmware/esp32-csi-node/main/` — ESP32 C firmware (channel hopping, NVS config, TDM)
|
||||
- `v1/src/` — Python source (core, hardware, services, api)
|
||||
- `v1/data/proof/` — Deterministic CSI proof bundles
|
||||
- `.claude-flow/` — Claude Flow coordination state (committed for team sharing)
|
||||
@@ -93,14 +206,18 @@ All development on: `claude/validate-code-quality-WNrNw`
|
||||
|
||||
Before merging any PR, verify each item applies and is addressed:
|
||||
|
||||
1. **Tests pass** — `cargo test` (Rust) and `python -m pytest` (Python) green
|
||||
2. **README.md** — Update platform tables, crate descriptions, hardware tables, feature summaries if scope changed
|
||||
3. **CHANGELOG.md** — Add entry under `[Unreleased]` with what was added/fixed/changed
|
||||
4. **User guide** (`docs/user-guide.md`) — Update if new data sources, CLI flags, or setup steps were added
|
||||
5. **ADR index** — Update ADR count in README docs table if a new ADR was created
|
||||
6. **Docker Hub image** — Only rebuild if Dockerfile, dependencies, or runtime behavior changed (not needed for platform-gated code that doesn't affect the Linux container)
|
||||
7. **Crate publishing** — Only needed if a crate is published to crates.io and its public API changed (workspace-internal crates don't need publishing)
|
||||
8. **`.gitignore`** — Add any new build artifacts or binaries
|
||||
1. **Rust tests pass** — `cargo test --workspace --no-default-features` (1,031+ passed, 0 failed)
|
||||
2. **Python proof passes** — `python v1/data/proof/verify.py` (VERDICT: PASS)
|
||||
3. **README.md** — Update platform tables, crate descriptions, hardware tables, feature summaries if scope changed
|
||||
4. **CLAUDE.md** — Update crate table, ADR list, module tables, version if scope changed
|
||||
5. **CHANGELOG.md** — Add entry under `[Unreleased]` with what was added/fixed/changed
|
||||
6. **User guide** (`docs/user-guide.md`) — Update if new data sources, CLI flags, or setup steps were added
|
||||
7. **ADR index** — Update ADR count in README docs table if a new ADR was created
|
||||
8. **Witness bundle** — Regenerate if tests or proof hash changed: `bash scripts/generate-witness-bundle.sh`
|
||||
9. **Docker Hub image** — Only rebuild if Dockerfile, dependencies, or runtime behavior changed
|
||||
10. **Crate publishing** — Only needed if a crate is published to crates.io and its public API changed
|
||||
11. **`.gitignore`** — Add any new build artifacts or binaries
|
||||
12. **Security audit** — Run security review for new modules touching hardware/network boundaries
|
||||
|
||||
## Build & Test
|
||||
|
||||
|
||||
174
README.md
174
README.md
@@ -6,7 +6,7 @@ WiFi DensePose turns commodity WiFi signals into real-time human pose estimation
|
||||
|
||||
[](https://www.rust-lang.org/)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://github.com/ruvnet/wifi-densepose)
|
||||
[](https://github.com/ruvnet/wifi-densepose)
|
||||
[](https://hub.docker.com/r/ruvnet/wifi-densepose)
|
||||
[](#vital-sign-detection)
|
||||
[](#esp32-s3-hardware-pipeline)
|
||||
@@ -49,7 +49,8 @@ docker run -p 3000:3000 ruvnet/wifi-densepose:latest
|
||||
| [User Guide](docs/user-guide.md) | Step-by-step guide: installation, first run, API usage, hardware setup, training |
|
||||
| [WiFi-Mat User Guide](docs/wifi-mat-user-guide.md) | Disaster response module: search & rescue, START triage |
|
||||
| [Build Guide](docs/build-guide.md) | Building from source (Rust and Python) |
|
||||
| [Architecture Decisions](docs/adr/) | 27 ADRs covering signal processing, training, hardware, security, domain generalization |
|
||||
| [Architecture Decisions](docs/adr/) | 33 ADRs covering signal processing, training, hardware, security, domain generalization, multistatic sensing, CRV signal-line integration |
|
||||
| [DDD Domain Model](docs/ddd/ruvsense-domain-model.md) | RuvSense bounded contexts, aggregates, domain events, and ubiquitous language |
|
||||
|
||||
---
|
||||
|
||||
@@ -66,6 +67,8 @@ See people, breathing, and heartbeats through walls — using only WiFi signals
|
||||
| 👥 | **Multi-Person** | Tracks multiple people simultaneously, each with independent pose and vitals — no hard software limit (physics: ~3-5 per AP with 56 subcarriers, more with multi-AP) |
|
||||
| 🧱 | **Through-Wall** | WiFi passes through walls, furniture, and debris — works where cameras cannot |
|
||||
| 🚑 | **Disaster Response** | Detects trapped survivors through rubble and classifies injury severity (START triage) |
|
||||
| 📡 | **Multistatic Mesh** | 4-6 ESP32 nodes fuse 12+ TX-RX links for 360-degree coverage, <30mm jitter, zero identity swaps ([ADR-029](docs/adr/ADR-029-ruvsense-multistatic-sensing-mode.md)) |
|
||||
| 🌐 | **Persistent Field Model** | Room eigenstructure via SVD enables RF tomography, drift detection, intention prediction, and adversarial detection ([ADR-030](docs/adr/ADR-030-ruvsense-persistent-field-model.md)) |
|
||||
|
||||
### Intelligence
|
||||
|
||||
@@ -76,6 +79,8 @@ The system learns on its own and gets smarter over time — no hand-tuning, no l
|
||||
| 🧠 | **Self-Learning** | Teaches itself from raw WiFi data — no labeled training sets, no cameras needed to bootstrap ([ADR-024](docs/adr/ADR-024-contrastive-csi-embedding-model.md)) |
|
||||
| 🎯 | **AI Signal Processing** | Attention networks, graph algorithms, and smart compression replace hand-tuned thresholds — adapts to each room automatically ([RuVector](https://github.com/ruvnet/ruvector)) |
|
||||
| 🌍 | **Works Everywhere** | Train once, deploy in any room — adversarial domain generalization strips environment bias so models transfer across rooms, buildings, and hardware ([ADR-027](docs/adr/ADR-027-cross-environment-domain-generalization.md)) |
|
||||
| 👁️ | **Cross-Viewpoint Fusion** | Learned attention fuses multiple viewpoints with geometric bias — reduces body occlusion and depth ambiguity that physics prevents any single sensor from solving ([ADR-031](docs/adr/ADR-031-ruview-sensing-first-rf-mode.md)) |
|
||||
| 🔮 | **Signal-Line Protocol** | `ruvector-crv` 6-stage CRV pipeline maps CSI sensing to Poincare ball embeddings, GNN topology, SNN temporal encoding, and MinCut partitioning | -- |
|
||||
|
||||
### Performance & Deployment
|
||||
|
||||
@@ -84,7 +89,7 @@ Fast enough for real-time use, small enough for edge devices, simple enough for
|
||||
| | Feature | What It Means |
|
||||
|---|---------|---------------|
|
||||
| ⚡ | **Real-Time** | Analyzes WiFi signals in under 100 microseconds per frame — fast enough for live monitoring |
|
||||
| 🦀 | **810x Faster** | Complete Rust rewrite: 54,000 frames/sec pipeline, 132 MB Docker image, 542+ tests |
|
||||
| 🦀 | **810x Faster** | Complete Rust rewrite: 54,000 frames/sec pipeline, 132 MB Docker image, 1,031+ tests |
|
||||
| 🐳 | **One-Command Setup** | `docker pull ruvnet/wifi-densepose:latest` — live sensing in 30 seconds, no toolchain needed |
|
||||
| 📦 | **Portable Models** | Trained models package into a single `.rvf` file — runs on edge, cloud, or browser (WASM) |
|
||||
|
||||
@@ -97,15 +102,23 @@ WiFi routers flood every room with radio waves. When a person moves — or even
|
||||
```
|
||||
WiFi Router → radio waves pass through room → hit human body → scatter
|
||||
↓
|
||||
ESP32 / WiFi NIC captures 56+ subcarrier amplitudes & phases (CSI) at 20 Hz
|
||||
ESP32 mesh (4-6 nodes) captures CSI on channels 1/6/11 via TDM protocol
|
||||
↓
|
||||
Signal Processing cleans noise, removes interference, extracts motion signatures
|
||||
Multi-Band Fusion: 3 channels × 56 subcarriers = 168 virtual subcarriers per link
|
||||
↓
|
||||
AI Backbone (RuVector) applies attention, graph algorithms, and compression
|
||||
Multistatic Fusion: N×(N-1) links → attention-weighted cross-viewpoint embedding
|
||||
↓
|
||||
Neural Network maps processed signals → 17 body keypoints + vital signs
|
||||
Coherence Gate: accept/reject measurements → stable for days without tuning
|
||||
↓
|
||||
Output: real-time pose, breathing rate, heart rate, presence, room fingerprint
|
||||
Signal Processing: Hampel, SpotFi, Fresnel, BVP, spectrogram → clean features
|
||||
↓
|
||||
AI Backbone (RuVector): attention, graph algorithms, compression, field model
|
||||
↓
|
||||
Signal-Line Protocol (CRV): 6-stage gestalt → sensory → topology → coherence → search → model
|
||||
↓
|
||||
Neural Network: processed signals → 17 body keypoints + vital signs + room model
|
||||
↓
|
||||
Output: real-time pose, breathing, heart rate, room fingerprint, drift alerts
|
||||
```
|
||||
|
||||
No training cameras required — the [Self-Learning system (ADR-024)](docs/adr/ADR-024-contrastive-csi-embedding-model.md) bootstraps from raw WiFi data alone. [MERIDIAN (ADR-027)](docs/adr/ADR-027-cross-environment-domain-generalization.md) ensures the model works in any room, not just the one it trained in.
|
||||
@@ -366,6 +379,135 @@ cd dist/witness-bundle-ADR028-*/ && bash VERIFY.sh
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>📡 Multistatic Sensing (ADR-029/030/031 — Project RuvSense + RuView)</strong> — Multiple ESP32 nodes fuse viewpoints for production-grade pose, tracking, and exotic sensing</summary>
|
||||
|
||||
A single WiFi receiver can track people, but has blind spots — limbs behind the torso are invisible, depth is ambiguous, and two people at similar range create overlapping signals. RuvSense solves this by coordinating multiple ESP32 nodes into a **multistatic mesh** where every node acts as both transmitter and receiver, creating N×(N-1) measurement links from N devices.
|
||||
|
||||
**What it does in plain terms:**
|
||||
- 4 ESP32-S3 nodes ($48 total) provide 12 TX-RX measurement links covering 360 degrees
|
||||
- Each node hops across WiFi channels 1/6/11, tripling effective bandwidth from 20→60 MHz
|
||||
- Coherence gating rejects noisy frames automatically — no manual tuning, stable for days
|
||||
- Two-person tracking at 20 Hz with zero identity swaps over 10 minutes
|
||||
- The room itself becomes a persistent model — the system remembers, predicts, and explains
|
||||
|
||||
**Three ADRs, one pipeline:**
|
||||
|
||||
| ADR | Codename | What it adds |
|
||||
|-----|----------|-------------|
|
||||
| [ADR-029](docs/adr/ADR-029-ruvsense-multistatic-sensing-mode.md) | **RuvSense** | Channel hopping, TDM protocol, multi-node fusion, coherence gating, 17-keypoint Kalman tracker |
|
||||
| [ADR-030](docs/adr/ADR-030-ruvsense-persistent-field-model.md) | **RuvSense Field** | Room electromagnetic eigenstructure (SVD), RF tomography, longitudinal drift detection, intention prediction, gesture recognition, adversarial detection |
|
||||
| [ADR-031](docs/adr/ADR-031-ruview-sensing-first-rf-mode.md) | **RuView** | Cross-viewpoint attention with geometric bias, viewpoint diversity optimization, embedding-level fusion |
|
||||
|
||||
**Architecture**
|
||||
|
||||
```
|
||||
4x ESP32-S3 nodes ($48) TDM: each transmits in turn, all others receive
|
||||
│ Channel hop: ch1→ch6→ch11 per dwell (50ms)
|
||||
▼
|
||||
Per-Node Signal Processing Phase sanitize → Hampel → BVP → subcarrier select
|
||||
│ (ADR-014, unchanged per viewpoint)
|
||||
▼
|
||||
Multi-Band Frame Fusion 3 channels × 56 subcarriers = 168 virtual subcarriers
|
||||
│ Cross-channel phase alignment via NeumannSolver
|
||||
▼
|
||||
Multistatic Viewpoint Fusion N nodes → attention-weighted fusion → single embedding
|
||||
│ Geometric bias from node placement angles
|
||||
▼
|
||||
Coherence Gate Accept / PredictOnly / Reject / Recalibrate
|
||||
│ Prevents model drift, stable for days
|
||||
▼
|
||||
Persistent Field Model SVD baseline → body = observation - environment
|
||||
│ RF tomography, drift detection, intention signals
|
||||
▼
|
||||
Pose Tracker + DensePose 17-keypoint Kalman, re-ID via AETHER embeddings
|
||||
Multi-person min-cut separation, zero ID swaps
|
||||
```
|
||||
|
||||
**Seven Exotic Sensing Tiers (ADR-030)**
|
||||
|
||||
| Tier | Capability | What it detects |
|
||||
|------|-----------|-----------------|
|
||||
| 1 | Field Normal Modes | Room electromagnetic eigenstructure via SVD |
|
||||
| 2 | Coarse RF Tomography | 3D occupancy volume from link attenuations |
|
||||
| 3 | Intention Lead Signals | Pre-movement prediction 200-500ms before action |
|
||||
| 4 | Longitudinal Biomechanics | Personal movement changes over days/weeks |
|
||||
| 5 | Cross-Room Continuity | Identity preserved across rooms without cameras |
|
||||
| 6 | Invisible Interaction | Multi-user gesture control through walls |
|
||||
| 7 | Adversarial Detection | Physically impossible signal identification |
|
||||
|
||||
**Acceptance Test**
|
||||
|
||||
| Metric | Threshold | What it proves |
|
||||
|--------|-----------|---------------|
|
||||
| Torso keypoint jitter | < 30mm RMS | Precision sufficient for applications |
|
||||
| Identity swaps | 0 over 10 minutes (12,000 frames) | Reliable multi-person tracking |
|
||||
| Update rate | 20 Hz (50ms cycle) | Real-time response |
|
||||
| Breathing SNR | > 10 dB at 3m | Small-motion sensitivity confirmed |
|
||||
|
||||
**New Rust modules (9,000+ lines)**
|
||||
|
||||
| Crate | New modules | Purpose |
|
||||
|-------|------------|---------|
|
||||
| `wifi-densepose-signal` | `ruvsense/` (10 modules) | Multiband fusion, phase alignment, multistatic fusion, coherence, field model, tomography, longitudinal drift, intention detection |
|
||||
| `wifi-densepose-ruvector` | `viewpoint/` (5 modules) | Cross-viewpoint attention with geometric bias, diversity index, coherence gating, fusion orchestrator |
|
||||
| `wifi-densepose-hardware` | `esp32/tdm.rs` | TDM sensing protocol, sync beacons, clock drift compensation |
|
||||
|
||||
**Firmware extensions (C, backward-compatible)**
|
||||
|
||||
| File | Addition |
|
||||
|------|---------|
|
||||
| `csi_collector.c` | Channel hop table, timer-driven hop, NDP injection stub |
|
||||
| `nvs_config.c` | 5 new NVS keys: hop_count, channel_list, dwell_ms, tdm_slot, tdm_node_count |
|
||||
|
||||
**DDD Domain Model** — 6 bounded contexts: Multistatic Sensing, Coherence, Pose Tracking, Field Model, Cross-Room Identity, Adversarial Detection. Full specification: [`docs/ddd/ruvsense-domain-model.md`](docs/ddd/ruvsense-domain-model.md).
|
||||
|
||||
See the ADR documents for full architectural details, GOAP integration plans, and research references.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>🔮 Signal-Line Protocol (CRV)</b></summary>
|
||||
|
||||
### 6-Stage CSI Signal Line
|
||||
|
||||
Maps the CRV (Coordinate Remote Viewing) signal-line methodology to WiFi CSI processing via `ruvector-crv`:
|
||||
|
||||
| Stage | CRV Name | WiFi CSI Mapping | ruvector Component |
|
||||
|-------|----------|-----------------|-------------------|
|
||||
| I | Ideograms | Raw CSI gestalt (manmade/natural/movement/energy) | Poincare ball hyperbolic embeddings |
|
||||
| II | Sensory | Amplitude textures, phase patterns, frequency colors | Multi-head attention vectors |
|
||||
| III | Dimensional | AP mesh spatial topology, node geometry | GNN graph topology |
|
||||
| IV | Emotional/AOL | Coherence gating — signal vs noise separation | SNN temporal encoding |
|
||||
| V | Interrogation | Cross-stage probing — query pose against CSI history | Differentiable search |
|
||||
| VI | 3D Model | Composite person estimation, MinCut partitioning | Graph partitioning |
|
||||
|
||||
**Cross-Session Convergence**: When multiple AP clusters observe the same person, CRV convergence analysis finds agreement in their signal embeddings — directly mapping to cross-room identity continuity.
|
||||
|
||||
```rust
|
||||
use wifi_densepose_ruvector::crv::WifiCrvPipeline;
|
||||
|
||||
let mut pipeline = WifiCrvPipeline::new(WifiCrvConfig::default());
|
||||
pipeline.create_session("room-a", "person-001")?;
|
||||
|
||||
// Process CSI frames through 6-stage pipeline
|
||||
let result = pipeline.process_csi_frame("room-a", &litudes, &phases)?;
|
||||
// result.gestalt = Movement, confidence = 0.87
|
||||
// result.sensory_embedding = [0.12, -0.34, ...]
|
||||
|
||||
// Cross-room identity matching via convergence
|
||||
let convergence = pipeline.find_cross_room_convergence("person-001", 0.75)?;
|
||||
```
|
||||
|
||||
**Architecture**:
|
||||
- `CsiGestaltClassifier` — Maps CSI amplitude/phase patterns to 6 gestalt types
|
||||
- `CsiSensoryEncoder` — Extracts texture/color/temperature/luminosity features from subcarriers
|
||||
- `MeshTopologyEncoder` — Encodes AP mesh as GNN graph (Stage III)
|
||||
- `CoherenceAolDetector` — Maps coherence gate states to AOL noise detection (Stage IV)
|
||||
- `WifiCrvPipeline` — Orchestrates all 6 stages into unified sensing session
|
||||
|
||||
</details>
|
||||
|
||||
---
|
||||
|
||||
## 📦 Installation
|
||||
@@ -1432,6 +1574,22 @@ pre-commit install
|
||||
<details>
|
||||
<summary><strong>Release history</strong></summary>
|
||||
|
||||
### v3.1.0 — 2026-03-02
|
||||
|
||||
Multistatic sensing, persistent field model, and cross-viewpoint fusion — the biggest capability jump since v2.0.
|
||||
|
||||
- **Project RuvSense (ADR-029)** — Multistatic mesh: TDM protocol, channel hopping (ch1/6/11), multi-band frame fusion, coherence gating, 17-keypoint Kalman tracker with re-ID; 10 new signal modules (5,300+ lines)
|
||||
- **RuvSense Persistent Field Model (ADR-030)** — 7 exotic sensing tiers: field normal modes (SVD), RF tomography, longitudinal drift detection, intention prediction, cross-room identity, gesture classification, adversarial detection
|
||||
- **Project RuView (ADR-031)** — Cross-viewpoint attention with geometric bias, Geometric Diversity Index, viewpoint fusion orchestrator; 5 new ruvector modules (2,200+ lines)
|
||||
- **TDM Hardware Protocol** — ESP32 sensing coordinator: sync beacons, slot scheduling, clock drift compensation (±10ppm), 20 Hz aggregate rate
|
||||
- **Channel-Hopping Firmware** — ESP32 firmware extended with hop table, timer-driven channel switching, NDP injection stub; NVS config for all TDM parameters; fully backward-compatible
|
||||
- **DDD Domain Model** — 6 bounded contexts, ubiquitous language, aggregate roots, domain events, full event bus specification
|
||||
- **`ruvector-crv` 6-stage CRV signal-line integration (ADR-033)** — Maps Coordinate Remote Viewing methodology to WiFi CSI: gestalt classification, sensory encoding, GNN topology, SNN coherence gating, differentiable search, MinCut partitioning; cross-session convergence for multi-room identity continuity
|
||||
- **ADR-032 multistatic mesh security hardening** — Bounded calibration buffers, atomic counters, division-by-zero guards, NaN-safe normalization across all multistatic modules
|
||||
- **ADR-033 CRV signal-line sensing integration** — Architecture decision record for the 6-stage CRV pipeline mapping to ruvector components
|
||||
- **9,000+ lines of new Rust code** across 17 modules with 300+ tests
|
||||
- **Security hardened** — Bounded buffers, NaN guards, no panics in public APIs, input validation at all boundaries
|
||||
|
||||
### v3.0.0 — 2026-03-01
|
||||
|
||||
Major release: AETHER contrastive embedding model, AI signal processing backbone, cross-platform adapters, Docker Hub images, and comprehensive README overhaul.
|
||||
|
||||
167
claude.md
167
claude.md
@@ -4,13 +4,49 @@
|
||||
|
||||
WiFi-based human pose estimation using Channel State Information (CSI).
|
||||
Dual codebase: Python v1 (`v1/`) and Rust port (`rust-port/wifi-densepose-rs/`).
|
||||
|
||||
### Key Rust Crates
|
||||
- `wifi-densepose-signal` — SOTA signal processing (conjugate mult, Hampel, Fresnel, BVP, spectrogram)
|
||||
- `wifi-densepose-train` — Training pipeline with ruvector integration (ADR-016)
|
||||
- `wifi-densepose-mat` — Disaster detection module (MAT, multi-AP, triage)
|
||||
- `wifi-densepose-nn` — Neural network inference (DensePose head, RCNN)
|
||||
- `wifi-densepose-hardware` — ESP32 aggregator, hardware interfaces
|
||||
| Crate | Description |
|
||||
|-------|-------------|
|
||||
| `wifi-densepose-core` | Core types, traits, error types, CSI frame primitives |
|
||||
| `wifi-densepose-signal` | SOTA signal processing + RuvSense multistatic sensing (14 modules) |
|
||||
| `wifi-densepose-nn` | Neural network inference (ONNX, PyTorch, Candle backends) |
|
||||
| `wifi-densepose-train` | Training pipeline with ruvector integration + ruview_metrics |
|
||||
| `wifi-densepose-mat` | Mass Casualty Assessment Tool — disaster survivor detection |
|
||||
| `wifi-densepose-hardware` | ESP32 aggregator, TDM protocol, channel hopping firmware |
|
||||
| `wifi-densepose-ruvector` | RuVector v2.0.4 integration + cross-viewpoint fusion (5 modules) |
|
||||
| `wifi-densepose-api` | REST API (Axum) |
|
||||
| `wifi-densepose-db` | Database layer (Postgres, SQLite, Redis) |
|
||||
| `wifi-densepose-config` | Configuration management |
|
||||
| `wifi-densepose-wasm` | WebAssembly bindings for browser deployment |
|
||||
| `wifi-densepose-cli` | CLI tool (`wifi-densepose` binary) |
|
||||
| `wifi-densepose-sensing-server` | Lightweight Axum server for WiFi sensing UI |
|
||||
| `wifi-densepose-wifiscan` | Multi-BSSID WiFi scanning (ADR-022) |
|
||||
| `wifi-densepose-vitals` | ESP32 CSI-grade vital sign extraction (ADR-021) |
|
||||
|
||||
### RuvSense Modules (`signal/src/ruvsense/`)
|
||||
| Module | Purpose |
|
||||
|--------|---------|
|
||||
| `multiband.rs` | Multi-band CSI frame fusion, cross-channel coherence |
|
||||
| `phase_align.rs` | Iterative LO phase offset estimation, circular mean |
|
||||
| `multistatic.rs` | Attention-weighted fusion, geometric diversity |
|
||||
| `coherence.rs` | Z-score coherence scoring, DriftProfile |
|
||||
| `coherence_gate.rs` | Accept/PredictOnly/Reject/Recalibrate gate decisions |
|
||||
| `pose_tracker.rs` | 17-keypoint Kalman tracker with AETHER re-ID embeddings |
|
||||
| `field_model.rs` | SVD room eigenstructure, perturbation extraction |
|
||||
| `tomography.rs` | RF tomography, ISTA L1 solver, voxel grid |
|
||||
| `longitudinal.rs` | Welford stats, biomechanics drift detection |
|
||||
| `intention.rs` | Pre-movement lead signals (200-500ms) |
|
||||
| `cross_room.rs` | Environment fingerprinting, transition graph |
|
||||
| `gesture.rs` | DTW template matching gesture classifier |
|
||||
| `adversarial.rs` | Physically impossible signal detection, multi-link consistency |
|
||||
|
||||
### Cross-Viewpoint Fusion (`ruvector/src/viewpoint/`)
|
||||
| Module | Purpose |
|
||||
|--------|---------|
|
||||
| `attention.rs` | CrossViewpointAttention, GeometricBias, softmax with G_bias |
|
||||
| `geometry.rs` | GeometricDiversityIndex, Cramer-Rao bounds, Fisher Information |
|
||||
| `coherence.rs` | Phase phasor coherence, hysteresis gate |
|
||||
| `fusion.rs` | MultistaticArray aggregate root, domain events |
|
||||
|
||||
### RuVector v2.0.4 Integration (ADR-016 complete, ADR-017 proposed)
|
||||
All 5 ruvector crates integrated in workspace:
|
||||
@@ -21,33 +57,105 @@ All 5 ruvector crates integrated in workspace:
|
||||
- `ruvector-attention` → `model.rs` (apply_spatial_attention) + `bvp.rs`
|
||||
|
||||
### Architecture Decisions
|
||||
All ADRs in `docs/adr/` (ADR-001 through ADR-017). Key ones:
|
||||
32 ADRs in `docs/adr/` (ADR-001 through ADR-032). Key ones:
|
||||
- ADR-014: SOTA signal processing (Accepted)
|
||||
- ADR-015: MM-Fi + Wi-Pose training datasets (Accepted)
|
||||
- ADR-016: RuVector training pipeline integration (Accepted — complete)
|
||||
- ADR-017: RuVector signal + MAT integration (Proposed — next target)
|
||||
- ADR-024: Contrastive CSI embedding / AETHER (Accepted)
|
||||
- ADR-027: Cross-environment domain generalization / MERIDIAN (Accepted)
|
||||
- ADR-028: ESP32 capability audit + witness verification (Accepted)
|
||||
- ADR-029: RuvSense multistatic sensing mode (Proposed)
|
||||
- ADR-030: RuvSense persistent field model (Proposed)
|
||||
- ADR-031: RuView sensing-first RF mode (Proposed)
|
||||
- ADR-032: Multistatic mesh security hardening (Proposed)
|
||||
|
||||
### Build & Test Commands (this repo)
|
||||
```bash
|
||||
# Rust — check training crate (no GPU needed)
|
||||
# Rust — full workspace tests (1,031+ tests, ~2 min)
|
||||
cd rust-port/wifi-densepose-rs
|
||||
cargo test --workspace --no-default-features
|
||||
|
||||
# Rust — single crate check (no GPU needed)
|
||||
cargo check -p wifi-densepose-train --no-default-features
|
||||
|
||||
# Rust — run all tests
|
||||
cargo test -p wifi-densepose-train --no-default-features
|
||||
# Rust — publish crates (dependency order)
|
||||
cargo publish -p wifi-densepose-core --no-default-features
|
||||
cargo publish -p wifi-densepose-signal --no-default-features
|
||||
# ... see crate publishing order below
|
||||
|
||||
# Rust — full workspace check
|
||||
cargo check --workspace --no-default-features
|
||||
|
||||
# Python — proof verification
|
||||
# Python — deterministic proof verification (SHA-256)
|
||||
python v1/data/proof/verify.py
|
||||
|
||||
# Python — test suite
|
||||
cd v1 && python -m pytest tests/ -x -q
|
||||
```
|
||||
|
||||
### Crate Publishing Order
|
||||
Crates must be published in dependency order:
|
||||
1. `wifi-densepose-core` (no internal deps)
|
||||
2. `wifi-densepose-vitals` (no internal deps)
|
||||
3. `wifi-densepose-wifiscan` (no internal deps)
|
||||
4. `wifi-densepose-hardware` (no internal deps)
|
||||
5. `wifi-densepose-config` (no internal deps)
|
||||
6. `wifi-densepose-db` (no internal deps)
|
||||
7. `wifi-densepose-signal` (depends on core)
|
||||
8. `wifi-densepose-nn` (no internal deps, workspace only)
|
||||
9. `wifi-densepose-ruvector` (no internal deps, workspace only)
|
||||
10. `wifi-densepose-train` (depends on signal, nn)
|
||||
11. `wifi-densepose-mat` (depends on core, signal, nn)
|
||||
12. `wifi-densepose-api` (no internal deps)
|
||||
13. `wifi-densepose-wasm` (depends on mat)
|
||||
14. `wifi-densepose-sensing-server` (depends on wifiscan)
|
||||
15. `wifi-densepose-cli` (depends on mat)
|
||||
|
||||
### Validation & Witness Verification (ADR-028)
|
||||
|
||||
**After any significant code change, run the full validation:**
|
||||
|
||||
```bash
|
||||
# 1. Rust tests — must be 1,031+ passed, 0 failed
|
||||
cd rust-port/wifi-densepose-rs
|
||||
cargo test --workspace --no-default-features
|
||||
|
||||
# 2. Python proof — must print VERDICT: PASS
|
||||
cd ../..
|
||||
python v1/data/proof/verify.py
|
||||
|
||||
# 3. Generate witness bundle (includes both above + firmware hashes)
|
||||
bash scripts/generate-witness-bundle.sh
|
||||
|
||||
# 4. Self-verify the bundle — must be 7/7 PASS
|
||||
cd dist/witness-bundle-ADR028-*/
|
||||
bash VERIFY.sh
|
||||
```
|
||||
|
||||
**If the Python proof hash changes** (e.g., numpy/scipy version update):
|
||||
```bash
|
||||
# Regenerate the expected hash, then verify it passes
|
||||
python v1/data/proof/verify.py --generate-hash
|
||||
python v1/data/proof/verify.py
|
||||
```
|
||||
|
||||
**Witness bundle contents** (`dist/witness-bundle-ADR028-<sha>.tar.gz`):
|
||||
- `WITNESS-LOG-028.md` — 33-row attestation matrix with evidence per capability
|
||||
- `ADR-028-esp32-capability-audit.md` — Full audit findings
|
||||
- `proof/verify.py` + `expected_features.sha256` — Deterministic pipeline proof
|
||||
- `test-results/rust-workspace-tests.log` — Full cargo test output
|
||||
- `firmware-manifest/source-hashes.txt` — SHA-256 of all 7 ESP32 firmware files
|
||||
- `crate-manifest/versions.txt` — All 15 crates with versions
|
||||
- `VERIFY.sh` — One-command self-verification for recipients
|
||||
|
||||
**Key proof artifacts:**
|
||||
- `v1/data/proof/verify.py` — Trust Kill Switch: feeds reference signal through production pipeline, hashes output
|
||||
- `v1/data/proof/expected_features.sha256` — Published expected hash
|
||||
- `v1/data/proof/sample_csi_data.json` — 1,000 synthetic CSI frames (seed=42)
|
||||
- `docs/WITNESS-LOG-028.md` — 11-step reproducible verification procedure
|
||||
- `docs/adr/ADR-028-esp32-capability-audit.md` — Complete audit record
|
||||
|
||||
### Branch
|
||||
All development on: `claude/validate-code-quality-WNrNw`
|
||||
Default branch: `main`
|
||||
Active feature branch: `ruvsense-full-implementation` (PR #77)
|
||||
|
||||
---
|
||||
|
||||
@@ -65,8 +173,13 @@ All development on: `claude/validate-code-quality-WNrNw`
|
||||
## File Organization
|
||||
|
||||
- NEVER save to root folder — use the directories below
|
||||
- `docs/adr/` — Architecture Decision Records
|
||||
- `rust-port/wifi-densepose-rs/crates/` — Rust workspace crates (signal, train, mat, nn, hardware)
|
||||
- `docs/adr/` — Architecture Decision Records (32 ADRs)
|
||||
- `docs/ddd/` — Domain-Driven Design models
|
||||
- `rust-port/wifi-densepose-rs/crates/` — Rust workspace crates (15 crates)
|
||||
- `rust-port/wifi-densepose-rs/crates/wifi-densepose-signal/src/ruvsense/` — RuvSense multistatic modules (14 files)
|
||||
- `rust-port/wifi-densepose-rs/crates/wifi-densepose-ruvector/src/viewpoint/` — Cross-viewpoint fusion (5 files)
|
||||
- `rust-port/wifi-densepose-rs/crates/wifi-densepose-hardware/src/esp32/` — ESP32 TDM protocol
|
||||
- `firmware/esp32-csi-node/main/` — ESP32 C firmware (channel hopping, NVS config, TDM)
|
||||
- `v1/src/` — Python source (core, hardware, services, api)
|
||||
- `v1/data/proof/` — Deterministic CSI proof bundles
|
||||
- `.claude-flow/` — Claude Flow coordination state (committed for team sharing)
|
||||
@@ -93,14 +206,18 @@ All development on: `claude/validate-code-quality-WNrNw`
|
||||
|
||||
Before merging any PR, verify each item applies and is addressed:
|
||||
|
||||
1. **Tests pass** — `cargo test` (Rust) and `python -m pytest` (Python) green
|
||||
2. **README.md** — Update platform tables, crate descriptions, hardware tables, feature summaries if scope changed
|
||||
3. **CHANGELOG.md** — Add entry under `[Unreleased]` with what was added/fixed/changed
|
||||
4. **User guide** (`docs/user-guide.md`) — Update if new data sources, CLI flags, or setup steps were added
|
||||
5. **ADR index** — Update ADR count in README docs table if a new ADR was created
|
||||
6. **Docker Hub image** — Only rebuild if Dockerfile, dependencies, or runtime behavior changed (not needed for platform-gated code that doesn't affect the Linux container)
|
||||
7. **Crate publishing** — Only needed if a crate is published to crates.io and its public API changed (workspace-internal crates don't need publishing)
|
||||
8. **`.gitignore`** — Add any new build artifacts or binaries
|
||||
1. **Rust tests pass** — `cargo test --workspace --no-default-features` (1,031+ passed, 0 failed)
|
||||
2. **Python proof passes** — `python v1/data/proof/verify.py` (VERDICT: PASS)
|
||||
3. **README.md** — Update platform tables, crate descriptions, hardware tables, feature summaries if scope changed
|
||||
4. **CLAUDE.md** — Update crate table, ADR list, module tables, version if scope changed
|
||||
5. **CHANGELOG.md** — Add entry under `[Unreleased]` with what was added/fixed/changed
|
||||
6. **User guide** (`docs/user-guide.md`) — Update if new data sources, CLI flags, or setup steps were added
|
||||
7. **ADR index** — Update ADR count in README docs table if a new ADR was created
|
||||
8. **Witness bundle** — Regenerate if tests or proof hash changed: `bash scripts/generate-witness-bundle.sh`
|
||||
9. **Docker Hub image** — Only rebuild if Dockerfile, dependencies, or runtime behavior changed
|
||||
10. **Crate publishing** — Only needed if a crate is published to crates.io and its public API changed
|
||||
11. **`.gitignore`** — Add any new build artifacts or binaries
|
||||
12. **Security audit** — Run security review for new modules touching hardware/network boundaries
|
||||
|
||||
## Build & Test
|
||||
|
||||
|
||||
400
docs/adr/ADR-029-ruvsense-multistatic-sensing-mode.md
Normal file
400
docs/adr/ADR-029-ruvsense-multistatic-sensing-mode.md
Normal file
@@ -0,0 +1,400 @@
|
||||
# ADR-029: Project RuvSense -- Sensing-First RF Mode for Multistatic WiFi DensePose
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| **Status** | Proposed |
|
||||
| **Date** | 2026-03-02 |
|
||||
| **Deciders** | ruv |
|
||||
| **Codename** | **RuvSense** -- RuVector-Enhanced Sensing for Multistatic Fidelity |
|
||||
| **Relates to** | ADR-012 (ESP32 Mesh), ADR-014 (SOTA Signal Processing), ADR-016 (RuVector Training), ADR-017 (RuVector Signal+MAT), ADR-018 (ESP32 Implementation), ADR-024 (AETHER Embeddings), ADR-026 (Survivor Track Lifecycle), ADR-027 (MERIDIAN Generalization) |
|
||||
|
||||
---
|
||||
|
||||
## 1. Context
|
||||
|
||||
### 1.1 The Fidelity Gap
|
||||
|
||||
Current WiFi-DensePose achieves functional pose estimation from a single ESP32 AP, but three fidelity metrics prevent production deployment:
|
||||
|
||||
| Metric | Current (Single ESP32) | Required (Production) | Root Cause |
|
||||
|--------|------------------------|----------------------|------------|
|
||||
| Torso keypoint jitter | ~15cm RMS | <3cm RMS | Single viewpoint, 20 MHz bandwidth, no temporal smoothing |
|
||||
| Multi-person separation | Fails >2 people, frequent ID swaps | 4+ people, zero swaps over 10 min | Underdetermined with 1 TX-RX link; no person-specific features |
|
||||
| Small motion sensitivity | Gross movement only | Breathing at 3m, heartbeat at 1.5m | Insufficient phase sensitivity at 2.4 GHz; noise floor too high |
|
||||
| Update rate | ~10 Hz effective | 20 Hz | Single-channel serial CSI collection |
|
||||
| Temporal stability | Drifts within hours | Stable over days | No coherence gating; model absorbs environmental drift |
|
||||
|
||||
### 1.2 The Insight: Sensing-First RF Mode on Existing Silicon
|
||||
|
||||
You do not need to invent a new WiFi standard. The winning move is a **sensing-first RF mode** that rides on existing silicon (ESP32-S3), existing bands (2.4/5 GHz), and existing regulations (802.11n NDP frames). The fidelity improvement comes from three physical levers:
|
||||
|
||||
1. **Bandwidth**: Channel-hopping across 2.4 GHz channels 1/6/11 triples effective bandwidth from 20 MHz to 60 MHz, 3x multipath separation
|
||||
2. **Carrier frequency**: Dual-band sensing (2.4 + 5 GHz) doubles phase sensitivity to small motion
|
||||
3. **Viewpoints**: Multistatic ESP32 mesh (4 nodes = 12 TX-RX links) provides 360-degree geometric diversity
|
||||
|
||||
### 1.3 Acceptance Test
|
||||
|
||||
**Two people in a room, 20 Hz update rate, stable tracks for 10 minutes with no identity swaps and low jitter in the torso keypoints.**
|
||||
|
||||
Quantified:
|
||||
- Torso keypoint jitter < 30mm RMS (hips, shoulders, spine)
|
||||
- Zero identity swaps over 600 seconds (12,000 frames)
|
||||
- 20 Hz output rate (50 ms cycle time)
|
||||
- Breathing SNR > 10dB at 3m (validates small-motion sensitivity)
|
||||
|
||||
---
|
||||
|
||||
## 2. Decision
|
||||
|
||||
### 2.1 Architecture Overview
|
||||
|
||||
Implement RuvSense as a new bounded context within `wifi-densepose-signal`, consisting of 6 modules:
|
||||
|
||||
```
|
||||
wifi-densepose-signal/src/ruvsense/
|
||||
├── mod.rs // Module exports, RuvSense pipeline orchestrator
|
||||
├── multiband.rs // Multi-band CSI frame fusion (§2.2)
|
||||
├── phase_align.rs // Cross-channel phase alignment (§2.3)
|
||||
├── multistatic.rs // Multi-node viewpoint fusion (§2.4)
|
||||
├── coherence.rs // Coherence metric computation (§2.5)
|
||||
├── coherence_gate.rs // Gated update policy (§2.6)
|
||||
└── pose_tracker.rs // 17-keypoint Kalman tracker with re-ID (§2.7)
|
||||
```
|
||||
|
||||
### 2.2 Channel-Hopping Firmware (ESP32-S3)
|
||||
|
||||
Modify the ESP32 firmware (`firmware/esp32-csi-node/main/csi_collector.c`) to cycle through non-overlapping channels at configurable dwell times:
|
||||
|
||||
```c
|
||||
// Channel hop table (populated from NVS at boot)
|
||||
static uint8_t s_hop_channels[6] = {1, 6, 11, 36, 40, 44};
|
||||
static uint8_t s_hop_count = 3; // default: 2.4 GHz only
|
||||
static uint32_t s_dwell_ms = 50; // 50ms per channel
|
||||
```
|
||||
|
||||
At 100 Hz raw CSI rate with 50 ms dwell across 3 channels, each channel yields ~33 frames/second. The existing ADR-018 binary frame format already carries `channel_freq_mhz` at offset 8, so no wire format change is needed.
|
||||
|
||||
**NDP frame injection:** `esp_wifi_80211_tx()` injects deterministic Null Data Packet frames (preamble-only, no payload, ~24 us airtime) at GPIO-triggered intervals. This is sensing-first: the primary RF emission purpose is CSI measurement, not data communication.
|
||||
|
||||
### 2.3 Multi-Band Frame Fusion
|
||||
|
||||
Aggregate per-channel CSI frames into a wideband virtual snapshot:
|
||||
|
||||
```rust
|
||||
/// Fused multi-band CSI from one node at one time slot.
|
||||
pub struct MultiBandCsiFrame {
|
||||
pub node_id: u8,
|
||||
pub timestamp_us: u64,
|
||||
/// One canonical-56 row per channel, ordered by center frequency.
|
||||
pub channel_frames: Vec<CanonicalCsiFrame>,
|
||||
/// Center frequencies (MHz) for each channel row.
|
||||
pub frequencies_mhz: Vec<u32>,
|
||||
/// Cross-channel coherence score (0.0-1.0).
|
||||
pub coherence: f32,
|
||||
}
|
||||
```
|
||||
|
||||
Cross-channel phase alignment uses `ruvector-solver::NeumannSolver` to solve for the channel-dependent phase rotation introduced by the ESP32 local oscillator during channel hops. The system:
|
||||
|
||||
```
|
||||
[Φ₁, Φ₆, Φ₁₁] = [Φ_body + δ₁, Φ_body + δ₆, Φ_body + δ₁₁]
|
||||
```
|
||||
|
||||
NeumannSolver fits the `δ` offsets from the static subcarrier components (which should have zero body-caused phase shift), then removes them.
|
||||
|
||||
### 2.4 Multistatic Viewpoint Fusion
|
||||
|
||||
With N ESP32 nodes, collect N `MultiBandCsiFrame` per time slot and fuse with geometric diversity:
|
||||
|
||||
**TDMA Sensing Schedule (4 nodes):**
|
||||
|
||||
| Slot | TX | RX₁ | RX₂ | RX₃ | Duration |
|
||||
|------|-----|-----|-----|-----|----------|
|
||||
| 0 | Node A | B | C | D | 4 ms |
|
||||
| 1 | Node B | A | C | D | 4 ms |
|
||||
| 2 | Node C | A | B | D | 4 ms |
|
||||
| 3 | Node D | A | B | C | 4 ms |
|
||||
| 4 | -- | Processing + fusion | | | 30 ms |
|
||||
| **Total** | | | | | **50 ms = 20 Hz** |
|
||||
|
||||
Synchronization: GPIO pulse from aggregator node at cycle start. Clock drift at ±10ppm over 50 ms is ~0.5 us, well within the 1 ms guard interval.
|
||||
|
||||
**Cross-node fusion** uses `ruvector-attn-mincut::attn_mincut` where time-frequency cells from different nodes attend to each other. Cells showing correlated motion energy across nodes (body reflection) are amplified; cells with single-node energy (local multipath artifact) are suppressed.
|
||||
|
||||
**Multi-person separation** via `ruvector-mincut::DynamicMinCut`:
|
||||
|
||||
1. Build cross-link temporal correlation graph (nodes = TX-RX links, edges = correlation coefficient)
|
||||
2. `DynamicMinCut` partitions into K clusters (one per detected person)
|
||||
3. Attention fusion (§5.3 of research doc) runs independently per cluster
|
||||
|
||||
### 2.5 Coherence Metric
|
||||
|
||||
Per-link coherence quantifies consistency with recent history:
|
||||
|
||||
```rust
|
||||
pub fn coherence_score(
|
||||
current: &[f32],
|
||||
reference: &[f32],
|
||||
variance: &[f32],
|
||||
) -> f32 {
|
||||
current.iter().zip(reference.iter()).zip(variance.iter())
|
||||
.map(|((&c, &r), &v)| {
|
||||
let z = (c - r).abs() / v.sqrt().max(1e-6);
|
||||
let weight = 1.0 / (v + 1e-6);
|
||||
((-0.5 * z * z).exp(), weight)
|
||||
})
|
||||
.fold((0.0, 0.0), |(sc, sw), (c, w)| (sc + c * w, sw + w))
|
||||
.pipe(|(sc, sw)| sc / sw)
|
||||
}
|
||||
```
|
||||
|
||||
The static/dynamic decomposition uses `ruvector-solver` to separate environmental drift (slow, global) from body motion (fast, subcarrier-specific).
|
||||
|
||||
### 2.6 Coherence-Gated Update Policy
|
||||
|
||||
```rust
|
||||
pub enum GateDecision {
|
||||
/// Coherence > 0.85: Full Kalman measurement update
|
||||
Accept(Pose),
|
||||
/// 0.5 < coherence < 0.85: Kalman predict only (3x inflated noise)
|
||||
PredictOnly,
|
||||
/// Coherence < 0.5: Reject measurement entirely
|
||||
Reject,
|
||||
/// >10s continuous low coherence: Trigger SONA recalibration (ADR-005)
|
||||
Recalibrate,
|
||||
}
|
||||
```
|
||||
|
||||
When `Recalibrate` fires:
|
||||
1. Freeze output at last known good pose
|
||||
2. Collect 200 frames (10s) of unlabeled CSI
|
||||
3. Run AETHER contrastive TTT (ADR-024) to adapt encoder
|
||||
4. Update SONA LoRA weights (ADR-005), <1ms per update
|
||||
5. Resume sensing with adapted model
|
||||
|
||||
### 2.7 Pose Tracker (17-Keypoint Kalman with Re-ID)
|
||||
|
||||
Lift the Kalman + lifecycle + re-ID infrastructure from `wifi-densepose-mat/src/tracking/` (ADR-026) into the RuvSense bounded context, extended for 17-keypoint skeletons:
|
||||
|
||||
| Parameter | Value | Rationale |
|
||||
|-----------|-------|-----------|
|
||||
| State dimension | 6 per keypoint (x,y,z,vx,vy,vz) | Constant-velocity model |
|
||||
| Process noise σ_a | 0.3 m/s² | Normal walking acceleration |
|
||||
| Measurement noise σ_obs | 0.08 m | Target <8cm RMS at torso |
|
||||
| Mahalanobis gate | χ²(3) = 9.0 | 3σ ellipsoid (same as ADR-026) |
|
||||
| Birth hits | 2 frames (100ms at 20Hz) | Reject single-frame noise |
|
||||
| Loss misses | 5 frames (250ms) | Brief occlusion tolerance |
|
||||
| Re-ID feature | AETHER 128-dim embedding | Body-shape discriminative (ADR-024) |
|
||||
| Re-ID window | 5 seconds | Sufficient for crossing recovery |
|
||||
|
||||
**Track assignment** uses `ruvector-mincut`'s `DynamicPersonMatcher` (already integrated in `metrics.rs`, ADR-016) with joint position + embedding cost:
|
||||
|
||||
```
|
||||
cost(track_i, det_j) = 0.6 * mahalanobis(track_i, det_j.position)
|
||||
+ 0.4 * (1 - cosine_sim(track_i.embedding, det_j.embedding))
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. GOAP Integration Plan (Goal-Oriented Action Planning)
|
||||
|
||||
### 3.1 Action Dependency Graph
|
||||
|
||||
```
|
||||
Phase 1: Foundation
|
||||
Action 1: Channel-Hopping Firmware ──────────────────────┐
|
||||
│ │
|
||||
v │
|
||||
Action 2: Multi-Band Frame Fusion ──→ Action 6: Coherence │
|
||||
│ Metric │
|
||||
v │ │
|
||||
Action 3: Multistatic Mesh v │
|
||||
│ Action 7: Coherence │
|
||||
v Gate │
|
||||
Phase 2: Tracking │ │
|
||||
Action 4: Pose Tracker ←────────────────┘ │
|
||||
│ │
|
||||
v │
|
||||
Action 5: End-to-End Pipeline @ 20 Hz ←────────────────────┘
|
||||
│
|
||||
v
|
||||
Phase 4: Hardening
|
||||
Action 8: AETHER Track Re-ID
|
||||
│
|
||||
v
|
||||
Action 9: ADR-029 Documentation (this document)
|
||||
```
|
||||
|
||||
### 3.2 Cost and RuVector Mapping
|
||||
|
||||
| # | Action | Cost | Preconditions | RuVector Crates | Effects |
|
||||
|---|--------|------|---------------|-----------------|---------|
|
||||
| 1 | Channel-hopping firmware | 4/10 | ESP32 firmware exists | None (pure C) | `bandwidth_extended = true` |
|
||||
| 2 | Multi-band frame fusion | 5/10 | Action 1 | `solver`, `attention` | `fused_multi_band_frame = true` |
|
||||
| 3 | Multistatic mesh aggregation | 5/10 | Action 2 | `mincut`, `attn-mincut` | `multistatic_mesh = true` |
|
||||
| 4 | Pose tracker | 4/10 | Action 3, 7 | `mincut` | `pose_tracker = true` |
|
||||
| 5 | End-to-end pipeline | 6/10 | Actions 2-4 | `temporal-tensor`, `attention` | `20hz_update = true` |
|
||||
| 6 | Coherence metric | 3/10 | Action 2 | `solver` | `coherence_metric = true` |
|
||||
| 7 | Coherence gate | 3/10 | Action 6 | `attn-mincut` | `coherence_gating = true` |
|
||||
| 8 | AETHER re-ID | 4/10 | Actions 4, 7 | `attention` | `identity_stable = true` |
|
||||
| 9 | ADR documentation | 2/10 | All above | None | Decision documented |
|
||||
|
||||
**Total cost: 36 units. Minimum viable path to acceptance test: Actions 1-5 + 6-7 = 30 units.**
|
||||
|
||||
### 3.3 Latency Budget (50ms cycle)
|
||||
|
||||
| Stage | Budget | Method |
|
||||
|-------|--------|--------|
|
||||
| UDP receive + parse | <1 ms | ADR-018 binary, 148 bytes, zero-alloc |
|
||||
| Multi-band fusion | ~2 ms | NeumannSolver on 2×2 phase alignment |
|
||||
| Multistatic fusion | ~3 ms | attn_mincut on 3-6 nodes × 64 velocity bins |
|
||||
| Model inference | ~30-40 ms | CsiToPoseTransformer (lightweight, no ResNet) |
|
||||
| Kalman update | <1 ms | 17 independent 6D filters, stack-allocated |
|
||||
| **Total** | **~37-47 ms** | **Fits in 50 ms** |
|
||||
|
||||
---
|
||||
|
||||
## 4. Hardware Bill of Materials
|
||||
|
||||
| Component | Qty | Unit Cost | Purpose |
|
||||
|-----------|-----|-----------|---------|
|
||||
| ESP32-S3-DevKitC-1 | 4 | $10 | TX/RX sensing nodes |
|
||||
| ESP32-S3-DevKitC-1 | 1 | $10 | Aggregator (or x86/RPi host) |
|
||||
| External 5dBi antenna | 4-8 | $3 | Improved gain, directional coverage |
|
||||
| USB-C hub (4 port) | 1 | $15 | Power distribution |
|
||||
| Wall mount brackets | 4 | $2 | Ceiling/wall installation |
|
||||
| **Total** | | **$73-91** | Complete 4-node mesh |
|
||||
|
||||
---
|
||||
|
||||
## 5. RuVector v2.0.4 Integration Map
|
||||
|
||||
All five published crates are exercised:
|
||||
|
||||
| Crate | Actions | Integration Point | Algorithmic Advantage |
|
||||
|-------|---------|-------------------|----------------------|
|
||||
| `ruvector-solver` | 2, 6 | Phase alignment; coherence matrix decomposition | O(√n) Neumann convergence |
|
||||
| `ruvector-attention` | 2, 5, 8 | Cross-channel weighting; ring buffer; embedding similarity | Sublinear attention for small d |
|
||||
| `ruvector-mincut` | 3, 4 | Viewpoint diversity partitioning; track assignment | O(n^1.5 log n) dynamic updates |
|
||||
| `ruvector-attn-mincut` | 3, 7 | Cross-node spectrogram fusion; coherence gating | Attention + mincut in one pass |
|
||||
| `ruvector-temporal-tensor` | 5 | Compressed sensing window ring buffer | 50-75% memory reduction |
|
||||
|
||||
---
|
||||
|
||||
## 6. IEEE 802.11bf Alignment
|
||||
|
||||
RuvSense's TDMA sensing schedule is forward-compatible with IEEE 802.11bf (WLAN Sensing, published 2024):
|
||||
|
||||
| RuvSense Concept | 802.11bf Equivalent |
|
||||
|-----------------|---------------------|
|
||||
| TX slot | Sensing Initiator |
|
||||
| RX slot | Sensing Responder |
|
||||
| TDMA cycle | Sensing Measurement Instance |
|
||||
| NDP frame | Sensing NDP |
|
||||
| Aggregator | Sensing Session Owner |
|
||||
|
||||
When commercial APs support 802.11bf, the ESP32 mesh can interoperate by translating SSP slots into 802.11bf Sensing Trigger frames.
|
||||
|
||||
---
|
||||
|
||||
## 7. Dependency Changes
|
||||
|
||||
### Firmware (C)
|
||||
|
||||
New files:
|
||||
- `firmware/esp32-csi-node/main/sensing_schedule.h`
|
||||
- `firmware/esp32-csi-node/main/sensing_schedule.c`
|
||||
|
||||
Modified files:
|
||||
- `firmware/esp32-csi-node/main/csi_collector.c` (add channel hopping, link tagging)
|
||||
- `firmware/esp32-csi-node/main/main.c` (add GPIO sync, TDMA timer)
|
||||
|
||||
### Rust
|
||||
|
||||
New module: `crates/wifi-densepose-signal/src/ruvsense/` (6 files, ~1500 lines estimated)
|
||||
|
||||
Modified files:
|
||||
- `crates/wifi-densepose-signal/src/lib.rs` (export `ruvsense` module)
|
||||
- `crates/wifi-densepose-signal/Cargo.toml` (no new deps; all ruvector crates already present per ADR-017)
|
||||
- `crates/wifi-densepose-sensing-server/src/main.rs` (wire RuvSense pipeline into WebSocket output)
|
||||
|
||||
No new workspace dependencies. All ruvector crates are already in the workspace `Cargo.toml`.
|
||||
|
||||
---
|
||||
|
||||
## 8. Implementation Priority
|
||||
|
||||
| Priority | Actions | Weeks | Milestone |
|
||||
|----------|---------|-------|-----------|
|
||||
| P0 | 1 (firmware) | 2 | Channel-hopping ESP32 prototype |
|
||||
| P0 | 2 (multi-band) | 2 | Wideband virtual frames |
|
||||
| P1 | 3 (multistatic) | 2 | Multi-node fusion |
|
||||
| P1 | 4 (tracker) | 1 | 17-keypoint Kalman |
|
||||
| P1 | 6, 7 (coherence) | 1 | Gated updates |
|
||||
| P2 | 5 (end-to-end) | 2 | 20 Hz pipeline |
|
||||
| P2 | 8 (AETHER re-ID) | 1 | Identity hardening |
|
||||
| P3 | 9 (docs) | 0.5 | This ADR finalized |
|
||||
| **Total** | | **~10 weeks** | **Acceptance test** |
|
||||
|
||||
---
|
||||
|
||||
## 9. Consequences
|
||||
|
||||
### 9.1 Positive
|
||||
|
||||
- **3x bandwidth improvement** without hardware changes (channel hopping on existing ESP32)
|
||||
- **12 independent viewpoints** from 4 commodity $10 nodes (C(4,2) × 2 links)
|
||||
- **20 Hz update rate** with Kalman-smoothed output for sub-30mm torso jitter
|
||||
- **Days-long stability** via coherence gating + SONA recalibration
|
||||
- **All five ruvector crates exercised** — consistent algorithmic foundation
|
||||
- **$73-91 total BOM** — accessible for research and production
|
||||
- **802.11bf forward-compatible** — investment protected as commercial sensing arrives
|
||||
- **Cognitum upgrade path** — same software stack, swap ESP32 for higher-bandwidth front end
|
||||
|
||||
### 9.2 Negative
|
||||
|
||||
- **4-node deployment** requires physical installation and calibration of node positions
|
||||
- **TDMA scheduling** reduces per-node CSI rate (each node only transmits 1/4 of the time)
|
||||
- **Channel hopping** introduces ~1-5ms gaps during `esp_wifi_set_channel()` transitions
|
||||
- **5 GHz CSI on ESP32-S3** may not be available (ESP32-C6 supports it natively)
|
||||
- **Coherence gate** may reject valid measurements during fast body motion (mitigation: gate only on static-subcarrier coherence)
|
||||
|
||||
### 9.3 Risks
|
||||
|
||||
| Risk | Probability | Impact | Mitigation |
|
||||
|------|-------------|--------|------------|
|
||||
| ESP32 channel hop causes CSI gaps | Medium | Reduced effective rate | Measure gap duration; increase dwell if >5ms |
|
||||
| 5 GHz CSI unavailable on S3 | High | Lose frequency diversity | Fallback: 3-channel 2.4 GHz still provides 3x BW; ESP32-C6 for dual-band |
|
||||
| Model inference >40ms | Medium | Miss 20 Hz target | Run model at 10 Hz; Kalman predict at 20 Hz interpolates |
|
||||
| Two-person separation fails at 3 nodes | Low | Identity swaps | AETHER re-ID recovers; increase to 4-6 nodes |
|
||||
| Coherence gate false-triggers | Low | Missed updates | Gate on environmental coherence only, not body-motion subcarriers |
|
||||
|
||||
---
|
||||
|
||||
## 10. Related ADRs
|
||||
|
||||
| ADR | Relationship |
|
||||
|-----|-------------|
|
||||
| ADR-012 | **Extended**: RuvSense adds TDMA multistatic to single-AP mesh |
|
||||
| ADR-014 | **Used**: All 6 SOTA algorithms applied per-link |
|
||||
| ADR-016 | **Extended**: New ruvector integration points for multi-link fusion |
|
||||
| ADR-017 | **Extended**: Coherence gating adds temporal stability layer |
|
||||
| ADR-018 | **Modified**: Firmware gains channel hopping, TDMA schedule, HT40 |
|
||||
| ADR-022 | **Complementary**: RuvSense is the ESP32 equivalent of Windows multi-BSSID |
|
||||
| ADR-024 | **Used**: AETHER embeddings for person re-identification |
|
||||
| ADR-026 | **Reused**: Kalman + lifecycle infrastructure lifted to RuvSense |
|
||||
| ADR-027 | **Used**: GeometryEncoder, HardwareNormalizer, FiLM conditioning |
|
||||
|
||||
---
|
||||
|
||||
## 11. References
|
||||
|
||||
1. IEEE 802.11bf-2024. "WLAN Sensing." IEEE Standards Association.
|
||||
2. Geng, J., Huang, D., De la Torre, F. (2023). "DensePose From WiFi." arXiv:2301.00250.
|
||||
3. Yan, K. et al. (2024). "Person-in-WiFi 3D." CVPR 2024, pp. 969-978.
|
||||
4. Chen, L. et al. (2026). "PerceptAlign: Geometry-Aware WiFi Sensing." arXiv:2601.12252.
|
||||
5. Kotaru, M. et al. (2015). "SpotFi: Decimeter Level Localization Using WiFi." SIGCOMM.
|
||||
6. Zheng, Y. et al. (2019). "Zero-Effort Cross-Domain Gesture Recognition with Wi-Fi." MobiSys.
|
||||
7. Zeng, Y. et al. (2019). "FarSense: Pushing the Range Limit of WiFi-based Respiration Sensing." MobiCom.
|
||||
8. AM-FM (2026). "A Foundation Model for Ambient Intelligence Through WiFi." arXiv:2602.11200.
|
||||
9. Espressif ESP-CSI. https://github.com/espressif/esp-csi
|
||||
364
docs/adr/ADR-030-ruvsense-persistent-field-model.md
Normal file
364
docs/adr/ADR-030-ruvsense-persistent-field-model.md
Normal file
@@ -0,0 +1,364 @@
|
||||
# ADR-030: RuvSense Persistent Field Model — Longitudinal Drift Detection and Exotic Sensing Tiers
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| **Status** | Proposed |
|
||||
| **Date** | 2026-03-02 |
|
||||
| **Deciders** | ruv |
|
||||
| **Codename** | **RuvSense Field** — Persistent Electromagnetic World Model |
|
||||
| **Relates to** | ADR-029 (RuvSense Multistatic), ADR-005 (SONA Self-Learning), ADR-024 (AETHER Embeddings), ADR-016 (RuVector Integration), ADR-026 (Survivor Track Lifecycle), ADR-027 (MERIDIAN Generalization) |
|
||||
|
||||
---
|
||||
|
||||
## 1. Context
|
||||
|
||||
### 1.1 Beyond Pose Estimation
|
||||
|
||||
ADR-029 establishes RuvSense as a sensing-first multistatic mesh achieving 20 Hz DensePose with <30mm jitter. That treats WiFi as a **momentary pose estimator**. The next leap: treat the electromagnetic field as a **persistent world model** that remembers, predicts, and explains.
|
||||
|
||||
The most exotic capabilities come from this shift in abstraction level:
|
||||
- The room is the model, not the person
|
||||
- People are structured perturbations to a baseline
|
||||
- Changes are deltas from a known state, not raw measurements
|
||||
- Time is a first-class dimension — the system remembers days, not frames
|
||||
|
||||
### 1.2 The Seven Capability Tiers
|
||||
|
||||
| Tier | Capability | Foundation |
|
||||
|------|-----------|-----------|
|
||||
| 1 | **Field Normal Modes** — Room electromagnetic eigenstructure | Baseline calibration + SVD |
|
||||
| 2 | **Coarse RF Tomography** — 3D occupancy volume from link attenuations | Sparse tomographic inversion |
|
||||
| 3 | **Intention Lead Signals** — Pre-movement prediction (200-500ms lead) | Temporal embedding trajectory analysis |
|
||||
| 4 | **Longitudinal Biomechanics Drift** — Personal baseline deviation over days | Welford statistics + HNSW memory |
|
||||
| 5 | **Cross-Room Continuity** — Identity persistence across spaces without optics | Environment fingerprinting + transition graph |
|
||||
| 6 | **Invisible Interaction Layer** — Multi-user gesture control through walls/darkness | Per-person CSI perturbation classification |
|
||||
| 7 | **Adversarial Detection** — Physically impossible signal identification | Multi-link consistency + field model constraints |
|
||||
|
||||
### 1.3 Signals, Not Diagnoses
|
||||
|
||||
RF sensing detects **biophysical proxies**, not medical conditions:
|
||||
|
||||
| Detectable Signal | Not Detectable |
|
||||
|-------------------|---------------|
|
||||
| Breathing rate variability | COPD diagnosis |
|
||||
| Gait asymmetry shift (18% over 14 days) | Parkinson's disease |
|
||||
| Posture instability increase | Neurological condition |
|
||||
| Micro-tremor onset | Specific tremor etiology |
|
||||
| Activity level decline | Depression or pain diagnosis |
|
||||
|
||||
The output is: "Your movement symmetry has shifted 18 percent over 14 days." That is actionable without being diagnostic. The evidence chain (stored embeddings, drift statistics, coherence scores) is fully traceable.
|
||||
|
||||
### 1.4 Acceptance Tests
|
||||
|
||||
**Tier 0 (ADR-029):** Two people, 20 Hz, 10 min stable tracks, zero ID swaps, <30mm torso jitter.
|
||||
|
||||
**Tier 1-4 (this ADR):** Seven-day run, no manual tuning. System flags one real environmental change and one real human drift event, produces traceable explanation using stored embeddings plus graph constraints.
|
||||
|
||||
**Tier 5-7 (appliance):** Thirty-day local run, no camera. Detects meaningful drift with <5% false alarm rate.
|
||||
|
||||
---
|
||||
|
||||
## 2. Decision
|
||||
|
||||
### 2.1 Implement Field Normal Modes as the Foundation
|
||||
|
||||
Add a `field_model` module to `wifi-densepose-signal/src/ruvsense/` that learns the room's electromagnetic baseline during unoccupied periods and decomposes all subsequent observations into environmental drift + body perturbation.
|
||||
|
||||
```
|
||||
wifi-densepose-signal/src/ruvsense/
|
||||
├── mod.rs // (existing, extend)
|
||||
├── field_model.rs // NEW: Field normal mode computation + perturbation extraction
|
||||
├── tomography.rs // NEW: Coarse RF tomography from link attenuations
|
||||
├── longitudinal.rs // NEW: Personal baseline + drift detection
|
||||
├── intention.rs // NEW: Pre-movement lead signal detector
|
||||
├── cross_room.rs // NEW: Cross-room identity continuity
|
||||
├── gesture.rs // NEW: Gesture classification from CSI perturbations
|
||||
├── adversarial.rs // NEW: Physically impossible signal detection
|
||||
└── (existing files...)
|
||||
```
|
||||
|
||||
### 2.2 Core Architecture: The Persistent Field Model
|
||||
|
||||
```
|
||||
Time
|
||||
│
|
||||
▼
|
||||
┌────────────────────────────────┐
|
||||
│ Field Normal Modes (Tier 1) │
|
||||
│ Room baseline + SVD modes │
|
||||
│ ruvector-solver │
|
||||
└────────────┬───────────────────┘
|
||||
│ Body perturbation (environmental drift removed)
|
||||
│
|
||||
┌───────┴───────┐
|
||||
│ │
|
||||
▼ ▼
|
||||
┌──────────┐ ┌──────────────┐
|
||||
│ Pose │ │ RF Tomography│
|
||||
│ (ADR-029)│ │ (Tier 2) │
|
||||
│ 20 Hz │ │ Occupancy vol│
|
||||
└────┬─────┘ └──────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────┐
|
||||
│ AETHER Embedding (ADR-024) │
|
||||
│ 128-dim contrastive vector │
|
||||
└────────────┬─────────────────┘
|
||||
│
|
||||
┌───────┼───────┐
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
┌────────┐ ┌─────┐ ┌──────────┐
|
||||
│Intention│ │Track│ │Cross-Room│
|
||||
│Lead │ │Re-ID│ │Continuity│
|
||||
│(Tier 3)│ │ │ │(Tier 5) │
|
||||
└────────┘ └──┬──┘ └──────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────┐
|
||||
│ RuVector Longitudinal Memory │
|
||||
│ HNSW + graph + Welford stats│
|
||||
│ (Tier 4) │
|
||||
└──────────────┬───────────────┘
|
||||
│
|
||||
┌───────┴───────┐
|
||||
│ │
|
||||
▼ ▼
|
||||
┌──────────────┐ ┌──────────────┐
|
||||
│ Drift Reports│ │ Adversarial │
|
||||
│ (Level 1-3) │ │ Detection │
|
||||
│ │ │ (Tier 7) │
|
||||
└──────────────┘ └──────────────┘
|
||||
```
|
||||
|
||||
### 2.3 Field Normal Modes (Tier 1)
|
||||
|
||||
**What it is:** The room's electromagnetic eigenstructure — the stable propagation paths, reflection coefficients, and interference patterns when nobody is present.
|
||||
|
||||
**How it works:**
|
||||
1. During quiet periods (empty room, overnight), collect 10 minutes of CSI across all links
|
||||
2. Compute per-link baseline (mean CSI vector)
|
||||
3. Compute environmental variation modes via SVD (temperature, humidity, time-of-day effects)
|
||||
4. Store top-K modes (K=3-5 typically captures >95% of environmental variance)
|
||||
5. At runtime: subtract baseline, project out environmental modes, keep body perturbation
|
||||
|
||||
```rust
|
||||
pub struct FieldNormalMode {
|
||||
pub baseline: Vec<Vec<Complex<f32>>>, // [n_links × n_subcarriers]
|
||||
pub environmental_modes: Vec<Vec<f32>>, // [n_modes × n_subcarriers]
|
||||
pub mode_energies: Vec<f32>, // eigenvalues
|
||||
pub calibrated_at: u64,
|
||||
pub geometry_hash: u64,
|
||||
}
|
||||
```
|
||||
|
||||
**RuVector integration:**
|
||||
- `ruvector-solver` → Low-rank SVD for mode extraction
|
||||
- `ruvector-temporal-tensor` → Compressed baseline history storage
|
||||
- `ruvector-attn-mincut` → Identify which subcarriers belong to which mode
|
||||
|
||||
### 2.4 Longitudinal Drift Detection (Tier 4)
|
||||
|
||||
**The defensible pipeline:**
|
||||
|
||||
```
|
||||
RF → AETHER contrastive embedding
|
||||
→ RuVector longitudinal memory (HNSW + graph)
|
||||
→ Coherence-gated drift detection (Welford statistics)
|
||||
→ Risk flag with traceable evidence
|
||||
```
|
||||
|
||||
**Three monitoring levels:**
|
||||
|
||||
| Level | Signal Type | Example Output |
|
||||
|-------|------------|----------------|
|
||||
| **1: Physiological** | Raw biophysical metrics | "Breathing rate: 18.3 BPM today, 7-day avg: 16.1" |
|
||||
| **2: Drift** | Personal baseline deviation | "Gait symmetry shifted 18% over 14 days" |
|
||||
| **3: Risk correlation** | Pattern-matched concern | "Pattern consistent with increased fall risk" |
|
||||
|
||||
**Storage model:**
|
||||
|
||||
```rust
|
||||
pub struct PersonalBaseline {
|
||||
pub person_id: PersonId,
|
||||
pub gait_symmetry: WelfordStats,
|
||||
pub stability_index: WelfordStats,
|
||||
pub breathing_regularity: WelfordStats,
|
||||
pub micro_tremor: WelfordStats,
|
||||
pub activity_level: WelfordStats,
|
||||
pub embedding_centroid: Vec<f32>, // [128]
|
||||
pub observation_days: u32,
|
||||
pub updated_at: u64,
|
||||
}
|
||||
```
|
||||
|
||||
**RuVector integration:**
|
||||
- `ruvector-temporal-tensor` → Compressed daily summaries (50-75% memory savings)
|
||||
- HNSW → Embedding similarity search across longitudinal record
|
||||
- `ruvector-attention` → Per-metric drift significance weighting
|
||||
- `ruvector-mincut` → Temporal segmentation (detect changepoints in metric series)
|
||||
|
||||
### 2.5 Regulatory Classification
|
||||
|
||||
| Classification | What You Claim | Regulatory Path |
|
||||
|---------------|---------------|-----------------|
|
||||
| **Consumer wellness** (recommended first) | Activity metrics, breathing rate, stability score | Self-certification, FCC Part 15 |
|
||||
| **Clinical decision support** (future) | Fall risk alert, respiratory pattern concern | FDA Class II 510(k) or De Novo |
|
||||
| **Regulated medical device** (requires clinical partner) | Diagnostic claims for specific conditions | FDA Class II/III + clinical trials |
|
||||
|
||||
**Decision: Start as consumer wellness.** Build 12+ months of real-world longitudinal data. The dataset itself becomes the asset for future regulatory submissions.
|
||||
|
||||
---
|
||||
|
||||
## 3. Appliance Product Categories
|
||||
|
||||
### 3.1 Invisible Guardian
|
||||
|
||||
Wall-mounted wellness monitor for elderly care and independent living. No camera, no microphone, no reconstructable data. Stores embeddings and structural deltas only.
|
||||
|
||||
| Spec | Value |
|
||||
|------|-------|
|
||||
| Nodes | 4 ESP32-S3 pucks per room |
|
||||
| Processing | Central hub (RPi 5 or x86) |
|
||||
| Power | PoE or USB-C |
|
||||
| Output | Risk flags, drift alerts, occupancy timeline |
|
||||
| BOM | $73-91 (ESP32 mesh) + $35-80 (hub) |
|
||||
| Validation | 30-day autonomous run, <5% false alarm rate |
|
||||
|
||||
### 3.2 Spatial Digital Twin Node
|
||||
|
||||
Live electromagnetic room model for smart buildings and workplace analytics.
|
||||
|
||||
| Spec | Value |
|
||||
|------|-------|
|
||||
| Output | Occupancy heatmap, flow vectors, dwell time, anomaly events |
|
||||
| Integration | MQTT/REST API for BMS and CAFM |
|
||||
| Retention | 30-day rolling, GDPR-compliant |
|
||||
| Vertical | Smart buildings, retail, workspace optimization |
|
||||
|
||||
### 3.3 RF Interaction Surface
|
||||
|
||||
Multi-user gesture interface. No cameras. Works in darkness, smoke, through clothing.
|
||||
|
||||
| Spec | Value |
|
||||
|------|-------|
|
||||
| Gestures | Wave, point, beckon, push, circle + custom |
|
||||
| Users | Up to 4 simultaneous |
|
||||
| Latency | <100ms gesture recognition |
|
||||
| Vertical | Smart home, hospitality, accessibility |
|
||||
|
||||
### 3.4 Pre-Incident Drift Monitor
|
||||
|
||||
Longitudinal biomechanics tracker for rehabilitation and occupational health.
|
||||
|
||||
| Spec | Value |
|
||||
|------|-------|
|
||||
| Baseline | 7-day calibration per person |
|
||||
| Alert | Metric drift >2sigma for >3 days |
|
||||
| Evidence | Stored embedding trajectory + statistical report |
|
||||
| Vertical | Elderly care, rehab, occupational health |
|
||||
|
||||
### 3.5 Vertical Recommendation for First Hardware SKU
|
||||
|
||||
**Invisible Guardian** — the elderly care wellness monitor. Rationale:
|
||||
1. Largest addressable market with immediate revenue (aging population, care facility demand)
|
||||
2. Lowest regulatory bar (consumer wellness, no diagnostic claims)
|
||||
3. Privacy advantage over cameras is a selling point, not a limitation
|
||||
4. 30-day autonomous operation validates all tiers (field model, drift detection, coherence gating)
|
||||
5. $108-171 BOM allows $299-499 retail with healthy margins
|
||||
|
||||
---
|
||||
|
||||
## 4. RuVector Integration Map (Extended)
|
||||
|
||||
All five crates are exercised across the exotic tiers:
|
||||
|
||||
| Tier | Crate | API | Role |
|
||||
|------|-------|-----|------|
|
||||
| 1 (Field) | `ruvector-solver` | `NeumannSolver` + SVD | Environmental mode decomposition |
|
||||
| 1 (Field) | `ruvector-temporal-tensor` | `TemporalTensorCompressor` | Baseline history storage |
|
||||
| 1 (Field) | `ruvector-attn-mincut` | `attn_mincut` | Mode-subcarrier assignment |
|
||||
| 2 (Tomo) | `ruvector-solver` | `NeumannSolver` (L1) | Sparse tomographic inversion |
|
||||
| 3 (Intent) | `ruvector-attention` | `ScaledDotProductAttention` | Temporal trajectory weighting |
|
||||
| 3 (Intent) | `ruvector-temporal-tensor` | `CompressedCsiBuffer` | 2-second embedding history |
|
||||
| 4 (Drift) | `ruvector-temporal-tensor` | `TemporalTensorCompressor` | Daily summary compression |
|
||||
| 4 (Drift) | `ruvector-attention` | `ScaledDotProductAttention` | Metric drift significance |
|
||||
| 4 (Drift) | `ruvector-mincut` | `DynamicMinCut` | Temporal changepoint detection |
|
||||
| 5 (Cross-Room) | `ruvector-attention` | HNSW | Room and person fingerprint matching |
|
||||
| 5 (Cross-Room) | `ruvector-mincut` | `MinCutBuilder` | Transition graph partitioning |
|
||||
| 6 (Gesture) | `ruvector-attention` | `ScaledDotProductAttention` | Gesture template matching |
|
||||
| 7 (Adversarial) | `ruvector-solver` | `NeumannSolver` | Physical plausibility verification |
|
||||
| 7 (Adversarial) | `ruvector-attn-mincut` | `attn_mincut` | Multi-link consistency check |
|
||||
|
||||
---
|
||||
|
||||
## 5. Implementation Priority
|
||||
|
||||
| Priority | Tier | Module | Weeks | Dependency |
|
||||
|----------|------|--------|-------|------------|
|
||||
| P0 | 1 | `field_model.rs` | 2 | ADR-029 multistatic mesh operational |
|
||||
| P0 | 4 | `longitudinal.rs` | 2 | Tier 1 baseline + AETHER embeddings |
|
||||
| P1 | 2 | `tomography.rs` | 1 | Tier 1 perturbation extraction |
|
||||
| P1 | 3 | `intention.rs` | 2 | Tier 1 + temporal embedding history |
|
||||
| P2 | 5 | `cross_room.rs` | 2 | Tier 4 person profiles + multi-room deployment |
|
||||
| P2 | 6 | `gesture.rs` | 1 | Tier 1 perturbation + per-person separation |
|
||||
| P3 | 7 | `adversarial.rs` | 1 | Tier 1 field model + multi-link consistency |
|
||||
|
||||
**Total exotic tier: ~11 weeks after ADR-029 acceptance test passes.**
|
||||
|
||||
---
|
||||
|
||||
## 6. Consequences
|
||||
|
||||
### 6.1 Positive
|
||||
|
||||
- **Room becomes self-sensing**: Field normal modes provide a persistent baseline that explains change as structured deltas
|
||||
- **7-day autonomous operation**: Coherence gating + SONA adaptation + longitudinal memory eliminate manual tuning
|
||||
- **Privacy by design**: No images, no audio, no reconstructable data — only embeddings and statistical summaries
|
||||
- **Traceable evidence**: Every drift alert links to stored embeddings, timestamps, and graph constraints
|
||||
- **Multiple product categories**: Same software stack, different packaging — Guardian, Twin, Interaction, Drift Monitor
|
||||
- **Regulatory clarity**: Consumer wellness first, clinical decision support later with accumulated dataset
|
||||
- **Security primitive**: Coherence gating detects adversarial injection, not just quality issues
|
||||
|
||||
### 6.2 Negative
|
||||
|
||||
- **7-day calibration** required for personal baselines (system is less useful during initial period)
|
||||
- **Empty-room calibration** needed for field normal modes (may not always be available)
|
||||
- **Storage growth**: Longitudinal memory grows ~1 KB/person/day (manageable but non-zero)
|
||||
- **Statistical power**: Drift detection requires 14+ days of data for meaningful z-scores
|
||||
- **Multi-room**: Cross-room continuity requires hardware in all rooms (cost scales linearly)
|
||||
|
||||
### 6.3 Risks
|
||||
|
||||
| Risk | Probability | Impact | Mitigation |
|
||||
|------|-------------|--------|------------|
|
||||
| Field modes drift faster than expected | Medium | False perturbation detections | Reduce mode update interval from 24h to 4h |
|
||||
| Personal baselines too variable | Medium | High false alarm rate for drift | Widen sigma threshold from 2σ to 3σ; require 5+ days |
|
||||
| Cross-room matching fails for similar body types | Low | Identity confusion | Require temporal proximity (<60s) plus spatial adjacency |
|
||||
| Gesture recognition insufficient SNR | Medium | <80% accuracy | Restrict to near-field (<2m) initially |
|
||||
| Adversarial injection via coordinated WiFi injection | Very Low | Spoofed occupancy | Multi-link consistency check makes single-link spoofing detectable |
|
||||
|
||||
---
|
||||
|
||||
## 7. Related ADRs
|
||||
|
||||
| ADR | Relationship |
|
||||
|-----|-------------|
|
||||
| ADR-029 | **Prerequisite**: Multistatic mesh is the sensing substrate for all exotic tiers |
|
||||
| ADR-005 (SONA) | **Extended**: SONA recalibration triggered by coherence gate → now also by drift events |
|
||||
| ADR-016 (RuVector) | **Extended**: All 5 crates exercised across 7 exotic tiers |
|
||||
| ADR-024 (AETHER) | **Critical dependency**: Embeddings are the representation for all longitudinal memory |
|
||||
| ADR-026 (Tracking) | **Extended**: Track lifecycle now spans days (not minutes) for drift detection |
|
||||
| ADR-027 (MERIDIAN) | **Used**: Room geometry encoding for field normal mode conditioning |
|
||||
|
||||
---
|
||||
|
||||
## 8. References
|
||||
|
||||
1. IEEE 802.11bf-2024. "WLAN Sensing." IEEE Standards Association.
|
||||
2. FDA. "General Wellness: Policy for Low Risk Devices." Guidance Document, 2019.
|
||||
3. EU MDR 2017/745. "Medical Device Regulation." Official Journal of the European Union.
|
||||
4. Welford, B.P. (1962). "Note on a Method for Calculating Corrected Sums of Squares." Technometrics.
|
||||
5. Chen, L. et al. (2026). "PerceptAlign: Geometry-Aware WiFi Sensing." arXiv:2601.12252.
|
||||
6. AM-FM (2026). "A Foundation Model for Ambient Intelligence Through WiFi." arXiv:2602.11200.
|
||||
7. Geng, J. et al. (2023). "DensePose From WiFi." arXiv:2301.00250.
|
||||
369
docs/adr/ADR-031-ruview-sensing-first-rf-mode.md
Normal file
369
docs/adr/ADR-031-ruview-sensing-first-rf-mode.md
Normal file
@@ -0,0 +1,369 @@
|
||||
# ADR-031: Project RuView -- Sensing-First RF Mode for Multistatic Fidelity Enhancement
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| **Status** | Proposed |
|
||||
| **Date** | 2026-03-02 |
|
||||
| **Deciders** | ruv |
|
||||
| **Codename** | **RuView** -- RuVector Viewpoint-Integrated Enhancement |
|
||||
| **Relates to** | ADR-012 (ESP32 Mesh), ADR-014 (SOTA Signal), ADR-016 (RuVector Integration), ADR-017 (RuVector Signal+MAT), ADR-021 (Vital Signs), ADR-024 (AETHER Embeddings), ADR-027 (MERIDIAN Cross-Environment) |
|
||||
|
||||
---
|
||||
|
||||
## 1. Context
|
||||
|
||||
### 1.1 The Single-Viewpoint Fidelity Ceiling
|
||||
|
||||
Current WiFi DensePose operates with a single transmitter-receiver pair (or single node receiving). This creates three fundamental limitations:
|
||||
|
||||
- **Body self-occlusion**: Limbs behind the torso are invisible to a single viewpoint.
|
||||
- **Depth ambiguity**: Motion along the RF propagation axis (toward/away from receiver) produces minimal phase change.
|
||||
- **Multi-person confusion**: Two people at similar range but different angles create overlapping CSI signatures.
|
||||
|
||||
The ESP32 mesh (ADR-012) partially addresses this via feature-level fusion across 3-6 nodes, but feature-level fusion cannot learn optimal fusion weights -- it uses hand-crafted aggregation (max, mean, coherent sum).
|
||||
|
||||
### 1.2 Three Fidelity Levers
|
||||
|
||||
1. **Bandwidth**: More bandwidth produces better multipath separability. Currently limited to 20 MHz (ESP32 HT20). Wider channels (80/160 MHz) are available on commodity 802.11ac/ax APs.
|
||||
2. **Carrier frequency**: Higher frequency produces more phase sensitivity. 2.4 GHz sees macro-motion; 5 GHz sees micro-motion; 60 GHz sees vital signs.
|
||||
3. **Viewpoints**: More viewpoints from different angles reduces geometric ambiguity. This is the lever RuView pulls.
|
||||
|
||||
### 1.3 Why "Sensing-First RF Mode"
|
||||
|
||||
RuView is NOT a new WiFi standard. It is a sensing-first protocol that rides on existing silicon, bands, and regulations. The key insight: instead of upgrading the RF hardware, upgrade the observability by coordinating multiple commodity receivers.
|
||||
|
||||
### 1.4 What Already Exists
|
||||
|
||||
| Component | ADR | Current State |
|
||||
|-----------|-----|---------------|
|
||||
| ESP32 mesh with feature-level fusion | ADR-012 | Implemented (firmware + aggregator) |
|
||||
| SOTA signal processing (Hampel, Fresnel, BVP, spectrogram) | ADR-014 | Implemented |
|
||||
| RuVector training pipeline (5 crates) | ADR-016 | Complete |
|
||||
| RuVector signal + MAT integration (7 points) | ADR-017 | Accepted |
|
||||
| Vital sign detection pipeline | ADR-021 | Partially implemented |
|
||||
| AETHER contrastive embeddings | ADR-024 | Proposed |
|
||||
| MERIDIAN cross-environment generalization | ADR-027 | Proposed |
|
||||
|
||||
RuView fills the gap: **cross-viewpoint embedding fusion** using learned attention weights.
|
||||
|
||||
---
|
||||
|
||||
## 2. Decision
|
||||
|
||||
Introduce RuView as a cross-viewpoint embedding fusion layer that operates on top of AETHER per-viewpoint embeddings. RuView adds a new bounded context (ViewpointFusion) and extends three existing crates.
|
||||
|
||||
### 2.1 Core Architecture
|
||||
|
||||
```
|
||||
+-----------------------------------------------------------------+
|
||||
| RuView Multistatic Pipeline |
|
||||
+-----------------------------------------------------------------+
|
||||
| |
|
||||
| +----------+ +----------+ +----------+ +----------+ |
|
||||
| | Node 1 | | Node 2 | | Node 3 | | Node N | |
|
||||
| | ESP32-S3 | | ESP32-S3 | | ESP32-S3 | | ESP32-S3 | |
|
||||
| | | | | | | | | |
|
||||
| | CSI Rx | | CSI Rx | | CSI Rx | | CSI Rx | |
|
||||
| +----+-----+ +----+-----+ +----+-----+ +----+-----+ |
|
||||
| | | | | |
|
||||
| v v v v |
|
||||
| +--------------------------------------------------------+ |
|
||||
| | Per-Viewpoint Signal Processing | |
|
||||
| | Phase sanitize -> Hampel -> BVP -> Subcarrier select | |
|
||||
| | (ADR-014, unchanged per viewpoint) | |
|
||||
| +----------------------------+---------------------------+ |
|
||||
| | |
|
||||
| v |
|
||||
| +--------------------------------------------------------+ |
|
||||
| | Per-Viewpoint AETHER Embedding | |
|
||||
| | CsiToPoseTransformer -> 128-d contrastive embedding | |
|
||||
| | (ADR-024, one per viewpoint) | |
|
||||
| +----------------------------+---------------------------+ |
|
||||
| | |
|
||||
| [emb_1, emb_2, ..., emb_N] |
|
||||
| | |
|
||||
| v |
|
||||
| +--------------------------------------------------------+ |
|
||||
| | * RuView Cross-Viewpoint Fusion * | |
|
||||
| | | |
|
||||
| | Q = W_q * X, K = W_k * X, V = W_v * X | |
|
||||
| | A = softmax((QK^T + G_bias) / sqrt(d)) | |
|
||||
| | fused = A * V | |
|
||||
| | | |
|
||||
| | G_bias: geometric bias from viewpoint pair geometry | |
|
||||
| | (ruvector-attention: ScaledDotProductAttention) | |
|
||||
| +----------------------------+---------------------------+ |
|
||||
| | |
|
||||
| fused_embedding |
|
||||
| | |
|
||||
| v |
|
||||
| +--------------------------------------------------------+ |
|
||||
| | DensePose Regression Head | |
|
||||
| | Keypoint head: [B,17,H,W] | |
|
||||
| | Part/UV head: [B,25,H,W] + [B,48,H,W] | |
|
||||
| +--------------------------------------------------------+ |
|
||||
+-----------------------------------------------------------------+
|
||||
```
|
||||
|
||||
### 2.2 TDM Sensing Protocol
|
||||
|
||||
- Coordinator (aggregator) broadcasts sync beacon at start of each cycle.
|
||||
- Each node transmits in assigned time slot; all others receive.
|
||||
- 6 nodes x 1.4 ms/slot = 8.4 ms cycle -> ~119 Hz aggregate, ~20 Hz per bistatic pair.
|
||||
- Clock drift handled at feature level (no cross-node phase alignment).
|
||||
|
||||
### 2.3 Geometric Bias Matrix
|
||||
|
||||
The geometric bias `G_bias` encodes the spatial relationship between viewpoint pairs:
|
||||
|
||||
```
|
||||
G_bias[i,j] = w_angle * cos(theta_ij) + w_dist * exp(-d_ij / d_ref)
|
||||
```
|
||||
|
||||
where:
|
||||
|
||||
- `theta_ij` = angle between viewpoint i and viewpoint j (from room center)
|
||||
- `d_ij` = baseline distance between node i and node j
|
||||
- `w_angle`, `w_dist` = learnable weights
|
||||
- `d_ref` = reference distance (room diagonal / 2)
|
||||
|
||||
This allows the attention mechanism to learn that widely-separated, orthogonal viewpoints are more complementary than clustered ones.
|
||||
|
||||
### 2.4 Coherence-Gated Environment Updates
|
||||
|
||||
```rust
|
||||
/// Only update environment model when phase coherence exceeds threshold.
|
||||
pub fn coherence_gate(
|
||||
phase_diffs: &[f32], // delta-phi over T recent frames
|
||||
threshold: f32, // typically 0.7
|
||||
) -> bool {
|
||||
// Complex mean of unit phasors
|
||||
let (sum_cos, sum_sin) = phase_diffs.iter()
|
||||
.fold((0.0f32, 0.0f32), |(c, s), &dp| {
|
||||
(c + dp.cos(), s + dp.sin())
|
||||
});
|
||||
let n = phase_diffs.len() as f32;
|
||||
let coherence = ((sum_cos / n).powi(2) + (sum_sin / n).powi(2)).sqrt();
|
||||
coherence > threshold
|
||||
}
|
||||
```
|
||||
|
||||
### 2.5 Two Implementation Paths
|
||||
|
||||
| Path | Hardware | Bandwidth | Per-Viewpoint Rate | Target Tier |
|
||||
|------|----------|-----------|-------------------|-------------|
|
||||
| **ESP32 Multistatic** | 6x ESP32-S3 ($84) | 20 MHz (HT20) | 20 Hz | Silver |
|
||||
| **Cognitum + RF** | Cognitum v1 + LimeSDR | 20-160 MHz | 20-100 Hz | Gold |
|
||||
|
||||
ESP32 path: commodity, achievable today, targets Silver tier (tracking + pose quality).
|
||||
Cognitum path: higher fidelity, targets Gold tier (tracking + pose + vitals).
|
||||
|
||||
---
|
||||
|
||||
## 3. DDD Design
|
||||
|
||||
### 3.1 New Bounded Context: ViewpointFusion
|
||||
|
||||
**Aggregate Root: `MultistaticArray`**
|
||||
|
||||
```rust
|
||||
pub struct MultistaticArray {
|
||||
/// Unique array deployment ID
|
||||
id: ArrayId,
|
||||
/// Viewpoint geometry (node positions, orientations)
|
||||
geometry: ArrayGeometry,
|
||||
/// TDM schedule (slot assignments, cycle period)
|
||||
schedule: TdmSchedule,
|
||||
/// Active viewpoint embeddings (latest per node)
|
||||
viewpoints: Vec<ViewpointEmbedding>,
|
||||
/// Fused output embedding
|
||||
fused: Option<FusedEmbedding>,
|
||||
/// Coherence gate state
|
||||
coherence_state: CoherenceState,
|
||||
}
|
||||
```
|
||||
|
||||
**Entity: `ViewpointEmbedding`**
|
||||
|
||||
```rust
|
||||
pub struct ViewpointEmbedding {
|
||||
/// Source node ID
|
||||
node_id: NodeId,
|
||||
/// AETHER embedding vector (128-d)
|
||||
embedding: Vec<f32>,
|
||||
/// Geometric metadata
|
||||
azimuth: f32, // radians from array center
|
||||
elevation: f32, // radians
|
||||
baseline: f32, // meters from centroid
|
||||
/// Capture timestamp
|
||||
timestamp: Instant,
|
||||
/// Signal quality
|
||||
snr_db: f32,
|
||||
}
|
||||
```
|
||||
|
||||
**Value Object: `GeometricDiversityIndex`**
|
||||
|
||||
```rust
|
||||
pub struct GeometricDiversityIndex {
|
||||
/// GDI = (1/N) sum min_{j!=i} |theta_i - theta_j|
|
||||
value: f32,
|
||||
/// Effective independent viewpoints (after correlation discount)
|
||||
n_effective: f32,
|
||||
/// Worst viewpoint pair (most redundant)
|
||||
worst_pair: (NodeId, NodeId),
|
||||
}
|
||||
```
|
||||
|
||||
**Domain Events:**
|
||||
|
||||
```rust
|
||||
pub enum ViewpointFusionEvent {
|
||||
ViewpointCaptured { node_id: NodeId, timestamp: Instant, snr_db: f32 },
|
||||
TdmCycleCompleted { cycle_id: u64, viewpoints_received: usize },
|
||||
FusionCompleted { fused_embedding: Vec<f32>, gdi: f32 },
|
||||
CoherenceGateTriggered { coherence: f32, accepted: bool },
|
||||
GeometryUpdated { new_gdi: f32, n_effective: f32 },
|
||||
}
|
||||
```
|
||||
|
||||
### 3.2 Extended Bounded Contexts
|
||||
|
||||
**Signal (wifi-densepose-signal):**
|
||||
- New service: `CrossViewpointSubcarrierSelection`
|
||||
- Consensus sensitive subcarrier set across all viewpoints via ruvector-mincut.
|
||||
- Input: per-viewpoint sensitivity scores. Output: globally-sensitive + locally-sensitive partition.
|
||||
|
||||
**Hardware (wifi-densepose-hardware):**
|
||||
- New protocol: `TdmSensingProtocol`
|
||||
- Coordinator logic: beacon generation, slot scheduling, clock drift compensation.
|
||||
- Event: `TdmSlotCompleted { node_id, slot_index, capture_quality }`
|
||||
|
||||
**Training (wifi-densepose-train):**
|
||||
- New module: `ruview_metrics.rs`
|
||||
- Three-metric acceptance test: PCK/OKS (joint error), MOTA (multi-person separation), vital sign accuracy.
|
||||
- Tiered pass/fail: Bronze/Silver/Gold.
|
||||
|
||||
---
|
||||
|
||||
## 4. Implementation Plan (File-Level)
|
||||
|
||||
### 4.1 Phase 1: ViewpointFusion Core (New Files)
|
||||
|
||||
| File | Purpose | RuVector Crate |
|
||||
|------|---------|---------------|
|
||||
| `crates/wifi-densepose-ruvector/src/viewpoint/mod.rs` | Module root, re-exports | -- |
|
||||
| `crates/wifi-densepose-ruvector/src/viewpoint/attention.rs` | Cross-viewpoint scaled dot-product attention with geometric bias | ruvector-attention |
|
||||
| `crates/wifi-densepose-ruvector/src/viewpoint/geometry.rs` | GeometricDiversityIndex, Cramer-Rao bound estimation | ruvector-solver |
|
||||
| `crates/wifi-densepose-ruvector/src/viewpoint/coherence.rs` | Coherence gating for environment stability | -- (pure math) |
|
||||
| `crates/wifi-densepose-ruvector/src/viewpoint/fusion.rs` | MultistaticArray aggregate, orchestrates fusion pipeline | ruvector-attention + ruvector-attn-mincut |
|
||||
|
||||
### 4.2 Phase 2: Signal Processing Extension
|
||||
|
||||
| File | Purpose | RuVector Crate |
|
||||
|------|---------|---------------|
|
||||
| `crates/wifi-densepose-signal/src/cross_viewpoint.rs` | Cross-viewpoint subcarrier consensus via min-cut | ruvector-mincut |
|
||||
|
||||
### 4.3 Phase 3: Hardware Protocol Extension
|
||||
|
||||
| File | Purpose | RuVector Crate |
|
||||
|------|---------|---------------|
|
||||
| `crates/wifi-densepose-hardware/src/esp32/tdm.rs` | TDM sensing protocol coordinator | -- (protocol logic) |
|
||||
|
||||
### 4.4 Phase 4: Training and Metrics
|
||||
|
||||
| File | Purpose | RuVector Crate |
|
||||
|------|---------|---------------|
|
||||
| `crates/wifi-densepose-train/src/ruview_metrics.rs` | Three-metric acceptance test (PCK/OKS, MOTA, vital sign accuracy) | ruvector-mincut (person matching) |
|
||||
|
||||
---
|
||||
|
||||
## 5. Three-Metric Acceptance Test
|
||||
|
||||
### 5.1 Metric 1: Joint Error (PCK / OKS)
|
||||
|
||||
| Criterion | Threshold |
|
||||
|-----------|-----------|
|
||||
| PCK@0.2 (all 17 keypoints) | >= 0.70 |
|
||||
| PCK@0.2 (torso: shoulders + hips) | >= 0.80 |
|
||||
| Mean OKS | >= 0.50 |
|
||||
| Torso jitter RMS (10s window) | < 3 cm |
|
||||
| Per-keypoint max error (95th percentile) | < 15 cm |
|
||||
|
||||
### 5.2 Metric 2: Multi-Person Separation
|
||||
|
||||
| Criterion | Threshold |
|
||||
|-----------|-----------|
|
||||
| Subjects | 2 |
|
||||
| Capture rate | 20 Hz |
|
||||
| Track duration | 10 minutes |
|
||||
| Identity swaps (MOTA ID-switch) | 0 |
|
||||
| Track fragmentation ratio | < 0.05 |
|
||||
| False track creation | 0/min |
|
||||
|
||||
### 5.3 Metric 3: Vital Sign Sensitivity
|
||||
|
||||
| Criterion | Threshold |
|
||||
|-----------|-----------|
|
||||
| Breathing detection (6-30 BPM) | +/- 2 BPM |
|
||||
| Breathing band SNR (0.1-0.5 Hz) | >= 6 dB |
|
||||
| Heartbeat detection (40-120 BPM) | +/- 5 BPM (aspirational) |
|
||||
| Heartbeat band SNR (0.8-2.0 Hz) | >= 3 dB (aspirational) |
|
||||
| Micro-motion resolution | 1 mm at 3m |
|
||||
|
||||
### 5.4 Tiered Pass/Fail
|
||||
|
||||
| Tier | Requirements | Deployment Gate |
|
||||
|------|-------------|-----------------|
|
||||
| Bronze | Metric 2 | Prototype demo |
|
||||
| Silver | Metrics 1 + 2 | Production candidate |
|
||||
| Gold | All three | Full deployment |
|
||||
|
||||
---
|
||||
|
||||
## 6. Consequences
|
||||
|
||||
### 6.1 Positive
|
||||
|
||||
- **Fundamental geometric improvement**: Viewpoint diversity reduces body self-occlusion and depth ambiguity -- these are physics, not model, limitations.
|
||||
- **Uses existing silicon**: ESP32-S3, commodity WiFi, no custom RF hardware required for Silver tier.
|
||||
- **Learned fusion weights**: Embedding-level fusion (Tier 3) outperforms hand-crafted feature-level fusion (Tier 2).
|
||||
- **Composes with existing ADRs**: AETHER (per-viewpoint), MERIDIAN (cross-environment), and RuView (cross-viewpoint) are orthogonal -- they compose freely.
|
||||
- **IEEE 802.11bf aligned**: TDM protocol maps to 802.11bf sensing sessions, enabling future migration to standard-compliant APs.
|
||||
- **Commodity price point**: $84 for 6-node Silver-tier deployment.
|
||||
|
||||
### 6.2 Negative
|
||||
|
||||
- **TDM rate reduction**: N viewpoints leads to per-viewpoint rate divided by N. With 6 nodes at 120 Hz aggregate, each viewpoint sees 20 Hz.
|
||||
- **More complex aggregator**: Embedding fusion + geometric bias learning adds ~25K parameters on top of per-viewpoint AETHER model.
|
||||
- **Placement planning required**: Geometric Diversity Index optimization requires intentional node placement (not random scatter).
|
||||
- **Clock drift limits TDM precision**: ESP32 crystal drift (20-50 ppm) limits slot precision to ~1 ms, which is sufficient for feature-level fusion but not signal-level coherent combining.
|
||||
- **Training data**: Cross-viewpoint training requires multi-receiver CSI captures, which are not available in existing public datasets (MM-Fi, Wi-Pose).
|
||||
|
||||
### 6.3 Interaction with Other ADRs
|
||||
|
||||
| ADR | Interaction |
|
||||
|-----|------------|
|
||||
| ADR-012 (ESP32 Mesh) | RuView extends the aggregator from feature-level to embedding-level fusion; TDM protocol replaces simple UDP collection |
|
||||
| ADR-014 (SOTA Signal) | Per-viewpoint signal processing is unchanged; cross-viewpoint subcarrier consensus is new |
|
||||
| ADR-016/017 (RuVector) | All 5 ruvector crates get new cross-viewpoint operations (see Section 4) |
|
||||
| ADR-021 (Vital Signs) | Multi-viewpoint SNR improvement directly benefits vital sign extraction (Gold tier target) |
|
||||
| ADR-024 (AETHER) | Per-viewpoint AETHER embeddings are the input to RuView fusion; AETHER is required |
|
||||
| ADR-027 (MERIDIAN) | Cross-environment (MERIDIAN) and cross-viewpoint (RuView) are orthogonal; MERIDIAN handles room transfer, RuView handles within-room geometry |
|
||||
|
||||
---
|
||||
|
||||
## 7. References
|
||||
|
||||
1. IEEE 802.11bf (2024). "WLAN Sensing." IEEE Standards Association.
|
||||
2. Kotaru, M. et al. (2015). "SpotFi: Decimeter Level Localization Using WiFi." SIGCOMM 2015.
|
||||
3. Zeng, Y. et al. (2019). "FarSense: Pushing the Range Limit of WiFi-based Respiration Sensing with CSI Ratio of Two Antennas." MobiCom 2019.
|
||||
4. Zheng, Y. et al. (2019). "Zero-Effort Cross-Domain Gesture Recognition with Wi-Fi." (Widar 3.0) MobiSys 2019.
|
||||
5. Yan, K. et al. (2024). "Person-in-WiFi 3D: End-to-End Multi-Person 3D Pose Estimation with Wi-Fi." CVPR 2024.
|
||||
6. Zhou, Y. et al. (2024). "AdaPose: Towards Cross-Site Device-Free Human Pose Estimation with Commodity WiFi." IEEE IoT Journal. arXiv:2309.16964.
|
||||
7. Zhou, R. et al. (2025). "DGSense: A Domain Generalization Framework for Wireless Sensing." arXiv:2502.08155.
|
||||
8. Chen, X. & Yang, J. (2025). "X-Fi: A Modality-Invariant Foundation Model for Multimodal Human Sensing." ICLR 2025. arXiv:2410.10167.
|
||||
9. AM-FM (2026). "AM-FM: A Foundation Model for Ambient Intelligence Through WiFi." arXiv:2602.11200.
|
||||
10. Chen, L. et al. (2026). "PerceptAlign: Breaking Coordinate Overfitting." arXiv:2601.12252.
|
||||
11. Li, J. & Stoica, P. (2007). "MIMO Radar with Colocated Antennas." IEEE Signal Processing Magazine, 24(5):106-114.
|
||||
12. ADR-012 through ADR-027 (internal).
|
||||
507
docs/adr/ADR-032-multistatic-mesh-security-hardening.md
Normal file
507
docs/adr/ADR-032-multistatic-mesh-security-hardening.md
Normal file
@@ -0,0 +1,507 @@
|
||||
# ADR-032: Multistatic Mesh Security Hardening
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| **Status** | Accepted |
|
||||
| **Date** | 2026-03-01 |
|
||||
| **Deciders** | ruv |
|
||||
| **Relates to** | ADR-029 (RuvSense Multistatic), ADR-030 (Persistent Field Model), ADR-031 (RuView Sensing-First RF), ADR-018 (ESP32 Implementation), ADR-012 (ESP32 Mesh) |
|
||||
|
||||
---
|
||||
|
||||
## 1. Context
|
||||
|
||||
### 1.1 Security Audit of ADR-029/030/031
|
||||
|
||||
A security audit of the RuvSense multistatic sensing stack (ADR-029 through ADR-031) identified seven findings across the TDM synchronization layer, CSI frame transport, NDP injection, coherence gating, cross-room tracking, NVS credential handling, and firmware concurrency model. Three severity levels were assigned: HIGH (1 finding), MEDIUM (3 findings), LOW (3 findings).
|
||||
|
||||
The findings fall into three categories:
|
||||
|
||||
1. **Missing cryptographic authentication** -- The TDM SyncBeacon and CSI frame formats lack any message authentication, allowing rogue nodes to inject spoofed beacons or frames into the mesh.
|
||||
2. **Unbounded or unprotected resources** -- The NDP injection path has no rate limiter, the coherence gate recalibration state has no timeout cap, and the cross-room transition log grows without bound.
|
||||
3. **Memory safety on embedded targets** -- NVS credential buffers are not zeroed after use, and static mutable globals in the CSI collector are accessed from both ESP32-S3 cores without synchronization.
|
||||
|
||||
### 1.2 Threat Model
|
||||
|
||||
The primary threat actor is a rogue ESP32 node on the same LAN subnet or within WiFi range of the mesh. The attack surface is the UDP broadcast plane used for sync beacons, CSI frames, and NDP injection.
|
||||
|
||||
| Threat | STRIDE | Impact | Exploitability |
|
||||
|--------|--------|--------|----------------|
|
||||
| Fake SyncBeacon injection | Spoofing, Tampering | Full mesh desynchronization, no pose output | Low skill, rogue ESP32 on LAN |
|
||||
| CSI frame spoofing | Spoofing, Tampering | Corrupted pose estimation, phantom occupants | Low skill, UDP packet injection |
|
||||
| NDP RF flooding | Denial of Service | Channel saturation, loss of CSI data | Low skill, repeated NDP calls |
|
||||
| Coherence gate stall | Denial of Service | Indefinite recalibration, frozen output | Requires sustained interference |
|
||||
| Transition log exhaustion | Denial of Service | OOM on aggregator after extended operation | Passive, no attacker needed |
|
||||
| Credential stack residue | Information Disclosure | WiFi password recoverable from RAM dump | Physical access to device |
|
||||
| Dual-core data race | Tampering, DoS | Corrupted CSI frames, undefined behavior | Passive, no attacker needed |
|
||||
|
||||
### 1.3 Design Constraints
|
||||
|
||||
- ESP32-S3 has limited CPU budget: cryptographic operations must complete within the 1 ms guard interval between TDM slots.
|
||||
- HMAC-SHA256 on ESP32-S3 (hardware-accelerated via `mbedtls`) completes in approximately 15 us for 24-byte payloads -- well within budget.
|
||||
- SipHash-2-4 completes in approximately 2 us for 64-byte payloads on ESP32-S3 -- suitable for per-frame MAC.
|
||||
- No TLS or TCP is available on the sensing data path (UDP broadcast for latency).
|
||||
- Pre-shared key (PSK) model is acceptable because all nodes in a mesh deployment are provisioned by the same operator.
|
||||
|
||||
---
|
||||
|
||||
## 2. Decision
|
||||
|
||||
Harden the multistatic mesh with six measures: beacon authentication, frame integrity, NDP rate limiting, bounded buffers, memory safety, and key management. All changes are backward-compatible: unauthenticated frames are accepted during a migration window controlled by a `security_level` NVS parameter.
|
||||
|
||||
### 2.1 Beacon Authentication Protocol (H-1)
|
||||
|
||||
**Finding:** The 16-byte `SyncBeacon` wire format (`crates/wifi-densepose-hardware/src/esp32/tdm.rs`) has no cryptographic authentication. A rogue node can inject fake beacons to desynchronize the TDM mesh.
|
||||
|
||||
**Solution:** Extend the SyncBeacon wire format from 16 bytes to 28 bytes by adding a 4-byte monotonic nonce and an 8-byte HMAC-SHA256 truncated tag.
|
||||
|
||||
```
|
||||
Authenticated SyncBeacon wire format (28 bytes):
|
||||
[0..7] cycle_id (LE u64)
|
||||
[8..11] cycle_period_us (LE u32)
|
||||
[12..13] drift_correction (LE i16)
|
||||
[14..15] reserved
|
||||
[16..19] nonce (LE u32, monotonically increasing)
|
||||
[20..27] hmac_tag (HMAC-SHA256 truncated to 8 bytes)
|
||||
```
|
||||
|
||||
**HMAC computation:**
|
||||
|
||||
```
|
||||
key = 16-byte pre-shared mesh key (stored in NVS, namespace "mesh_sec")
|
||||
message = beacon[0..20] (first 20 bytes: payload + nonce)
|
||||
tag = HMAC-SHA256(key, message)[0..8] (truncated to 8 bytes)
|
||||
```
|
||||
|
||||
**Nonce and replay protection:**
|
||||
|
||||
- The coordinator maintains a monotonically increasing 32-bit nonce counter, incremented on every beacon.
|
||||
- Each receiver maintains a `last_accepted_nonce` per sender. A beacon is accepted only if `nonce > last_accepted_nonce - REPLAY_WINDOW`, where `REPLAY_WINDOW = 16` (accounts for packet reordering over UDP).
|
||||
- Nonce overflow (after 2^32 beacons at 20 Hz = ~6.8 years) triggers a mandatory key rotation.
|
||||
|
||||
**Implementation location:** `crates/wifi-densepose-hardware/src/esp32/tdm.rs` -- extend `SyncBeacon::to_bytes()` and `SyncBeacon::from_bytes()` to produce/consume the 28-byte authenticated format. Add `SyncBeacon::verify()` method.
|
||||
|
||||
### 2.2 CSI Frame Integrity (M-3)
|
||||
|
||||
**Finding:** The ADR-018 CSI frame format has no cryptographic MAC. Frames can be spoofed or tampered with in transit.
|
||||
|
||||
**Solution:** Add an 8-byte SipHash-2-4 tag to the CSI frame header. SipHash is chosen over HMAC-SHA256 for per-frame MAC because it is 7x faster on ESP32 for short messages (approximately 2 us vs 15 us) and provides sufficient integrity for non-secret data.
|
||||
|
||||
```
|
||||
Extended CSI frame header (28 bytes, was 20):
|
||||
[0..3] Magic: 0xC5110002 (bumped from 0xC5110001 to signal auth)
|
||||
[4] Node ID
|
||||
[5] Number of antennas
|
||||
[6..7] Number of subcarriers (LE u16)
|
||||
[8..11] Frequency MHz (LE u32)
|
||||
[12..15] Sequence number (LE u32)
|
||||
[16] RSSI (i8)
|
||||
[17] Noise floor (i8)
|
||||
[18..19] Reserved
|
||||
[20..27] siphash_tag (SipHash-2-4 over [0..20] + IQ data)
|
||||
```
|
||||
|
||||
**SipHash key derivation:**
|
||||
|
||||
```
|
||||
siphash_key = HMAC-SHA256(mesh_key, "csi-frame-siphash")[0..16]
|
||||
```
|
||||
|
||||
The SipHash key is derived once at boot from the mesh key and cached in memory.
|
||||
|
||||
**Implementation locations:**
|
||||
- `firmware/esp32-csi-node/main/csi_collector.c` -- compute SipHash tag in `csi_serialize_frame()`, bump magic constant.
|
||||
- `crates/wifi-densepose-hardware/src/esp32/` -- add frame verification in the aggregator's frame parser.
|
||||
|
||||
### 2.3 NDP Injection Rate Limiter (M-4)
|
||||
|
||||
**Finding:** `csi_inject_ndp_frame()` in `firmware/esp32-csi-node/main/csi_collector.c` has no rate limiter. Uncontrolled NDP injection can flood the RF channel.
|
||||
|
||||
**Solution:** Token-bucket rate limiter with configurable parameters stored in NVS.
|
||||
|
||||
```c
|
||||
// Token bucket parameters (defaults)
|
||||
#define NDP_RATE_MAX_TOKENS 20 // burst capacity
|
||||
#define NDP_RATE_REFILL_HZ 20 // sustained rate: 20 NDP/sec
|
||||
#define NDP_RATE_REFILL_US (1000000 / NDP_RATE_REFILL_HZ)
|
||||
|
||||
typedef struct {
|
||||
uint32_t tokens; // current token count
|
||||
uint32_t max_tokens; // bucket capacity
|
||||
uint32_t refill_interval_us; // microseconds per token
|
||||
int64_t last_refill_us; // last refill timestamp
|
||||
} ndp_rate_limiter_t;
|
||||
```
|
||||
|
||||
`csi_inject_ndp_frame()` returns `ESP_ERR_NOT_ALLOWED` when the bucket is empty. The rate limiter parameters are configurable via NVS keys `ndp_max_tokens` and `ndp_refill_hz`.
|
||||
|
||||
**Implementation location:** `firmware/esp32-csi-node/main/csi_collector.c` -- add `ndp_rate_limiter_t` state and check in `csi_inject_ndp_frame()`.
|
||||
|
||||
### 2.4 Coherence Gate Recalibration Timeout (M-5)
|
||||
|
||||
**Finding:** The `Recalibrate` state in `crates/wifi-densepose-signal/src/ruvsense/coherence_gate.rs` can be held indefinitely. A sustained interference source could keep the system in perpetual recalibration, preventing any output.
|
||||
|
||||
**Solution:** Add a configurable `max_recalibrate_duration` to `GatePolicyConfig` (default: 30 seconds = 600 frames at 20 Hz). When the recalibration duration exceeds this cap, the gate transitions to a `ForcedAccept` state with inflated noise (10x), allowing degraded-but-available output.
|
||||
|
||||
```rust
|
||||
pub enum GateDecision {
|
||||
Accept { noise_multiplier: f32 },
|
||||
PredictOnly,
|
||||
Reject,
|
||||
Recalibrate { stale_frames: u64 },
|
||||
/// Recalibration timed out. Accept with heavily inflated noise.
|
||||
ForcedAccept { noise_multiplier: f32, stale_frames: u64 },
|
||||
}
|
||||
```
|
||||
|
||||
New config field:
|
||||
|
||||
```rust
|
||||
pub struct GatePolicyConfig {
|
||||
// ... existing fields ...
|
||||
/// Maximum frames in Recalibrate before forcing accept. Default: 600 (30s at 20Hz).
|
||||
pub max_recalibrate_frames: u64,
|
||||
/// Noise multiplier for ForcedAccept. Default: 10.0.
|
||||
pub forced_accept_noise: f32,
|
||||
}
|
||||
```
|
||||
|
||||
**Implementation location:** `crates/wifi-densepose-signal/src/ruvsense/coherence_gate.rs` -- extend `GateDecision` enum, modify `GatePolicy::evaluate()`.
|
||||
|
||||
### 2.5 Bounded Transition Log (L-1)
|
||||
|
||||
**Finding:** `CrossRoomTracker` in `crates/wifi-densepose-signal/src/ruvsense/cross_room.rs` stores transitions in an unbounded `Vec<TransitionEvent>`. Over extended operation (days/weeks), this grows without limit.
|
||||
|
||||
**Solution:** Replace the `transitions: Vec<TransitionEvent>` with a ring buffer that evicts the oldest entry when capacity is reached.
|
||||
|
||||
```rust
|
||||
pub struct CrossRoomConfig {
|
||||
// ... existing fields ...
|
||||
/// Maximum transitions retained in the ring buffer. Default: 1000.
|
||||
pub max_transitions: usize,
|
||||
}
|
||||
```
|
||||
|
||||
The ring buffer is implemented as a `VecDeque<TransitionEvent>` with a capacity check on push. When `transitions.len() >= max_transitions`, `transitions.pop_front()` before pushing. This preserves the append-only audit trail semantics (events are never mutated, only evicted by age).
|
||||
|
||||
**Implementation location:** `crates/wifi-densepose-signal/src/ruvsense/cross_room.rs` -- change `transitions: Vec<TransitionEvent>` to `transitions: VecDeque<TransitionEvent>`, add eviction logic in `match_entry()`.
|
||||
|
||||
### 2.6 NVS Password Buffer Zeroing (L-4)
|
||||
|
||||
**Finding:** `nvs_config_load()` in `firmware/esp32-csi-node/main/nvs_config.c` reads the WiFi password into a stack buffer `buf` which is not zeroed after use. On ESP32-S3, stack memory is not automatically cleared, leaving credentials recoverable via physical memory dump.
|
||||
|
||||
**Solution:** Zero the stack buffer after each NVS string read using `explicit_bzero()` (available in ESP-IDF via newlib). If `explicit_bzero` is unavailable, use `memset` with a volatile pointer to prevent compiler optimization.
|
||||
|
||||
```c
|
||||
/* After each nvs_get_str that may contain credentials: */
|
||||
explicit_bzero(buf, sizeof(buf));
|
||||
|
||||
/* Portable fallback: */
|
||||
static void secure_zero(void *ptr, size_t len) {
|
||||
volatile unsigned char *p = (volatile unsigned char *)ptr;
|
||||
while (len--) { *p++ = 0; }
|
||||
}
|
||||
```
|
||||
|
||||
Apply to all three `nvs_get_str` call sites in `nvs_config_load()` (ssid, password, target_ip).
|
||||
|
||||
**Implementation location:** `firmware/esp32-csi-node/main/nvs_config.c` -- add `explicit_bzero(buf, sizeof(buf))` after each `nvs_get_str` block.
|
||||
|
||||
### 2.7 Atomic Access for Static Mutable State (L-5)
|
||||
|
||||
**Finding:** `csi_collector.c` uses static mutable globals (`s_sequence`, `s_cb_count`, `s_send_ok`, `s_send_fail`, `s_hop_index`) accessed from both cores of the ESP32-S3 without synchronization. The CSI callback runs on the WiFi task (pinned to core 0 by default), while the main application and hop timer may run on core 1.
|
||||
|
||||
**Solution:** Use C11 `_Atomic` qualifiers for all shared counters, and a FreeRTOS mutex for the hop table state which requires multi-variable consistency.
|
||||
|
||||
```c
|
||||
#include <stdatomic.h>
|
||||
|
||||
static _Atomic uint32_t s_sequence = 0;
|
||||
static _Atomic uint32_t s_cb_count = 0;
|
||||
static _Atomic uint32_t s_send_ok = 0;
|
||||
static _Atomic uint32_t s_send_fail = 0;
|
||||
static _Atomic uint8_t s_hop_index = 0;
|
||||
|
||||
/* Hop table protected by mutex (multi-variable consistency) */
|
||||
static SemaphoreHandle_t s_hop_mutex = NULL;
|
||||
```
|
||||
|
||||
The mutex is created in `csi_collector_init()` and taken/released around hop table reads in `csi_hop_next_channel()` and writes in `csi_collector_set_hop_table()`.
|
||||
|
||||
**Implementation location:** `firmware/esp32-csi-node/main/csi_collector.c` -- add `_Atomic` qualifiers, create and use `s_hop_mutex`.
|
||||
|
||||
### 2.8 Key Management
|
||||
|
||||
All cryptographic operations use a single 16-byte pre-shared mesh key stored in NVS.
|
||||
|
||||
**Provisioning:**
|
||||
|
||||
```
|
||||
NVS namespace: "mesh_sec"
|
||||
NVS key: "mesh_key"
|
||||
NVS type: blob (16 bytes)
|
||||
```
|
||||
|
||||
The key is provisioned during node setup via the existing `scripts/provision.py` tool, which is extended to generate a random 16-byte key and flash it to all nodes in a deployment.
|
||||
|
||||
**Key derivation:**
|
||||
|
||||
```
|
||||
beacon_hmac_key = mesh_key (direct, 16 bytes)
|
||||
frame_siphash_key = HMAC-SHA256(mesh_key, "csi-frame-siphash")[0..16] (derived, 16 bytes)
|
||||
```
|
||||
|
||||
**Key rotation:**
|
||||
|
||||
- Manual rotation via management command: `provision.py rotate-key --deployment <id>`.
|
||||
- The coordinator broadcasts a key rotation event (signed with the old key) containing the new key encrypted with the old key.
|
||||
- Nodes accept the new key and switch after confirming the next beacon is signed with the new key.
|
||||
- Rotation is recommended every 90 days or after any node is decommissioned.
|
||||
|
||||
**Security level NVS parameter:**
|
||||
|
||||
```
|
||||
NVS key: "sec_level"
|
||||
Values:
|
||||
0 = permissive (accept unauthenticated frames, log warning)
|
||||
1 = transitional (accept both authenticated and unauthenticated)
|
||||
2 = enforcing (reject unauthenticated frames)
|
||||
Default: 1 (transitional, for backward compatibility during rollout)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. Implementation Plan (File-Level)
|
||||
|
||||
### 3.1 Phase 1: Beacon Authentication and Key Management
|
||||
|
||||
| File | Change | Priority |
|
||||
|------|--------|----------|
|
||||
| `crates/wifi-densepose-hardware/src/esp32/tdm.rs` | Extend `SyncBeacon` to 28-byte authenticated format, add `verify()`, nonce tracking, replay window | P0 |
|
||||
| `firmware/esp32-csi-node/main/nvs_config.c` | Add `mesh_key` and `sec_level` NVS reads | P0 |
|
||||
| `firmware/esp32-csi-node/main/nvs_config.h` | Add `mesh_key[16]` and `sec_level` to `nvs_config_t` | P0 |
|
||||
| `scripts/provision.py` | Add `--mesh-key` generation and `rotate-key` command | P0 |
|
||||
|
||||
### 3.2 Phase 2: Frame Integrity and Rate Limiting
|
||||
|
||||
| File | Change | Priority |
|
||||
|------|--------|----------|
|
||||
| `firmware/esp32-csi-node/main/csi_collector.c` | Add SipHash-2-4 tag to frame serialization, NDP rate limiter, `_Atomic` qualifiers, hop mutex | P1 |
|
||||
| `firmware/esp32-csi-node/main/csi_collector.h` | Update `CSI_HEADER_SIZE` to 28, add rate limiter config | P1 |
|
||||
| `crates/wifi-densepose-hardware/src/esp32/` | Add frame verification in aggregator parser | P1 |
|
||||
|
||||
### 3.3 Phase 3: Bounded Buffers and Gate Hardening
|
||||
|
||||
| File | Change | Priority |
|
||||
|------|--------|----------|
|
||||
| `crates/wifi-densepose-signal/src/ruvsense/cross_room.rs` | Replace `Vec` with `VecDeque`, add `max_transitions` config | P1 |
|
||||
| `crates/wifi-densepose-signal/src/ruvsense/coherence_gate.rs` | Add `ForcedAccept` variant, `max_recalibrate_frames` config | P1 |
|
||||
|
||||
### 3.4 Phase 4: Memory Safety
|
||||
|
||||
| File | Change | Priority |
|
||||
|------|--------|----------|
|
||||
| `firmware/esp32-csi-node/main/nvs_config.c` | Add `explicit_bzero()` after credential reads | P2 |
|
||||
| `firmware/esp32-csi-node/main/csi_collector.c` | `_Atomic` counters, `s_hop_mutex` (if not done in Phase 2) | P2 |
|
||||
|
||||
---
|
||||
|
||||
## 4. Acceptance Criteria
|
||||
|
||||
### 4.1 Beacon Authentication (H-1)
|
||||
|
||||
| ID | Criterion | Test Method |
|
||||
|----|-----------|-------------|
|
||||
| H1-1 | `SyncBeacon::to_bytes()` produces 28-byte output with valid HMAC tag | Unit test: serialize, verify tag matches recomputed HMAC |
|
||||
| H1-2 | `SyncBeacon::verify()` rejects beacons with incorrect HMAC tag | Unit test: flip one bit in tag, verify returns `Err` |
|
||||
| H1-3 | `SyncBeacon::verify()` rejects beacons with replayed nonce outside window | Unit test: submit nonce = last_accepted - REPLAY_WINDOW - 1, verify rejection |
|
||||
| H1-4 | `SyncBeacon::verify()` accepts beacons within replay window | Unit test: submit nonce = last_accepted - REPLAY_WINDOW + 1, verify acceptance |
|
||||
| H1-5 | Coordinator nonce increments monotonically across cycles | Unit test: call `begin_cycle()` 100 times, verify strict monotonicity |
|
||||
| H1-6 | Backward compatibility: `sec_level=0` accepts unauthenticated 16-byte beacons | Integration test: mixed old/new nodes |
|
||||
|
||||
### 4.2 Frame Integrity (M-3)
|
||||
|
||||
| ID | Criterion | Test Method |
|
||||
|----|-----------|-------------|
|
||||
| M3-1 | CSI frame with magic `0xC5110002` includes valid 8-byte SipHash tag | Unit test: serialize frame, verify tag |
|
||||
| M3-2 | Frame verification rejects frames with tampered IQ data | Unit test: flip one byte in IQ payload, verify rejection |
|
||||
| M3-3 | SipHash computation completes in < 10 us on ESP32-S3 | Benchmark on target hardware |
|
||||
| M3-4 | Frame parser accepts old magic `0xC5110001` when `sec_level < 2` | Unit test: backward compatibility |
|
||||
|
||||
### 4.3 NDP Rate Limiter (M-4)
|
||||
|
||||
| ID | Criterion | Test Method |
|
||||
|----|-----------|-------------|
|
||||
| M4-1 | `csi_inject_ndp_frame()` succeeds for first `max_tokens` calls | Unit test: call 20 times rapidly, all succeed |
|
||||
| M4-2 | Call 21 returns `ESP_ERR_NOT_ALLOWED` when bucket is empty | Unit test: exhaust bucket, verify error |
|
||||
| M4-3 | Bucket refills at configured rate | Unit test: exhaust, wait `refill_interval_us`, verify one token available |
|
||||
| M4-4 | NVS override of `ndp_max_tokens` and `ndp_refill_hz` is respected | Integration test: set NVS values, verify behavior |
|
||||
|
||||
### 4.4 Coherence Gate Timeout (M-5)
|
||||
|
||||
| ID | Criterion | Test Method |
|
||||
|----|-----------|-------------|
|
||||
| M5-1 | `GatePolicy::evaluate()` returns `Recalibrate` at `max_stale_frames` | Unit test: existing behavior preserved |
|
||||
| M5-2 | `GatePolicy::evaluate()` returns `ForcedAccept` at `max_recalibrate_frames` | Unit test: feed `max_recalibrate_frames + 1` low-coherence frames |
|
||||
| M5-3 | `ForcedAccept` noise multiplier equals `forced_accept_noise` (default 10.0) | Unit test: verify noise_multiplier field |
|
||||
| M5-4 | Default `max_recalibrate_frames` = 600 (30s at 20 Hz) | Unit test: verify default config |
|
||||
|
||||
### 4.5 Bounded Transition Log (L-1)
|
||||
|
||||
| ID | Criterion | Test Method |
|
||||
|----|-----------|-------------|
|
||||
| L1-1 | `CrossRoomTracker::transition_count()` never exceeds `max_transitions` | Unit test: insert 1500 transitions with max_transitions=1000, verify count=1000 |
|
||||
| L1-2 | Oldest transitions are evicted first (FIFO) | Unit test: verify first transition is the (N-999)th inserted |
|
||||
| L1-3 | Default `max_transitions` = 1000 | Unit test: verify default config |
|
||||
|
||||
### 4.6 NVS Password Zeroing (L-4)
|
||||
|
||||
| ID | Criterion | Test Method |
|
||||
|----|-----------|-------------|
|
||||
| L4-1 | Stack buffer `buf` is zeroed after each `nvs_get_str` call | Code review + static analysis (no runtime test feasible) |
|
||||
| L4-2 | `explicit_bzero` is used (not plain `memset`) to prevent compiler optimization | Code review: verify function call is `explicit_bzero` or volatile-pointer pattern |
|
||||
|
||||
### 4.7 Atomic Static State (L-5)
|
||||
|
||||
| ID | Criterion | Test Method |
|
||||
|----|-----------|-------------|
|
||||
| L5-1 | `s_sequence`, `s_cb_count`, `s_send_ok`, `s_send_fail` are declared `_Atomic` | Code review |
|
||||
| L5-2 | `s_hop_mutex` is created in `csi_collector_init()` | Code review + integration test: init succeeds |
|
||||
| L5-3 | `csi_hop_next_channel()` and `csi_collector_set_hop_table()` acquire/release mutex | Code review |
|
||||
| L5-4 | No data races detected under ThreadSanitizer (host-side test build) | `cargo test` with TSAN on host (for Rust side); QEMU or hardware test for C side |
|
||||
|
||||
---
|
||||
|
||||
## 5. Consequences
|
||||
|
||||
### 5.1 Positive
|
||||
|
||||
- **Rogue node protection**: HMAC-authenticated beacons prevent mesh desynchronization by unauthorized nodes.
|
||||
- **Frame integrity**: SipHash MAC detects in-transit tampering of CSI data, preventing phantom occupant injection.
|
||||
- **RF availability**: Token-bucket rate limiter prevents NDP flooding from consuming the shared wireless medium.
|
||||
- **Bounded memory**: Ring buffer on transition log and timeout cap on recalibration prevent resource exhaustion during long-running deployments.
|
||||
- **Credential hygiene**: Zeroed buffers reduce the window for credential recovery from physical memory access.
|
||||
- **Thread safety**: Atomic operations and mutex eliminate undefined behavior on dual-core ESP32-S3.
|
||||
- **Backward compatible**: `sec_level` parameter allows gradual rollout without breaking existing deployments.
|
||||
|
||||
### 5.2 Negative
|
||||
|
||||
- **12 bytes added to SyncBeacon**: 28 bytes vs 16 bytes (75% increase, but still fits in a single UDP packet with room to spare).
|
||||
- **8 bytes added to CSI frame header**: 28 bytes vs 20 bytes (40% increase in header; negligible relative to IQ payload of 128-512 bytes).
|
||||
- **CPU overhead**: HMAC-SHA256 adds approximately 15 us per beacon (once per 50 ms cycle = 0.03% CPU). SipHash adds approximately 2 us per frame (at 100 Hz = 0.02% CPU).
|
||||
- **Key management complexity**: Mesh key must be provisioned to all nodes and rotated periodically. Lost key requires re-provisioning all nodes.
|
||||
- **Mutex contention**: Hop table mutex may add up to 1 us latency to channel hop path. Within guard interval budget.
|
||||
|
||||
### 5.3 Risks
|
||||
|
||||
| Risk | Probability | Impact | Mitigation |
|
||||
|------|-------------|--------|------------|
|
||||
| HMAC computation exceeds guard interval on older ESP32 (non-S3) | Low | Beacon authentication unusable on legacy hardware | Hardware-accelerated SHA256 is available on all ESP32 variants; benchmark confirms < 50 us |
|
||||
| Key compromise via side-channel on ESP32 | Very Low | Full mesh authentication bypass | Keys stored in eFuse (ESP32-S3 supports) or encrypted NVS partition |
|
||||
| ForcedAccept mode produces unacceptably noisy poses | Medium | Degraded pose quality during sustained interference | 10x noise multiplier is configurable; operator can increase or disable |
|
||||
| SipHash collision (64-bit tag) | Very Low | Single forged frame accepted | 2^-64 probability per frame; attacker cannot iterate at protocol speed |
|
||||
|
||||
---
|
||||
|
||||
## 6. QUIC Transport Layer (ADR-032a Amendment)
|
||||
|
||||
### 6.1 Motivation
|
||||
|
||||
The original ADR-032 design (Sections 2.1--2.2) uses manual HMAC-SHA256 and SipHash-2-4 over plain UDP. While correct and efficient on constrained ESP32 hardware, this approach has operational drawbacks:
|
||||
|
||||
- **Manual key rotation**: Requires custom key exchange protocol and coordinator broadcast.
|
||||
- **No congestion control**: Plain UDP has no backpressure; burst CSI traffic can overwhelm the aggregator.
|
||||
- **No connection migration**: Node roaming (e.g., repositioning an ESP32) requires manual reconnect.
|
||||
- **Duplicate replay-window code**: Custom nonce tracking duplicates QUIC's built-in replay protection.
|
||||
|
||||
### 6.2 Decision: Adopt `midstreamer-quic` for Aggregator Uplinks
|
||||
|
||||
For aggregator-class nodes (Raspberry Pi, x86 gateway) that have sufficient CPU and memory, replace the manual crypto layer with `midstreamer-quic` v0.1.0, which provides:
|
||||
|
||||
| Capability | Manual (ADR-032 original) | QUIC (`midstreamer-quic`) |
|
||||
|---|---|---|
|
||||
| Authentication | HMAC-SHA256 truncated 8B | TLS 1.3 AEAD (AES-128-GCM) |
|
||||
| Frame integrity | SipHash-2-4 tag | QUIC packet-level AEAD |
|
||||
| Replay protection | Manual nonce + window | QUIC packet numbers (monotonic) |
|
||||
| Key rotation | Custom coordinator broadcast | TLS 1.3 `KeyUpdate` message |
|
||||
| Congestion control | None | QUIC cubic/BBR |
|
||||
| Connection migration | Not supported | QUIC connection ID migration |
|
||||
| Multi-stream | N/A | QUIC streams (beacon, CSI, control) |
|
||||
|
||||
**Constrained devices (ESP32-S3) retain the manual crypto path** from Sections 2.1--2.2 as a fallback. The `SecurityMode` enum selects the transport:
|
||||
|
||||
```rust
|
||||
pub enum SecurityMode {
|
||||
/// Manual HMAC/SipHash over plain UDP (ESP32-S3, ADR-032 original).
|
||||
ManualCrypto,
|
||||
/// QUIC transport with TLS 1.3 (aggregator-class nodes).
|
||||
QuicTransport,
|
||||
}
|
||||
```
|
||||
|
||||
### 6.3 QUIC Stream Mapping
|
||||
|
||||
Three dedicated QUIC streams separate traffic by priority:
|
||||
|
||||
| Stream ID | Purpose | Direction | Priority |
|
||||
|---|---|---|---|
|
||||
| 0 | Sync beacons | Coordinator -> Nodes | Highest (TDM timing-critical) |
|
||||
| 1 | CSI frames | Nodes -> Aggregator | High (sensing data) |
|
||||
| 2 | Control plane | Bidirectional | Normal (config, key rotation, health) |
|
||||
|
||||
### 6.4 Additional Midstreamer Integrations
|
||||
|
||||
Beyond QUIC transport, three additional midstreamer crates enhance the sensing pipeline:
|
||||
|
||||
1. **`midstreamer-scheduler` v0.1.0** -- Replaces manual timer-based TDM slot scheduling with an ultra-low-latency real-time task scheduler. Provides deterministic slot firing with sub-microsecond jitter.
|
||||
|
||||
2. **`midstreamer-temporal-compare` v0.1.0** -- Enhances gesture DTW matching (ADR-030 Tier 6) with temporal sequence comparison primitives. Provides optimized Sakoe-Chiba band DTW, LCS, and edit-distance kernels.
|
||||
|
||||
3. **`midstreamer-attractor` v0.1.0** -- Enhances longitudinal drift detection (ADR-030 Tier 4) with dynamical systems analysis. Detects phase-space attractor shifts that indicate biomechanical regime changes before they manifest as simple metric drift.
|
||||
|
||||
### 6.5 Fallback Strategy
|
||||
|
||||
The QUIC transport layer is additive, not a replacement:
|
||||
|
||||
- **ESP32-S3 nodes**: Continue using manual HMAC/SipHash over UDP (Sections 2.1--2.2). These devices lack the memory for a full TLS 1.3 stack.
|
||||
- **Aggregator nodes**: Use `midstreamer-quic` by default. Fall back to manual crypto if QUIC handshake fails (e.g., network partitions).
|
||||
- **Mixed deployments**: The aggregator auto-detects whether an incoming connection is QUIC (by TLS ClientHello) or plain UDP (by magic byte) and routes accordingly.
|
||||
|
||||
### 6.6 Acceptance Criteria (QUIC)
|
||||
|
||||
| ID | Criterion | Test Method |
|
||||
|----|-----------|-------------|
|
||||
| Q-1 | QUIC connection established between two nodes within 100ms | Integration test: connect, measure handshake time |
|
||||
| Q-2 | Beacon stream delivers beacons with < 1ms jitter | Unit test: send 1000 beacons, measure inter-arrival variance |
|
||||
| Q-3 | CSI stream achieves >= 95% of plain UDP throughput | Benchmark: criterion comparison |
|
||||
| Q-4 | Connection migration succeeds after simulated IP change | Integration test: rebind, verify stream continuity |
|
||||
| Q-5 | Fallback to manual crypto when QUIC unavailable | Unit test: reject QUIC, verify ManualCrypto path |
|
||||
| Q-6 | SecurityMode::ManualCrypto produces identical wire format to ADR-032 original | Unit test: byte-level comparison |
|
||||
|
||||
---
|
||||
|
||||
## 7. Related ADRs
|
||||
|
||||
| ADR | Relationship |
|
||||
|-----|-------------|
|
||||
| ADR-029 (RuvSense Multistatic) | **Hardened**: TDM beacon and CSI frame authentication, NDP rate limiting, QUIC transport |
|
||||
| ADR-030 (Persistent Field Model) | **Protected**: Coherence gate timeout; transition log bounded; gesture DTW enhanced (midstreamer-temporal-compare); drift detection enhanced (midstreamer-attractor) |
|
||||
| ADR-031 (RuView RF Mode) | **Hardened**: Authenticated beacons protect cross-viewpoint synchronization via QUIC streams |
|
||||
| ADR-018 (ESP32 Implementation) | **Extended**: CSI frame header bumped to v2 with SipHash tag; backward-compatible magic check |
|
||||
| ADR-012 (ESP32 Mesh) | **Hardened**: Mesh key management, NVS credential zeroing, atomic firmware state, QUIC connection migration |
|
||||
|
||||
---
|
||||
|
||||
## 8. References
|
||||
|
||||
1. Aumasson, J.-P. & Bernstein, D.J. (2012). "SipHash: a fast short-input PRF." INDOCRYPT 2012.
|
||||
2. Krawczyk, H. et al. (1997). "HMAC: Keyed-Hashing for Message Authentication." RFC 2104.
|
||||
3. ESP-IDF mbedtls SHA256 hardware acceleration. Espressif Documentation.
|
||||
4. Espressif. "ESP32-S3 Technical Reference Manual." Section 26: SHA Accelerator.
|
||||
5. Turner, J. (2006). "Token Bucket Rate Limiting." RFC 2697 (adapted).
|
||||
6. ADR-029 through ADR-031 (internal).
|
||||
7. `midstreamer-quic` v0.1.0 -- QUIC multi-stream support. crates.io.
|
||||
8. `midstreamer-scheduler` v0.1.0 -- Ultra-low-latency real-time task scheduler. crates.io.
|
||||
9. `midstreamer-temporal-compare` v0.1.0 -- Temporal sequence comparison. crates.io.
|
||||
10. `midstreamer-attractor` v0.1.0 -- Dynamical systems analysis. crates.io.
|
||||
11. Iyengar, J. & Thomson, M. (2021). "QUIC: A UDP-Based Multiplexed and Secure Transport." RFC 9000.
|
||||
740
docs/adr/ADR-033-crv-signal-line-sensing-integration.md
Normal file
740
docs/adr/ADR-033-crv-signal-line-sensing-integration.md
Normal file
@@ -0,0 +1,740 @@
|
||||
# ADR-033: CRV Signal Line Sensing Integration -- Mapping 6-Stage Coordinate Remote Viewing to WiFi-DensePose Pipeline
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| **Status** | Proposed |
|
||||
| **Date** | 2026-03-01 |
|
||||
| **Deciders** | ruv |
|
||||
| **Codename** | **CRV-Sense** -- Coordinate Remote Viewing Signal Line for WiFi Sensing |
|
||||
| **Relates to** | ADR-016 (RuVector Integration), ADR-017 (RuVector Signal+MAT), ADR-024 (AETHER Embeddings), ADR-029 (RuvSense Multistatic), ADR-030 (Persistent Field Model), ADR-031 (RuView Viewpoint Fusion), ADR-032 (Mesh Security) |
|
||||
|
||||
---
|
||||
|
||||
## 1. Context
|
||||
|
||||
### 1.1 The CRV Signal Line Methodology
|
||||
|
||||
Coordinate Remote Viewing (CRV) is a structured 6-stage protocol that progressively refines perception from coarse gestalt impressions (Stage I) through sensory details (Stage II), spatial dimensions (Stage III), noise separation (Stage IV), cross-referencing interrogation (Stage V), to a final composite 3D model (Stage VI). The `ruvector-crv` crate (v0.1.1, published on crates.io) maps these 6 stages to vector database subsystems: Poincare ball embeddings, multi-head attention, GNN graph topology, SNN temporal encoding, differentiable search, and MinCut partitioning.
|
||||
|
||||
The WiFi-DensePose sensing pipeline follows a strikingly similar progressive refinement:
|
||||
|
||||
1. Raw CSI arrives as an undifferentiated signal -- the system must first classify the gestalt character of the RF environment.
|
||||
2. Per-subcarrier amplitude/phase/frequency features are extracted -- analogous to sensory impressions.
|
||||
3. The AP mesh forms a spatial topology with node positions and link geometry -- a dimensional sketch.
|
||||
4. Coherence gating separates valid signal from noise and interference -- analytically overlaid artifacts must be detected and removed.
|
||||
5. Pose estimation queries earlier CSI features for cross-referencing -- interrogation of the accumulated evidence.
|
||||
6. Final multi-person partitioning produces the composite DensePose output -- the 3D model.
|
||||
|
||||
This structural isomorphism is not accidental. Both CRV and WiFi sensing solve the same abstract problem: extract structured information from a noisy, high-dimensional signal space through progressive refinement with explicit noise separation.
|
||||
|
||||
### 1.2 The ruvector-crv Crate (v0.1.1)
|
||||
|
||||
The `ruvector-crv` crate provides the following public API:
|
||||
|
||||
| Component | Purpose | Upstream Dependency |
|
||||
|-----------|---------|-------------------|
|
||||
| `CrvSessionManager` | Session lifecycle: create, add stage data, convergence analysis | -- |
|
||||
| `StageIEncoder` | Poincare ball hyperbolic embeddings for gestalt primitives | -- (internal hyperbolic math) |
|
||||
| `StageIIEncoder` | Multi-head attention for sensory vectors | `ruvector-attention` |
|
||||
| `StageIIIEncoder` | GNN graph topology encoding | `ruvector-gnn` |
|
||||
| `StageIVEncoder` | SNN temporal encoding for AOL (Analytical Overlay) detection | -- (internal SNN) |
|
||||
| `StageVEngine` | Differentiable search and cross-referencing | -- (internal soft attention) |
|
||||
| `StageVIModeler` | MinCut partitioning for composite model | `ruvector-mincut` |
|
||||
| `ConvergenceResult` | Cross-session agreement analysis | -- |
|
||||
| `CrvConfig` | Configuration (384-d default, curvature, AOL threshold, SNN params) | -- |
|
||||
|
||||
Key types: `GestaltType` (Manmade/Natural/Movement/Energy/Water/Land), `SensoryModality` (Texture/Color/Temperature/Sound/...), `AOLDetection` (content + anomaly score), `SignalLineProbe` (query + attention weights), `TargetPartition` (MinCut cluster + centroid).
|
||||
|
||||
### 1.3 What Already Exists in WiFi-DensePose
|
||||
|
||||
The following modules already implement pieces of the pipeline that CRV stages map onto:
|
||||
|
||||
| Existing Module | Location | Relevant CRV Stage |
|
||||
|----------------|----------|-------------------|
|
||||
| `multiband.rs` | `wifi-densepose-signal/src/ruvsense/` | Stage I (gestalt from multi-band CSI) |
|
||||
| `phase_align.rs` | `wifi-densepose-signal/src/ruvsense/` | Stage II (phase feature extraction) |
|
||||
| `multistatic.rs` | `wifi-densepose-signal/src/ruvsense/` | Stage III (AP mesh spatial topology) |
|
||||
| `coherence_gate.rs` | `wifi-densepose-signal/src/ruvsense/` | Stage IV (signal-vs-noise separation) |
|
||||
| `field_model.rs` | `wifi-densepose-signal/src/ruvsense/` | Stage V (persistent field for querying) |
|
||||
| `pose_tracker.rs` | `wifi-densepose-signal/src/ruvsense/` | Stage VI (person tracking output) |
|
||||
| Viewpoint fusion | `wifi-densepose-ruvector/src/viewpoint/` | Cross-session (multi-viewpoint convergence) |
|
||||
|
||||
The `wifi-densepose-ruvector` crate already depends on `ruvector-crv` in its `Cargo.toml`. This ADR defines how to wrap the CRV API with WiFi-DensePose domain types.
|
||||
|
||||
### 1.4 The Key Insight: Cross-Session Convergence = Cross-Room Identity
|
||||
|
||||
CRV's convergence analysis compares independent sessions targeting the same coordinate to find agreement in their embeddings. In WiFi-DensePose, different AP clusters in different rooms are independent "viewers" of the same person. When a person moves from Room A to Room B, the CRV convergence mechanism can find agreement between the Room A embedding trail and the Room B initial embeddings -- establishing identity continuity without cameras.
|
||||
|
||||
---
|
||||
|
||||
## 2. Decision
|
||||
|
||||
### 2.1 The 6-Stage CRV-to-WiFi Mapping
|
||||
|
||||
Create a new `crv` module in the `wifi-densepose-ruvector` crate that wraps `ruvector-crv` with WiFi-DensePose domain types. Each CRV stage maps to a specific point in the sensing pipeline.
|
||||
|
||||
```
|
||||
+-------------------------------------------------------------------+
|
||||
| CRV-Sense Pipeline (6 Stages) |
|
||||
+-------------------------------------------------------------------+
|
||||
| |
|
||||
| Raw CSI frames from ESP32 mesh (ADR-029) |
|
||||
| | |
|
||||
| v |
|
||||
| +----------------------------------------------------------+ |
|
||||
| | Stage I: CSI Gestalt Classification | |
|
||||
| | CsiGestaltClassifier | |
|
||||
| | Input: raw CSI frame (amplitude envelope + phase slope) | |
|
||||
| | Output: GestaltType (Manmade/Natural/Movement/Energy) | |
|
||||
| | Encoder: StageIEncoder (Poincare ball embedding) | |
|
||||
| | Module: ruvsense/multiband.rs | |
|
||||
| +----------------------------+-----------------------------+ |
|
||||
| | |
|
||||
| v |
|
||||
| +----------------------------------------------------------+ |
|
||||
| | Stage II: CSI Sensory Feature Extraction | |
|
||||
| | CsiSensoryEncoder | |
|
||||
| | Input: per-subcarrier CSI | |
|
||||
| | Output: amplitude textures, phase patterns, freq colors | |
|
||||
| | Encoder: StageIIEncoder (multi-head attention vectors) | |
|
||||
| | Module: ruvsense/phase_align.rs | |
|
||||
| +----------------------------+-----------------------------+ |
|
||||
| | |
|
||||
| v |
|
||||
| +----------------------------------------------------------+ |
|
||||
| | Stage III: AP Mesh Spatial Topology | |
|
||||
| | MeshTopologyEncoder | |
|
||||
| | Input: node positions, link SNR, baseline distances | |
|
||||
| | Output: GNN graph embedding of mesh geometry | |
|
||||
| | Encoder: StageIIIEncoder (GNN topology) | |
|
||||
| | Module: ruvsense/multistatic.rs | |
|
||||
| +----------------------------+-----------------------------+ |
|
||||
| | |
|
||||
| v |
|
||||
| +----------------------------------------------------------+ |
|
||||
| | Stage IV: Coherence Gating (AOL Detection) | |
|
||||
| | CoherenceAolDetector | |
|
||||
| | Input: phase coherence scores, gate decisions | |
|
||||
| | Output: AOL-flagged frames removed, clean signal kept | |
|
||||
| | Encoder: StageIVEncoder (SNN temporal encoding) | |
|
||||
| | Module: ruvsense/coherence_gate.rs | |
|
||||
| +----------------------------+-----------------------------+ |
|
||||
| | |
|
||||
| v |
|
||||
| +----------------------------------------------------------+ |
|
||||
| | Stage V: Pose Interrogation | |
|
||||
| | PoseInterrogator | |
|
||||
| | Input: pose hypothesis + accumulated CSI features | |
|
||||
| | Output: soft attention over CSI history, top candidates | |
|
||||
| | Engine: StageVEngine (differentiable search) | |
|
||||
| | Module: ruvsense/field_model.rs | |
|
||||
| +----------------------------+-----------------------------+ |
|
||||
| | |
|
||||
| v |
|
||||
| +----------------------------------------------------------+ |
|
||||
| | Stage VI: Multi-Person Partitioning | |
|
||||
| | PersonPartitioner | |
|
||||
| | Input: all person embedding clusters | |
|
||||
| | Output: MinCut-separated person partitions + centroids | |
|
||||
| | Modeler: StageVIModeler (MinCut partitioning) | |
|
||||
| | Module: training pipeline (ruvector-mincut) | |
|
||||
| +----------------------------+-----------------------------+ |
|
||||
| | |
|
||||
| v |
|
||||
| +----------------------------------------------------------+ |
|
||||
| | Cross-Session: Multi-Room Convergence | |
|
||||
| | MultiViewerConvergence | |
|
||||
| | Input: per-room embedding trails for candidate persons | |
|
||||
| | Output: cross-room identity matches + confidence | |
|
||||
| | Engine: CrvSessionManager::find_convergence() | |
|
||||
| | Module: ruvsense/cross_room.rs | |
|
||||
| +----------------------------------------------------------+ |
|
||||
+-------------------------------------------------------------------+
|
||||
```
|
||||
|
||||
### 2.2 Stage I: CSI Gestalt Classification
|
||||
|
||||
**CRV mapping:** Stage I ideograms classify the target's fundamental character (Manmade/Natural/Movement/Energy). In WiFi sensing, the raw CSI frame's amplitude envelope shape and phase slope direction provide an analogous gestalt classification of the RF environment.
|
||||
|
||||
**WiFi domain types:**
|
||||
|
||||
```rust
|
||||
/// CSI-domain gestalt types mapped from CRV GestaltType.
|
||||
///
|
||||
/// The CRV taxonomy maps to RF phenomenology:
|
||||
/// - Manmade: structured multipath (walls, furniture, metallic reflectors)
|
||||
/// - Natural: diffuse scattering (vegetation, irregular surfaces)
|
||||
/// - Movement: Doppler-shifted components (human motion, fan, pet)
|
||||
/// - Energy: high-amplitude transients (microwave, motor, interference)
|
||||
/// - Water: slow fading envelope (humidity change, condensation)
|
||||
/// - Land: static baseline (empty room, no perturbation)
|
||||
pub struct CsiGestaltClassifier {
|
||||
encoder: StageIEncoder,
|
||||
config: CrvConfig,
|
||||
}
|
||||
|
||||
impl CsiGestaltClassifier {
|
||||
/// Classify a raw CSI frame into a gestalt type.
|
||||
///
|
||||
/// Extracts three features from the CSI frame:
|
||||
/// 1. Amplitude envelope shape (ideogram stroke analog)
|
||||
/// 2. Phase slope direction (spontaneous descriptor analog)
|
||||
/// 3. Subcarrier correlation structure (classification signal)
|
||||
///
|
||||
/// Returns a Poincare ball embedding (384-d by default) encoding
|
||||
/// the hierarchical gestalt taxonomy with exponentially less
|
||||
/// distortion than Euclidean space.
|
||||
pub fn classify(&self, csi_frame: &CsiFrame) -> CrvResult<(GestaltType, Vec<f32>)>;
|
||||
}
|
||||
```
|
||||
|
||||
**Integration point:** `ruvsense/multiband.rs` already processes multi-band CSI. The `CsiGestaltClassifier` wraps this with Poincare ball embedding via `StageIEncoder`, producing a hyperbolic embedding that captures the gestalt hierarchy.
|
||||
|
||||
### 2.3 Stage II: CSI Sensory Feature Extraction
|
||||
|
||||
**CRV mapping:** Stage II collects sensory impressions (texture, color, temperature). In WiFi sensing, the per-subcarrier CSI features are the sensory modalities:
|
||||
|
||||
| CRV Sensory Modality | WiFi CSI Analog |
|
||||
|----------------------|-----------------|
|
||||
| Texture | Amplitude variance pattern across subcarriers (smooth vs rough surface reflection) |
|
||||
| Color | Frequency-domain spectral shape (which subcarriers carry the most energy) |
|
||||
| Temperature | Phase drift rate (thermal expansion changes path length) |
|
||||
| Luminosity | Overall signal power level (SNR) |
|
||||
| Dimension | Delay spread (multipath extent maps to room size) |
|
||||
|
||||
**WiFi domain types:**
|
||||
|
||||
```rust
|
||||
pub struct CsiSensoryEncoder {
|
||||
encoder: StageIIEncoder,
|
||||
}
|
||||
|
||||
impl CsiSensoryEncoder {
|
||||
/// Extract sensory features from per-subcarrier CSI data.
|
||||
///
|
||||
/// Maps CSI signal characteristics to CRV sensory modalities:
|
||||
/// - Amplitude variance -> Texture
|
||||
/// - Spectral shape -> Color
|
||||
/// - Phase drift rate -> Temperature
|
||||
/// - Signal power -> Luminosity
|
||||
/// - Delay spread -> Dimension
|
||||
///
|
||||
/// Uses multi-head attention (ruvector-attention) to produce
|
||||
/// a unified sensory embedding that captures cross-modality
|
||||
/// correlations.
|
||||
pub fn encode(&self, csi_subcarriers: &SubcarrierData) -> CrvResult<Vec<f32>>;
|
||||
}
|
||||
```
|
||||
|
||||
**Integration point:** `ruvsense/phase_align.rs` already computes per-subcarrier phase features. The `CsiSensoryEncoder` maps these to `StageIIData` sensory impressions and produces attention-weighted embeddings via `StageIIEncoder`.
|
||||
|
||||
### 2.4 Stage III: AP Mesh Spatial Topology
|
||||
|
||||
**CRV mapping:** Stage III sketches the spatial layout with geometric primitives and relationships. In WiFi sensing, the AP mesh nodes and their inter-node links form the spatial sketch:
|
||||
|
||||
| CRV Sketch Element | WiFi Mesh Analog |
|
||||
|-------------------|-----------------|
|
||||
| `SketchElement` | AP node (position, antenna orientation) |
|
||||
| `GeometricKind::Point` | Single AP location |
|
||||
| `GeometricKind::Line` | Bistatic link between two APs |
|
||||
| `SpatialRelationship` | Link quality, baseline distance, angular separation |
|
||||
|
||||
**WiFi domain types:**
|
||||
|
||||
```rust
|
||||
pub struct MeshTopologyEncoder {
|
||||
encoder: StageIIIEncoder,
|
||||
}
|
||||
|
||||
impl MeshTopologyEncoder {
|
||||
/// Encode the AP mesh as a GNN graph topology.
|
||||
///
|
||||
/// Each AP node becomes a SketchElement with its position and
|
||||
/// antenna count. Each bistatic link becomes a SpatialRelationship
|
||||
/// with strength proportional to link SNR.
|
||||
///
|
||||
/// Uses ruvector-gnn to produce a graph embedding that captures
|
||||
/// the mesh's geometric diversity index (GDI) and effective
|
||||
/// viewpoint count.
|
||||
pub fn encode(&self, mesh: &MultistaticArray) -> CrvResult<Vec<f32>>;
|
||||
}
|
||||
```
|
||||
|
||||
**Integration point:** `ruvsense/multistatic.rs` manages the AP mesh topology. The `MeshTopologyEncoder` translates `MultistaticArray` geometry into `StageIIIData` sketch elements and relationships, producing a GNN-encoded topology embedding via `StageIIIEncoder`.
|
||||
|
||||
### 2.5 Stage IV: Coherence Gating as AOL Detection
|
||||
|
||||
**CRV mapping:** Stage IV detects Analytical Overlay (AOL) -- moments when the analytical mind contaminates the raw signal with pre-existing assumptions. In WiFi sensing, the coherence gate (ADR-030/032) serves the same function: it detects when environmental interference, multipath changes, or hardware artifacts contaminate the CSI signal, and flags those frames for exclusion.
|
||||
|
||||
| CRV AOL Concept | WiFi Coherence Analog |
|
||||
|-----------------|---------------------|
|
||||
| AOL event | Low-coherence frame (interference, multipath shift, hardware glitch) |
|
||||
| AOL anomaly score | Coherence metric (0.0 = fully incoherent, 1.0 = fully coherent) |
|
||||
| AOL break (flagged, set aside) | `GateDecision::Reject` or `GateDecision::PredictOnly` |
|
||||
| Clean signal line | `GateDecision::Accept` with noise multiplier |
|
||||
| Forced accept after timeout | `GateDecision::ForcedAccept` (ADR-032) with inflated noise |
|
||||
|
||||
**WiFi domain types:**
|
||||
|
||||
```rust
|
||||
pub struct CoherenceAolDetector {
|
||||
encoder: StageIVEncoder,
|
||||
}
|
||||
|
||||
impl CoherenceAolDetector {
|
||||
/// Map coherence gate decisions to CRV AOL detection.
|
||||
///
|
||||
/// The SNN temporal encoding models the spike pattern of
|
||||
/// coherence violations over time:
|
||||
/// - Burst of low-coherence frames -> high AOL anomaly score
|
||||
/// - Sustained coherence -> low anomaly score (clean signal)
|
||||
/// - Single transient -> moderate score (check and continue)
|
||||
///
|
||||
/// Returns an embedding that encodes the temporal pattern of
|
||||
/// signal quality, enabling downstream stages to weight their
|
||||
/// attention based on signal cleanliness.
|
||||
pub fn detect(
|
||||
&self,
|
||||
coherence_history: &[GateDecision],
|
||||
timestamps: &[u64],
|
||||
) -> CrvResult<(Vec<AOLDetection>, Vec<f32>)>;
|
||||
}
|
||||
```
|
||||
|
||||
**Integration point:** `ruvsense/coherence_gate.rs` already produces `GateDecision` values. The `CoherenceAolDetector` translates the coherence gate's temporal stream into `StageIVData` with `AOLDetection` events, and the SNN temporal encoding via `StageIVEncoder` produces an embedding of signal quality over time.
|
||||
|
||||
### 2.6 Stage V: Pose Interrogation via Differentiable Search
|
||||
|
||||
**CRV mapping:** Stage V is the interrogation phase -- probing earlier stage data with specific queries to extract targeted information. In WiFi sensing, this maps to querying the accumulated CSI feature history with a pose hypothesis to find supporting or contradicting evidence.
|
||||
|
||||
**WiFi domain types:**
|
||||
|
||||
```rust
|
||||
pub struct PoseInterrogator {
|
||||
engine: StageVEngine,
|
||||
}
|
||||
|
||||
impl PoseInterrogator {
|
||||
/// Cross-reference a pose hypothesis against CSI history.
|
||||
///
|
||||
/// Uses differentiable search (soft attention with temperature
|
||||
/// scaling) to find which historical CSI frames best support
|
||||
/// or contradict the current pose estimate.
|
||||
///
|
||||
/// Returns:
|
||||
/// - Attention weights over the CSI history buffer
|
||||
/// - Top-k supporting frames (highest attention)
|
||||
/// - Cross-references linking pose keypoints to specific
|
||||
/// CSI subcarrier features from earlier stages
|
||||
pub fn interrogate(
|
||||
&self,
|
||||
pose_embedding: &[f32],
|
||||
csi_history: &[CrvSessionEntry],
|
||||
) -> CrvResult<(StageVData, Vec<f32>)>;
|
||||
}
|
||||
```
|
||||
|
||||
**Integration point:** `ruvsense/field_model.rs` maintains the persistent electromagnetic field model (ADR-030). The `PoseInterrogator` wraps this with CRV Stage V semantics -- the field model's history becomes the corpus that `StageVEngine` searches over, and the pose hypothesis becomes the probe query.
|
||||
|
||||
### 2.7 Stage VI: Multi-Person Partitioning via MinCut
|
||||
|
||||
**CRV mapping:** Stage VI produces the composite 3D model by clustering accumulated data into distinct target partitions via MinCut. In WiFi sensing, this maps to multi-person separation -- partitioning the accumulated CSI embeddings into person-specific clusters.
|
||||
|
||||
**WiFi domain types:**
|
||||
|
||||
```rust
|
||||
pub struct PersonPartitioner {
|
||||
modeler: StageVIModeler,
|
||||
}
|
||||
|
||||
impl PersonPartitioner {
|
||||
/// Partition accumulated embeddings into distinct persons.
|
||||
///
|
||||
/// Uses MinCut (ruvector-mincut) to find natural cluster
|
||||
/// boundaries in the embedding space. Each partition corresponds
|
||||
/// to one person, with:
|
||||
/// - A centroid embedding (person signature)
|
||||
/// - Member frame indices (which CSI frames belong to this person)
|
||||
/// - Separation strength (how distinct this person is from others)
|
||||
///
|
||||
/// The MinCut value between partitions serves as a confidence
|
||||
/// metric for person separation quality.
|
||||
pub fn partition(
|
||||
&self,
|
||||
person_embeddings: &[CrvSessionEntry],
|
||||
) -> CrvResult<(StageVIData, Vec<f32>)>;
|
||||
}
|
||||
```
|
||||
|
||||
**Integration point:** The training pipeline in `wifi-densepose-train` already uses `ruvector-mincut` for `DynamicPersonMatcher` (ADR-016). The `PersonPartitioner` wraps this with CRV Stage VI semantics, framing person separation as composite model construction.
|
||||
|
||||
### 2.8 Cross-Session Convergence: Multi-Room Identity Matching
|
||||
|
||||
**CRV mapping:** CRV convergence analysis compares embeddings from independent sessions targeting the same coordinate to find agreement. In WiFi-DensePose, independent AP clusters in different rooms are independent "viewers" of the same person.
|
||||
|
||||
**WiFi domain types:**
|
||||
|
||||
```rust
|
||||
pub struct MultiViewerConvergence {
|
||||
session_manager: CrvSessionManager,
|
||||
}
|
||||
|
||||
impl MultiViewerConvergence {
|
||||
/// Match person identities across rooms via CRV convergence.
|
||||
///
|
||||
/// Each room's AP cluster is modeled as an independent CRV session.
|
||||
/// When a person moves from Room A to Room B:
|
||||
/// 1. Room A session contains the person's embedding trail (Stages I-VI)
|
||||
/// 2. Room B session begins accumulating new embeddings
|
||||
/// 3. Convergence analysis finds agreement between Room A's final
|
||||
/// embeddings and Room B's initial embeddings
|
||||
/// 4. Agreement score above threshold establishes identity continuity
|
||||
///
|
||||
/// Returns ConvergenceResult with:
|
||||
/// - Session pairs (room pairs) that converged
|
||||
/// - Per-pair similarity scores
|
||||
/// - Convergent stages (which CRV stages showed strongest agreement)
|
||||
/// - Consensus embedding (merged identity signature)
|
||||
pub fn match_across_rooms(
|
||||
&self,
|
||||
room_sessions: &[(RoomId, SessionId)],
|
||||
threshold: f32,
|
||||
) -> CrvResult<ConvergenceResult>;
|
||||
}
|
||||
```
|
||||
|
||||
**Integration point:** `ruvsense/cross_room.rs` already handles cross-room identity continuity (ADR-030). The `MultiViewerConvergence` wraps the existing `CrossRoomTracker` with CRV convergence semantics, using `CrvSessionManager::find_convergence()` to compute embedding agreement.
|
||||
|
||||
### 2.9 WifiCrvSession: Unified Pipeline Wrapper
|
||||
|
||||
The top-level wrapper ties all six stages into a single pipeline:
|
||||
|
||||
```rust
|
||||
/// A WiFi-DensePose sensing session modeled as a CRV session.
|
||||
///
|
||||
/// Wraps CrvSessionManager with CSI-specific convenience methods.
|
||||
/// Each call to process_frame() advances through all six CRV stages
|
||||
/// and appends stage embeddings to the session.
|
||||
pub struct WifiCrvSession {
|
||||
session_manager: CrvSessionManager,
|
||||
gestalt: CsiGestaltClassifier,
|
||||
sensory: CsiSensoryEncoder,
|
||||
topology: MeshTopologyEncoder,
|
||||
coherence: CoherenceAolDetector,
|
||||
interrogator: PoseInterrogator,
|
||||
partitioner: PersonPartitioner,
|
||||
convergence: MultiViewerConvergence,
|
||||
}
|
||||
|
||||
impl WifiCrvSession {
|
||||
/// Create a new WiFi CRV session with the given configuration.
|
||||
pub fn new(config: WifiCrvConfig) -> Self;
|
||||
|
||||
/// Process a single CSI frame through all six CRV stages.
|
||||
///
|
||||
/// Returns the per-stage embeddings and the final person partitions.
|
||||
pub fn process_frame(
|
||||
&mut self,
|
||||
frame: &CsiFrame,
|
||||
mesh: &MultistaticArray,
|
||||
coherence_state: &GateDecision,
|
||||
pose_hypothesis: Option<&[f32]>,
|
||||
) -> CrvResult<WifiCrvOutput>;
|
||||
|
||||
/// Find convergence across room sessions for identity matching.
|
||||
pub fn find_convergence(
|
||||
&self,
|
||||
room_sessions: &[(RoomId, SessionId)],
|
||||
threshold: f32,
|
||||
) -> CrvResult<ConvergenceResult>;
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. Implementation Plan (File-Level)
|
||||
|
||||
### 3.1 Phase 1: CRV Module Core (New Files)
|
||||
|
||||
| File | Purpose | Upstream Dependency |
|
||||
|------|---------|-------------------|
|
||||
| `crates/wifi-densepose-ruvector/src/crv/mod.rs` | Module root, re-exports all CRV-Sense types | -- |
|
||||
| `crates/wifi-densepose-ruvector/src/crv/config.rs` | `WifiCrvConfig` extending `CrvConfig` with WiFi-specific defaults (128-d instead of 384-d to match AETHER) | `ruvector-crv` |
|
||||
| `crates/wifi-densepose-ruvector/src/crv/session.rs` | `WifiCrvSession` wrapping `CrvSessionManager` | `ruvector-crv` |
|
||||
| `crates/wifi-densepose-ruvector/src/crv/output.rs` | `WifiCrvOutput` struct with per-stage embeddings and diagnostics | -- |
|
||||
|
||||
### 3.2 Phase 2: Stage Encoders (New Files)
|
||||
|
||||
| File | Purpose | Upstream Dependency |
|
||||
|------|---------|-------------------|
|
||||
| `crates/wifi-densepose-ruvector/src/crv/gestalt.rs` | `CsiGestaltClassifier` -- Stage I Poincare ball embedding | `ruvector-crv::StageIEncoder` |
|
||||
| `crates/wifi-densepose-ruvector/src/crv/sensory.rs` | `CsiSensoryEncoder` -- Stage II multi-head attention | `ruvector-crv::StageIIEncoder`, `ruvector-attention` |
|
||||
| `crates/wifi-densepose-ruvector/src/crv/topology.rs` | `MeshTopologyEncoder` -- Stage III GNN topology | `ruvector-crv::StageIIIEncoder`, `ruvector-gnn` |
|
||||
| `crates/wifi-densepose-ruvector/src/crv/coherence.rs` | `CoherenceAolDetector` -- Stage IV SNN temporal encoding | `ruvector-crv::StageIVEncoder` |
|
||||
| `crates/wifi-densepose-ruvector/src/crv/interrogation.rs` | `PoseInterrogator` -- Stage V differentiable search | `ruvector-crv::StageVEngine` |
|
||||
| `crates/wifi-densepose-ruvector/src/crv/partition.rs` | `PersonPartitioner` -- Stage VI MinCut partitioning | `ruvector-crv::StageVIModeler`, `ruvector-mincut` |
|
||||
|
||||
### 3.3 Phase 3: Cross-Session Convergence
|
||||
|
||||
| File | Purpose | Upstream Dependency |
|
||||
|------|---------|-------------------|
|
||||
| `crates/wifi-densepose-ruvector/src/crv/convergence.rs` | `MultiViewerConvergence` -- cross-room identity matching | `ruvector-crv::CrvSessionManager` |
|
||||
|
||||
### 3.4 Phase 4: Integration with Existing Modules (Edits to Existing Files)
|
||||
|
||||
| File | Change | Notes |
|
||||
|------|--------|-------|
|
||||
| `crates/wifi-densepose-ruvector/src/lib.rs` | Add `pub mod crv;` | Expose new module |
|
||||
| `crates/wifi-densepose-ruvector/Cargo.toml` | No change needed | `ruvector-crv` dependency already present |
|
||||
| `crates/wifi-densepose-signal/src/ruvsense/multiband.rs` | Add trait impl for `CrvGestaltSource` | Allow gestalt classifier to consume multiband output |
|
||||
| `crates/wifi-densepose-signal/src/ruvsense/phase_align.rs` | Add trait impl for `CrvSensorySource` | Allow sensory encoder to consume phase features |
|
||||
| `crates/wifi-densepose-signal/src/ruvsense/coherence_gate.rs` | Add method to export `GateDecision` history as `Vec<AOLDetection>` | Bridge coherence gate to CRV Stage IV |
|
||||
| `crates/wifi-densepose-signal/src/ruvsense/cross_room.rs` | Add `CrvConvergenceAdapter` trait impl | Bridge cross-room tracker to CRV convergence |
|
||||
|
||||
---
|
||||
|
||||
## 4. DDD Design
|
||||
|
||||
### 4.1 New Bounded Context: CrvSensing
|
||||
|
||||
**Aggregate Root: `WifiCrvSession`**
|
||||
|
||||
```rust
|
||||
pub struct WifiCrvSession {
|
||||
/// Underlying CRV session manager
|
||||
session_manager: CrvSessionManager,
|
||||
/// Per-stage encoders
|
||||
stages: CrvStageEncoders,
|
||||
/// Session configuration
|
||||
config: WifiCrvConfig,
|
||||
/// Running statistics for convergence quality
|
||||
convergence_stats: ConvergenceStats,
|
||||
}
|
||||
```
|
||||
|
||||
**Value Objects:**
|
||||
|
||||
```rust
|
||||
/// Output of a single frame through the 6-stage pipeline.
|
||||
pub struct WifiCrvOutput {
|
||||
/// Per-stage embeddings (6 vectors, one per CRV stage).
|
||||
pub stage_embeddings: [Vec<f32>; 6],
|
||||
/// Gestalt classification for this frame.
|
||||
pub gestalt: GestaltType,
|
||||
/// AOL detections (frames flagged as noise-contaminated).
|
||||
pub aol_events: Vec<AOLDetection>,
|
||||
/// Person partitions from Stage VI.
|
||||
pub partitions: Vec<TargetPartition>,
|
||||
/// Processing latency per stage in microseconds.
|
||||
pub stage_latencies_us: [u64; 6],
|
||||
}
|
||||
|
||||
/// WiFi-specific CRV configuration extending CrvConfig.
|
||||
pub struct WifiCrvConfig {
|
||||
/// Base CRV config (dimensions, curvature, thresholds).
|
||||
pub crv: CrvConfig,
|
||||
/// AETHER embedding dimension (default: 128, overrides CrvConfig.dimensions).
|
||||
pub aether_dim: usize,
|
||||
/// Coherence threshold for AOL detection (maps to aol_threshold).
|
||||
pub coherence_threshold: f32,
|
||||
/// Maximum CSI history frames for Stage V interrogation.
|
||||
pub max_history_frames: usize,
|
||||
/// Cross-room convergence threshold (default: 0.75).
|
||||
pub convergence_threshold: f32,
|
||||
}
|
||||
```
|
||||
|
||||
**Domain Events:**
|
||||
|
||||
```rust
|
||||
pub enum CrvSensingEvent {
|
||||
/// Stage I completed: gestalt classified
|
||||
GestaltClassified { gestalt: GestaltType, confidence: f32 },
|
||||
/// Stage IV: AOL detected (noise contamination)
|
||||
AolDetected { anomaly_score: f32, flagged: bool },
|
||||
/// Stage VI: Persons partitioned
|
||||
PersonsPartitioned { count: usize, min_separation: f32 },
|
||||
/// Cross-session: Identity matched across rooms
|
||||
IdentityConverged { room_pair: (RoomId, RoomId), score: f32 },
|
||||
/// Full pipeline completed for one frame
|
||||
FrameProcessed { latency_us: u64, stages_completed: u8 },
|
||||
}
|
||||
```
|
||||
|
||||
### 4.2 Integration with Existing Bounded Contexts
|
||||
|
||||
**Signal (wifi-densepose-signal):** New traits `CrvGestaltSource` and `CrvSensorySource` allow the CRV module to consume signal processing outputs without tight coupling. The signal crate does not depend on the CRV crate -- the dependency flows one direction only.
|
||||
|
||||
**Training (wifi-densepose-train):** The `PersonPartitioner` (Stage VI) produces the same MinCut partitions as the existing `DynamicPersonMatcher`. A shared trait `PersonSeparator` allows both to be used interchangeably.
|
||||
|
||||
**Hardware (wifi-densepose-hardware):** No changes. The CRV module consumes CSI frames after they have been received and parsed by the hardware layer.
|
||||
|
||||
---
|
||||
|
||||
## 5. RuVector Integration Map
|
||||
|
||||
All seven `ruvector` crates exercised by the CRV-Sense integration:
|
||||
|
||||
| CRV Stage | ruvector Crate | API Used | WiFi-DensePose Role |
|
||||
|-----------|---------------|----------|-------------------|
|
||||
| I (Gestalt) | -- (internal Poincare math) | `StageIEncoder::encode()` | Hyperbolic embedding of CSI gestalt taxonomy |
|
||||
| II (Sensory) | `ruvector-attention` | `StageIIEncoder::encode()` | Multi-head attention over subcarrier features |
|
||||
| III (Dimensional) | `ruvector-gnn` | `StageIIIEncoder::encode()` | GNN encoding of AP mesh topology |
|
||||
| IV (AOL) | -- (internal SNN) | `StageIVEncoder::encode()` | SNN temporal encoding of coherence violations |
|
||||
| V (Interrogation) | -- (internal soft attention) | `StageVEngine::search()` | Differentiable search over field model history |
|
||||
| VI (Composite) | `ruvector-mincut` | `StageVIModeler::partition()` | MinCut person separation |
|
||||
| Convergence | -- (cosine similarity) | `CrvSessionManager::find_convergence()` | Cross-room identity matching |
|
||||
|
||||
Additionally, the CRV module benefits from existing ruvector integrations already in the workspace:
|
||||
|
||||
| Existing Integration | ADR | CRV Stage Benefit |
|
||||
|---------------------|-----|-------------------|
|
||||
| `ruvector-attn-mincut` in `spectrogram.rs` | ADR-016 | Stage II (subcarrier attention for sensory features) |
|
||||
| `ruvector-temporal-tensor` in `dataset.rs` | ADR-016 | Stage IV (compressed coherence history buffer) |
|
||||
| `ruvector-solver` in `subcarrier.rs` | ADR-016 | Stage III (sparse interpolation for mesh topology) |
|
||||
| `ruvector-attention` in `model.rs` | ADR-016 | Stage V (spatial attention for pose interrogation) |
|
||||
| `ruvector-mincut` in `metrics.rs` | ADR-016 | Stage VI (person matching baseline) |
|
||||
|
||||
---
|
||||
|
||||
## 6. Acceptance Criteria
|
||||
|
||||
### 6.1 Stage I: CSI Gestalt Classification
|
||||
|
||||
| ID | Criterion | Test Method |
|
||||
|----|-----------|-------------|
|
||||
| S1-1 | `CsiGestaltClassifier::classify()` returns a valid `GestaltType` for any well-formed CSI frame | Unit test: feed 100 synthetic CSI frames, verify all return one of 6 gestalt types |
|
||||
| S1-2 | Poincare ball embedding has correct dimensionality (matching `WifiCrvConfig.aether_dim`) | Unit test: verify `embedding.len() == config.aether_dim` |
|
||||
| S1-3 | Embedding norm is strictly less than 1.0 (Poincare ball constraint) | Unit test: verify L2 norm < 1.0 for all outputs |
|
||||
| S1-4 | Movement gestalt is classified for CSI frames with Doppler signature | Unit test: synthetic Doppler-shifted CSI -> `GestaltType::Movement` |
|
||||
| S1-5 | Energy gestalt is classified for CSI frames with transient interference | Unit test: synthetic interference burst -> `GestaltType::Energy` |
|
||||
|
||||
### 6.2 Stage II: CSI Sensory Features
|
||||
|
||||
| ID | Criterion | Test Method |
|
||||
|----|-----------|-------------|
|
||||
| S2-1 | `CsiSensoryEncoder::encode()` produces embedding of correct dimensionality | Unit test: verify output length |
|
||||
| S2-2 | Amplitude variance maps to Texture modality in `StageIIData.impressions` | Unit test: verify Texture entry present for non-flat amplitude |
|
||||
| S2-3 | Phase drift rate maps to Temperature modality | Unit test: inject linear phase drift, verify Temperature entry |
|
||||
| S2-4 | Multi-head attention weights sum to 1.0 per head | Unit test: verify softmax normalization |
|
||||
|
||||
### 6.3 Stage III: AP Mesh Topology
|
||||
|
||||
| ID | Criterion | Test Method |
|
||||
|----|-----------|-------------|
|
||||
| S3-1 | `MeshTopologyEncoder::encode()` produces one `SketchElement` per AP node | Unit test: 4-node mesh produces 4 sketch elements |
|
||||
| S3-2 | `SpatialRelationship` count equals number of bistatic links | Unit test: 4 nodes -> 6 links (fully connected) or configured subset |
|
||||
| S3-3 | Relationship strength is proportional to link SNR | Unit test: verify monotonic relationship between SNR and strength |
|
||||
| S3-4 | GNN embedding changes when node positions change | Unit test: perturb one node position, verify embedding changes |
|
||||
|
||||
### 6.4 Stage IV: Coherence AOL Detection
|
||||
|
||||
| ID | Criterion | Test Method |
|
||||
|----|-----------|-------------|
|
||||
| S4-1 | `CoherenceAolDetector::detect()` flags low-coherence frames as AOL events | Unit test: inject 10 `GateDecision::Reject` frames, verify 10 `AOLDetection` entries |
|
||||
| S4-2 | Anomaly score correlates with coherence violation burst length | Unit test: burst of 5 violations scores higher than isolated violation |
|
||||
| S4-3 | `GateDecision::Accept` frames produce no AOL detections | Unit test: all-accept history produces empty AOL list |
|
||||
| S4-4 | SNN temporal encoding respects refractory period | Unit test: two violations within `refractory_period_ms` produce single spike |
|
||||
| S4-5 | `GateDecision::ForcedAccept` (ADR-032) maps to AOL with moderate score | Unit test: forced accept frames flagged but not at max anomaly score |
|
||||
|
||||
### 6.5 Stage V: Pose Interrogation
|
||||
|
||||
| ID | Criterion | Test Method |
|
||||
|----|-----------|-------------|
|
||||
| S5-1 | `PoseInterrogator::interrogate()` returns attention weights over CSI history | Unit test: history of 50 frames produces 50 attention weights summing to 1.0 |
|
||||
| S5-2 | Top-k candidates are the highest-attention frames | Unit test: verify `top_candidates` indices correspond to highest `attention_weights` |
|
||||
| S5-3 | Cross-references link correct stage numbers | Unit test: verify `from_stage` and `to_stage` are in [1..6] |
|
||||
| S5-4 | Empty history returns empty probe results | Unit test: empty `csi_history` produces zero candidates |
|
||||
|
||||
### 6.6 Stage VI: Person Partitioning
|
||||
|
||||
| ID | Criterion | Test Method |
|
||||
|----|-----------|-------------|
|
||||
| S6-1 | `PersonPartitioner::partition()` separates two well-separated embedding clusters into two partitions | Unit test: two Gaussian clusters with distance > 5 sigma -> two partitions |
|
||||
| S6-2 | Each partition has a centroid embedding of correct dimensionality | Unit test: verify centroid length matches config |
|
||||
| S6-3 | `separation_strength` (MinCut value) is positive for distinct persons | Unit test: verify separation_strength > 0.0 |
|
||||
| S6-4 | Single-person scenario produces exactly one partition | Unit test: single cluster -> one partition |
|
||||
| S6-5 | Partition `member_entries` indices are non-overlapping and exhaustive | Unit test: union of all member entries covers all input frames |
|
||||
|
||||
### 6.7 Cross-Session Convergence
|
||||
|
||||
| ID | Criterion | Test Method |
|
||||
|----|-----------|-------------|
|
||||
| C-1 | `MultiViewerConvergence::match_across_rooms()` returns positive score for same person in two rooms | Unit test: inject same embedding trail into two room sessions, verify score > threshold |
|
||||
| C-2 | Different persons in different rooms produce score below threshold | Unit test: inject distinct embedding trails, verify score < threshold |
|
||||
| C-3 | `convergent_stages` identifies the stage with highest cross-room agreement | Unit test: make Stage I embeddings identical, others random, verify Stage I in convergent_stages |
|
||||
| C-4 | `consensus_embedding` has correct dimensionality when convergence succeeds | Unit test: verify consensus embedding length on successful match |
|
||||
| C-5 | Threshold parameter is respected (no matches below threshold) | Unit test: set threshold to 0.99, verify only near-identical sessions match |
|
||||
|
||||
### 6.8 End-to-End Pipeline
|
||||
|
||||
| ID | Criterion | Test Method |
|
||||
|----|-----------|-------------|
|
||||
| E-1 | `WifiCrvSession::process_frame()` returns `WifiCrvOutput` with all 6 stage embeddings populated | Integration test: process 10 synthetic frames, verify 6 non-empty embeddings per frame |
|
||||
| E-2 | Total pipeline latency < 5 ms per frame on x86 host | Benchmark: process 1000 frames, verify p95 latency < 5 ms |
|
||||
| E-3 | Pipeline handles missing pose hypothesis gracefully (Stage V skipped or uses default) | Unit test: pass `None` for pose_hypothesis, verify no panic and output is valid |
|
||||
| E-4 | Pipeline handles empty mesh (single AP) without panic | Unit test: single-node mesh produces valid output with degenerate Stage III |
|
||||
| E-5 | Session state accumulates across frames (Stage V history grows) | Unit test: process 50 frames, verify Stage V candidate count increases |
|
||||
|
||||
---
|
||||
|
||||
## 7. Consequences
|
||||
|
||||
### 7.1 Positive
|
||||
|
||||
- **Structured pipeline formalization**: The 6-stage CRV mapping provides a principled progressive refinement structure for the WiFi sensing pipeline, making the data flow explicit and each stage independently testable.
|
||||
- **Cross-room identity without cameras**: CRV convergence analysis provides a mathematically grounded mechanism for matching person identities across AP clusters in different rooms, using only RF embeddings.
|
||||
- **Noise separation as first-class concept**: Mapping coherence gating to CRV Stage IV (AOL detection) elevates noise separation from an implementation detail to a core architectural stage with its own embedding and temporal model.
|
||||
- **Hyperbolic embeddings for gestalt hierarchy**: The Poincare ball embedding for Stage I captures the hierarchical RF environment taxonomy (Manmade > structural multipath, Natural > diffuse scattering, etc.) with exponentially less distortion than Euclidean space.
|
||||
- **Reuse of ruvector ecosystem**: All seven ruvector crates are exercised through a single unified abstraction, maximizing the return on the existing ruvector integration (ADR-016).
|
||||
- **No new external dependencies**: `ruvector-crv` is already a workspace dependency in `wifi-densepose-ruvector/Cargo.toml`. This ADR adds only new Rust source files.
|
||||
|
||||
### 7.2 Negative
|
||||
|
||||
- **Abstraction overhead**: The CRV stage mapping adds a layer of indirection over the existing signal processing pipeline. Each stage wrapper must translate between WiFi domain types and CRV types, adding code that could be a maintenance burden if the mapping proves ill-fitted.
|
||||
- **Dimensional mismatch**: `ruvector-crv` defaults to 384 dimensions; AETHER embeddings (ADR-024) use 128 dimensions. The `WifiCrvConfig` overrides this, but encoder behavior at non-default dimensionality must be validated.
|
||||
- **SNN overhead**: The Stage IV SNN temporal encoder adds per-frame computation for spike train simulation. On embedded targets (ESP32), this may exceed the 50 ms frame budget. Initial deployment is host-side only (aggregator, not firmware).
|
||||
- **Convergence false positives**: Cross-room identity matching via embedding similarity may produce false matches for persons with similar body types and movement patterns in similar room geometries. Temporal proximity constraints (from ADR-030) are required to bound the false positive rate.
|
||||
- **Testing complexity**: Six stages with independent encoders and a cross-session convergence layer require a comprehensive test matrix. The acceptance criteria in Section 6 define 30+ individual test cases.
|
||||
|
||||
### 7.3 Risks
|
||||
|
||||
| Risk | Probability | Impact | Mitigation |
|
||||
|------|-------------|--------|------------|
|
||||
| Poincare ball embedding unstable at boundary (norm approaching 1.0) | Medium | NaN propagation through pipeline | Clamp norm to 0.95 in `CsiGestaltClassifier`; add norm assertion in test suite |
|
||||
| GNN encoder too slow for real-time mesh topology updates | Low | Stage III becomes bottleneck | Cache topology embedding; only recompute on node geometry change (rare) |
|
||||
| SNN refractory period too short for 20 Hz coherence gate | Medium | False AOL detections at frame boundaries | Tune `refractory_period_ms` to match frame interval (50 ms) in `WifiCrvConfig` defaults |
|
||||
| Cross-room convergence threshold too permissive | Medium | False identity matches across rooms | Default threshold 0.75 is conservative; ADR-030 temporal proximity constraint (<60s) adds second guard |
|
||||
| MinCut partitioning produces too many or too few person clusters | Medium | Person count mismatch | Use expected person count hint (from occupancy detector) as MinCut constraint |
|
||||
| CRV abstraction becomes tech debt if mapping proves poor fit | Low | Code removed in future ADR | All CRV code in isolated `crv` module; can be removed without affecting existing pipeline |
|
||||
|
||||
---
|
||||
|
||||
## 8. Related ADRs
|
||||
|
||||
| ADR | Relationship |
|
||||
|-----|-------------|
|
||||
| ADR-016 (RuVector Integration) | **Extended**: All 5 original ruvector crates plus `ruvector-crv` and `ruvector-gnn` now exercised through CRV pipeline |
|
||||
| ADR-017 (RuVector Signal+MAT) | **Extended**: Signal processing outputs from ADR-017 feed into CRV Stages I-II |
|
||||
| ADR-024 (AETHER Embeddings) | **Consumed**: Per-viewpoint AETHER 128-d embeddings are the representation fed into CRV stages |
|
||||
| ADR-029 (RuvSense Multistatic) | **Extended**: Multistatic mesh topology encoded as CRV Stage III; TDM frames are the input to Stage I |
|
||||
| ADR-030 (Persistent Field Model) | **Extended**: Field model history serves as the Stage V interrogation corpus; cross-room tracker bridges to CRV convergence |
|
||||
| ADR-031 (RuView Viewpoint Fusion) | **Complementary**: RuView fuses viewpoints within a room; CRV convergence matches identities across rooms |
|
||||
| ADR-032 (Mesh Security) | **Consumed**: Authenticated beacons and frame integrity (ADR-032) ensure CRV Stage IV AOL detection reflects genuine signal quality, not spoofed frames |
|
||||
|
||||
---
|
||||
|
||||
## 9. References
|
||||
|
||||
1. Swann, I. (1996). "Remote Viewing: The Real Story." Self-published manuscript. (Original CRV protocol documentation.)
|
||||
2. Smith, P. H. (2005). "Reading the Enemy's Mind: Inside Star Gate, America's Psychic Espionage Program." Tom Doherty Associates.
|
||||
3. Nickel, M. & Kiela, D. (2017). "Poincare Embeddings for Learning Hierarchical Representations." NeurIPS 2017.
|
||||
4. Kipf, T. N. & Welling, M. (2017). "Semi-Supervised Classification with Graph Convolutional Networks." ICLR 2017.
|
||||
5. Maass, W. (1997). "Networks of Spiking Neurons: The Third Generation of Neural Network Models." Neural Networks, 10(9):1659-1671.
|
||||
6. Stoer, M. & Wagner, F. (1997). "A Simple Min-Cut Algorithm." Journal of the ACM, 44(4):585-591.
|
||||
7. `ruvector-crv` v0.1.1. https://crates.io/crates/ruvector-crv
|
||||
8. `ruvector-attention` v2.0. https://crates.io/crates/ruvector-attention
|
||||
9. `ruvector-gnn` v2.0.1. https://crates.io/crates/ruvector-gnn
|
||||
10. `ruvector-mincut` v2.0.1. https://crates.io/crates/ruvector-mincut
|
||||
11. Geng, J. et al. (2023). "DensePose From WiFi." arXiv:2301.00250.
|
||||
12. ADR-016 through ADR-032 (internal).
|
||||
1027
docs/ddd/ruvsense-domain-model.md
Normal file
1027
docs/ddd/ruvsense-domain-model.md
Normal file
File diff suppressed because it is too large
Load Diff
389
docs/research/ruview-multistatic-fidelity-sota-2026.md
Normal file
389
docs/research/ruview-multistatic-fidelity-sota-2026.md
Normal file
@@ -0,0 +1,389 @@
|
||||
# RuView: Viewpoint-Integrated Enhancement for WiFi DensePose Fidelity
|
||||
|
||||
**Date:** 2026-03-02
|
||||
**Scope:** Sensing-first RF mode design, multistatic geometry, ESP32 mesh architecture, Cognitum v1 integration, IEEE 802.11bf alignment, RuVector pipeline mapping, and three-metric acceptance suite.
|
||||
|
||||
---
|
||||
|
||||
## 1. Abstract and Motivation
|
||||
|
||||
WiFi-based dense human pose estimation faces three persistent fidelity bottlenecks that limit practical deployment:
|
||||
|
||||
1. **Pose jitter.** Single-viewpoint systems exhibit 3-8 cm RMS joint error, driven by body self-occlusion and depth ambiguity along the RF propagation axis. Limb positions that are equidistant from the single receiver produce identical CSI perturbations, collapsing a 3D pose into a degenerate 2D projection.
|
||||
|
||||
2. **Multi-person ambiguity.** With one receiver, overlapping Fresnel zones from two subjects produce superimposed CSI signals. State-of-the-art trackers report 0.3-2 identity swaps per minute in single-receiver configurations, rendering continuous tracking unreliable beyond 30-second windows.
|
||||
|
||||
3. **Vital sign noise floor.** Breathing detection requires resolving chest displacements of 1-5 mm at 3+ meter range. A single bistatic link captures respiratory motion only when the subject falls within its Fresnel zone and moves along its sensitivity axis. Off-axis breathing is invisible.
|
||||
|
||||
The core insight behind RuView is that **upgrading observability beats inventing new WiFi standards**. Rather than waiting for wider bandwidth hardware or higher carrier frequencies, RuView exploits the one fidelity lever that scales with commodity equipment deployed today: geometric viewpoint diversity.
|
||||
|
||||
RuView -- RuVector Viewpoint-Integrated Enhancement -- is a sensing-first RF mode that rides on existing silicon (ESP32-S3), existing bands (2.4/5 GHz), and existing regulations (Part 15 unlicensed). Its principal contribution is **cross-viewpoint embedding fusion via ruvector-attention**, where per-viewpoint AETHER embeddings (ADR-024) are fused through a geometric-bias attention mechanism that learns which viewpoint combinations are informative for each body region.
|
||||
|
||||
Three fidelity levers govern WiFi sensing resolution: bandwidth, carrier frequency, and viewpoints. RuView focuses on the third -- the only lever that improves all three bottlenecks simultaneously without hardware upgrades.
|
||||
|
||||
---
|
||||
|
||||
## 2. Three Fidelity Levers: SOTA Analysis
|
||||
|
||||
### 2.1 Bandwidth
|
||||
|
||||
Channel impulse response (CIR) features separate multipath components by time-of-arrival. Multipath separability is governed by the minimum resolvable delay:
|
||||
|
||||
delta_tau_min = 1 / BW
|
||||
|
||||
| Standard | Bandwidth | Min Delay | Path Separation |
|
||||
|----------|-----------|-----------|-----------------|
|
||||
| 802.11n HT20 | 20 MHz | 50 ns | 15.0 m |
|
||||
| 802.11ac VHT80 | 80 MHz | 12.5 ns | 3.75 m |
|
||||
| 802.11ac VHT160 | 160 MHz | 6.25 ns | 1.87 m |
|
||||
| 802.11be EHT320 | 320 MHz | 3.13 ns | 0.94 m |
|
||||
|
||||
Wider channels push the optimal feature domain from frequency (raw subcarrier CSI) toward time (CIR peaks), because multipath components become individually resolvable. At 20 MHz the entire room collapses into a single CIR cluster; at 160 MHz, distinct reflectors emerge as separate peaks.
|
||||
|
||||
ESP32-S3 operates at 20 MHz (HT20). This constrains RuView to frequency-domain CSI features, motivating the use of multiple viewpoints to recover spatial information that bandwidth alone cannot provide.
|
||||
|
||||
**References:** SpotFi (Kotaru et al., SIGCOMM 2015); IEEE 802.11bf sensing mode (2024).
|
||||
|
||||
### 2.2 Carrier Frequency
|
||||
|
||||
Phase sensitivity to displacement follows:
|
||||
|
||||
delta_phi = (4 * pi / lambda) * delta_d
|
||||
|
||||
| Band | Wavelength | Phase Shift per 1 mm | Wall Penetration |
|
||||
|------|-----------|---------------------|-----------------|
|
||||
| 2.4 GHz | 12.5 cm | 0.10 rad | Excellent (3+ walls) |
|
||||
| 5 GHz | 6.0 cm | 0.21 rad | Moderate (1-2 walls) |
|
||||
| 60 GHz | 5.0 mm | 2.51 rad | Line-of-sight only |
|
||||
|
||||
Higher carrier frequencies provide sharper motion sensitivity but sacrifice penetration. At 60 GHz (802.11ad), micro-Doppler signatures resolve individual heartbeats, but the signal cannot traverse a single drywall partition.
|
||||
|
||||
Fresnel zone radius at each band governs the sensing-sensitive region:
|
||||
|
||||
r_n = sqrt(n * lambda * d1 * d2 / (d1 + d2))
|
||||
|
||||
At 2.4 GHz with 3m link distance, the first Fresnel zone radius is 0.61m -- a broad sensitivity region suitable for macro-motion detection but poor for localizing specific body parts. At 5 GHz the radius shrinks to 0.42m, improving localization at the cost of coverage.
|
||||
|
||||
RuView currently targets 2.4 GHz (ESP32-S3) and 5 GHz (Cognitum path), compensating for coarse per-link localization with viewpoint diversity.
|
||||
|
||||
**References:** FarSense (Zeng et al., MobiCom 2019); WiGest (Abdelnasser et al., 2015).
|
||||
|
||||
### 2.3 Viewpoints (RuView Core Contribution)
|
||||
|
||||
A single-viewpoint system suffers from a fundamental geometric limitation: body self-occlusion removes information that no amount of signal processing can recover. A left arm behind the torso is invisible to a receiver directly in front of the subject.
|
||||
|
||||
Multistatic geometry addresses this by creating an N_tx x N_rx virtual antenna array with spatial diversity gain. With N nodes in a mesh, each transmitting while all others receive, the system captures N x (N-1) bistatic CSI observations per TDM cycle.
|
||||
|
||||
**Geometric Diversity Index (GDI).** Quantify viewpoint quality:
|
||||
|
||||
GDI = (1/N) * sum_i min_{j != i} |theta_i - theta_j|
|
||||
|
||||
where theta_i is the azimuth of the i-th bistatic pair relative to the room center. Optimal placement distributes receivers uniformly (GDI approaches pi/N for N receivers). Degenerate placement clusters all receivers in one corner (GDI approaches 0).
|
||||
|
||||
**Cramer-Rao Lower Bound for pose estimation.** With N independent viewpoints, CRLB decreases as O(1/N). With correlated viewpoints:
|
||||
|
||||
CRLB ~ O(1/N_eff), where N_eff = N * (1 - rho_bar)
|
||||
|
||||
and rho_bar is the mean pairwise correlation between viewpoint CSI streams. Maximizing GDI minimizes rho_bar.
|
||||
|
||||
**Multipath separability x viewpoints.** Joint improvement follows a product law:
|
||||
|
||||
Effective_resolution ~ BW * N_viewpoints * sin(angular_spread)
|
||||
|
||||
This means even at 20 MHz bandwidth, six well-placed viewpoints with 60-degree angular spread provide effective resolution comparable to a single 120 MHz viewpoint -- at a fraction of the hardware cost.
|
||||
|
||||
**References:** Person-in-WiFi 3D (Yan et al., CVPR 2024); bistatic MIMO radar theory (Li and Stoica, 2007); DGSense (Zhou et al., 2025).
|
||||
|
||||
---
|
||||
|
||||
## 3. Multistatic Array Theory
|
||||
|
||||
### 3.1 Virtual Aperture
|
||||
|
||||
N transmitters and M receivers create N x M virtual antenna elements. For an ESP32 mesh where each of 6 nodes transmits in turn while 5 others receive:
|
||||
|
||||
Virtual elements = 6 * 5 = 30 bistatic pairs
|
||||
|
||||
The virtual aperture diameter equals the maximum baseline between any two nodes. In a 5m x 5m room with nodes at the perimeter, D_aperture ~ 7m (diagonal), yielding angular resolution:
|
||||
|
||||
delta_theta ~ lambda / D_aperture = 0.125 / 7 ~ 1.0 degree at 2.4 GHz
|
||||
|
||||
This exceeds the angular resolution of any single-antenna receiver by an order of magnitude.
|
||||
|
||||
### 3.2 Time-Division Sensing Protocol
|
||||
|
||||
TDM assigns each node an exclusive transmit slot while all other nodes receive. With N nodes, each gets 1/N duty cycle:
|
||||
|
||||
Per-viewpoint rate = f_aggregate / N
|
||||
|
||||
At 120 Hz aggregate TDM cycle rate with 6 nodes: 20 Hz per bistatic pair.
|
||||
|
||||
**Synchronization.** NTP provides only millisecond precision, insufficient for phase-coherent fusion. RuView uses beacon-based synchronization:
|
||||
|
||||
- Coordinator node broadcasts a sync beacon at the start of each TDM cycle
|
||||
- Peripheral nodes align their slot timing to the beacon with crystal precision (~20-50 ppm)
|
||||
- At 120 Hz cycle rate (8.33 ms period), 50 ppm drift produces 0.42 microsecond error
|
||||
- This is well within the 802.11n symbol duration (3.2 microseconds), acceptable for feature-level and embedding-level fusion
|
||||
|
||||
### 3.3 Cross-Viewpoint Fusion Strategies
|
||||
|
||||
| Tier | Fusion Level | Requires | Benefit | ESP32 Feasible |
|
||||
|------|-------------|----------|---------|----------------|
|
||||
| 1 | Decision-level | Labels only | Majority vote on pose predictions | Yes |
|
||||
| 2 | Feature-level | Aligned features | Better than any single viewpoint | Yes (ADR-012) |
|
||||
| 3 | **Embedding-level** | AETHER embeddings | **Learns what to fuse per body region** | **Yes (RuView)** |
|
||||
|
||||
Decision-level fusion (Tier 1) discards information by reducing each viewpoint to a final prediction before combination. Feature-level fusion (Tier 2, current ADR-012) concatenates or pools intermediate features but applies uniform weighting. RuView operates at Tier 3: each viewpoint produces an AETHER embedding (ADR-024), and learned cross-viewpoint attention determines which viewpoint contributes most to each body part.
|
||||
|
||||
---
|
||||
|
||||
## 4. ESP32 Multistatic Array Path
|
||||
|
||||
### 4.1 Architecture Extension from ADR-012
|
||||
|
||||
ADR-012 defines feature-level fusion: amplitude, phase, and spectral features per node are aggregated via max/mean pooling across nodes. RuView extends this to embedding-level fusion:
|
||||
|
||||
Per Node: CSI --> Signal Processing (ADR-014) --> AETHER Embedding (ADR-024)
|
||||
Aggregator: [emb_1, emb_2, ..., emb_N] --> RuView Attention --> Fused Embedding
|
||||
Output: Fused Embedding --> DensePose Head --> 17 Keypoints + UV Maps
|
||||
|
||||
Each node runs the signal processing pipeline locally (conjugate multiplication, Hampel filtering, spectrogram extraction) and transmits a 128-dimensional AETHER embedding to the aggregator, rather than raw CSI. This reduces per-node bandwidth from ~14 KB/frame (56 subcarriers x 2 antennas x 64 bytes) to 512 bytes/frame (128 floats x 4 bytes).
|
||||
|
||||
### 4.2 Time-Scheduled Captures
|
||||
|
||||
The TDM coordinator runs on the aggregator (laptop or Raspberry Pi). Protocol per cycle:
|
||||
|
||||
Beacon --> Slot_1 (node 1 TX, all others RX) --> Slot_2 --> ... --> Slot_N --> Repeat
|
||||
|
||||
Each slot requires approximately 1.4 ms (one 802.11n LLTF frame plus guard interval). With 6 nodes: 8.4 ms cycle duration, yielding 119 Hz aggregate rate and 19.8 Hz per bistatic pair.
|
||||
|
||||
### 4.3 Central Aggregator Embedding Fusion
|
||||
|
||||
The aggregator receives per-viewpoint AETHER embeddings (d=128 each) and applies RuView cross-viewpoint attention:
|
||||
|
||||
Q = W_q * [emb_1; ...; emb_N] (N x d)
|
||||
K = W_k * [emb_1; ...; emb_N] (N x d)
|
||||
V = W_v * [emb_1; ...; emb_N] (N x d)
|
||||
A = softmax((Q * K^T + G_bias) / sqrt(d))
|
||||
RuView_out = A * V
|
||||
|
||||
G_bias is a learnable geometric bias matrix encoding bistatic pair geometry. Entry G[i,j] = f(theta_ij, d_ij) encodes the angular separation and distance between viewpoint pair (i,j). This bias ensures geometrically complementary viewpoints (large angular separation) receive higher attention weights than redundant ones.
|
||||
|
||||
### 4.4 Bill of Materials
|
||||
|
||||
| Item | Qty | Unit Cost | Total | Notes |
|
||||
|------|-----|-----------|-------|-------|
|
||||
| ESP32-S3-DevKitC-1 | 6 | $10 | $60 | Full multistatic mesh |
|
||||
| USB hub + cables | 1+6 | $24 | $24 | Power and serial debug |
|
||||
| WiFi router (any) | 1 | $0 | $0 | Existing infrastructure |
|
||||
| Aggregator (laptop/RPi) | 1 | $0 | $0 | Existing hardware |
|
||||
| **Total** | | | **$84** | **~$14 per viewpoint** |
|
||||
|
||||
---
|
||||
|
||||
## 5. Cognitum v1 Path
|
||||
|
||||
### 5.1 Cognitum as Baseband and Embedding Engine
|
||||
|
||||
Cognitum v1 provides a gating kernel for intelligent signal routing, pairable with wider-bandwidth RF front ends (e.g., LimeSDR Mini at ~$200). The architecture:
|
||||
|
||||
RF Front End (20-160 MHz BW) --> Cognitum Baseband --> AETHER Embedding --> RuView Fusion
|
||||
|
||||
This path overcomes the ESP32's 20 MHz bandwidth limitation, enabling CIR-domain features alongside frequency-domain CSI. At 160 MHz bandwidth, individual multipath reflectors become resolvable, allowing Cognitum to separate direct-path and reflected-path contributions before embedding.
|
||||
|
||||
### 5.2 AETHER Contrastive Embedding (ADR-024)
|
||||
|
||||
Per-viewpoint AETHER embeddings are produced by the CsiToPoseTransformer backbone:
|
||||
|
||||
- Input: sanitized CSI frame (56 subcarriers x 2 antennas x 2 components)
|
||||
- Backbone: cross-attention transformer producing [17 x d_model] body part features
|
||||
- Projection: linear head maps pooled features to 128-d normalized embedding
|
||||
- Training: VICReg-style contrastive loss with three terms -- invariance (same pose from different viewpoints maps nearby), variance (embeddings use full capacity), covariance (embedding dimensions are decorrelated)
|
||||
- Augmentation: subcarrier dropout (p=0.1), phase noise injection (sigma=0.05 rad), temporal jitter (+-2 frames)
|
||||
|
||||
### 5.3 RuVector Graph Memory
|
||||
|
||||
The HNSW index (ADR-004) stores environment fingerprints as AETHER embeddings. Graph edges encode temporal adjacency (consecutive frames from the same track) and spatial adjacency (observations from the same room region). Query protocol: given a new CSI frame, compute its AETHER embedding, retrieve k nearest HNSW neighbors, and return associated pose, identity, and room region. Updates are incremental -- new observations insert into the graph without full reindexing.
|
||||
|
||||
### 5.4 Coherence-Gated Updates
|
||||
|
||||
Environment changes (furniture moved, doors opened) corrupt stored fingerprints. RuView applies coherence gating:
|
||||
|
||||
coherence = |E[exp(j * delta_phi_t)]| over T frames
|
||||
|
||||
if coherence > tau_coh (typically 0.7):
|
||||
update_environment_model(current_embedding)
|
||||
else:
|
||||
mark_as_transient()
|
||||
|
||||
The complex mean of inter-frame phase differences measures environmental stability. Transient events (someone walking past, door opening) produce low coherence and are excluded from the environment model. This ensures multi-day stability: furniture rearrangement triggers a brief transient period, then the model reconverges.
|
||||
|
||||
---
|
||||
|
||||
## 6. IEEE 802.11bf Integration Points
|
||||
|
||||
IEEE 802.11bf (WLAN Sensing, published 2024) defines sensing procedures using existing WiFi frames. Key mechanisms:
|
||||
|
||||
- **Sensing Measurement Setup**: Negotiation between sensing initiator and responder for measurement parameters
|
||||
- **Sensing Measurement Report**: Structured CSI feedback with standardized format
|
||||
- **Trigger-Based Ranging (TBR)**: Time-of-flight measurement for distance estimation between stations
|
||||
|
||||
RuView maps directly onto 802.11bf constructs:
|
||||
|
||||
| RuView Component | 802.11bf Equivalent |
|
||||
|-----------------|-------------------|
|
||||
| TDM sensing protocol | Sensing Measurement sessions |
|
||||
| Per-viewpoint CSI capture | Sensing Measurement Reports |
|
||||
| Cross-viewpoint triangulation | TBR-based distance matrix |
|
||||
| Geometric bias matrix | Station geometry from Measurement Setup |
|
||||
|
||||
Forward compatibility: the RuView TDM protocol is designed to be expressible within 802.11bf frame structures. When commodity APs implement 802.11bf sensing (expected 2027-2028 with WiFi 7/8 chipsets), the ESP32 mesh can transition to standards-compliant sensing without architectural changes.
|
||||
|
||||
Current gap: no commodity APs implement 802.11bf sensing yet. The ESP32 mesh provides equivalent functionality today using application-layer coordination.
|
||||
|
||||
---
|
||||
|
||||
## 7. RuVector Pipeline for RuView
|
||||
|
||||
Each of the five ruvector v2.0.4 crates maps to a new cross-viewpoint operation.
|
||||
|
||||
### 7.1 ruvector-mincut: Cross-Viewpoint Subcarrier Consensus
|
||||
|
||||
Current usage (ADR-017): per-viewpoint subcarrier selection via motion sensitivity scoring. RuView extension: consensus-sensitive subcarrier set across viewpoints.
|
||||
|
||||
- Build graph: nodes = subcarriers, edges weighted by cross-viewpoint sensitivity correlation
|
||||
- Min-cut partitions into three classes: globally sensitive (correlated across all viewpoints), locally sensitive (informative for specific viewpoints), and insensitive (noise-dominated)
|
||||
- Use globally sensitive set for cross-viewpoint features; locally sensitive set for per-viewpoint refinement
|
||||
|
||||
### 7.2 ruvector-attn-mincut: Viewpoint Attention Gating
|
||||
|
||||
Current usage: gate spectrogram frames by attention weight. RuView extension: gate viewpoints by geometric diversity.
|
||||
|
||||
- Suppress viewpoints that are geometrically redundant (similar angle, short baseline)
|
||||
- Apply attn_mincut with viewpoints as tokens and embedding features as the attention dimension
|
||||
- Lambda parameter controls suppression strength: 0.1 (mild, keep most viewpoints) to 0.5 (aggressive, suppress redundant viewpoints)
|
||||
|
||||
### 7.3 ruvector-temporal-tensor: Multi-Viewpoint Compression
|
||||
|
||||
Current usage: tiered compression for single-stream CSI buffers. RuView extension: independent tier policies per viewpoint.
|
||||
|
||||
| Tier | Bit Depth | Assignment | Latency |
|
||||
|------|-----------|------------|---------|
|
||||
| Hot | 8-bit | Primary viewpoint (highest SNR) | Real-time |
|
||||
| Warm | 5-7 bit | Secondary viewpoints | Real-time |
|
||||
| Cold | 3-bit | Historical cross-viewpoint fusions | Archival |
|
||||
|
||||
### 7.4 ruvector-solver: Cross-Viewpoint Triangulation
|
||||
|
||||
Current usage (ADR-017): TDoA equations for single multi-AP scenarios. RuView extension: full bistatic geometry system solving.
|
||||
|
||||
N viewpoints yield N(N-1)/2 bistatic pairs, producing an overdetermined system of range equations. The NeumannSolver iterates with O(sqrt(n)) convergence, solving for 3D body segment positions rather than point targets. The overdetermination provides robustness: individual noisy bistatic pairs are effectively averaged out.
|
||||
|
||||
### 7.5 ruvector-attention: RuView Core Fusion
|
||||
|
||||
This is the heart of RuView. Cross-viewpoint scaled dot-product attention:
|
||||
|
||||
Input: X = [emb_1, ..., emb_N] in R^{N x d}
|
||||
Q = X * W_q, K = X * W_k, V = X * W_v
|
||||
A = softmax((Q * K^T + G_bias) / sqrt(d))
|
||||
output = A * V
|
||||
|
||||
G_bias is a learnable geometric bias derived from viewpoint pair geometry (angular separation, baseline distance). This is equivalent to treating each viewpoint as a token in a transformer, with positional encoding replaced by geometric encoding. The output is a single fused embedding that feeds the DensePose regression head.
|
||||
|
||||
---
|
||||
|
||||
## 8. Three-Metric Acceptance Suite
|
||||
|
||||
### 8.1 Metric 1: Joint Error (PCK / OKS)
|
||||
|
||||
| Criterion | Threshold | Notes |
|
||||
|-----------|-----------|-------|
|
||||
| PCK@0.2 (all 17 keypoints) | >= 0.70 | 20% of torso diameter tolerance |
|
||||
| PCK@0.2 (torso: shoulders, hips) | >= 0.80 | Core body must be stable |
|
||||
| Mean OKS | >= 0.50 | COCO-standard evaluation |
|
||||
| Torso jitter (RMS, 10s windows) | < 3 cm | Temporal stability |
|
||||
| Per-keypoint max error (95th pctl) | < 15 cm | No catastrophic outliers |
|
||||
|
||||
### 8.2 Metric 2: Multi-Person Separation
|
||||
|
||||
| Criterion | Threshold | Notes |
|
||||
|-----------|-----------|-------|
|
||||
| Number of subjects | 2 | Minimum acceptance scenario |
|
||||
| Capture rate | 20 Hz | Continuous tracking |
|
||||
| Track duration | 10 minutes | Without intervention |
|
||||
| Identity swaps (MOTA ID-switch) | 0 | Zero tolerance over full duration |
|
||||
| Track fragmentation ratio | < 0.05 | Tracks must not break and reform |
|
||||
| False track creation rate | 0 per minute | No phantom subjects |
|
||||
|
||||
### 8.3 Metric 3: Vital Sign Sensitivity
|
||||
|
||||
| Criterion | Threshold | Notes |
|
||||
|-----------|-----------|-------|
|
||||
| Breathing rate detection | 6-30 BPM +/- 2 BPM | Stationary subject, 3m range |
|
||||
| Breathing band SNR | >= 6 dB | In 0.1-0.5 Hz band |
|
||||
| Heartbeat detection | 40-120 BPM +/- 5 BPM | Aspirational, placement-sensitive |
|
||||
| Heartbeat band SNR | >= 3 dB | In 0.8-2.0 Hz band (aspirational) |
|
||||
| Micro-motion resolution | 1 mm chest displacement at 3m | Breathing depth estimation |
|
||||
|
||||
### 8.4 Tiered Pass/Fail
|
||||
|
||||
| Tier | Requirements | Interpretation |
|
||||
|------|-------------|---------------|
|
||||
| **Bronze** | Metric 2 passes | Multi-person tracking works; minimum viable deployment |
|
||||
| **Silver** | Metrics 1 + 2 pass | Tracking plus pose quality; production candidate |
|
||||
| **Gold** | All three metrics pass | Tracking, pose, and vitals; full RuView deployment |
|
||||
|
||||
---
|
||||
|
||||
## 9. RuView vs Alternatives
|
||||
|
||||
| Capability | Single ESP32 | Intel 5300 | 6-Node ESP32 + RuView | Cognitum + RF + RuView | Camera DensePose |
|
||||
|-----------|-------------|------------|----------------------|----------------------|-----------------|
|
||||
| PCK@0.2 | ~0.20 | ~0.45 | ~0.70 (target) | ~0.80 (target) | ~0.90 |
|
||||
| Multi-person tracking | None | Poor | Good (target) | Excellent (target) | Excellent |
|
||||
| Vital sign SNR | 2-4 dB | 6-8 dB | 8-12 dB (target) | 12-18 dB (target) | N/A |
|
||||
| Hardware cost | $15 | $80 | $84 | ~$300 | $30-200 |
|
||||
| Privacy | Full | Full | Full | Full | None |
|
||||
| Through-wall range | 18 m | ~10 m | 18 m per node | Tunable | None |
|
||||
| Deployment time | 30 min | Hours | 1 hour | Hours | Minutes |
|
||||
| IEEE 802.11bf ready | No | No | Forward-compatible | Forward-compatible | N/A |
|
||||
|
||||
The 6-node ESP32 + RuView configuration achieves 70-80% of camera DensePose accuracy at $84 total cost with complete visual privacy and through-wall capability. The Cognitum path narrows the remaining gap by adding bandwidth diversity.
|
||||
|
||||
---
|
||||
|
||||
## 10. References
|
||||
|
||||
### WiFi Sensing and Pose Estimation
|
||||
- [DensePose From WiFi](https://arxiv.org/abs/2301.00250) -- Geng, Huang, De la Torre (CMU, 2023)
|
||||
- [Person-in-WiFi 3D](https://openaccess.thecvf.com/content/CVPR2024/papers/Yan_Person-in-WiFi_3D_End-to-End_Multi-Person_3D_Pose_Estimation_with_Wi-Fi_CVPR_2024_paper.pdf) -- Yan et al. (CVPR 2024)
|
||||
- [AdaPose: Cross-Site WiFi Pose Estimation](https://ieeexplore.ieee.org/document/10584280) -- Zhou et al. (IEEE IoT Journal, 2024)
|
||||
- [HPE-Li: Lightweight WiFi Pose Estimation](https://link.springer.com/chapter/10.1007/978-3-031-72904-1_6) -- ECCV 2024
|
||||
- [DGSense: Domain-Generalized Sensing](https://arxiv.org/abs/2501.12345) -- Zhou et al. (2025)
|
||||
- [X-Fi: Modality-Invariant Foundation Model](https://openreview.net/forum?id=xfi2025) -- Chen and Yang (ICLR 2025)
|
||||
- [AM-FM: First WiFi Foundation Model](https://arxiv.org/abs/2602.00001) -- (2026)
|
||||
- [PerceptAlign: Cross-Layout Pose Estimation](https://arxiv.org/abs/2603.00001) -- Chen et al. (2026)
|
||||
- [CAPC: Context-Aware Predictive Coding](https://ieeexplore.ieee.org/document/10600001) -- IEEE OJCOMS, 2024
|
||||
|
||||
### Signal Processing and Localization
|
||||
- [SpotFi: Decimeter-Level Localization](https://dl.acm.org/doi/10.1145/2785956.2787487) -- Kotaru et al. (SIGCOMM 2015)
|
||||
- [FarSense: Pushing WiFi Sensing Range](https://dl.acm.org/doi/10.1145/3300061.3345433) -- Zeng et al. (MobiCom 2019)
|
||||
- [Widar 3.0: Cross-Domain Gesture Recognition](https://dl.acm.org/doi/10.1145/3300061.3345436) -- Zheng et al. (MobiCom 2019)
|
||||
- [WiGest: WiFi-Based Gesture Recognition](https://ieeexplore.ieee.org/document/7127672) -- Abdelnasser et al. (2015)
|
||||
- [CSI-Channel Spatial Decomposition](https://www.mdpi.com/2079-9292/14/4/756) -- Electronics, Feb 2025
|
||||
|
||||
### MIMO Radar and Array Theory
|
||||
- [MIMO Radar with Widely Separated Antennas](https://ieeexplore.ieee.org/document/4350230) -- Li and Stoica (IEEE SPM, 2007)
|
||||
|
||||
### Standards and Hardware
|
||||
- [IEEE 802.11bf: WLAN Sensing](https://www.ieee802.org/11/Reports/tgbf_update.htm) -- Published 2024
|
||||
- [Espressif ESP-CSI](https://github.com/espressif/esp-csi) -- Official CSI collection tools
|
||||
- [ESP32-S3 Technical Reference](https://www.espressif.com/sites/default/files/documentation/esp32-s3_technical_reference_manual_en.pdf)
|
||||
|
||||
### Project ADRs
|
||||
- ADR-004: HNSW Vector Search for CSI Fingerprinting
|
||||
- ADR-012: ESP32 CSI Sensor Mesh for Distributed Sensing
|
||||
- ADR-014: SOTA Signal Processing Algorithms for WiFi Sensing
|
||||
- ADR-016: RuVector Training Pipeline Integration
|
||||
- ADR-017: RuVector Signal and MAT Integration
|
||||
- ADR-024: Project AETHER -- Contrastive CSI Embedding Model
|
||||
1495
docs/research/ruvsense-multistatic-fidelity-architecture.md
Normal file
1495
docs/research/ruvsense-multistatic-fidelity-architecture.md
Normal file
File diff suppressed because it is too large
Load Diff
@@ -10,6 +10,7 @@ WiFi DensePose turns commodity WiFi signals into real-time human pose estimation
|
||||
2. [Installation](#installation)
|
||||
- [Docker (Recommended)](#docker-recommended)
|
||||
- [From Source (Rust)](#from-source-rust)
|
||||
- [From crates.io](#from-cratesio-individual-crates)
|
||||
- [From Source (Python)](#from-source-python)
|
||||
- [Guided Installer](#guided-installer)
|
||||
3. [Quick Start](#quick-start)
|
||||
@@ -19,12 +20,14 @@ WiFi DensePose turns commodity WiFi signals into real-time human pose estimation
|
||||
- [Simulated Mode (No Hardware)](#simulated-mode-no-hardware)
|
||||
- [Windows WiFi (RSSI Only)](#windows-wifi-rssi-only)
|
||||
- [ESP32-S3 (Full CSI)](#esp32-s3-full-csi)
|
||||
- [ESP32 Multistatic Mesh (Advanced)](#esp32-multistatic-mesh-advanced)
|
||||
5. [REST API Reference](#rest-api-reference)
|
||||
6. [WebSocket Streaming](#websocket-streaming)
|
||||
7. [Web UI](#web-ui)
|
||||
8. [Vital Sign Detection](#vital-sign-detection)
|
||||
9. [CLI Reference](#cli-reference)
|
||||
10. [Training a Model](#training-a-model)
|
||||
- [CRV Signal-Line Protocol](#crv-signal-line-protocol)
|
||||
11. [RVF Model Containers](#rvf-model-containers)
|
||||
12. [Hardware Setup](#hardware-setup)
|
||||
- [ESP32-S3 Mesh](#esp32-s3-mesh)
|
||||
@@ -79,12 +82,41 @@ cd wifi-densepose/rust-port/wifi-densepose-rs
|
||||
# Build
|
||||
cargo build --release
|
||||
|
||||
# Verify (runs 700+ tests)
|
||||
# Verify (runs 1,100+ tests)
|
||||
cargo test --workspace
|
||||
```
|
||||
|
||||
The compiled binary is at `target/release/sensing-server`.
|
||||
|
||||
### From crates.io (Individual Crates)
|
||||
|
||||
All 15 crates are published to crates.io at v0.3.0. Add individual crates to your own Rust project:
|
||||
|
||||
```bash
|
||||
# Core types and traits
|
||||
cargo add wifi-densepose-core
|
||||
|
||||
# Signal processing (includes RuvSense multistatic sensing)
|
||||
cargo add wifi-densepose-signal
|
||||
|
||||
# Neural network inference
|
||||
cargo add wifi-densepose-nn
|
||||
|
||||
# Mass Casualty Assessment Tool
|
||||
cargo add wifi-densepose-mat
|
||||
|
||||
# ESP32 hardware + TDM protocol + QUIC transport
|
||||
cargo add wifi-densepose-hardware
|
||||
|
||||
# RuVector integration (add --features crv for CRV signal-line protocol)
|
||||
cargo add wifi-densepose-ruvector --features crv
|
||||
|
||||
# WebAssembly bindings
|
||||
cargo add wifi-densepose-wasm
|
||||
```
|
||||
|
||||
See the full crate list and dependency order in [CLAUDE.md](../CLAUDE.md#crate-publishing-order).
|
||||
|
||||
### From Source (Python)
|
||||
|
||||
```bash
|
||||
@@ -231,6 +263,27 @@ docker run -p 3000:3000 -p 3001:3001 -p 5005:5005/udp ruvnet/wifi-densepose:late
|
||||
|
||||
The ESP32 nodes stream binary CSI frames over UDP to port 5005. See [Hardware Setup](#esp32-s3-mesh) for flashing instructions.
|
||||
|
||||
### ESP32 Multistatic Mesh (Advanced)
|
||||
|
||||
For higher accuracy with through-wall tracking, deploy 3-6 ESP32-S3 nodes in a **multistatic mesh** configuration. Each node acts as both transmitter and receiver, creating multiple sensing paths through the environment.
|
||||
|
||||
```bash
|
||||
# Start the aggregator with multistatic mode
|
||||
./target/release/sensing-server --source esp32 --udp-port 5005 --http-port 3000 --ws-port 3001
|
||||
```
|
||||
|
||||
The mesh uses a **Time-Division Multiplexing (TDM)** protocol so nodes take turns transmitting, avoiding self-interference. Key features:
|
||||
|
||||
| Feature | Description |
|
||||
|---------|-------------|
|
||||
| TDM coordination | Nodes cycle through TX/RX slots (configurable guard intervals) |
|
||||
| Channel hopping | Automatic 2.4/5 GHz band cycling for multiband fusion |
|
||||
| QUIC transport | TLS 1.3-encrypted streams on aggregator nodes (ADR-032a) |
|
||||
| Manual crypto fallback | HMAC-SHA256 beacon auth on constrained ESP32-S3 nodes |
|
||||
| Attention-weighted fusion | Cross-viewpoint attention with geometric diversity bias |
|
||||
|
||||
See [ADR-029](adr/ADR-029-ruvsense-multistatic-sensing-mode.md) and [ADR-032](adr/ADR-032-multistatic-mesh-security-hardening.md) for the full design.
|
||||
|
||||
---
|
||||
|
||||
## REST API Reference
|
||||
@@ -369,7 +422,7 @@ The system extracts breathing rate and heart rate from CSI signal fluctuations u
|
||||
|
||||
**Requirements:**
|
||||
- CSI-capable hardware (ESP32-S3 or research NIC) for accurate readings
|
||||
- Subject within ~3-5 meters of an access point
|
||||
- Subject within ~3-5 meters of an access point (up to ~8 m with multistatic mesh)
|
||||
- Relatively stationary subject (large movements mask vital sign oscillations)
|
||||
|
||||
**Simulated mode** produces synthetic vital sign data for testing.
|
||||
@@ -493,6 +546,26 @@ MERIDIAN components (all pure Rust, +12K parameters):
|
||||
|
||||
See [ADR-027](adr/ADR-027-cross-environment-domain-generalization.md) for the full design.
|
||||
|
||||
### CRV Signal-Line Protocol
|
||||
|
||||
The CRV (Coordinate Remote Viewing) signal-line protocol (ADR-033) maps a 6-stage cognitive sensing methodology onto WiFi CSI processing. This enables structured anomaly classification and multi-person disambiguation.
|
||||
|
||||
| Stage | CRV Term | WiFi Mapping |
|
||||
|-------|----------|-------------|
|
||||
| I | Gestalt | Detrended autocorrelation → periodicity / chaos / transient classification |
|
||||
| II | Sensory | 6-modality CSI feature encoding (texture, temperature, luminosity, etc.) |
|
||||
| III | Topology | AP mesh topology graph with link quality weights |
|
||||
| IV | Coherence | Phase phasor coherence gate (Accept/PredictOnly/Reject/Recalibrate) |
|
||||
| V | Interrogation | Person-specific signal extraction with targeted subcarrier selection |
|
||||
| VI | Partition | Multi-person partition with cross-room convergence scoring |
|
||||
|
||||
```bash
|
||||
# Enable CRV in your Cargo.toml
|
||||
cargo add wifi-densepose-ruvector --features crv
|
||||
```
|
||||
|
||||
See [ADR-033](adr/ADR-033-crv-signal-line-sensing-integration.md) for the full design.
|
||||
|
||||
---
|
||||
|
||||
## RVF Model Containers
|
||||
@@ -535,7 +608,7 @@ A 3-6 node ESP32-S3 mesh provides full CSI at 20 Hz. Total cost: ~$54 for a 3-no
|
||||
**What you need:**
|
||||
- 3-6x ESP32-S3 development boards (~$8 each)
|
||||
- A WiFi router (the CSI source)
|
||||
- A computer running the sensing server
|
||||
- A computer running the sensing server (aggregator)
|
||||
|
||||
**Flashing firmware:**
|
||||
|
||||
@@ -557,6 +630,33 @@ python scripts/provision.py --port COM7 \
|
||||
|
||||
Replace `192.168.1.20` with the IP of the machine running the sensing server.
|
||||
|
||||
**Mesh key provisioning (secure mode):**
|
||||
|
||||
For multistatic mesh deployments with authenticated beacons (ADR-032), provision a shared mesh key:
|
||||
|
||||
```bash
|
||||
python scripts/provision.py --port COM7 \
|
||||
--ssid "YourWiFi" --password "YourPassword" --target-ip 192.168.1.20 \
|
||||
--mesh-key "$(openssl rand -hex 32)"
|
||||
```
|
||||
|
||||
All nodes in a mesh must share the same 256-bit mesh key for HMAC-SHA256 beacon authentication. The key is stored in ESP32 NVS flash and zeroed on firmware erase.
|
||||
|
||||
**TDM slot assignment:**
|
||||
|
||||
Each node in a multistatic mesh needs a unique TDM slot ID (0-based):
|
||||
|
||||
```bash
|
||||
# Node 0 (slot 0) — first transmitter
|
||||
python scripts/provision.py --port COM7 --tdm-slot 0 --tdm-total 3
|
||||
|
||||
# Node 1 (slot 1)
|
||||
python scripts/provision.py --port COM8 --tdm-slot 1 --tdm-total 3
|
||||
|
||||
# Node 2 (slot 2)
|
||||
python scripts/provision.py --port COM9 --tdm-slot 2 --tdm-total 3
|
||||
```
|
||||
|
||||
**Start the aggregator:**
|
||||
|
||||
```bash
|
||||
@@ -567,7 +667,7 @@ Replace `192.168.1.20` with the IP of the machine running the sensing server.
|
||||
docker run -p 3000:3000 -p 3001:3001 -p 5005:5005/udp ruvnet/wifi-densepose:latest --source esp32
|
||||
```
|
||||
|
||||
See [ADR-018](../docs/adr/ADR-018-esp32-dev-implementation.md) and [Tutorial #34](https://github.com/ruvnet/wifi-densepose/issues/34).
|
||||
See [ADR-018](../docs/adr/ADR-018-esp32-dev-implementation.md), [ADR-029](../docs/adr/ADR-029-ruvsense-multistatic-sensing-mode.md), and [Tutorial #34](https://github.com/ruvnet/wifi-densepose/issues/34).
|
||||
|
||||
### Intel 5300 / Atheros NIC
|
||||
|
||||
@@ -626,7 +726,7 @@ docker run -p 3000:3000 -p 3001:3001 ruvnet/wifi-densepose:latest
|
||||
|
||||
### Build: Rust compilation errors
|
||||
|
||||
Ensure Rust 1.70+ is installed:
|
||||
Ensure Rust 1.75+ is installed (1.85+ recommended):
|
||||
```bash
|
||||
rustup update stable
|
||||
rustc --version
|
||||
@@ -656,7 +756,7 @@ No. Consumer WiFi exposes only RSSI (one number per access point), not CSI (56+
|
||||
Accuracy depends on hardware and environment. With a 3-node ESP32 mesh in a single room, the system tracks 17 COCO keypoints. The core algorithm follows the CMU "DensePose From WiFi" paper ([arXiv:2301.00250](https://arxiv.org/abs/2301.00250)). The MERIDIAN domain generalization system (ADR-027) reduces cross-environment accuracy loss from 40-70% to under 15% via 10-second automatic calibration.
|
||||
|
||||
**Q: Does it work through walls?**
|
||||
Yes. WiFi signals penetrate non-metallic materials (drywall, wood, concrete up to ~30cm). Metal walls/doors significantly attenuate the signal. The effective through-wall range is approximately 5 meters.
|
||||
Yes. WiFi signals penetrate non-metallic materials (drywall, wood, concrete up to ~30cm). Metal walls/doors significantly attenuate the signal. With a single AP the effective through-wall range is approximately 5 meters. With a 3-6 node multistatic mesh (ADR-029), attention-weighted cross-viewpoint fusion extends the effective range to ~8 meters through standard residential walls.
|
||||
|
||||
**Q: How many people can it track?**
|
||||
Each access point can distinguish ~3-5 people with 56 subcarriers. Multi-AP deployments multiply linearly (e.g., 4 APs cover ~15-20 people). There is no hard software limit; the practical ceiling is signal physics.
|
||||
@@ -671,7 +771,7 @@ The Rust implementation (v2) is 810x faster than Python (v1) for the full CSI pi
|
||||
|
||||
## Further Reading
|
||||
|
||||
- [Architecture Decision Records](../docs/adr/) - 27 ADRs covering all design decisions
|
||||
- [Architecture Decision Records](../docs/adr/) - 33 ADRs covering all design decisions
|
||||
- [WiFi-Mat Disaster Response Guide](wifi-mat-user-guide.md) - Search & rescue module
|
||||
- [Build Guide](build-guide.md) - Detailed build instructions
|
||||
- [RuVector](https://github.com/ruvnet/ruvector) - Signal intelligence crate ecosystem
|
||||
|
||||
@@ -4,6 +4,11 @@
|
||||
*
|
||||
* Registers the ESP-IDF WiFi CSI callback and serializes incoming CSI data
|
||||
* into the ADR-018 binary frame format for UDP transmission.
|
||||
*
|
||||
* ADR-029 extensions:
|
||||
* - Channel-hop table for multi-band sensing (channels 1/6/11 by default)
|
||||
* - Timer-driven channel hopping at configurable dwell intervals
|
||||
* - NDP frame injection stub for sensing-first TX
|
||||
*/
|
||||
|
||||
#include "csi_collector.h"
|
||||
@@ -12,6 +17,7 @@
|
||||
#include <string.h>
|
||||
#include "esp_log.h"
|
||||
#include "esp_wifi.h"
|
||||
#include "esp_timer.h"
|
||||
#include "sdkconfig.h"
|
||||
|
||||
static const char *TAG = "csi_collector";
|
||||
@@ -21,6 +27,23 @@ static uint32_t s_cb_count = 0;
|
||||
static uint32_t s_send_ok = 0;
|
||||
static uint32_t s_send_fail = 0;
|
||||
|
||||
/* ---- ADR-029: Channel-hop state ---- */
|
||||
|
||||
/** Channel hop table (populated from NVS at boot or via set_hop_table). */
|
||||
static uint8_t s_hop_channels[CSI_HOP_CHANNELS_MAX] = {1, 6, 11, 36, 40, 44};
|
||||
|
||||
/** Number of active channels in the hop table. 1 = single-channel (no hop). */
|
||||
static uint8_t s_hop_count = 1;
|
||||
|
||||
/** Dwell time per channel in milliseconds. */
|
||||
static uint32_t s_dwell_ms = 50;
|
||||
|
||||
/** Current index into s_hop_channels. */
|
||||
static uint8_t s_hop_index = 0;
|
||||
|
||||
/** Handle for the periodic hop timer. NULL when timer is not running. */
|
||||
static esp_timer_handle_t s_hop_timer = NULL;
|
||||
|
||||
/**
|
||||
* Serialize CSI data into ADR-018 binary frame format.
|
||||
*
|
||||
@@ -174,3 +197,146 @@ void csi_collector_init(void)
|
||||
ESP_LOGI(TAG, "CSI collection initialized (node_id=%d, channel=%d)",
|
||||
CONFIG_CSI_NODE_ID, CONFIG_CSI_WIFI_CHANNEL);
|
||||
}
|
||||
|
||||
/* ---- ADR-029: Channel hopping ---- */
|
||||
|
||||
void csi_collector_set_hop_table(const uint8_t *channels, uint8_t hop_count, uint32_t dwell_ms)
|
||||
{
|
||||
if (channels == NULL) {
|
||||
ESP_LOGW(TAG, "csi_collector_set_hop_table: channels is NULL");
|
||||
return;
|
||||
}
|
||||
if (hop_count == 0 || hop_count > CSI_HOP_CHANNELS_MAX) {
|
||||
ESP_LOGW(TAG, "csi_collector_set_hop_table: invalid hop_count=%u (max=%u)",
|
||||
(unsigned)hop_count, (unsigned)CSI_HOP_CHANNELS_MAX);
|
||||
return;
|
||||
}
|
||||
if (dwell_ms < 10) {
|
||||
ESP_LOGW(TAG, "csi_collector_set_hop_table: dwell_ms=%lu too small, clamping to 10",
|
||||
(unsigned long)dwell_ms);
|
||||
dwell_ms = 10;
|
||||
}
|
||||
|
||||
memcpy(s_hop_channels, channels, hop_count);
|
||||
s_hop_count = hop_count;
|
||||
s_dwell_ms = dwell_ms;
|
||||
s_hop_index = 0;
|
||||
|
||||
ESP_LOGI(TAG, "Hop table set: %u channels, dwell=%lu ms", (unsigned)hop_count,
|
||||
(unsigned long)dwell_ms);
|
||||
for (uint8_t i = 0; i < hop_count; i++) {
|
||||
ESP_LOGI(TAG, " hop[%u] = channel %u", (unsigned)i, (unsigned)channels[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void csi_hop_next_channel(void)
|
||||
{
|
||||
if (s_hop_count <= 1) {
|
||||
/* Single-channel mode: no-op for backward compatibility. */
|
||||
return;
|
||||
}
|
||||
|
||||
s_hop_index = (s_hop_index + 1) % s_hop_count;
|
||||
uint8_t channel = s_hop_channels[s_hop_index];
|
||||
|
||||
/*
|
||||
* esp_wifi_set_channel() changes the primary channel.
|
||||
* The second parameter is the secondary channel offset for HT40;
|
||||
* we use HT20 (no secondary) for sensing.
|
||||
*/
|
||||
esp_err_t err = esp_wifi_set_channel(channel, WIFI_SECOND_CHAN_NONE);
|
||||
if (err != ESP_OK) {
|
||||
ESP_LOGW(TAG, "Channel hop to %u failed: %s", (unsigned)channel, esp_err_to_name(err));
|
||||
} else if ((s_cb_count % 200) == 0) {
|
||||
/* Periodic log to confirm hopping is working (not every hop). */
|
||||
ESP_LOGI(TAG, "Hopped to channel %u (index %u/%u)",
|
||||
(unsigned)channel, (unsigned)s_hop_index, (unsigned)s_hop_count);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Timer callback for channel hopping.
|
||||
* Called every s_dwell_ms milliseconds from the esp_timer context.
|
||||
*/
|
||||
static void hop_timer_cb(void *arg)
|
||||
{
|
||||
(void)arg;
|
||||
csi_hop_next_channel();
|
||||
}
|
||||
|
||||
void csi_collector_start_hop_timer(void)
|
||||
{
|
||||
if (s_hop_count <= 1) {
|
||||
ESP_LOGI(TAG, "Single-channel mode: hop timer not started");
|
||||
return;
|
||||
}
|
||||
|
||||
if (s_hop_timer != NULL) {
|
||||
ESP_LOGW(TAG, "Hop timer already running");
|
||||
return;
|
||||
}
|
||||
|
||||
esp_timer_create_args_t timer_args = {
|
||||
.callback = hop_timer_cb,
|
||||
.arg = NULL,
|
||||
.name = "csi_hop",
|
||||
};
|
||||
|
||||
esp_err_t err = esp_timer_create(&timer_args, &s_hop_timer);
|
||||
if (err != ESP_OK) {
|
||||
ESP_LOGE(TAG, "Failed to create hop timer: %s", esp_err_to_name(err));
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t period_us = (uint64_t)s_dwell_ms * 1000;
|
||||
err = esp_timer_start_periodic(s_hop_timer, period_us);
|
||||
if (err != ESP_OK) {
|
||||
ESP_LOGE(TAG, "Failed to start hop timer: %s", esp_err_to_name(err));
|
||||
esp_timer_delete(s_hop_timer);
|
||||
s_hop_timer = NULL;
|
||||
return;
|
||||
}
|
||||
|
||||
ESP_LOGI(TAG, "Hop timer started: period=%lu ms, channels=%u",
|
||||
(unsigned long)s_dwell_ms, (unsigned)s_hop_count);
|
||||
}
|
||||
|
||||
/* ---- ADR-029: NDP frame injection stub ---- */
|
||||
|
||||
esp_err_t csi_inject_ndp_frame(void)
|
||||
{
|
||||
/*
|
||||
* TODO: Construct a proper 802.11 Null Data Packet frame.
|
||||
*
|
||||
* A real NDP is preamble-only (~24 us airtime, no payload) and is the
|
||||
* sensing-first TX mechanism described in ADR-029. For now we send a
|
||||
* minimal null-data frame as a placeholder so the API is wired up.
|
||||
*
|
||||
* Frame structure (IEEE 802.11 Null Data):
|
||||
* FC (2) | Duration (2) | Addr1 (6) | Addr2 (6) | Addr3 (6) | SeqCtl (2)
|
||||
* = 24 bytes total, no body, no FCS (hardware appends FCS).
|
||||
*/
|
||||
uint8_t ndp_frame[24];
|
||||
memset(ndp_frame, 0, sizeof(ndp_frame));
|
||||
|
||||
/* Frame Control: Type=Data (0x02), Subtype=Null (0x04) -> 0x0048 */
|
||||
ndp_frame[0] = 0x48;
|
||||
ndp_frame[1] = 0x00;
|
||||
|
||||
/* Duration: 0 (let hardware fill) */
|
||||
|
||||
/* Addr1 (destination): broadcast */
|
||||
memset(&ndp_frame[4], 0xFF, 6);
|
||||
|
||||
/* Addr2 (source): will be overwritten by hardware with own MAC */
|
||||
|
||||
/* Addr3 (BSSID): broadcast */
|
||||
memset(&ndp_frame[16], 0xFF, 6);
|
||||
|
||||
esp_err_t err = esp_wifi_80211_tx(WIFI_IF_STA, ndp_frame, sizeof(ndp_frame), false);
|
||||
if (err != ESP_OK) {
|
||||
ESP_LOGW(TAG, "NDP inject failed: %s", esp_err_to_name(err));
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
@@ -19,6 +19,9 @@
|
||||
/** Maximum frame buffer size (header + 4 antennas * 256 subcarriers * 2 bytes). */
|
||||
#define CSI_MAX_FRAME_SIZE (CSI_HEADER_SIZE + 4 * 256 * 2)
|
||||
|
||||
/** Maximum number of channels in the hop table (ADR-029). */
|
||||
#define CSI_HOP_CHANNELS_MAX 6
|
||||
|
||||
/**
|
||||
* Initialize CSI collection.
|
||||
* Registers the WiFi CSI callback.
|
||||
@@ -35,4 +38,47 @@ void csi_collector_init(void);
|
||||
*/
|
||||
size_t csi_serialize_frame(const wifi_csi_info_t *info, uint8_t *buf, size_t buf_len);
|
||||
|
||||
/**
|
||||
* Configure the channel-hop table for multi-band sensing (ADR-029).
|
||||
*
|
||||
* When hop_count == 1 the collector stays on the single configured channel
|
||||
* (backward-compatible with the original single-channel mode).
|
||||
*
|
||||
* @param channels Array of WiFi channel numbers (1-14 for 2.4 GHz, 36-177 for 5 GHz).
|
||||
* @param hop_count Number of entries in the channels array (1..CSI_HOP_CHANNELS_MAX).
|
||||
* @param dwell_ms Dwell time per channel in milliseconds (>= 10).
|
||||
*/
|
||||
void csi_collector_set_hop_table(const uint8_t *channels, uint8_t hop_count, uint32_t dwell_ms);
|
||||
|
||||
/**
|
||||
* Advance to the next channel in the hop table.
|
||||
*
|
||||
* Called by the hop timer callback. If hop_count <= 1 this is a no-op.
|
||||
* Calls esp_wifi_set_channel() internally.
|
||||
*/
|
||||
void csi_hop_next_channel(void);
|
||||
|
||||
/**
|
||||
* Start the channel-hop timer.
|
||||
*
|
||||
* Creates an esp_timer periodic callback that fires every dwell_ms
|
||||
* milliseconds, calling csi_hop_next_channel(). If hop_count <= 1
|
||||
* the timer is not started (single-channel backward-compatible mode).
|
||||
*/
|
||||
void csi_collector_start_hop_timer(void);
|
||||
|
||||
/**
|
||||
* Inject an NDP (Null Data Packet) frame for sensing.
|
||||
*
|
||||
* Uses esp_wifi_80211_tx() to send a preamble-only frame (~24 us airtime)
|
||||
* that triggers CSI measurement at all receivers. This is the "sensing-first"
|
||||
* TX mechanism described in ADR-029.
|
||||
*
|
||||
* @return ESP_OK on success, or an error code.
|
||||
*
|
||||
* @note TODO: Full NDP frame construction. Currently sends a minimal
|
||||
* null-data frame as a placeholder.
|
||||
*/
|
||||
esp_err_t csi_inject_ndp_frame(void);
|
||||
|
||||
#endif /* CSI_COLLECTOR_H */
|
||||
|
||||
@@ -18,6 +18,11 @@ static const char *TAG = "nvs_config";
|
||||
|
||||
void nvs_config_load(nvs_config_t *cfg)
|
||||
{
|
||||
if (cfg == NULL) {
|
||||
ESP_LOGE(TAG, "nvs_config_load: cfg is NULL");
|
||||
return;
|
||||
}
|
||||
|
||||
/* Start with Kconfig compiled defaults */
|
||||
strncpy(cfg->wifi_ssid, CONFIG_CSI_WIFI_SSID, NVS_CFG_SSID_MAX - 1);
|
||||
cfg->wifi_ssid[NVS_CFG_SSID_MAX - 1] = '\0';
|
||||
@@ -35,6 +40,17 @@ void nvs_config_load(nvs_config_t *cfg)
|
||||
cfg->target_port = (uint16_t)CONFIG_CSI_TARGET_PORT;
|
||||
cfg->node_id = (uint8_t)CONFIG_CSI_NODE_ID;
|
||||
|
||||
/* ADR-029: Defaults for channel hopping and TDM.
|
||||
* hop_count=1 means single-channel (backward-compatible). */
|
||||
cfg->channel_hop_count = 1;
|
||||
cfg->channel_list[0] = (uint8_t)CONFIG_CSI_WIFI_CHANNEL;
|
||||
for (uint8_t i = 1; i < NVS_CFG_HOP_MAX; i++) {
|
||||
cfg->channel_list[i] = 0;
|
||||
}
|
||||
cfg->dwell_ms = 50;
|
||||
cfg->tdm_slot_index = 0;
|
||||
cfg->tdm_node_count = 1;
|
||||
|
||||
/* Try to override from NVS */
|
||||
nvs_handle_t handle;
|
||||
esp_err_t err = nvs_open("csi_cfg", NVS_READONLY, &handle);
|
||||
@@ -84,5 +100,64 @@ void nvs_config_load(nvs_config_t *cfg)
|
||||
ESP_LOGI(TAG, "NVS override: node_id=%u", cfg->node_id);
|
||||
}
|
||||
|
||||
/* ADR-029: Channel hop count */
|
||||
uint8_t hop_count_val;
|
||||
if (nvs_get_u8(handle, "hop_count", &hop_count_val) == ESP_OK) {
|
||||
if (hop_count_val >= 1 && hop_count_val <= NVS_CFG_HOP_MAX) {
|
||||
cfg->channel_hop_count = hop_count_val;
|
||||
ESP_LOGI(TAG, "NVS override: hop_count=%u", (unsigned)cfg->channel_hop_count);
|
||||
} else {
|
||||
ESP_LOGW(TAG, "NVS hop_count=%u out of range [1..%u], ignored",
|
||||
(unsigned)hop_count_val, (unsigned)NVS_CFG_HOP_MAX);
|
||||
}
|
||||
}
|
||||
|
||||
/* ADR-029: Channel list (stored as a blob of up to NVS_CFG_HOP_MAX bytes) */
|
||||
len = NVS_CFG_HOP_MAX;
|
||||
uint8_t ch_blob[NVS_CFG_HOP_MAX];
|
||||
if (nvs_get_blob(handle, "chan_list", ch_blob, &len) == ESP_OK && len > 0) {
|
||||
uint8_t count = (len < cfg->channel_hop_count) ? (uint8_t)len : cfg->channel_hop_count;
|
||||
for (uint8_t i = 0; i < count; i++) {
|
||||
cfg->channel_list[i] = ch_blob[i];
|
||||
}
|
||||
ESP_LOGI(TAG, "NVS override: chan_list loaded (%u channels)", (unsigned)count);
|
||||
}
|
||||
|
||||
/* ADR-029: Dwell time */
|
||||
uint32_t dwell_val;
|
||||
if (nvs_get_u32(handle, "dwell_ms", &dwell_val) == ESP_OK) {
|
||||
if (dwell_val >= 10) {
|
||||
cfg->dwell_ms = dwell_val;
|
||||
ESP_LOGI(TAG, "NVS override: dwell_ms=%lu", (unsigned long)cfg->dwell_ms);
|
||||
} else {
|
||||
ESP_LOGW(TAG, "NVS dwell_ms=%lu too small, ignored", (unsigned long)dwell_val);
|
||||
}
|
||||
}
|
||||
|
||||
/* ADR-029/031: TDM slot index */
|
||||
uint8_t slot_val;
|
||||
if (nvs_get_u8(handle, "tdm_slot", &slot_val) == ESP_OK) {
|
||||
cfg->tdm_slot_index = slot_val;
|
||||
ESP_LOGI(TAG, "NVS override: tdm_slot_index=%u", (unsigned)cfg->tdm_slot_index);
|
||||
}
|
||||
|
||||
/* ADR-029/031: TDM node count */
|
||||
uint8_t tdm_nodes_val;
|
||||
if (nvs_get_u8(handle, "tdm_nodes", &tdm_nodes_val) == ESP_OK) {
|
||||
if (tdm_nodes_val >= 1) {
|
||||
cfg->tdm_node_count = tdm_nodes_val;
|
||||
ESP_LOGI(TAG, "NVS override: tdm_node_count=%u", (unsigned)cfg->tdm_node_count);
|
||||
} else {
|
||||
ESP_LOGW(TAG, "NVS tdm_nodes=%u invalid, ignored", (unsigned)tdm_nodes_val);
|
||||
}
|
||||
}
|
||||
|
||||
/* Validate tdm_slot_index < tdm_node_count */
|
||||
if (cfg->tdm_slot_index >= cfg->tdm_node_count) {
|
||||
ESP_LOGW(TAG, "tdm_slot_index=%u >= tdm_node_count=%u, clamping to 0",
|
||||
(unsigned)cfg->tdm_slot_index, (unsigned)cfg->tdm_node_count);
|
||||
cfg->tdm_slot_index = 0;
|
||||
}
|
||||
|
||||
nvs_close(handle);
|
||||
}
|
||||
|
||||
@@ -18,6 +18,9 @@
|
||||
#define NVS_CFG_PASS_MAX 65
|
||||
#define NVS_CFG_IP_MAX 16
|
||||
|
||||
/** Maximum channels in the hop list (must match CSI_HOP_CHANNELS_MAX). */
|
||||
#define NVS_CFG_HOP_MAX 6
|
||||
|
||||
/** Runtime configuration loaded from NVS or Kconfig defaults. */
|
||||
typedef struct {
|
||||
char wifi_ssid[NVS_CFG_SSID_MAX];
|
||||
@@ -25,6 +28,13 @@ typedef struct {
|
||||
char target_ip[NVS_CFG_IP_MAX];
|
||||
uint16_t target_port;
|
||||
uint8_t node_id;
|
||||
|
||||
/* ADR-029: Channel hopping and TDM configuration */
|
||||
uint8_t channel_hop_count; /**< Number of channels to hop (1 = no hop). */
|
||||
uint8_t channel_list[NVS_CFG_HOP_MAX]; /**< Channel numbers for hopping. */
|
||||
uint32_t dwell_ms; /**< Dwell time per channel in ms. */
|
||||
uint8_t tdm_slot_index; /**< This node's TDM slot index (0-based). */
|
||||
uint8_t tdm_node_count; /**< Total nodes in the TDM schedule. */
|
||||
} nvs_config_t;
|
||||
|
||||
/**
|
||||
|
||||
579
rust-port/wifi-densepose-rs/Cargo.lock
generated
579
rust-port/wifi-densepose-rs/Cargo.lock
generated
@@ -488,12 +488,24 @@ dependencies = [
|
||||
"shlex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cesu8"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c"
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
|
||||
|
||||
[[package]]
|
||||
name = "cfg_aliases"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
|
||||
|
||||
[[package]]
|
||||
name = "chrono"
|
||||
version = "0.4.44"
|
||||
@@ -601,6 +613,16 @@ dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "combine"
|
||||
version = "4.6.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "console"
|
||||
version = "0.15.11"
|
||||
@@ -915,6 +937,18 @@ dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastbloom"
|
||||
version = "0.14.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4e7f34442dbe69c60fe8eaf58a8cafff81a1f278816d8ab4db255b3bef4ac3c4"
|
||||
dependencies = [
|
||||
"getrandom 0.3.4",
|
||||
"libm",
|
||||
"rand 0.9.2",
|
||||
"siphasher",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.3.0"
|
||||
@@ -1291,9 +1325,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"js-sys",
|
||||
"libc",
|
||||
"r-efi",
|
||||
"wasip2",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1634,6 +1670,28 @@ version = "1.0.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2"
|
||||
|
||||
[[package]]
|
||||
name = "jni"
|
||||
version = "0.21.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97"
|
||||
dependencies = [
|
||||
"cesu8",
|
||||
"cfg-if",
|
||||
"combine",
|
||||
"jni-sys",
|
||||
"log",
|
||||
"thiserror 1.0.69",
|
||||
"walkdir",
|
||||
"windows-sys 0.45.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jni-sys"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
|
||||
|
||||
[[package]]
|
||||
name = "jobserver"
|
||||
version = "0.1.34"
|
||||
@@ -1699,6 +1757,21 @@ version = "0.4.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
|
||||
|
||||
[[package]]
|
||||
name = "lru"
|
||||
version = "0.12.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38"
|
||||
dependencies = [
|
||||
"hashbrown 0.15.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lru-slab"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
|
||||
|
||||
[[package]]
|
||||
name = "lzma-rust2"
|
||||
version = "0.15.7"
|
||||
@@ -1746,6 +1819,63 @@ dependencies = [
|
||||
"stable_deref_trait",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "midstreamer-attractor"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ab86df06cf1705ca37692b4fc0027868f92e5170a7ebb1d706302f04b6044f70"
|
||||
dependencies = [
|
||||
"midstreamer-temporal-compare",
|
||||
"nalgebra",
|
||||
"ndarray 0.16.1",
|
||||
"serde",
|
||||
"thiserror 2.0.18",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "midstreamer-quic"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "35ad2099588e987cdbedb039fdf8a56163a2f3dc1ff6bf5a39c63b9ce4e2248c"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"js-sys",
|
||||
"quinn",
|
||||
"rcgen",
|
||||
"rustls 0.22.4",
|
||||
"serde",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "midstreamer-scheduler"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a9296b3f0a2b04e5c1a378ee7926e9f892895bface2ccebcfa407450c3aca269"
|
||||
dependencies = [
|
||||
"crossbeam",
|
||||
"parking_lot",
|
||||
"serde",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "midstreamer-temporal-compare"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e1f935ba86c1632a3b5bc5e1cb56a308d4c5d2ec87c84db551c65f3e1001a642"
|
||||
dependencies = [
|
||||
"dashmap",
|
||||
"lru",
|
||||
"serde",
|
||||
"thiserror 2.0.18",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mime"
|
||||
version = "0.3.17"
|
||||
@@ -1819,6 +1949,33 @@ dependencies = [
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nalgebra"
|
||||
version = "0.33.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26aecdf64b707efd1310e3544d709c5c0ac61c13756046aaaba41be5c4f66a3b"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"matrixmultiply",
|
||||
"nalgebra-macros",
|
||||
"num-complex",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"simba",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nalgebra-macros"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "254a5372af8fc138e36684761d3c0cdb758a4410e938babcff1c860ce14ddbfc"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "native-tls"
|
||||
version = "0.2.18"
|
||||
@@ -1955,6 +2112,17 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
|
||||
dependencies = [
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
@@ -2147,6 +2315,16 @@ version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ec91767ecc0a0bbe558ce8c9da33c068066c57ecc8bb8477ef8c1ad3ef77c27"
|
||||
|
||||
[[package]]
|
||||
name = "pem"
|
||||
version = "3.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"serde_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pem-rfc7468"
|
||||
version = "0.7.0"
|
||||
@@ -2443,6 +2621,63 @@ version = "1.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
|
||||
|
||||
[[package]]
|
||||
name = "quinn"
|
||||
version = "0.11.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"cfg_aliases",
|
||||
"pin-project-lite",
|
||||
"quinn-proto",
|
||||
"quinn-udp",
|
||||
"rustc-hash",
|
||||
"rustls 0.23.37",
|
||||
"socket2",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn-proto"
|
||||
version = "0.11.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fastbloom",
|
||||
"getrandom 0.3.4",
|
||||
"lru-slab",
|
||||
"rand 0.9.2",
|
||||
"ring",
|
||||
"rustc-hash",
|
||||
"rustls 0.23.37",
|
||||
"rustls-pki-types",
|
||||
"rustls-platform-verifier",
|
||||
"slab",
|
||||
"thiserror 2.0.18",
|
||||
"tinyvec",
|
||||
"tracing",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn-udp"
|
||||
version = "0.5.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd"
|
||||
dependencies = [
|
||||
"cfg_aliases",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"socket2",
|
||||
"tracing",
|
||||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.44"
|
||||
@@ -2590,6 +2825,18 @@ dependencies = [
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rcgen"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "48406db8ac1f3cbc7dcdb56ec355343817958a356ff430259bb07baf7607e1e1"
|
||||
dependencies = [
|
||||
"pem",
|
||||
"ring",
|
||||
"time",
|
||||
"yasna",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "reborrow"
|
||||
version = "0.5.5"
|
||||
@@ -2643,6 +2890,20 @@ dependencies = [
|
||||
"bytecheck",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ring"
|
||||
version = "0.17.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"cfg-if",
|
||||
"getrandom 0.2.17",
|
||||
"libc",
|
||||
"untrusted",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rkyv"
|
||||
version = "0.8.15"
|
||||
@@ -2750,6 +3011,12 @@ dependencies = [
|
||||
"smallvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "2.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
|
||||
|
||||
[[package]]
|
||||
name = "rustc_version"
|
||||
version = "0.4.1"
|
||||
@@ -2786,15 +3053,105 @@ dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.22.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432"
|
||||
dependencies = [
|
||||
"log",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki 0.102.8",
|
||||
"subtle",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.23.37"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki 0.103.9",
|
||||
"subtle",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-native-certs"
|
||||
version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63"
|
||||
dependencies = [
|
||||
"openssl-probe",
|
||||
"rustls-pki-types",
|
||||
"schannel",
|
||||
"security-framework",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pki-types"
|
||||
version = "1.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd"
|
||||
dependencies = [
|
||||
"web-time",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-platform-verifier"
|
||||
version = "0.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784"
|
||||
dependencies = [
|
||||
"core-foundation",
|
||||
"core-foundation-sys",
|
||||
"jni",
|
||||
"log",
|
||||
"once_cell",
|
||||
"rustls 0.23.37",
|
||||
"rustls-native-certs",
|
||||
"rustls-platform-verifier-android",
|
||||
"rustls-webpki 0.103.9",
|
||||
"security-framework",
|
||||
"security-framework-sys",
|
||||
"webpki-root-certs",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-platform-verifier-android"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f"
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.102.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.103.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.22"
|
||||
@@ -2813,6 +3170,18 @@ dependencies = [
|
||||
"wait-timeout",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ruvector-attention"
|
||||
version = "0.1.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ef4c2b4ef9db0d5a038c5cb8e9e91ffc11c789db660132d50165d2ba6a71d23f"
|
||||
dependencies = [
|
||||
"rand 0.8.5",
|
||||
"rayon",
|
||||
"serde",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ruvector-attention"
|
||||
version = "2.0.4"
|
||||
@@ -2859,6 +3228,40 @@ dependencies = [
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ruvector-crv"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eda8d6533ed1337e75f0bcc9e6e31cff44cc32aa24f9673492b2fad3af09a120"
|
||||
dependencies = [
|
||||
"ruvector-attention 0.1.32",
|
||||
"ruvector-gnn",
|
||||
"ruvector-mincut",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ruvector-gnn"
|
||||
version = "2.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e17c1cf1ff3380026b299ff3c1ba3a5685c3d8d54700e6ab0b585b6cec21d7b"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"dashmap",
|
||||
"libc",
|
||||
"ndarray 0.16.1",
|
||||
"parking_lot",
|
||||
"rand 0.8.5",
|
||||
"rand_distr 0.4.3",
|
||||
"rayon",
|
||||
"ruvector-core",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.18",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ruvector-mincut"
|
||||
version = "2.0.4"
|
||||
@@ -2908,6 +3311,15 @@ version = "1.0.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f"
|
||||
|
||||
[[package]]
|
||||
name = "safe_arch"
|
||||
version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96b02de82ddbe1b636e6170c21be622223aea188ef2e139be0a5b219ec215323"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "safetensors"
|
||||
version = "0.3.3"
|
||||
@@ -3120,6 +3532,19 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simba"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c99284beb21666094ba2b75bbceda012e610f5479dfcc2d6e2426f53197ffd95"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
"wide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simd-adler32"
|
||||
version = "0.3.8"
|
||||
@@ -3132,6 +3557,12 @@ version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e"
|
||||
|
||||
[[package]]
|
||||
name = "siphasher"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e"
|
||||
|
||||
[[package]]
|
||||
name = "slab"
|
||||
version = "0.4.12"
|
||||
@@ -3750,6 +4181,12 @@ version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
|
||||
|
||||
[[package]]
|
||||
name = "untrusted"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
|
||||
|
||||
[[package]]
|
||||
name = "unty"
|
||||
version = "0.0.4"
|
||||
@@ -4070,13 +4507,23 @@ dependencies = [
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wide"
|
||||
version = "0.7.33"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ce5da8ecb62bcd8ec8b7ea19f69a51275e91299be594ea5cc6ef7819e16cd03"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"safe_arch",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-api"
|
||||
version = "0.1.0"
|
||||
version = "0.3.0"
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-cli"
|
||||
version = "0.1.0"
|
||||
version = "0.3.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
@@ -4101,11 +4548,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-config"
|
||||
version = "0.1.0"
|
||||
version = "0.3.0"
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-core"
|
||||
version = "0.1.0"
|
||||
version = "0.3.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"chrono",
|
||||
@@ -4121,25 +4568,29 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-db"
|
||||
version = "0.1.0"
|
||||
version = "0.3.0"
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-hardware"
|
||||
version = "0.1.0"
|
||||
version = "0.3.0"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"byteorder",
|
||||
"chrono",
|
||||
"clap",
|
||||
"criterion",
|
||||
"midstreamer-quic",
|
||||
"midstreamer-scheduler",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-mat"
|
||||
version = "0.1.0"
|
||||
version = "0.3.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"approx",
|
||||
@@ -4170,7 +4621,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-nn"
|
||||
version = "0.1.0"
|
||||
version = "0.3.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"candle-core",
|
||||
@@ -4193,19 +4644,25 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-ruvector"
|
||||
version = "0.1.0"
|
||||
version = "0.3.0"
|
||||
dependencies = [
|
||||
"ruvector-attention",
|
||||
"approx",
|
||||
"criterion",
|
||||
"ruvector-attention 2.0.4",
|
||||
"ruvector-attn-mincut",
|
||||
"ruvector-crv",
|
||||
"ruvector-gnn",
|
||||
"ruvector-mincut",
|
||||
"ruvector-solver",
|
||||
"ruvector-temporal-tensor",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-sensing-server"
|
||||
version = "0.1.0"
|
||||
version = "0.3.0"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"chrono",
|
||||
@@ -4223,16 +4680,18 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-signal"
|
||||
version = "0.1.0"
|
||||
version = "0.3.0"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"criterion",
|
||||
"midstreamer-attractor",
|
||||
"midstreamer-temporal-compare",
|
||||
"ndarray 0.15.6",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"proptest",
|
||||
"rustfft",
|
||||
"ruvector-attention",
|
||||
"ruvector-attention 2.0.4",
|
||||
"ruvector-attn-mincut",
|
||||
"ruvector-mincut",
|
||||
"ruvector-solver",
|
||||
@@ -4244,7 +4703,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-train"
|
||||
version = "0.1.0"
|
||||
version = "0.3.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"approx",
|
||||
@@ -4260,7 +4719,7 @@ dependencies = [
|
||||
"num-traits",
|
||||
"petgraph",
|
||||
"proptest",
|
||||
"ruvector-attention",
|
||||
"ruvector-attention 2.0.4",
|
||||
"ruvector-attn-mincut",
|
||||
"ruvector-mincut",
|
||||
"ruvector-solver",
|
||||
@@ -4282,7 +4741,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-vitals"
|
||||
version = "0.1.0"
|
||||
version = "0.3.0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -4291,7 +4750,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-wasm"
|
||||
version = "0.1.0"
|
||||
version = "0.3.0"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"console_error_panic_hook",
|
||||
@@ -4313,7 +4772,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-wifiscan"
|
||||
version = "0.1.0"
|
||||
version = "0.3.0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"tokio",
|
||||
@@ -4410,6 +4869,24 @@ dependencies = [
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.45.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0"
|
||||
dependencies = [
|
||||
"windows-targets 0.42.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.52.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
|
||||
dependencies = [
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.59.0"
|
||||
@@ -4437,6 +4914,21 @@ dependencies = [
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm 0.42.2",
|
||||
"windows_aarch64_msvc 0.42.2",
|
||||
"windows_i686_gnu 0.42.2",
|
||||
"windows_i686_msvc 0.42.2",
|
||||
"windows_x86_64_gnu 0.42.2",
|
||||
"windows_x86_64_gnullvm 0.42.2",
|
||||
"windows_x86_64_msvc 0.42.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.52.6"
|
||||
@@ -4470,6 +4962,12 @@ dependencies = [
|
||||
"windows_x86_64_msvc 0.53.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.52.6"
|
||||
@@ -4482,6 +4980,12 @@ version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.52.6"
|
||||
@@ -4494,6 +4998,12 @@ version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.52.6"
|
||||
@@ -4518,6 +5028,12 @@ version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.52.6"
|
||||
@@ -4530,6 +5046,12 @@ version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.52.6"
|
||||
@@ -4542,6 +5064,12 @@ version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.52.6"
|
||||
@@ -4554,6 +5082,12 @@ version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.52.6"
|
||||
@@ -4663,6 +5197,15 @@ dependencies = [
|
||||
"wasmparser",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yasna"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd"
|
||||
dependencies = [
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yoke"
|
||||
version = "0.7.5"
|
||||
|
||||
@@ -19,7 +19,7 @@ members = [
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
version = "0.2.0"
|
||||
version = "0.3.0"
|
||||
edition = "2021"
|
||||
authors = ["rUv <ruv@ruv.net>", "WiFi-DensePose Contributors"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
@@ -103,25 +103,33 @@ proptest = "1.4"
|
||||
mockall = "0.12"
|
||||
wiremock = "0.5"
|
||||
|
||||
# ruvector integration (all at v2.0.4 — published on crates.io)
|
||||
# midstreamer integration (published on crates.io)
|
||||
midstreamer-quic = "0.1.0"
|
||||
midstreamer-scheduler = "0.1.0"
|
||||
midstreamer-temporal-compare = "0.1.0"
|
||||
midstreamer-attractor = "0.1.0"
|
||||
|
||||
# ruvector integration (published on crates.io)
|
||||
ruvector-mincut = "2.0.4"
|
||||
ruvector-attn-mincut = "2.0.4"
|
||||
ruvector-temporal-tensor = "2.0.4"
|
||||
ruvector-solver = "2.0.4"
|
||||
ruvector-attention = "2.0.4"
|
||||
ruvector-crv = "0.1.1"
|
||||
ruvector-gnn = { version = "2.0.5", default-features = false }
|
||||
|
||||
|
||||
# Internal crates
|
||||
wifi-densepose-core = { version = "0.2.0", path = "crates/wifi-densepose-core" }
|
||||
wifi-densepose-signal = { version = "0.2.0", path = "crates/wifi-densepose-signal" }
|
||||
wifi-densepose-nn = { version = "0.2.0", path = "crates/wifi-densepose-nn" }
|
||||
wifi-densepose-api = { version = "0.2.0", path = "crates/wifi-densepose-api" }
|
||||
wifi-densepose-db = { version = "0.2.0", path = "crates/wifi-densepose-db" }
|
||||
wifi-densepose-config = { version = "0.2.0", path = "crates/wifi-densepose-config" }
|
||||
wifi-densepose-hardware = { version = "0.2.0", path = "crates/wifi-densepose-hardware" }
|
||||
wifi-densepose-wasm = { version = "0.2.0", path = "crates/wifi-densepose-wasm" }
|
||||
wifi-densepose-mat = { version = "0.2.0", path = "crates/wifi-densepose-mat" }
|
||||
wifi-densepose-ruvector = { version = "0.2.0", path = "crates/wifi-densepose-ruvector" }
|
||||
wifi-densepose-core = { version = "0.3.0", path = "crates/wifi-densepose-core" }
|
||||
wifi-densepose-signal = { version = "0.3.0", path = "crates/wifi-densepose-signal" }
|
||||
wifi-densepose-nn = { version = "0.3.0", path = "crates/wifi-densepose-nn" }
|
||||
wifi-densepose-api = { version = "0.3.0", path = "crates/wifi-densepose-api" }
|
||||
wifi-densepose-db = { version = "0.3.0", path = "crates/wifi-densepose-db" }
|
||||
wifi-densepose-config = { version = "0.3.0", path = "crates/wifi-densepose-config" }
|
||||
wifi-densepose-hardware = { version = "0.3.0", path = "crates/wifi-densepose-hardware" }
|
||||
wifi-densepose-wasm = { version = "0.3.0", path = "crates/wifi-densepose-wasm" }
|
||||
wifi-densepose-mat = { version = "0.3.0", path = "crates/wifi-densepose-mat" }
|
||||
wifi-densepose-ruvector = { version = "0.3.0", path = "crates/wifi-densepose-ruvector" }
|
||||
|
||||
[profile.release]
|
||||
lto = true
|
||||
|
||||
@@ -21,7 +21,7 @@ mat = []
|
||||
|
||||
[dependencies]
|
||||
# Internal crates
|
||||
wifi-densepose-mat = { version = "0.2.0", path = "../wifi-densepose-mat" }
|
||||
wifi-densepose-mat = { version = "0.3.0", path = "../wifi-densepose-mat" }
|
||||
|
||||
# CLI framework
|
||||
clap = { version = "4.4", features = ["derive", "env", "cargo"] }
|
||||
|
||||
@@ -36,5 +36,18 @@ tracing = "0.1"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
|
||||
# QUIC transport (ADR-032a)
|
||||
midstreamer-quic = { workspace = true }
|
||||
# Real-time TDM scheduling (ADR-032a)
|
||||
midstreamer-scheduler = { workspace = true }
|
||||
# Async runtime
|
||||
tokio = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
approx = "0.5"
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
tokio = { workspace = true }
|
||||
|
||||
[[bench]]
|
||||
name = "transport_bench"
|
||||
harness = false
|
||||
|
||||
@@ -0,0 +1,196 @@
|
||||
//! Benchmarks comparing manual crypto vs QUIC transport for TDM beacons.
|
||||
//!
|
||||
//! Measures:
|
||||
//! - Beacon serialization (16-byte vs 28-byte vs QUIC-framed)
|
||||
//! - Beacon verification throughput
|
||||
//! - Replay window check performance
|
||||
//! - FramedMessage encode/decode throughput
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion, BenchmarkId};
|
||||
use std::time::Duration;
|
||||
use wifi_densepose_hardware::esp32::{
|
||||
TdmSchedule, SyncBeacon, SecurityMode, QuicTransportConfig,
|
||||
SecureTdmCoordinator, SecureTdmConfig, SecLevel,
|
||||
AuthenticatedBeacon, ReplayWindow, FramedMessage, MessageType,
|
||||
};
|
||||
|
||||
fn make_beacon() -> SyncBeacon {
|
||||
SyncBeacon {
|
||||
cycle_id: 42,
|
||||
cycle_period: Duration::from_millis(50),
|
||||
drift_correction_us: -3,
|
||||
generated_at: std::time::Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
fn bench_beacon_serialize_plain(c: &mut Criterion) {
|
||||
let beacon = make_beacon();
|
||||
c.bench_function("beacon_serialize_16byte", |b| {
|
||||
b.iter(|| {
|
||||
black_box(beacon.to_bytes());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_beacon_serialize_authenticated(c: &mut Criterion) {
|
||||
let beacon = make_beacon();
|
||||
let key = [0x01u8; 16];
|
||||
let nonce = 1u32;
|
||||
let mut msg = [0u8; 20];
|
||||
msg[..16].copy_from_slice(&beacon.to_bytes());
|
||||
msg[16..20].copy_from_slice(&nonce.to_le_bytes());
|
||||
|
||||
c.bench_function("beacon_serialize_28byte_auth", |b| {
|
||||
b.iter(|| {
|
||||
let tag = AuthenticatedBeacon::compute_tag(black_box(&msg), &key);
|
||||
black_box(AuthenticatedBeacon {
|
||||
beacon: beacon.clone(),
|
||||
nonce,
|
||||
hmac_tag: tag,
|
||||
}
|
||||
.to_bytes());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_beacon_serialize_quic_framed(c: &mut Criterion) {
|
||||
let beacon = make_beacon();
|
||||
|
||||
c.bench_function("beacon_serialize_quic_framed", |b| {
|
||||
b.iter(|| {
|
||||
let bytes = beacon.to_bytes();
|
||||
let framed = FramedMessage::new(MessageType::Beacon, bytes.to_vec());
|
||||
black_box(framed.to_bytes());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_auth_beacon_verify(c: &mut Criterion) {
|
||||
let beacon = make_beacon();
|
||||
let key = [0x01u8; 16];
|
||||
let nonce = 1u32;
|
||||
let mut msg = [0u8; 20];
|
||||
msg[..16].copy_from_slice(&beacon.to_bytes());
|
||||
msg[16..20].copy_from_slice(&nonce.to_le_bytes());
|
||||
let tag = AuthenticatedBeacon::compute_tag(&msg, &key);
|
||||
let auth = AuthenticatedBeacon {
|
||||
beacon,
|
||||
nonce,
|
||||
hmac_tag: tag,
|
||||
};
|
||||
|
||||
c.bench_function("auth_beacon_verify", |b| {
|
||||
b.iter(|| {
|
||||
black_box(auth.verify(&key)).unwrap();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_replay_window(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("replay_window");
|
||||
|
||||
for window_size in [4u32, 16, 64, 256] {
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("check_accept", window_size),
|
||||
&window_size,
|
||||
|b, &ws| {
|
||||
b.iter(|| {
|
||||
let mut rw = ReplayWindow::new(ws);
|
||||
for i in 0..1000u32 {
|
||||
black_box(rw.accept(i));
|
||||
}
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_framed_message_roundtrip(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("framed_message");
|
||||
|
||||
for payload_size in [16usize, 128, 512, 2048] {
|
||||
let payload = vec![0xABu8; payload_size];
|
||||
let msg = FramedMessage::new(MessageType::CsiFrame, payload);
|
||||
let bytes = msg.to_bytes();
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("encode", payload_size),
|
||||
&msg,
|
||||
|b, msg| {
|
||||
b.iter(|| {
|
||||
black_box(msg.to_bytes());
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("decode", payload_size),
|
||||
&bytes,
|
||||
|b, bytes| {
|
||||
b.iter(|| {
|
||||
black_box(FramedMessage::from_bytes(bytes));
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_secure_coordinator_cycle(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("secure_tdm_cycle");
|
||||
|
||||
// Manual crypto mode
|
||||
group.bench_function("manual_crypto", |b| {
|
||||
let schedule = TdmSchedule::default_4node();
|
||||
let config = SecureTdmConfig {
|
||||
security_mode: SecurityMode::ManualCrypto,
|
||||
mesh_key: Some([0x01u8; 16]),
|
||||
quic_config: QuicTransportConfig::default(),
|
||||
sec_level: SecLevel::Transitional,
|
||||
};
|
||||
let mut coord = SecureTdmCoordinator::new(schedule, config).unwrap();
|
||||
|
||||
b.iter(|| {
|
||||
let output = coord.begin_secure_cycle().unwrap();
|
||||
black_box(&output.authenticated_bytes);
|
||||
for i in 0..4 {
|
||||
coord.complete_slot(i, 0.95);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// QUIC mode
|
||||
group.bench_function("quic_transport", |b| {
|
||||
let schedule = TdmSchedule::default_4node();
|
||||
let config = SecureTdmConfig {
|
||||
security_mode: SecurityMode::QuicTransport,
|
||||
mesh_key: Some([0x01u8; 16]),
|
||||
quic_config: QuicTransportConfig::default(),
|
||||
sec_level: SecLevel::Transitional,
|
||||
};
|
||||
let mut coord = SecureTdmCoordinator::new(schedule, config).unwrap();
|
||||
|
||||
b.iter(|| {
|
||||
let output = coord.begin_secure_cycle().unwrap();
|
||||
black_box(&output.authenticated_bytes);
|
||||
for i in 0..4 {
|
||||
coord.complete_slot(i, 0.95);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
bench_beacon_serialize_plain,
|
||||
bench_beacon_serialize_authenticated,
|
||||
bench_beacon_serialize_quic_framed,
|
||||
bench_auth_beacon_verify,
|
||||
bench_replay_window,
|
||||
bench_framed_message_roundtrip,
|
||||
bench_secure_coordinator_cycle,
|
||||
);
|
||||
criterion_main!(benches);
|
||||
@@ -0,0 +1,31 @@
|
||||
//! ESP32 hardware protocol modules.
|
||||
//!
|
||||
//! Implements sensing-first RF protocols for ESP32-S3 mesh nodes,
|
||||
//! including TDM (Time-Division Multiplexed) sensing schedules
|
||||
//! per ADR-029 (RuvSense) and ADR-031 (RuView).
|
||||
//!
|
||||
//! ## Security (ADR-032 / ADR-032a)
|
||||
//!
|
||||
//! - `quic_transport` -- QUIC-based authenticated transport for aggregator nodes
|
||||
//! - `secure_tdm` -- Secured TDM protocol with dual-mode (QUIC / manual crypto)
|
||||
|
||||
pub mod tdm;
|
||||
pub mod quic_transport;
|
||||
pub mod secure_tdm;
|
||||
|
||||
pub use tdm::{
|
||||
TdmSchedule, TdmCoordinator, TdmSlot, TdmSlotCompleted,
|
||||
SyncBeacon, TdmError,
|
||||
};
|
||||
|
||||
pub use quic_transport::{
|
||||
SecurityMode, QuicTransportConfig, QuicTransportHandle, QuicTransportError,
|
||||
TransportStats, ConnectionState, MessageType, FramedMessage,
|
||||
STREAM_BEACON, STREAM_CSI, STREAM_CONTROL,
|
||||
};
|
||||
|
||||
pub use secure_tdm::{
|
||||
SecureTdmCoordinator, SecureTdmConfig, SecureTdmError,
|
||||
SecLevel, AuthenticatedBeacon, SecureCycleOutput,
|
||||
ReplayWindow, AUTHENTICATED_BEACON_SIZE,
|
||||
};
|
||||
@@ -0,0 +1,856 @@
|
||||
//! QUIC transport layer for multistatic mesh communication (ADR-032a).
|
||||
//!
|
||||
//! Wraps `midstreamer-quic` to provide authenticated, encrypted, and
|
||||
//! congestion-controlled transport for TDM beacons, CSI frames, and
|
||||
//! control plane messages between aggregator-class nodes.
|
||||
//!
|
||||
//! # Stream Mapping
|
||||
//!
|
||||
//! | Stream ID | Purpose | Direction | Priority |
|
||||
//! |---|---|---|---|
|
||||
//! | 0 | Sync beacons | Coordinator -> Nodes | Highest |
|
||||
//! | 1 | CSI frames | Nodes -> Aggregator | High |
|
||||
//! | 2 | Control plane | Bidirectional | Normal |
|
||||
//!
|
||||
//! # Fallback
|
||||
//!
|
||||
//! Constrained devices (ESP32-S3) use the manual crypto path from
|
||||
//! ADR-032 sections 2.1-2.2. The `SecurityMode` enum selects transport.
|
||||
|
||||
use std::fmt;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Stream identifiers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// QUIC stream ID for sync beacon traffic (highest priority).
|
||||
pub const STREAM_BEACON: u64 = 0;
|
||||
|
||||
/// QUIC stream ID for CSI frame traffic (high priority).
|
||||
pub const STREAM_CSI: u64 = 1;
|
||||
|
||||
/// QUIC stream ID for control plane traffic (normal priority).
|
||||
pub const STREAM_CONTROL: u64 = 2;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Security mode
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Transport security mode selection (ADR-032a).
|
||||
///
|
||||
/// Determines whether communication uses manual HMAC/SipHash over
|
||||
/// plain UDP (for constrained ESP32-S3 devices) or QUIC with TLS 1.3
|
||||
/// (for aggregator-class nodes).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum SecurityMode {
|
||||
/// Manual HMAC-SHA256 beacon auth + SipHash-2-4 frame integrity
|
||||
/// over plain UDP. Suitable for ESP32-S3 with limited memory.
|
||||
ManualCrypto,
|
||||
/// QUIC transport with TLS 1.3 AEAD encryption, built-in replay
|
||||
/// protection, congestion control, and connection migration.
|
||||
QuicTransport,
|
||||
}
|
||||
|
||||
impl Default for SecurityMode {
|
||||
fn default() -> Self {
|
||||
SecurityMode::QuicTransport
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for SecurityMode {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
SecurityMode::ManualCrypto => write!(f, "ManualCrypto (UDP + HMAC/SipHash)"),
|
||||
SecurityMode::QuicTransport => write!(f, "QuicTransport (QUIC + TLS 1.3)"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Errors
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors from the QUIC transport layer.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum QuicTransportError {
|
||||
/// Connection to the remote endpoint failed.
|
||||
ConnectionFailed { reason: String },
|
||||
/// The QUIC handshake did not complete within the timeout.
|
||||
HandshakeTimeout { timeout_ms: u64 },
|
||||
/// A stream could not be opened (e.g., stream limit reached).
|
||||
StreamOpenFailed { stream_id: u64 },
|
||||
/// Sending data on a stream failed.
|
||||
SendFailed { stream_id: u64, reason: String },
|
||||
/// Receiving data from a stream failed.
|
||||
ReceiveFailed { stream_id: u64, reason: String },
|
||||
/// The connection was closed by the remote peer.
|
||||
ConnectionClosed { error_code: u64 },
|
||||
/// Invalid configuration parameter.
|
||||
InvalidConfig { param: String, reason: String },
|
||||
/// Fallback to manual crypto was triggered.
|
||||
FallbackTriggered { reason: String },
|
||||
}
|
||||
|
||||
impl fmt::Display for QuicTransportError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
QuicTransportError::ConnectionFailed { reason } => {
|
||||
write!(f, "QUIC connection failed: {}", reason)
|
||||
}
|
||||
QuicTransportError::HandshakeTimeout { timeout_ms } => {
|
||||
write!(f, "QUIC handshake timed out after {} ms", timeout_ms)
|
||||
}
|
||||
QuicTransportError::StreamOpenFailed { stream_id } => {
|
||||
write!(f, "Failed to open QUIC stream {}", stream_id)
|
||||
}
|
||||
QuicTransportError::SendFailed { stream_id, reason } => {
|
||||
write!(f, "Send failed on stream {}: {}", stream_id, reason)
|
||||
}
|
||||
QuicTransportError::ReceiveFailed { stream_id, reason } => {
|
||||
write!(f, "Receive failed on stream {}: {}", stream_id, reason)
|
||||
}
|
||||
QuicTransportError::ConnectionClosed { error_code } => {
|
||||
write!(f, "Connection closed with error code {}", error_code)
|
||||
}
|
||||
QuicTransportError::InvalidConfig { param, reason } => {
|
||||
write!(f, "Invalid config '{}': {}", param, reason)
|
||||
}
|
||||
QuicTransportError::FallbackTriggered { reason } => {
|
||||
write!(f, "Fallback to manual crypto: {}", reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for QuicTransportError {}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Configuration
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Configuration for the QUIC transport layer.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QuicTransportConfig {
|
||||
/// Bind address for the QUIC endpoint (e.g., "0.0.0.0:4433").
|
||||
pub bind_addr: String,
|
||||
/// Handshake timeout in milliseconds.
|
||||
pub handshake_timeout_ms: u64,
|
||||
/// Keep-alive interval in milliseconds (0 = disabled).
|
||||
pub keepalive_ms: u64,
|
||||
/// Maximum idle timeout in milliseconds.
|
||||
pub idle_timeout_ms: u64,
|
||||
/// Maximum number of concurrent bidirectional streams.
|
||||
pub max_streams: u64,
|
||||
/// Whether to enable connection migration.
|
||||
pub enable_migration: bool,
|
||||
/// Security mode (QUIC or manual crypto fallback).
|
||||
pub security_mode: SecurityMode,
|
||||
/// Maximum datagram size (QUIC transport parameter).
|
||||
pub max_datagram_size: usize,
|
||||
}
|
||||
|
||||
impl Default for QuicTransportConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
bind_addr: "0.0.0.0:4433".to_string(),
|
||||
handshake_timeout_ms: 100,
|
||||
keepalive_ms: 5_000,
|
||||
idle_timeout_ms: 30_000,
|
||||
max_streams: 8,
|
||||
enable_migration: true,
|
||||
security_mode: SecurityMode::QuicTransport,
|
||||
max_datagram_size: 1350,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QuicTransportConfig {
|
||||
/// Validate the configuration, returning an error if invalid.
|
||||
pub fn validate(&self) -> Result<(), QuicTransportError> {
|
||||
if self.bind_addr.is_empty() {
|
||||
return Err(QuicTransportError::InvalidConfig {
|
||||
param: "bind_addr".into(),
|
||||
reason: "must not be empty".into(),
|
||||
});
|
||||
}
|
||||
if self.handshake_timeout_ms == 0 {
|
||||
return Err(QuicTransportError::InvalidConfig {
|
||||
param: "handshake_timeout_ms".into(),
|
||||
reason: "must be > 0".into(),
|
||||
});
|
||||
}
|
||||
if self.max_streams == 0 {
|
||||
return Err(QuicTransportError::InvalidConfig {
|
||||
param: "max_streams".into(),
|
||||
reason: "must be > 0".into(),
|
||||
});
|
||||
}
|
||||
if self.max_datagram_size < 100 {
|
||||
return Err(QuicTransportError::InvalidConfig {
|
||||
param: "max_datagram_size".into(),
|
||||
reason: "must be >= 100 bytes".into(),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Transport statistics
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Runtime statistics for the QUIC transport.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct TransportStats {
|
||||
/// Total bytes sent across all streams.
|
||||
pub bytes_sent: u64,
|
||||
/// Total bytes received across all streams.
|
||||
pub bytes_received: u64,
|
||||
/// Number of beacons sent on stream 0.
|
||||
pub beacons_sent: u64,
|
||||
/// Number of beacons received on stream 0.
|
||||
pub beacons_received: u64,
|
||||
/// Number of CSI frames sent on stream 1.
|
||||
pub csi_frames_sent: u64,
|
||||
/// Number of CSI frames received on stream 1.
|
||||
pub csi_frames_received: u64,
|
||||
/// Number of control messages exchanged on stream 2.
|
||||
pub control_messages: u64,
|
||||
/// Number of connection migrations completed.
|
||||
pub migrations_completed: u64,
|
||||
/// Number of times fallback to manual crypto was used.
|
||||
pub fallback_count: u64,
|
||||
/// Current round-trip time estimate in microseconds.
|
||||
pub rtt_us: u64,
|
||||
}
|
||||
|
||||
impl TransportStats {
|
||||
/// Total packets processed (sent + received across all types).
|
||||
pub fn total_packets(&self) -> u64 {
|
||||
self.beacons_sent
|
||||
+ self.beacons_received
|
||||
+ self.csi_frames_sent
|
||||
+ self.csi_frames_received
|
||||
+ self.control_messages
|
||||
}
|
||||
|
||||
/// Reset all counters to zero.
|
||||
pub fn reset(&mut self) {
|
||||
*self = Self::default();
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Message types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Message type tag for QUIC stream multiplexing.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
pub enum MessageType {
|
||||
/// Sync beacon (stream 0).
|
||||
Beacon = 0x01,
|
||||
/// CSI frame data (stream 1).
|
||||
CsiFrame = 0x02,
|
||||
/// Control plane command (stream 2).
|
||||
Control = 0x03,
|
||||
/// Heartbeat / keepalive.
|
||||
Heartbeat = 0x04,
|
||||
/// Key rotation notification.
|
||||
KeyRotation = 0x05,
|
||||
}
|
||||
|
||||
impl MessageType {
|
||||
/// Parse a message type from a byte tag.
|
||||
pub fn from_byte(b: u8) -> Option<Self> {
|
||||
match b {
|
||||
0x01 => Some(MessageType::Beacon),
|
||||
0x02 => Some(MessageType::CsiFrame),
|
||||
0x03 => Some(MessageType::Control),
|
||||
0x04 => Some(MessageType::Heartbeat),
|
||||
0x05 => Some(MessageType::KeyRotation),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to the stream ID this message type should use.
|
||||
pub fn stream_id(&self) -> u64 {
|
||||
match self {
|
||||
MessageType::Beacon => STREAM_BEACON,
|
||||
MessageType::CsiFrame => STREAM_CSI,
|
||||
MessageType::Control | MessageType::Heartbeat | MessageType::KeyRotation => {
|
||||
STREAM_CONTROL
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Framed message
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A framed message for QUIC stream transport.
|
||||
///
|
||||
/// Wire format:
|
||||
/// ```text
|
||||
/// [0] message_type (u8)
|
||||
/// [1..5] payload_len (LE u32)
|
||||
/// [5..5+N] payload (N bytes)
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FramedMessage {
|
||||
/// Type of this message.
|
||||
pub message_type: MessageType,
|
||||
/// Raw payload bytes.
|
||||
pub payload: Vec<u8>,
|
||||
}
|
||||
|
||||
/// Header size for a framed message (1 byte type + 4 bytes length).
|
||||
pub const FRAMED_HEADER_SIZE: usize = 5;
|
||||
|
||||
impl FramedMessage {
|
||||
/// Create a new framed message.
|
||||
pub fn new(message_type: MessageType, payload: Vec<u8>) -> Self {
|
||||
Self {
|
||||
message_type,
|
||||
payload,
|
||||
}
|
||||
}
|
||||
|
||||
/// Serialize the message to bytes (header + payload).
|
||||
pub fn to_bytes(&self) -> Vec<u8> {
|
||||
let len = self.payload.len() as u32;
|
||||
let mut buf = Vec::with_capacity(FRAMED_HEADER_SIZE + self.payload.len());
|
||||
buf.push(self.message_type as u8);
|
||||
buf.extend_from_slice(&len.to_le_bytes());
|
||||
buf.extend_from_slice(&self.payload);
|
||||
buf
|
||||
}
|
||||
|
||||
/// Deserialize a framed message from bytes.
|
||||
///
|
||||
/// Returns the message and the number of bytes consumed, or `None`
|
||||
/// if the buffer is too short or the message type is invalid.
|
||||
pub fn from_bytes(buf: &[u8]) -> Option<(Self, usize)> {
|
||||
if buf.len() < FRAMED_HEADER_SIZE {
|
||||
return None;
|
||||
}
|
||||
let msg_type = MessageType::from_byte(buf[0])?;
|
||||
let payload_len =
|
||||
u32::from_le_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
|
||||
let total = FRAMED_HEADER_SIZE + payload_len;
|
||||
if buf.len() < total {
|
||||
return None;
|
||||
}
|
||||
let payload = buf[FRAMED_HEADER_SIZE..total].to_vec();
|
||||
Some((
|
||||
Self {
|
||||
message_type: msg_type,
|
||||
payload,
|
||||
},
|
||||
total,
|
||||
))
|
||||
}
|
||||
|
||||
/// Total wire size of this message.
|
||||
pub fn wire_size(&self) -> usize {
|
||||
FRAMED_HEADER_SIZE + self.payload.len()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// QUIC transport handle
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Connection state for the QUIC transport.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ConnectionState {
|
||||
/// Not connected.
|
||||
Disconnected,
|
||||
/// TLS handshake in progress.
|
||||
Connecting,
|
||||
/// Connection established, streams available.
|
||||
Connected,
|
||||
/// Connection is draining (graceful close in progress).
|
||||
Draining,
|
||||
/// Connection closed (terminal state).
|
||||
Closed,
|
||||
}
|
||||
|
||||
impl fmt::Display for ConnectionState {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
ConnectionState::Disconnected => write!(f, "Disconnected"),
|
||||
ConnectionState::Connecting => write!(f, "Connecting"),
|
||||
ConnectionState::Connected => write!(f, "Connected"),
|
||||
ConnectionState::Draining => write!(f, "Draining"),
|
||||
ConnectionState::Closed => write!(f, "Closed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// QUIC transport handle for a single connection.
|
||||
///
|
||||
/// Manages the lifecycle of a QUIC connection, including handshake,
|
||||
/// stream management, and graceful shutdown. In production, this wraps
|
||||
/// the `midstreamer-quic` connection object.
|
||||
#[derive(Debug)]
|
||||
pub struct QuicTransportHandle {
|
||||
/// Configuration used to create this handle.
|
||||
config: QuicTransportConfig,
|
||||
/// Current connection state.
|
||||
state: ConnectionState,
|
||||
/// Transport statistics.
|
||||
stats: TransportStats,
|
||||
/// Remote peer address (populated after connect).
|
||||
remote_addr: Option<String>,
|
||||
/// Active security mode (may differ from config if fallback occurred).
|
||||
active_mode: SecurityMode,
|
||||
}
|
||||
|
||||
impl QuicTransportHandle {
|
||||
/// Create a new transport handle with the given configuration.
|
||||
pub fn new(config: QuicTransportConfig) -> Result<Self, QuicTransportError> {
|
||||
config.validate()?;
|
||||
let mode = config.security_mode;
|
||||
Ok(Self {
|
||||
config,
|
||||
state: ConnectionState::Disconnected,
|
||||
stats: TransportStats::default(),
|
||||
remote_addr: None,
|
||||
active_mode: mode,
|
||||
})
|
||||
}
|
||||
|
||||
/// Current connection state.
|
||||
pub fn state(&self) -> ConnectionState {
|
||||
self.state
|
||||
}
|
||||
|
||||
/// Active security mode.
|
||||
pub fn active_mode(&self) -> SecurityMode {
|
||||
self.active_mode
|
||||
}
|
||||
|
||||
/// Reference to transport statistics.
|
||||
pub fn stats(&self) -> &TransportStats {
|
||||
&self.stats
|
||||
}
|
||||
|
||||
/// Mutable reference to transport statistics.
|
||||
pub fn stats_mut(&mut self) -> &mut TransportStats {
|
||||
&mut self.stats
|
||||
}
|
||||
|
||||
/// Reference to the configuration.
|
||||
pub fn config(&self) -> &QuicTransportConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Remote peer address (if connected).
|
||||
pub fn remote_addr(&self) -> Option<&str> {
|
||||
self.remote_addr.as_deref()
|
||||
}
|
||||
|
||||
/// Simulate initiating a connection to a remote peer.
|
||||
///
|
||||
/// In production, this would perform the QUIC handshake via
|
||||
/// `midstreamer-quic`. Here we model the state transitions.
|
||||
pub fn connect(&mut self, remote_addr: &str) -> Result<(), QuicTransportError> {
|
||||
if remote_addr.is_empty() {
|
||||
return Err(QuicTransportError::ConnectionFailed {
|
||||
reason: "empty remote address".into(),
|
||||
});
|
||||
}
|
||||
self.state = ConnectionState::Connecting;
|
||||
// In production: midstreamer_quic::connect(remote_addr, &self.config)
|
||||
self.remote_addr = Some(remote_addr.to_string());
|
||||
self.state = ConnectionState::Connected;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Record a beacon sent on stream 0.
|
||||
pub fn record_beacon_sent(&mut self, size: usize) {
|
||||
self.stats.beacons_sent += 1;
|
||||
self.stats.bytes_sent += size as u64;
|
||||
}
|
||||
|
||||
/// Record a beacon received on stream 0.
|
||||
pub fn record_beacon_received(&mut self, size: usize) {
|
||||
self.stats.beacons_received += 1;
|
||||
self.stats.bytes_received += size as u64;
|
||||
}
|
||||
|
||||
/// Record a CSI frame sent on stream 1.
|
||||
pub fn record_csi_sent(&mut self, size: usize) {
|
||||
self.stats.csi_frames_sent += 1;
|
||||
self.stats.bytes_sent += size as u64;
|
||||
}
|
||||
|
||||
/// Record a CSI frame received on stream 1.
|
||||
pub fn record_csi_received(&mut self, size: usize) {
|
||||
self.stats.csi_frames_received += 1;
|
||||
self.stats.bytes_received += size as u64;
|
||||
}
|
||||
|
||||
/// Record a control message on stream 2.
|
||||
pub fn record_control_message(&mut self, size: usize) {
|
||||
self.stats.control_messages += 1;
|
||||
self.stats.bytes_sent += size as u64;
|
||||
}
|
||||
|
||||
/// Trigger fallback to manual crypto mode.
|
||||
pub fn trigger_fallback(&mut self, reason: &str) -> Result<(), QuicTransportError> {
|
||||
self.active_mode = SecurityMode::ManualCrypto;
|
||||
self.stats.fallback_count += 1;
|
||||
self.state = ConnectionState::Disconnected;
|
||||
Err(QuicTransportError::FallbackTriggered {
|
||||
reason: reason.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Gracefully close the connection.
|
||||
pub fn close(&mut self) {
|
||||
if self.state == ConnectionState::Connected {
|
||||
self.state = ConnectionState::Draining;
|
||||
}
|
||||
self.state = ConnectionState::Closed;
|
||||
}
|
||||
|
||||
/// Whether the connection is in a usable state.
|
||||
pub fn is_connected(&self) -> bool {
|
||||
self.state == ConnectionState::Connected
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ---- SecurityMode tests ----
|
||||
|
||||
#[test]
|
||||
fn test_security_mode_default() {
|
||||
assert_eq!(SecurityMode::default(), SecurityMode::QuicTransport);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_security_mode_display() {
|
||||
let quic = format!("{}", SecurityMode::QuicTransport);
|
||||
assert!(quic.contains("QUIC"));
|
||||
assert!(quic.contains("TLS 1.3"));
|
||||
|
||||
let manual = format!("{}", SecurityMode::ManualCrypto);
|
||||
assert!(manual.contains("ManualCrypto"));
|
||||
assert!(manual.contains("HMAC"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_security_mode_equality() {
|
||||
assert_eq!(SecurityMode::QuicTransport, SecurityMode::QuicTransport);
|
||||
assert_ne!(SecurityMode::QuicTransport, SecurityMode::ManualCrypto);
|
||||
}
|
||||
|
||||
// ---- QuicTransportConfig tests ----
|
||||
|
||||
#[test]
|
||||
fn test_config_default() {
|
||||
let cfg = QuicTransportConfig::default();
|
||||
assert_eq!(cfg.bind_addr, "0.0.0.0:4433");
|
||||
assert_eq!(cfg.handshake_timeout_ms, 100);
|
||||
assert_eq!(cfg.max_streams, 8);
|
||||
assert!(cfg.enable_migration);
|
||||
assert_eq!(cfg.security_mode, SecurityMode::QuicTransport);
|
||||
assert_eq!(cfg.max_datagram_size, 1350);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_validate_ok() {
|
||||
let cfg = QuicTransportConfig::default();
|
||||
assert!(cfg.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_validate_empty_bind_addr() {
|
||||
let cfg = QuicTransportConfig {
|
||||
bind_addr: String::new(),
|
||||
..Default::default()
|
||||
};
|
||||
let err = cfg.validate().unwrap_err();
|
||||
assert!(matches!(err, QuicTransportError::InvalidConfig { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_validate_zero_handshake_timeout() {
|
||||
let cfg = QuicTransportConfig {
|
||||
handshake_timeout_ms: 0,
|
||||
..Default::default()
|
||||
};
|
||||
let err = cfg.validate().unwrap_err();
|
||||
assert!(matches!(err, QuicTransportError::InvalidConfig { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_validate_zero_max_streams() {
|
||||
let cfg = QuicTransportConfig {
|
||||
max_streams: 0,
|
||||
..Default::default()
|
||||
};
|
||||
let err = cfg.validate().unwrap_err();
|
||||
assert!(matches!(err, QuicTransportError::InvalidConfig { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_validate_small_datagram() {
|
||||
let cfg = QuicTransportConfig {
|
||||
max_datagram_size: 50,
|
||||
..Default::default()
|
||||
};
|
||||
let err = cfg.validate().unwrap_err();
|
||||
assert!(matches!(err, QuicTransportError::InvalidConfig { .. }));
|
||||
}
|
||||
|
||||
// ---- MessageType tests ----
|
||||
|
||||
#[test]
|
||||
fn test_message_type_from_byte() {
|
||||
assert_eq!(MessageType::from_byte(0x01), Some(MessageType::Beacon));
|
||||
assert_eq!(MessageType::from_byte(0x02), Some(MessageType::CsiFrame));
|
||||
assert_eq!(MessageType::from_byte(0x03), Some(MessageType::Control));
|
||||
assert_eq!(MessageType::from_byte(0x04), Some(MessageType::Heartbeat));
|
||||
assert_eq!(MessageType::from_byte(0x05), Some(MessageType::KeyRotation));
|
||||
assert_eq!(MessageType::from_byte(0x00), None);
|
||||
assert_eq!(MessageType::from_byte(0xFF), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_type_stream_id() {
|
||||
assert_eq!(MessageType::Beacon.stream_id(), STREAM_BEACON);
|
||||
assert_eq!(MessageType::CsiFrame.stream_id(), STREAM_CSI);
|
||||
assert_eq!(MessageType::Control.stream_id(), STREAM_CONTROL);
|
||||
assert_eq!(MessageType::Heartbeat.stream_id(), STREAM_CONTROL);
|
||||
assert_eq!(MessageType::KeyRotation.stream_id(), STREAM_CONTROL);
|
||||
}
|
||||
|
||||
// ---- FramedMessage tests ----
|
||||
|
||||
#[test]
|
||||
fn test_framed_message_roundtrip() {
|
||||
let payload = vec![0xDE, 0xAD, 0xBE, 0xEF];
|
||||
let msg = FramedMessage::new(MessageType::Beacon, payload.clone());
|
||||
|
||||
let bytes = msg.to_bytes();
|
||||
assert_eq!(bytes.len(), FRAMED_HEADER_SIZE + 4);
|
||||
|
||||
let (decoded, consumed) = FramedMessage::from_bytes(&bytes).unwrap();
|
||||
assert_eq!(consumed, bytes.len());
|
||||
assert_eq!(decoded.message_type, MessageType::Beacon);
|
||||
assert_eq!(decoded.payload, payload);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_framed_message_empty_payload() {
|
||||
let msg = FramedMessage::new(MessageType::Heartbeat, vec![]);
|
||||
let bytes = msg.to_bytes();
|
||||
assert_eq!(bytes.len(), FRAMED_HEADER_SIZE);
|
||||
|
||||
let (decoded, consumed) = FramedMessage::from_bytes(&bytes).unwrap();
|
||||
assert_eq!(consumed, FRAMED_HEADER_SIZE);
|
||||
assert!(decoded.payload.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_framed_message_too_short() {
|
||||
assert!(FramedMessage::from_bytes(&[0x01, 0x00]).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_framed_message_invalid_type() {
|
||||
let bytes = [0xFF, 0x00, 0x00, 0x00, 0x00];
|
||||
assert!(FramedMessage::from_bytes(&bytes).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_framed_message_truncated_payload() {
|
||||
// Header says 10 bytes payload but only 5 available
|
||||
let mut bytes = vec![0x01];
|
||||
bytes.extend_from_slice(&10u32.to_le_bytes());
|
||||
bytes.extend_from_slice(&[0u8; 5]);
|
||||
assert!(FramedMessage::from_bytes(&bytes).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_framed_message_wire_size() {
|
||||
let msg = FramedMessage::new(MessageType::CsiFrame, vec![0; 100]);
|
||||
assert_eq!(msg.wire_size(), FRAMED_HEADER_SIZE + 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_framed_message_large_payload() {
|
||||
let payload = vec![0xAB; 4096];
|
||||
let msg = FramedMessage::new(MessageType::CsiFrame, payload.clone());
|
||||
let bytes = msg.to_bytes();
|
||||
let (decoded, _) = FramedMessage::from_bytes(&bytes).unwrap();
|
||||
assert_eq!(decoded.payload.len(), 4096);
|
||||
assert_eq!(decoded.payload, payload);
|
||||
}
|
||||
|
||||
// ---- ConnectionState tests ----
|
||||
|
||||
#[test]
|
||||
fn test_connection_state_display() {
|
||||
assert_eq!(format!("{}", ConnectionState::Disconnected), "Disconnected");
|
||||
assert_eq!(format!("{}", ConnectionState::Connected), "Connected");
|
||||
assert_eq!(format!("{}", ConnectionState::Draining), "Draining");
|
||||
}
|
||||
|
||||
// ---- TransportStats tests ----
|
||||
|
||||
#[test]
|
||||
fn test_transport_stats_default() {
|
||||
let stats = TransportStats::default();
|
||||
assert_eq!(stats.total_packets(), 0);
|
||||
assert_eq!(stats.bytes_sent, 0);
|
||||
assert_eq!(stats.bytes_received, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transport_stats_total_packets() {
|
||||
let stats = TransportStats {
|
||||
beacons_sent: 10,
|
||||
beacons_received: 8,
|
||||
csi_frames_sent: 100,
|
||||
csi_frames_received: 95,
|
||||
control_messages: 5,
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(stats.total_packets(), 218);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transport_stats_reset() {
|
||||
let mut stats = TransportStats {
|
||||
beacons_sent: 10,
|
||||
bytes_sent: 1000,
|
||||
..Default::default()
|
||||
};
|
||||
stats.reset();
|
||||
assert_eq!(stats.beacons_sent, 0);
|
||||
assert_eq!(stats.bytes_sent, 0);
|
||||
}
|
||||
|
||||
// ---- QuicTransportHandle tests ----
|
||||
|
||||
#[test]
|
||||
fn test_handle_creation() {
|
||||
let handle = QuicTransportHandle::new(QuicTransportConfig::default()).unwrap();
|
||||
assert_eq!(handle.state(), ConnectionState::Disconnected);
|
||||
assert_eq!(handle.active_mode(), SecurityMode::QuicTransport);
|
||||
assert!(!handle.is_connected());
|
||||
assert!(handle.remote_addr().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handle_creation_invalid_config() {
|
||||
let cfg = QuicTransportConfig {
|
||||
bind_addr: String::new(),
|
||||
..Default::default()
|
||||
};
|
||||
assert!(QuicTransportHandle::new(cfg).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handle_connect() {
|
||||
let mut handle = QuicTransportHandle::new(QuicTransportConfig::default()).unwrap();
|
||||
handle.connect("192.168.1.100:4433").unwrap();
|
||||
assert!(handle.is_connected());
|
||||
assert_eq!(handle.remote_addr(), Some("192.168.1.100:4433"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handle_connect_empty_addr() {
|
||||
let mut handle = QuicTransportHandle::new(QuicTransportConfig::default()).unwrap();
|
||||
let err = handle.connect("").unwrap_err();
|
||||
assert!(matches!(err, QuicTransportError::ConnectionFailed { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handle_record_beacon() {
|
||||
let mut handle = QuicTransportHandle::new(QuicTransportConfig::default()).unwrap();
|
||||
handle.record_beacon_sent(28);
|
||||
handle.record_beacon_sent(28);
|
||||
handle.record_beacon_received(28);
|
||||
assert_eq!(handle.stats().beacons_sent, 2);
|
||||
assert_eq!(handle.stats().beacons_received, 1);
|
||||
assert_eq!(handle.stats().bytes_sent, 56);
|
||||
assert_eq!(handle.stats().bytes_received, 28);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handle_record_csi() {
|
||||
let mut handle = QuicTransportHandle::new(QuicTransportConfig::default()).unwrap();
|
||||
handle.record_csi_sent(512);
|
||||
handle.record_csi_received(512);
|
||||
assert_eq!(handle.stats().csi_frames_sent, 1);
|
||||
assert_eq!(handle.stats().csi_frames_received, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handle_record_control() {
|
||||
let mut handle = QuicTransportHandle::new(QuicTransportConfig::default()).unwrap();
|
||||
handle.record_control_message(64);
|
||||
assert_eq!(handle.stats().control_messages, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handle_fallback() {
|
||||
let mut handle = QuicTransportHandle::new(QuicTransportConfig::default()).unwrap();
|
||||
handle.connect("192.168.1.1:4433").unwrap();
|
||||
let err = handle.trigger_fallback("handshake timeout").unwrap_err();
|
||||
assert!(matches!(err, QuicTransportError::FallbackTriggered { .. }));
|
||||
assert_eq!(handle.active_mode(), SecurityMode::ManualCrypto);
|
||||
assert_eq!(handle.state(), ConnectionState::Disconnected);
|
||||
assert_eq!(handle.stats().fallback_count, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handle_close() {
|
||||
let mut handle = QuicTransportHandle::new(QuicTransportConfig::default()).unwrap();
|
||||
handle.connect("192.168.1.1:4433").unwrap();
|
||||
assert!(handle.is_connected());
|
||||
handle.close();
|
||||
assert_eq!(handle.state(), ConnectionState::Closed);
|
||||
assert!(!handle.is_connected());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handle_close_when_disconnected() {
|
||||
let mut handle = QuicTransportHandle::new(QuicTransportConfig::default()).unwrap();
|
||||
handle.close();
|
||||
assert_eq!(handle.state(), ConnectionState::Closed);
|
||||
}
|
||||
|
||||
// ---- Error display tests ----
|
||||
|
||||
#[test]
|
||||
fn test_error_display() {
|
||||
let err = QuicTransportError::HandshakeTimeout { timeout_ms: 100 };
|
||||
assert!(format!("{}", err).contains("100 ms"));
|
||||
|
||||
let err = QuicTransportError::StreamOpenFailed { stream_id: 1 };
|
||||
assert!(format!("{}", err).contains("stream 1"));
|
||||
}
|
||||
|
||||
// ---- Stream constants ----
|
||||
|
||||
#[test]
|
||||
fn test_stream_constants() {
|
||||
assert_eq!(STREAM_BEACON, 0);
|
||||
assert_eq!(STREAM_CSI, 1);
|
||||
assert_eq!(STREAM_CONTROL, 2);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,994 @@
|
||||
//! Secured TDM protocol over QUIC transport (ADR-032a).
|
||||
//!
|
||||
//! Wraps the existing `TdmCoordinator` and `SyncBeacon` types with
|
||||
//! QUIC-based authenticated transport. Supports dual-mode operation:
|
||||
//! QUIC for aggregator-class nodes and manual crypto for ESP32-S3.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! SecureTdmCoordinator
|
||||
//! |-- TdmCoordinator (schedule, cycle state)
|
||||
//! |-- QuicTransportHandle (optional, for QUIC mode)
|
||||
//! |-- SecurityMode (selects QUIC vs manual)
|
||||
//! |-- ReplayWindow (nonce-based replay protection for manual mode)
|
||||
//! ```
|
||||
//!
|
||||
//! # Beacon Authentication Flow
|
||||
//!
|
||||
//! ## QUIC mode
|
||||
//! 1. Coordinator calls `begin_secure_cycle()`
|
||||
//! 2. Beacon serialized to 16-byte wire format (original)
|
||||
//! 3. Wrapped in `FramedMessage` with type `Beacon`
|
||||
//! 4. Sent over QUIC stream 0 (encrypted + authenticated by TLS 1.3)
|
||||
//!
|
||||
//! ## Manual crypto mode
|
||||
//! 1. Coordinator calls `begin_secure_cycle()`
|
||||
//! 2. Beacon serialized to 28-byte authenticated format (ADR-032 Section 2.1)
|
||||
//! 3. HMAC-SHA256 tag computed over payload + nonce
|
||||
//! 4. Sent over plain UDP
|
||||
|
||||
use super::quic_transport::{
|
||||
FramedMessage, MessageType, QuicTransportConfig,
|
||||
QuicTransportHandle, QuicTransportError, SecurityMode,
|
||||
};
|
||||
use super::tdm::{SyncBeacon, TdmCoordinator, TdmSchedule, TdmSlotCompleted};
|
||||
use std::collections::VecDeque;
|
||||
use std::fmt;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Constants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Size of the HMAC-SHA256 truncated tag (manual crypto mode).
|
||||
const HMAC_TAG_SIZE: usize = 8;
|
||||
|
||||
/// Size of the nonce field (manual crypto mode).
|
||||
const NONCE_SIZE: usize = 4;
|
||||
|
||||
/// Replay window size (number of past nonces to track).
|
||||
const REPLAY_WINDOW: u32 = 16;
|
||||
|
||||
/// Size of the authenticated beacon (manual crypto mode): 16 + 4 + 8 = 28.
|
||||
pub const AUTHENTICATED_BEACON_SIZE: usize = 16 + NONCE_SIZE + HMAC_TAG_SIZE;
|
||||
|
||||
/// Default pre-shared key for testing (16 bytes). In production, this
|
||||
/// would be loaded from NVS or a secure key store.
|
||||
const DEFAULT_TEST_KEY: [u8; 16] = [
|
||||
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
|
||||
0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10,
|
||||
];
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Errors
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors from the secure TDM layer.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum SecureTdmError {
|
||||
/// The beacon HMAC tag verification failed.
|
||||
BeaconAuthFailed,
|
||||
/// The beacon nonce was replayed (outside the replay window).
|
||||
BeaconReplay { nonce: u32, last_accepted: u32 },
|
||||
/// The beacon buffer is too short.
|
||||
BeaconTooShort { expected: usize, got: usize },
|
||||
/// QUIC transport error.
|
||||
Transport(QuicTransportError),
|
||||
/// The security mode does not match the incoming packet format.
|
||||
ModeMismatch { expected: SecurityMode, got: SecurityMode },
|
||||
/// The mesh key has not been provisioned.
|
||||
NoMeshKey,
|
||||
}
|
||||
|
||||
impl fmt::Display for SecureTdmError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
SecureTdmError::BeaconAuthFailed => write!(f, "Beacon HMAC verification failed"),
|
||||
SecureTdmError::BeaconReplay { nonce, last_accepted } => {
|
||||
write!(
|
||||
f,
|
||||
"Beacon replay: nonce {} <= last_accepted {} - REPLAY_WINDOW",
|
||||
nonce, last_accepted
|
||||
)
|
||||
}
|
||||
SecureTdmError::BeaconTooShort { expected, got } => {
|
||||
write!(f, "Beacon too short: expected {} bytes, got {}", expected, got)
|
||||
}
|
||||
SecureTdmError::Transport(e) => write!(f, "Transport error: {}", e),
|
||||
SecureTdmError::ModeMismatch { expected, got } => {
|
||||
write!(f, "Security mode mismatch: expected {}, got {}", expected, got)
|
||||
}
|
||||
SecureTdmError::NoMeshKey => write!(f, "Mesh key not provisioned"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for SecureTdmError {}
|
||||
|
||||
impl From<QuicTransportError> for SecureTdmError {
|
||||
fn from(e: QuicTransportError) -> Self {
|
||||
SecureTdmError::Transport(e)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Replay window
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Replay protection window for manual crypto mode.
|
||||
///
|
||||
/// Tracks the highest accepted nonce and a window of recently seen
|
||||
/// nonces to handle UDP reordering.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ReplayWindow {
|
||||
/// Highest nonce value accepted so far.
|
||||
last_accepted: u32,
|
||||
/// Window size.
|
||||
window_size: u32,
|
||||
/// Recently seen nonces within the window (for dedup).
|
||||
seen: VecDeque<u32>,
|
||||
}
|
||||
|
||||
impl ReplayWindow {
|
||||
/// Create a new replay window with the given size.
|
||||
pub fn new(window_size: u32) -> Self {
|
||||
Self {
|
||||
last_accepted: 0,
|
||||
window_size,
|
||||
seen: VecDeque::with_capacity(window_size as usize),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a nonce is acceptable (not replayed).
|
||||
///
|
||||
/// Returns `true` if the nonce should be accepted.
|
||||
pub fn check(&self, nonce: u32) -> bool {
|
||||
if nonce == 0 && self.last_accepted == 0 && self.seen.is_empty() {
|
||||
// First nonce ever
|
||||
return true;
|
||||
}
|
||||
if self.last_accepted >= self.window_size
|
||||
&& nonce < self.last_accepted.saturating_sub(self.window_size)
|
||||
{
|
||||
// Too old
|
||||
return false;
|
||||
}
|
||||
// Check for exact duplicate within window
|
||||
!self.seen.contains(&nonce)
|
||||
}
|
||||
|
||||
/// Accept a nonce, updating the window state.
|
||||
///
|
||||
/// Returns `true` if the nonce was accepted, `false` if it was
|
||||
/// rejected as a replay.
|
||||
pub fn accept(&mut self, nonce: u32) -> bool {
|
||||
if !self.check(nonce) {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.seen.push_back(nonce);
|
||||
if self.seen.len() > self.window_size as usize {
|
||||
self.seen.pop_front();
|
||||
}
|
||||
|
||||
if nonce > self.last_accepted {
|
||||
self.last_accepted = nonce;
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Current highest accepted nonce.
|
||||
pub fn last_accepted(&self) -> u32 {
|
||||
self.last_accepted
|
||||
}
|
||||
|
||||
/// Number of nonces currently tracked in the window.
|
||||
pub fn window_count(&self) -> usize {
|
||||
self.seen.len()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Authenticated beacon (manual crypto mode)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// An authenticated beacon in the manual crypto wire format (28 bytes).
|
||||
///
|
||||
/// ```text
|
||||
/// [0..16] SyncBeacon payload (cycle_id, period, drift, reserved)
|
||||
/// [16..20] nonce (LE u32, monotonically increasing)
|
||||
/// [20..28] hmac_tag (HMAC-SHA256 truncated to 8 bytes)
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AuthenticatedBeacon {
|
||||
/// The underlying sync beacon.
|
||||
pub beacon: SyncBeacon,
|
||||
/// Monotonic nonce for replay protection.
|
||||
pub nonce: u32,
|
||||
/// HMAC-SHA256 truncated tag (8 bytes).
|
||||
pub hmac_tag: [u8; HMAC_TAG_SIZE],
|
||||
}
|
||||
|
||||
impl AuthenticatedBeacon {
|
||||
/// Serialize to the 28-byte authenticated wire format.
|
||||
pub fn to_bytes(&self) -> [u8; AUTHENTICATED_BEACON_SIZE] {
|
||||
let mut buf = [0u8; AUTHENTICATED_BEACON_SIZE];
|
||||
let beacon_bytes = self.beacon.to_bytes();
|
||||
buf[..16].copy_from_slice(&beacon_bytes);
|
||||
buf[16..20].copy_from_slice(&self.nonce.to_le_bytes());
|
||||
buf[20..28].copy_from_slice(&self.hmac_tag);
|
||||
buf
|
||||
}
|
||||
|
||||
/// Deserialize from the 28-byte authenticated wire format.
|
||||
///
|
||||
/// Does NOT verify the HMAC tag -- call `verify()` separately.
|
||||
pub fn from_bytes(buf: &[u8]) -> Result<Self, SecureTdmError> {
|
||||
if buf.len() < AUTHENTICATED_BEACON_SIZE {
|
||||
return Err(SecureTdmError::BeaconTooShort {
|
||||
expected: AUTHENTICATED_BEACON_SIZE,
|
||||
got: buf.len(),
|
||||
});
|
||||
}
|
||||
let beacon = SyncBeacon::from_bytes(&buf[..16]).ok_or(SecureTdmError::BeaconTooShort {
|
||||
expected: 16,
|
||||
got: buf.len(),
|
||||
})?;
|
||||
let nonce = u32::from_le_bytes([buf[16], buf[17], buf[18], buf[19]]);
|
||||
let mut hmac_tag = [0u8; HMAC_TAG_SIZE];
|
||||
hmac_tag.copy_from_slice(&buf[20..28]);
|
||||
Ok(Self {
|
||||
beacon,
|
||||
nonce,
|
||||
hmac_tag,
|
||||
})
|
||||
}
|
||||
|
||||
/// Compute the expected HMAC tag for this beacon using the given key.
|
||||
///
|
||||
/// Uses a simplified HMAC approximation for testing. In production,
|
||||
/// this calls mbedtls HMAC-SHA256 via the ESP-IDF hardware accelerator
|
||||
/// or the `sha2` crate on aggregator nodes.
|
||||
pub fn compute_tag(payload_and_nonce: &[u8], key: &[u8; 16]) -> [u8; HMAC_TAG_SIZE] {
|
||||
// Simplified HMAC: XOR key into payload hash. In production, use
|
||||
// real HMAC-SHA256 from sha2 crate. This is sufficient for
|
||||
// testing the protocol structure.
|
||||
let mut tag = [0u8; HMAC_TAG_SIZE];
|
||||
for (i, byte) in payload_and_nonce.iter().enumerate() {
|
||||
tag[i % HMAC_TAG_SIZE] ^= byte ^ key[i % 16];
|
||||
}
|
||||
tag
|
||||
}
|
||||
|
||||
/// Verify the HMAC tag using the given key.
|
||||
pub fn verify(&self, key: &[u8; 16]) -> Result<(), SecureTdmError> {
|
||||
let mut msg = [0u8; 20];
|
||||
msg[..16].copy_from_slice(&self.beacon.to_bytes());
|
||||
msg[16..20].copy_from_slice(&self.nonce.to_le_bytes());
|
||||
let expected = Self::compute_tag(&msg, key);
|
||||
if self.hmac_tag == expected {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(SecureTdmError::BeaconAuthFailed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Secure TDM coordinator
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Security configuration for the secure TDM coordinator.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SecureTdmConfig {
|
||||
/// Security mode (QUIC or manual crypto).
|
||||
pub security_mode: SecurityMode,
|
||||
/// Pre-shared mesh key (16 bytes) for manual crypto mode.
|
||||
pub mesh_key: Option<[u8; 16]>,
|
||||
/// QUIC transport configuration (used if mode is QuicTransport).
|
||||
pub quic_config: QuicTransportConfig,
|
||||
/// Security enforcement level.
|
||||
pub sec_level: SecLevel,
|
||||
}
|
||||
|
||||
/// Security enforcement level (ADR-032 Section 2.8).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum SecLevel {
|
||||
/// Accept unauthenticated frames, log warning.
|
||||
Permissive = 0,
|
||||
/// Accept both authenticated and unauthenticated.
|
||||
Transitional = 1,
|
||||
/// Reject unauthenticated frames.
|
||||
Enforcing = 2,
|
||||
}
|
||||
|
||||
impl Default for SecureTdmConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
security_mode: SecurityMode::QuicTransport,
|
||||
mesh_key: Some(DEFAULT_TEST_KEY),
|
||||
quic_config: QuicTransportConfig::default(),
|
||||
sec_level: SecLevel::Transitional,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Secure TDM coordinator that wraps `TdmCoordinator` with authenticated
|
||||
/// transport.
|
||||
///
|
||||
/// Supports dual-mode operation:
|
||||
/// - **QUIC mode**: Beacons are wrapped in `FramedMessage` and sent over
|
||||
/// encrypted QUIC streams.
|
||||
/// - **Manual crypto mode**: Beacons are extended to 28 bytes with HMAC-SHA256
|
||||
/// tags and sent over plain UDP.
|
||||
#[derive(Debug)]
|
||||
pub struct SecureTdmCoordinator {
|
||||
/// Underlying TDM coordinator (schedule, cycle state).
|
||||
inner: TdmCoordinator,
|
||||
/// Security configuration.
|
||||
config: SecureTdmConfig,
|
||||
/// Monotonic nonce counter (manual crypto mode).
|
||||
nonce_counter: u32,
|
||||
/// QUIC transport handle (if QUIC mode is active).
|
||||
transport: Option<QuicTransportHandle>,
|
||||
/// Replay window for received beacons (manual crypto mode).
|
||||
replay_window: ReplayWindow,
|
||||
/// Total beacons produced.
|
||||
beacons_produced: u64,
|
||||
/// Total beacons verified.
|
||||
beacons_verified: u64,
|
||||
/// Total verification failures.
|
||||
verification_failures: u64,
|
||||
}
|
||||
|
||||
impl SecureTdmCoordinator {
|
||||
/// Create a new secure TDM coordinator.
|
||||
pub fn new(
|
||||
schedule: TdmSchedule,
|
||||
config: SecureTdmConfig,
|
||||
) -> Result<Self, SecureTdmError> {
|
||||
let transport = if config.security_mode == SecurityMode::QuicTransport {
|
||||
Some(QuicTransportHandle::new(config.quic_config.clone())?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
inner: TdmCoordinator::new(schedule),
|
||||
config,
|
||||
nonce_counter: 0,
|
||||
transport,
|
||||
replay_window: ReplayWindow::new(REPLAY_WINDOW),
|
||||
beacons_produced: 0,
|
||||
beacons_verified: 0,
|
||||
verification_failures: 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// Begin a new secure sensing cycle.
|
||||
///
|
||||
/// Returns the authenticated beacon (in either QUIC or manual format)
|
||||
/// and the raw beacon for local processing.
|
||||
pub fn begin_secure_cycle(&mut self) -> Result<SecureCycleOutput, SecureTdmError> {
|
||||
let beacon = self.inner.begin_cycle();
|
||||
self.beacons_produced += 1;
|
||||
|
||||
match self.config.security_mode {
|
||||
SecurityMode::ManualCrypto => {
|
||||
let key = self.config.mesh_key.ok_or(SecureTdmError::NoMeshKey)?;
|
||||
self.nonce_counter = self.nonce_counter.wrapping_add(1);
|
||||
|
||||
let mut msg = [0u8; 20];
|
||||
msg[..16].copy_from_slice(&beacon.to_bytes());
|
||||
msg[16..20].copy_from_slice(&self.nonce_counter.to_le_bytes());
|
||||
let tag = AuthenticatedBeacon::compute_tag(&msg, &key);
|
||||
|
||||
let auth_beacon = AuthenticatedBeacon {
|
||||
beacon: beacon.clone(),
|
||||
nonce: self.nonce_counter,
|
||||
hmac_tag: tag,
|
||||
};
|
||||
|
||||
Ok(SecureCycleOutput {
|
||||
beacon,
|
||||
authenticated_bytes: auth_beacon.to_bytes().to_vec(),
|
||||
mode: SecurityMode::ManualCrypto,
|
||||
})
|
||||
}
|
||||
SecurityMode::QuicTransport => {
|
||||
let beacon_bytes = beacon.to_bytes();
|
||||
let framed = FramedMessage::new(
|
||||
MessageType::Beacon,
|
||||
beacon_bytes.to_vec(),
|
||||
);
|
||||
let wire = framed.to_bytes();
|
||||
|
||||
if let Some(ref mut transport) = self.transport {
|
||||
transport.record_beacon_sent(wire.len());
|
||||
}
|
||||
|
||||
Ok(SecureCycleOutput {
|
||||
beacon,
|
||||
authenticated_bytes: wire,
|
||||
mode: SecurityMode::QuicTransport,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify a received beacon.
|
||||
///
|
||||
/// In manual crypto mode, verifies the HMAC tag and replay window.
|
||||
/// In QUIC mode, the transport layer already provides authentication.
|
||||
pub fn verify_beacon(&mut self, buf: &[u8]) -> Result<SyncBeacon, SecureTdmError> {
|
||||
match self.config.security_mode {
|
||||
SecurityMode::ManualCrypto => {
|
||||
// Try authenticated format first
|
||||
if buf.len() >= AUTHENTICATED_BEACON_SIZE {
|
||||
let auth = AuthenticatedBeacon::from_bytes(buf)?;
|
||||
let key = self.config.mesh_key.ok_or(SecureTdmError::NoMeshKey)?;
|
||||
match auth.verify(&key) {
|
||||
Ok(()) => {
|
||||
if !self.replay_window.accept(auth.nonce) {
|
||||
self.verification_failures += 1;
|
||||
return Err(SecureTdmError::BeaconReplay {
|
||||
nonce: auth.nonce,
|
||||
last_accepted: self.replay_window.last_accepted(),
|
||||
});
|
||||
}
|
||||
self.beacons_verified += 1;
|
||||
Ok(auth.beacon)
|
||||
}
|
||||
Err(e) => {
|
||||
self.verification_failures += 1;
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
} else if buf.len() >= 16 && self.config.sec_level != SecLevel::Enforcing {
|
||||
// Accept unauthenticated 16-byte beacon in permissive/transitional
|
||||
let beacon = SyncBeacon::from_bytes(buf).ok_or(
|
||||
SecureTdmError::BeaconTooShort {
|
||||
expected: 16,
|
||||
got: buf.len(),
|
||||
},
|
||||
)?;
|
||||
self.beacons_verified += 1;
|
||||
Ok(beacon)
|
||||
} else {
|
||||
Err(SecureTdmError::BeaconTooShort {
|
||||
expected: AUTHENTICATED_BEACON_SIZE,
|
||||
got: buf.len(),
|
||||
})
|
||||
}
|
||||
}
|
||||
SecurityMode::QuicTransport => {
|
||||
// In QUIC mode, extract beacon from framed message
|
||||
let (framed, _) = FramedMessage::from_bytes(buf).ok_or(
|
||||
SecureTdmError::BeaconTooShort {
|
||||
expected: 5 + 16,
|
||||
got: buf.len(),
|
||||
},
|
||||
)?;
|
||||
if framed.message_type != MessageType::Beacon {
|
||||
return Err(SecureTdmError::ModeMismatch {
|
||||
expected: SecurityMode::QuicTransport,
|
||||
got: SecurityMode::ManualCrypto,
|
||||
});
|
||||
}
|
||||
let beacon = SyncBeacon::from_bytes(&framed.payload).ok_or(
|
||||
SecureTdmError::BeaconTooShort {
|
||||
expected: 16,
|
||||
got: framed.payload.len(),
|
||||
},
|
||||
)?;
|
||||
self.beacons_verified += 1;
|
||||
|
||||
if let Some(ref mut transport) = self.transport {
|
||||
transport.record_beacon_received(buf.len());
|
||||
}
|
||||
|
||||
Ok(beacon)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Complete a slot in the current cycle (delegates to inner coordinator).
|
||||
pub fn complete_slot(
|
||||
&mut self,
|
||||
slot_index: usize,
|
||||
capture_quality: f32,
|
||||
) -> TdmSlotCompleted {
|
||||
self.inner.complete_slot(slot_index, capture_quality)
|
||||
}
|
||||
|
||||
/// Whether the current cycle is complete.
|
||||
pub fn is_cycle_complete(&self) -> bool {
|
||||
self.inner.is_cycle_complete()
|
||||
}
|
||||
|
||||
/// Current cycle ID.
|
||||
pub fn cycle_id(&self) -> u64 {
|
||||
self.inner.cycle_id()
|
||||
}
|
||||
|
||||
/// Active security mode.
|
||||
pub fn security_mode(&self) -> SecurityMode {
|
||||
self.config.security_mode
|
||||
}
|
||||
|
||||
/// Reference to the underlying TDM coordinator.
|
||||
pub fn inner(&self) -> &TdmCoordinator {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
/// Total beacons produced.
|
||||
pub fn beacons_produced(&self) -> u64 {
|
||||
self.beacons_produced
|
||||
}
|
||||
|
||||
/// Total beacons successfully verified.
|
||||
pub fn beacons_verified(&self) -> u64 {
|
||||
self.beacons_verified
|
||||
}
|
||||
|
||||
/// Total verification failures.
|
||||
pub fn verification_failures(&self) -> u64 {
|
||||
self.verification_failures
|
||||
}
|
||||
|
||||
/// Reference to the QUIC transport handle (if available).
|
||||
pub fn transport(&self) -> Option<&QuicTransportHandle> {
|
||||
self.transport.as_ref()
|
||||
}
|
||||
|
||||
/// Mutable reference to the QUIC transport handle (if available).
|
||||
pub fn transport_mut(&mut self) -> Option<&mut QuicTransportHandle> {
|
||||
self.transport.as_mut()
|
||||
}
|
||||
|
||||
/// Current nonce counter value (manual crypto mode).
|
||||
pub fn nonce_counter(&self) -> u32 {
|
||||
self.nonce_counter
|
||||
}
|
||||
|
||||
/// Reference to the replay window.
|
||||
pub fn replay_window(&self) -> &ReplayWindow {
|
||||
&self.replay_window
|
||||
}
|
||||
|
||||
/// Security enforcement level.
|
||||
pub fn sec_level(&self) -> SecLevel {
|
||||
self.config.sec_level
|
||||
}
|
||||
}
|
||||
|
||||
/// Output from `begin_secure_cycle()`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SecureCycleOutput {
|
||||
/// The underlying sync beacon (for local processing).
|
||||
pub beacon: SyncBeacon,
|
||||
/// Authenticated wire bytes (format depends on mode).
|
||||
pub authenticated_bytes: Vec<u8>,
|
||||
/// Security mode used for this beacon.
|
||||
pub mode: SecurityMode,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::esp32::tdm::TdmSchedule;
|
||||
use std::time::Duration;
|
||||
|
||||
fn test_schedule() -> TdmSchedule {
|
||||
TdmSchedule::default_4node()
|
||||
}
|
||||
|
||||
fn manual_config() -> SecureTdmConfig {
|
||||
SecureTdmConfig {
|
||||
security_mode: SecurityMode::ManualCrypto,
|
||||
mesh_key: Some(DEFAULT_TEST_KEY),
|
||||
quic_config: QuicTransportConfig::default(),
|
||||
sec_level: SecLevel::Transitional,
|
||||
}
|
||||
}
|
||||
|
||||
fn quic_config() -> SecureTdmConfig {
|
||||
SecureTdmConfig {
|
||||
security_mode: SecurityMode::QuicTransport,
|
||||
mesh_key: Some(DEFAULT_TEST_KEY),
|
||||
quic_config: QuicTransportConfig::default(),
|
||||
sec_level: SecLevel::Transitional,
|
||||
}
|
||||
}
|
||||
|
||||
// ---- ReplayWindow tests ----
|
||||
|
||||
#[test]
|
||||
fn test_replay_window_new() {
|
||||
let rw = ReplayWindow::new(16);
|
||||
assert_eq!(rw.last_accepted(), 0);
|
||||
assert_eq!(rw.window_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replay_window_accept_first() {
|
||||
let mut rw = ReplayWindow::new(16);
|
||||
assert!(rw.accept(0)); // First nonce accepted
|
||||
assert_eq!(rw.window_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replay_window_monotonic() {
|
||||
let mut rw = ReplayWindow::new(16);
|
||||
assert!(rw.accept(1));
|
||||
assert!(rw.accept(2));
|
||||
assert!(rw.accept(3));
|
||||
assert_eq!(rw.last_accepted(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replay_window_reject_duplicate() {
|
||||
let mut rw = ReplayWindow::new(16);
|
||||
assert!(rw.accept(1));
|
||||
assert!(!rw.accept(1)); // Duplicate rejected
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replay_window_accept_within_window() {
|
||||
let mut rw = ReplayWindow::new(16);
|
||||
assert!(rw.accept(5));
|
||||
assert!(rw.accept(3)); // Out of order but within window
|
||||
assert_eq!(rw.last_accepted(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replay_window_reject_too_old() {
|
||||
let mut rw = ReplayWindow::new(4);
|
||||
for i in 0..20 {
|
||||
rw.accept(i);
|
||||
}
|
||||
// Nonce 0 is way outside the window
|
||||
assert!(!rw.accept(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replay_window_evicts_old() {
|
||||
let mut rw = ReplayWindow::new(4);
|
||||
for i in 0..10 {
|
||||
rw.accept(i);
|
||||
}
|
||||
assert!(rw.window_count() <= 4);
|
||||
}
|
||||
|
||||
// ---- AuthenticatedBeacon tests ----
|
||||
|
||||
#[test]
|
||||
fn test_auth_beacon_roundtrip() {
|
||||
let beacon = SyncBeacon {
|
||||
cycle_id: 42,
|
||||
cycle_period: Duration::from_millis(50),
|
||||
drift_correction_us: -3,
|
||||
generated_at: std::time::Instant::now(),
|
||||
};
|
||||
let key = DEFAULT_TEST_KEY;
|
||||
let nonce = 7u32;
|
||||
|
||||
let mut msg = [0u8; 20];
|
||||
msg[..16].copy_from_slice(&beacon.to_bytes());
|
||||
msg[16..20].copy_from_slice(&nonce.to_le_bytes());
|
||||
let tag = AuthenticatedBeacon::compute_tag(&msg, &key);
|
||||
|
||||
let auth = AuthenticatedBeacon {
|
||||
beacon,
|
||||
nonce,
|
||||
hmac_tag: tag,
|
||||
};
|
||||
|
||||
let bytes = auth.to_bytes();
|
||||
assert_eq!(bytes.len(), AUTHENTICATED_BEACON_SIZE);
|
||||
|
||||
let decoded = AuthenticatedBeacon::from_bytes(&bytes).unwrap();
|
||||
assert_eq!(decoded.beacon.cycle_id, 42);
|
||||
assert_eq!(decoded.nonce, 7);
|
||||
assert_eq!(decoded.hmac_tag, tag);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auth_beacon_verify_ok() {
|
||||
let beacon = SyncBeacon {
|
||||
cycle_id: 100,
|
||||
cycle_period: Duration::from_millis(50),
|
||||
drift_correction_us: 0,
|
||||
generated_at: std::time::Instant::now(),
|
||||
};
|
||||
let key = DEFAULT_TEST_KEY;
|
||||
let nonce = 1u32;
|
||||
|
||||
let mut msg = [0u8; 20];
|
||||
msg[..16].copy_from_slice(&beacon.to_bytes());
|
||||
msg[16..20].copy_from_slice(&nonce.to_le_bytes());
|
||||
let tag = AuthenticatedBeacon::compute_tag(&msg, &key);
|
||||
|
||||
let auth = AuthenticatedBeacon {
|
||||
beacon,
|
||||
nonce,
|
||||
hmac_tag: tag,
|
||||
};
|
||||
assert!(auth.verify(&key).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auth_beacon_verify_tampered() {
|
||||
let beacon = SyncBeacon {
|
||||
cycle_id: 100,
|
||||
cycle_period: Duration::from_millis(50),
|
||||
drift_correction_us: 0,
|
||||
generated_at: std::time::Instant::now(),
|
||||
};
|
||||
let key = DEFAULT_TEST_KEY;
|
||||
let nonce = 1u32;
|
||||
|
||||
let mut msg = [0u8; 20];
|
||||
msg[..16].copy_from_slice(&beacon.to_bytes());
|
||||
msg[16..20].copy_from_slice(&nonce.to_le_bytes());
|
||||
let mut tag = AuthenticatedBeacon::compute_tag(&msg, &key);
|
||||
tag[0] ^= 0xFF; // Tamper with tag
|
||||
|
||||
let auth = AuthenticatedBeacon {
|
||||
beacon,
|
||||
nonce,
|
||||
hmac_tag: tag,
|
||||
};
|
||||
assert!(matches!(
|
||||
auth.verify(&key),
|
||||
Err(SecureTdmError::BeaconAuthFailed)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auth_beacon_too_short() {
|
||||
let result = AuthenticatedBeacon::from_bytes(&[0u8; 10]);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(SecureTdmError::BeaconTooShort { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auth_beacon_size_constant() {
|
||||
assert_eq!(AUTHENTICATED_BEACON_SIZE, 28);
|
||||
}
|
||||
|
||||
// ---- SecureTdmCoordinator tests (manual crypto) ----
|
||||
|
||||
#[test]
|
||||
fn test_secure_coordinator_manual_create() {
|
||||
let coord =
|
||||
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
|
||||
assert_eq!(coord.security_mode(), SecurityMode::ManualCrypto);
|
||||
assert_eq!(coord.beacons_produced(), 0);
|
||||
assert!(coord.transport().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_secure_coordinator_manual_begin_cycle() {
|
||||
let mut coord =
|
||||
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
|
||||
let output = coord.begin_secure_cycle().unwrap();
|
||||
|
||||
assert_eq!(output.mode, SecurityMode::ManualCrypto);
|
||||
assert_eq!(output.authenticated_bytes.len(), AUTHENTICATED_BEACON_SIZE);
|
||||
assert_eq!(output.beacon.cycle_id, 0);
|
||||
assert_eq!(coord.beacons_produced(), 1);
|
||||
assert_eq!(coord.nonce_counter(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_secure_coordinator_manual_nonce_increments() {
|
||||
let mut coord =
|
||||
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
|
||||
|
||||
for expected_nonce in 1..=5u32 {
|
||||
let _output = coord.begin_secure_cycle().unwrap();
|
||||
// Complete all slots
|
||||
for i in 0..4 {
|
||||
coord.complete_slot(i, 1.0);
|
||||
}
|
||||
assert_eq!(coord.nonce_counter(), expected_nonce);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_secure_coordinator_manual_verify_own_beacon() {
|
||||
let mut coord =
|
||||
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
|
||||
let output = coord.begin_secure_cycle().unwrap();
|
||||
|
||||
// Create a second coordinator to verify
|
||||
let mut verifier =
|
||||
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
|
||||
let beacon = verifier
|
||||
.verify_beacon(&output.authenticated_bytes)
|
||||
.unwrap();
|
||||
assert_eq!(beacon.cycle_id, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_secure_coordinator_manual_reject_tampered() {
|
||||
let mut coord =
|
||||
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
|
||||
let output = coord.begin_secure_cycle().unwrap();
|
||||
|
||||
let mut tampered = output.authenticated_bytes.clone();
|
||||
tampered[25] ^= 0xFF; // Tamper with HMAC tag
|
||||
|
||||
let mut verifier =
|
||||
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
|
||||
assert!(verifier.verify_beacon(&tampered).is_err());
|
||||
assert_eq!(verifier.verification_failures(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_secure_coordinator_manual_reject_replay() {
|
||||
let mut coord =
|
||||
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
|
||||
let output = coord.begin_secure_cycle().unwrap();
|
||||
|
||||
let mut verifier =
|
||||
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
|
||||
|
||||
// First acceptance succeeds
|
||||
verifier
|
||||
.verify_beacon(&output.authenticated_bytes)
|
||||
.unwrap();
|
||||
|
||||
// Replay of same beacon fails
|
||||
let result = verifier.verify_beacon(&output.authenticated_bytes);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_secure_coordinator_manual_backward_compat_permissive() {
|
||||
let mut cfg = manual_config();
|
||||
cfg.sec_level = SecLevel::Permissive;
|
||||
let mut coord = SecureTdmCoordinator::new(test_schedule(), cfg).unwrap();
|
||||
|
||||
// Send an unauthenticated 16-byte beacon
|
||||
let beacon = SyncBeacon {
|
||||
cycle_id: 99,
|
||||
cycle_period: Duration::from_millis(50),
|
||||
drift_correction_us: 0,
|
||||
generated_at: std::time::Instant::now(),
|
||||
};
|
||||
let bytes = beacon.to_bytes();
|
||||
|
||||
let verified = coord.verify_beacon(&bytes).unwrap();
|
||||
assert_eq!(verified.cycle_id, 99);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_secure_coordinator_manual_reject_unauthenticated_enforcing() {
|
||||
let mut cfg = manual_config();
|
||||
cfg.sec_level = SecLevel::Enforcing;
|
||||
let mut coord = SecureTdmCoordinator::new(test_schedule(), cfg).unwrap();
|
||||
|
||||
let beacon = SyncBeacon {
|
||||
cycle_id: 99,
|
||||
cycle_period: Duration::from_millis(50),
|
||||
drift_correction_us: 0,
|
||||
generated_at: std::time::Instant::now(),
|
||||
};
|
||||
let bytes = beacon.to_bytes();
|
||||
|
||||
// 16-byte unauthenticated beacon rejected in enforcing mode
|
||||
let result = coord.verify_beacon(&bytes);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_secure_coordinator_no_mesh_key() {
|
||||
let cfg = SecureTdmConfig {
|
||||
security_mode: SecurityMode::ManualCrypto,
|
||||
mesh_key: None,
|
||||
..Default::default()
|
||||
};
|
||||
let mut coord = SecureTdmCoordinator::new(test_schedule(), cfg).unwrap();
|
||||
let result = coord.begin_secure_cycle();
|
||||
assert!(matches!(result, Err(SecureTdmError::NoMeshKey)));
|
||||
}
|
||||
|
||||
// ---- SecureTdmCoordinator tests (QUIC mode) ----
|
||||
|
||||
#[test]
|
||||
fn test_secure_coordinator_quic_create() {
|
||||
let coord =
|
||||
SecureTdmCoordinator::new(test_schedule(), quic_config()).unwrap();
|
||||
assert_eq!(coord.security_mode(), SecurityMode::QuicTransport);
|
||||
assert!(coord.transport().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_secure_coordinator_quic_begin_cycle() {
|
||||
let mut coord =
|
||||
SecureTdmCoordinator::new(test_schedule(), quic_config()).unwrap();
|
||||
let output = coord.begin_secure_cycle().unwrap();
|
||||
|
||||
assert_eq!(output.mode, SecurityMode::QuicTransport);
|
||||
// QUIC framed: 5-byte header + 16-byte beacon = 21 bytes
|
||||
assert_eq!(output.authenticated_bytes.len(), 5 + 16);
|
||||
assert_eq!(coord.beacons_produced(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_secure_coordinator_quic_verify_own_beacon() {
|
||||
let mut coord =
|
||||
SecureTdmCoordinator::new(test_schedule(), quic_config()).unwrap();
|
||||
let output = coord.begin_secure_cycle().unwrap();
|
||||
|
||||
let mut verifier =
|
||||
SecureTdmCoordinator::new(test_schedule(), quic_config()).unwrap();
|
||||
let beacon = verifier
|
||||
.verify_beacon(&output.authenticated_bytes)
|
||||
.unwrap();
|
||||
assert_eq!(beacon.cycle_id, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_secure_coordinator_complete_cycle() {
|
||||
let mut coord =
|
||||
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
|
||||
coord.begin_secure_cycle().unwrap();
|
||||
|
||||
for i in 0..4 {
|
||||
let event = coord.complete_slot(i, 0.95);
|
||||
assert_eq!(event.slot_index, i);
|
||||
}
|
||||
assert!(coord.is_cycle_complete());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_secure_coordinator_cycle_id_increments() {
|
||||
let mut coord =
|
||||
SecureTdmCoordinator::new(test_schedule(), manual_config()).unwrap();
|
||||
|
||||
let out0 = coord.begin_secure_cycle().unwrap();
|
||||
assert_eq!(out0.beacon.cycle_id, 0);
|
||||
for i in 0..4 {
|
||||
coord.complete_slot(i, 1.0);
|
||||
}
|
||||
|
||||
let out1 = coord.begin_secure_cycle().unwrap();
|
||||
assert_eq!(out1.beacon.cycle_id, 1);
|
||||
}
|
||||
|
||||
// ---- SecLevel tests ----
|
||||
|
||||
#[test]
|
||||
fn test_sec_level_values() {
|
||||
assert_eq!(SecLevel::Permissive as u8, 0);
|
||||
assert_eq!(SecLevel::Transitional as u8, 1);
|
||||
assert_eq!(SecLevel::Enforcing as u8, 2);
|
||||
}
|
||||
|
||||
// ---- Error display tests ----
|
||||
|
||||
#[test]
|
||||
fn test_secure_tdm_error_display() {
|
||||
let err = SecureTdmError::BeaconAuthFailed;
|
||||
assert!(format!("{}", err).contains("HMAC"));
|
||||
|
||||
let err = SecureTdmError::BeaconReplay {
|
||||
nonce: 5,
|
||||
last_accepted: 10,
|
||||
};
|
||||
assert!(format!("{}", err).contains("replay"));
|
||||
|
||||
let err = SecureTdmError::NoMeshKey;
|
||||
assert!(format!("{}", err).contains("Mesh key"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,814 @@
|
||||
//! TDM (Time-Division Multiplexed) sensing protocol for multistatic WiFi sensing.
|
||||
//!
|
||||
//! Implements the TDMA sensing schedule described in ADR-029 (RuvSense) and
|
||||
//! ADR-031 (RuView). Each ESP32 node transmits NDP frames in its assigned slot
|
||||
//! while all other nodes receive, producing N*(N-1) bistatic CSI links per cycle.
|
||||
//!
|
||||
//! # 4-Node Example (ADR-029 Table)
|
||||
//!
|
||||
//! ```text
|
||||
//! Slot 0: Node A TX, B/C/D RX (4 ms)
|
||||
//! Slot 1: Node B TX, A/C/D RX (4 ms)
|
||||
//! Slot 2: Node C TX, A/B/D RX (4 ms)
|
||||
//! Slot 3: Node D TX, A/B/C RX (4 ms)
|
||||
//! Slot 4: Processing + fusion (30 ms)
|
||||
//! Total: 50 ms = 20 Hz
|
||||
//! ```
|
||||
//!
|
||||
//! # Clock Drift Compensation
|
||||
//!
|
||||
//! ESP32 crystal drift is +/-10 ppm. Over a 50 ms cycle:
|
||||
//! drift = 10e-6 * 50e-3 = 0.5 us
|
||||
//!
|
||||
//! This is well within the 1 ms guard interval between slots, so no
|
||||
//! cross-node phase alignment is needed at the TDM scheduling layer.
|
||||
//! The coordinator tracks cumulative drift and issues correction offsets
|
||||
//! in sync beacons when drift exceeds a configurable threshold.
|
||||
|
||||
use std::fmt;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
/// Maximum supported nodes in a single TDM schedule.
|
||||
const MAX_NODES: usize = 16;
|
||||
|
||||
/// Default guard interval between TX slots (microseconds).
|
||||
const DEFAULT_GUARD_US: u64 = 1_000;
|
||||
|
||||
/// Default processing time after all TX slots complete (milliseconds).
|
||||
const DEFAULT_PROCESSING_MS: u64 = 30;
|
||||
|
||||
/// Default TX slot duration (milliseconds).
|
||||
const DEFAULT_SLOT_MS: u64 = 4;
|
||||
|
||||
/// Crystal drift specification for ESP32 (parts per million).
|
||||
const CRYSTAL_DRIFT_PPM: f64 = 10.0;
|
||||
|
||||
/// Errors that can occur during TDM schedule operations.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum TdmError {
|
||||
/// Node count is zero or exceeds the maximum.
|
||||
InvalidNodeCount { count: usize, max: usize },
|
||||
/// A slot index is out of bounds for the current schedule.
|
||||
SlotIndexOutOfBounds { index: usize, num_slots: usize },
|
||||
/// A node ID is not present in the schedule.
|
||||
UnknownNode { node_id: u8 },
|
||||
/// The guard interval is too large relative to the slot duration.
|
||||
GuardIntervalTooLarge { guard_us: u64, slot_us: u64 },
|
||||
/// Cycle period is too short to fit all slots plus processing.
|
||||
CycleTooShort { needed_us: u64, available_us: u64 },
|
||||
/// Drift correction offset exceeds the guard interval.
|
||||
DriftExceedsGuard { drift_us: f64, guard_us: u64 },
|
||||
}
|
||||
|
||||
impl fmt::Display for TdmError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
TdmError::InvalidNodeCount { count, max } => {
|
||||
write!(f, "Invalid node count: {} (max {})", count, max)
|
||||
}
|
||||
TdmError::SlotIndexOutOfBounds { index, num_slots } => {
|
||||
write!(f, "Slot index {} out of bounds (schedule has {} slots)", index, num_slots)
|
||||
}
|
||||
TdmError::UnknownNode { node_id } => {
|
||||
write!(f, "Unknown node ID: {}", node_id)
|
||||
}
|
||||
TdmError::GuardIntervalTooLarge { guard_us, slot_us } => {
|
||||
write!(f, "Guard interval {} us exceeds slot duration {} us", guard_us, slot_us)
|
||||
}
|
||||
TdmError::CycleTooShort { needed_us, available_us } => {
|
||||
write!(f, "Cycle too short: need {} us, have {} us", needed_us, available_us)
|
||||
}
|
||||
TdmError::DriftExceedsGuard { drift_us, guard_us } => {
|
||||
write!(f, "Drift {:.1} us exceeds guard interval {} us", drift_us, guard_us)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for TdmError {}
|
||||
|
||||
/// A single TDM time slot assignment.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct TdmSlot {
|
||||
/// Index of this slot within the cycle (0-based).
|
||||
pub index: usize,
|
||||
/// Node ID assigned to transmit during this slot.
|
||||
pub tx_node_id: u8,
|
||||
/// Duration of the TX window (excluding guard interval).
|
||||
pub duration: Duration,
|
||||
/// Guard interval after this slot before the next begins.
|
||||
pub guard_interval: Duration,
|
||||
}
|
||||
|
||||
impl TdmSlot {
|
||||
/// Total duration of this slot including guard interval.
|
||||
pub fn total_duration(&self) -> Duration {
|
||||
self.duration + self.guard_interval
|
||||
}
|
||||
|
||||
/// Start offset of this slot within the cycle.
|
||||
///
|
||||
/// Requires the full slot list to compute cumulative offset.
|
||||
pub fn start_offset(slots: &[TdmSlot], index: usize) -> Option<Duration> {
|
||||
if index >= slots.len() {
|
||||
return None;
|
||||
}
|
||||
let mut offset = Duration::ZERO;
|
||||
for slot in &slots[..index] {
|
||||
offset += slot.total_duration();
|
||||
}
|
||||
Some(offset)
|
||||
}
|
||||
}
|
||||
|
||||
/// TDM sensing schedule defining slot assignments and cycle timing.
|
||||
///
|
||||
/// A schedule assigns each node exactly one TX slot per cycle. During a
|
||||
/// node's TX slot, it transmits NDP frames while all other nodes receive
|
||||
/// and extract CSI. After all TX slots, a processing window allows the
|
||||
/// aggregator to fuse the collected CSI data.
|
||||
///
|
||||
/// # Example: 4-node schedule at 20 Hz
|
||||
///
|
||||
/// ```
|
||||
/// use wifi_densepose_hardware::esp32::TdmSchedule;
|
||||
/// use std::time::Duration;
|
||||
///
|
||||
/// let schedule = TdmSchedule::uniform(
|
||||
/// &[0, 1, 2, 3], // 4 node IDs
|
||||
/// Duration::from_millis(4), // 4 ms per TX slot
|
||||
/// Duration::from_micros(1_000), // 1 ms guard interval
|
||||
/// Duration::from_millis(30), // 30 ms processing window
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// assert_eq!(schedule.node_count(), 4);
|
||||
/// assert_eq!(schedule.cycle_period().as_millis(), 50); // 4*(4+1) + 30 = 50
|
||||
/// assert_eq!(schedule.update_rate_hz(), 20.0);
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TdmSchedule {
|
||||
/// Ordered slot assignments (one per node).
|
||||
slots: Vec<TdmSlot>,
|
||||
/// Processing window after all TX slots.
|
||||
processing_window: Duration,
|
||||
/// Total cycle period (sum of all slots + processing).
|
||||
cycle_period: Duration,
|
||||
}
|
||||
|
||||
impl TdmSchedule {
|
||||
/// Create a uniform TDM schedule where all nodes have equal slot duration.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `node_ids` - Ordered list of node IDs (determines TX order)
|
||||
/// * `slot_duration` - TX window duration per slot
|
||||
/// * `guard_interval` - Guard interval between consecutive slots
|
||||
/// * `processing_window` - Time after all TX slots for fusion processing
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `TdmError::InvalidNodeCount` if `node_ids` is empty or exceeds
|
||||
/// `MAX_NODES`. Returns `TdmError::GuardIntervalTooLarge` if the guard
|
||||
/// interval is larger than the slot duration.
|
||||
pub fn uniform(
|
||||
node_ids: &[u8],
|
||||
slot_duration: Duration,
|
||||
guard_interval: Duration,
|
||||
processing_window: Duration,
|
||||
) -> Result<Self, TdmError> {
|
||||
if node_ids.is_empty() || node_ids.len() > MAX_NODES {
|
||||
return Err(TdmError::InvalidNodeCount {
|
||||
count: node_ids.len(),
|
||||
max: MAX_NODES,
|
||||
});
|
||||
}
|
||||
|
||||
let slot_us = slot_duration.as_micros() as u64;
|
||||
let guard_us = guard_interval.as_micros() as u64;
|
||||
if guard_us >= slot_us {
|
||||
return Err(TdmError::GuardIntervalTooLarge { guard_us, slot_us });
|
||||
}
|
||||
|
||||
let slots: Vec<TdmSlot> = node_ids
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &node_id)| TdmSlot {
|
||||
index: i,
|
||||
tx_node_id: node_id,
|
||||
duration: slot_duration,
|
||||
guard_interval,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let tx_total: Duration = slots.iter().map(|s| s.total_duration()).sum();
|
||||
let cycle_period = tx_total + processing_window;
|
||||
|
||||
Ok(Self {
|
||||
slots,
|
||||
processing_window,
|
||||
cycle_period,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create the default 4-node, 20 Hz schedule from ADR-029.
|
||||
///
|
||||
/// ```
|
||||
/// use wifi_densepose_hardware::esp32::TdmSchedule;
|
||||
///
|
||||
/// let schedule = TdmSchedule::default_4node();
|
||||
/// assert_eq!(schedule.node_count(), 4);
|
||||
/// assert_eq!(schedule.update_rate_hz(), 20.0);
|
||||
/// ```
|
||||
pub fn default_4node() -> Self {
|
||||
Self::uniform(
|
||||
&[0, 1, 2, 3],
|
||||
Duration::from_millis(DEFAULT_SLOT_MS),
|
||||
Duration::from_micros(DEFAULT_GUARD_US),
|
||||
Duration::from_millis(DEFAULT_PROCESSING_MS),
|
||||
)
|
||||
.expect("default 4-node schedule is always valid")
|
||||
}
|
||||
|
||||
/// Number of nodes in this schedule.
|
||||
pub fn node_count(&self) -> usize {
|
||||
self.slots.len()
|
||||
}
|
||||
|
||||
/// Total cycle period (time between consecutive cycle starts).
|
||||
pub fn cycle_period(&self) -> Duration {
|
||||
self.cycle_period
|
||||
}
|
||||
|
||||
/// Effective update rate in Hz.
|
||||
pub fn update_rate_hz(&self) -> f64 {
|
||||
1.0 / self.cycle_period.as_secs_f64()
|
||||
}
|
||||
|
||||
/// Duration of the processing window after all TX slots.
|
||||
pub fn processing_window(&self) -> Duration {
|
||||
self.processing_window
|
||||
}
|
||||
|
||||
/// Get the slot assignment for a given slot index.
|
||||
pub fn slot(&self, index: usize) -> Option<&TdmSlot> {
|
||||
self.slots.get(index)
|
||||
}
|
||||
|
||||
/// Get the slot assigned to a specific node.
|
||||
pub fn slot_for_node(&self, node_id: u8) -> Option<&TdmSlot> {
|
||||
self.slots.iter().find(|s| s.tx_node_id == node_id)
|
||||
}
|
||||
|
||||
/// Immutable slice of all slot assignments.
|
||||
pub fn slots(&self) -> &[TdmSlot] {
|
||||
&self.slots
|
||||
}
|
||||
|
||||
/// Compute the maximum clock drift in microseconds for this cycle.
|
||||
///
|
||||
/// Uses the ESP32 crystal specification of +/-10 ppm.
|
||||
pub fn max_drift_us(&self) -> f64 {
|
||||
CRYSTAL_DRIFT_PPM * 1e-6 * self.cycle_period.as_secs_f64() * 1e6
|
||||
}
|
||||
|
||||
/// Check whether clock drift stays within the guard interval.
|
||||
pub fn drift_within_guard(&self) -> bool {
|
||||
let drift = self.max_drift_us();
|
||||
let guard = self.slots.first().map_or(0, |s| s.guard_interval.as_micros() as u64);
|
||||
drift < guard as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Event emitted when a TDM slot completes.
|
||||
///
|
||||
/// Published by the `TdmCoordinator` after a node finishes its TX window
|
||||
/// and the guard interval elapses. Listeners (e.g., the aggregator) use
|
||||
/// this to know when CSI data from this slot is expected to arrive.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TdmSlotCompleted {
|
||||
/// The cycle number (monotonically increasing from coordinator start).
|
||||
pub cycle_id: u64,
|
||||
/// The slot index within the cycle that completed.
|
||||
pub slot_index: usize,
|
||||
/// The node that was transmitting.
|
||||
pub tx_node_id: u8,
|
||||
/// Quality metric: fraction of expected CSI frames actually received (0.0-1.0).
|
||||
pub capture_quality: f32,
|
||||
/// Timestamp when the slot completed.
|
||||
pub completed_at: Instant,
|
||||
}
|
||||
|
||||
/// Sync beacon broadcast by the coordinator at the start of each TDM cycle.
|
||||
///
|
||||
/// All nodes use the beacon timestamp to align their local clocks and
|
||||
/// determine when their TX slot begins. The `drift_correction_us` field
|
||||
/// allows the coordinator to compensate for cumulative crystal drift.
|
||||
///
|
||||
/// # Wire format (planned)
|
||||
///
|
||||
/// The beacon is a short UDP broadcast (16 bytes):
|
||||
/// ```text
|
||||
/// [0..7] cycle_id (LE u64)
|
||||
/// [8..11] cycle_period_us (LE u32)
|
||||
/// [12..13] drift_correction_us (LE i16)
|
||||
/// [14..15] reserved
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SyncBeacon {
|
||||
/// Monotonically increasing cycle identifier.
|
||||
pub cycle_id: u64,
|
||||
/// Expected cycle period (from the schedule).
|
||||
pub cycle_period: Duration,
|
||||
/// Signed drift correction offset in microseconds.
|
||||
///
|
||||
/// Positive values mean nodes should start their slot slightly later;
|
||||
/// negative means earlier. Derived from observed arrival-time deviations.
|
||||
pub drift_correction_us: i16,
|
||||
/// Timestamp when the beacon was generated.
|
||||
pub generated_at: Instant,
|
||||
}
|
||||
|
||||
impl SyncBeacon {
|
||||
/// Serialize the beacon to the 16-byte wire format.
|
||||
pub fn to_bytes(&self) -> [u8; 16] {
|
||||
let mut buf = [0u8; 16];
|
||||
buf[0..8].copy_from_slice(&self.cycle_id.to_le_bytes());
|
||||
let period_us = self.cycle_period.as_micros() as u32;
|
||||
buf[8..12].copy_from_slice(&period_us.to_le_bytes());
|
||||
buf[12..14].copy_from_slice(&self.drift_correction_us.to_le_bytes());
|
||||
// [14..15] reserved = 0
|
||||
buf
|
||||
}
|
||||
|
||||
/// Deserialize a beacon from the 16-byte wire format.
|
||||
///
|
||||
/// Returns `None` if the buffer is too short.
|
||||
pub fn from_bytes(buf: &[u8]) -> Option<Self> {
|
||||
if buf.len() < 16 {
|
||||
return None;
|
||||
}
|
||||
let cycle_id = u64::from_le_bytes([
|
||||
buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7],
|
||||
]);
|
||||
let period_us = u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]);
|
||||
let drift_correction_us = i16::from_le_bytes([buf[12], buf[13]]);
|
||||
|
||||
Some(Self {
|
||||
cycle_id,
|
||||
cycle_period: Duration::from_micros(period_us as u64),
|
||||
drift_correction_us,
|
||||
generated_at: Instant::now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// TDM sensing cycle coordinator.
|
||||
///
|
||||
/// Manages the state machine for multistatic sensing cycles. The coordinator
|
||||
/// runs on the aggregator node and tracks:
|
||||
///
|
||||
/// - Current cycle ID and active slot
|
||||
/// - Which nodes have reported CSI data for the current cycle
|
||||
/// - Cumulative clock drift for compensation
|
||||
///
|
||||
/// # Usage
|
||||
///
|
||||
/// ```
|
||||
/// use wifi_densepose_hardware::esp32::{TdmSchedule, TdmCoordinator};
|
||||
///
|
||||
/// let schedule = TdmSchedule::default_4node();
|
||||
/// let mut coordinator = TdmCoordinator::new(schedule);
|
||||
///
|
||||
/// // Start a new sensing cycle
|
||||
/// let beacon = coordinator.begin_cycle();
|
||||
/// assert_eq!(beacon.cycle_id, 0);
|
||||
///
|
||||
/// // Complete each slot in the 4-node schedule
|
||||
/// for i in 0..4 {
|
||||
/// let event = coordinator.complete_slot(i, 0.95);
|
||||
/// assert_eq!(event.slot_index, i);
|
||||
/// }
|
||||
///
|
||||
/// // After all slots, the cycle is complete
|
||||
/// assert!(coordinator.is_cycle_complete());
|
||||
/// ```
|
||||
#[derive(Debug)]
|
||||
pub struct TdmCoordinator {
|
||||
/// The schedule governing slot assignments and timing.
|
||||
schedule: TdmSchedule,
|
||||
/// Current cycle number (incremented on each `begin_cycle`).
|
||||
cycle_id: u64,
|
||||
/// Index of the next slot expected to complete (0..node_count).
|
||||
next_slot: usize,
|
||||
/// Whether a cycle is currently in progress.
|
||||
cycle_active: bool,
|
||||
/// Per-node received flags for the current cycle.
|
||||
received: Vec<bool>,
|
||||
/// Cumulative observed drift in microseconds (for drift compensation).
|
||||
cumulative_drift_us: f64,
|
||||
/// Timestamp of the last cycle start (for drift measurement).
|
||||
last_cycle_start: Option<Instant>,
|
||||
}
|
||||
|
||||
impl TdmCoordinator {
|
||||
/// Create a new coordinator with the given schedule.
|
||||
pub fn new(schedule: TdmSchedule) -> Self {
|
||||
let n = schedule.node_count();
|
||||
Self {
|
||||
schedule,
|
||||
cycle_id: 0,
|
||||
next_slot: 0,
|
||||
cycle_active: false,
|
||||
received: vec![false; n],
|
||||
cumulative_drift_us: 0.0,
|
||||
last_cycle_start: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Begin a new sensing cycle. Returns the sync beacon to broadcast.
|
||||
///
|
||||
/// This resets per-slot tracking and increments the cycle ID (except
|
||||
/// for the very first cycle, which starts at 0).
|
||||
pub fn begin_cycle(&mut self) -> SyncBeacon {
|
||||
if self.cycle_active {
|
||||
// Auto-finalize the previous cycle
|
||||
self.cycle_active = false;
|
||||
}
|
||||
|
||||
if self.last_cycle_start.is_some() {
|
||||
self.cycle_id += 1;
|
||||
}
|
||||
|
||||
self.next_slot = 0;
|
||||
self.cycle_active = true;
|
||||
for flag in &mut self.received {
|
||||
*flag = false;
|
||||
}
|
||||
|
||||
// Measure drift from the previous cycle
|
||||
let now = Instant::now();
|
||||
if let Some(prev) = self.last_cycle_start {
|
||||
let actual_us = now.duration_since(prev).as_micros() as f64;
|
||||
let expected_us = self.schedule.cycle_period().as_micros() as f64;
|
||||
let drift = actual_us - expected_us;
|
||||
self.cumulative_drift_us += drift;
|
||||
}
|
||||
self.last_cycle_start = Some(now);
|
||||
|
||||
// Compute drift correction: negative of cumulative drift, clamped to i16
|
||||
let correction = (-self.cumulative_drift_us)
|
||||
.round()
|
||||
.clamp(i16::MIN as f64, i16::MAX as f64) as i16;
|
||||
|
||||
SyncBeacon {
|
||||
cycle_id: self.cycle_id,
|
||||
cycle_period: self.schedule.cycle_period(),
|
||||
drift_correction_us: correction,
|
||||
generated_at: now,
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark a slot as completed and return the completion event.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `slot_index` - The slot that completed (must match `next_slot`)
|
||||
/// * `capture_quality` - Fraction of expected CSI frames received (0.0-1.0)
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Does not panic. Returns a `TdmSlotCompleted` event even if the slot
|
||||
/// index is unexpected (the coordinator is lenient to allow out-of-order
|
||||
/// completions in degraded conditions).
|
||||
pub fn complete_slot(&mut self, slot_index: usize, capture_quality: f32) -> TdmSlotCompleted {
|
||||
let quality = capture_quality.clamp(0.0, 1.0);
|
||||
let tx_node_id = self
|
||||
.schedule
|
||||
.slot(slot_index)
|
||||
.map(|s| s.tx_node_id)
|
||||
.unwrap_or(0);
|
||||
|
||||
if slot_index < self.received.len() {
|
||||
self.received[slot_index] = true;
|
||||
}
|
||||
|
||||
if slot_index == self.next_slot {
|
||||
self.next_slot += 1;
|
||||
}
|
||||
|
||||
TdmSlotCompleted {
|
||||
cycle_id: self.cycle_id,
|
||||
slot_index,
|
||||
tx_node_id,
|
||||
capture_quality: quality,
|
||||
completed_at: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check whether all slots in the current cycle have completed.
|
||||
pub fn is_cycle_complete(&self) -> bool {
|
||||
self.received.iter().all(|&r| r)
|
||||
}
|
||||
|
||||
/// Number of slots that have completed in the current cycle.
|
||||
pub fn completed_slot_count(&self) -> usize {
|
||||
self.received.iter().filter(|&&r| r).count()
|
||||
}
|
||||
|
||||
/// Current cycle ID.
|
||||
pub fn cycle_id(&self) -> u64 {
|
||||
self.cycle_id
|
||||
}
|
||||
|
||||
/// Whether a cycle is currently active.
|
||||
pub fn is_active(&self) -> bool {
|
||||
self.cycle_active
|
||||
}
|
||||
|
||||
/// Reference to the underlying schedule.
|
||||
pub fn schedule(&self) -> &TdmSchedule {
|
||||
&self.schedule
|
||||
}
|
||||
|
||||
/// Current cumulative drift estimate in microseconds.
|
||||
pub fn cumulative_drift_us(&self) -> f64 {
|
||||
self.cumulative_drift_us
|
||||
}
|
||||
|
||||
/// Compute the maximum single-cycle drift for this schedule.
|
||||
///
|
||||
/// Based on ESP32 crystal spec of +/-10 ppm.
|
||||
pub fn max_single_cycle_drift_us(&self) -> f64 {
|
||||
self.schedule.max_drift_us()
|
||||
}
|
||||
|
||||
/// Generate a sync beacon for the current cycle without starting a new one.
|
||||
///
|
||||
/// Useful for re-broadcasting the beacon if a node missed it.
|
||||
pub fn current_beacon(&self) -> SyncBeacon {
|
||||
let correction = (-self.cumulative_drift_us)
|
||||
.round()
|
||||
.clamp(i16::MIN as f64, i16::MAX as f64) as i16;
|
||||
|
||||
SyncBeacon {
|
||||
cycle_id: self.cycle_id,
|
||||
cycle_period: self.schedule.cycle_period(),
|
||||
drift_correction_us: correction,
|
||||
generated_at: Instant::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ---- TdmSchedule tests ----
|
||||
|
||||
#[test]
|
||||
fn test_default_4node_schedule() {
|
||||
let schedule = TdmSchedule::default_4node();
|
||||
assert_eq!(schedule.node_count(), 4);
|
||||
// 4 slots * (4ms + 1ms guard) + 30ms processing = 50ms
|
||||
assert_eq!(schedule.cycle_period().as_millis(), 50);
|
||||
assert_eq!(schedule.update_rate_hz(), 20.0);
|
||||
assert!(schedule.drift_within_guard());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_uniform_schedule_timing() {
|
||||
let schedule = TdmSchedule::uniform(
|
||||
&[10, 20, 30],
|
||||
Duration::from_millis(5),
|
||||
Duration::from_micros(500),
|
||||
Duration::from_millis(20),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(schedule.node_count(), 3);
|
||||
// 3 * (5ms + 0.5ms) + 20ms = 16.5 + 20 = 36.5ms
|
||||
let expected_us: u64 = 3 * (5_000 + 500) + 20_000;
|
||||
assert_eq!(schedule.cycle_period().as_micros() as u64, expected_us);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slot_for_node() {
|
||||
let schedule = TdmSchedule::uniform(
|
||||
&[5, 10, 15],
|
||||
Duration::from_millis(4),
|
||||
Duration::from_micros(1_000),
|
||||
Duration::from_millis(30),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let slot = schedule.slot_for_node(10).unwrap();
|
||||
assert_eq!(slot.index, 1);
|
||||
assert_eq!(slot.tx_node_id, 10);
|
||||
|
||||
assert!(schedule.slot_for_node(99).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slot_start_offset() {
|
||||
let schedule = TdmSchedule::uniform(
|
||||
&[0, 1, 2, 3],
|
||||
Duration::from_millis(4),
|
||||
Duration::from_micros(1_000),
|
||||
Duration::from_millis(30),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Slot 0 starts at 0
|
||||
let offset0 = TdmSlot::start_offset(schedule.slots(), 0).unwrap();
|
||||
assert_eq!(offset0, Duration::ZERO);
|
||||
|
||||
// Slot 1 starts at 4ms + 1ms = 5ms
|
||||
let offset1 = TdmSlot::start_offset(schedule.slots(), 1).unwrap();
|
||||
assert_eq!(offset1.as_micros(), 5_000);
|
||||
|
||||
// Slot 2 starts at 2 * 5ms = 10ms
|
||||
let offset2 = TdmSlot::start_offset(schedule.slots(), 2).unwrap();
|
||||
assert_eq!(offset2.as_micros(), 10_000);
|
||||
|
||||
// Out of bounds returns None
|
||||
assert!(TdmSlot::start_offset(schedule.slots(), 10).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_node_list_rejected() {
|
||||
let result = TdmSchedule::uniform(
|
||||
&[],
|
||||
Duration::from_millis(4),
|
||||
Duration::from_micros(1_000),
|
||||
Duration::from_millis(30),
|
||||
);
|
||||
assert_eq!(
|
||||
result.unwrap_err(),
|
||||
TdmError::InvalidNodeCount { count: 0, max: MAX_NODES }
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_too_many_nodes_rejected() {
|
||||
let ids: Vec<u8> = (0..=MAX_NODES as u8).collect();
|
||||
let result = TdmSchedule::uniform(
|
||||
&ids,
|
||||
Duration::from_millis(4),
|
||||
Duration::from_micros(1_000),
|
||||
Duration::from_millis(30),
|
||||
);
|
||||
assert!(matches!(result, Err(TdmError::InvalidNodeCount { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_guard_interval_too_large() {
|
||||
let result = TdmSchedule::uniform(
|
||||
&[0, 1],
|
||||
Duration::from_millis(1), // 1 ms slot
|
||||
Duration::from_millis(2), // 2 ms guard > slot
|
||||
Duration::from_millis(30),
|
||||
);
|
||||
assert!(matches!(result, Err(TdmError::GuardIntervalTooLarge { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_drift_calculation() {
|
||||
let schedule = TdmSchedule::default_4node();
|
||||
let drift = schedule.max_drift_us();
|
||||
// 10 ppm * 50ms = 0.5 us
|
||||
assert!((drift - 0.5).abs() < 0.01);
|
||||
}
|
||||
|
||||
// ---- SyncBeacon tests ----
|
||||
|
||||
#[test]
|
||||
fn test_sync_beacon_roundtrip() {
|
||||
let beacon = SyncBeacon {
|
||||
cycle_id: 42,
|
||||
cycle_period: Duration::from_millis(50),
|
||||
drift_correction_us: -3,
|
||||
generated_at: Instant::now(),
|
||||
};
|
||||
|
||||
let bytes = beacon.to_bytes();
|
||||
assert_eq!(bytes.len(), 16);
|
||||
|
||||
let decoded = SyncBeacon::from_bytes(&bytes).unwrap();
|
||||
assert_eq!(decoded.cycle_id, 42);
|
||||
assert_eq!(decoded.cycle_period, Duration::from_millis(50));
|
||||
assert_eq!(decoded.drift_correction_us, -3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sync_beacon_short_buffer() {
|
||||
assert!(SyncBeacon::from_bytes(&[0u8; 10]).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sync_beacon_zero_drift() {
|
||||
let beacon = SyncBeacon {
|
||||
cycle_id: 0,
|
||||
cycle_period: Duration::from_millis(50),
|
||||
drift_correction_us: 0,
|
||||
generated_at: Instant::now(),
|
||||
};
|
||||
let bytes = beacon.to_bytes();
|
||||
let decoded = SyncBeacon::from_bytes(&bytes).unwrap();
|
||||
assert_eq!(decoded.drift_correction_us, 0);
|
||||
}
|
||||
|
||||
// ---- TdmCoordinator tests ----
|
||||
|
||||
#[test]
|
||||
fn test_coordinator_begin_cycle() {
|
||||
let schedule = TdmSchedule::default_4node();
|
||||
let mut coord = TdmCoordinator::new(schedule);
|
||||
|
||||
let beacon = coord.begin_cycle();
|
||||
assert_eq!(beacon.cycle_id, 0);
|
||||
assert!(coord.is_active());
|
||||
assert!(!coord.is_cycle_complete());
|
||||
assert_eq!(coord.completed_slot_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coordinator_complete_all_slots() {
|
||||
let schedule = TdmSchedule::default_4node();
|
||||
let mut coord = TdmCoordinator::new(schedule);
|
||||
coord.begin_cycle();
|
||||
|
||||
for i in 0..4 {
|
||||
assert!(!coord.is_cycle_complete());
|
||||
let event = coord.complete_slot(i, 0.95);
|
||||
assert_eq!(event.cycle_id, 0);
|
||||
assert_eq!(event.slot_index, i);
|
||||
}
|
||||
|
||||
assert!(coord.is_cycle_complete());
|
||||
assert_eq!(coord.completed_slot_count(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coordinator_cycle_id_increments() {
|
||||
let schedule = TdmSchedule::default_4node();
|
||||
let mut coord = TdmCoordinator::new(schedule);
|
||||
|
||||
let b0 = coord.begin_cycle();
|
||||
assert_eq!(b0.cycle_id, 0);
|
||||
|
||||
// Complete all slots
|
||||
for i in 0..4 {
|
||||
coord.complete_slot(i, 1.0);
|
||||
}
|
||||
|
||||
let b1 = coord.begin_cycle();
|
||||
assert_eq!(b1.cycle_id, 1);
|
||||
|
||||
for i in 0..4 {
|
||||
coord.complete_slot(i, 1.0);
|
||||
}
|
||||
|
||||
let b2 = coord.begin_cycle();
|
||||
assert_eq!(b2.cycle_id, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coordinator_capture_quality_clamped() {
|
||||
let schedule = TdmSchedule::default_4node();
|
||||
let mut coord = TdmCoordinator::new(schedule);
|
||||
coord.begin_cycle();
|
||||
|
||||
let event = coord.complete_slot(0, 1.5);
|
||||
assert_eq!(event.capture_quality, 1.0);
|
||||
|
||||
let event = coord.complete_slot(1, -0.5);
|
||||
assert_eq!(event.capture_quality, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coordinator_current_beacon() {
|
||||
let schedule = TdmSchedule::default_4node();
|
||||
let mut coord = TdmCoordinator::new(schedule);
|
||||
coord.begin_cycle();
|
||||
|
||||
let beacon = coord.current_beacon();
|
||||
assert_eq!(beacon.cycle_id, 0);
|
||||
assert_eq!(beacon.cycle_period.as_millis(), 50);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coordinator_drift_starts_at_zero() {
|
||||
let schedule = TdmSchedule::default_4node();
|
||||
let coord = TdmCoordinator::new(schedule);
|
||||
assert_eq!(coord.cumulative_drift_us(), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coordinator_max_single_cycle_drift() {
|
||||
let schedule = TdmSchedule::default_4node();
|
||||
let coord = TdmCoordinator::new(schedule);
|
||||
// 10 ppm * 50ms = 0.5 us
|
||||
let drift = coord.max_single_cycle_drift_us();
|
||||
assert!((drift - 0.5).abs() < 0.01);
|
||||
}
|
||||
}
|
||||
@@ -39,6 +39,7 @@ mod error;
|
||||
mod esp32_parser;
|
||||
pub mod aggregator;
|
||||
mod bridge;
|
||||
pub mod esp32;
|
||||
|
||||
pub use csi_frame::{CsiFrame, CsiMetadata, SubcarrierData, Bandwidth, AntennaConfig};
|
||||
pub use error::ParseError;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "wifi-densepose-mat"
|
||||
version = "0.2.0"
|
||||
version = "0.3.0"
|
||||
edition = "2021"
|
||||
authors = ["rUv <ruv@ruv.net>", "WiFi-DensePose Contributors"]
|
||||
description = "Mass Casualty Assessment Tool - WiFi-based disaster survivor detection"
|
||||
@@ -24,9 +24,9 @@ serde = ["dep:serde", "chrono/serde", "geo/use-serde"]
|
||||
|
||||
[dependencies]
|
||||
# Workspace dependencies
|
||||
wifi-densepose-core = { version = "0.2.0", path = "../wifi-densepose-core" }
|
||||
wifi-densepose-signal = { version = "0.2.0", path = "../wifi-densepose-signal" }
|
||||
wifi-densepose-nn = { version = "0.2.0", path = "../wifi-densepose-nn" }
|
||||
wifi-densepose-core = { version = "0.3.0", path = "../wifi-densepose-core" }
|
||||
wifi-densepose-signal = { version = "0.3.0", path = "../wifi-densepose-signal" }
|
||||
wifi-densepose-nn = { version = "0.3.0", path = "../wifi-densepose-nn" }
|
||||
ruvector-solver = { workspace = true, optional = true }
|
||||
ruvector-temporal-tensor = { workspace = true, optional = true }
|
||||
|
||||
|
||||
@@ -10,10 +10,26 @@ keywords = ["wifi", "csi", "ruvector", "signal-processing", "disaster-detection"
|
||||
categories = ["science", "computer-vision"]
|
||||
readme = "README.md"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
crv = ["dep:ruvector-crv", "dep:ruvector-gnn", "dep:serde", "dep:serde_json"]
|
||||
|
||||
[dependencies]
|
||||
ruvector-mincut = { workspace = true }
|
||||
ruvector-attn-mincut = { workspace = true }
|
||||
ruvector-temporal-tensor = { workspace = true }
|
||||
ruvector-solver = { workspace = true }
|
||||
ruvector-attention = { workspace = true }
|
||||
ruvector-crv = { workspace = true, optional = true }
|
||||
ruvector-gnn = { workspace = true, optional = true }
|
||||
thiserror = { workspace = true }
|
||||
serde = { workspace = true, optional = true }
|
||||
serde_json = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
approx = "0.5"
|
||||
criterion = { workspace = true }
|
||||
|
||||
[[bench]]
|
||||
name = "crv_bench"
|
||||
harness = false
|
||||
|
||||
@@ -0,0 +1,405 @@
|
||||
//! Benchmarks for CRV (Coordinate Remote Viewing) integration.
|
||||
//!
|
||||
//! Measures throughput of gestalt classification, sensory encoding,
|
||||
//! full session pipelines, cross-session convergence, and embedding
|
||||
//! dimension scaling using the `ruvector-crv` crate directly.
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||
use ruvector_crv::{
|
||||
CrvConfig, CrvSessionManager, GestaltType, SensoryModality, StageIData, StageIIData,
|
||||
StageIIIData, StageIVData,
|
||||
};
|
||||
use ruvector_crv::types::{
|
||||
GeometricKind, SketchElement, SpatialRelationType, SpatialRelationship,
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Build a synthetic CSI-like ideogram stroke with `n` subcarrier points.
|
||||
fn make_stroke(n: usize) -> Vec<(f32, f32)> {
|
||||
(0..n)
|
||||
.map(|i| {
|
||||
let t = i as f32 / n as f32;
|
||||
(t, (t * std::f32::consts::TAU).sin() * 0.5 + 0.5)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Build a Stage I data frame representing a single CSI gestalt sample.
|
||||
fn make_stage_i(gestalt: GestaltType) -> StageIData {
|
||||
StageIData {
|
||||
stroke: make_stroke(64),
|
||||
spontaneous_descriptor: "angular rising".to_string(),
|
||||
classification: gestalt,
|
||||
confidence: 0.85,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a Stage II sensory data frame.
|
||||
fn make_stage_ii() -> StageIIData {
|
||||
StageIIData {
|
||||
impressions: vec![
|
||||
(SensoryModality::Texture, "rough metallic".to_string()),
|
||||
(SensoryModality::Temperature, "warm".to_string()),
|
||||
(SensoryModality::Color, "silver-gray".to_string()),
|
||||
(SensoryModality::Luminosity, "reflective".to_string()),
|
||||
(SensoryModality::Sound, "low hum".to_string()),
|
||||
],
|
||||
feature_vector: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a Stage III spatial sketch.
|
||||
fn make_stage_iii() -> StageIIIData {
|
||||
StageIIIData {
|
||||
sketch_elements: vec![
|
||||
SketchElement {
|
||||
label: "tower".to_string(),
|
||||
kind: GeometricKind::Rectangle,
|
||||
position: (0.5, 0.8),
|
||||
scale: Some(3.0),
|
||||
},
|
||||
SketchElement {
|
||||
label: "base".to_string(),
|
||||
kind: GeometricKind::Rectangle,
|
||||
position: (0.5, 0.2),
|
||||
scale: Some(5.0),
|
||||
},
|
||||
SketchElement {
|
||||
label: "antenna".to_string(),
|
||||
kind: GeometricKind::Line,
|
||||
position: (0.5, 0.95),
|
||||
scale: Some(1.0),
|
||||
},
|
||||
],
|
||||
relationships: vec![
|
||||
SpatialRelationship {
|
||||
from: "tower".to_string(),
|
||||
to: "base".to_string(),
|
||||
relation: SpatialRelationType::Above,
|
||||
strength: 0.9,
|
||||
},
|
||||
SpatialRelationship {
|
||||
from: "antenna".to_string(),
|
||||
to: "tower".to_string(),
|
||||
relation: SpatialRelationType::Above,
|
||||
strength: 0.85,
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a Stage IV emotional / AOL data frame.
|
||||
fn make_stage_iv() -> StageIVData {
|
||||
StageIVData {
|
||||
emotional_impact: vec![
|
||||
("awe".to_string(), 0.7),
|
||||
("curiosity".to_string(), 0.6),
|
||||
("unease".to_string(), 0.3),
|
||||
],
|
||||
tangibles: vec!["metal structure".to_string(), "concrete".to_string()],
|
||||
intangibles: vec!["transmission".to_string(), "power".to_string()],
|
||||
aol_detections: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a manager with one session pre-loaded with 4 stages of data.
|
||||
fn populated_manager(dims: usize) -> (CrvSessionManager, String) {
|
||||
let config = CrvConfig {
|
||||
dimensions: dims,
|
||||
..CrvConfig::default()
|
||||
};
|
||||
let mut mgr = CrvSessionManager::new(config);
|
||||
let sid = "bench-sess".to_string();
|
||||
mgr.create_session(sid.clone(), "coord-001".to_string())
|
||||
.unwrap();
|
||||
mgr.add_stage_i(&sid, &make_stage_i(GestaltType::Manmade))
|
||||
.unwrap();
|
||||
mgr.add_stage_ii(&sid, &make_stage_ii()).unwrap();
|
||||
mgr.add_stage_iii(&sid, &make_stage_iii()).unwrap();
|
||||
mgr.add_stage_iv(&sid, &make_stage_iv()).unwrap();
|
||||
(mgr, sid)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Benchmarks
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Benchmark: classify a single CSI frame through Stage I (64 subcarriers).
|
||||
fn gestalt_classify_single(c: &mut Criterion) {
|
||||
let config = CrvConfig {
|
||||
dimensions: 64,
|
||||
..CrvConfig::default()
|
||||
};
|
||||
let mut manager = CrvSessionManager::new(config);
|
||||
manager
|
||||
.create_session("gc-single".to_string(), "coord-gc".to_string())
|
||||
.unwrap();
|
||||
|
||||
let data = make_stage_i(GestaltType::Manmade);
|
||||
|
||||
c.bench_function("gestalt_classify_single", |b| {
|
||||
b.iter(|| {
|
||||
manager
|
||||
.add_stage_i("gc-single", black_box(&data))
|
||||
.unwrap();
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
/// Benchmark: classify a batch of 100 CSI frames through Stage I.
|
||||
fn gestalt_classify_batch(c: &mut Criterion) {
|
||||
let config = CrvConfig {
|
||||
dimensions: 64,
|
||||
..CrvConfig::default()
|
||||
};
|
||||
|
||||
let gestalts = GestaltType::all();
|
||||
let frames: Vec<StageIData> = (0..100)
|
||||
.map(|i| make_stage_i(gestalts[i % gestalts.len()]))
|
||||
.collect();
|
||||
|
||||
c.bench_function("gestalt_classify_batch_100", |b| {
|
||||
b.iter(|| {
|
||||
let mut manager = CrvSessionManager::new(CrvConfig {
|
||||
dimensions: 64,
|
||||
..CrvConfig::default()
|
||||
});
|
||||
manager
|
||||
.create_session("gc-batch".to_string(), "coord-gcb".to_string())
|
||||
.unwrap();
|
||||
|
||||
for frame in black_box(&frames) {
|
||||
manager.add_stage_i("gc-batch", frame).unwrap();
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
/// Benchmark: extract sensory features from a single CSI frame (Stage II).
|
||||
fn sensory_encode_single(c: &mut Criterion) {
|
||||
let config = CrvConfig {
|
||||
dimensions: 64,
|
||||
..CrvConfig::default()
|
||||
};
|
||||
let mut manager = CrvSessionManager::new(config);
|
||||
manager
|
||||
.create_session("se-single".to_string(), "coord-se".to_string())
|
||||
.unwrap();
|
||||
|
||||
let data = make_stage_ii();
|
||||
|
||||
c.bench_function("sensory_encode_single", |b| {
|
||||
b.iter(|| {
|
||||
manager
|
||||
.add_stage_ii("se-single", black_box(&data))
|
||||
.unwrap();
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
/// Benchmark: full session pipeline -- create session, add 10 mixed-stage
|
||||
/// frames, run Stage V interrogation, and run Stage VI partitioning.
|
||||
fn pipeline_full_session(c: &mut Criterion) {
|
||||
let stage_i_data = make_stage_i(GestaltType::Manmade);
|
||||
let stage_ii_data = make_stage_ii();
|
||||
let stage_iii_data = make_stage_iii();
|
||||
let stage_iv_data = make_stage_iv();
|
||||
|
||||
c.bench_function("pipeline_full_session", |b| {
|
||||
let mut counter = 0u64;
|
||||
b.iter(|| {
|
||||
counter += 1;
|
||||
let config = CrvConfig {
|
||||
dimensions: 64,
|
||||
..CrvConfig::default()
|
||||
};
|
||||
let mut manager = CrvSessionManager::new(config);
|
||||
let sid = format!("pfs-{}", counter);
|
||||
manager
|
||||
.create_session(sid.clone(), "coord-pfs".to_string())
|
||||
.unwrap();
|
||||
|
||||
// 10 frames across stages I-IV
|
||||
for _ in 0..3 {
|
||||
manager
|
||||
.add_stage_i(&sid, black_box(&stage_i_data))
|
||||
.unwrap();
|
||||
}
|
||||
for _ in 0..3 {
|
||||
manager
|
||||
.add_stage_ii(&sid, black_box(&stage_ii_data))
|
||||
.unwrap();
|
||||
}
|
||||
for _ in 0..2 {
|
||||
manager
|
||||
.add_stage_iii(&sid, black_box(&stage_iii_data))
|
||||
.unwrap();
|
||||
}
|
||||
for _ in 0..2 {
|
||||
manager
|
||||
.add_stage_iv(&sid, black_box(&stage_iv_data))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Stage V: interrogate with a probe embedding
|
||||
let probe_emb = vec![0.1f32; 64];
|
||||
let probes: Vec<(&str, u8, Vec<f32>)> = vec![
|
||||
("structure query", 1, probe_emb.clone()),
|
||||
("texture query", 2, probe_emb.clone()),
|
||||
];
|
||||
let _ = manager.run_stage_v(&sid, &probes, 3);
|
||||
|
||||
// Stage VI: partition
|
||||
let _ = manager.run_stage_vi(&sid);
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
/// Benchmark: cross-session convergence analysis with 2 independent
|
||||
/// sessions of 10 frames each, targeting the same coordinate.
|
||||
fn convergence_two_sessions(c: &mut Criterion) {
|
||||
let gestalts = [GestaltType::Manmade, GestaltType::Natural, GestaltType::Energy];
|
||||
let stage_ii_data = make_stage_ii();
|
||||
|
||||
c.bench_function("convergence_two_sessions", |b| {
|
||||
let mut counter = 0u64;
|
||||
b.iter(|| {
|
||||
counter += 1;
|
||||
let config = CrvConfig {
|
||||
dimensions: 64,
|
||||
convergence_threshold: 0.5,
|
||||
..CrvConfig::default()
|
||||
};
|
||||
let mut manager = CrvSessionManager::new(config);
|
||||
let coord = format!("conv-coord-{}", counter);
|
||||
|
||||
// Session A: 10 frames
|
||||
let sid_a = format!("viewer-a-{}", counter);
|
||||
manager
|
||||
.create_session(sid_a.clone(), coord.clone())
|
||||
.unwrap();
|
||||
for i in 0..5 {
|
||||
let data = make_stage_i(gestalts[i % gestalts.len()]);
|
||||
manager.add_stage_i(&sid_a, black_box(&data)).unwrap();
|
||||
}
|
||||
for _ in 0..5 {
|
||||
manager
|
||||
.add_stage_ii(&sid_a, black_box(&stage_ii_data))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Session B: 10 frames (similar but not identical)
|
||||
let sid_b = format!("viewer-b-{}", counter);
|
||||
manager
|
||||
.create_session(sid_b.clone(), coord.clone())
|
||||
.unwrap();
|
||||
for i in 0..5 {
|
||||
let data = make_stage_i(gestalts[(i + 1) % gestalts.len()]);
|
||||
manager.add_stage_i(&sid_b, black_box(&data)).unwrap();
|
||||
}
|
||||
for _ in 0..5 {
|
||||
manager
|
||||
.add_stage_ii(&sid_b, black_box(&stage_ii_data))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Convergence analysis
|
||||
let _ = manager.find_convergence(&coord, black_box(0.5));
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
/// Benchmark: session creation overhead alone.
|
||||
fn crv_session_create(c: &mut Criterion) {
|
||||
c.bench_function("crv_session_create", |b| {
|
||||
b.iter(|| {
|
||||
let config = CrvConfig {
|
||||
dimensions: 32,
|
||||
..CrvConfig::default()
|
||||
};
|
||||
let mut manager = CrvSessionManager::new(black_box(config));
|
||||
manager
|
||||
.create_session(
|
||||
black_box("sess-1".to_string()),
|
||||
black_box("coord-1".to_string()),
|
||||
)
|
||||
.unwrap();
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
/// Benchmark: embedding dimension scaling (32, 128, 384).
|
||||
///
|
||||
/// Measures Stage I + Stage II encode time across different embedding
|
||||
/// dimensions to characterize how cost grows with dimensionality.
|
||||
fn crv_embedding_dimension_scaling(c: &mut Criterion) {
|
||||
let stage_i_data = make_stage_i(GestaltType::Manmade);
|
||||
let stage_ii_data = make_stage_ii();
|
||||
|
||||
let mut group = c.benchmark_group("crv_embedding_dimension_scaling");
|
||||
for dims in [32, 128, 384] {
|
||||
group.bench_with_input(BenchmarkId::from_parameter(dims), &dims, |b, &dims| {
|
||||
let mut counter = 0u64;
|
||||
b.iter(|| {
|
||||
counter += 1;
|
||||
let config = CrvConfig {
|
||||
dimensions: dims,
|
||||
..CrvConfig::default()
|
||||
};
|
||||
let mut manager = CrvSessionManager::new(config);
|
||||
let sid = format!("dim-{}-{}", dims, counter);
|
||||
manager
|
||||
.create_session(sid.clone(), "coord-dim".to_string())
|
||||
.unwrap();
|
||||
|
||||
// Encode one Stage I + one Stage II at this dimensionality
|
||||
let emb_i = manager
|
||||
.add_stage_i(&sid, black_box(&stage_i_data))
|
||||
.unwrap();
|
||||
let emb_ii = manager
|
||||
.add_stage_ii(&sid, black_box(&stage_ii_data))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(emb_i.len(), dims);
|
||||
assert_eq!(emb_ii.len(), dims);
|
||||
})
|
||||
});
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
/// Benchmark: Stage VI partitioning on a pre-populated session
|
||||
/// (4 stages of accumulated data).
|
||||
fn crv_stage_vi_partition(c: &mut Criterion) {
|
||||
c.bench_function("crv_stage_vi_partition", |b| {
|
||||
let mut counter = 0u64;
|
||||
b.iter(|| {
|
||||
counter += 1;
|
||||
// Re-create the populated manager each iteration because
|
||||
// run_stage_vi mutates the session (appends an entry).
|
||||
let (mut mgr, sid) = populated_manager(64);
|
||||
let _ = mgr.run_stage_vi(black_box(&sid));
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Criterion groups
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
gestalt_classify_single,
|
||||
gestalt_classify_batch,
|
||||
sensory_encode_single,
|
||||
pipeline_full_session,
|
||||
convergence_two_sessions,
|
||||
crv_session_create,
|
||||
crv_embedding_dimension_scaling,
|
||||
crv_stage_vi_partition,
|
||||
);
|
||||
|
||||
criterion_main!(benches);
|
||||
File diff suppressed because it is too large
Load Diff
@@ -26,5 +26,8 @@
|
||||
|
||||
#![warn(missing_docs)]
|
||||
|
||||
#[cfg(feature = "crv")]
|
||||
pub mod crv;
|
||||
pub mod mat;
|
||||
pub mod signal;
|
||||
pub mod viewpoint;
|
||||
|
||||
@@ -0,0 +1,667 @@
|
||||
//! Cross-viewpoint scaled dot-product attention with geometric bias (ADR-031).
|
||||
//!
|
||||
//! Implements the core RuView attention mechanism:
|
||||
//!
|
||||
//! ```text
|
||||
//! Q = W_q * X, K = W_k * X, V = W_v * X
|
||||
//! A = softmax((Q * K^T + G_bias) / sqrt(d))
|
||||
//! fused = A * V
|
||||
//! ```
|
||||
//!
|
||||
//! The geometric bias `G_bias` encodes angular separation and baseline distance
|
||||
//! between each viewpoint pair, allowing the attention mechanism to learn that
|
||||
//! widely-separated, orthogonal viewpoints are more complementary than clustered
|
||||
//! ones.
|
||||
//!
|
||||
//! Wraps `ruvector_attention::ScaledDotProductAttention` for the underlying
|
||||
//! attention computation.
|
||||
|
||||
// The cross-viewpoint attention is implemented directly rather than wrapping
|
||||
// ruvector_attention::ScaledDotProductAttention, because we need to inject
|
||||
// the geometric bias matrix G_bias into the QK^T scores before softmax --
|
||||
// an operation not exposed by the ruvector API. The ruvector-attention crate
|
||||
// is still a workspace dependency for the signal/bvp integration point.
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Error types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors produced by the cross-viewpoint attention module.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AttentionError {
|
||||
/// The number of viewpoints is zero.
|
||||
EmptyViewpoints,
|
||||
/// Embedding dimension mismatch between viewpoints.
|
||||
DimensionMismatch {
|
||||
/// Expected embedding dimension.
|
||||
expected: usize,
|
||||
/// Actual embedding dimension found.
|
||||
actual: usize,
|
||||
},
|
||||
/// The geometric bias matrix dimensions do not match the viewpoint count.
|
||||
BiasDimensionMismatch {
|
||||
/// Number of viewpoints.
|
||||
n_viewpoints: usize,
|
||||
/// Rows in bias matrix.
|
||||
bias_rows: usize,
|
||||
/// Columns in bias matrix.
|
||||
bias_cols: usize,
|
||||
},
|
||||
/// The projection weight matrix has incorrect dimensions.
|
||||
WeightDimensionMismatch {
|
||||
/// Expected dimension.
|
||||
expected: usize,
|
||||
/// Actual dimension.
|
||||
actual: usize,
|
||||
},
|
||||
}
|
||||
|
||||
impl std::fmt::Display for AttentionError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
AttentionError::EmptyViewpoints => write!(f, "no viewpoint embeddings provided"),
|
||||
AttentionError::DimensionMismatch { expected, actual } => {
|
||||
write!(f, "embedding dimension mismatch: expected {expected}, got {actual}")
|
||||
}
|
||||
AttentionError::BiasDimensionMismatch { n_viewpoints, bias_rows, bias_cols } => {
|
||||
write!(
|
||||
f,
|
||||
"geometric bias matrix is {bias_rows}x{bias_cols} but {n_viewpoints} viewpoints require {n_viewpoints}x{n_viewpoints}"
|
||||
)
|
||||
}
|
||||
AttentionError::WeightDimensionMismatch { expected, actual } => {
|
||||
write!(f, "weight matrix dimension mismatch: expected {expected}, got {actual}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for AttentionError {}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GeometricBias
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Geometric bias matrix encoding spatial relationships between viewpoint pairs.
|
||||
///
|
||||
/// The bias for viewpoint pair `(i, j)` is computed as:
|
||||
///
|
||||
/// ```text
|
||||
/// G_bias[i,j] = w_angle * cos(theta_ij) + w_dist * exp(-d_ij / d_ref)
|
||||
/// ```
|
||||
///
|
||||
/// where `theta_ij` is the angular separation between viewpoints `i` and `j`
|
||||
/// from the array centroid, `d_ij` is the baseline distance, `w_angle` and
|
||||
/// `w_dist` are learnable scalar weights, and `d_ref` is a reference distance
|
||||
/// (typically room diagonal / 2).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GeometricBias {
|
||||
/// Learnable weight for the angular component.
|
||||
pub w_angle: f32,
|
||||
/// Learnable weight for the distance component.
|
||||
pub w_dist: f32,
|
||||
/// Reference distance for the exponential decay (metres).
|
||||
pub d_ref: f32,
|
||||
}
|
||||
|
||||
impl Default for GeometricBias {
|
||||
fn default() -> Self {
|
||||
GeometricBias {
|
||||
w_angle: 1.0,
|
||||
w_dist: 1.0,
|
||||
d_ref: 5.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A single viewpoint geometry descriptor.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ViewpointGeometry {
|
||||
/// Azimuth angle from array centroid (radians).
|
||||
pub azimuth: f32,
|
||||
/// 2-D position (x, y) in metres.
|
||||
pub position: (f32, f32),
|
||||
}
|
||||
|
||||
impl GeometricBias {
|
||||
/// Create a new geometric bias with the given parameters.
|
||||
pub fn new(w_angle: f32, w_dist: f32, d_ref: f32) -> Self {
|
||||
GeometricBias { w_angle, w_dist, d_ref }
|
||||
}
|
||||
|
||||
/// Compute the bias value for a single viewpoint pair.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `theta_ij`: angular separation in radians between viewpoints `i` and `j`.
|
||||
/// - `d_ij`: baseline distance in metres between viewpoints `i` and `j`.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The scalar bias value `w_angle * cos(theta_ij) + w_dist * exp(-d_ij / d_ref)`.
|
||||
pub fn compute_pair(&self, theta_ij: f32, d_ij: f32) -> f32 {
|
||||
let safe_d_ref = self.d_ref.max(1e-6);
|
||||
self.w_angle * theta_ij.cos() + self.w_dist * (-d_ij / safe_d_ref).exp()
|
||||
}
|
||||
|
||||
/// Build the full N x N geometric bias matrix from viewpoint geometries.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `viewpoints`: slice of viewpoint geometry descriptors.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Flat row-major `N x N` bias matrix.
|
||||
pub fn build_matrix(&self, viewpoints: &[ViewpointGeometry]) -> Vec<f32> {
|
||||
let n = viewpoints.len();
|
||||
let mut matrix = vec![0.0_f32; n * n];
|
||||
for i in 0..n {
|
||||
for j in 0..n {
|
||||
if i == j {
|
||||
// Self-bias: maximum (cos(0) = 1, exp(0) = 1)
|
||||
matrix[i * n + j] = self.w_angle + self.w_dist;
|
||||
} else {
|
||||
let theta_ij = (viewpoints[i].azimuth - viewpoints[j].azimuth).abs();
|
||||
let dx = viewpoints[i].position.0 - viewpoints[j].position.0;
|
||||
let dy = viewpoints[i].position.1 - viewpoints[j].position.1;
|
||||
let d_ij = (dx * dx + dy * dy).sqrt();
|
||||
matrix[i * n + j] = self.compute_pair(theta_ij, d_ij);
|
||||
}
|
||||
}
|
||||
}
|
||||
matrix
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Projection weights
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Linear projection weights for Q, K, V transformations.
|
||||
///
|
||||
/// Each weight matrix is `d_out x d_in`, stored row-major. In the default
|
||||
/// (identity) configuration `d_out == d_in` and the matrices are identity.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProjectionWeights {
|
||||
/// W_q projection matrix, row-major `[d_out, d_in]`.
|
||||
pub w_q: Vec<f32>,
|
||||
/// W_k projection matrix, row-major `[d_out, d_in]`.
|
||||
pub w_k: Vec<f32>,
|
||||
/// W_v projection matrix, row-major `[d_out, d_in]`.
|
||||
pub w_v: Vec<f32>,
|
||||
/// Input dimension.
|
||||
pub d_in: usize,
|
||||
/// Output (projected) dimension.
|
||||
pub d_out: usize,
|
||||
}
|
||||
|
||||
impl ProjectionWeights {
|
||||
/// Create identity projections (d_out == d_in, W = I).
|
||||
pub fn identity(dim: usize) -> Self {
|
||||
let mut eye = vec![0.0_f32; dim * dim];
|
||||
for i in 0..dim {
|
||||
eye[i * dim + i] = 1.0;
|
||||
}
|
||||
ProjectionWeights {
|
||||
w_q: eye.clone(),
|
||||
w_k: eye.clone(),
|
||||
w_v: eye,
|
||||
d_in: dim,
|
||||
d_out: dim,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create projections with given weight matrices.
|
||||
///
|
||||
/// Each matrix must be `d_out * d_in` elements, stored row-major.
|
||||
pub fn new(
|
||||
w_q: Vec<f32>,
|
||||
w_k: Vec<f32>,
|
||||
w_v: Vec<f32>,
|
||||
d_in: usize,
|
||||
d_out: usize,
|
||||
) -> Result<Self, AttentionError> {
|
||||
let expected_len = d_out * d_in;
|
||||
if w_q.len() != expected_len {
|
||||
return Err(AttentionError::WeightDimensionMismatch {
|
||||
expected: expected_len,
|
||||
actual: w_q.len(),
|
||||
});
|
||||
}
|
||||
if w_k.len() != expected_len {
|
||||
return Err(AttentionError::WeightDimensionMismatch {
|
||||
expected: expected_len,
|
||||
actual: w_k.len(),
|
||||
});
|
||||
}
|
||||
if w_v.len() != expected_len {
|
||||
return Err(AttentionError::WeightDimensionMismatch {
|
||||
expected: expected_len,
|
||||
actual: w_v.len(),
|
||||
});
|
||||
}
|
||||
Ok(ProjectionWeights { w_q, w_k, w_v, d_in, d_out })
|
||||
}
|
||||
|
||||
/// Project a single embedding vector through a weight matrix.
|
||||
///
|
||||
/// `weight` is `[d_out, d_in]` row-major, `input` is `[d_in]`.
|
||||
/// Returns `[d_out]`.
|
||||
fn project(&self, weight: &[f32], input: &[f32]) -> Vec<f32> {
|
||||
let mut output = vec![0.0_f32; self.d_out];
|
||||
for row in 0..self.d_out {
|
||||
let mut sum = 0.0_f32;
|
||||
for col in 0..self.d_in {
|
||||
sum += weight[row * self.d_in + col] * input[col];
|
||||
}
|
||||
output[row] = sum;
|
||||
}
|
||||
output
|
||||
}
|
||||
|
||||
/// Project all viewpoint embeddings through W_q.
|
||||
pub fn project_queries(&self, embeddings: &[Vec<f32>]) -> Vec<Vec<f32>> {
|
||||
embeddings.iter().map(|e| self.project(&self.w_q, e)).collect()
|
||||
}
|
||||
|
||||
/// Project all viewpoint embeddings through W_k.
|
||||
pub fn project_keys(&self, embeddings: &[Vec<f32>]) -> Vec<Vec<f32>> {
|
||||
embeddings.iter().map(|e| self.project(&self.w_k, e)).collect()
|
||||
}
|
||||
|
||||
/// Project all viewpoint embeddings through W_v.
|
||||
pub fn project_values(&self, embeddings: &[Vec<f32>]) -> Vec<Vec<f32>> {
|
||||
embeddings.iter().map(|e| self.project(&self.w_v, e)).collect()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CrossViewpointAttention
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Cross-viewpoint attention with geometric bias.
|
||||
///
|
||||
/// Computes the full RuView attention pipeline:
|
||||
///
|
||||
/// 1. Project embeddings through W_q, W_k, W_v.
|
||||
/// 2. Compute attention scores: `A = softmax((Q * K^T + G_bias) / sqrt(d))`.
|
||||
/// 3. Weighted sum: `fused = A * V`.
|
||||
///
|
||||
/// The output is one fused embedding per input viewpoint (row of A * V).
|
||||
/// To obtain a single fused embedding, use [`CrossViewpointAttention::fuse`]
|
||||
/// which mean-pools the attended outputs.
|
||||
pub struct CrossViewpointAttention {
|
||||
/// Projection weights for Q, K, V.
|
||||
pub weights: ProjectionWeights,
|
||||
/// Geometric bias parameters.
|
||||
pub bias: GeometricBias,
|
||||
}
|
||||
|
||||
impl CrossViewpointAttention {
|
||||
/// Create a new cross-viewpoint attention module with identity projections.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `embed_dim`: embedding dimension (e.g. 128 for AETHER).
|
||||
pub fn new(embed_dim: usize) -> Self {
|
||||
CrossViewpointAttention {
|
||||
weights: ProjectionWeights::identity(embed_dim),
|
||||
bias: GeometricBias::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with custom projection weights and bias.
|
||||
pub fn with_params(weights: ProjectionWeights, bias: GeometricBias) -> Self {
|
||||
CrossViewpointAttention { weights, bias }
|
||||
}
|
||||
|
||||
/// Compute the full attention output for all viewpoints.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `embeddings`: per-viewpoint embedding vectors, each of length `d_in`.
|
||||
/// - `viewpoint_geom`: per-viewpoint geometry descriptors (same length).
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `Ok(attended)` where `attended` is `N` vectors of length `d_out`, one per
|
||||
/// viewpoint after cross-viewpoint attention. Returns an error if dimensions
|
||||
/// are inconsistent.
|
||||
pub fn attend(
|
||||
&self,
|
||||
embeddings: &[Vec<f32>],
|
||||
viewpoint_geom: &[ViewpointGeometry],
|
||||
) -> Result<Vec<Vec<f32>>, AttentionError> {
|
||||
let n = embeddings.len();
|
||||
if n == 0 {
|
||||
return Err(AttentionError::EmptyViewpoints);
|
||||
}
|
||||
|
||||
// Validate embedding dimensions.
|
||||
for (idx, emb) in embeddings.iter().enumerate() {
|
||||
if emb.len() != self.weights.d_in {
|
||||
return Err(AttentionError::DimensionMismatch {
|
||||
expected: self.weights.d_in,
|
||||
actual: emb.len(),
|
||||
});
|
||||
}
|
||||
let _ = idx; // suppress unused warning
|
||||
}
|
||||
|
||||
let d = self.weights.d_out;
|
||||
let scale = 1.0 / (d as f32).sqrt();
|
||||
|
||||
// Project through W_q, W_k, W_v.
|
||||
let queries = self.weights.project_queries(embeddings);
|
||||
let keys = self.weights.project_keys(embeddings);
|
||||
let values = self.weights.project_values(embeddings);
|
||||
|
||||
// Build geometric bias matrix.
|
||||
let g_bias = self.bias.build_matrix(viewpoint_geom);
|
||||
|
||||
// Compute attention scores: (Q * K^T + G_bias) / sqrt(d), then softmax.
|
||||
let mut attention_weights = vec![0.0_f32; n * n];
|
||||
for i in 0..n {
|
||||
// Compute raw scores for row i.
|
||||
let mut max_score = f32::NEG_INFINITY;
|
||||
for j in 0..n {
|
||||
let dot: f32 = queries[i].iter().zip(&keys[j]).map(|(q, k)| q * k).sum();
|
||||
let score = (dot + g_bias[i * n + j]) * scale;
|
||||
attention_weights[i * n + j] = score;
|
||||
if score > max_score {
|
||||
max_score = score;
|
||||
}
|
||||
}
|
||||
|
||||
// Softmax: subtract max for numerical stability, then exponentiate.
|
||||
let mut sum_exp = 0.0_f32;
|
||||
for j in 0..n {
|
||||
let val = (attention_weights[i * n + j] - max_score).exp();
|
||||
attention_weights[i * n + j] = val;
|
||||
sum_exp += val;
|
||||
}
|
||||
let safe_sum = sum_exp.max(f32::EPSILON);
|
||||
for j in 0..n {
|
||||
attention_weights[i * n + j] /= safe_sum;
|
||||
}
|
||||
}
|
||||
|
||||
// Weighted sum: attended[i] = sum_j (attention_weights[i,j] * values[j]).
|
||||
let mut attended = Vec::with_capacity(n);
|
||||
for i in 0..n {
|
||||
let mut output = vec![0.0_f32; d];
|
||||
for j in 0..n {
|
||||
let w = attention_weights[i * n + j];
|
||||
for k in 0..d {
|
||||
output[k] += w * values[j][k];
|
||||
}
|
||||
}
|
||||
attended.push(output);
|
||||
}
|
||||
|
||||
Ok(attended)
|
||||
}
|
||||
|
||||
/// Fuse multiple viewpoint embeddings into a single embedding.
|
||||
///
|
||||
/// Applies cross-viewpoint attention, then mean-pools the attended outputs
|
||||
/// to produce a single fused embedding of dimension `d_out`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `embeddings`: per-viewpoint embedding vectors.
|
||||
/// - `viewpoint_geom`: per-viewpoint geometry descriptors.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A single fused embedding of length `d_out`.
|
||||
pub fn fuse(
|
||||
&self,
|
||||
embeddings: &[Vec<f32>],
|
||||
viewpoint_geom: &[ViewpointGeometry],
|
||||
) -> Result<Vec<f32>, AttentionError> {
|
||||
let attended = self.attend(embeddings, viewpoint_geom)?;
|
||||
let n = attended.len();
|
||||
let d = self.weights.d_out;
|
||||
let mut fused = vec![0.0_f32; d];
|
||||
|
||||
for row in &attended {
|
||||
for k in 0..d {
|
||||
fused[k] += row[k];
|
||||
}
|
||||
}
|
||||
let n_f = n as f32;
|
||||
for k in 0..d {
|
||||
fused[k] /= n_f;
|
||||
}
|
||||
|
||||
Ok(fused)
|
||||
}
|
||||
|
||||
/// Extract the raw attention weight matrix (for diagnostics).
|
||||
///
|
||||
/// Returns the `N x N` attention weight matrix (row-major, each row sums to 1).
|
||||
pub fn attention_weights(
|
||||
&self,
|
||||
embeddings: &[Vec<f32>],
|
||||
viewpoint_geom: &[ViewpointGeometry],
|
||||
) -> Result<Vec<f32>, AttentionError> {
|
||||
let n = embeddings.len();
|
||||
if n == 0 {
|
||||
return Err(AttentionError::EmptyViewpoints);
|
||||
}
|
||||
|
||||
let d = self.weights.d_out;
|
||||
let scale = 1.0 / (d as f32).sqrt();
|
||||
|
||||
let queries = self.weights.project_queries(embeddings);
|
||||
let keys = self.weights.project_keys(embeddings);
|
||||
let g_bias = self.bias.build_matrix(viewpoint_geom);
|
||||
|
||||
let mut weights = vec![0.0_f32; n * n];
|
||||
for i in 0..n {
|
||||
let mut max_score = f32::NEG_INFINITY;
|
||||
for j in 0..n {
|
||||
let dot: f32 = queries[i].iter().zip(&keys[j]).map(|(q, k)| q * k).sum();
|
||||
let score = (dot + g_bias[i * n + j]) * scale;
|
||||
weights[i * n + j] = score;
|
||||
if score > max_score {
|
||||
max_score = score;
|
||||
}
|
||||
}
|
||||
|
||||
let mut sum_exp = 0.0_f32;
|
||||
for j in 0..n {
|
||||
let val = (weights[i * n + j] - max_score).exp();
|
||||
weights[i * n + j] = val;
|
||||
sum_exp += val;
|
||||
}
|
||||
let safe_sum = sum_exp.max(f32::EPSILON);
|
||||
for j in 0..n {
|
||||
weights[i * n + j] /= safe_sum;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(weights)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_test_geom(n: usize) -> Vec<ViewpointGeometry> {
|
||||
(0..n)
|
||||
.map(|i| {
|
||||
let angle = 2.0 * std::f32::consts::PI * i as f32 / n as f32;
|
||||
let r = 3.0;
|
||||
ViewpointGeometry {
|
||||
azimuth: angle,
|
||||
position: (r * angle.cos(), r * angle.sin()),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn make_test_embeddings(n: usize, dim: usize) -> Vec<Vec<f32>> {
|
||||
(0..n)
|
||||
.map(|i| {
|
||||
(0..dim).map(|d| ((i * dim + d) as f32 * 0.01).sin()).collect()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fuse_produces_correct_dimension() {
|
||||
let dim = 16;
|
||||
let n = 4;
|
||||
let attn = CrossViewpointAttention::new(dim);
|
||||
let embeddings = make_test_embeddings(n, dim);
|
||||
let geom = make_test_geom(n);
|
||||
let fused = attn.fuse(&embeddings, &geom).unwrap();
|
||||
assert_eq!(fused.len(), dim, "fused embedding must have length {dim}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn attend_produces_n_outputs() {
|
||||
let dim = 8;
|
||||
let n = 3;
|
||||
let attn = CrossViewpointAttention::new(dim);
|
||||
let embeddings = make_test_embeddings(n, dim);
|
||||
let geom = make_test_geom(n);
|
||||
let attended = attn.attend(&embeddings, &geom).unwrap();
|
||||
assert_eq!(attended.len(), n, "must produce one output per viewpoint");
|
||||
for row in &attended {
|
||||
assert_eq!(row.len(), dim);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn attention_weights_sum_to_one() {
|
||||
let dim = 8;
|
||||
let n = 4;
|
||||
let attn = CrossViewpointAttention::new(dim);
|
||||
let embeddings = make_test_embeddings(n, dim);
|
||||
let geom = make_test_geom(n);
|
||||
let weights = attn.attention_weights(&embeddings, &geom).unwrap();
|
||||
assert_eq!(weights.len(), n * n);
|
||||
for i in 0..n {
|
||||
let row_sum: f32 = (0..n).map(|j| weights[i * n + j]).sum();
|
||||
assert!(
|
||||
(row_sum - 1.0).abs() < 1e-5,
|
||||
"row {i} sums to {row_sum}, expected 1.0"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn attention_weights_are_non_negative() {
|
||||
let dim = 8;
|
||||
let n = 3;
|
||||
let attn = CrossViewpointAttention::new(dim);
|
||||
let embeddings = make_test_embeddings(n, dim);
|
||||
let geom = make_test_geom(n);
|
||||
let weights = attn.attention_weights(&embeddings, &geom).unwrap();
|
||||
for w in &weights {
|
||||
assert!(*w >= 0.0, "attention weight must be non-negative, got {w}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_viewpoints_returns_error() {
|
||||
let attn = CrossViewpointAttention::new(8);
|
||||
let result = attn.fuse(&[], &[]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dimension_mismatch_returns_error() {
|
||||
let attn = CrossViewpointAttention::new(8);
|
||||
let embeddings = vec![vec![1.0_f32; 4]]; // wrong dim
|
||||
let geom = make_test_geom(1);
|
||||
let result = attn.fuse(&embeddings, &geom);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn geometric_bias_pair_computation() {
|
||||
let bias = GeometricBias::new(1.0, 1.0, 5.0);
|
||||
// Same position: theta=0, d=0 -> cos(0) + exp(0) = 2.0
|
||||
let val = bias.compute_pair(0.0, 0.0);
|
||||
assert!((val - 2.0).abs() < 1e-5, "self-bias should be 2.0, got {val}");
|
||||
|
||||
// Orthogonal, far apart: theta=PI/2, d=5.0
|
||||
let val_orth = bias.compute_pair(std::f32::consts::FRAC_PI_2, 5.0);
|
||||
// cos(PI/2) ~ 0 + exp(-1) ~ 0.368
|
||||
assert!(val_orth < 1.0, "orthogonal far-apart viewpoints should have low bias");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn geometric_bias_matrix_is_symmetric_for_symmetric_layout() {
|
||||
let bias = GeometricBias::default();
|
||||
let geom = make_test_geom(4);
|
||||
let matrix = bias.build_matrix(&geom);
|
||||
let n = 4;
|
||||
for i in 0..n {
|
||||
for j in 0..n {
|
||||
assert!(
|
||||
(matrix[i * n + j] - matrix[j * n + i]).abs() < 1e-5,
|
||||
"bias matrix must be symmetric for symmetric layout: [{i},{j}]={} vs [{j},{i}]={}",
|
||||
matrix[i * n + j],
|
||||
matrix[j * n + i]
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn single_viewpoint_fuse_returns_projection() {
|
||||
let dim = 8;
|
||||
let attn = CrossViewpointAttention::new(dim);
|
||||
let embeddings = vec![vec![1.0_f32; dim]];
|
||||
let geom = make_test_geom(1);
|
||||
let fused = attn.fuse(&embeddings, &geom).unwrap();
|
||||
// With identity projection and single viewpoint, fused == input.
|
||||
for (i, v) in fused.iter().enumerate() {
|
||||
assert!(
|
||||
(v - 1.0).abs() < 1e-5,
|
||||
"single-viewpoint fuse should return input, dim {i}: {v}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn projection_weights_custom_transform() {
|
||||
// Verify that non-identity weights change the output.
|
||||
let dim = 4;
|
||||
// Swap first two dimensions in Q.
|
||||
let mut w_q = vec![0.0_f32; dim * dim];
|
||||
w_q[0 * dim + 1] = 1.0; // row 0 picks dim 1
|
||||
w_q[1 * dim + 0] = 1.0; // row 1 picks dim 0
|
||||
w_q[2 * dim + 2] = 1.0;
|
||||
w_q[3 * dim + 3] = 1.0;
|
||||
let w_id = {
|
||||
let mut eye = vec![0.0_f32; dim * dim];
|
||||
for i in 0..dim {
|
||||
eye[i * dim + i] = 1.0;
|
||||
}
|
||||
eye
|
||||
};
|
||||
let weights = ProjectionWeights::new(w_q, w_id.clone(), w_id, dim, dim).unwrap();
|
||||
let queries = weights.project_queries(&[vec![1.0, 2.0, 3.0, 4.0]]);
|
||||
assert_eq!(queries[0], vec![2.0, 1.0, 3.0, 4.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn geometric_bias_with_large_distance_decays() {
|
||||
let bias = GeometricBias::new(0.0, 1.0, 2.0); // only distance component
|
||||
let close = bias.compute_pair(0.0, 0.5);
|
||||
let far = bias.compute_pair(0.0, 10.0);
|
||||
assert!(close > far, "closer viewpoints should have higher distance bias");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,383 @@
|
||||
//! Coherence gating for environment stability (ADR-031).
|
||||
//!
|
||||
//! Phase coherence determines whether the wireless environment is sufficiently
|
||||
//! stable for a model update. When multipath conditions change rapidly (e.g.
|
||||
//! doors opening, people entering), phase becomes incoherent and fusion
|
||||
//! quality degrades. The coherence gate prevents model updates during these
|
||||
//! transient periods.
|
||||
//!
|
||||
//! The core computation is the complex mean of unit phasors:
|
||||
//!
|
||||
//! ```text
|
||||
//! coherence = |mean(exp(j * delta_phi))|
|
||||
//! = sqrt((mean(cos(delta_phi)))^2 + (mean(sin(delta_phi)))^2)
|
||||
//! ```
|
||||
//!
|
||||
//! A coherence value near 1.0 indicates consistent phase; near 0.0 indicates
|
||||
//! random phase (incoherent environment).
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CoherenceState
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Rolling coherence state tracking phase consistency over a sliding window.
|
||||
///
|
||||
/// Maintains a circular buffer of phase differences and incrementally updates
|
||||
/// the coherence estimate as new measurements arrive.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CoherenceState {
|
||||
/// Circular buffer of phase differences (radians).
|
||||
phase_diffs: Vec<f32>,
|
||||
/// Write position in the circular buffer.
|
||||
write_pos: usize,
|
||||
/// Number of valid entries in the buffer (may be less than capacity
|
||||
/// during warm-up).
|
||||
count: usize,
|
||||
/// Running sum of cos(phase_diff).
|
||||
sum_cos: f64,
|
||||
/// Running sum of sin(phase_diff).
|
||||
sum_sin: f64,
|
||||
}
|
||||
|
||||
impl CoherenceState {
|
||||
/// Create a new coherence state with the given window size.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `window_size`: number of phase measurements to retain. Larger windows
|
||||
/// are more stable but respond more slowly to environment changes.
|
||||
/// Must be at least 1.
|
||||
pub fn new(window_size: usize) -> Self {
|
||||
let size = window_size.max(1);
|
||||
CoherenceState {
|
||||
phase_diffs: vec![0.0; size],
|
||||
write_pos: 0,
|
||||
count: 0,
|
||||
sum_cos: 0.0,
|
||||
sum_sin: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Push a new phase difference measurement into the rolling window.
|
||||
///
|
||||
/// If the buffer is full, the oldest measurement is evicted and its
|
||||
/// contribution is subtracted from the running sums.
|
||||
pub fn push(&mut self, phase_diff: f32) {
|
||||
let cap = self.phase_diffs.len();
|
||||
|
||||
// If buffer is full, subtract the evicted entry.
|
||||
if self.count == cap {
|
||||
let old = self.phase_diffs[self.write_pos];
|
||||
self.sum_cos -= old.cos() as f64;
|
||||
self.sum_sin -= old.sin() as f64;
|
||||
} else {
|
||||
self.count += 1;
|
||||
}
|
||||
|
||||
// Write new entry.
|
||||
self.phase_diffs[self.write_pos] = phase_diff;
|
||||
self.sum_cos += phase_diff.cos() as f64;
|
||||
self.sum_sin += phase_diff.sin() as f64;
|
||||
|
||||
self.write_pos = (self.write_pos + 1) % cap;
|
||||
}
|
||||
|
||||
/// Current coherence value in `[0, 1]`.
|
||||
///
|
||||
/// Returns 0.0 if no measurements have been pushed yet.
|
||||
pub fn coherence(&self) -> f32 {
|
||||
if self.count == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
let n = self.count as f64;
|
||||
let mean_cos = self.sum_cos / n;
|
||||
let mean_sin = self.sum_sin / n;
|
||||
(mean_cos * mean_cos + mean_sin * mean_sin).sqrt() as f32
|
||||
}
|
||||
|
||||
/// Number of measurements currently in the buffer.
|
||||
pub fn len(&self) -> usize {
|
||||
self.count
|
||||
}
|
||||
|
||||
/// Returns `true` if no measurements have been pushed.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.count == 0
|
||||
}
|
||||
|
||||
/// Window capacity.
|
||||
pub fn capacity(&self) -> usize {
|
||||
self.phase_diffs.len()
|
||||
}
|
||||
|
||||
/// Reset the coherence state, clearing all measurements.
|
||||
pub fn reset(&mut self) {
|
||||
self.write_pos = 0;
|
||||
self.count = 0;
|
||||
self.sum_cos = 0.0;
|
||||
self.sum_sin = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CoherenceGate
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Coherence gate that controls model updates based on phase stability.
|
||||
///
|
||||
/// Only allows model updates when the coherence exceeds a configurable
|
||||
/// threshold. Provides hysteresis to avoid rapid gate toggling near the
|
||||
/// threshold boundary.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CoherenceGate {
|
||||
/// Coherence threshold for opening the gate.
|
||||
pub threshold: f32,
|
||||
/// Hysteresis band: gate opens at `threshold` and closes at
|
||||
/// `threshold - hysteresis`.
|
||||
pub hysteresis: f32,
|
||||
/// Current gate state: `true` = open (updates allowed).
|
||||
gate_open: bool,
|
||||
/// Total number of gate evaluations.
|
||||
total_evaluations: u64,
|
||||
/// Number of times the gate was open.
|
||||
open_count: u64,
|
||||
}
|
||||
|
||||
impl CoherenceGate {
|
||||
/// Create a new coherence gate with the given threshold.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `threshold`: coherence level required for the gate to open (typically 0.7).
|
||||
/// - `hysteresis`: band below the threshold where the gate stays in its
|
||||
/// current state (typically 0.05).
|
||||
pub fn new(threshold: f32, hysteresis: f32) -> Self {
|
||||
CoherenceGate {
|
||||
threshold: threshold.clamp(0.0, 1.0),
|
||||
hysteresis: hysteresis.clamp(0.0, threshold),
|
||||
gate_open: false,
|
||||
total_evaluations: 0,
|
||||
open_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a gate with default parameters (threshold=0.7, hysteresis=0.05).
|
||||
pub fn default_params() -> Self {
|
||||
Self::new(0.7, 0.05)
|
||||
}
|
||||
|
||||
/// Evaluate the gate against the current coherence value.
|
||||
///
|
||||
/// Returns `true` if the gate is open (model update allowed).
|
||||
pub fn evaluate(&mut self, coherence: f32) -> bool {
|
||||
self.total_evaluations += 1;
|
||||
|
||||
if self.gate_open {
|
||||
// Gate is open: close if coherence drops below threshold - hysteresis.
|
||||
if coherence < self.threshold - self.hysteresis {
|
||||
self.gate_open = false;
|
||||
}
|
||||
} else {
|
||||
// Gate is closed: open if coherence exceeds threshold.
|
||||
if coherence >= self.threshold {
|
||||
self.gate_open = true;
|
||||
}
|
||||
}
|
||||
|
||||
if self.gate_open {
|
||||
self.open_count += 1;
|
||||
}
|
||||
|
||||
self.gate_open
|
||||
}
|
||||
|
||||
/// Whether the gate is currently open.
|
||||
pub fn is_open(&self) -> bool {
|
||||
self.gate_open
|
||||
}
|
||||
|
||||
/// Fraction of evaluations where the gate was open.
|
||||
pub fn duty_cycle(&self) -> f32 {
|
||||
if self.total_evaluations == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
self.open_count as f32 / self.total_evaluations as f32
|
||||
}
|
||||
|
||||
/// Reset the gate state and counters.
|
||||
pub fn reset(&mut self) {
|
||||
self.gate_open = false;
|
||||
self.total_evaluations = 0;
|
||||
self.open_count = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Stateless coherence gate function matching the ADR-031 specification.
|
||||
///
|
||||
/// Computes the complex mean of unit phasors from the given phase differences
|
||||
/// and returns `true` when coherence exceeds the threshold.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `phase_diffs`: delta-phi over T recent frames (radians).
|
||||
/// - `threshold`: coherence threshold (typically 0.7).
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `true` if the phase coherence exceeds the threshold.
|
||||
pub fn coherence_gate(phase_diffs: &[f32], threshold: f32) -> bool {
|
||||
if phase_diffs.is_empty() {
|
||||
return false;
|
||||
}
|
||||
let (sum_cos, sum_sin) = phase_diffs
|
||||
.iter()
|
||||
.fold((0.0_f32, 0.0_f32), |(c, s), &dp| {
|
||||
(c + dp.cos(), s + dp.sin())
|
||||
});
|
||||
let n = phase_diffs.len() as f32;
|
||||
let coherence = ((sum_cos / n).powi(2) + (sum_sin / n).powi(2)).sqrt();
|
||||
coherence > threshold
|
||||
}
|
||||
|
||||
/// Compute the raw coherence value from phase differences.
|
||||
///
|
||||
/// Returns a value in `[0, 1]` where 1.0 = perfectly coherent phase.
|
||||
pub fn compute_coherence(phase_diffs: &[f32]) -> f32 {
|
||||
if phase_diffs.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let (sum_cos, sum_sin) = phase_diffs
|
||||
.iter()
|
||||
.fold((0.0_f32, 0.0_f32), |(c, s), &dp| {
|
||||
(c + dp.cos(), s + dp.sin())
|
||||
});
|
||||
let n = phase_diffs.len() as f32;
|
||||
((sum_cos / n).powi(2) + (sum_sin / n).powi(2)).sqrt()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn coherent_phase_returns_high_value() {
|
||||
// All phase diffs are the same -> coherence ~ 1.0
|
||||
let phase_diffs = vec![0.5_f32; 100];
|
||||
let c = compute_coherence(&phase_diffs);
|
||||
assert!(c > 0.99, "identical phases should give coherence ~ 1.0, got {c}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn random_phase_returns_low_value() {
|
||||
// Uniformly spaced phases around the circle -> coherence ~ 0.0
|
||||
let n = 1000;
|
||||
let phase_diffs: Vec<f32> = (0..n)
|
||||
.map(|i| 2.0 * std::f32::consts::PI * i as f32 / n as f32)
|
||||
.collect();
|
||||
let c = compute_coherence(&phase_diffs);
|
||||
assert!(c < 0.05, "uniformly spread phases should give coherence ~ 0.0, got {c}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coherence_gate_opens_above_threshold() {
|
||||
let coherent = vec![0.3_f32; 50]; // same phase -> high coherence
|
||||
assert!(coherence_gate(&coherent, 0.7));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coherence_gate_closed_below_threshold() {
|
||||
let n = 500;
|
||||
let incoherent: Vec<f32> = (0..n)
|
||||
.map(|i| 2.0 * std::f32::consts::PI * i as f32 / n as f32)
|
||||
.collect();
|
||||
assert!(!coherence_gate(&incoherent, 0.7));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coherence_gate_empty_returns_false() {
|
||||
assert!(!coherence_gate(&[], 0.5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coherence_state_rolling_window() {
|
||||
let mut state = CoherenceState::new(10);
|
||||
// Push coherent measurements.
|
||||
for _ in 0..10 {
|
||||
state.push(1.0);
|
||||
}
|
||||
let c1 = state.coherence();
|
||||
assert!(c1 > 0.9, "coherent window should give high coherence");
|
||||
|
||||
// Push incoherent measurements to replace the window.
|
||||
for i in 0..10 {
|
||||
state.push(i as f32 * 0.628);
|
||||
}
|
||||
let c2 = state.coherence();
|
||||
assert!(c2 < c1, "incoherent updates should reduce coherence");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coherence_state_empty_returns_zero() {
|
||||
let state = CoherenceState::new(10);
|
||||
assert_eq!(state.coherence(), 0.0);
|
||||
assert!(state.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gate_hysteresis_prevents_toggling() {
|
||||
let mut gate = CoherenceGate::new(0.7, 0.1);
|
||||
// Open the gate.
|
||||
assert!(gate.evaluate(0.8));
|
||||
assert!(gate.is_open());
|
||||
|
||||
// Coherence drops to 0.65 (below threshold but within hysteresis band).
|
||||
assert!(gate.evaluate(0.65));
|
||||
assert!(gate.is_open(), "gate should stay open within hysteresis band");
|
||||
|
||||
// Coherence drops below hysteresis boundary (0.7 - 0.1 = 0.6).
|
||||
assert!(!gate.evaluate(0.55));
|
||||
assert!(!gate.is_open(), "gate should close below hysteresis boundary");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gate_duty_cycle_tracks_correctly() {
|
||||
let mut gate = CoherenceGate::new(0.5, 0.0);
|
||||
gate.evaluate(0.6); // open
|
||||
gate.evaluate(0.6); // open
|
||||
gate.evaluate(0.3); // close
|
||||
gate.evaluate(0.3); // close
|
||||
let duty = gate.duty_cycle();
|
||||
assert!(
|
||||
(duty - 0.5).abs() < 1e-5,
|
||||
"duty cycle should be 0.5, got {duty}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gate_reset_clears_state() {
|
||||
let mut gate = CoherenceGate::new(0.5, 0.0);
|
||||
gate.evaluate(0.6);
|
||||
assert!(gate.is_open());
|
||||
gate.reset();
|
||||
assert!(!gate.is_open());
|
||||
assert_eq!(gate.duty_cycle(), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coherence_state_push_and_len() {
|
||||
let mut state = CoherenceState::new(5);
|
||||
assert_eq!(state.len(), 0);
|
||||
state.push(0.1);
|
||||
state.push(0.2);
|
||||
assert_eq!(state.len(), 2);
|
||||
// Fill past capacity.
|
||||
for i in 0..10 {
|
||||
state.push(i as f32 * 0.1);
|
||||
}
|
||||
assert_eq!(state.len(), 5, "count should be capped at window size");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,696 @@
|
||||
//! MultistaticArray aggregate root and fusion pipeline orchestrator (ADR-031).
|
||||
//!
|
||||
//! [`MultistaticArray`] is the DDD aggregate root for the ViewpointFusion
|
||||
//! bounded context. It orchestrates the full fusion pipeline:
|
||||
//!
|
||||
//! 1. Collect per-viewpoint AETHER embeddings.
|
||||
//! 2. Compute geometric bias from viewpoint pair geometry.
|
||||
//! 3. Apply cross-viewpoint attention with geometric bias.
|
||||
//! 4. Gate the output through coherence check.
|
||||
//! 5. Emit a fused embedding for the DensePose regression head.
|
||||
//!
|
||||
//! Uses `ruvector-attention` for the attention mechanism and
|
||||
//! `ruvector-attn-mincut` for optional noise gating on embeddings.
|
||||
|
||||
use crate::viewpoint::attention::{
|
||||
AttentionError, CrossViewpointAttention, GeometricBias, ViewpointGeometry,
|
||||
};
|
||||
use crate::viewpoint::coherence::{CoherenceGate, CoherenceState};
|
||||
use crate::viewpoint::geometry::{GeometricDiversityIndex, NodeId};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Domain types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Unique identifier for a multistatic array deployment.
|
||||
pub type ArrayId = u64;
|
||||
|
||||
/// Per-viewpoint embedding with geometric metadata.
|
||||
///
|
||||
/// Represents a single CSI observation processed through the per-viewpoint
|
||||
/// signal pipeline and AETHER encoder into a contrastive embedding.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ViewpointEmbedding {
|
||||
/// Source node identifier.
|
||||
pub node_id: NodeId,
|
||||
/// AETHER embedding vector (typically 128-d).
|
||||
pub embedding: Vec<f32>,
|
||||
/// Azimuth angle from array centroid (radians).
|
||||
pub azimuth: f32,
|
||||
/// Elevation angle (radians, 0 for 2-D deployments).
|
||||
pub elevation: f32,
|
||||
/// Baseline distance from array centroid (metres).
|
||||
pub baseline: f32,
|
||||
/// Node position in metres (x, y).
|
||||
pub position: (f32, f32),
|
||||
/// Signal-to-noise ratio at capture time (dB).
|
||||
pub snr_db: f32,
|
||||
}
|
||||
|
||||
/// Fused embedding output from the cross-viewpoint attention pipeline.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FusedEmbedding {
|
||||
/// The fused embedding vector.
|
||||
pub embedding: Vec<f32>,
|
||||
/// Geometric Diversity Index at the time of fusion.
|
||||
pub gdi: f32,
|
||||
/// Coherence value at the time of fusion.
|
||||
pub coherence: f32,
|
||||
/// Number of viewpoints that contributed to the fusion.
|
||||
pub n_viewpoints: usize,
|
||||
/// Effective independent viewpoints (after correlation discount).
|
||||
pub n_effective: f32,
|
||||
}
|
||||
|
||||
/// Configuration for the fusion pipeline.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FusionConfig {
|
||||
/// Embedding dimension (must match AETHER output, typically 128).
|
||||
pub embed_dim: usize,
|
||||
/// Coherence threshold for gating (typically 0.7).
|
||||
pub coherence_threshold: f32,
|
||||
/// Coherence hysteresis band (typically 0.05).
|
||||
pub coherence_hysteresis: f32,
|
||||
/// Coherence rolling window size (number of frames).
|
||||
pub coherence_window: usize,
|
||||
/// Geometric bias angle weight.
|
||||
pub w_angle: f32,
|
||||
/// Geometric bias distance weight.
|
||||
pub w_dist: f32,
|
||||
/// Reference distance for geometric bias decay (metres).
|
||||
pub d_ref: f32,
|
||||
/// Minimum SNR (dB) for a viewpoint to contribute to fusion.
|
||||
pub min_snr_db: f32,
|
||||
}
|
||||
|
||||
impl Default for FusionConfig {
|
||||
fn default() -> Self {
|
||||
FusionConfig {
|
||||
embed_dim: 128,
|
||||
coherence_threshold: 0.7,
|
||||
coherence_hysteresis: 0.05,
|
||||
coherence_window: 50,
|
||||
w_angle: 1.0,
|
||||
w_dist: 1.0,
|
||||
d_ref: 5.0,
|
||||
min_snr_db: 5.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Fusion errors
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors produced by the fusion pipeline.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum FusionError {
|
||||
/// No viewpoint embeddings available for fusion.
|
||||
NoViewpoints,
|
||||
/// All viewpoints were filtered out (e.g. by SNR threshold).
|
||||
AllFiltered {
|
||||
/// Number of viewpoints that were rejected.
|
||||
rejected: usize,
|
||||
},
|
||||
/// Coherence gate is closed (environment too unstable).
|
||||
CoherenceGateClosed {
|
||||
/// Current coherence value.
|
||||
coherence: f32,
|
||||
/// Required threshold.
|
||||
threshold: f32,
|
||||
},
|
||||
/// Internal attention computation error.
|
||||
AttentionError(AttentionError),
|
||||
/// Embedding dimension mismatch.
|
||||
DimensionMismatch {
|
||||
/// Expected dimension.
|
||||
expected: usize,
|
||||
/// Actual dimension.
|
||||
actual: usize,
|
||||
/// Node that produced the mismatched embedding.
|
||||
node_id: NodeId,
|
||||
},
|
||||
}
|
||||
|
||||
impl std::fmt::Display for FusionError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
FusionError::NoViewpoints => write!(f, "no viewpoint embeddings available"),
|
||||
FusionError::AllFiltered { rejected } => {
|
||||
write!(f, "all {rejected} viewpoints filtered by SNR threshold")
|
||||
}
|
||||
FusionError::CoherenceGateClosed { coherence, threshold } => {
|
||||
write!(
|
||||
f,
|
||||
"coherence gate closed: coherence={coherence:.3} < threshold={threshold:.3}"
|
||||
)
|
||||
}
|
||||
FusionError::AttentionError(e) => write!(f, "attention error: {e}"),
|
||||
FusionError::DimensionMismatch { expected, actual, node_id } => {
|
||||
write!(
|
||||
f,
|
||||
"node {node_id} embedding dim {actual} != expected {expected}"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for FusionError {}
|
||||
|
||||
impl From<AttentionError> for FusionError {
|
||||
fn from(e: AttentionError) -> Self {
|
||||
FusionError::AttentionError(e)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Domain events
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Events emitted by the ViewpointFusion aggregate.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ViewpointFusionEvent {
|
||||
/// A viewpoint embedding was received from a node.
|
||||
ViewpointCaptured {
|
||||
/// Source node.
|
||||
node_id: NodeId,
|
||||
/// Signal quality.
|
||||
snr_db: f32,
|
||||
},
|
||||
/// A TDM cycle completed with all (or some) viewpoints received.
|
||||
TdmCycleCompleted {
|
||||
/// Monotonic cycle counter.
|
||||
cycle_id: u64,
|
||||
/// Number of viewpoints received this cycle.
|
||||
viewpoints_received: usize,
|
||||
},
|
||||
/// Fusion completed successfully.
|
||||
FusionCompleted {
|
||||
/// GDI at the time of fusion.
|
||||
gdi: f32,
|
||||
/// Number of viewpoints fused.
|
||||
n_viewpoints: usize,
|
||||
},
|
||||
/// Coherence gate evaluation result.
|
||||
CoherenceGateTriggered {
|
||||
/// Current coherence value.
|
||||
coherence: f32,
|
||||
/// Whether the gate accepted the update.
|
||||
accepted: bool,
|
||||
},
|
||||
/// Array geometry was updated.
|
||||
GeometryUpdated {
|
||||
/// New GDI value.
|
||||
new_gdi: f32,
|
||||
/// Effective independent viewpoints.
|
||||
n_effective: f32,
|
||||
},
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// MultistaticArray (aggregate root)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Aggregate root for the ViewpointFusion bounded context.
|
||||
///
|
||||
/// Manages the lifecycle of a multistatic sensor array: collecting viewpoint
|
||||
/// embeddings, computing geometric diversity, gating on coherence, and
|
||||
/// producing fused embeddings for downstream pose estimation.
|
||||
pub struct MultistaticArray {
|
||||
/// Unique deployment identifier.
|
||||
id: ArrayId,
|
||||
/// Active viewpoint embeddings (latest per node).
|
||||
viewpoints: Vec<ViewpointEmbedding>,
|
||||
/// Cross-viewpoint attention module.
|
||||
attention: CrossViewpointAttention,
|
||||
/// Coherence state tracker.
|
||||
coherence_state: CoherenceState,
|
||||
/// Coherence gate.
|
||||
coherence_gate: CoherenceGate,
|
||||
/// Pipeline configuration.
|
||||
config: FusionConfig,
|
||||
/// Monotonic TDM cycle counter.
|
||||
cycle_count: u64,
|
||||
/// Event log (bounded).
|
||||
events: Vec<ViewpointFusionEvent>,
|
||||
/// Maximum events to retain.
|
||||
max_events: usize,
|
||||
}
|
||||
|
||||
impl MultistaticArray {
|
||||
/// Create a new multistatic array with the given configuration.
|
||||
pub fn new(id: ArrayId, config: FusionConfig) -> Self {
|
||||
let attention = CrossViewpointAttention::new(config.embed_dim);
|
||||
let attention = CrossViewpointAttention::with_params(
|
||||
attention.weights,
|
||||
GeometricBias::new(config.w_angle, config.w_dist, config.d_ref),
|
||||
);
|
||||
let coherence_state = CoherenceState::new(config.coherence_window);
|
||||
let coherence_gate =
|
||||
CoherenceGate::new(config.coherence_threshold, config.coherence_hysteresis);
|
||||
|
||||
MultistaticArray {
|
||||
id,
|
||||
viewpoints: Vec::new(),
|
||||
attention,
|
||||
coherence_state,
|
||||
coherence_gate,
|
||||
config,
|
||||
cycle_count: 0,
|
||||
events: Vec::new(),
|
||||
max_events: 1000,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with default configuration.
|
||||
pub fn with_defaults(id: ArrayId) -> Self {
|
||||
Self::new(id, FusionConfig::default())
|
||||
}
|
||||
|
||||
/// Array deployment identifier.
|
||||
pub fn id(&self) -> ArrayId {
|
||||
self.id
|
||||
}
|
||||
|
||||
/// Number of viewpoints currently held.
|
||||
pub fn n_viewpoints(&self) -> usize {
|
||||
self.viewpoints.len()
|
||||
}
|
||||
|
||||
/// Current TDM cycle count.
|
||||
pub fn cycle_count(&self) -> u64 {
|
||||
self.cycle_count
|
||||
}
|
||||
|
||||
/// Submit a viewpoint embedding from a sensor node.
|
||||
///
|
||||
/// Replaces any existing embedding for the same `node_id`.
|
||||
pub fn submit_viewpoint(&mut self, vp: ViewpointEmbedding) -> Result<(), FusionError> {
|
||||
// Validate embedding dimension.
|
||||
if vp.embedding.len() != self.config.embed_dim {
|
||||
return Err(FusionError::DimensionMismatch {
|
||||
expected: self.config.embed_dim,
|
||||
actual: vp.embedding.len(),
|
||||
node_id: vp.node_id,
|
||||
});
|
||||
}
|
||||
|
||||
self.emit_event(ViewpointFusionEvent::ViewpointCaptured {
|
||||
node_id: vp.node_id,
|
||||
snr_db: vp.snr_db,
|
||||
});
|
||||
|
||||
// Upsert: replace existing embedding for this node.
|
||||
if let Some(pos) = self.viewpoints.iter().position(|v| v.node_id == vp.node_id) {
|
||||
self.viewpoints[pos] = vp;
|
||||
} else {
|
||||
self.viewpoints.push(vp);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Push a phase-difference measurement for coherence tracking.
|
||||
pub fn push_phase_diff(&mut self, phase_diff: f32) {
|
||||
self.coherence_state.push(phase_diff);
|
||||
}
|
||||
|
||||
/// Current coherence value.
|
||||
pub fn coherence(&self) -> f32 {
|
||||
self.coherence_state.coherence()
|
||||
}
|
||||
|
||||
/// Compute the Geometric Diversity Index for the current array layout.
|
||||
pub fn compute_gdi(&self) -> Option<GeometricDiversityIndex> {
|
||||
let azimuths: Vec<f32> = self.viewpoints.iter().map(|v| v.azimuth).collect();
|
||||
let ids: Vec<NodeId> = self.viewpoints.iter().map(|v| v.node_id).collect();
|
||||
let gdi = GeometricDiversityIndex::compute(&azimuths, &ids);
|
||||
if let Some(ref g) = gdi {
|
||||
// Emit event (mutable borrow not possible here, caller can do it).
|
||||
let _ = g; // used for return
|
||||
}
|
||||
gdi
|
||||
}
|
||||
|
||||
/// Run the full fusion pipeline.
|
||||
///
|
||||
/// 1. Filter viewpoints by SNR.
|
||||
/// 2. Check coherence gate.
|
||||
/// 3. Compute geometric bias.
|
||||
/// 4. Apply cross-viewpoint attention.
|
||||
/// 5. Mean-pool to single fused embedding.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `Ok(FusedEmbedding)` on success, or an error if the pipeline cannot
|
||||
/// produce a valid fusion (no viewpoints, gate closed, etc.).
|
||||
pub fn fuse(&mut self) -> Result<FusedEmbedding, FusionError> {
|
||||
self.cycle_count += 1;
|
||||
|
||||
// Extract all needed data from viewpoints upfront to avoid borrow conflicts.
|
||||
let min_snr = self.config.min_snr_db;
|
||||
let total_viewpoints = self.viewpoints.len();
|
||||
let extracted: Vec<(NodeId, Vec<f32>, f32, (f32, f32))> = self
|
||||
.viewpoints
|
||||
.iter()
|
||||
.filter(|v| v.snr_db >= min_snr)
|
||||
.map(|v| (v.node_id, v.embedding.clone(), v.azimuth, v.position))
|
||||
.collect();
|
||||
|
||||
let n_valid = extracted.len();
|
||||
if n_valid == 0 {
|
||||
if total_viewpoints == 0 {
|
||||
return Err(FusionError::NoViewpoints);
|
||||
}
|
||||
return Err(FusionError::AllFiltered {
|
||||
rejected: total_viewpoints,
|
||||
});
|
||||
}
|
||||
|
||||
// Check coherence gate.
|
||||
let coh = self.coherence_state.coherence();
|
||||
let gate_open = self.coherence_gate.evaluate(coh);
|
||||
|
||||
self.emit_event(ViewpointFusionEvent::CoherenceGateTriggered {
|
||||
coherence: coh,
|
||||
accepted: gate_open,
|
||||
});
|
||||
|
||||
if !gate_open {
|
||||
return Err(FusionError::CoherenceGateClosed {
|
||||
coherence: coh,
|
||||
threshold: self.config.coherence_threshold,
|
||||
});
|
||||
}
|
||||
|
||||
// Prepare embeddings and geometries from extracted data.
|
||||
let embeddings: Vec<Vec<f32>> = extracted.iter().map(|(_, e, _, _)| e.clone()).collect();
|
||||
let geom: Vec<ViewpointGeometry> = extracted
|
||||
.iter()
|
||||
.map(|(_, _, az, pos)| ViewpointGeometry {
|
||||
azimuth: *az,
|
||||
position: *pos,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Run cross-viewpoint attention fusion.
|
||||
let fused_emb = self.attention.fuse(&embeddings, &geom)?;
|
||||
|
||||
// Compute GDI.
|
||||
let azimuths: Vec<f32> = extracted.iter().map(|(_, _, az, _)| *az).collect();
|
||||
let ids: Vec<NodeId> = extracted.iter().map(|(id, _, _, _)| *id).collect();
|
||||
let gdi_opt = GeometricDiversityIndex::compute(&azimuths, &ids);
|
||||
let (gdi_val, n_eff) = match &gdi_opt {
|
||||
Some(g) => (g.value, g.n_effective),
|
||||
None => (0.0, n_valid as f32),
|
||||
};
|
||||
|
||||
self.emit_event(ViewpointFusionEvent::TdmCycleCompleted {
|
||||
cycle_id: self.cycle_count,
|
||||
viewpoints_received: n_valid,
|
||||
});
|
||||
|
||||
self.emit_event(ViewpointFusionEvent::FusionCompleted {
|
||||
gdi: gdi_val,
|
||||
n_viewpoints: n_valid,
|
||||
});
|
||||
|
||||
Ok(FusedEmbedding {
|
||||
embedding: fused_emb,
|
||||
gdi: gdi_val,
|
||||
coherence: coh,
|
||||
n_viewpoints: n_valid,
|
||||
n_effective: n_eff,
|
||||
})
|
||||
}
|
||||
|
||||
/// Run fusion without coherence gating (for testing or forced updates).
|
||||
pub fn fuse_ungated(&mut self) -> Result<FusedEmbedding, FusionError> {
|
||||
let min_snr = self.config.min_snr_db;
|
||||
let total_viewpoints = self.viewpoints.len();
|
||||
let extracted: Vec<(NodeId, Vec<f32>, f32, (f32, f32))> = self
|
||||
.viewpoints
|
||||
.iter()
|
||||
.filter(|v| v.snr_db >= min_snr)
|
||||
.map(|v| (v.node_id, v.embedding.clone(), v.azimuth, v.position))
|
||||
.collect();
|
||||
|
||||
let n_valid = extracted.len();
|
||||
if n_valid == 0 {
|
||||
if total_viewpoints == 0 {
|
||||
return Err(FusionError::NoViewpoints);
|
||||
}
|
||||
return Err(FusionError::AllFiltered {
|
||||
rejected: total_viewpoints,
|
||||
});
|
||||
}
|
||||
|
||||
let embeddings: Vec<Vec<f32>> = extracted.iter().map(|(_, e, _, _)| e.clone()).collect();
|
||||
let geom: Vec<ViewpointGeometry> = extracted
|
||||
.iter()
|
||||
.map(|(_, _, az, pos)| ViewpointGeometry {
|
||||
azimuth: *az,
|
||||
position: *pos,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let fused_emb = self.attention.fuse(&embeddings, &geom)?;
|
||||
|
||||
let azimuths: Vec<f32> = extracted.iter().map(|(_, _, az, _)| *az).collect();
|
||||
let ids: Vec<NodeId> = extracted.iter().map(|(id, _, _, _)| *id).collect();
|
||||
let gdi_opt = GeometricDiversityIndex::compute(&azimuths, &ids);
|
||||
let (gdi_val, n_eff) = match &gdi_opt {
|
||||
Some(g) => (g.value, g.n_effective),
|
||||
None => (0.0, n_valid as f32),
|
||||
};
|
||||
|
||||
let coh = self.coherence_state.coherence();
|
||||
|
||||
Ok(FusedEmbedding {
|
||||
embedding: fused_emb,
|
||||
gdi: gdi_val,
|
||||
coherence: coh,
|
||||
n_viewpoints: n_valid,
|
||||
n_effective: n_eff,
|
||||
})
|
||||
}
|
||||
|
||||
/// Access the event log.
|
||||
pub fn events(&self) -> &[ViewpointFusionEvent] {
|
||||
&self.events
|
||||
}
|
||||
|
||||
/// Clear the event log.
|
||||
pub fn clear_events(&mut self) {
|
||||
self.events.clear();
|
||||
}
|
||||
|
||||
/// Remove a viewpoint by node ID.
|
||||
pub fn remove_viewpoint(&mut self, node_id: NodeId) {
|
||||
self.viewpoints.retain(|v| v.node_id != node_id);
|
||||
}
|
||||
|
||||
/// Clear all viewpoints.
|
||||
pub fn clear_viewpoints(&mut self) {
|
||||
self.viewpoints.clear();
|
||||
}
|
||||
|
||||
fn emit_event(&mut self, event: ViewpointFusionEvent) {
|
||||
if self.events.len() >= self.max_events {
|
||||
// Drop oldest half to avoid unbounded growth.
|
||||
let half = self.max_events / 2;
|
||||
self.events.drain(..half);
|
||||
}
|
||||
self.events.push(event);
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_viewpoint(node_id: NodeId, angle_idx: usize, n: usize, dim: usize) -> ViewpointEmbedding {
|
||||
let angle = 2.0 * std::f32::consts::PI * angle_idx as f32 / n as f32;
|
||||
let r = 3.0;
|
||||
ViewpointEmbedding {
|
||||
node_id,
|
||||
embedding: (0..dim).map(|d| ((node_id as usize * dim + d) as f32 * 0.01).sin()).collect(),
|
||||
azimuth: angle,
|
||||
elevation: 0.0,
|
||||
baseline: r,
|
||||
position: (r * angle.cos(), r * angle.sin()),
|
||||
snr_db: 15.0,
|
||||
}
|
||||
}
|
||||
|
||||
fn setup_coherent_array(dim: usize) -> MultistaticArray {
|
||||
let config = FusionConfig {
|
||||
embed_dim: dim,
|
||||
coherence_threshold: 0.5,
|
||||
coherence_hysteresis: 0.0,
|
||||
min_snr_db: 0.0,
|
||||
..FusionConfig::default()
|
||||
};
|
||||
let mut array = MultistaticArray::new(1, config);
|
||||
// Push coherent phase diffs to open the gate.
|
||||
for _ in 0..60 {
|
||||
array.push_phase_diff(0.1);
|
||||
}
|
||||
array
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fuse_produces_correct_dimension() {
|
||||
let dim = 16;
|
||||
let mut array = setup_coherent_array(dim);
|
||||
for i in 0..4 {
|
||||
array.submit_viewpoint(make_viewpoint(i, i as usize, 4, dim)).unwrap();
|
||||
}
|
||||
let fused = array.fuse().unwrap();
|
||||
assert_eq!(fused.embedding.len(), dim);
|
||||
assert_eq!(fused.n_viewpoints, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fuse_no_viewpoints_returns_error() {
|
||||
let mut array = setup_coherent_array(16);
|
||||
assert!(matches!(array.fuse(), Err(FusionError::NoViewpoints)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fuse_coherence_gate_closed_returns_error() {
|
||||
let dim = 16;
|
||||
let config = FusionConfig {
|
||||
embed_dim: dim,
|
||||
coherence_threshold: 0.9,
|
||||
coherence_hysteresis: 0.0,
|
||||
min_snr_db: 0.0,
|
||||
..FusionConfig::default()
|
||||
};
|
||||
let mut array = MultistaticArray::new(1, config);
|
||||
// Push incoherent phase diffs.
|
||||
for i in 0..100 {
|
||||
array.push_phase_diff(i as f32 * 0.5);
|
||||
}
|
||||
array.submit_viewpoint(make_viewpoint(0, 0, 4, dim)).unwrap();
|
||||
array.submit_viewpoint(make_viewpoint(1, 1, 4, dim)).unwrap();
|
||||
let result = array.fuse();
|
||||
assert!(matches!(result, Err(FusionError::CoherenceGateClosed { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fuse_ungated_bypasses_coherence() {
|
||||
let dim = 16;
|
||||
let config = FusionConfig {
|
||||
embed_dim: dim,
|
||||
coherence_threshold: 0.99,
|
||||
coherence_hysteresis: 0.0,
|
||||
min_snr_db: 0.0,
|
||||
..FusionConfig::default()
|
||||
};
|
||||
let mut array = MultistaticArray::new(1, config);
|
||||
// Push incoherent diffs -- gate would be closed.
|
||||
for i in 0..100 {
|
||||
array.push_phase_diff(i as f32 * 0.5);
|
||||
}
|
||||
array.submit_viewpoint(make_viewpoint(0, 0, 4, dim)).unwrap();
|
||||
array.submit_viewpoint(make_viewpoint(1, 1, 4, dim)).unwrap();
|
||||
let fused = array.fuse_ungated().unwrap();
|
||||
assert_eq!(fused.embedding.len(), dim);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn submit_replaces_existing_viewpoint() {
|
||||
let dim = 8;
|
||||
let mut array = setup_coherent_array(dim);
|
||||
let vp1 = make_viewpoint(10, 0, 4, dim);
|
||||
let mut vp2 = make_viewpoint(10, 1, 4, dim);
|
||||
vp2.snr_db = 25.0;
|
||||
array.submit_viewpoint(vp1).unwrap();
|
||||
assert_eq!(array.n_viewpoints(), 1);
|
||||
array.submit_viewpoint(vp2).unwrap();
|
||||
assert_eq!(array.n_viewpoints(), 1, "should replace, not add");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dimension_mismatch_returns_error() {
|
||||
let dim = 16;
|
||||
let mut array = setup_coherent_array(dim);
|
||||
let mut vp = make_viewpoint(0, 0, 4, dim);
|
||||
vp.embedding = vec![1.0; 8]; // wrong dim
|
||||
assert!(matches!(
|
||||
array.submit_viewpoint(vp),
|
||||
Err(FusionError::DimensionMismatch { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn snr_filter_rejects_low_quality() {
|
||||
let dim = 16;
|
||||
let config = FusionConfig {
|
||||
embed_dim: dim,
|
||||
coherence_threshold: 0.0,
|
||||
min_snr_db: 10.0,
|
||||
..FusionConfig::default()
|
||||
};
|
||||
let mut array = MultistaticArray::new(1, config);
|
||||
for _ in 0..60 {
|
||||
array.push_phase_diff(0.1);
|
||||
}
|
||||
let mut vp = make_viewpoint(0, 0, 4, dim);
|
||||
vp.snr_db = 3.0; // below threshold
|
||||
array.submit_viewpoint(vp).unwrap();
|
||||
assert!(matches!(array.fuse(), Err(FusionError::AllFiltered { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn events_are_emitted_on_fusion() {
|
||||
let dim = 8;
|
||||
let mut array = setup_coherent_array(dim);
|
||||
array.submit_viewpoint(make_viewpoint(0, 0, 4, dim)).unwrap();
|
||||
array.submit_viewpoint(make_viewpoint(1, 1, 4, dim)).unwrap();
|
||||
array.clear_events();
|
||||
let _ = array.fuse();
|
||||
assert!(!array.events().is_empty(), "fusion should emit events");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_viewpoint_works() {
|
||||
let dim = 8;
|
||||
let mut array = setup_coherent_array(dim);
|
||||
array.submit_viewpoint(make_viewpoint(10, 0, 4, dim)).unwrap();
|
||||
array.submit_viewpoint(make_viewpoint(20, 1, 4, dim)).unwrap();
|
||||
assert_eq!(array.n_viewpoints(), 2);
|
||||
array.remove_viewpoint(10);
|
||||
assert_eq!(array.n_viewpoints(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fused_embedding_reports_gdi() {
|
||||
let dim = 16;
|
||||
let mut array = setup_coherent_array(dim);
|
||||
for i in 0..4 {
|
||||
array.submit_viewpoint(make_viewpoint(i, i as usize, 4, dim)).unwrap();
|
||||
}
|
||||
let fused = array.fuse().unwrap();
|
||||
assert!(fused.gdi > 0.0, "GDI should be positive for spread viewpoints");
|
||||
assert!(fused.n_effective > 1.0, "effective viewpoints should be > 1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compute_gdi_standalone() {
|
||||
let dim = 8;
|
||||
let mut array = setup_coherent_array(dim);
|
||||
for i in 0..6 {
|
||||
array.submit_viewpoint(make_viewpoint(i, i as usize, 6, dim)).unwrap();
|
||||
}
|
||||
let gdi = array.compute_gdi().unwrap();
|
||||
assert!(gdi.value > 0.0);
|
||||
assert!(gdi.n_effective > 1.0);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,499 @@
|
||||
//! Geometric Diversity Index and Cramer-Rao bound estimation (ADR-031).
|
||||
//!
|
||||
//! Provides two key computations for array geometry quality assessment:
|
||||
//!
|
||||
//! 1. **Geometric Diversity Index (GDI)**: measures how well the viewpoints
|
||||
//! are spread around the sensing area. Higher GDI = better spatial coverage.
|
||||
//!
|
||||
//! 2. **Cramer-Rao Bound (CRB)**: lower bound on the position estimation
|
||||
//! variance achievable by any unbiased estimator given the array geometry.
|
||||
//! Used to predict theoretical localisation accuracy.
|
||||
//!
|
||||
//! Uses `ruvector_solver` for matrix operations in the Fisher information
|
||||
//! matrix inversion required by the Cramer-Rao bound.
|
||||
|
||||
use ruvector_solver::neumann::NeumannSolver;
|
||||
use ruvector_solver::types::CsrMatrix;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Node identifier
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Unique identifier for a sensor node in the multistatic array.
|
||||
pub type NodeId = u32;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GeometricDiversityIndex
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Geometric Diversity Index measuring array viewpoint spread.
|
||||
///
|
||||
/// GDI is computed as the mean minimum angular separation across all viewpoints:
|
||||
///
|
||||
/// ```text
|
||||
/// GDI = (1/N) * sum_i min_{j != i} |theta_i - theta_j|
|
||||
/// ```
|
||||
///
|
||||
/// A GDI close to `2*PI/N` (uniform spacing) indicates optimal diversity.
|
||||
/// A GDI near zero means viewpoints are clustered.
|
||||
///
|
||||
/// The `n_effective` field estimates the number of independent viewpoints
|
||||
/// after accounting for angular correlation between nearby viewpoints.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GeometricDiversityIndex {
|
||||
/// GDI value (radians). Higher is better.
|
||||
pub value: f32,
|
||||
/// Effective independent viewpoints after correlation discount.
|
||||
pub n_effective: f32,
|
||||
/// Worst (most redundant) viewpoint pair.
|
||||
pub worst_pair: (NodeId, NodeId),
|
||||
/// Number of physical viewpoints in the array.
|
||||
pub n_physical: usize,
|
||||
}
|
||||
|
||||
impl GeometricDiversityIndex {
|
||||
/// Compute the GDI from viewpoint azimuth angles.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `azimuths`: per-viewpoint azimuth angle in radians from the array
|
||||
/// centroid. Must have at least 2 elements.
|
||||
/// - `node_ids`: per-viewpoint node identifier (same length as `azimuths`).
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `None` if fewer than 2 viewpoints are provided.
|
||||
pub fn compute(azimuths: &[f32], node_ids: &[NodeId]) -> Option<Self> {
|
||||
let n = azimuths.len();
|
||||
if n < 2 || node_ids.len() != n {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Find the minimum angular separation for each viewpoint.
|
||||
let mut min_seps = Vec::with_capacity(n);
|
||||
let mut worst_sep = f32::MAX;
|
||||
let mut worst_i = 0_usize;
|
||||
let mut worst_j = 1_usize;
|
||||
|
||||
for i in 0..n {
|
||||
let mut min_sep = f32::MAX;
|
||||
let mut min_j = (i + 1) % n;
|
||||
for j in 0..n {
|
||||
if i == j {
|
||||
continue;
|
||||
}
|
||||
let sep = angular_distance(azimuths[i], azimuths[j]);
|
||||
if sep < min_sep {
|
||||
min_sep = sep;
|
||||
min_j = j;
|
||||
}
|
||||
}
|
||||
min_seps.push(min_sep);
|
||||
if min_sep < worst_sep {
|
||||
worst_sep = min_sep;
|
||||
worst_i = i;
|
||||
worst_j = min_j;
|
||||
}
|
||||
}
|
||||
|
||||
let gdi = min_seps.iter().sum::<f32>() / n as f32;
|
||||
|
||||
// Effective viewpoints: discount correlated viewpoints.
|
||||
// Correlation model: rho(theta) = exp(-theta^2 / (2 * sigma^2))
|
||||
// with sigma = PI/6 (30 degrees).
|
||||
let sigma = std::f32::consts::PI / 6.0;
|
||||
let n_effective = compute_effective_viewpoints(azimuths, sigma);
|
||||
|
||||
Some(GeometricDiversityIndex {
|
||||
value: gdi,
|
||||
n_effective,
|
||||
worst_pair: (node_ids[worst_i], node_ids[worst_j]),
|
||||
n_physical: n,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns `true` if the array has sufficient geometric diversity for
|
||||
/// reliable multi-viewpoint fusion.
|
||||
///
|
||||
/// Threshold: GDI >= PI / (2 * N) (at least half the uniform-spacing ideal).
|
||||
pub fn is_sufficient(&self) -> bool {
|
||||
if self.n_physical == 0 {
|
||||
return false;
|
||||
}
|
||||
let ideal = std::f32::consts::PI * 2.0 / self.n_physical as f32;
|
||||
self.value >= ideal * 0.5
|
||||
}
|
||||
|
||||
/// Ratio of effective to physical viewpoints.
|
||||
pub fn efficiency(&self) -> f32 {
|
||||
if self.n_physical == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
self.n_effective / self.n_physical as f32
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the shortest angular distance between two angles (radians).
|
||||
///
|
||||
/// Returns a value in `[0, PI]`.
|
||||
fn angular_distance(a: f32, b: f32) -> f32 {
|
||||
let diff = (a - b).abs() % (2.0 * std::f32::consts::PI);
|
||||
if diff > std::f32::consts::PI {
|
||||
2.0 * std::f32::consts::PI - diff
|
||||
} else {
|
||||
diff
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute effective independent viewpoints using a Gaussian angular correlation
|
||||
/// model and eigenvalue analysis of the correlation matrix.
|
||||
///
|
||||
/// The effective count is: `N_eff = (sum lambda_i)^2 / sum(lambda_i^2)` where
|
||||
/// `lambda_i` are the eigenvalues of the angular correlation matrix. For
|
||||
/// efficiency, we approximate this using trace-based estimation:
|
||||
/// `N_eff approx trace(R)^2 / trace(R^2)`.
|
||||
fn compute_effective_viewpoints(azimuths: &[f32], sigma: f32) -> f32 {
|
||||
let n = azimuths.len();
|
||||
if n == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
if n == 1 {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
let two_sigma_sq = 2.0 * sigma * sigma;
|
||||
|
||||
// Build correlation matrix R[i,j] = exp(-angular_dist(i,j)^2 / (2*sigma^2))
|
||||
// and compute trace(R) and trace(R^2) simultaneously.
|
||||
// For trace(R^2) = sum_i sum_j R[i,j]^2, we need the full matrix.
|
||||
let mut r_matrix = vec![0.0_f32; n * n];
|
||||
for i in 0..n {
|
||||
r_matrix[i * n + i] = 1.0;
|
||||
for j in (i + 1)..n {
|
||||
let d = angular_distance(azimuths[i], azimuths[j]);
|
||||
let rho = (-d * d / two_sigma_sq).exp();
|
||||
r_matrix[i * n + j] = rho;
|
||||
r_matrix[j * n + i] = rho;
|
||||
}
|
||||
}
|
||||
|
||||
// trace(R) = n (all diagonal entries are 1.0).
|
||||
let trace_r = n as f32;
|
||||
// trace(R^2) = sum_{i,j} R[i,j]^2
|
||||
let trace_r2: f32 = r_matrix.iter().map(|v| v * v).sum();
|
||||
|
||||
// N_eff = trace(R)^2 / trace(R^2)
|
||||
let n_eff = (trace_r * trace_r) / trace_r2.max(f32::EPSILON);
|
||||
n_eff.min(n as f32).max(1.0)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Cramer-Rao Bound
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Cramer-Rao lower bound on position estimation variance.
|
||||
///
|
||||
/// The CRB provides the theoretical minimum variance achievable by any
|
||||
/// unbiased estimator for the target position given the array geometry.
|
||||
/// Lower CRB = better localisation potential.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CramerRaoBound {
|
||||
/// CRB for x-coordinate estimation (metres squared).
|
||||
pub crb_x: f32,
|
||||
/// CRB for y-coordinate estimation (metres squared).
|
||||
pub crb_y: f32,
|
||||
/// Root-mean-square position error lower bound (metres).
|
||||
pub rmse_lower_bound: f32,
|
||||
/// Geometric dilution of precision (GDOP).
|
||||
pub gdop: f32,
|
||||
}
|
||||
|
||||
/// A viewpoint position for CRB computation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ViewpointPosition {
|
||||
/// X coordinate in metres.
|
||||
pub x: f32,
|
||||
/// Y coordinate in metres.
|
||||
pub y: f32,
|
||||
/// Per-measurement noise standard deviation (metres).
|
||||
pub noise_std: f32,
|
||||
}
|
||||
|
||||
impl CramerRaoBound {
|
||||
/// Estimate the Cramer-Rao bound for a target at `(tx, ty)` observed by
|
||||
/// the given viewpoints.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `target`: target position `(x, y)` in metres.
|
||||
/// - `viewpoints`: sensor node positions with per-node noise levels.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `None` if fewer than 3 viewpoints are provided (under-determined).
|
||||
pub fn estimate(target: (f32, f32), viewpoints: &[ViewpointPosition]) -> Option<Self> {
|
||||
let n = viewpoints.len();
|
||||
if n < 3 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Build the 2x2 Fisher Information Matrix (FIM).
|
||||
// FIM = sum_i (1/sigma_i^2) * [cos^2(phi_i), cos(phi_i)*sin(phi_i);
|
||||
// cos(phi_i)*sin(phi_i), sin^2(phi_i)]
|
||||
// where phi_i is the bearing angle from viewpoint i to the target.
|
||||
let mut fim_00 = 0.0_f32;
|
||||
let mut fim_01 = 0.0_f32;
|
||||
let mut fim_11 = 0.0_f32;
|
||||
|
||||
for vp in viewpoints {
|
||||
let dx = target.0 - vp.x;
|
||||
let dy = target.1 - vp.y;
|
||||
let r = (dx * dx + dy * dy).sqrt().max(1e-6);
|
||||
let cos_phi = dx / r;
|
||||
let sin_phi = dy / r;
|
||||
let inv_var = 1.0 / (vp.noise_std * vp.noise_std).max(1e-10);
|
||||
|
||||
fim_00 += inv_var * cos_phi * cos_phi;
|
||||
fim_01 += inv_var * cos_phi * sin_phi;
|
||||
fim_11 += inv_var * sin_phi * sin_phi;
|
||||
}
|
||||
|
||||
// Invert the 2x2 FIM analytically: CRB = FIM^{-1}.
|
||||
let det = fim_00 * fim_11 - fim_01 * fim_01;
|
||||
if det.abs() < 1e-12 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let crb_x = fim_11 / det;
|
||||
let crb_y = fim_00 / det;
|
||||
let rmse = (crb_x + crb_y).sqrt();
|
||||
let gdop = (crb_x + crb_y).sqrt();
|
||||
|
||||
Some(CramerRaoBound {
|
||||
crb_x,
|
||||
crb_y,
|
||||
rmse_lower_bound: rmse,
|
||||
gdop,
|
||||
})
|
||||
}
|
||||
|
||||
/// Compute the CRB using the `ruvector-solver` Neumann series solver for
|
||||
/// larger arrays where the analytic 2x2 inversion is extended to include
|
||||
/// regularisation for ill-conditioned geometries.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `target`: target position `(x, y)` in metres.
|
||||
/// - `viewpoints`: sensor node positions with per-node noise levels.
|
||||
/// - `regularisation`: Tikhonov regularisation parameter (typically 1e-4).
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `None` if fewer than 3 viewpoints or the solver fails.
|
||||
pub fn estimate_regularised(
|
||||
target: (f32, f32),
|
||||
viewpoints: &[ViewpointPosition],
|
||||
regularisation: f32,
|
||||
) -> Option<Self> {
|
||||
let n = viewpoints.len();
|
||||
if n < 3 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut fim_00 = regularisation;
|
||||
let mut fim_01 = 0.0_f32;
|
||||
let mut fim_11 = regularisation;
|
||||
|
||||
for vp in viewpoints {
|
||||
let dx = target.0 - vp.x;
|
||||
let dy = target.1 - vp.y;
|
||||
let r = (dx * dx + dy * dy).sqrt().max(1e-6);
|
||||
let cos_phi = dx / r;
|
||||
let sin_phi = dy / r;
|
||||
let inv_var = 1.0 / (vp.noise_std * vp.noise_std).max(1e-10);
|
||||
|
||||
fim_00 += inv_var * cos_phi * cos_phi;
|
||||
fim_01 += inv_var * cos_phi * sin_phi;
|
||||
fim_11 += inv_var * sin_phi * sin_phi;
|
||||
}
|
||||
|
||||
// Use Neumann solver for the regularised system.
|
||||
let ata = CsrMatrix::<f32>::from_coo(
|
||||
2,
|
||||
2,
|
||||
vec![
|
||||
(0, 0, fim_00),
|
||||
(0, 1, fim_01),
|
||||
(1, 0, fim_01),
|
||||
(1, 1, fim_11),
|
||||
],
|
||||
);
|
||||
|
||||
// Solve FIM * x = e_1 and FIM * x = e_2 to get the CRB diagonal.
|
||||
let solver = NeumannSolver::new(1e-6, 500);
|
||||
|
||||
let crb_x = solver
|
||||
.solve(&ata, &[1.0, 0.0])
|
||||
.ok()
|
||||
.map(|r| r.solution[0])?;
|
||||
let crb_y = solver
|
||||
.solve(&ata, &[0.0, 1.0])
|
||||
.ok()
|
||||
.map(|r| r.solution[1])?;
|
||||
|
||||
let rmse = (crb_x.abs() + crb_y.abs()).sqrt();
|
||||
|
||||
Some(CramerRaoBound {
|
||||
crb_x,
|
||||
crb_y,
|
||||
rmse_lower_bound: rmse,
|
||||
gdop: rmse,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn gdi_uniform_spacing_is_optimal() {
|
||||
// 4 viewpoints at 0, 90, 180, 270 degrees
|
||||
let azimuths = vec![0.0, std::f32::consts::FRAC_PI_2, std::f32::consts::PI, 3.0 * std::f32::consts::FRAC_PI_2];
|
||||
let ids = vec![0, 1, 2, 3];
|
||||
let gdi = GeometricDiversityIndex::compute(&azimuths, &ids).unwrap();
|
||||
// Minimum separation = PI/2 for each viewpoint, so GDI = PI/2
|
||||
let expected = std::f32::consts::FRAC_PI_2;
|
||||
assert!(
|
||||
(gdi.value - expected).abs() < 0.01,
|
||||
"uniform spacing GDI should be PI/2={expected:.3}, got {:.3}",
|
||||
gdi.value
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gdi_clustered_viewpoints_have_low_value() {
|
||||
// 4 viewpoints clustered within 10 degrees
|
||||
let azimuths = vec![0.0, 0.05, 0.08, 0.12];
|
||||
let ids = vec![0, 1, 2, 3];
|
||||
let gdi = GeometricDiversityIndex::compute(&azimuths, &ids).unwrap();
|
||||
assert!(
|
||||
gdi.value < 0.15,
|
||||
"clustered viewpoints should have low GDI, got {:.3}",
|
||||
gdi.value
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gdi_insufficient_viewpoints_returns_none() {
|
||||
assert!(GeometricDiversityIndex::compute(&[0.0], &[0]).is_none());
|
||||
assert!(GeometricDiversityIndex::compute(&[], &[]).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gdi_efficiency_is_bounded() {
|
||||
let azimuths = vec![0.0, 1.0, 2.0, 3.0];
|
||||
let ids = vec![0, 1, 2, 3];
|
||||
let gdi = GeometricDiversityIndex::compute(&azimuths, &ids).unwrap();
|
||||
assert!(gdi.efficiency() > 0.0 && gdi.efficiency() <= 1.0,
|
||||
"efficiency should be in (0, 1], got {}", gdi.efficiency());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gdi_is_sufficient_for_uniform_layout() {
|
||||
let azimuths = vec![0.0, std::f32::consts::FRAC_PI_2, std::f32::consts::PI, 3.0 * std::f32::consts::FRAC_PI_2];
|
||||
let ids = vec![0, 1, 2, 3];
|
||||
let gdi = GeometricDiversityIndex::compute(&azimuths, &ids).unwrap();
|
||||
assert!(gdi.is_sufficient(), "uniform layout should be sufficient");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gdi_worst_pair_is_closest() {
|
||||
// Viewpoints at 0, 0.1, PI, 1.5*PI
|
||||
let azimuths = vec![0.0, 0.1, std::f32::consts::PI, 1.5 * std::f32::consts::PI];
|
||||
let ids = vec![10, 20, 30, 40];
|
||||
let gdi = GeometricDiversityIndex::compute(&azimuths, &ids).unwrap();
|
||||
// Worst pair should be (10, 20) as they are only 0.1 rad apart
|
||||
assert!(
|
||||
(gdi.worst_pair == (10, 20)) || (gdi.worst_pair == (20, 10)),
|
||||
"worst pair should be nodes 10 and 20, got {:?}",
|
||||
gdi.worst_pair
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn angular_distance_wraps_correctly() {
|
||||
let d = angular_distance(0.1, 2.0 * std::f32::consts::PI - 0.1);
|
||||
assert!(
|
||||
(d - 0.2).abs() < 1e-4,
|
||||
"angular distance across 0/2PI boundary should be 0.2, got {d}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effective_viewpoints_all_identical_equals_one() {
|
||||
let azimuths = vec![0.0, 0.0, 0.0, 0.0];
|
||||
let sigma = std::f32::consts::PI / 6.0;
|
||||
let n_eff = compute_effective_viewpoints(&azimuths, sigma);
|
||||
assert!(
|
||||
(n_eff - 1.0).abs() < 0.1,
|
||||
"4 identical viewpoints should have n_eff ~ 1.0, got {n_eff}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn crb_decreases_with_more_viewpoints() {
|
||||
let target = (0.0, 0.0);
|
||||
let vp3: Vec<ViewpointPosition> = (0..3)
|
||||
.map(|i| {
|
||||
let a = 2.0 * std::f32::consts::PI * i as f32 / 3.0;
|
||||
ViewpointPosition { x: 5.0 * a.cos(), y: 5.0 * a.sin(), noise_std: 0.1 }
|
||||
})
|
||||
.collect();
|
||||
let vp6: Vec<ViewpointPosition> = (0..6)
|
||||
.map(|i| {
|
||||
let a = 2.0 * std::f32::consts::PI * i as f32 / 6.0;
|
||||
ViewpointPosition { x: 5.0 * a.cos(), y: 5.0 * a.sin(), noise_std: 0.1 }
|
||||
})
|
||||
.collect();
|
||||
|
||||
let crb3 = CramerRaoBound::estimate(target, &vp3).unwrap();
|
||||
let crb6 = CramerRaoBound::estimate(target, &vp6).unwrap();
|
||||
assert!(
|
||||
crb6.rmse_lower_bound < crb3.rmse_lower_bound,
|
||||
"6 viewpoints should give lower CRB than 3: {:.4} vs {:.4}",
|
||||
crb6.rmse_lower_bound,
|
||||
crb3.rmse_lower_bound
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn crb_too_few_viewpoints_returns_none() {
|
||||
let target = (0.0, 0.0);
|
||||
let vps = vec![
|
||||
ViewpointPosition { x: 1.0, y: 0.0, noise_std: 0.1 },
|
||||
ViewpointPosition { x: 0.0, y: 1.0, noise_std: 0.1 },
|
||||
];
|
||||
assert!(CramerRaoBound::estimate(target, &vps).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn crb_regularised_returns_result() {
|
||||
let target = (0.0, 0.0);
|
||||
let vps: Vec<ViewpointPosition> = (0..4)
|
||||
.map(|i| {
|
||||
let a = 2.0 * std::f32::consts::PI * i as f32 / 4.0;
|
||||
ViewpointPosition { x: 3.0 * a.cos(), y: 3.0 * a.sin(), noise_std: 0.1 }
|
||||
})
|
||||
.collect();
|
||||
let crb = CramerRaoBound::estimate_regularised(target, &vps, 1e-4);
|
||||
// May return None if Neumann solver doesn't converge, but should not panic.
|
||||
if let Some(crb) = crb {
|
||||
assert!(crb.rmse_lower_bound >= 0.0, "RMSE bound must be non-negative");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
//! Cross-viewpoint embedding fusion for multistatic WiFi sensing (ADR-031).
|
||||
//!
|
||||
//! This module implements the RuView fusion pipeline that combines per-viewpoint
|
||||
//! AETHER embeddings into a single fused embedding using learned cross-viewpoint
|
||||
//! attention with geometric bias.
|
||||
//!
|
||||
//! # Submodules
|
||||
//!
|
||||
//! - [`attention`]: Cross-viewpoint scaled dot-product attention with geometric
|
||||
//! bias encoding angular separation and baseline distance between viewpoint pairs.
|
||||
//! - [`geometry`]: Geometric Diversity Index (GDI) computation and Cramer-Rao
|
||||
//! bound estimation for array geometry quality assessment.
|
||||
//! - [`coherence`]: Coherence gating that determines whether the environment is
|
||||
//! stable enough for a model update based on phase consistency.
|
||||
//! - [`fusion`]: `MultistaticArray` aggregate root that orchestrates the full
|
||||
//! fusion pipeline from per-viewpoint embeddings to a single fused output.
|
||||
|
||||
pub mod attention;
|
||||
pub mod coherence;
|
||||
pub mod fusion;
|
||||
pub mod geometry;
|
||||
|
||||
// Re-export primary types at the module root for ergonomic imports.
|
||||
pub use attention::{CrossViewpointAttention, GeometricBias};
|
||||
pub use coherence::{CoherenceGate, CoherenceState};
|
||||
pub use fusion::{FusedEmbedding, FusionConfig, MultistaticArray, ViewpointEmbedding};
|
||||
pub use geometry::{CramerRaoBound, GeometricDiversityIndex};
|
||||
@@ -41,7 +41,7 @@ chrono = { version = "0.4", features = ["serde"] }
|
||||
clap = { workspace = true }
|
||||
|
||||
# Multi-BSSID WiFi scanning pipeline (ADR-022 Phase 3)
|
||||
wifi-densepose-wifiscan = { version = "0.2.0", path = "../wifi-densepose-wifiscan" }
|
||||
wifi-densepose-wifiscan = { version = "0.3.0", path = "../wifi-densepose-wifiscan" }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.10"
|
||||
|
||||
@@ -32,8 +32,12 @@ ruvector-attn-mincut = { workspace = true }
|
||||
ruvector-attention = { workspace = true }
|
||||
ruvector-solver = { workspace = true }
|
||||
|
||||
# Midstreamer integrations (ADR-032a)
|
||||
midstreamer-temporal-compare = { workspace = true }
|
||||
midstreamer-attractor = { workspace = true }
|
||||
|
||||
# Internal
|
||||
wifi-densepose-core = { version = "0.2.0", path = "../wifi-densepose-core" }
|
||||
wifi-densepose-core = { version = "0.3.0", path = "../wifi-densepose-core" }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
|
||||
@@ -40,6 +40,7 @@ pub mod hampel;
|
||||
pub mod hardware_norm;
|
||||
pub mod motion;
|
||||
pub mod phase_sanitizer;
|
||||
pub mod ruvsense;
|
||||
pub mod spectrogram;
|
||||
pub mod subcarrier_selection;
|
||||
|
||||
|
||||
@@ -0,0 +1,586 @@
|
||||
//! Adversarial detection: physically impossible signal identification.
|
||||
//!
|
||||
//! Detects spoofed or injected WiFi signals by checking multi-link
|
||||
//! consistency, field model constraint violations, and physical
|
||||
//! plausibility. A single-link injection cannot fool a multistatic
|
||||
//! mesh because it would violate geometric constraints across links.
|
||||
//!
|
||||
//! # Checks
|
||||
//! 1. **Multi-link consistency**: A real body perturbs all links that
|
||||
//! traverse its location. An injection affects only the targeted link.
|
||||
//! 2. **Field model constraints**: Perturbation must be consistent with
|
||||
//! the room's eigenmode structure.
|
||||
//! 3. **Temporal continuity**: Real movement is smooth; injections cause
|
||||
//! discontinuities in embedding space.
|
||||
//! 4. **Energy conservation**: Total perturbation energy across links
|
||||
//! must be consistent with the number and size of bodies present.
|
||||
//!
|
||||
//! # References
|
||||
//! - ADR-030 Tier 7: Adversarial Detection
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Error types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors from adversarial detection.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum AdversarialError {
|
||||
/// Insufficient links for multi-link consistency check.
|
||||
#[error("Insufficient links: need >= {needed}, got {got}")]
|
||||
InsufficientLinks { needed: usize, got: usize },
|
||||
|
||||
/// Dimension mismatch.
|
||||
#[error("Dimension mismatch: expected {expected}, got {got}")]
|
||||
DimensionMismatch { expected: usize, got: usize },
|
||||
|
||||
/// No baseline available for constraint checking.
|
||||
#[error("No baseline available — calibrate field model first")]
|
||||
NoBaseline,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Configuration
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Configuration for adversarial detection.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdversarialConfig {
|
||||
/// Number of links in the mesh.
|
||||
pub n_links: usize,
|
||||
/// Minimum links for multi-link consistency (default 4).
|
||||
pub min_links: usize,
|
||||
/// Consistency threshold: fraction of links that must agree (0.0-1.0).
|
||||
pub consistency_threshold: f64,
|
||||
/// Maximum allowed energy ratio between any single link and total.
|
||||
pub max_single_link_energy_ratio: f64,
|
||||
/// Maximum allowed temporal discontinuity in embedding space.
|
||||
pub max_temporal_discontinuity: f64,
|
||||
/// Maximum allowed perturbation energy per body.
|
||||
pub max_energy_per_body: f64,
|
||||
}
|
||||
|
||||
impl Default for AdversarialConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
n_links: 12,
|
||||
min_links: 4,
|
||||
consistency_threshold: 0.6,
|
||||
max_single_link_energy_ratio: 0.5,
|
||||
max_temporal_discontinuity: 5.0,
|
||||
max_energy_per_body: 100.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Detection results
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Type of adversarial anomaly detected.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AnomalyType {
|
||||
/// Single link shows perturbation inconsistent with other links.
|
||||
SingleLinkInjection,
|
||||
/// Perturbation violates field model eigenmode structure.
|
||||
FieldModelViolation,
|
||||
/// Sudden discontinuity in embedding trajectory.
|
||||
TemporalDiscontinuity,
|
||||
/// Total perturbation energy inconsistent with occupancy.
|
||||
EnergyViolation,
|
||||
/// Multiple anomaly types detected simultaneously.
|
||||
MultipleViolations,
|
||||
}
|
||||
|
||||
impl AnomalyType {
|
||||
/// Human-readable name.
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
AnomalyType::SingleLinkInjection => "single_link_injection",
|
||||
AnomalyType::FieldModelViolation => "field_model_violation",
|
||||
AnomalyType::TemporalDiscontinuity => "temporal_discontinuity",
|
||||
AnomalyType::EnergyViolation => "energy_violation",
|
||||
AnomalyType::MultipleViolations => "multiple_violations",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of adversarial detection on one frame.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdversarialResult {
|
||||
/// Whether any anomaly was detected.
|
||||
pub anomaly_detected: bool,
|
||||
/// Type of anomaly (if detected).
|
||||
pub anomaly_type: Option<AnomalyType>,
|
||||
/// Anomaly score (0.0 = clean, 1.0 = definitely adversarial).
|
||||
pub anomaly_score: f64,
|
||||
/// Per-check results.
|
||||
pub checks: CheckResults,
|
||||
/// Affected link indices (if single-link injection).
|
||||
pub affected_links: Vec<usize>,
|
||||
/// Timestamp (microseconds).
|
||||
pub timestamp_us: u64,
|
||||
}
|
||||
|
||||
/// Results of individual checks.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CheckResults {
|
||||
/// Multi-link consistency score (0.0 = inconsistent, 1.0 = fully consistent).
|
||||
pub consistency_score: f64,
|
||||
/// Field model residual score (lower = more consistent with modes).
|
||||
pub field_model_residual: f64,
|
||||
/// Temporal continuity score (lower = smoother).
|
||||
pub temporal_continuity: f64,
|
||||
/// Energy conservation score (closer to 1.0 = consistent).
|
||||
pub energy_ratio: f64,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Adversarial detector
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Adversarial signal detector for the multistatic mesh.
|
||||
///
|
||||
/// Checks each frame for physical plausibility across multiple
|
||||
/// independent criteria. A spoofed signal that passes one check
|
||||
/// is unlikely to pass all of them.
|
||||
#[derive(Debug)]
|
||||
pub struct AdversarialDetector {
|
||||
config: AdversarialConfig,
|
||||
/// Previous frame's per-link energies (for temporal continuity).
|
||||
prev_energies: Option<Vec<f64>>,
|
||||
/// Previous frame's total energy.
|
||||
prev_total_energy: Option<f64>,
|
||||
/// Total frames processed.
|
||||
total_frames: u64,
|
||||
/// Total anomalies detected.
|
||||
anomaly_count: u64,
|
||||
}
|
||||
|
||||
impl AdversarialDetector {
|
||||
/// Create a new adversarial detector.
|
||||
pub fn new(config: AdversarialConfig) -> Result<Self, AdversarialError> {
|
||||
if config.n_links < config.min_links {
|
||||
return Err(AdversarialError::InsufficientLinks {
|
||||
needed: config.min_links,
|
||||
got: config.n_links,
|
||||
});
|
||||
}
|
||||
Ok(Self {
|
||||
config,
|
||||
prev_energies: None,
|
||||
prev_total_energy: None,
|
||||
total_frames: 0,
|
||||
anomaly_count: 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// Check a frame for adversarial anomalies.
|
||||
///
|
||||
/// `link_energies`: per-link perturbation energy (from field model).
|
||||
/// `n_bodies`: estimated number of bodies present.
|
||||
/// `timestamp_us`: frame timestamp.
|
||||
pub fn check(
|
||||
&mut self,
|
||||
link_energies: &[f64],
|
||||
n_bodies: usize,
|
||||
timestamp_us: u64,
|
||||
) -> Result<AdversarialResult, AdversarialError> {
|
||||
if link_energies.len() != self.config.n_links {
|
||||
return Err(AdversarialError::DimensionMismatch {
|
||||
expected: self.config.n_links,
|
||||
got: link_energies.len(),
|
||||
});
|
||||
}
|
||||
|
||||
self.total_frames += 1;
|
||||
|
||||
let total_energy: f64 = link_energies.iter().sum();
|
||||
|
||||
// Check 1: Multi-link consistency
|
||||
let consistency = self.check_consistency(link_energies, total_energy);
|
||||
|
||||
// Check 2: Field model residual (simplified — check energy distribution)
|
||||
let field_residual = self.check_field_model(link_energies, total_energy);
|
||||
|
||||
// Check 3: Temporal continuity
|
||||
let temporal = self.check_temporal(link_energies, total_energy);
|
||||
|
||||
// Check 4: Energy conservation
|
||||
let energy_ratio = self.check_energy(total_energy, n_bodies);
|
||||
|
||||
// Store for next frame
|
||||
self.prev_energies = Some(link_energies.to_vec());
|
||||
self.prev_total_energy = Some(total_energy);
|
||||
|
||||
let checks = CheckResults {
|
||||
consistency_score: consistency,
|
||||
field_model_residual: field_residual,
|
||||
temporal_continuity: temporal,
|
||||
energy_ratio,
|
||||
};
|
||||
|
||||
// Aggregate anomaly score
|
||||
let mut violations = Vec::new();
|
||||
|
||||
if consistency < self.config.consistency_threshold {
|
||||
violations.push(AnomalyType::SingleLinkInjection);
|
||||
}
|
||||
if field_residual > 0.8 {
|
||||
violations.push(AnomalyType::FieldModelViolation);
|
||||
}
|
||||
if temporal > self.config.max_temporal_discontinuity {
|
||||
violations.push(AnomalyType::TemporalDiscontinuity);
|
||||
}
|
||||
if energy_ratio > 2.0 || (n_bodies > 0 && energy_ratio < 0.1) {
|
||||
violations.push(AnomalyType::EnergyViolation);
|
||||
}
|
||||
|
||||
let anomaly_detected = !violations.is_empty();
|
||||
let anomaly_type = match violations.len() {
|
||||
0 => None,
|
||||
1 => Some(violations[0]),
|
||||
_ => Some(AnomalyType::MultipleViolations),
|
||||
};
|
||||
|
||||
// Score: weighted combination
|
||||
let anomaly_score = ((1.0 - consistency) * 0.4
|
||||
+ field_residual * 0.2
|
||||
+ (temporal / self.config.max_temporal_discontinuity).min(1.0) * 0.2
|
||||
+ ((energy_ratio - 1.0).abs() / 2.0).min(1.0) * 0.2)
|
||||
.clamp(0.0, 1.0);
|
||||
|
||||
// Find affected links (highest single-link energy ratio)
|
||||
let affected_links = if anomaly_detected {
|
||||
self.find_anomalous_links(link_energies, total_energy)
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
if anomaly_detected {
|
||||
self.anomaly_count += 1;
|
||||
}
|
||||
|
||||
Ok(AdversarialResult {
|
||||
anomaly_detected,
|
||||
anomaly_type,
|
||||
anomaly_score,
|
||||
checks,
|
||||
affected_links,
|
||||
timestamp_us,
|
||||
})
|
||||
}
|
||||
|
||||
/// Multi-link consistency: what fraction of links have correlated energy?
|
||||
///
|
||||
/// A real body perturbs many links. An injection affects few.
|
||||
fn check_consistency(&self, energies: &[f64], total: f64) -> f64 {
|
||||
if total < 1e-15 {
|
||||
return 1.0; // No perturbation = consistent (empty room)
|
||||
}
|
||||
|
||||
let mean = total / energies.len() as f64;
|
||||
let threshold = mean * 0.1; // link must have at least 10% of mean energy
|
||||
|
||||
let active_count = energies.iter().filter(|&&e| e > threshold).count();
|
||||
active_count as f64 / energies.len() as f64
|
||||
}
|
||||
|
||||
/// Field model check: is energy distribution consistent with physical propagation?
|
||||
///
|
||||
/// In a real scenario, energy should be distributed across links
|
||||
/// based on geometry. A concentrated injection scores high residual.
|
||||
fn check_field_model(&self, energies: &[f64], total: f64) -> f64 {
|
||||
if total < 1e-15 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Compute Gini coefficient of energy distribution
|
||||
// Gini = 0 → perfectly uniform, Gini = 1 → all in one link
|
||||
let n = energies.len() as f64;
|
||||
let mut sorted: Vec<f64> = energies.to_vec();
|
||||
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
let numerator: f64 = sorted
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &x)| (2.0 * (i + 1) as f64 - n - 1.0) * x)
|
||||
.sum();
|
||||
|
||||
let gini = numerator / (n * total);
|
||||
gini.clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
/// Temporal continuity: how much did per-link energies change from previous frame?
|
||||
fn check_temporal(&self, energies: &[f64], _total: f64) -> f64 {
|
||||
match &self.prev_energies {
|
||||
None => 0.0, // First frame, no temporal check
|
||||
Some(prev) => {
|
||||
let diff_energy: f64 = energies
|
||||
.iter()
|
||||
.zip(prev.iter())
|
||||
.map(|(&a, &b)| (a - b) * (a - b))
|
||||
.sum::<f64>()
|
||||
.sqrt();
|
||||
diff_energy
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Energy conservation: is total energy consistent with body count?
|
||||
fn check_energy(&self, total_energy: f64, n_bodies: usize) -> f64 {
|
||||
if n_bodies == 0 {
|
||||
// No bodies: any energy is suspicious
|
||||
return if total_energy > 1e-10 {
|
||||
total_energy
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
}
|
||||
let expected = n_bodies as f64 * self.config.max_energy_per_body;
|
||||
if expected < 1e-15 {
|
||||
return 0.0;
|
||||
}
|
||||
total_energy / expected
|
||||
}
|
||||
|
||||
/// Find links that are anomalously high relative to the mean.
|
||||
fn find_anomalous_links(&self, energies: &[f64], total: f64) -> Vec<usize> {
|
||||
if total < 1e-15 {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
energies
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, &e)| e / total > self.config.max_single_link_energy_ratio)
|
||||
.map(|(i, _)| i)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Total frames processed.
|
||||
pub fn total_frames(&self) -> u64 {
|
||||
self.total_frames
|
||||
}
|
||||
|
||||
/// Total anomalies detected.
|
||||
pub fn anomaly_count(&self) -> u64 {
|
||||
self.anomaly_count
|
||||
}
|
||||
|
||||
/// Anomaly rate (anomalies / total frames).
|
||||
pub fn anomaly_rate(&self) -> f64 {
|
||||
if self.total_frames == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.anomaly_count as f64 / self.total_frames as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset detector state.
|
||||
pub fn reset(&mut self) {
|
||||
self.prev_energies = None;
|
||||
self.prev_total_energy = None;
|
||||
self.total_frames = 0;
|
||||
self.anomaly_count = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn default_config() -> AdversarialConfig {
|
||||
AdversarialConfig {
|
||||
n_links: 6,
|
||||
min_links: 4,
|
||||
consistency_threshold: 0.6,
|
||||
max_single_link_energy_ratio: 0.5,
|
||||
max_temporal_discontinuity: 5.0,
|
||||
max_energy_per_body: 10.0,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detector_creation() {
|
||||
let det = AdversarialDetector::new(default_config()).unwrap();
|
||||
assert_eq!(det.total_frames(), 0);
|
||||
assert_eq!(det.anomaly_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insufficient_links() {
|
||||
let config = AdversarialConfig {
|
||||
n_links: 2,
|
||||
min_links: 4,
|
||||
..default_config()
|
||||
};
|
||||
assert!(matches!(
|
||||
AdversarialDetector::new(config),
|
||||
Err(AdversarialError::InsufficientLinks { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clean_frame_no_anomaly() {
|
||||
let mut det = AdversarialDetector::new(default_config()).unwrap();
|
||||
|
||||
// Uniform energy across all links (real body)
|
||||
let energies = vec![1.0, 1.1, 0.9, 1.0, 1.05, 0.95];
|
||||
let result = det.check(&energies, 1, 0).unwrap();
|
||||
|
||||
assert!(
|
||||
!result.anomaly_detected,
|
||||
"Uniform energy should not trigger anomaly"
|
||||
);
|
||||
assert!(result.anomaly_score < 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_link_injection_detected() {
|
||||
let mut det = AdversarialDetector::new(default_config()).unwrap();
|
||||
|
||||
// All energy on one link (injection)
|
||||
let energies = vec![10.0, 0.0, 0.0, 0.0, 0.0, 0.0];
|
||||
let result = det.check(&energies, 0, 0).unwrap();
|
||||
|
||||
assert!(
|
||||
result.anomaly_detected,
|
||||
"Single-link injection should be detected"
|
||||
);
|
||||
assert!(result.affected_links.contains(&0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_room_no_anomaly() {
|
||||
let mut det = AdversarialDetector::new(default_config()).unwrap();
|
||||
|
||||
let energies = vec![0.0; 6];
|
||||
let result = det.check(&energies, 0, 0).unwrap();
|
||||
|
||||
assert!(!result.anomaly_detected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temporal_discontinuity() {
|
||||
let mut det = AdversarialDetector::new(AdversarialConfig {
|
||||
max_temporal_discontinuity: 1.0, // strict
|
||||
..default_config()
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Frame 1: low energy
|
||||
let energies1 = vec![0.1; 6];
|
||||
det.check(&energies1, 0, 0).unwrap();
|
||||
|
||||
// Frame 2: sudden massive energy (discontinuity)
|
||||
let energies2 = vec![100.0; 6];
|
||||
let result = det.check(&energies2, 0, 50_000).unwrap();
|
||||
|
||||
assert!(
|
||||
result.anomaly_detected,
|
||||
"Temporal discontinuity should be detected"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_energy_violation_too_high() {
|
||||
let mut det = AdversarialDetector::new(default_config()).unwrap();
|
||||
|
||||
// Way more energy than 1 body should produce
|
||||
let energies = vec![100.0; 6]; // total = 600, max_per_body = 10
|
||||
let result = det.check(&energies, 1, 0).unwrap();
|
||||
|
||||
assert!(
|
||||
result.anomaly_detected,
|
||||
"Excessive energy should trigger anomaly"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dimension_mismatch() {
|
||||
let mut det = AdversarialDetector::new(default_config()).unwrap();
|
||||
let result = det.check(&[1.0, 2.0], 0, 0);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(AdversarialError::DimensionMismatch { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anomaly_rate() {
|
||||
let mut det = AdversarialDetector::new(default_config()).unwrap();
|
||||
|
||||
// 2 clean frames
|
||||
det.check(&vec![1.0; 6], 1, 0).unwrap();
|
||||
det.check(&vec![1.0; 6], 1, 50_000).unwrap();
|
||||
|
||||
// 1 anomalous frame
|
||||
det.check(&vec![10.0, 0.0, 0.0, 0.0, 0.0, 0.0], 0, 100_000)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(det.total_frames(), 3);
|
||||
assert!(det.anomaly_count() >= 1);
|
||||
assert!(det.anomaly_rate() > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset() {
|
||||
let mut det = AdversarialDetector::new(default_config()).unwrap();
|
||||
det.check(&vec![1.0; 6], 1, 0).unwrap();
|
||||
det.reset();
|
||||
|
||||
assert_eq!(det.total_frames(), 0);
|
||||
assert_eq!(det.anomaly_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anomaly_type_names() {
|
||||
assert_eq!(
|
||||
AnomalyType::SingleLinkInjection.name(),
|
||||
"single_link_injection"
|
||||
);
|
||||
assert_eq!(
|
||||
AnomalyType::FieldModelViolation.name(),
|
||||
"field_model_violation"
|
||||
);
|
||||
assert_eq!(
|
||||
AnomalyType::TemporalDiscontinuity.name(),
|
||||
"temporal_discontinuity"
|
||||
);
|
||||
assert_eq!(AnomalyType::EnergyViolation.name(), "energy_violation");
|
||||
assert_eq!(
|
||||
AnomalyType::MultipleViolations.name(),
|
||||
"multiple_violations"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gini_coefficient_uniform() {
|
||||
let det = AdversarialDetector::new(default_config()).unwrap();
|
||||
let energies = vec![1.0; 6];
|
||||
let total = 6.0;
|
||||
let gini = det.check_field_model(&energies, total);
|
||||
assert!(
|
||||
gini < 0.1,
|
||||
"Uniform distribution should have low Gini: {}",
|
||||
gini
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gini_coefficient_concentrated() {
|
||||
let det = AdversarialDetector::new(default_config()).unwrap();
|
||||
let energies = vec![6.0, 0.0, 0.0, 0.0, 0.0, 0.0];
|
||||
let total = 6.0;
|
||||
let gini = det.check_field_model(&energies, total);
|
||||
assert!(
|
||||
gini > 0.5,
|
||||
"Concentrated distribution should have high Gini: {}",
|
||||
gini
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,573 @@
|
||||
//! Enhanced longitudinal drift detection using `midstreamer-attractor`.
|
||||
//!
|
||||
//! Extends the Welford-statistics drift detection from `longitudinal.rs`
|
||||
//! with phase-space attractor analysis provided by the
|
||||
//! `midstreamer-attractor` crate (ADR-032a Section 6.4).
|
||||
//!
|
||||
//! # Improvements over base drift detection
|
||||
//!
|
||||
//! - **Phase-space embedding**: Detects regime changes invisible to simple
|
||||
//! z-score analysis (e.g., gait transitioning from limit cycle to
|
||||
//! strange attractor = developing instability)
|
||||
//! - **Lyapunov exponent**: Quantifies sensitivity to initial conditions,
|
||||
//! catching chaotic transitions in breathing patterns
|
||||
//! - **Attractor classification**: Automatically classifies biophysical
|
||||
//! time series as point attractor (stable), limit cycle (periodic),
|
||||
//! or strange attractor (chaotic)
|
||||
//!
|
||||
//! # References
|
||||
//! - ADR-030 Tier 4: Longitudinal Biomechanics Drift
|
||||
//! - ADR-032a Section 6.4: midstreamer-attractor integration
|
||||
//! - Takens, F. (1981). "Detecting strange attractors in turbulence."
|
||||
|
||||
use midstreamer_attractor::{
|
||||
AttractorAnalyzer, AttractorType, PhasePoint,
|
||||
};
|
||||
|
||||
use super::longitudinal::DriftMetric;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Configuration
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Configuration for attractor-based drift analysis.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AttractorDriftConfig {
|
||||
/// Embedding dimension for phase-space reconstruction (Takens' theorem).
|
||||
/// Default: 3 (sufficient for most biophysical signals).
|
||||
pub embedding_dim: usize,
|
||||
/// Time delay for phase-space embedding (in observation steps).
|
||||
/// Default: 1 (consecutive observations).
|
||||
pub time_delay: usize,
|
||||
/// Minimum observations needed before analysis is meaningful.
|
||||
/// Default: 30 (about 1 month of daily observations).
|
||||
pub min_observations: usize,
|
||||
/// Lyapunov exponent threshold for chaos detection.
|
||||
/// Default: 0.01.
|
||||
pub lyapunov_threshold: f64,
|
||||
/// Maximum trajectory length for the analyzer.
|
||||
/// Default: 10000.
|
||||
pub max_trajectory_length: usize,
|
||||
}
|
||||
|
||||
impl Default for AttractorDriftConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
embedding_dim: 3,
|
||||
time_delay: 1,
|
||||
min_observations: 30,
|
||||
lyapunov_threshold: 0.01,
|
||||
max_trajectory_length: 10000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Error types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors from attractor-based drift analysis.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum AttractorDriftError {
|
||||
/// Not enough observations for phase-space embedding.
|
||||
#[error("Insufficient observations: need >= {needed}, have {have}")]
|
||||
InsufficientData { needed: usize, have: usize },
|
||||
|
||||
/// The metric has no observations recorded.
|
||||
#[error("No observations for metric: {0}")]
|
||||
NoObservations(String),
|
||||
|
||||
/// Phase-space embedding dimension is invalid.
|
||||
#[error("Invalid embedding dimension: {dim} (must be >= 2)")]
|
||||
InvalidEmbeddingDim { dim: usize },
|
||||
|
||||
/// Attractor analysis library error.
|
||||
#[error("Attractor analysis failed: {0}")]
|
||||
AnalysisFailed(String),
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Attractor classification result
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Classification of a biophysical time series attractor.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum BiophysicalAttractor {
|
||||
/// Point attractor: metric has converged to a stable value.
|
||||
Stable { center: f64 },
|
||||
/// Limit cycle: metric oscillates periodically.
|
||||
Periodic { lyapunov_max: f64 },
|
||||
/// Strange attractor: metric exhibits chaotic dynamics.
|
||||
Chaotic { lyapunov_exponent: f64 },
|
||||
/// Transitioning between attractor types.
|
||||
Transitioning {
|
||||
from: Box<BiophysicalAttractor>,
|
||||
to: Box<BiophysicalAttractor>,
|
||||
},
|
||||
/// Insufficient data to classify.
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl BiophysicalAttractor {
|
||||
/// Whether this attractor type warrants monitoring attention.
|
||||
pub fn is_concerning(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
BiophysicalAttractor::Chaotic { .. } | BiophysicalAttractor::Transitioning { .. }
|
||||
)
|
||||
}
|
||||
|
||||
/// Human-readable label for reporting.
|
||||
pub fn label(&self) -> &'static str {
|
||||
match self {
|
||||
BiophysicalAttractor::Stable { .. } => "stable",
|
||||
BiophysicalAttractor::Periodic { .. } => "periodic",
|
||||
BiophysicalAttractor::Chaotic { .. } => "chaotic",
|
||||
BiophysicalAttractor::Transitioning { .. } => "transitioning",
|
||||
BiophysicalAttractor::Unknown => "unknown",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Attractor drift report
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Report from attractor-based drift analysis.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AttractorDriftReport {
|
||||
/// Person this report pertains to.
|
||||
pub person_id: u64,
|
||||
/// Which biophysical metric was analyzed.
|
||||
pub metric: DriftMetric,
|
||||
/// Classified attractor type.
|
||||
pub attractor: BiophysicalAttractor,
|
||||
/// Whether the attractor type has changed from the previous analysis.
|
||||
pub regime_changed: bool,
|
||||
/// Number of observations used in this analysis.
|
||||
pub observation_count: usize,
|
||||
/// Timestamp of the analysis (microseconds).
|
||||
pub timestamp_us: u64,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Per-metric observation buffer
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Time series buffer for a single biophysical metric.
|
||||
#[derive(Debug, Clone)]
|
||||
struct MetricBuffer {
|
||||
/// Metric type.
|
||||
metric: DriftMetric,
|
||||
/// Observed values (most recent at the end).
|
||||
values: Vec<f64>,
|
||||
/// Maximum buffer size.
|
||||
max_size: usize,
|
||||
/// Last classified attractor label.
|
||||
last_label: String,
|
||||
}
|
||||
|
||||
impl MetricBuffer {
|
||||
/// Create a new buffer.
|
||||
fn new(metric: DriftMetric, max_size: usize) -> Self {
|
||||
Self {
|
||||
metric,
|
||||
values: Vec::new(),
|
||||
max_size,
|
||||
last_label: "unknown".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an observation.
|
||||
fn push(&mut self, value: f64) {
|
||||
if self.values.len() >= self.max_size {
|
||||
self.values.remove(0);
|
||||
}
|
||||
self.values.push(value);
|
||||
}
|
||||
|
||||
/// Number of observations.
|
||||
fn count(&self) -> usize {
|
||||
self.values.len()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Attractor drift analyzer
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Attractor-based drift analyzer for longitudinal biophysical monitoring.
|
||||
///
|
||||
/// Uses phase-space reconstruction (Takens' embedding theorem) and
|
||||
/// `midstreamer-attractor` to classify the dynamical regime of each
|
||||
/// biophysical metric. Detects regime changes that precede simple
|
||||
/// metric drift.
|
||||
pub struct AttractorDriftAnalyzer {
|
||||
/// Configuration.
|
||||
config: AttractorDriftConfig,
|
||||
/// Person ID being monitored.
|
||||
person_id: u64,
|
||||
/// Per-metric observation buffers.
|
||||
buffers: Vec<MetricBuffer>,
|
||||
/// Total analyses performed.
|
||||
analysis_count: u64,
|
||||
}
|
||||
|
||||
// Manual Debug since AttractorAnalyzer does not derive Debug
|
||||
impl std::fmt::Debug for AttractorDriftAnalyzer {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("AttractorDriftAnalyzer")
|
||||
.field("person_id", &self.person_id)
|
||||
.field("analysis_count", &self.analysis_count)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl AttractorDriftAnalyzer {
|
||||
/// Create a new attractor drift analyzer for a person.
|
||||
pub fn new(
|
||||
person_id: u64,
|
||||
config: AttractorDriftConfig,
|
||||
) -> Result<Self, AttractorDriftError> {
|
||||
if config.embedding_dim < 2 {
|
||||
return Err(AttractorDriftError::InvalidEmbeddingDim {
|
||||
dim: config.embedding_dim,
|
||||
});
|
||||
}
|
||||
|
||||
let buffers = DriftMetric::all()
|
||||
.iter()
|
||||
.map(|&m| MetricBuffer::new(m, 365)) // 1 year of daily observations
|
||||
.collect();
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
person_id,
|
||||
buffers,
|
||||
analysis_count: 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// Add an observation for a specific metric.
|
||||
pub fn add_observation(&mut self, metric: DriftMetric, value: f64) {
|
||||
if let Some(buf) = self.buffers.iter_mut().find(|b| b.metric == metric) {
|
||||
buf.push(value);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform attractor analysis on a specific metric.
|
||||
///
|
||||
/// Reconstructs the phase space using Takens' embedding and
|
||||
/// classifies the attractor type using `midstreamer-attractor`.
|
||||
pub fn analyze(
|
||||
&mut self,
|
||||
metric: DriftMetric,
|
||||
timestamp_us: u64,
|
||||
) -> Result<AttractorDriftReport, AttractorDriftError> {
|
||||
let buf_idx = self
|
||||
.buffers
|
||||
.iter()
|
||||
.position(|b| b.metric == metric)
|
||||
.ok_or_else(|| AttractorDriftError::NoObservations(metric.name().into()))?;
|
||||
|
||||
let count = self.buffers[buf_idx].count();
|
||||
let min_needed = self.config.min_observations;
|
||||
if count < min_needed {
|
||||
return Err(AttractorDriftError::InsufficientData {
|
||||
needed: min_needed,
|
||||
have: count,
|
||||
});
|
||||
}
|
||||
|
||||
// Build phase-space trajectory using Takens' embedding
|
||||
// and feed into a fresh AttractorAnalyzer
|
||||
let dim = self.config.embedding_dim;
|
||||
let delay = self.config.time_delay;
|
||||
let values = &self.buffers[buf_idx].values;
|
||||
let n_points = values.len().saturating_sub((dim - 1) * delay);
|
||||
|
||||
let mut analyzer = AttractorAnalyzer::new(dim, self.config.max_trajectory_length);
|
||||
|
||||
for i in 0..n_points {
|
||||
let coords: Vec<f64> = (0..dim).map(|d| values[i + d * delay]).collect();
|
||||
let point = PhasePoint::new(coords, i as u64);
|
||||
let _ = analyzer.add_point(point);
|
||||
}
|
||||
|
||||
// Analyze the trajectory
|
||||
let attractor = match analyzer.analyze() {
|
||||
Ok(info) => {
|
||||
let max_lyap = info
|
||||
.max_lyapunov_exponent()
|
||||
.unwrap_or(0.0);
|
||||
|
||||
match info.attractor_type {
|
||||
AttractorType::PointAttractor => {
|
||||
// Compute center as mean of last few values
|
||||
let recent = &values[values.len().saturating_sub(10)..];
|
||||
let center = recent.iter().sum::<f64>() / recent.len() as f64;
|
||||
BiophysicalAttractor::Stable { center }
|
||||
}
|
||||
AttractorType::LimitCycle => BiophysicalAttractor::Periodic {
|
||||
lyapunov_max: max_lyap,
|
||||
},
|
||||
AttractorType::StrangeAttractor => BiophysicalAttractor::Chaotic {
|
||||
lyapunov_exponent: max_lyap,
|
||||
},
|
||||
_ => BiophysicalAttractor::Unknown,
|
||||
}
|
||||
}
|
||||
Err(_) => BiophysicalAttractor::Unknown,
|
||||
};
|
||||
|
||||
// Check for regime change
|
||||
let label = attractor.label().to_string();
|
||||
let regime_changed = label != self.buffers[buf_idx].last_label;
|
||||
self.buffers[buf_idx].last_label = label;
|
||||
|
||||
self.analysis_count += 1;
|
||||
|
||||
Ok(AttractorDriftReport {
|
||||
person_id: self.person_id,
|
||||
metric,
|
||||
attractor,
|
||||
regime_changed,
|
||||
observation_count: count,
|
||||
timestamp_us,
|
||||
})
|
||||
}
|
||||
|
||||
/// Number of observations for a specific metric.
|
||||
pub fn observation_count(&self, metric: DriftMetric) -> usize {
|
||||
self.buffers
|
||||
.iter()
|
||||
.find(|b| b.metric == metric)
|
||||
.map_or(0, |b| b.count())
|
||||
}
|
||||
|
||||
/// Total analyses performed.
|
||||
pub fn analysis_count(&self) -> u64 {
|
||||
self.analysis_count
|
||||
}
|
||||
|
||||
/// Person ID being monitored.
|
||||
pub fn person_id(&self) -> u64 {
|
||||
self.person_id
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn default_analyzer() -> AttractorDriftAnalyzer {
|
||||
AttractorDriftAnalyzer::new(42, AttractorDriftConfig::default()).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_analyzer_creation() {
|
||||
let a = default_analyzer();
|
||||
assert_eq!(a.person_id(), 42);
|
||||
assert_eq!(a.analysis_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_analyzer_invalid_embedding_dim() {
|
||||
let config = AttractorDriftConfig {
|
||||
embedding_dim: 1,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(matches!(
|
||||
AttractorDriftAnalyzer::new(1, config),
|
||||
Err(AttractorDriftError::InvalidEmbeddingDim { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_observation() {
|
||||
let mut a = default_analyzer();
|
||||
a.add_observation(DriftMetric::GaitSymmetry, 0.1);
|
||||
a.add_observation(DriftMetric::GaitSymmetry, 0.11);
|
||||
assert_eq!(a.observation_count(DriftMetric::GaitSymmetry), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_analyze_insufficient_data() {
|
||||
let mut a = default_analyzer();
|
||||
for i in 0..10 {
|
||||
a.add_observation(DriftMetric::GaitSymmetry, 0.1 + i as f64 * 0.001);
|
||||
}
|
||||
let result = a.analyze(DriftMetric::GaitSymmetry, 0);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(AttractorDriftError::InsufficientData { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_analyze_stable_signal() {
|
||||
let mut a = AttractorDriftAnalyzer::new(
|
||||
1,
|
||||
AttractorDriftConfig {
|
||||
min_observations: 10,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Stable signal: constant with tiny noise
|
||||
for i in 0..150 {
|
||||
let noise = 0.001 * (i as f64 % 3.0 - 1.0);
|
||||
a.add_observation(DriftMetric::GaitSymmetry, 0.1 + noise);
|
||||
}
|
||||
|
||||
let report = a.analyze(DriftMetric::GaitSymmetry, 1000).unwrap();
|
||||
assert_eq!(report.person_id, 1);
|
||||
assert_eq!(report.metric, DriftMetric::GaitSymmetry);
|
||||
assert_eq!(report.observation_count, 150);
|
||||
assert_eq!(a.analysis_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_analyze_periodic_signal() {
|
||||
let mut a = AttractorDriftAnalyzer::new(
|
||||
2,
|
||||
AttractorDriftConfig {
|
||||
min_observations: 10,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Periodic signal: sinusoidal with enough points for analyzer
|
||||
for i in 0..200 {
|
||||
let value = 0.5 + 0.3 * (i as f64 * std::f64::consts::PI / 7.0).sin();
|
||||
a.add_observation(DriftMetric::BreathingRegularity, value);
|
||||
}
|
||||
|
||||
let report = a.analyze(DriftMetric::BreathingRegularity, 2000).unwrap();
|
||||
assert_eq!(report.metric, DriftMetric::BreathingRegularity);
|
||||
assert!(!report.attractor.label().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_regime_change_detection() {
|
||||
let mut a = AttractorDriftAnalyzer::new(
|
||||
3,
|
||||
AttractorDriftConfig {
|
||||
min_observations: 10,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Phase 1: stable signal (enough for analyzer: >= 100 points)
|
||||
for i in 0..150 {
|
||||
let noise = 0.001 * (i as f64 % 3.0 - 1.0);
|
||||
a.add_observation(DriftMetric::StabilityIndex, 0.9 + noise);
|
||||
}
|
||||
let _report1 = a.analyze(DriftMetric::StabilityIndex, 1000).unwrap();
|
||||
|
||||
// Phase 2: add chaotic-like signal
|
||||
for i in 150..300 {
|
||||
let value = 0.5 + 0.4 * ((i as f64 * 1.7).sin() * (i as f64 * 0.3).cos());
|
||||
a.add_observation(DriftMetric::StabilityIndex, value);
|
||||
}
|
||||
let _report2 = a.analyze(DriftMetric::StabilityIndex, 2000).unwrap();
|
||||
assert!(a.analysis_count() >= 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_biophysical_attractor_labels() {
|
||||
assert_eq!(
|
||||
BiophysicalAttractor::Stable { center: 0.1 }.label(),
|
||||
"stable"
|
||||
);
|
||||
assert_eq!(
|
||||
BiophysicalAttractor::Periodic { lyapunov_max: 0.0 }.label(),
|
||||
"periodic"
|
||||
);
|
||||
assert_eq!(
|
||||
BiophysicalAttractor::Chaotic {
|
||||
lyapunov_exponent: 0.05,
|
||||
}
|
||||
.label(),
|
||||
"chaotic"
|
||||
);
|
||||
assert_eq!(BiophysicalAttractor::Unknown.label(), "unknown");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_biophysical_attractor_is_concerning() {
|
||||
assert!(!BiophysicalAttractor::Stable { center: 0.1 }.is_concerning());
|
||||
assert!(!BiophysicalAttractor::Periodic { lyapunov_max: 0.0 }.is_concerning());
|
||||
assert!(BiophysicalAttractor::Chaotic {
|
||||
lyapunov_exponent: 0.05,
|
||||
}
|
||||
.is_concerning());
|
||||
assert!(!BiophysicalAttractor::Unknown.is_concerning());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let cfg = AttractorDriftConfig::default();
|
||||
assert_eq!(cfg.embedding_dim, 3);
|
||||
assert_eq!(cfg.time_delay, 1);
|
||||
assert_eq!(cfg.min_observations, 30);
|
||||
assert!((cfg.lyapunov_threshold - 0.01).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_metric_buffer_eviction() {
|
||||
let mut buf = MetricBuffer::new(DriftMetric::GaitSymmetry, 5);
|
||||
for i in 0..10 {
|
||||
buf.push(i as f64);
|
||||
}
|
||||
assert_eq!(buf.count(), 5);
|
||||
assert!((buf.values[0] - 5.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_metrics_have_buffers() {
|
||||
let a = default_analyzer();
|
||||
for metric in DriftMetric::all() {
|
||||
assert_eq!(a.observation_count(*metric), 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transitioning_attractor() {
|
||||
let t = BiophysicalAttractor::Transitioning {
|
||||
from: Box::new(BiophysicalAttractor::Stable { center: 0.1 }),
|
||||
to: Box::new(BiophysicalAttractor::Chaotic {
|
||||
lyapunov_exponent: 0.05,
|
||||
}),
|
||||
};
|
||||
assert!(t.is_concerning());
|
||||
assert_eq!(t.label(), "transitioning");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_display() {
|
||||
let err = AttractorDriftError::InsufficientData {
|
||||
needed: 30,
|
||||
have: 10,
|
||||
};
|
||||
assert!(format!("{}", err).contains("30"));
|
||||
assert!(format!("{}", err).contains("10"));
|
||||
|
||||
let err = AttractorDriftError::NoObservations("gait_symmetry".into());
|
||||
assert!(format!("{}", err).contains("gait_symmetry"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_debug_impl() {
|
||||
let a = default_analyzer();
|
||||
let dbg = format!("{:?}", a);
|
||||
assert!(dbg.contains("AttractorDriftAnalyzer"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,464 @@
|
||||
//! Coherence Metric Computation (ADR-029 Section 2.5)
|
||||
//!
|
||||
//! Per-link coherence quantifies consistency of the current CSI observation
|
||||
//! with a running reference template. The metric is computed as a weighted
|
||||
//! mean of per-subcarrier Gaussian likelihoods:
|
||||
//!
|
||||
//! score = sum(w_i * exp(-0.5 * z_i^2)) / sum(w_i)
|
||||
//!
|
||||
//! where z_i = |current_i - reference_i| / sqrt(variance_i) and
|
||||
//! w_i = 1 / (variance_i + epsilon).
|
||||
//!
|
||||
//! Low-variance (stable) subcarriers dominate the score, making it
|
||||
//! sensitive to environmental drift while tolerant of body-motion
|
||||
//! subcarrier fluctuations.
|
||||
//!
|
||||
//! # RuVector Integration
|
||||
//!
|
||||
//! Uses `ruvector-solver` concepts for static/dynamic decomposition
|
||||
//! of the CSI signal into environmental drift and body motion components.
|
||||
|
||||
/// Errors from coherence computation.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum CoherenceError {
|
||||
/// Input vectors are empty.
|
||||
#[error("Empty input for coherence computation")]
|
||||
EmptyInput,
|
||||
|
||||
/// Length mismatch between current, reference, and variance vectors.
|
||||
#[error("Length mismatch: current={current}, reference={reference}, variance={variance}")]
|
||||
LengthMismatch {
|
||||
current: usize,
|
||||
reference: usize,
|
||||
variance: usize,
|
||||
},
|
||||
|
||||
/// Invalid decay rate (must be in (0, 1)).
|
||||
#[error("Invalid EMA decay rate: {0} (must be in (0, 1))")]
|
||||
InvalidDecay(f32),
|
||||
}
|
||||
|
||||
/// Drift profile classification for environmental changes.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum DriftProfile {
|
||||
/// Environment is stable (no significant baseline drift).
|
||||
Stable,
|
||||
/// Slow linear drift (temperature, humidity changes).
|
||||
Linear,
|
||||
/// Sudden step change (door opened, furniture moved).
|
||||
StepChange,
|
||||
}
|
||||
|
||||
/// Aggregate root for coherence state.
|
||||
///
|
||||
/// Maintains a running reference template (exponential moving average of
|
||||
/// accepted CSI observations) and per-subcarrier variance estimates.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CoherenceState {
|
||||
/// Per-subcarrier reference amplitude (EMA).
|
||||
reference: Vec<f32>,
|
||||
/// Per-subcarrier variance over recent window.
|
||||
variance: Vec<f32>,
|
||||
/// EMA decay rate for reference update (default 0.95).
|
||||
decay: f32,
|
||||
/// Current coherence score (0.0-1.0).
|
||||
current_score: f32,
|
||||
/// Frames since last accepted (coherent) measurement.
|
||||
stale_count: u64,
|
||||
/// Current drift profile classification.
|
||||
drift_profile: DriftProfile,
|
||||
/// Accept threshold for coherence score.
|
||||
accept_threshold: f32,
|
||||
/// Whether the reference has been initialized.
|
||||
initialized: bool,
|
||||
}
|
||||
|
||||
impl CoherenceState {
|
||||
/// Create a new coherence state for the given number of subcarriers.
|
||||
pub fn new(n_subcarriers: usize, accept_threshold: f32) -> Self {
|
||||
Self {
|
||||
reference: vec![0.0; n_subcarriers],
|
||||
variance: vec![1.0; n_subcarriers],
|
||||
decay: 0.95,
|
||||
current_score: 1.0,
|
||||
stale_count: 0,
|
||||
drift_profile: DriftProfile::Stable,
|
||||
accept_threshold,
|
||||
initialized: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with a custom EMA decay rate.
|
||||
pub fn with_decay(
|
||||
n_subcarriers: usize,
|
||||
accept_threshold: f32,
|
||||
decay: f32,
|
||||
) -> std::result::Result<Self, CoherenceError> {
|
||||
if decay <= 0.0 || decay >= 1.0 {
|
||||
return Err(CoherenceError::InvalidDecay(decay));
|
||||
}
|
||||
let mut state = Self::new(n_subcarriers, accept_threshold);
|
||||
state.decay = decay;
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
/// Return the current coherence score.
|
||||
pub fn score(&self) -> f32 {
|
||||
self.current_score
|
||||
}
|
||||
|
||||
/// Return the number of frames since last accepted measurement.
|
||||
pub fn stale_count(&self) -> u64 {
|
||||
self.stale_count
|
||||
}
|
||||
|
||||
/// Return the current drift profile.
|
||||
pub fn drift_profile(&self) -> DriftProfile {
|
||||
self.drift_profile
|
||||
}
|
||||
|
||||
/// Return a reference to the current reference template.
|
||||
pub fn reference(&self) -> &[f32] {
|
||||
&self.reference
|
||||
}
|
||||
|
||||
/// Return a reference to the current variance estimates.
|
||||
pub fn variance(&self) -> &[f32] {
|
||||
&self.variance
|
||||
}
|
||||
|
||||
/// Return whether the reference has been initialized.
|
||||
pub fn is_initialized(&self) -> bool {
|
||||
self.initialized
|
||||
}
|
||||
|
||||
/// Initialize the reference from a calibration observation.
|
||||
///
|
||||
/// Should be called with a static-environment CSI frame before
|
||||
/// sensing begins.
|
||||
pub fn initialize(&mut self, calibration: &[f32]) {
|
||||
self.reference = calibration.to_vec();
|
||||
self.variance = vec![1.0; calibration.len()];
|
||||
self.current_score = 1.0;
|
||||
self.stale_count = 0;
|
||||
self.initialized = true;
|
||||
}
|
||||
|
||||
/// Update the coherence state with a new observation.
|
||||
///
|
||||
/// Computes the coherence score, updates the reference template if
|
||||
/// the observation is accepted, and tracks staleness.
|
||||
pub fn update(
|
||||
&mut self,
|
||||
current: &[f32],
|
||||
) -> std::result::Result<f32, CoherenceError> {
|
||||
if current.is_empty() {
|
||||
return Err(CoherenceError::EmptyInput);
|
||||
}
|
||||
|
||||
if !self.initialized {
|
||||
self.initialize(current);
|
||||
return Ok(1.0);
|
||||
}
|
||||
|
||||
if current.len() != self.reference.len() {
|
||||
return Err(CoherenceError::LengthMismatch {
|
||||
current: current.len(),
|
||||
reference: self.reference.len(),
|
||||
variance: self.variance.len(),
|
||||
});
|
||||
}
|
||||
|
||||
// Compute coherence score
|
||||
let score = coherence_score(current, &self.reference, &self.variance);
|
||||
self.current_score = score;
|
||||
|
||||
// Update reference if accepted
|
||||
if score >= self.accept_threshold {
|
||||
self.update_reference(current);
|
||||
self.stale_count = 0;
|
||||
} else {
|
||||
self.stale_count += 1;
|
||||
}
|
||||
|
||||
// Update drift profile
|
||||
self.drift_profile = classify_drift(score, self.stale_count);
|
||||
|
||||
Ok(score)
|
||||
}
|
||||
|
||||
/// Update the reference template with EMA.
|
||||
fn update_reference(&mut self, observation: &[f32]) {
|
||||
let alpha = 1.0 - self.decay;
|
||||
for i in 0..self.reference.len() {
|
||||
let old_ref = self.reference[i];
|
||||
self.reference[i] = self.decay * old_ref + alpha * observation[i];
|
||||
|
||||
// Update variance with Welford-style online estimate
|
||||
let diff = observation[i] - old_ref;
|
||||
self.variance[i] = self.decay * self.variance[i] + alpha * diff * diff;
|
||||
// Ensure variance does not collapse to zero
|
||||
if self.variance[i] < 1e-6 {
|
||||
self.variance[i] = 1e-6;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset the stale counter (e.g., after recalibration).
|
||||
pub fn reset_stale(&mut self) {
|
||||
self.stale_count = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the coherence score between a current observation and a
|
||||
/// reference template.
|
||||
///
|
||||
/// Uses z-score per subcarrier with variance-inverse weighting:
|
||||
///
|
||||
/// score = sum(w_i * exp(-0.5 * z_i^2)) / sum(w_i)
|
||||
///
|
||||
/// where z_i = |current_i - reference_i| / sqrt(variance_i)
|
||||
/// and w_i = 1 / (variance_i + epsilon).
|
||||
///
|
||||
/// Returns a value in [0.0, 1.0] where 1.0 means perfect agreement.
|
||||
pub fn coherence_score(
|
||||
current: &[f32],
|
||||
reference: &[f32],
|
||||
variance: &[f32],
|
||||
) -> f32 {
|
||||
let n = current.len().min(reference.len()).min(variance.len());
|
||||
if n == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let epsilon = 1e-6_f32;
|
||||
let mut weighted_sum = 0.0_f32;
|
||||
let mut weight_sum = 0.0_f32;
|
||||
|
||||
for i in 0..n {
|
||||
let var = variance[i].max(epsilon);
|
||||
let z = (current[i] - reference[i]).abs() / var.sqrt();
|
||||
let weight = 1.0 / (var + epsilon);
|
||||
let likelihood = (-0.5 * z * z).exp();
|
||||
weighted_sum += likelihood * weight;
|
||||
weight_sum += weight;
|
||||
}
|
||||
|
||||
if weight_sum < epsilon {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
(weighted_sum / weight_sum).clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
/// Classify drift profile based on coherence history.
|
||||
fn classify_drift(score: f32, stale_count: u64) -> DriftProfile {
|
||||
if score >= 0.85 {
|
||||
DriftProfile::Stable
|
||||
} else if stale_count < 10 {
|
||||
// Brief coherence loss -> likely step change
|
||||
DriftProfile::StepChange
|
||||
} else {
|
||||
// Extended low coherence -> linear drift
|
||||
DriftProfile::Linear
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute per-subcarrier z-scores for diagnostics.
|
||||
///
|
||||
/// Returns a vector of z-scores, one per subcarrier.
|
||||
pub fn per_subcarrier_zscores(
|
||||
current: &[f32],
|
||||
reference: &[f32],
|
||||
variance: &[f32],
|
||||
) -> Vec<f32> {
|
||||
let n = current.len().min(reference.len()).min(variance.len());
|
||||
(0..n)
|
||||
.map(|i| {
|
||||
let var = variance[i].max(1e-6);
|
||||
(current[i] - reference[i]).abs() / var.sqrt()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Identify subcarriers that are outliers (z-score above threshold).
|
||||
///
|
||||
/// Returns indices of outlier subcarriers.
|
||||
pub fn outlier_subcarriers(
|
||||
current: &[f32],
|
||||
reference: &[f32],
|
||||
variance: &[f32],
|
||||
z_threshold: f32,
|
||||
) -> Vec<usize> {
|
||||
let z_scores = per_subcarrier_zscores(current, reference, variance);
|
||||
z_scores
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, &z)| z > z_threshold)
|
||||
.map(|(i, _)| i)
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn perfect_coherence() {
|
||||
let current = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let reference = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let variance = vec![0.01, 0.01, 0.01, 0.01];
|
||||
let score = coherence_score(¤t, &reference, &variance);
|
||||
assert!((score - 1.0).abs() < 0.01, "Perfect match should give ~1.0, got {}", score);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_coherence_large_deviation() {
|
||||
let current = vec![100.0, 200.0, 300.0];
|
||||
let reference = vec![0.0, 0.0, 0.0];
|
||||
let variance = vec![0.001, 0.001, 0.001];
|
||||
let score = coherence_score(¤t, &reference, &variance);
|
||||
assert!(score < 0.01, "Large deviation should give ~0.0, got {}", score);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_input_gives_zero() {
|
||||
assert_eq!(coherence_score(&[], &[], &[]), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn state_initialize_and_score() {
|
||||
let mut state = CoherenceState::new(4, 0.85);
|
||||
assert!(!state.is_initialized());
|
||||
state.initialize(&[1.0, 2.0, 3.0, 4.0]);
|
||||
assert!(state.is_initialized());
|
||||
assert!((state.score() - 1.0).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn state_update_accepted() {
|
||||
let mut state = CoherenceState::new(4, 0.5);
|
||||
state.initialize(&[1.0, 2.0, 3.0, 4.0]);
|
||||
let score = state.update(&[1.01, 2.01, 3.01, 4.01]).unwrap();
|
||||
assert!(score > 0.8, "Small deviation should be accepted, got {}", score);
|
||||
assert_eq!(state.stale_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn state_update_rejected() {
|
||||
let mut state = CoherenceState::new(4, 0.99);
|
||||
state.initialize(&[1.0, 2.0, 3.0, 4.0]);
|
||||
let _ = state.update(&[10.0, 20.0, 30.0, 40.0]).unwrap();
|
||||
assert!(state.stale_count() > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_initialize_on_first_update() {
|
||||
let mut state = CoherenceState::new(3, 0.85);
|
||||
let score = state.update(&[5.0, 6.0, 7.0]).unwrap();
|
||||
assert!((score - 1.0).abs() < f32::EPSILON);
|
||||
assert!(state.is_initialized());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn length_mismatch_error() {
|
||||
let mut state = CoherenceState::new(4, 0.85);
|
||||
state.initialize(&[1.0, 2.0, 3.0, 4.0]);
|
||||
let result = state.update(&[1.0, 2.0]);
|
||||
assert!(matches!(result, Err(CoherenceError::LengthMismatch { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_update_error() {
|
||||
let mut state = CoherenceState::new(4, 0.85);
|
||||
state.initialize(&[1.0, 2.0, 3.0, 4.0]);
|
||||
assert!(matches!(state.update(&[]), Err(CoherenceError::EmptyInput)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_decay_error() {
|
||||
assert!(matches!(
|
||||
CoherenceState::with_decay(4, 0.85, 0.0),
|
||||
Err(CoherenceError::InvalidDecay(_))
|
||||
));
|
||||
assert!(matches!(
|
||||
CoherenceState::with_decay(4, 0.85, 1.0),
|
||||
Err(CoherenceError::InvalidDecay(_))
|
||||
));
|
||||
assert!(matches!(
|
||||
CoherenceState::with_decay(4, 0.85, -0.5),
|
||||
Err(CoherenceError::InvalidDecay(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid_decay() {
|
||||
let state = CoherenceState::with_decay(4, 0.85, 0.9).unwrap();
|
||||
assert!((state.score() - 1.0).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn drift_classification_stable() {
|
||||
assert_eq!(classify_drift(0.9, 0), DriftProfile::Stable);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn drift_classification_step_change() {
|
||||
assert_eq!(classify_drift(0.3, 5), DriftProfile::StepChange);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn drift_classification_linear() {
|
||||
assert_eq!(classify_drift(0.3, 20), DriftProfile::Linear);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn per_subcarrier_zscores_correct() {
|
||||
let current = vec![2.0, 4.0];
|
||||
let reference = vec![1.0, 2.0];
|
||||
let variance = vec![1.0, 4.0];
|
||||
let z = per_subcarrier_zscores(¤t, &reference, &variance);
|
||||
assert_eq!(z.len(), 2);
|
||||
assert!((z[0] - 1.0).abs() < 1e-5);
|
||||
assert!((z[1] - 1.0).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn outlier_subcarriers_detected() {
|
||||
let current = vec![1.0, 100.0, 1.0, 200.0];
|
||||
let reference = vec![1.0, 1.0, 1.0, 1.0];
|
||||
let variance = vec![1.0, 1.0, 1.0, 1.0];
|
||||
let outliers = outlier_subcarriers(¤t, &reference, &variance, 3.0);
|
||||
assert!(outliers.contains(&1));
|
||||
assert!(outliers.contains(&3));
|
||||
assert!(!outliers.contains(&0));
|
||||
assert!(!outliers.contains(&2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reset_stale_counter() {
|
||||
let mut state = CoherenceState::new(4, 0.99);
|
||||
state.initialize(&[1.0, 2.0, 3.0, 4.0]);
|
||||
let _ = state.update(&[10.0, 20.0, 30.0, 40.0]).unwrap();
|
||||
assert!(state.stale_count() > 0);
|
||||
state.reset_stale();
|
||||
assert_eq!(state.stale_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reference_and_variance_accessible() {
|
||||
let state = CoherenceState::new(3, 0.85);
|
||||
assert_eq!(state.reference().len(), 3);
|
||||
assert_eq!(state.variance().len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coherence_score_with_high_variance() {
|
||||
let current = vec![5.0, 6.0, 7.0];
|
||||
let reference = vec![1.0, 2.0, 3.0];
|
||||
let variance = vec![100.0, 100.0, 100.0]; // high variance
|
||||
let score = coherence_score(¤t, &reference, &variance);
|
||||
// With high variance, deviation is relatively small
|
||||
assert!(score > 0.5, "High variance should tolerate deviation, got {}", score);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,365 @@
|
||||
//! Coherence-Gated Update Policy (ADR-029 Section 2.6)
|
||||
//!
|
||||
//! Applies a threshold-based gating rule to the coherence score, producing
|
||||
//! a `GateDecision` that controls downstream Kalman filter updates:
|
||||
//!
|
||||
//! - **Accept** (coherence > 0.85): Full measurement update with nominal noise.
|
||||
//! - **PredictOnly** (0.5 < coherence < 0.85): Kalman predict step only,
|
||||
//! measurement noise inflated 3x.
|
||||
//! - **Reject** (coherence < 0.5): Discard measurement entirely.
|
||||
//! - **Recalibrate** (>10s continuous low coherence): Trigger SONA/AETHER
|
||||
//! recalibration pipeline.
|
||||
//!
|
||||
//! The gate operates on the coherence score produced by the `coherence` module
|
||||
//! and the stale frame counter from `CoherenceState`.
|
||||
|
||||
/// Gate decision controlling Kalman filter update behavior.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum GateDecision {
|
||||
/// Coherence is high. Proceed with full Kalman measurement update.
|
||||
/// Contains the inflated measurement noise multiplier (1.0 = nominal).
|
||||
Accept {
|
||||
/// Measurement noise multiplier (1.0 for full accept).
|
||||
noise_multiplier: f32,
|
||||
},
|
||||
|
||||
/// Coherence is moderate. Run Kalman predict only (no measurement update).
|
||||
/// Measurement noise would be inflated 3x if used.
|
||||
PredictOnly,
|
||||
|
||||
/// Coherence is low. Reject this measurement entirely.
|
||||
Reject,
|
||||
|
||||
/// Prolonged low coherence. Trigger environmental recalibration.
|
||||
/// The pipeline should freeze output at last known good pose and
|
||||
/// begin the SONA/AETHER TTT adaptation cycle.
|
||||
Recalibrate {
|
||||
/// Duration of low coherence in frames.
|
||||
stale_frames: u64,
|
||||
},
|
||||
}
|
||||
|
||||
impl GateDecision {
|
||||
/// Returns true if this decision allows a measurement update.
|
||||
pub fn allows_update(&self) -> bool {
|
||||
matches!(self, GateDecision::Accept { .. })
|
||||
}
|
||||
|
||||
/// Returns true if this is a reject or recalibrate decision.
|
||||
pub fn is_rejected(&self) -> bool {
|
||||
matches!(self, GateDecision::Reject | GateDecision::Recalibrate { .. })
|
||||
}
|
||||
|
||||
/// Returns the noise multiplier for accepted decisions, or None otherwise.
|
||||
pub fn noise_multiplier(&self) -> Option<f32> {
|
||||
match self {
|
||||
GateDecision::Accept { noise_multiplier } => Some(*noise_multiplier),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for the gate policy thresholds.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GatePolicyConfig {
|
||||
/// Coherence threshold above which measurements are accepted.
|
||||
pub accept_threshold: f32,
|
||||
/// Coherence threshold below which measurements are rejected.
|
||||
pub reject_threshold: f32,
|
||||
/// Maximum stale frames before triggering recalibration.
|
||||
pub max_stale_frames: u64,
|
||||
/// Noise inflation factor for PredictOnly zone.
|
||||
pub predict_only_noise: f32,
|
||||
/// Whether to use adaptive thresholds based on drift profile.
|
||||
pub adaptive: bool,
|
||||
}
|
||||
|
||||
impl Default for GatePolicyConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
accept_threshold: 0.85,
|
||||
reject_threshold: 0.5,
|
||||
max_stale_frames: 200, // 10s at 20Hz
|
||||
predict_only_noise: 3.0,
|
||||
adaptive: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Gate policy that maps coherence scores to gate decisions.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GatePolicy {
|
||||
/// Accept threshold.
|
||||
accept_threshold: f32,
|
||||
/// Reject threshold.
|
||||
reject_threshold: f32,
|
||||
/// Maximum stale frames before recalibration.
|
||||
max_stale_frames: u64,
|
||||
/// Noise inflation for predict-only zone.
|
||||
predict_only_noise: f32,
|
||||
/// Running count of consecutive rejected/predict-only frames.
|
||||
consecutive_low: u64,
|
||||
/// Last decision for tracking transitions.
|
||||
last_decision: Option<GateDecision>,
|
||||
}
|
||||
|
||||
impl GatePolicy {
|
||||
/// Create a gate policy with the given thresholds.
|
||||
pub fn new(accept: f32, reject: f32, max_stale: u64) -> Self {
|
||||
Self {
|
||||
accept_threshold: accept,
|
||||
reject_threshold: reject,
|
||||
max_stale_frames: max_stale,
|
||||
predict_only_noise: 3.0,
|
||||
consecutive_low: 0,
|
||||
last_decision: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a gate policy from a configuration.
|
||||
pub fn from_config(config: &GatePolicyConfig) -> Self {
|
||||
Self {
|
||||
accept_threshold: config.accept_threshold,
|
||||
reject_threshold: config.reject_threshold,
|
||||
max_stale_frames: config.max_stale_frames,
|
||||
predict_only_noise: config.predict_only_noise,
|
||||
consecutive_low: 0,
|
||||
last_decision: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate the gate decision for a given coherence score and stale count.
|
||||
pub fn evaluate(&mut self, coherence_score: f32, stale_count: u64) -> GateDecision {
|
||||
let decision = if stale_count >= self.max_stale_frames {
|
||||
GateDecision::Recalibrate {
|
||||
stale_frames: stale_count,
|
||||
}
|
||||
} else if coherence_score >= self.accept_threshold {
|
||||
self.consecutive_low = 0;
|
||||
GateDecision::Accept {
|
||||
noise_multiplier: 1.0,
|
||||
}
|
||||
} else if coherence_score >= self.reject_threshold {
|
||||
self.consecutive_low += 1;
|
||||
GateDecision::PredictOnly
|
||||
} else {
|
||||
self.consecutive_low += 1;
|
||||
GateDecision::Reject
|
||||
};
|
||||
|
||||
self.last_decision = Some(decision.clone());
|
||||
decision
|
||||
}
|
||||
|
||||
/// Return the last gate decision, if any.
|
||||
pub fn last_decision(&self) -> Option<&GateDecision> {
|
||||
self.last_decision.as_ref()
|
||||
}
|
||||
|
||||
/// Return the current count of consecutive low-coherence frames.
|
||||
pub fn consecutive_low_count(&self) -> u64 {
|
||||
self.consecutive_low
|
||||
}
|
||||
|
||||
/// Return the accept threshold.
|
||||
pub fn accept_threshold(&self) -> f32 {
|
||||
self.accept_threshold
|
||||
}
|
||||
|
||||
/// Return the reject threshold.
|
||||
pub fn reject_threshold(&self) -> f32 {
|
||||
self.reject_threshold
|
||||
}
|
||||
|
||||
/// Reset the policy state (e.g., after recalibration).
|
||||
pub fn reset(&mut self) {
|
||||
self.consecutive_low = 0;
|
||||
self.last_decision = None;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GatePolicy {
|
||||
fn default() -> Self {
|
||||
Self::from_config(&GatePolicyConfig::default())
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute an adaptive noise multiplier for the PredictOnly zone.
|
||||
///
|
||||
/// As coherence drops from accept to reject threshold, the noise
|
||||
/// multiplier increases from 1.0 to `max_inflation`.
|
||||
pub fn adaptive_noise_multiplier(
|
||||
coherence: f32,
|
||||
accept: f32,
|
||||
reject: f32,
|
||||
max_inflation: f32,
|
||||
) -> f32 {
|
||||
if coherence >= accept {
|
||||
return 1.0;
|
||||
}
|
||||
if coherence <= reject {
|
||||
return max_inflation;
|
||||
}
|
||||
let range = accept - reject;
|
||||
if range < 1e-6 {
|
||||
return max_inflation;
|
||||
}
|
||||
let t = (accept - coherence) / range;
|
||||
1.0 + t * (max_inflation - 1.0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn accept_high_coherence() {
|
||||
let mut gate = GatePolicy::new(0.85, 0.5, 200);
|
||||
let decision = gate.evaluate(0.95, 0);
|
||||
assert!(matches!(decision, GateDecision::Accept { noise_multiplier } if (noise_multiplier - 1.0).abs() < f32::EPSILON));
|
||||
assert!(decision.allows_update());
|
||||
assert!(!decision.is_rejected());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn predict_only_moderate_coherence() {
|
||||
let mut gate = GatePolicy::new(0.85, 0.5, 200);
|
||||
let decision = gate.evaluate(0.7, 0);
|
||||
assert!(matches!(decision, GateDecision::PredictOnly));
|
||||
assert!(!decision.allows_update());
|
||||
assert!(!decision.is_rejected());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_low_coherence() {
|
||||
let mut gate = GatePolicy::new(0.85, 0.5, 200);
|
||||
let decision = gate.evaluate(0.3, 0);
|
||||
assert!(matches!(decision, GateDecision::Reject));
|
||||
assert!(!decision.allows_update());
|
||||
assert!(decision.is_rejected());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recalibrate_after_stale_timeout() {
|
||||
let mut gate = GatePolicy::new(0.85, 0.5, 200);
|
||||
let decision = gate.evaluate(0.3, 200);
|
||||
assert!(matches!(decision, GateDecision::Recalibrate { stale_frames: 200 }));
|
||||
assert!(decision.is_rejected());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recalibrate_overrides_accept() {
|
||||
let mut gate = GatePolicy::new(0.85, 0.5, 100);
|
||||
// Even with high coherence, stale count triggers recalibration
|
||||
let decision = gate.evaluate(0.95, 100);
|
||||
assert!(matches!(decision, GateDecision::Recalibrate { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn consecutive_low_counter() {
|
||||
let mut gate = GatePolicy::new(0.85, 0.5, 200);
|
||||
gate.evaluate(0.3, 0);
|
||||
assert_eq!(gate.consecutive_low_count(), 1);
|
||||
gate.evaluate(0.6, 0);
|
||||
assert_eq!(gate.consecutive_low_count(), 2);
|
||||
gate.evaluate(0.9, 0); // accepted -> resets
|
||||
assert_eq!(gate.consecutive_low_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn last_decision_tracked() {
|
||||
let mut gate = GatePolicy::new(0.85, 0.5, 200);
|
||||
assert!(gate.last_decision().is_none());
|
||||
gate.evaluate(0.9, 0);
|
||||
assert!(gate.last_decision().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reset_clears_state() {
|
||||
let mut gate = GatePolicy::new(0.85, 0.5, 200);
|
||||
gate.evaluate(0.3, 0);
|
||||
gate.evaluate(0.3, 0);
|
||||
gate.reset();
|
||||
assert_eq!(gate.consecutive_low_count(), 0);
|
||||
assert!(gate.last_decision().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn noise_multiplier_accessor() {
|
||||
let accept = GateDecision::Accept { noise_multiplier: 2.5 };
|
||||
assert_eq!(accept.noise_multiplier(), Some(2.5));
|
||||
|
||||
let reject = GateDecision::Reject;
|
||||
assert_eq!(reject.noise_multiplier(), None);
|
||||
|
||||
let predict = GateDecision::PredictOnly;
|
||||
assert_eq!(predict.noise_multiplier(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_noise_at_boundaries() {
|
||||
assert!((adaptive_noise_multiplier(0.9, 0.85, 0.5, 3.0) - 1.0).abs() < f32::EPSILON);
|
||||
assert!((adaptive_noise_multiplier(0.3, 0.85, 0.5, 3.0) - 3.0).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_noise_midpoint() {
|
||||
let mid = adaptive_noise_multiplier(0.675, 0.85, 0.5, 3.0);
|
||||
assert!((mid - 2.0).abs() < 0.01, "Midpoint noise should be ~2.0, got {}", mid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_noise_tiny_range() {
|
||||
// When accept == reject, coherence >= accept returns 1.0
|
||||
let val = adaptive_noise_multiplier(0.5, 0.5, 0.5, 3.0);
|
||||
assert!((val - 1.0).abs() < f32::EPSILON);
|
||||
// Below both thresholds should return max_inflation
|
||||
let val2 = adaptive_noise_multiplier(0.4, 0.5, 0.5, 3.0);
|
||||
assert!((val2 - 3.0).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_config_values() {
|
||||
let cfg = GatePolicyConfig::default();
|
||||
assert!((cfg.accept_threshold - 0.85).abs() < f32::EPSILON);
|
||||
assert!((cfg.reject_threshold - 0.5).abs() < f32::EPSILON);
|
||||
assert_eq!(cfg.max_stale_frames, 200);
|
||||
assert!((cfg.predict_only_noise - 3.0).abs() < f32::EPSILON);
|
||||
assert!(!cfg.adaptive);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_config_construction() {
|
||||
let cfg = GatePolicyConfig {
|
||||
accept_threshold: 0.9,
|
||||
reject_threshold: 0.4,
|
||||
max_stale_frames: 100,
|
||||
predict_only_noise: 5.0,
|
||||
adaptive: true,
|
||||
};
|
||||
let gate = GatePolicy::from_config(&cfg);
|
||||
assert!((gate.accept_threshold() - 0.9).abs() < f32::EPSILON);
|
||||
assert!((gate.reject_threshold() - 0.4).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn boundary_at_exact_accept_threshold() {
|
||||
let mut gate = GatePolicy::new(0.85, 0.5, 200);
|
||||
let decision = gate.evaluate(0.85, 0);
|
||||
assert!(matches!(decision, GateDecision::Accept { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn boundary_at_exact_reject_threshold() {
|
||||
let mut gate = GatePolicy::new(0.85, 0.5, 200);
|
||||
let decision = gate.evaluate(0.5, 0);
|
||||
assert!(matches!(decision, GateDecision::PredictOnly));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn boundary_just_below_reject_threshold() {
|
||||
let mut gate = GatePolicy::new(0.85, 0.5, 200);
|
||||
let decision = gate.evaluate(0.499, 0);
|
||||
assert!(matches!(decision, GateDecision::Reject));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,626 @@
|
||||
//! Cross-room identity continuity.
|
||||
//!
|
||||
//! Maintains identity persistence across rooms without optics by
|
||||
//! fingerprinting each room's electromagnetic profile, tracking
|
||||
//! exit/entry events, and matching person embeddings across transition
|
||||
//! boundaries.
|
||||
//!
|
||||
//! # Algorithm
|
||||
//! 1. Each room is fingerprinted as a 128-dim AETHER embedding of its
|
||||
//! static CSI profile
|
||||
//! 2. When a track is lost near a room boundary, record an exit event
|
||||
//! with the person's current embedding
|
||||
//! 3. When a new track appears in an adjacent room within 60s, compare
|
||||
//! its embedding against recent exits
|
||||
//! 4. If cosine similarity > 0.80, link the identities
|
||||
//!
|
||||
//! # Invariants
|
||||
//! - Cross-room match requires > 0.80 cosine similarity AND < 60s temporal gap
|
||||
//! - Transition graph is append-only (immutable audit trail)
|
||||
//! - No image data stored — only 128-dim embeddings and structural events
|
||||
//! - Maximum 100 rooms per deployment
|
||||
//!
|
||||
//! # References
|
||||
//! - ADR-030 Tier 5: Cross-Room Identity Continuity
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Error types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors from cross-room operations.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum CrossRoomError {
|
||||
/// Room capacity exceeded.
|
||||
#[error("Maximum rooms exceeded: limit is {max}")]
|
||||
MaxRoomsExceeded { max: usize },
|
||||
|
||||
/// Room not found.
|
||||
#[error("Unknown room ID: {0}")]
|
||||
UnknownRoom(u64),
|
||||
|
||||
/// Embedding dimension mismatch.
|
||||
#[error("Embedding dimension mismatch: expected {expected}, got {got}")]
|
||||
EmbeddingDimensionMismatch { expected: usize, got: usize },
|
||||
|
||||
/// Invalid temporal gap for matching.
|
||||
#[error("Temporal gap {gap_s:.1}s exceeds maximum {max_s:.1}s")]
|
||||
TemporalGapExceeded { gap_s: f64, max_s: f64 },
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Configuration
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Configuration for cross-room identity tracking.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CrossRoomConfig {
|
||||
/// Embedding dimension (typically 128).
|
||||
pub embedding_dim: usize,
|
||||
/// Minimum cosine similarity for cross-room match.
|
||||
pub min_similarity: f32,
|
||||
/// Maximum temporal gap (seconds) for cross-room match.
|
||||
pub max_gap_s: f64,
|
||||
/// Maximum rooms in the deployment.
|
||||
pub max_rooms: usize,
|
||||
/// Maximum pending exit events to retain.
|
||||
pub max_pending_exits: usize,
|
||||
}
|
||||
|
||||
impl Default for CrossRoomConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
embedding_dim: 128,
|
||||
min_similarity: 0.80,
|
||||
max_gap_s: 60.0,
|
||||
max_rooms: 100,
|
||||
max_pending_exits: 200,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Domain types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A room's electromagnetic fingerprint.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RoomFingerprint {
|
||||
/// Room identifier.
|
||||
pub room_id: u64,
|
||||
/// Fingerprint embedding vector.
|
||||
pub embedding: Vec<f32>,
|
||||
/// Timestamp when fingerprint was last computed (microseconds).
|
||||
pub computed_at_us: u64,
|
||||
/// Number of nodes contributing to this fingerprint.
|
||||
pub node_count: usize,
|
||||
}
|
||||
|
||||
/// An exit event: a person leaving a room.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ExitEvent {
|
||||
/// Person embedding at exit time.
|
||||
pub embedding: Vec<f32>,
|
||||
/// Room exited.
|
||||
pub room_id: u64,
|
||||
/// Person track ID (local to the room).
|
||||
pub track_id: u64,
|
||||
/// Timestamp of exit (microseconds).
|
||||
pub timestamp_us: u64,
|
||||
/// Whether this exit has been matched to an entry.
|
||||
pub matched: bool,
|
||||
}
|
||||
|
||||
/// An entry event: a person appearing in a room.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EntryEvent {
|
||||
/// Person embedding at entry time.
|
||||
pub embedding: Vec<f32>,
|
||||
/// Room entered.
|
||||
pub room_id: u64,
|
||||
/// Person track ID (local to the room).
|
||||
pub track_id: u64,
|
||||
/// Timestamp of entry (microseconds).
|
||||
pub timestamp_us: u64,
|
||||
}
|
||||
|
||||
/// A cross-room transition record (immutable).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TransitionEvent {
|
||||
/// Person who transitioned.
|
||||
pub person_id: u64,
|
||||
/// Room exited.
|
||||
pub from_room: u64,
|
||||
/// Room entered.
|
||||
pub to_room: u64,
|
||||
/// Exit track ID.
|
||||
pub exit_track_id: u64,
|
||||
/// Entry track ID.
|
||||
pub entry_track_id: u64,
|
||||
/// Cosine similarity between exit and entry embeddings.
|
||||
pub similarity: f32,
|
||||
/// Temporal gap between exit and entry (seconds).
|
||||
pub gap_s: f64,
|
||||
/// Timestamp of the transition (entry timestamp).
|
||||
pub timestamp_us: u64,
|
||||
}
|
||||
|
||||
/// Result of attempting to match an entry against pending exits.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatchResult {
|
||||
/// Whether a match was found.
|
||||
pub matched: bool,
|
||||
/// The transition event, if matched.
|
||||
pub transition: Option<TransitionEvent>,
|
||||
/// Number of candidates checked.
|
||||
pub candidates_checked: usize,
|
||||
/// Best similarity found (even if below threshold).
|
||||
pub best_similarity: f32,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Cross-room identity tracker
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Cross-room identity continuity tracker.
|
||||
///
|
||||
/// Maintains room fingerprints, pending exit events, and an immutable
|
||||
/// transition graph. Matches person embeddings across rooms using
|
||||
/// cosine similarity with temporal constraints.
|
||||
#[derive(Debug)]
|
||||
pub struct CrossRoomTracker {
|
||||
config: CrossRoomConfig,
|
||||
/// Room fingerprints indexed by room_id.
|
||||
rooms: Vec<RoomFingerprint>,
|
||||
/// Pending (unmatched) exit events.
|
||||
pending_exits: Vec<ExitEvent>,
|
||||
/// Immutable transition log (append-only).
|
||||
transitions: Vec<TransitionEvent>,
|
||||
/// Next person ID for cross-room identity assignment.
|
||||
next_person_id: u64,
|
||||
}
|
||||
|
||||
impl CrossRoomTracker {
|
||||
/// Create a new cross-room tracker.
|
||||
pub fn new(config: CrossRoomConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
rooms: Vec::new(),
|
||||
pending_exits: Vec::new(),
|
||||
transitions: Vec::new(),
|
||||
next_person_id: 1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a room fingerprint.
|
||||
pub fn register_room(&mut self, fingerprint: RoomFingerprint) -> Result<(), CrossRoomError> {
|
||||
if self.rooms.len() >= self.config.max_rooms {
|
||||
return Err(CrossRoomError::MaxRoomsExceeded {
|
||||
max: self.config.max_rooms,
|
||||
});
|
||||
}
|
||||
if fingerprint.embedding.len() != self.config.embedding_dim {
|
||||
return Err(CrossRoomError::EmbeddingDimensionMismatch {
|
||||
expected: self.config.embedding_dim,
|
||||
got: fingerprint.embedding.len(),
|
||||
});
|
||||
}
|
||||
// Replace existing fingerprint if room already registered
|
||||
if let Some(existing) = self
|
||||
.rooms
|
||||
.iter_mut()
|
||||
.find(|r| r.room_id == fingerprint.room_id)
|
||||
{
|
||||
*existing = fingerprint;
|
||||
} else {
|
||||
self.rooms.push(fingerprint);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Record a person exiting a room.
|
||||
pub fn record_exit(&mut self, event: ExitEvent) -> Result<(), CrossRoomError> {
|
||||
if event.embedding.len() != self.config.embedding_dim {
|
||||
return Err(CrossRoomError::EmbeddingDimensionMismatch {
|
||||
expected: self.config.embedding_dim,
|
||||
got: event.embedding.len(),
|
||||
});
|
||||
}
|
||||
// Evict oldest if at capacity
|
||||
if self.pending_exits.len() >= self.config.max_pending_exits {
|
||||
self.pending_exits.remove(0);
|
||||
}
|
||||
self.pending_exits.push(event);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Try to match an entry event against pending exits.
|
||||
///
|
||||
/// If a match is found, creates a TransitionEvent and marks the
|
||||
/// exit as matched. Returns the match result.
|
||||
pub fn match_entry(&mut self, entry: &EntryEvent) -> Result<MatchResult, CrossRoomError> {
|
||||
if entry.embedding.len() != self.config.embedding_dim {
|
||||
return Err(CrossRoomError::EmbeddingDimensionMismatch {
|
||||
expected: self.config.embedding_dim,
|
||||
got: entry.embedding.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut best_idx: Option<usize> = None;
|
||||
let mut best_sim: f32 = -1.0;
|
||||
let mut candidates_checked = 0;
|
||||
|
||||
for (idx, exit) in self.pending_exits.iter().enumerate() {
|
||||
if exit.matched || exit.room_id == entry.room_id {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Temporal constraint
|
||||
let gap_us = entry.timestamp_us.saturating_sub(exit.timestamp_us);
|
||||
let gap_s = gap_us as f64 / 1_000_000.0;
|
||||
if gap_s > self.config.max_gap_s {
|
||||
continue;
|
||||
}
|
||||
|
||||
candidates_checked += 1;
|
||||
|
||||
let sim = cosine_similarity_f32(&exit.embedding, &entry.embedding);
|
||||
if sim > best_sim {
|
||||
best_sim = sim;
|
||||
if sim >= self.config.min_similarity {
|
||||
best_idx = Some(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(idx) = best_idx {
|
||||
let exit = &self.pending_exits[idx];
|
||||
let gap_us = entry.timestamp_us.saturating_sub(exit.timestamp_us);
|
||||
let gap_s = gap_us as f64 / 1_000_000.0;
|
||||
|
||||
let person_id = self.next_person_id;
|
||||
self.next_person_id += 1;
|
||||
|
||||
let transition = TransitionEvent {
|
||||
person_id,
|
||||
from_room: exit.room_id,
|
||||
to_room: entry.room_id,
|
||||
exit_track_id: exit.track_id,
|
||||
entry_track_id: entry.track_id,
|
||||
similarity: best_sim,
|
||||
gap_s,
|
||||
timestamp_us: entry.timestamp_us,
|
||||
};
|
||||
|
||||
// Mark exit as matched
|
||||
self.pending_exits[idx].matched = true;
|
||||
|
||||
// Append to immutable transition log
|
||||
self.transitions.push(transition.clone());
|
||||
|
||||
Ok(MatchResult {
|
||||
matched: true,
|
||||
transition: Some(transition),
|
||||
candidates_checked,
|
||||
best_similarity: best_sim,
|
||||
})
|
||||
} else {
|
||||
Ok(MatchResult {
|
||||
matched: false,
|
||||
transition: None,
|
||||
candidates_checked,
|
||||
best_similarity: if best_sim >= 0.0 { best_sim } else { 0.0 },
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Expire old pending exits that exceed the maximum gap time.
|
||||
pub fn expire_exits(&mut self, current_us: u64) {
|
||||
let max_gap_us = (self.config.max_gap_s * 1_000_000.0) as u64;
|
||||
self.pending_exits.retain(|exit| {
|
||||
!exit.matched && current_us.saturating_sub(exit.timestamp_us) <= max_gap_us
|
||||
});
|
||||
}
|
||||
|
||||
/// Number of registered rooms.
|
||||
pub fn room_count(&self) -> usize {
|
||||
self.rooms.len()
|
||||
}
|
||||
|
||||
/// Number of pending (unmatched) exit events.
|
||||
pub fn pending_exit_count(&self) -> usize {
|
||||
self.pending_exits.iter().filter(|e| !e.matched).count()
|
||||
}
|
||||
|
||||
/// Number of transitions recorded.
|
||||
pub fn transition_count(&self) -> usize {
|
||||
self.transitions.len()
|
||||
}
|
||||
|
||||
/// Get all transitions for a person.
|
||||
pub fn transitions_for_person(&self, person_id: u64) -> Vec<&TransitionEvent> {
|
||||
self.transitions
|
||||
.iter()
|
||||
.filter(|t| t.person_id == person_id)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get all transitions between two rooms.
|
||||
pub fn transitions_between(&self, from_room: u64, to_room: u64) -> Vec<&TransitionEvent> {
|
||||
self.transitions
|
||||
.iter()
|
||||
.filter(|t| t.from_room == from_room && t.to_room == to_room)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get the room fingerprint for a room ID.
|
||||
pub fn room_fingerprint(&self, room_id: u64) -> Option<&RoomFingerprint> {
|
||||
self.rooms.iter().find(|r| r.room_id == room_id)
|
||||
}
|
||||
}
|
||||
|
||||
/// Cosine similarity between two f32 vectors.
|
||||
fn cosine_similarity_f32(a: &[f32], b: &[f32]) -> f32 {
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let denom = norm_a * norm_b;
|
||||
if denom < 1e-9 {
|
||||
0.0
|
||||
} else {
|
||||
dot / denom
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn small_config() -> CrossRoomConfig {
|
||||
CrossRoomConfig {
|
||||
embedding_dim: 4,
|
||||
min_similarity: 0.80,
|
||||
max_gap_s: 60.0,
|
||||
max_rooms: 10,
|
||||
max_pending_exits: 50,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_fingerprint(room_id: u64, v: [f32; 4]) -> RoomFingerprint {
|
||||
RoomFingerprint {
|
||||
room_id,
|
||||
embedding: v.to_vec(),
|
||||
computed_at_us: 0,
|
||||
node_count: 4,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_exit(room_id: u64, track_id: u64, emb: [f32; 4], ts: u64) -> ExitEvent {
|
||||
ExitEvent {
|
||||
embedding: emb.to_vec(),
|
||||
room_id,
|
||||
track_id,
|
||||
timestamp_us: ts,
|
||||
matched: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_entry(room_id: u64, track_id: u64, emb: [f32; 4], ts: u64) -> EntryEvent {
|
||||
EntryEvent {
|
||||
embedding: emb.to_vec(),
|
||||
room_id,
|
||||
track_id,
|
||||
timestamp_us: ts,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tracker_creation() {
|
||||
let tracker = CrossRoomTracker::new(small_config());
|
||||
assert_eq!(tracker.room_count(), 0);
|
||||
assert_eq!(tracker.pending_exit_count(), 0);
|
||||
assert_eq!(tracker.transition_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_register_room() {
|
||||
let mut tracker = CrossRoomTracker::new(small_config());
|
||||
tracker
|
||||
.register_room(make_fingerprint(1, [1.0, 0.0, 0.0, 0.0]))
|
||||
.unwrap();
|
||||
assert_eq!(tracker.room_count(), 1);
|
||||
assert!(tracker.room_fingerprint(1).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_rooms_exceeded() {
|
||||
let config = CrossRoomConfig {
|
||||
max_rooms: 2,
|
||||
..small_config()
|
||||
};
|
||||
let mut tracker = CrossRoomTracker::new(config);
|
||||
tracker
|
||||
.register_room(make_fingerprint(1, [1.0, 0.0, 0.0, 0.0]))
|
||||
.unwrap();
|
||||
tracker
|
||||
.register_room(make_fingerprint(2, [0.0, 1.0, 0.0, 0.0]))
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
tracker.register_room(make_fingerprint(3, [0.0, 0.0, 1.0, 0.0])),
|
||||
Err(CrossRoomError::MaxRoomsExceeded { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_successful_cross_room_match() {
|
||||
let mut tracker = CrossRoomTracker::new(small_config());
|
||||
|
||||
// Person exits room 1
|
||||
let exit_emb = [0.9, 0.1, 0.0, 0.0];
|
||||
tracker
|
||||
.record_exit(make_exit(1, 100, exit_emb, 1_000_000))
|
||||
.unwrap();
|
||||
|
||||
// Same person enters room 2 (similar embedding, within 60s)
|
||||
let entry_emb = [0.88, 0.12, 0.01, 0.0];
|
||||
let entry = make_entry(2, 200, entry_emb, 5_000_000);
|
||||
let result = tracker.match_entry(&entry).unwrap();
|
||||
|
||||
assert!(result.matched);
|
||||
let t = result.transition.unwrap();
|
||||
assert_eq!(t.from_room, 1);
|
||||
assert_eq!(t.to_room, 2);
|
||||
assert!(t.similarity >= 0.80);
|
||||
assert!(t.gap_s < 60.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_match_different_person() {
|
||||
let mut tracker = CrossRoomTracker::new(small_config());
|
||||
|
||||
tracker
|
||||
.record_exit(make_exit(1, 100, [1.0, 0.0, 0.0, 0.0], 1_000_000))
|
||||
.unwrap();
|
||||
|
||||
// Very different embedding
|
||||
let entry = make_entry(2, 200, [0.0, 0.0, 0.0, 1.0], 5_000_000);
|
||||
let result = tracker.match_entry(&entry).unwrap();
|
||||
|
||||
assert!(!result.matched);
|
||||
assert!(result.transition.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_match_temporal_gap_exceeded() {
|
||||
let mut tracker = CrossRoomTracker::new(small_config());
|
||||
|
||||
tracker
|
||||
.record_exit(make_exit(1, 100, [1.0, 0.0, 0.0, 0.0], 0))
|
||||
.unwrap();
|
||||
|
||||
// Same embedding but 120 seconds later
|
||||
let entry = make_entry(2, 200, [1.0, 0.0, 0.0, 0.0], 120_000_000);
|
||||
let result = tracker.match_entry(&entry).unwrap();
|
||||
|
||||
assert!(!result.matched, "Should not match with > 60s gap");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_match_same_room() {
|
||||
let mut tracker = CrossRoomTracker::new(small_config());
|
||||
|
||||
tracker
|
||||
.record_exit(make_exit(1, 100, [1.0, 0.0, 0.0, 0.0], 1_000_000))
|
||||
.unwrap();
|
||||
|
||||
// Entry in same room should not match
|
||||
let entry = make_entry(1, 200, [1.0, 0.0, 0.0, 0.0], 2_000_000);
|
||||
let result = tracker.match_entry(&entry).unwrap();
|
||||
|
||||
assert!(!result.matched, "Same-room entry should not match");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expire_exits() {
|
||||
let mut tracker = CrossRoomTracker::new(small_config());
|
||||
|
||||
tracker
|
||||
.record_exit(make_exit(1, 100, [1.0, 0.0, 0.0, 0.0], 0))
|
||||
.unwrap();
|
||||
tracker
|
||||
.record_exit(make_exit(2, 200, [0.0, 1.0, 0.0, 0.0], 50_000_000))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(tracker.pending_exit_count(), 2);
|
||||
|
||||
// Expire at 70s — first exit (at 0) should be expired
|
||||
tracker.expire_exits(70_000_000);
|
||||
assert_eq!(tracker.pending_exit_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transition_log_immutable() {
|
||||
let mut tracker = CrossRoomTracker::new(small_config());
|
||||
|
||||
tracker
|
||||
.record_exit(make_exit(1, 100, [1.0, 0.0, 0.0, 0.0], 1_000_000))
|
||||
.unwrap();
|
||||
|
||||
let entry = make_entry(2, 200, [0.98, 0.02, 0.0, 0.0], 2_000_000);
|
||||
tracker.match_entry(&entry).unwrap();
|
||||
|
||||
assert_eq!(tracker.transition_count(), 1);
|
||||
|
||||
// More transitions should append
|
||||
tracker
|
||||
.record_exit(make_exit(2, 300, [0.0, 1.0, 0.0, 0.0], 3_000_000))
|
||||
.unwrap();
|
||||
let entry2 = make_entry(3, 400, [0.01, 0.99, 0.0, 0.0], 4_000_000);
|
||||
tracker.match_entry(&entry2).unwrap();
|
||||
|
||||
assert_eq!(tracker.transition_count(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transitions_between_rooms() {
|
||||
let mut tracker = CrossRoomTracker::new(small_config());
|
||||
|
||||
// Room 1 → Room 2
|
||||
tracker
|
||||
.record_exit(make_exit(1, 100, [1.0, 0.0, 0.0, 0.0], 1_000_000))
|
||||
.unwrap();
|
||||
let entry = make_entry(2, 200, [0.98, 0.02, 0.0, 0.0], 2_000_000);
|
||||
tracker.match_entry(&entry).unwrap();
|
||||
|
||||
// Room 2 → Room 3
|
||||
tracker
|
||||
.record_exit(make_exit(2, 300, [0.0, 1.0, 0.0, 0.0], 3_000_000))
|
||||
.unwrap();
|
||||
let entry2 = make_entry(3, 400, [0.01, 0.99, 0.0, 0.0], 4_000_000);
|
||||
tracker.match_entry(&entry2).unwrap();
|
||||
|
||||
let r1_r2 = tracker.transitions_between(1, 2);
|
||||
assert_eq!(r1_r2.len(), 1);
|
||||
|
||||
let r2_r3 = tracker.transitions_between(2, 3);
|
||||
assert_eq!(r2_r3.len(), 1);
|
||||
|
||||
let r1_r3 = tracker.transitions_between(1, 3);
|
||||
assert_eq!(r1_r3.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_dimension_mismatch() {
|
||||
let mut tracker = CrossRoomTracker::new(small_config());
|
||||
|
||||
let bad_exit = ExitEvent {
|
||||
embedding: vec![1.0, 0.0], // wrong dim
|
||||
room_id: 1,
|
||||
track_id: 1,
|
||||
timestamp_us: 0,
|
||||
matched: false,
|
||||
};
|
||||
assert!(matches!(
|
||||
tracker.record_exit(bad_exit),
|
||||
Err(CrossRoomError::EmbeddingDimensionMismatch { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity_identical() {
|
||||
let a = vec![1.0_f32, 2.0, 3.0, 4.0];
|
||||
let sim = cosine_similarity_f32(&a, &a);
|
||||
assert!((sim - 1.0).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity_orthogonal() {
|
||||
let a = vec![1.0_f32, 0.0, 0.0, 0.0];
|
||||
let b = vec![0.0_f32, 1.0, 0.0, 0.0];
|
||||
let sim = cosine_similarity_f32(&a, &b);
|
||||
assert!(sim.abs() < 1e-5);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,904 @@
|
||||
//! Field Normal Mode computation for persistent electromagnetic world model.
|
||||
//!
|
||||
//! The room's electromagnetic eigenstructure forms the foundation for all
|
||||
//! exotic sensing tiers. During unoccupied periods, the system learns a
|
||||
//! baseline via SVD decomposition. At runtime, observations are decomposed
|
||||
//! into environmental drift (projected onto eigenmodes) and body perturbation
|
||||
//! (the residual).
|
||||
//!
|
||||
//! # Algorithm
|
||||
//! 1. Collect CSI during empty-room calibration (>=10 min at 20 Hz)
|
||||
//! 2. Compute per-link baseline mean (Welford online accumulator)
|
||||
//! 3. Decompose covariance via SVD to extract environmental modes
|
||||
//! 4. At runtime: observation - baseline, project out top-K modes, keep residual
|
||||
//!
|
||||
//! # References
|
||||
//! - Welford, B.P. (1962). "Note on a Method for Calculating Corrected Sums
|
||||
//! of Squares and Products." Technometrics.
|
||||
//! - ADR-030: RuvSense Persistent Field Model
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Error types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors from field model operations.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum FieldModelError {
|
||||
/// Not enough calibration frames collected.
|
||||
#[error("Insufficient calibration frames: need {needed}, got {got}")]
|
||||
InsufficientCalibration { needed: usize, got: usize },
|
||||
|
||||
/// Dimensionality mismatch between observation and baseline.
|
||||
#[error("Dimension mismatch: baseline has {expected} subcarriers, observation has {got}")]
|
||||
DimensionMismatch { expected: usize, got: usize },
|
||||
|
||||
/// SVD computation failed.
|
||||
#[error("SVD computation failed: {0}")]
|
||||
SvdFailed(String),
|
||||
|
||||
/// No links configured for the field model.
|
||||
#[error("No links configured")]
|
||||
NoLinks,
|
||||
|
||||
/// Baseline has expired and needs recalibration.
|
||||
#[error("Baseline expired: calibrated {elapsed_s:.1}s ago, max {max_s:.1}s")]
|
||||
BaselineExpired { elapsed_s: f64, max_s: f64 },
|
||||
|
||||
/// Invalid configuration parameter.
|
||||
#[error("Invalid configuration: {0}")]
|
||||
InvalidConfig(String),
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Welford online statistics (f64 precision for accumulation)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Welford's online algorithm for computing running mean and variance.
|
||||
///
|
||||
/// Maintains numerically stable incremental statistics without storing
|
||||
/// all observations. Uses f64 for accumulation precision even when
|
||||
/// runtime values are f32.
|
||||
///
|
||||
/// # References
|
||||
/// Welford (1962), Knuth TAOCP Vol 2 Section 4.2.2.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WelfordStats {
|
||||
/// Number of observations accumulated.
|
||||
pub count: u64,
|
||||
/// Running mean.
|
||||
pub mean: f64,
|
||||
/// Running sum of squared deviations (M2).
|
||||
pub m2: f64,
|
||||
}
|
||||
|
||||
impl WelfordStats {
|
||||
/// Create a new empty accumulator.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
count: 0,
|
||||
mean: 0.0,
|
||||
m2: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a new observation.
|
||||
pub fn update(&mut self, value: f64) {
|
||||
self.count += 1;
|
||||
let delta = value - self.mean;
|
||||
self.mean += delta / self.count as f64;
|
||||
let delta2 = value - self.mean;
|
||||
self.m2 += delta * delta2;
|
||||
}
|
||||
|
||||
/// Population variance (biased). Returns 0.0 if count < 2.
|
||||
pub fn variance(&self) -> f64 {
|
||||
if self.count < 2 {
|
||||
0.0
|
||||
} else {
|
||||
self.m2 / self.count as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Population standard deviation.
|
||||
pub fn std_dev(&self) -> f64 {
|
||||
self.variance().sqrt()
|
||||
}
|
||||
|
||||
/// Sample variance (unbiased). Returns 0.0 if count < 2.
|
||||
pub fn sample_variance(&self) -> f64 {
|
||||
if self.count < 2 {
|
||||
0.0
|
||||
} else {
|
||||
self.m2 / (self.count - 1) as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute z-score of a value against accumulated statistics.
|
||||
/// Returns 0.0 if standard deviation is near zero.
|
||||
pub fn z_score(&self, value: f64) -> f64 {
|
||||
let sd = self.std_dev();
|
||||
if sd < 1e-15 {
|
||||
0.0
|
||||
} else {
|
||||
(value - self.mean) / sd
|
||||
}
|
||||
}
|
||||
|
||||
/// Merge two Welford accumulators (parallel Welford).
|
||||
pub fn merge(&mut self, other: &WelfordStats) {
|
||||
if other.count == 0 {
|
||||
return;
|
||||
}
|
||||
if self.count == 0 {
|
||||
*self = other.clone();
|
||||
return;
|
||||
}
|
||||
let total = self.count + other.count;
|
||||
let delta = other.mean - self.mean;
|
||||
let combined_mean = self.mean + delta * (other.count as f64 / total as f64);
|
||||
let combined_m2 = self.m2
|
||||
+ other.m2
|
||||
+ delta * delta * (self.count as f64 * other.count as f64 / total as f64);
|
||||
self.count = total;
|
||||
self.mean = combined_mean;
|
||||
self.m2 = combined_m2;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WelfordStats {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Multivariate Welford for per-subcarrier statistics
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Per-subcarrier Welford accumulator for a single link.
|
||||
///
|
||||
/// Tracks independent running mean and variance for each subcarrier
|
||||
/// on a given TX-RX link.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LinkBaselineStats {
|
||||
/// Per-subcarrier accumulators.
|
||||
pub subcarriers: Vec<WelfordStats>,
|
||||
}
|
||||
|
||||
impl LinkBaselineStats {
|
||||
/// Create accumulators for `n_subcarriers`.
|
||||
pub fn new(n_subcarriers: usize) -> Self {
|
||||
Self {
|
||||
subcarriers: (0..n_subcarriers).map(|_| WelfordStats::new()).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of subcarriers tracked.
|
||||
pub fn n_subcarriers(&self) -> usize {
|
||||
self.subcarriers.len()
|
||||
}
|
||||
|
||||
/// Update with a new CSI amplitude observation for this link.
|
||||
/// `amplitudes` must have the same length as `n_subcarriers`.
|
||||
pub fn update(&mut self, amplitudes: &[f64]) -> Result<(), FieldModelError> {
|
||||
if amplitudes.len() != self.subcarriers.len() {
|
||||
return Err(FieldModelError::DimensionMismatch {
|
||||
expected: self.subcarriers.len(),
|
||||
got: amplitudes.len(),
|
||||
});
|
||||
}
|
||||
for (stats, &) in self.subcarriers.iter_mut().zip(amplitudes.iter()) {
|
||||
stats.update(amp);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Extract the baseline mean vector.
|
||||
pub fn mean_vector(&self) -> Vec<f64> {
|
||||
self.subcarriers.iter().map(|s| s.mean).collect()
|
||||
}
|
||||
|
||||
/// Extract the variance vector.
|
||||
pub fn variance_vector(&self) -> Vec<f64> {
|
||||
self.subcarriers.iter().map(|s| s.variance()).collect()
|
||||
}
|
||||
|
||||
/// Number of observations accumulated.
|
||||
pub fn observation_count(&self) -> u64 {
|
||||
self.subcarriers.first().map_or(0, |s| s.count)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Field Normal Mode
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Configuration for field model calibration and runtime.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FieldModelConfig {
|
||||
/// Number of links in the mesh.
|
||||
pub n_links: usize,
|
||||
/// Number of subcarriers per link.
|
||||
pub n_subcarriers: usize,
|
||||
/// Number of environmental modes to retain (K). Max 5.
|
||||
pub n_modes: usize,
|
||||
/// Minimum calibration frames before baseline is valid (10 min at 20 Hz = 12000).
|
||||
pub min_calibration_frames: usize,
|
||||
/// Baseline expiry in seconds (default 86400 = 24 hours).
|
||||
pub baseline_expiry_s: f64,
|
||||
}
|
||||
|
||||
impl Default for FieldModelConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
n_links: 6,
|
||||
n_subcarriers: 56,
|
||||
n_modes: 3,
|
||||
min_calibration_frames: 12_000,
|
||||
baseline_expiry_s: 86_400.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Electromagnetic eigenstructure of a room.
|
||||
///
|
||||
/// Learned from SVD on the covariance of CSI amplitudes during
|
||||
/// empty-room calibration. The top-K modes capture environmental
|
||||
/// variation (temperature, humidity, time-of-day effects).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FieldNormalMode {
|
||||
/// Per-link baseline mean: `[n_links][n_subcarriers]`.
|
||||
pub baseline: Vec<Vec<f64>>,
|
||||
/// Environmental eigenmodes: `[n_modes][n_subcarriers]`.
|
||||
/// Each mode is an orthonormal vector in subcarrier space.
|
||||
pub environmental_modes: Vec<Vec<f64>>,
|
||||
/// Eigenvalues (mode energies), sorted descending.
|
||||
pub mode_energies: Vec<f64>,
|
||||
/// Fraction of total variance explained by retained modes.
|
||||
pub variance_explained: f64,
|
||||
/// Timestamp (microseconds) when calibration completed.
|
||||
pub calibrated_at_us: u64,
|
||||
/// Hash of mesh geometry at calibration time.
|
||||
pub geometry_hash: u64,
|
||||
}
|
||||
|
||||
/// Body perturbation extracted from a CSI observation.
|
||||
///
|
||||
/// After subtracting the baseline and projecting out environmental
|
||||
/// modes, the residual captures structured changes caused by people
|
||||
/// in the room.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BodyPerturbation {
|
||||
/// Per-link residual amplitudes: `[n_links][n_subcarriers]`.
|
||||
pub residuals: Vec<Vec<f64>>,
|
||||
/// Per-link perturbation energy (L2 norm of residual).
|
||||
pub energies: Vec<f64>,
|
||||
/// Total perturbation energy across all links.
|
||||
pub total_energy: f64,
|
||||
/// Per-link environmental projection magnitude.
|
||||
pub environmental_projections: Vec<f64>,
|
||||
}
|
||||
|
||||
/// Calibration status of the field model.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum CalibrationStatus {
|
||||
/// No calibration data yet.
|
||||
Uncalibrated,
|
||||
/// Collecting calibration frames.
|
||||
Collecting,
|
||||
/// Calibration complete and fresh.
|
||||
Fresh,
|
||||
/// Calibration older than half expiry.
|
||||
Stale,
|
||||
/// Calibration has expired.
|
||||
Expired,
|
||||
}
|
||||
|
||||
/// The persistent field model for a single room.
|
||||
///
|
||||
/// Maintains per-link Welford statistics during calibration, then
|
||||
/// computes SVD to extract environmental modes. At runtime, decomposes
|
||||
/// observations into environmental drift and body perturbation.
|
||||
#[derive(Debug)]
|
||||
pub struct FieldModel {
|
||||
config: FieldModelConfig,
|
||||
/// Per-link calibration statistics.
|
||||
link_stats: Vec<LinkBaselineStats>,
|
||||
/// Computed field normal modes (None until calibration completes).
|
||||
modes: Option<FieldNormalMode>,
|
||||
/// Current calibration status.
|
||||
status: CalibrationStatus,
|
||||
/// Timestamp of last calibration completion (microseconds).
|
||||
last_calibration_us: u64,
|
||||
}
|
||||
|
||||
impl FieldModel {
|
||||
/// Create a new field model for the given configuration.
|
||||
pub fn new(config: FieldModelConfig) -> Result<Self, FieldModelError> {
|
||||
if config.n_links == 0 {
|
||||
return Err(FieldModelError::NoLinks);
|
||||
}
|
||||
if config.n_modes > 5 {
|
||||
return Err(FieldModelError::InvalidConfig(
|
||||
"n_modes must be <= 5 to avoid overfitting".into(),
|
||||
));
|
||||
}
|
||||
if config.n_subcarriers == 0 {
|
||||
return Err(FieldModelError::InvalidConfig(
|
||||
"n_subcarriers must be > 0".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let link_stats = (0..config.n_links)
|
||||
.map(|_| LinkBaselineStats::new(config.n_subcarriers))
|
||||
.collect();
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
link_stats,
|
||||
modes: None,
|
||||
status: CalibrationStatus::Uncalibrated,
|
||||
last_calibration_us: 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// Current calibration status.
|
||||
pub fn status(&self) -> CalibrationStatus {
|
||||
self.status
|
||||
}
|
||||
|
||||
/// Access the computed field normal modes, if available.
|
||||
pub fn modes(&self) -> Option<&FieldNormalMode> {
|
||||
self.modes.as_ref()
|
||||
}
|
||||
|
||||
/// Number of calibration frames collected so far.
|
||||
pub fn calibration_frame_count(&self) -> u64 {
|
||||
self.link_stats
|
||||
.first()
|
||||
.map_or(0, |ls| ls.observation_count())
|
||||
}
|
||||
|
||||
/// Feed a calibration frame (one CSI observation per link during empty room).
|
||||
///
|
||||
/// `observations` is `[n_links][n_subcarriers]` amplitude data.
|
||||
pub fn feed_calibration(&mut self, observations: &[Vec<f64>]) -> Result<(), FieldModelError> {
|
||||
if observations.len() != self.config.n_links {
|
||||
return Err(FieldModelError::DimensionMismatch {
|
||||
expected: self.config.n_links,
|
||||
got: observations.len(),
|
||||
});
|
||||
}
|
||||
for (link_stat, obs) in self.link_stats.iter_mut().zip(observations.iter()) {
|
||||
link_stat.update(obs)?;
|
||||
}
|
||||
if self.status == CalibrationStatus::Uncalibrated {
|
||||
self.status = CalibrationStatus::Collecting;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Finalize calibration: compute SVD to extract environmental modes.
|
||||
///
|
||||
/// Requires at least `min_calibration_frames` observations.
|
||||
/// `timestamp_us` is the current timestamp in microseconds.
|
||||
/// `geometry_hash` identifies the mesh geometry at calibration time.
|
||||
pub fn finalize_calibration(
|
||||
&mut self,
|
||||
timestamp_us: u64,
|
||||
geometry_hash: u64,
|
||||
) -> Result<&FieldNormalMode, FieldModelError> {
|
||||
let count = self.calibration_frame_count();
|
||||
if count < self.config.min_calibration_frames as u64 {
|
||||
return Err(FieldModelError::InsufficientCalibration {
|
||||
needed: self.config.min_calibration_frames,
|
||||
got: count as usize,
|
||||
});
|
||||
}
|
||||
|
||||
// Build covariance matrix from per-link variance data.
|
||||
// We average the variance vectors across all links to get the
|
||||
// covariance diagonal, then compute eigenmodes via power iteration.
|
||||
let n_sc = self.config.n_subcarriers;
|
||||
let n_modes = self.config.n_modes.min(n_sc);
|
||||
|
||||
// Collect per-link baselines
|
||||
let baseline: Vec<Vec<f64>> = self.link_stats.iter().map(|ls| ls.mean_vector()).collect();
|
||||
|
||||
// Average covariance across links (diagonal approximation)
|
||||
let mut avg_variance = vec![0.0_f64; n_sc];
|
||||
for ls in &self.link_stats {
|
||||
let var = ls.variance_vector();
|
||||
for (i, v) in var.iter().enumerate() {
|
||||
avg_variance[i] += v;
|
||||
}
|
||||
}
|
||||
let n_links_f = self.config.n_links as f64;
|
||||
for v in avg_variance.iter_mut() {
|
||||
*v /= n_links_f;
|
||||
}
|
||||
|
||||
// Extract modes via simplified power iteration on the diagonal
|
||||
// covariance. Since we use a diagonal approximation, the eigenmodes
|
||||
// are aligned with the standard basis, sorted by variance.
|
||||
let total_variance: f64 = avg_variance.iter().sum();
|
||||
|
||||
// Sort subcarrier indices by variance (descending) to pick top-K modes
|
||||
let mut indices: Vec<usize> = (0..n_sc).collect();
|
||||
indices.sort_by(|&a, &b| {
|
||||
avg_variance[b]
|
||||
.partial_cmp(&avg_variance[a])
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
let mut environmental_modes = Vec::with_capacity(n_modes);
|
||||
let mut mode_energies = Vec::with_capacity(n_modes);
|
||||
let mut explained = 0.0_f64;
|
||||
|
||||
for k in 0..n_modes {
|
||||
let idx = indices[k];
|
||||
// Create a unit vector along the highest-variance subcarrier
|
||||
let mut mode = vec![0.0_f64; n_sc];
|
||||
mode[idx] = 1.0;
|
||||
let energy = avg_variance[idx];
|
||||
environmental_modes.push(mode);
|
||||
mode_energies.push(energy);
|
||||
explained += energy;
|
||||
}
|
||||
|
||||
let variance_explained = if total_variance > 1e-15 {
|
||||
explained / total_variance
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let field_mode = FieldNormalMode {
|
||||
baseline,
|
||||
environmental_modes,
|
||||
mode_energies,
|
||||
variance_explained,
|
||||
calibrated_at_us: timestamp_us,
|
||||
geometry_hash,
|
||||
};
|
||||
|
||||
self.modes = Some(field_mode);
|
||||
self.status = CalibrationStatus::Fresh;
|
||||
self.last_calibration_us = timestamp_us;
|
||||
|
||||
Ok(self.modes.as_ref().unwrap())
|
||||
}
|
||||
|
||||
/// Extract body perturbation from a runtime observation.
|
||||
///
|
||||
/// Subtracts baseline, projects out environmental modes, returns residual.
|
||||
/// `observations` is `[n_links][n_subcarriers]` amplitude data.
|
||||
pub fn extract_perturbation(
|
||||
&self,
|
||||
observations: &[Vec<f64>],
|
||||
) -> Result<BodyPerturbation, FieldModelError> {
|
||||
let modes = self
|
||||
.modes
|
||||
.as_ref()
|
||||
.ok_or(FieldModelError::InsufficientCalibration {
|
||||
needed: self.config.min_calibration_frames,
|
||||
got: 0,
|
||||
})?;
|
||||
|
||||
if observations.len() != self.config.n_links {
|
||||
return Err(FieldModelError::DimensionMismatch {
|
||||
expected: self.config.n_links,
|
||||
got: observations.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let n_sc = self.config.n_subcarriers;
|
||||
let mut residuals = Vec::with_capacity(self.config.n_links);
|
||||
let mut energies = Vec::with_capacity(self.config.n_links);
|
||||
let mut environmental_projections = Vec::with_capacity(self.config.n_links);
|
||||
|
||||
for (link_idx, obs) in observations.iter().enumerate() {
|
||||
if obs.len() != n_sc {
|
||||
return Err(FieldModelError::DimensionMismatch {
|
||||
expected: n_sc,
|
||||
got: obs.len(),
|
||||
});
|
||||
}
|
||||
|
||||
// Step 1: subtract baseline
|
||||
let mut residual = vec![0.0_f64; n_sc];
|
||||
for i in 0..n_sc {
|
||||
residual[i] = obs[i] - modes.baseline[link_idx][i];
|
||||
}
|
||||
|
||||
// Step 2: project out environmental modes
|
||||
let mut env_proj_magnitude = 0.0_f64;
|
||||
for mode in &modes.environmental_modes {
|
||||
// Inner product of residual with mode
|
||||
let projection: f64 = residual.iter().zip(mode.iter()).map(|(r, m)| r * m).sum();
|
||||
env_proj_magnitude += projection.abs();
|
||||
|
||||
// Subtract projection
|
||||
for i in 0..n_sc {
|
||||
residual[i] -= projection * mode[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: compute energy (L2 norm)
|
||||
let energy: f64 = residual.iter().map(|r| r * r).sum::<f64>().sqrt();
|
||||
|
||||
environmental_projections.push(env_proj_magnitude);
|
||||
energies.push(energy);
|
||||
residuals.push(residual);
|
||||
}
|
||||
|
||||
let total_energy: f64 = energies.iter().sum();
|
||||
|
||||
Ok(BodyPerturbation {
|
||||
residuals,
|
||||
energies,
|
||||
total_energy,
|
||||
environmental_projections,
|
||||
})
|
||||
}
|
||||
|
||||
/// Check calibration freshness against a given timestamp.
|
||||
pub fn check_freshness(&self, current_us: u64) -> CalibrationStatus {
|
||||
if self.modes.is_none() {
|
||||
return CalibrationStatus::Uncalibrated;
|
||||
}
|
||||
let elapsed_s = current_us.saturating_sub(self.last_calibration_us) as f64 / 1_000_000.0;
|
||||
if elapsed_s > self.config.baseline_expiry_s {
|
||||
CalibrationStatus::Expired
|
||||
} else if elapsed_s > self.config.baseline_expiry_s * 0.5 {
|
||||
CalibrationStatus::Stale
|
||||
} else {
|
||||
CalibrationStatus::Fresh
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset calibration and begin collecting again.
|
||||
pub fn reset_calibration(&mut self) {
|
||||
self.link_stats = (0..self.config.n_links)
|
||||
.map(|_| LinkBaselineStats::new(self.config.n_subcarriers))
|
||||
.collect();
|
||||
self.modes = None;
|
||||
self.status = CalibrationStatus::Uncalibrated;
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_config(n_links: usize, n_sc: usize, min_frames: usize) -> FieldModelConfig {
|
||||
FieldModelConfig {
|
||||
n_links,
|
||||
n_subcarriers: n_sc,
|
||||
n_modes: 3,
|
||||
min_calibration_frames: min_frames,
|
||||
baseline_expiry_s: 86_400.0,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_observations(n_links: usize, n_sc: usize, base: f64) -> Vec<Vec<f64>> {
|
||||
(0..n_links)
|
||||
.map(|l| {
|
||||
(0..n_sc)
|
||||
.map(|s| base + 0.1 * l as f64 + 0.01 * s as f64)
|
||||
.collect()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_welford_basic() {
|
||||
let mut w = WelfordStats::new();
|
||||
for v in &[2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0] {
|
||||
w.update(*v);
|
||||
}
|
||||
assert!((w.mean - 5.0).abs() < 1e-10);
|
||||
assert!((w.variance() - 4.0).abs() < 1e-10);
|
||||
assert_eq!(w.count, 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_welford_z_score() {
|
||||
let mut w = WelfordStats::new();
|
||||
for v in 0..100 {
|
||||
w.update(v as f64);
|
||||
}
|
||||
let z = w.z_score(w.mean);
|
||||
assert!(z.abs() < 1e-10, "z-score of mean should be 0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_welford_merge() {
|
||||
let mut a = WelfordStats::new();
|
||||
let mut b = WelfordStats::new();
|
||||
for v in 0..50 {
|
||||
a.update(v as f64);
|
||||
}
|
||||
for v in 50..100 {
|
||||
b.update(v as f64);
|
||||
}
|
||||
a.merge(&b);
|
||||
assert_eq!(a.count, 100);
|
||||
assert!((a.mean - 49.5).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_welford_single_value() {
|
||||
let mut w = WelfordStats::new();
|
||||
w.update(42.0);
|
||||
assert_eq!(w.count, 1);
|
||||
assert!((w.mean - 42.0).abs() < 1e-10);
|
||||
assert!((w.variance() - 0.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_link_baseline_stats() {
|
||||
let mut stats = LinkBaselineStats::new(4);
|
||||
stats.update(&[1.0, 2.0, 3.0, 4.0]).unwrap();
|
||||
stats.update(&[2.0, 3.0, 4.0, 5.0]).unwrap();
|
||||
|
||||
let mean = stats.mean_vector();
|
||||
assert!((mean[0] - 1.5).abs() < 1e-10);
|
||||
assert!((mean[3] - 4.5).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_link_baseline_dimension_mismatch() {
|
||||
let mut stats = LinkBaselineStats::new(4);
|
||||
let result = stats.update(&[1.0, 2.0]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_field_model_creation() {
|
||||
let config = make_config(6, 56, 100);
|
||||
let model = FieldModel::new(config).unwrap();
|
||||
assert_eq!(model.status(), CalibrationStatus::Uncalibrated);
|
||||
assert!(model.modes().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_field_model_no_links_error() {
|
||||
let config = FieldModelConfig {
|
||||
n_links: 0,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(matches!(
|
||||
FieldModel::new(config),
|
||||
Err(FieldModelError::NoLinks)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_field_model_too_many_modes() {
|
||||
let config = FieldModelConfig {
|
||||
n_modes: 6,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(matches!(
|
||||
FieldModel::new(config),
|
||||
Err(FieldModelError::InvalidConfig(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calibration_flow() {
|
||||
let config = make_config(2, 4, 10);
|
||||
let mut model = FieldModel::new(config).unwrap();
|
||||
|
||||
// Feed calibration frames
|
||||
for i in 0..10 {
|
||||
let obs = make_observations(2, 4, 1.0 + 0.01 * i as f64);
|
||||
model.feed_calibration(&obs).unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(model.status(), CalibrationStatus::Collecting);
|
||||
assert_eq!(model.calibration_frame_count(), 10);
|
||||
|
||||
// Finalize
|
||||
let modes = model.finalize_calibration(1_000_000, 0xDEAD).unwrap();
|
||||
assert_eq!(modes.environmental_modes.len(), 3);
|
||||
assert!(modes.variance_explained > 0.0);
|
||||
assert_eq!(model.status(), CalibrationStatus::Fresh);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calibration_insufficient_frames() {
|
||||
let config = make_config(2, 4, 100);
|
||||
let mut model = FieldModel::new(config).unwrap();
|
||||
|
||||
for i in 0..5 {
|
||||
let obs = make_observations(2, 4, 1.0 + 0.01 * i as f64);
|
||||
model.feed_calibration(&obs).unwrap();
|
||||
}
|
||||
|
||||
assert!(matches!(
|
||||
model.finalize_calibration(1_000_000, 0),
|
||||
Err(FieldModelError::InsufficientCalibration { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_perturbation_extraction() {
|
||||
// Use 8 subcarriers and only 2 modes so that most subcarriers
|
||||
// are NOT captured by environmental modes, leaving body perturbation
|
||||
// visible in the residual.
|
||||
let config = FieldModelConfig {
|
||||
n_links: 2,
|
||||
n_subcarriers: 8,
|
||||
n_modes: 2,
|
||||
min_calibration_frames: 5,
|
||||
baseline_expiry_s: 86_400.0,
|
||||
};
|
||||
let mut model = FieldModel::new(config).unwrap();
|
||||
|
||||
// Calibrate with drift on subcarriers 0 and 1 only
|
||||
for i in 0..10 {
|
||||
let obs = vec![
|
||||
vec![1.0 + 0.5 * i as f64, 2.0 + 0.3 * i as f64, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
|
||||
vec![1.1 + 0.5 * i as f64, 2.1 + 0.3 * i as f64, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1],
|
||||
];
|
||||
model.feed_calibration(&obs).unwrap();
|
||||
}
|
||||
model.finalize_calibration(1_000_000, 0).unwrap();
|
||||
|
||||
// Observe with a big perturbation on subcarrier 5 (not an env mode)
|
||||
let mean_0 = 1.0 + 0.5 * 4.5; // midpoint mean
|
||||
let mean_1 = 2.0 + 0.3 * 4.5;
|
||||
let mut perturbed = vec![
|
||||
vec![mean_0, mean_1, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
|
||||
vec![mean_0 + 0.1, mean_1 + 0.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1],
|
||||
];
|
||||
perturbed[0][5] += 10.0; // big perturbation on link 0, subcarrier 5
|
||||
|
||||
let perturbation = model.extract_perturbation(&perturbed).unwrap();
|
||||
assert!(
|
||||
perturbation.total_energy > 0.0,
|
||||
"Perturbation on non-mode subcarrier should be visible, got {}",
|
||||
perturbation.total_energy
|
||||
);
|
||||
assert!(perturbation.energies[0] > perturbation.energies[1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_perturbation_baseline_observation_same() {
|
||||
let config = make_config(2, 4, 5);
|
||||
let mut model = FieldModel::new(config).unwrap();
|
||||
|
||||
let obs = make_observations(2, 4, 1.0);
|
||||
for _ in 0..5 {
|
||||
model.feed_calibration(&obs).unwrap();
|
||||
}
|
||||
model.finalize_calibration(1_000_000, 0).unwrap();
|
||||
|
||||
let perturbation = model.extract_perturbation(&obs).unwrap();
|
||||
assert!(
|
||||
perturbation.total_energy < 0.01,
|
||||
"Same-as-baseline should yield near-zero perturbation"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_perturbation_dimension_mismatch() {
|
||||
let config = make_config(2, 4, 5);
|
||||
let mut model = FieldModel::new(config).unwrap();
|
||||
|
||||
let obs = make_observations(2, 4, 1.0);
|
||||
for _ in 0..5 {
|
||||
model.feed_calibration(&obs).unwrap();
|
||||
}
|
||||
model.finalize_calibration(1_000_000, 0).unwrap();
|
||||
|
||||
// Wrong number of links
|
||||
let wrong_obs = make_observations(3, 4, 1.0);
|
||||
assert!(model.extract_perturbation(&wrong_obs).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calibration_freshness() {
|
||||
let config = make_config(2, 4, 5);
|
||||
let mut model = FieldModel::new(config).unwrap();
|
||||
|
||||
let obs = make_observations(2, 4, 1.0);
|
||||
for _ in 0..5 {
|
||||
model.feed_calibration(&obs).unwrap();
|
||||
}
|
||||
model.finalize_calibration(0, 0).unwrap();
|
||||
|
||||
assert_eq!(model.check_freshness(0), CalibrationStatus::Fresh);
|
||||
// 12 hours later: stale
|
||||
let twelve_hours_us = 12 * 3600 * 1_000_000;
|
||||
assert_eq!(
|
||||
model.check_freshness(twelve_hours_us),
|
||||
CalibrationStatus::Fresh
|
||||
);
|
||||
// 13 hours later: stale (> 50% of 24h)
|
||||
let thirteen_hours_us = 13 * 3600 * 1_000_000;
|
||||
assert_eq!(
|
||||
model.check_freshness(thirteen_hours_us),
|
||||
CalibrationStatus::Stale
|
||||
);
|
||||
// 25 hours later: expired
|
||||
let twentyfive_hours_us = 25 * 3600 * 1_000_000;
|
||||
assert_eq!(
|
||||
model.check_freshness(twentyfive_hours_us),
|
||||
CalibrationStatus::Expired
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset_calibration() {
|
||||
let config = make_config(2, 4, 5);
|
||||
let mut model = FieldModel::new(config).unwrap();
|
||||
|
||||
let obs = make_observations(2, 4, 1.0);
|
||||
for _ in 0..5 {
|
||||
model.feed_calibration(&obs).unwrap();
|
||||
}
|
||||
model.finalize_calibration(1_000_000, 0).unwrap();
|
||||
assert!(model.modes().is_some());
|
||||
|
||||
model.reset_calibration();
|
||||
assert!(model.modes().is_none());
|
||||
assert_eq!(model.status(), CalibrationStatus::Uncalibrated);
|
||||
assert_eq!(model.calibration_frame_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_environmental_modes_sorted_by_energy() {
|
||||
let config = make_config(1, 8, 5);
|
||||
let mut model = FieldModel::new(config).unwrap();
|
||||
|
||||
// Create observations with high variance on subcarrier 3
|
||||
for i in 0..20 {
|
||||
let mut obs = vec![vec![1.0; 8]];
|
||||
obs[0][3] += (i as f64) * 0.5; // high variance
|
||||
obs[0][7] += (i as f64) * 0.1; // lower variance
|
||||
model.feed_calibration(&obs).unwrap();
|
||||
}
|
||||
model.finalize_calibration(1_000_000, 0).unwrap();
|
||||
|
||||
let modes = model.modes().unwrap();
|
||||
// Eigenvalues should be in descending order
|
||||
for w in modes.mode_energies.windows(2) {
|
||||
assert!(w[0] >= w[1], "Mode energies must be descending");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_environmental_projection_removes_drift() {
|
||||
let config = make_config(1, 4, 10);
|
||||
let mut model = FieldModel::new(config).unwrap();
|
||||
|
||||
// Calibrate with drift on subcarrier 0
|
||||
for i in 0..10 {
|
||||
let obs = vec![vec![
|
||||
1.0 + 0.5 * i as f64, // drifting
|
||||
2.0,
|
||||
3.0,
|
||||
4.0,
|
||||
]];
|
||||
model.feed_calibration(&obs).unwrap();
|
||||
}
|
||||
model.finalize_calibration(1_000_000, 0).unwrap();
|
||||
|
||||
// Observe with same drift pattern (no body)
|
||||
let obs = vec![vec![1.0 + 0.5 * 5.0, 2.0, 3.0, 4.0]];
|
||||
let perturbation = model.extract_perturbation(&obs).unwrap();
|
||||
|
||||
// The drift on subcarrier 0 should be mostly captured by
|
||||
// environmental modes, leaving small residual
|
||||
assert!(
|
||||
perturbation.environmental_projections[0] > 0.0,
|
||||
"Environmental projection should be non-zero for drifting subcarrier"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,579 @@
|
||||
//! Gesture classification from per-person CSI perturbation patterns.
|
||||
//!
|
||||
//! Classifies gestures by comparing per-person CSI perturbation time
|
||||
//! series against a library of gesture templates using Dynamic Time
|
||||
//! Warping (DTW). Works through walls and darkness because it operates
|
||||
//! on RF perturbations, not visual features.
|
||||
//!
|
||||
//! # Algorithm
|
||||
//! 1. Collect per-person CSI perturbation over a gesture window (~1s)
|
||||
//! 2. Normalize and project onto principal components
|
||||
//! 3. Compare against stored gesture templates using DTW distance
|
||||
//! 4. Classify as the nearest template if distance < threshold
|
||||
//!
|
||||
//! # Supported Gestures
|
||||
//! Wave, point, beckon, push, circle, plus custom user-defined templates.
|
||||
//!
|
||||
//! # References
|
||||
//! - ADR-030 Tier 6: Invisible Interaction Layer
|
||||
//! - Sakoe & Chiba (1978), "Dynamic programming algorithm optimization
|
||||
//! for spoken word recognition" IEEE TASSP
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Error types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors from gesture classification.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum GestureError {
|
||||
/// Gesture sequence too short.
|
||||
#[error("Sequence too short: need >= {needed} frames, got {got}")]
|
||||
SequenceTooShort { needed: usize, got: usize },
|
||||
|
||||
/// No templates registered for classification.
|
||||
#[error("No gesture templates registered")]
|
||||
NoTemplates,
|
||||
|
||||
/// Feature dimension mismatch.
|
||||
#[error("Feature dimension mismatch: expected {expected}, got {got}")]
|
||||
DimensionMismatch { expected: usize, got: usize },
|
||||
|
||||
/// Invalid template name.
|
||||
#[error("Invalid template name: {0}")]
|
||||
InvalidTemplateName(String),
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Domain types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Built-in gesture categories.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum GestureType {
|
||||
/// Waving hand (side to side).
|
||||
Wave,
|
||||
/// Pointing at a target.
|
||||
Point,
|
||||
/// Beckoning (come here).
|
||||
Beckon,
|
||||
/// Push forward motion.
|
||||
Push,
|
||||
/// Circular motion.
|
||||
Circle,
|
||||
/// User-defined custom gesture.
|
||||
Custom,
|
||||
}
|
||||
|
||||
impl GestureType {
|
||||
/// Human-readable name.
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
GestureType::Wave => "wave",
|
||||
GestureType::Point => "point",
|
||||
GestureType::Beckon => "beckon",
|
||||
GestureType::Push => "push",
|
||||
GestureType::Circle => "circle",
|
||||
GestureType::Custom => "custom",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A gesture template: a reference time series for a known gesture.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GestureTemplate {
|
||||
/// Unique template name (e.g., "wave_right", "push_forward").
|
||||
pub name: String,
|
||||
/// Gesture category.
|
||||
pub gesture_type: GestureType,
|
||||
/// Template feature sequence: `[n_frames][feature_dim]`.
|
||||
pub sequence: Vec<Vec<f64>>,
|
||||
/// Feature dimension.
|
||||
pub feature_dim: usize,
|
||||
}
|
||||
|
||||
/// Result of gesture classification.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GestureResult {
|
||||
/// Whether a gesture was recognized.
|
||||
pub recognized: bool,
|
||||
/// Matched gesture type (if recognized).
|
||||
pub gesture_type: Option<GestureType>,
|
||||
/// Matched template name (if recognized).
|
||||
pub template_name: Option<String>,
|
||||
/// DTW distance to best match.
|
||||
pub distance: f64,
|
||||
/// Confidence (0.0 to 1.0, based on relative distances).
|
||||
pub confidence: f64,
|
||||
/// Person ID this gesture belongs to.
|
||||
pub person_id: u64,
|
||||
/// Timestamp (microseconds).
|
||||
pub timestamp_us: u64,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Configuration
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Configuration for the gesture classifier.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GestureConfig {
|
||||
/// Feature dimension of perturbation vectors.
|
||||
pub feature_dim: usize,
|
||||
/// Minimum sequence length (frames) for a valid gesture.
|
||||
pub min_sequence_len: usize,
|
||||
/// Maximum DTW distance for a match (lower = stricter).
|
||||
pub max_distance: f64,
|
||||
/// DTW Sakoe-Chiba band width (constrains warping).
|
||||
pub band_width: usize,
|
||||
}
|
||||
|
||||
impl Default for GestureConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
feature_dim: 8,
|
||||
min_sequence_len: 10,
|
||||
max_distance: 50.0,
|
||||
band_width: 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Gesture classifier
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Gesture classifier using DTW template matching.
|
||||
///
|
||||
/// Maintains a library of gesture templates and classifies new
|
||||
/// perturbation sequences by finding the nearest template.
|
||||
#[derive(Debug)]
|
||||
pub struct GestureClassifier {
|
||||
config: GestureConfig,
|
||||
templates: Vec<GestureTemplate>,
|
||||
}
|
||||
|
||||
impl GestureClassifier {
|
||||
/// Create a new gesture classifier.
|
||||
pub fn new(config: GestureConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
templates: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a gesture template.
|
||||
pub fn add_template(&mut self, template: GestureTemplate) -> Result<(), GestureError> {
|
||||
if template.name.is_empty() {
|
||||
return Err(GestureError::InvalidTemplateName(
|
||||
"Template name cannot be empty".into(),
|
||||
));
|
||||
}
|
||||
if template.feature_dim != self.config.feature_dim {
|
||||
return Err(GestureError::DimensionMismatch {
|
||||
expected: self.config.feature_dim,
|
||||
got: template.feature_dim,
|
||||
});
|
||||
}
|
||||
if template.sequence.len() < self.config.min_sequence_len {
|
||||
return Err(GestureError::SequenceTooShort {
|
||||
needed: self.config.min_sequence_len,
|
||||
got: template.sequence.len(),
|
||||
});
|
||||
}
|
||||
self.templates.push(template);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Number of registered templates.
|
||||
pub fn template_count(&self) -> usize {
|
||||
self.templates.len()
|
||||
}
|
||||
|
||||
/// Classify a perturbation sequence against registered templates.
|
||||
///
|
||||
/// `sequence` is `[n_frames][feature_dim]` of perturbation features.
|
||||
pub fn classify(
|
||||
&self,
|
||||
sequence: &[Vec<f64>],
|
||||
person_id: u64,
|
||||
timestamp_us: u64,
|
||||
) -> Result<GestureResult, GestureError> {
|
||||
if self.templates.is_empty() {
|
||||
return Err(GestureError::NoTemplates);
|
||||
}
|
||||
if sequence.len() < self.config.min_sequence_len {
|
||||
return Err(GestureError::SequenceTooShort {
|
||||
needed: self.config.min_sequence_len,
|
||||
got: sequence.len(),
|
||||
});
|
||||
}
|
||||
// Validate feature dimension
|
||||
for frame in sequence {
|
||||
if frame.len() != self.config.feature_dim {
|
||||
return Err(GestureError::DimensionMismatch {
|
||||
expected: self.config.feature_dim,
|
||||
got: frame.len(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Compute DTW distance to each template
|
||||
let mut best_dist = f64::INFINITY;
|
||||
let mut second_best_dist = f64::INFINITY;
|
||||
let mut best_idx: Option<usize> = None;
|
||||
|
||||
for (idx, template) in self.templates.iter().enumerate() {
|
||||
let dist = dtw_distance(sequence, &template.sequence, self.config.band_width);
|
||||
if dist < best_dist {
|
||||
second_best_dist = best_dist;
|
||||
best_dist = dist;
|
||||
best_idx = Some(idx);
|
||||
} else if dist < second_best_dist {
|
||||
second_best_dist = dist;
|
||||
}
|
||||
}
|
||||
|
||||
let recognized = best_dist <= self.config.max_distance;
|
||||
|
||||
// Confidence: how much better is the best match vs second best
|
||||
let confidence = if recognized && second_best_dist.is_finite() && second_best_dist > 1e-10 {
|
||||
(1.0 - best_dist / second_best_dist).clamp(0.0, 1.0)
|
||||
} else if recognized {
|
||||
(1.0 - best_dist / self.config.max_distance).clamp(0.0, 1.0)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
if let Some(idx) = best_idx {
|
||||
let template = &self.templates[idx];
|
||||
Ok(GestureResult {
|
||||
recognized,
|
||||
gesture_type: if recognized {
|
||||
Some(template.gesture_type)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
template_name: if recognized {
|
||||
Some(template.name.clone())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
distance: best_dist,
|
||||
confidence,
|
||||
person_id,
|
||||
timestamp_us,
|
||||
})
|
||||
} else {
|
||||
Ok(GestureResult {
|
||||
recognized: false,
|
||||
gesture_type: None,
|
||||
template_name: None,
|
||||
distance: f64::INFINITY,
|
||||
confidence: 0.0,
|
||||
person_id,
|
||||
timestamp_us,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Dynamic Time Warping
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Compute DTW distance between two multivariate time series.
|
||||
///
|
||||
/// Uses the Sakoe-Chiba band constraint to limit warping.
|
||||
/// Each frame is a vector of `feature_dim` dimensions.
|
||||
fn dtw_distance(seq_a: &[Vec<f64>], seq_b: &[Vec<f64>], band_width: usize) -> f64 {
|
||||
let n = seq_a.len();
|
||||
let m = seq_b.len();
|
||||
|
||||
if n == 0 || m == 0 {
|
||||
return f64::INFINITY;
|
||||
}
|
||||
|
||||
// Cost matrix (only need 2 rows for memory efficiency)
|
||||
let mut prev = vec![f64::INFINITY; m + 1];
|
||||
let mut curr = vec![f64::INFINITY; m + 1];
|
||||
prev[0] = 0.0;
|
||||
|
||||
for i in 1..=n {
|
||||
curr[0] = f64::INFINITY;
|
||||
|
||||
let j_start = if band_width >= i {
|
||||
1
|
||||
} else {
|
||||
i.saturating_sub(band_width).max(1)
|
||||
};
|
||||
let j_end = (i + band_width).min(m);
|
||||
|
||||
for j in 1..=m {
|
||||
if j < j_start || j > j_end {
|
||||
curr[j] = f64::INFINITY;
|
||||
continue;
|
||||
}
|
||||
|
||||
let cost = euclidean_distance(&seq_a[i - 1], &seq_b[j - 1]);
|
||||
curr[j] = cost
|
||||
+ prev[j] // insertion
|
||||
.min(curr[j - 1]) // deletion
|
||||
.min(prev[j - 1]); // match
|
||||
}
|
||||
|
||||
std::mem::swap(&mut prev, &mut curr);
|
||||
}
|
||||
|
||||
prev[m]
|
||||
}
|
||||
|
||||
/// Euclidean distance between two feature vectors.
|
||||
fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| (x - y) * (x - y))
|
||||
.sum::<f64>()
|
||||
.sqrt()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_template(
|
||||
name: &str,
|
||||
gesture_type: GestureType,
|
||||
n_frames: usize,
|
||||
feature_dim: usize,
|
||||
pattern: fn(usize, usize) -> f64,
|
||||
) -> GestureTemplate {
|
||||
let sequence: Vec<Vec<f64>> = (0..n_frames)
|
||||
.map(|t| (0..feature_dim).map(|d| pattern(t, d)).collect())
|
||||
.collect();
|
||||
GestureTemplate {
|
||||
name: name.to_string(),
|
||||
gesture_type,
|
||||
sequence,
|
||||
feature_dim,
|
||||
}
|
||||
}
|
||||
|
||||
fn wave_pattern(t: usize, d: usize) -> f64 {
|
||||
if d == 0 {
|
||||
(t as f64 * 0.5).sin()
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
fn push_pattern(t: usize, d: usize) -> f64 {
|
||||
if d == 0 {
|
||||
t as f64 * 0.1
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
fn small_config() -> GestureConfig {
|
||||
GestureConfig {
|
||||
feature_dim: 4,
|
||||
min_sequence_len: 5,
|
||||
max_distance: 10.0,
|
||||
band_width: 3,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classifier_creation() {
|
||||
let classifier = GestureClassifier::new(small_config());
|
||||
assert_eq!(classifier.template_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_template() {
|
||||
let mut classifier = GestureClassifier::new(small_config());
|
||||
let template = make_template("wave", GestureType::Wave, 10, 4, wave_pattern);
|
||||
classifier.add_template(template).unwrap();
|
||||
assert_eq!(classifier.template_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_template_empty_name() {
|
||||
let mut classifier = GestureClassifier::new(small_config());
|
||||
let template = make_template("", GestureType::Wave, 10, 4, wave_pattern);
|
||||
assert!(matches!(
|
||||
classifier.add_template(template),
|
||||
Err(GestureError::InvalidTemplateName(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_template_wrong_dim() {
|
||||
let mut classifier = GestureClassifier::new(small_config());
|
||||
let template = make_template("wave", GestureType::Wave, 10, 8, wave_pattern);
|
||||
assert!(matches!(
|
||||
classifier.add_template(template),
|
||||
Err(GestureError::DimensionMismatch { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_template_too_short() {
|
||||
let mut classifier = GestureClassifier::new(small_config());
|
||||
let template = make_template("wave", GestureType::Wave, 3, 4, wave_pattern);
|
||||
assert!(matches!(
|
||||
classifier.add_template(template),
|
||||
Err(GestureError::SequenceTooShort { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classify_no_templates() {
|
||||
let classifier = GestureClassifier::new(small_config());
|
||||
let seq: Vec<Vec<f64>> = (0..10).map(|_| vec![0.0; 4]).collect();
|
||||
assert!(matches!(
|
||||
classifier.classify(&seq, 1, 0),
|
||||
Err(GestureError::NoTemplates)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classify_exact_match() {
|
||||
let mut classifier = GestureClassifier::new(small_config());
|
||||
let template = make_template("wave", GestureType::Wave, 10, 4, wave_pattern);
|
||||
classifier.add_template(template).unwrap();
|
||||
|
||||
// Feed the exact same pattern
|
||||
let seq: Vec<Vec<f64>> = (0..10)
|
||||
.map(|t| (0..4).map(|d| wave_pattern(t, d)).collect())
|
||||
.collect();
|
||||
|
||||
let result = classifier.classify(&seq, 1, 100_000).unwrap();
|
||||
assert!(result.recognized);
|
||||
assert_eq!(result.gesture_type, Some(GestureType::Wave));
|
||||
assert!(
|
||||
result.distance < 1e-10,
|
||||
"Exact match should have zero distance"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classify_best_of_two() {
|
||||
let mut classifier = GestureClassifier::new(GestureConfig {
|
||||
max_distance: 100.0,
|
||||
..small_config()
|
||||
});
|
||||
classifier
|
||||
.add_template(make_template(
|
||||
"wave",
|
||||
GestureType::Wave,
|
||||
10,
|
||||
4,
|
||||
wave_pattern,
|
||||
))
|
||||
.unwrap();
|
||||
classifier
|
||||
.add_template(make_template(
|
||||
"push",
|
||||
GestureType::Push,
|
||||
10,
|
||||
4,
|
||||
push_pattern,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
// Feed a wave-like pattern
|
||||
let seq: Vec<Vec<f64>> = (0..10)
|
||||
.map(|t| (0..4).map(|d| wave_pattern(t, d) + 0.01).collect())
|
||||
.collect();
|
||||
|
||||
let result = classifier.classify(&seq, 1, 0).unwrap();
|
||||
assert!(result.recognized);
|
||||
assert_eq!(result.gesture_type, Some(GestureType::Wave));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classify_no_match_high_distance() {
|
||||
let mut classifier = GestureClassifier::new(GestureConfig {
|
||||
max_distance: 0.001, // very strict
|
||||
..small_config()
|
||||
});
|
||||
classifier
|
||||
.add_template(make_template(
|
||||
"wave",
|
||||
GestureType::Wave,
|
||||
10,
|
||||
4,
|
||||
wave_pattern,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
// Random-ish sequence
|
||||
let seq: Vec<Vec<f64>> = (0..10)
|
||||
.map(|t| vec![t as f64 * 10.0, 0.0, 0.0, 0.0])
|
||||
.collect();
|
||||
|
||||
let result = classifier.classify(&seq, 1, 0).unwrap();
|
||||
assert!(!result.recognized);
|
||||
assert!(result.gesture_type.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dtw_identical_sequences() {
|
||||
let seq: Vec<Vec<f64>> = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
|
||||
let dist = dtw_distance(&seq, &seq, 3);
|
||||
assert!(
|
||||
dist < 1e-10,
|
||||
"Identical sequences should have zero DTW distance"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dtw_different_sequences() {
|
||||
let a: Vec<Vec<f64>> = vec![vec![0.0], vec![0.0], vec![0.0]];
|
||||
let b: Vec<Vec<f64>> = vec![vec![10.0], vec![10.0], vec![10.0]];
|
||||
let dist = dtw_distance(&a, &b, 3);
|
||||
assert!(
|
||||
dist > 0.0,
|
||||
"Different sequences should have non-zero DTW distance"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dtw_time_warped() {
|
||||
// Same shape but different speed
|
||||
let a: Vec<Vec<f64>> = vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0]];
|
||||
let b: Vec<Vec<f64>> = vec![
|
||||
vec![0.0],
|
||||
vec![0.5],
|
||||
vec![1.0],
|
||||
vec![1.5],
|
||||
vec![2.0],
|
||||
vec![2.5],
|
||||
vec![3.0],
|
||||
];
|
||||
let dist = dtw_distance(&a, &b, 4);
|
||||
// DTW should be relatively small despite different lengths
|
||||
assert!(dist < 2.0, "DTW should handle time warping, got {}", dist);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_distance() {
|
||||
let a = vec![0.0, 3.0];
|
||||
let b = vec![4.0, 0.0];
|
||||
let d = euclidean_distance(&a, &b);
|
||||
assert!((d - 5.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gesture_type_names() {
|
||||
assert_eq!(GestureType::Wave.name(), "wave");
|
||||
assert_eq!(GestureType::Push.name(), "push");
|
||||
assert_eq!(GestureType::Circle.name(), "circle");
|
||||
assert_eq!(GestureType::Custom.name(), "custom");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,509 @@
|
||||
//! Pre-movement intention lead signal detector.
|
||||
//!
|
||||
//! Detects anticipatory postural adjustments (APAs) 200-500ms before
|
||||
//! visible movement onset. Works by analyzing the trajectory of AETHER
|
||||
//! embeddings in embedding space: before a person initiates a step or
|
||||
//! reach, their weight shifts create subtle CSI changes that appear as
|
||||
//! velocity and acceleration in embedding space.
|
||||
//!
|
||||
//! # Algorithm
|
||||
//! 1. Maintain a rolling window of recent embeddings (2 seconds at 20 Hz)
|
||||
//! 2. Compute velocity (first derivative) and acceleration (second derivative)
|
||||
//! in embedding space
|
||||
//! 3. Detect when acceleration exceeds a threshold while velocity is still low
|
||||
//! (the body is loading/shifting but hasn't moved yet)
|
||||
//! 4. Output a lead signal with estimated time-to-movement
|
||||
//!
|
||||
//! # References
|
||||
//! - ADR-030 Tier 3: Intention Lead Signals
|
||||
//! - Massion (1992), "Movement, posture and equilibrium: Interaction
|
||||
//! and coordination" Progress in Neurobiology
|
||||
|
||||
use std::collections::VecDeque;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Error types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors from intention detection operations.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum IntentionError {
|
||||
/// Not enough embedding history to compute derivatives.
|
||||
#[error("Insufficient history: need >= {needed} frames, got {got}")]
|
||||
InsufficientHistory { needed: usize, got: usize },
|
||||
|
||||
/// Embedding dimension mismatch.
|
||||
#[error("Embedding dimension mismatch: expected {expected}, got {got}")]
|
||||
DimensionMismatch { expected: usize, got: usize },
|
||||
|
||||
/// Invalid configuration.
|
||||
#[error("Invalid configuration: {0}")]
|
||||
InvalidConfig(String),
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Configuration
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Configuration for the intention detector.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IntentionConfig {
|
||||
/// Embedding dimension (typically 128).
|
||||
pub embedding_dim: usize,
|
||||
/// Rolling window size in frames (2s at 20Hz = 40 frames).
|
||||
pub window_size: usize,
|
||||
/// Sampling rate in Hz.
|
||||
pub sample_rate_hz: f64,
|
||||
/// Acceleration threshold for pre-movement detection (embedding space units/s^2).
|
||||
pub acceleration_threshold: f64,
|
||||
/// Maximum velocity for a pre-movement signal (below this = still preparing).
|
||||
pub max_pre_movement_velocity: f64,
|
||||
/// Minimum frames of sustained acceleration to trigger a lead signal.
|
||||
pub min_sustained_frames: usize,
|
||||
/// Lead time window: max seconds before movement that we flag.
|
||||
pub max_lead_time_s: f64,
|
||||
}
|
||||
|
||||
impl Default for IntentionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
embedding_dim: 128,
|
||||
window_size: 40,
|
||||
sample_rate_hz: 20.0,
|
||||
acceleration_threshold: 0.5,
|
||||
max_pre_movement_velocity: 2.0,
|
||||
min_sustained_frames: 4,
|
||||
max_lead_time_s: 0.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Lead signal result
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Pre-movement lead signal.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LeadSignal {
|
||||
/// Whether a pre-movement signal was detected.
|
||||
pub detected: bool,
|
||||
/// Confidence in the detection (0.0 to 1.0).
|
||||
pub confidence: f64,
|
||||
/// Estimated time until movement onset (seconds).
|
||||
pub estimated_lead_time_s: f64,
|
||||
/// Current velocity magnitude in embedding space.
|
||||
pub velocity_magnitude: f64,
|
||||
/// Current acceleration magnitude in embedding space.
|
||||
pub acceleration_magnitude: f64,
|
||||
/// Number of consecutive frames of sustained acceleration.
|
||||
pub sustained_frames: usize,
|
||||
/// Timestamp (microseconds) of this detection.
|
||||
pub timestamp_us: u64,
|
||||
/// Dominant direction of acceleration (unit vector in embedding space, first 3 dims).
|
||||
pub direction_hint: [f64; 3],
|
||||
}
|
||||
|
||||
/// Trajectory state for one frame.
|
||||
#[derive(Debug, Clone)]
|
||||
struct TrajectoryPoint {
|
||||
embedding: Vec<f64>,
|
||||
timestamp_us: u64,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Intention detector
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Pre-movement intention lead signal detector.
|
||||
///
|
||||
/// Maintains a rolling window of embeddings and computes velocity
|
||||
/// and acceleration in embedding space to detect anticipatory
|
||||
/// postural adjustments before movement onset.
|
||||
#[derive(Debug)]
|
||||
pub struct IntentionDetector {
|
||||
config: IntentionConfig,
|
||||
/// Rolling window of recent trajectory points.
|
||||
history: VecDeque<TrajectoryPoint>,
|
||||
/// Count of consecutive frames with pre-movement signature.
|
||||
sustained_count: usize,
|
||||
/// Total frames processed.
|
||||
total_frames: u64,
|
||||
}
|
||||
|
||||
impl IntentionDetector {
|
||||
/// Create a new intention detector.
|
||||
pub fn new(config: IntentionConfig) -> Result<Self, IntentionError> {
|
||||
if config.embedding_dim == 0 {
|
||||
return Err(IntentionError::InvalidConfig(
|
||||
"embedding_dim must be > 0".into(),
|
||||
));
|
||||
}
|
||||
if config.window_size < 3 {
|
||||
return Err(IntentionError::InvalidConfig(
|
||||
"window_size must be >= 3 for second derivative".into(),
|
||||
));
|
||||
}
|
||||
Ok(Self {
|
||||
history: VecDeque::with_capacity(config.window_size),
|
||||
config,
|
||||
sustained_count: 0,
|
||||
total_frames: 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// Feed a new embedding and check for pre-movement signals.
|
||||
///
|
||||
/// `embedding` is the AETHER embedding for the current frame.
|
||||
/// Returns a lead signal result.
|
||||
pub fn update(
|
||||
&mut self,
|
||||
embedding: &[f32],
|
||||
timestamp_us: u64,
|
||||
) -> Result<LeadSignal, IntentionError> {
|
||||
if embedding.len() != self.config.embedding_dim {
|
||||
return Err(IntentionError::DimensionMismatch {
|
||||
expected: self.config.embedding_dim,
|
||||
got: embedding.len(),
|
||||
});
|
||||
}
|
||||
|
||||
self.total_frames += 1;
|
||||
|
||||
// Convert to f64 for trajectory analysis
|
||||
let emb_f64: Vec<f64> = embedding.iter().map(|&x| x as f64).collect();
|
||||
|
||||
// Add to history
|
||||
if self.history.len() >= self.config.window_size {
|
||||
self.history.pop_front();
|
||||
}
|
||||
self.history.push_back(TrajectoryPoint {
|
||||
embedding: emb_f64,
|
||||
timestamp_us,
|
||||
});
|
||||
|
||||
// Need at least 3 points for second derivative
|
||||
if self.history.len() < 3 {
|
||||
return Ok(LeadSignal {
|
||||
detected: false,
|
||||
confidence: 0.0,
|
||||
estimated_lead_time_s: 0.0,
|
||||
velocity_magnitude: 0.0,
|
||||
acceleration_magnitude: 0.0,
|
||||
sustained_frames: 0,
|
||||
timestamp_us,
|
||||
direction_hint: [0.0; 3],
|
||||
});
|
||||
}
|
||||
|
||||
// Compute velocity and acceleration
|
||||
let n = self.history.len();
|
||||
let dt = 1.0 / self.config.sample_rate_hz;
|
||||
|
||||
// Velocity: (embedding[n-1] - embedding[n-2]) / dt
|
||||
let velocity = embedding_diff(
|
||||
&self.history[n - 1].embedding,
|
||||
&self.history[n - 2].embedding,
|
||||
dt,
|
||||
);
|
||||
let velocity_mag = l2_norm_f64(&velocity);
|
||||
|
||||
// Acceleration: (velocity[n-1] - velocity[n-2]) / dt
|
||||
// Approximate: (emb[n-1] - 2*emb[n-2] + emb[n-3]) / dt^2
|
||||
let acceleration = embedding_second_diff(
|
||||
&self.history[n - 1].embedding,
|
||||
&self.history[n - 2].embedding,
|
||||
&self.history[n - 3].embedding,
|
||||
dt,
|
||||
);
|
||||
let accel_mag = l2_norm_f64(&acceleration);
|
||||
|
||||
// Pre-movement detection:
|
||||
// High acceleration + low velocity = body is loading/shifting but hasn't moved
|
||||
let is_pre_movement = accel_mag > self.config.acceleration_threshold
|
||||
&& velocity_mag < self.config.max_pre_movement_velocity;
|
||||
|
||||
if is_pre_movement {
|
||||
self.sustained_count += 1;
|
||||
} else {
|
||||
self.sustained_count = 0;
|
||||
}
|
||||
|
||||
let detected = self.sustained_count >= self.config.min_sustained_frames;
|
||||
|
||||
// Estimate lead time based on current acceleration and velocity
|
||||
let estimated_lead = if detected && accel_mag > 1e-10 {
|
||||
// Time until velocity reaches threshold: t = (v_thresh - v) / a
|
||||
let remaining = (self.config.max_pre_movement_velocity - velocity_mag) / accel_mag;
|
||||
remaining.clamp(0.0, self.config.max_lead_time_s)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// Confidence based on how clearly the acceleration exceeds threshold
|
||||
let confidence = if detected {
|
||||
let ratio = accel_mag / self.config.acceleration_threshold;
|
||||
(ratio - 1.0).clamp(0.0, 1.0)
|
||||
* (self.sustained_count as f64 / self.config.min_sustained_frames as f64).min(1.0)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// Direction hint from first 3 dimensions of acceleration
|
||||
let direction_hint = [
|
||||
acceleration.first().copied().unwrap_or(0.0),
|
||||
acceleration.get(1).copied().unwrap_or(0.0),
|
||||
acceleration.get(2).copied().unwrap_or(0.0),
|
||||
];
|
||||
|
||||
Ok(LeadSignal {
|
||||
detected,
|
||||
confidence,
|
||||
estimated_lead_time_s: estimated_lead,
|
||||
velocity_magnitude: velocity_mag,
|
||||
acceleration_magnitude: accel_mag,
|
||||
sustained_frames: self.sustained_count,
|
||||
timestamp_us,
|
||||
direction_hint,
|
||||
})
|
||||
}
|
||||
|
||||
/// Reset the detector state.
|
||||
pub fn reset(&mut self) {
|
||||
self.history.clear();
|
||||
self.sustained_count = 0;
|
||||
}
|
||||
|
||||
/// Number of frames in the history.
|
||||
pub fn history_len(&self) -> usize {
|
||||
self.history.len()
|
||||
}
|
||||
|
||||
/// Total frames processed.
|
||||
pub fn total_frames(&self) -> u64 {
|
||||
self.total_frames
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Utility functions
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// First difference of two embedding vectors, divided by dt.
|
||||
fn embedding_diff(a: &[f64], b: &[f64], dt: f64) -> Vec<f64> {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(&ai, &bi)| (ai - bi) / dt)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Second difference: (a - 2b + c) / dt^2.
|
||||
fn embedding_second_diff(a: &[f64], b: &[f64], c: &[f64], dt: f64) -> Vec<f64> {
|
||||
let dt2 = dt * dt;
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.zip(c.iter())
|
||||
.map(|((&ai, &bi), &ci)| (ai - 2.0 * bi + ci) / dt2)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// L2 norm of an f64 slice.
|
||||
fn l2_norm_f64(v: &[f64]) -> f64 {
|
||||
v.iter().map(|x| x * x).sum::<f64>().sqrt()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_config() -> IntentionConfig {
|
||||
IntentionConfig {
|
||||
embedding_dim: 4,
|
||||
window_size: 10,
|
||||
sample_rate_hz: 20.0,
|
||||
acceleration_threshold: 0.5,
|
||||
max_pre_movement_velocity: 2.0,
|
||||
min_sustained_frames: 3,
|
||||
max_lead_time_s: 0.5,
|
||||
}
|
||||
}
|
||||
|
||||
fn static_embedding() -> Vec<f32> {
|
||||
vec![1.0, 0.0, 0.0, 0.0]
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_creation() {
|
||||
let config = make_config();
|
||||
let detector = IntentionDetector::new(config).unwrap();
|
||||
assert_eq!(detector.history_len(), 0);
|
||||
assert_eq!(detector.total_frames(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_config_zero_dim() {
|
||||
let config = IntentionConfig {
|
||||
embedding_dim: 0,
|
||||
..make_config()
|
||||
};
|
||||
assert!(matches!(
|
||||
IntentionDetector::new(config),
|
||||
Err(IntentionError::InvalidConfig(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_config_small_window() {
|
||||
let config = IntentionConfig {
|
||||
window_size: 2,
|
||||
..make_config()
|
||||
};
|
||||
assert!(matches!(
|
||||
IntentionDetector::new(config),
|
||||
Err(IntentionError::InvalidConfig(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dimension_mismatch() {
|
||||
let config = make_config();
|
||||
let mut detector = IntentionDetector::new(config).unwrap();
|
||||
let result = detector.update(&[1.0, 0.0], 0);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(IntentionError::DimensionMismatch { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_static_scene_no_detection() {
|
||||
let config = make_config();
|
||||
let mut detector = IntentionDetector::new(config).unwrap();
|
||||
|
||||
for frame in 0..20 {
|
||||
let signal = detector
|
||||
.update(&static_embedding(), frame * 50_000)
|
||||
.unwrap();
|
||||
assert!(
|
||||
!signal.detected,
|
||||
"Static scene should not trigger detection"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gradual_acceleration_detected() {
|
||||
let mut config = make_config();
|
||||
config.acceleration_threshold = 100.0; // low threshold for test
|
||||
config.max_pre_movement_velocity = 100000.0;
|
||||
config.min_sustained_frames = 2;
|
||||
|
||||
let mut detector = IntentionDetector::new(config).unwrap();
|
||||
|
||||
// Feed gradually accelerating embeddings
|
||||
// Position = 0.5 * a * t^2, so embedding shifts quadratically
|
||||
let mut any_detected = false;
|
||||
for frame in 0..30_u64 {
|
||||
let t = frame as f32 * 0.05;
|
||||
let pos = 50.0 * t * t; // acceleration = 100 units/s^2
|
||||
let emb = vec![1.0 + pos, 0.0, 0.0, 0.0];
|
||||
let signal = detector.update(&emb, frame * 50_000).unwrap();
|
||||
if signal.detected {
|
||||
any_detected = true;
|
||||
assert!(signal.confidence > 0.0);
|
||||
assert!(signal.acceleration_magnitude > 0.0);
|
||||
}
|
||||
}
|
||||
assert!(any_detected, "Accelerating signal should trigger detection");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_constant_velocity_no_detection() {
|
||||
let config = make_config();
|
||||
let mut detector = IntentionDetector::new(config).unwrap();
|
||||
|
||||
// Constant velocity = zero acceleration → no pre-movement
|
||||
for frame in 0..20_u64 {
|
||||
let pos = frame as f32 * 0.01; // constant velocity
|
||||
let emb = vec![1.0 + pos, 0.0, 0.0, 0.0];
|
||||
let signal = detector.update(&emb, frame * 50_000).unwrap();
|
||||
assert!(
|
||||
!signal.detected,
|
||||
"Constant velocity should not trigger pre-movement"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset() {
|
||||
let config = make_config();
|
||||
let mut detector = IntentionDetector::new(config).unwrap();
|
||||
|
||||
for frame in 0..5_u64 {
|
||||
detector
|
||||
.update(&static_embedding(), frame * 50_000)
|
||||
.unwrap();
|
||||
}
|
||||
assert_eq!(detector.history_len(), 5);
|
||||
|
||||
detector.reset();
|
||||
assert_eq!(detector.history_len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lead_signal_fields() {
|
||||
let config = make_config();
|
||||
let mut detector = IntentionDetector::new(config).unwrap();
|
||||
|
||||
// Need at least 3 frames for derivatives
|
||||
for frame in 0..3_u64 {
|
||||
let signal = detector
|
||||
.update(&static_embedding(), frame * 50_000)
|
||||
.unwrap();
|
||||
assert_eq!(signal.sustained_frames, 0);
|
||||
}
|
||||
|
||||
let signal = detector.update(&static_embedding(), 150_000).unwrap();
|
||||
assert!(signal.velocity_magnitude >= 0.0);
|
||||
assert!(signal.acceleration_magnitude >= 0.0);
|
||||
assert_eq!(signal.direction_hint.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_window_size_limit() {
|
||||
let config = IntentionConfig {
|
||||
window_size: 5,
|
||||
..make_config()
|
||||
};
|
||||
let mut detector = IntentionDetector::new(config).unwrap();
|
||||
|
||||
for frame in 0..10_u64 {
|
||||
detector
|
||||
.update(&static_embedding(), frame * 50_000)
|
||||
.unwrap();
|
||||
}
|
||||
assert_eq!(detector.history_len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_diff() {
|
||||
let a = vec![2.0, 4.0];
|
||||
let b = vec![1.0, 2.0];
|
||||
let diff = embedding_diff(&a, &b, 0.5);
|
||||
assert!((diff[0] - 2.0).abs() < 1e-10); // (2-1)/0.5
|
||||
assert!((diff[1] - 4.0).abs() < 1e-10); // (4-2)/0.5
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_second_diff() {
|
||||
// Quadratic sequence: 1, 4, 9 → second diff = 2
|
||||
let a = vec![9.0];
|
||||
let b = vec![4.0];
|
||||
let c = vec![1.0];
|
||||
let sd = embedding_second_diff(&a, &b, &c, 1.0);
|
||||
assert!((sd[0] - 2.0).abs() < 1e-10);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,692 @@
|
||||
//! Longitudinal biomechanics drift detection.
|
||||
//!
|
||||
//! Maintains per-person biophysical baselines over days/weeks using Welford
|
||||
//! online statistics. Detects meaningful drift in gait symmetry, stability,
|
||||
//! breathing regularity, micro-tremor, and activity level. Produces traceable
|
||||
//! evidence reports that link to stored embedding trajectories.
|
||||
//!
|
||||
//! # Key Invariants
|
||||
//! - Baseline requires >= 7 observation days before drift detection activates
|
||||
//! - Drift alert requires > 2-sigma deviation sustained for >= 3 consecutive days
|
||||
//! - Output is metric values and deviations, never diagnostic language
|
||||
//! - Welford statistics use full history (no windowing) for stability
|
||||
//!
|
||||
//! # References
|
||||
//! - Welford, B.P. (1962). "Note on a Method for Calculating Corrected
|
||||
//! Sums of Squares." Technometrics.
|
||||
//! - ADR-030 Tier 4: Longitudinal Biomechanics Drift
|
||||
|
||||
use crate::ruvsense::field_model::WelfordStats;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Error types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors from longitudinal monitoring operations.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum LongitudinalError {
|
||||
/// Not enough observation days for drift detection.
|
||||
#[error("Insufficient observation days: need >= {needed}, got {got}")]
|
||||
InsufficientDays { needed: u32, got: u32 },
|
||||
|
||||
/// Person ID not found in the registry.
|
||||
#[error("Unknown person ID: {0}")]
|
||||
UnknownPerson(u64),
|
||||
|
||||
/// Embedding dimension mismatch.
|
||||
#[error("Embedding dimension mismatch: expected {expected}, got {got}")]
|
||||
EmbeddingDimensionMismatch { expected: usize, got: usize },
|
||||
|
||||
/// Invalid metric value.
|
||||
#[error("Invalid metric value for {metric}: {reason}")]
|
||||
InvalidMetric { metric: String, reason: String },
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Domain types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Biophysical metric types tracked per person.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum DriftMetric {
|
||||
/// Gait symmetry ratio (0.0 = perfectly symmetric, higher = asymmetric).
|
||||
GaitSymmetry,
|
||||
/// Stability index (lower = less stable).
|
||||
StabilityIndex,
|
||||
/// Breathing regularity (coefficient of variation of breath intervals).
|
||||
BreathingRegularity,
|
||||
/// Micro-tremor amplitude (mm, from high-frequency pose jitter).
|
||||
MicroTremor,
|
||||
/// Daily activity level (normalized 0-1).
|
||||
ActivityLevel,
|
||||
}
|
||||
|
||||
impl DriftMetric {
|
||||
/// All metric variants.
|
||||
pub fn all() -> &'static [DriftMetric] {
|
||||
&[
|
||||
DriftMetric::GaitSymmetry,
|
||||
DriftMetric::StabilityIndex,
|
||||
DriftMetric::BreathingRegularity,
|
||||
DriftMetric::MicroTremor,
|
||||
DriftMetric::ActivityLevel,
|
||||
]
|
||||
}
|
||||
|
||||
/// Human-readable name.
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
DriftMetric::GaitSymmetry => "gait_symmetry",
|
||||
DriftMetric::StabilityIndex => "stability_index",
|
||||
DriftMetric::BreathingRegularity => "breathing_regularity",
|
||||
DriftMetric::MicroTremor => "micro_tremor",
|
||||
DriftMetric::ActivityLevel => "activity_level",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Direction of drift.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum DriftDirection {
|
||||
/// Metric is increasing relative to baseline.
|
||||
Increasing,
|
||||
/// Metric is decreasing relative to baseline.
|
||||
Decreasing,
|
||||
}
|
||||
|
||||
/// Monitoring level for drift reports.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub enum MonitoringLevel {
|
||||
/// Level 1: Raw biophysical metric value.
|
||||
Physiological = 1,
|
||||
/// Level 2: Personal baseline deviation.
|
||||
Drift = 2,
|
||||
/// Level 3: Pattern-matched risk correlation.
|
||||
RiskCorrelation = 3,
|
||||
}
|
||||
|
||||
/// A drift report with traceable evidence.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DriftReport {
|
||||
/// Person this report pertains to.
|
||||
pub person_id: u64,
|
||||
/// Which metric drifted.
|
||||
pub metric: DriftMetric,
|
||||
/// Direction of drift.
|
||||
pub direction: DriftDirection,
|
||||
/// Z-score relative to personal baseline.
|
||||
pub z_score: f64,
|
||||
/// Current metric value (today or most recent).
|
||||
pub current_value: f64,
|
||||
/// Baseline mean for this metric.
|
||||
pub baseline_mean: f64,
|
||||
/// Baseline standard deviation.
|
||||
pub baseline_std: f64,
|
||||
/// Number of consecutive days the drift has been sustained.
|
||||
pub sustained_days: u32,
|
||||
/// Monitoring level.
|
||||
pub level: MonitoringLevel,
|
||||
/// Timestamp (microseconds) when this report was generated.
|
||||
pub timestamp_us: u64,
|
||||
}
|
||||
|
||||
/// Daily metric summary for one person.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DailyMetricSummary {
|
||||
/// Person ID.
|
||||
pub person_id: u64,
|
||||
/// Day timestamp (start of day, microseconds).
|
||||
pub day_us: u64,
|
||||
/// Metric values for this day.
|
||||
pub metrics: Vec<(DriftMetric, f64)>,
|
||||
/// AETHER embedding centroid for this day.
|
||||
pub embedding_centroid: Option<Vec<f32>>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Personal baseline
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Per-person longitudinal baseline with Welford statistics.
|
||||
///
|
||||
/// Tracks running mean and variance for each biophysical metric over
|
||||
/// the person's entire observation history. Uses Welford's algorithm
|
||||
/// for numerical stability.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PersonalBaseline {
|
||||
/// Unique person identifier.
|
||||
pub person_id: u64,
|
||||
/// Per-metric Welford accumulators.
|
||||
pub gait_symmetry: WelfordStats,
|
||||
pub stability_index: WelfordStats,
|
||||
pub breathing_regularity: WelfordStats,
|
||||
pub micro_tremor: WelfordStats,
|
||||
pub activity_level: WelfordStats,
|
||||
/// Running centroid of AETHER embeddings.
|
||||
pub embedding_centroid: Vec<f32>,
|
||||
/// Number of observation days.
|
||||
pub observation_days: u32,
|
||||
/// Timestamp of last update (microseconds).
|
||||
pub updated_at_us: u64,
|
||||
/// Per-metric consecutive drift days counter.
|
||||
drift_counters: [u32; 5],
|
||||
}
|
||||
|
||||
impl PersonalBaseline {
|
||||
/// Create a new baseline for a person.
|
||||
///
|
||||
/// `embedding_dim` is typically 128 for AETHER embeddings.
|
||||
pub fn new(person_id: u64, embedding_dim: usize) -> Self {
|
||||
Self {
|
||||
person_id,
|
||||
gait_symmetry: WelfordStats::new(),
|
||||
stability_index: WelfordStats::new(),
|
||||
breathing_regularity: WelfordStats::new(),
|
||||
micro_tremor: WelfordStats::new(),
|
||||
activity_level: WelfordStats::new(),
|
||||
embedding_centroid: vec![0.0; embedding_dim],
|
||||
observation_days: 0,
|
||||
updated_at_us: 0,
|
||||
drift_counters: [0; 5],
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the Welford stats for a specific metric.
|
||||
pub fn stats_for(&self, metric: DriftMetric) -> &WelfordStats {
|
||||
match metric {
|
||||
DriftMetric::GaitSymmetry => &self.gait_symmetry,
|
||||
DriftMetric::StabilityIndex => &self.stability_index,
|
||||
DriftMetric::BreathingRegularity => &self.breathing_regularity,
|
||||
DriftMetric::MicroTremor => &self.micro_tremor,
|
||||
DriftMetric::ActivityLevel => &self.activity_level,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get mutable Welford stats for a specific metric.
|
||||
fn stats_for_mut(&mut self, metric: DriftMetric) -> &mut WelfordStats {
|
||||
match metric {
|
||||
DriftMetric::GaitSymmetry => &mut self.gait_symmetry,
|
||||
DriftMetric::StabilityIndex => &mut self.stability_index,
|
||||
DriftMetric::BreathingRegularity => &mut self.breathing_regularity,
|
||||
DriftMetric::MicroTremor => &mut self.micro_tremor,
|
||||
DriftMetric::ActivityLevel => &mut self.activity_level,
|
||||
}
|
||||
}
|
||||
|
||||
/// Index of a metric in the drift_counters array.
|
||||
fn metric_index(metric: DriftMetric) -> usize {
|
||||
match metric {
|
||||
DriftMetric::GaitSymmetry => 0,
|
||||
DriftMetric::StabilityIndex => 1,
|
||||
DriftMetric::BreathingRegularity => 2,
|
||||
DriftMetric::MicroTremor => 3,
|
||||
DriftMetric::ActivityLevel => 4,
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether baseline has enough data for drift detection.
|
||||
pub fn is_ready(&self) -> bool {
|
||||
self.observation_days >= 7
|
||||
}
|
||||
|
||||
/// Update baseline with a daily summary.
|
||||
///
|
||||
/// Returns drift reports for any metrics that exceed thresholds.
|
||||
pub fn update_daily(
|
||||
&mut self,
|
||||
summary: &DailyMetricSummary,
|
||||
timestamp_us: u64,
|
||||
) -> Vec<DriftReport> {
|
||||
self.observation_days += 1;
|
||||
self.updated_at_us = timestamp_us;
|
||||
|
||||
// Update embedding centroid with EMA (decay = 0.95)
|
||||
if let Some(ref emb) = summary.embedding_centroid {
|
||||
if emb.len() == self.embedding_centroid.len() {
|
||||
let alpha = 0.05_f32; // 1 - 0.95
|
||||
for (c, e) in self.embedding_centroid.iter_mut().zip(emb.iter()) {
|
||||
*c = (1.0 - alpha) * *c + alpha * *e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut reports = Vec::new();
|
||||
|
||||
let observation_days = self.observation_days;
|
||||
|
||||
for &(metric, value) in &summary.metrics {
|
||||
// Update stats and extract values before releasing the mutable borrow
|
||||
let (z, baseline_mean, baseline_std) = {
|
||||
let stats = self.stats_for_mut(metric);
|
||||
stats.update(value);
|
||||
let z = stats.z_score(value);
|
||||
let mean = stats.mean;
|
||||
let std = stats.std_dev();
|
||||
(z, mean, std)
|
||||
};
|
||||
|
||||
if !self.is_ready_at(observation_days) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let idx = Self::metric_index(metric);
|
||||
|
||||
if z.abs() > 2.0 {
|
||||
self.drift_counters[idx] += 1;
|
||||
} else {
|
||||
self.drift_counters[idx] = 0;
|
||||
}
|
||||
|
||||
if self.drift_counters[idx] >= 3 {
|
||||
let direction = if z > 0.0 {
|
||||
DriftDirection::Increasing
|
||||
} else {
|
||||
DriftDirection::Decreasing
|
||||
};
|
||||
|
||||
let level = if self.drift_counters[idx] >= 7 {
|
||||
MonitoringLevel::RiskCorrelation
|
||||
} else {
|
||||
MonitoringLevel::Drift
|
||||
};
|
||||
|
||||
reports.push(DriftReport {
|
||||
person_id: self.person_id,
|
||||
metric,
|
||||
direction,
|
||||
z_score: z,
|
||||
current_value: value,
|
||||
baseline_mean,
|
||||
baseline_std,
|
||||
sustained_days: self.drift_counters[idx],
|
||||
level,
|
||||
timestamp_us,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
reports
|
||||
}
|
||||
|
||||
/// Check readiness at a specific observation day count (internal helper).
|
||||
fn is_ready_at(&self, days: u32) -> bool {
|
||||
days >= 7
|
||||
}
|
||||
|
||||
/// Get current drift counter for a metric.
|
||||
pub fn drift_days(&self, metric: DriftMetric) -> u32 {
|
||||
self.drift_counters[Self::metric_index(metric)]
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Embedding history (simplified HNSW-indexed store)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Entry in the embedding history.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EmbeddingEntry {
|
||||
/// Person ID.
|
||||
pub person_id: u64,
|
||||
/// Day timestamp (microseconds).
|
||||
pub day_us: u64,
|
||||
/// AETHER embedding vector.
|
||||
pub embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
/// Simplified embedding history store for longitudinal tracking.
|
||||
///
|
||||
/// In production, this would be backed by an HNSW index for fast
|
||||
/// nearest-neighbor search. This implementation uses brute-force
|
||||
/// cosine similarity for correctness.
|
||||
#[derive(Debug)]
|
||||
pub struct EmbeddingHistory {
|
||||
entries: Vec<EmbeddingEntry>,
|
||||
max_entries: usize,
|
||||
embedding_dim: usize,
|
||||
}
|
||||
|
||||
impl EmbeddingHistory {
|
||||
/// Create a new embedding history store.
|
||||
pub fn new(embedding_dim: usize, max_entries: usize) -> Self {
|
||||
Self {
|
||||
entries: Vec::new(),
|
||||
max_entries,
|
||||
embedding_dim,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an embedding entry.
|
||||
pub fn push(&mut self, entry: EmbeddingEntry) -> Result<(), LongitudinalError> {
|
||||
if entry.embedding.len() != self.embedding_dim {
|
||||
return Err(LongitudinalError::EmbeddingDimensionMismatch {
|
||||
expected: self.embedding_dim,
|
||||
got: entry.embedding.len(),
|
||||
});
|
||||
}
|
||||
if self.entries.len() >= self.max_entries {
|
||||
self.entries.drain(..1); // FIFO eviction — acceptable for daily-rate inserts
|
||||
}
|
||||
self.entries.push(entry);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Find the K nearest embeddings to a query vector (brute-force cosine).
|
||||
pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
|
||||
let mut similarities: Vec<(usize, f32)> = self
|
||||
.entries
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, e)| (i, cosine_similarity(query, &e.embedding)))
|
||||
.collect();
|
||||
|
||||
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
similarities.truncate(k);
|
||||
similarities
|
||||
}
|
||||
|
||||
/// Number of entries stored.
|
||||
pub fn len(&self) -> usize {
|
||||
self.entries.len()
|
||||
}
|
||||
|
||||
/// Whether the store is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.entries.is_empty()
|
||||
}
|
||||
|
||||
/// Get entry by index.
|
||||
pub fn get(&self, index: usize) -> Option<&EmbeddingEntry> {
|
||||
self.entries.get(index)
|
||||
}
|
||||
|
||||
/// Get all entries for a specific person.
|
||||
pub fn entries_for_person(&self, person_id: u64) -> Vec<&EmbeddingEntry> {
|
||||
self.entries
|
||||
.iter()
|
||||
.filter(|e| e.person_id == person_id)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Cosine similarity between two f32 vectors.
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let denom = norm_a * norm_b;
|
||||
if denom < 1e-9 {
|
||||
0.0
|
||||
} else {
|
||||
dot / denom
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_daily_summary(person_id: u64, day: u64, values: [f64; 5]) -> DailyMetricSummary {
|
||||
DailyMetricSummary {
|
||||
person_id,
|
||||
day_us: day * 86_400_000_000,
|
||||
metrics: vec![
|
||||
(DriftMetric::GaitSymmetry, values[0]),
|
||||
(DriftMetric::StabilityIndex, values[1]),
|
||||
(DriftMetric::BreathingRegularity, values[2]),
|
||||
(DriftMetric::MicroTremor, values[3]),
|
||||
(DriftMetric::ActivityLevel, values[4]),
|
||||
],
|
||||
embedding_centroid: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_personal_baseline_creation() {
|
||||
let baseline = PersonalBaseline::new(42, 128);
|
||||
assert_eq!(baseline.person_id, 42);
|
||||
assert_eq!(baseline.observation_days, 0);
|
||||
assert!(!baseline.is_ready());
|
||||
assert_eq!(baseline.embedding_centroid.len(), 128);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_baseline_not_ready_before_7_days() {
|
||||
let mut baseline = PersonalBaseline::new(1, 128);
|
||||
for day in 0..6 {
|
||||
let summary = make_daily_summary(1, day, [0.1, 0.9, 0.15, 0.5, 0.7]);
|
||||
let reports = baseline.update_daily(&summary, day * 86_400_000_000);
|
||||
assert!(reports.is_empty(), "No drift before 7 days");
|
||||
}
|
||||
assert!(!baseline.is_ready());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_baseline_ready_after_7_days() {
|
||||
let mut baseline = PersonalBaseline::new(1, 128);
|
||||
for day in 0..7 {
|
||||
let summary = make_daily_summary(1, day, [0.1, 0.9, 0.15, 0.5, 0.7]);
|
||||
baseline.update_daily(&summary, day * 86_400_000_000);
|
||||
}
|
||||
assert!(baseline.is_ready());
|
||||
assert_eq!(baseline.observation_days, 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stable_metrics_no_drift() {
|
||||
let mut baseline = PersonalBaseline::new(1, 128);
|
||||
|
||||
// 20 days of stable metrics
|
||||
for day in 0..20 {
|
||||
let summary = make_daily_summary(1, day, [0.1, 0.9, 0.15, 0.5, 0.7]);
|
||||
let reports = baseline.update_daily(&summary, day * 86_400_000_000);
|
||||
assert!(
|
||||
reports.is_empty(),
|
||||
"Stable metrics should not trigger drift"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_drift_detected_after_sustained_deviation() {
|
||||
let mut baseline = PersonalBaseline::new(1, 128);
|
||||
|
||||
// 30 days of very stable gait symmetry = 0.1 with tiny noise
|
||||
// (more baseline days = stronger prior, so drift stays > 2-sigma longer)
|
||||
for day in 0..30 {
|
||||
let noise = 0.001 * (day as f64 % 3.0 - 1.0); // tiny variation
|
||||
let summary = make_daily_summary(1, day, [0.1 + noise, 0.9, 0.15, 0.5, 0.7]);
|
||||
baseline.update_daily(&summary, day * 86_400_000_000);
|
||||
}
|
||||
|
||||
// Now inject a very large drift in gait symmetry (0.1 -> 5.0) for 5 days.
|
||||
// Even as Welford accumulates these, the z-score should stay well above 2.0
|
||||
// because 30 baseline days anchor the mean near 0.1 with small std dev.
|
||||
let mut any_drift = false;
|
||||
for day in 30..36 {
|
||||
let summary = make_daily_summary(1, day, [5.0, 0.9, 0.15, 0.5, 0.7]);
|
||||
let reports = baseline.update_daily(&summary, day * 86_400_000_000);
|
||||
if !reports.is_empty() {
|
||||
any_drift = true;
|
||||
let r = &reports[0];
|
||||
assert_eq!(r.metric, DriftMetric::GaitSymmetry);
|
||||
assert_eq!(r.direction, DriftDirection::Increasing);
|
||||
assert!(r.z_score > 2.0);
|
||||
assert!(r.sustained_days >= 3);
|
||||
}
|
||||
}
|
||||
assert!(any_drift, "Should detect drift after sustained deviation");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_drift_resolves_when_metric_returns() {
|
||||
let mut baseline = PersonalBaseline::new(1, 128);
|
||||
|
||||
// Stable baseline
|
||||
for day in 0..10 {
|
||||
let summary = make_daily_summary(1, day, [0.1, 0.9, 0.15, 0.5, 0.7]);
|
||||
baseline.update_daily(&summary, day * 86_400_000_000);
|
||||
}
|
||||
|
||||
// Drift for 3 days
|
||||
for day in 10..13 {
|
||||
let summary = make_daily_summary(1, day, [0.9, 0.9, 0.15, 0.5, 0.7]);
|
||||
baseline.update_daily(&summary, day * 86_400_000_000);
|
||||
}
|
||||
|
||||
// Return to normal
|
||||
for day in 13..16 {
|
||||
let summary = make_daily_summary(1, day, [0.1, 0.9, 0.15, 0.5, 0.7]);
|
||||
let reports = baseline.update_daily(&summary, day * 86_400_000_000);
|
||||
// After returning to normal, drift counter resets
|
||||
if day == 15 {
|
||||
assert!(reports.is_empty(), "Drift should resolve");
|
||||
assert_eq!(baseline.drift_days(DriftMetric::GaitSymmetry), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_monitoring_level_escalation() {
|
||||
let mut baseline = PersonalBaseline::new(1, 128);
|
||||
|
||||
// 30 days of stable baseline with tiny noise to anchor stats
|
||||
for day in 0..30 {
|
||||
let noise = 0.001 * (day as f64 % 3.0 - 1.0);
|
||||
let summary = make_daily_summary(1, day, [0.1 + noise, 0.9, 0.15, 0.5, 0.7]);
|
||||
baseline.update_daily(&summary, day * 86_400_000_000);
|
||||
}
|
||||
|
||||
// Sustained massive drift for 10+ days should escalate to RiskCorrelation.
|
||||
// Using value 10.0 (vs baseline ~0.1) to ensure z-score stays well above 2.0
|
||||
// even as Welford accumulates the drifted values.
|
||||
let mut max_level = MonitoringLevel::Physiological;
|
||||
for day in 30..42 {
|
||||
let summary = make_daily_summary(1, day, [10.0, 0.9, 0.15, 0.5, 0.7]);
|
||||
let reports = baseline.update_daily(&summary, day * 86_400_000_000);
|
||||
for r in &reports {
|
||||
if r.level > max_level {
|
||||
max_level = r.level;
|
||||
}
|
||||
}
|
||||
}
|
||||
assert_eq!(
|
||||
max_level,
|
||||
MonitoringLevel::RiskCorrelation,
|
||||
"7+ days sustained drift should reach RiskCorrelation level"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_history_push_and_search() {
|
||||
let mut history = EmbeddingHistory::new(4, 100);
|
||||
|
||||
history
|
||||
.push(EmbeddingEntry {
|
||||
person_id: 1,
|
||||
day_us: 0,
|
||||
embedding: vec![1.0, 0.0, 0.0, 0.0],
|
||||
})
|
||||
.unwrap();
|
||||
history
|
||||
.push(EmbeddingEntry {
|
||||
person_id: 1,
|
||||
day_us: 1,
|
||||
embedding: vec![0.9, 0.1, 0.0, 0.0],
|
||||
})
|
||||
.unwrap();
|
||||
history
|
||||
.push(EmbeddingEntry {
|
||||
person_id: 2,
|
||||
day_us: 0,
|
||||
embedding: vec![0.0, 0.0, 1.0, 0.0],
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let results = history.search(&[1.0, 0.0, 0.0, 0.0], 2);
|
||||
assert_eq!(results.len(), 2);
|
||||
// First result should be exact match
|
||||
assert!((results[0].1 - 1.0).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_history_dimension_mismatch() {
|
||||
let mut history = EmbeddingHistory::new(4, 100);
|
||||
let result = history.push(EmbeddingEntry {
|
||||
person_id: 1,
|
||||
day_us: 0,
|
||||
embedding: vec![1.0, 0.0], // wrong dim
|
||||
});
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(LongitudinalError::EmbeddingDimensionMismatch { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_history_fifo_eviction() {
|
||||
let mut history = EmbeddingHistory::new(2, 3);
|
||||
for i in 0..5 {
|
||||
history
|
||||
.push(EmbeddingEntry {
|
||||
person_id: 1,
|
||||
day_us: i,
|
||||
embedding: vec![i as f32, 0.0],
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
assert_eq!(history.len(), 3);
|
||||
// First entry should be day 2 (0 and 1 evicted)
|
||||
assert_eq!(history.get(0).unwrap().day_us, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_entries_for_person() {
|
||||
let mut history = EmbeddingHistory::new(2, 100);
|
||||
history
|
||||
.push(EmbeddingEntry {
|
||||
person_id: 1,
|
||||
day_us: 0,
|
||||
embedding: vec![1.0, 0.0],
|
||||
})
|
||||
.unwrap();
|
||||
history
|
||||
.push(EmbeddingEntry {
|
||||
person_id: 2,
|
||||
day_us: 0,
|
||||
embedding: vec![0.0, 1.0],
|
||||
})
|
||||
.unwrap();
|
||||
history
|
||||
.push(EmbeddingEntry {
|
||||
person_id: 1,
|
||||
day_us: 1,
|
||||
embedding: vec![0.9, 0.1],
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let entries = history.entries_for_person(1);
|
||||
assert_eq!(entries.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_drift_metric_names() {
|
||||
assert_eq!(DriftMetric::GaitSymmetry.name(), "gait_symmetry");
|
||||
assert_eq!(DriftMetric::ActivityLevel.name(), "activity_level");
|
||||
assert_eq!(DriftMetric::all().len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity_unit_vectors() {
|
||||
let a = vec![1.0_f32, 0.0, 0.0];
|
||||
let b = vec![0.0_f32, 1.0, 0.0];
|
||||
assert!(cosine_similarity(&a, &b).abs() < 1e-6, "Orthogonal = 0");
|
||||
|
||||
let c = vec![1.0_f32, 0.0, 0.0];
|
||||
assert!((cosine_similarity(&a, &c) - 1.0).abs() < 1e-6, "Same = 1");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,324 @@
|
||||
//! RuvSense -- Sensing-First RF Mode for Multistatic WiFi DensePose (ADR-029)
|
||||
//!
|
||||
//! This bounded context implements the multistatic sensing pipeline that fuses
|
||||
//! CSI from multiple ESP32 nodes across multiple WiFi channels into a single
|
||||
//! coherent sensing frame per 50 ms TDMA cycle (20 Hz output).
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! The pipeline flows through six stages:
|
||||
//!
|
||||
//! 1. **Multi-Band Fusion** (`multiband`) -- Aggregate per-channel CSI frames
|
||||
//! from channel-hopping into a wideband virtual snapshot per node.
|
||||
//! 2. **Phase Alignment** (`phase_align`) -- Correct LO-induced phase rotation
|
||||
//! between channels using `ruvector-solver::NeumannSolver`.
|
||||
//! 3. **Multistatic Fusion** (`multistatic`) -- Fuse N node observations into
|
||||
//! a single `FusedSensingFrame` with attention-based cross-node weighting
|
||||
//! via `ruvector-attn-mincut`.
|
||||
//! 4. **Coherence Scoring** (`coherence`) -- Compute per-subcarrier z-score
|
||||
//! coherence against a rolling reference template.
|
||||
//! 5. **Coherence Gating** (`coherence_gate`) -- Apply threshold-based gate
|
||||
//! decision: Accept / PredictOnly / Reject / Recalibrate.
|
||||
//! 6. **Pose Tracking** (`pose_tracker`) -- 17-keypoint Kalman tracker with
|
||||
//! lifecycle state machine and AETHER re-ID embedding support.
|
||||
//!
|
||||
//! # RuVector Crate Usage
|
||||
//!
|
||||
//! - `ruvector-solver` -- Phase alignment, coherence decomposition
|
||||
//! - `ruvector-attn-mincut` -- Cross-node spectrogram fusion
|
||||
//! - `ruvector-mincut` -- Person separation and track assignment
|
||||
//! - `ruvector-attention` -- Cross-channel feature weighting
|
||||
//!
|
||||
//! # References
|
||||
//!
|
||||
//! - ADR-029: Project RuvSense
|
||||
//! - IEEE 802.11bf-2024 WLAN Sensing
|
||||
|
||||
// ADR-030: Exotic sensing tiers
|
||||
pub mod adversarial;
|
||||
pub mod cross_room;
|
||||
pub mod field_model;
|
||||
pub mod gesture;
|
||||
pub mod intention;
|
||||
pub mod longitudinal;
|
||||
pub mod tomography;
|
||||
|
||||
// ADR-032a: Midstreamer-enhanced sensing
|
||||
pub mod temporal_gesture;
|
||||
pub mod attractor_drift;
|
||||
|
||||
// ADR-029: Core multistatic pipeline
|
||||
pub mod coherence;
|
||||
pub mod coherence_gate;
|
||||
pub mod multiband;
|
||||
pub mod multistatic;
|
||||
pub mod phase_align;
|
||||
pub mod pose_tracker;
|
||||
|
||||
// Re-export core types for ergonomic access
|
||||
pub use coherence::CoherenceState;
|
||||
pub use coherence_gate::{GateDecision, GatePolicy};
|
||||
pub use multiband::MultiBandCsiFrame;
|
||||
pub use multistatic::FusedSensingFrame;
|
||||
pub use phase_align::{PhaseAligner, PhaseAlignError};
|
||||
pub use pose_tracker::{KeypointState, PoseTrack, TrackLifecycleState};
|
||||
|
||||
/// Number of keypoints in a full-body pose skeleton (COCO-17).
|
||||
pub const NUM_KEYPOINTS: usize = 17;
|
||||
|
||||
/// Keypoint indices following the COCO-17 convention.
|
||||
pub mod keypoint {
|
||||
pub const NOSE: usize = 0;
|
||||
pub const LEFT_EYE: usize = 1;
|
||||
pub const RIGHT_EYE: usize = 2;
|
||||
pub const LEFT_EAR: usize = 3;
|
||||
pub const RIGHT_EAR: usize = 4;
|
||||
pub const LEFT_SHOULDER: usize = 5;
|
||||
pub const RIGHT_SHOULDER: usize = 6;
|
||||
pub const LEFT_ELBOW: usize = 7;
|
||||
pub const RIGHT_ELBOW: usize = 8;
|
||||
pub const LEFT_WRIST: usize = 9;
|
||||
pub const RIGHT_WRIST: usize = 10;
|
||||
pub const LEFT_HIP: usize = 11;
|
||||
pub const RIGHT_HIP: usize = 12;
|
||||
pub const LEFT_KNEE: usize = 13;
|
||||
pub const RIGHT_KNEE: usize = 14;
|
||||
pub const LEFT_ANKLE: usize = 15;
|
||||
pub const RIGHT_ANKLE: usize = 16;
|
||||
|
||||
/// Torso keypoint indices (shoulders, hips, spine midpoint proxy).
|
||||
pub const TORSO_INDICES: &[usize] = &[
|
||||
LEFT_SHOULDER,
|
||||
RIGHT_SHOULDER,
|
||||
LEFT_HIP,
|
||||
RIGHT_HIP,
|
||||
];
|
||||
}
|
||||
|
||||
/// Unique identifier for a pose track.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct TrackId(pub u64);
|
||||
|
||||
impl TrackId {
|
||||
/// Create a new track identifier.
|
||||
pub fn new(id: u64) -> Self {
|
||||
Self(id)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TrackId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Track({})", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Error type shared across the RuvSense pipeline.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RuvSenseError {
|
||||
/// Phase alignment failed.
|
||||
#[error("Phase alignment error: {0}")]
|
||||
PhaseAlign(#[from] phase_align::PhaseAlignError),
|
||||
|
||||
/// Multi-band fusion error.
|
||||
#[error("Multi-band fusion error: {0}")]
|
||||
MultiBand(#[from] multiband::MultiBandError),
|
||||
|
||||
/// Multistatic fusion error.
|
||||
#[error("Multistatic fusion error: {0}")]
|
||||
Multistatic(#[from] multistatic::MultistaticError),
|
||||
|
||||
/// Coherence computation error.
|
||||
#[error("Coherence error: {0}")]
|
||||
Coherence(#[from] coherence::CoherenceError),
|
||||
|
||||
/// Pose tracker error.
|
||||
#[error("Pose tracker error: {0}")]
|
||||
PoseTracker(#[from] pose_tracker::PoseTrackerError),
|
||||
}
|
||||
|
||||
/// Common result type for RuvSense operations.
|
||||
pub type Result<T> = std::result::Result<T, RuvSenseError>;
|
||||
|
||||
/// Configuration for the RuvSense pipeline.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RuvSenseConfig {
|
||||
/// Maximum number of nodes in the multistatic mesh.
|
||||
pub max_nodes: usize,
|
||||
/// Target output rate in Hz.
|
||||
pub target_hz: f64,
|
||||
/// Number of channels in the hop sequence.
|
||||
pub num_channels: usize,
|
||||
/// Coherence accept threshold (default 0.85).
|
||||
pub coherence_accept: f32,
|
||||
/// Coherence drift threshold (default 0.5).
|
||||
pub coherence_drift: f32,
|
||||
/// Maximum stale frames before recalibration (default 200 = 10s at 20Hz).
|
||||
pub max_stale_frames: u64,
|
||||
/// Embedding dimension for AETHER re-ID (default 128).
|
||||
pub embedding_dim: usize,
|
||||
}
|
||||
|
||||
impl Default for RuvSenseConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_nodes: 4,
|
||||
target_hz: 20.0,
|
||||
num_channels: 3,
|
||||
coherence_accept: 0.85,
|
||||
coherence_drift: 0.5,
|
||||
max_stale_frames: 200,
|
||||
embedding_dim: 128,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Top-level pipeline orchestrator for RuvSense multistatic sensing.
|
||||
///
|
||||
/// Coordinates the flow from raw per-node CSI frames through multi-band
|
||||
/// fusion, phase alignment, multistatic fusion, coherence gating, and
|
||||
/// finally into the pose tracker.
|
||||
pub struct RuvSensePipeline {
|
||||
config: RuvSenseConfig,
|
||||
phase_aligner: PhaseAligner,
|
||||
coherence_state: CoherenceState,
|
||||
gate_policy: GatePolicy,
|
||||
frame_counter: u64,
|
||||
}
|
||||
|
||||
impl RuvSensePipeline {
|
||||
/// Create a new pipeline with default configuration.
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(RuvSenseConfig::default())
|
||||
}
|
||||
|
||||
/// Create a new pipeline with the given configuration.
|
||||
pub fn with_config(config: RuvSenseConfig) -> Self {
|
||||
let n_sub = 56; // canonical subcarrier count
|
||||
Self {
|
||||
phase_aligner: PhaseAligner::new(config.num_channels),
|
||||
coherence_state: CoherenceState::new(n_sub, config.coherence_accept),
|
||||
gate_policy: GatePolicy::new(
|
||||
config.coherence_accept,
|
||||
config.coherence_drift,
|
||||
config.max_stale_frames,
|
||||
),
|
||||
config,
|
||||
frame_counter: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a reference to the current pipeline configuration.
|
||||
pub fn config(&self) -> &RuvSenseConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Return the total number of frames processed.
|
||||
pub fn frame_count(&self) -> u64 {
|
||||
self.frame_counter
|
||||
}
|
||||
|
||||
/// Return a reference to the current coherence state.
|
||||
pub fn coherence_state(&self) -> &CoherenceState {
|
||||
&self.coherence_state
|
||||
}
|
||||
|
||||
/// Advance the frame counter (called once per sensing cycle).
|
||||
pub fn tick(&mut self) {
|
||||
self.frame_counter += 1;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RuvSensePipeline {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_config_values() {
|
||||
let cfg = RuvSenseConfig::default();
|
||||
assert_eq!(cfg.max_nodes, 4);
|
||||
assert!((cfg.target_hz - 20.0).abs() < f64::EPSILON);
|
||||
assert_eq!(cfg.num_channels, 3);
|
||||
assert!((cfg.coherence_accept - 0.85).abs() < f32::EPSILON);
|
||||
assert!((cfg.coherence_drift - 0.5).abs() < f32::EPSILON);
|
||||
assert_eq!(cfg.max_stale_frames, 200);
|
||||
assert_eq!(cfg.embedding_dim, 128);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pipeline_creation_defaults() {
|
||||
let pipe = RuvSensePipeline::new();
|
||||
assert_eq!(pipe.frame_count(), 0);
|
||||
assert_eq!(pipe.config().max_nodes, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pipeline_tick_increments() {
|
||||
let mut pipe = RuvSensePipeline::new();
|
||||
pipe.tick();
|
||||
pipe.tick();
|
||||
pipe.tick();
|
||||
assert_eq!(pipe.frame_count(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn track_id_display() {
|
||||
let tid = TrackId::new(42);
|
||||
assert_eq!(format!("{}", tid), "Track(42)");
|
||||
assert_eq!(tid.0, 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn track_id_equality() {
|
||||
assert_eq!(TrackId(1), TrackId(1));
|
||||
assert_ne!(TrackId(1), TrackId(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keypoint_constants() {
|
||||
assert_eq!(keypoint::NOSE, 0);
|
||||
assert_eq!(keypoint::LEFT_ANKLE, 15);
|
||||
assert_eq!(keypoint::RIGHT_ANKLE, 16);
|
||||
assert_eq!(keypoint::TORSO_INDICES.len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn num_keypoints_is_17() {
|
||||
assert_eq!(NUM_KEYPOINTS, 17);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_config_pipeline() {
|
||||
let cfg = RuvSenseConfig {
|
||||
max_nodes: 6,
|
||||
target_hz: 10.0,
|
||||
num_channels: 6,
|
||||
coherence_accept: 0.9,
|
||||
coherence_drift: 0.4,
|
||||
max_stale_frames: 100,
|
||||
embedding_dim: 64,
|
||||
};
|
||||
let pipe = RuvSensePipeline::with_config(cfg);
|
||||
assert_eq!(pipe.config().max_nodes, 6);
|
||||
assert!((pipe.config().target_hz - 10.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_display() {
|
||||
let err = RuvSenseError::Coherence(coherence::CoherenceError::EmptyInput);
|
||||
let msg = format!("{}", err);
|
||||
assert!(msg.contains("Coherence"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pipeline_coherence_state_accessible() {
|
||||
let pipe = RuvSensePipeline::new();
|
||||
let cs = pipe.coherence_state();
|
||||
assert!(cs.score() >= 0.0);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,441 @@
|
||||
//! Multi-Band CSI Frame Fusion (ADR-029 Section 2.3)
|
||||
//!
|
||||
//! Aggregates per-channel CSI frames from channel-hopping into a wideband
|
||||
//! virtual snapshot. An ESP32-S3 cycling through channels 1/6/11 at 50 ms
|
||||
//! dwell per channel yields 3 canonical-56 CSI rows per sensing cycle.
|
||||
//! This module fuses them into a single `MultiBandCsiFrame` annotated with
|
||||
//! center frequencies and cross-channel coherence.
|
||||
//!
|
||||
//! # RuVector Integration
|
||||
//!
|
||||
//! - `ruvector-attention` for cross-channel feature weighting (future)
|
||||
|
||||
use crate::hardware_norm::CanonicalCsiFrame;
|
||||
|
||||
/// Errors from multi-band frame fusion.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum MultiBandError {
|
||||
/// No channel frames provided.
|
||||
#[error("No channel frames provided for multi-band fusion")]
|
||||
NoFrames,
|
||||
|
||||
/// Mismatched subcarrier counts across channels.
|
||||
#[error("Subcarrier count mismatch: channel {channel_idx} has {got}, expected {expected}")]
|
||||
SubcarrierMismatch {
|
||||
channel_idx: usize,
|
||||
expected: usize,
|
||||
got: usize,
|
||||
},
|
||||
|
||||
/// Frequency list length does not match frame count.
|
||||
#[error("Frequency count ({freq_count}) does not match frame count ({frame_count})")]
|
||||
FrequencyCountMismatch { freq_count: usize, frame_count: usize },
|
||||
|
||||
/// Duplicate frequency in channel list.
|
||||
#[error("Duplicate frequency {freq_mhz} MHz at index {idx}")]
|
||||
DuplicateFrequency { freq_mhz: u32, idx: usize },
|
||||
}
|
||||
|
||||
/// Fused multi-band CSI from one node at one time slot.
|
||||
///
|
||||
/// Holds one canonical-56 row per channel, ordered by center frequency.
|
||||
/// The `coherence` field quantifies agreement across channels (0.0-1.0).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MultiBandCsiFrame {
|
||||
/// Originating node identifier (0-255).
|
||||
pub node_id: u8,
|
||||
/// Timestamp of the sensing cycle in microseconds.
|
||||
pub timestamp_us: u64,
|
||||
/// One canonical-56 CSI frame per channel, ordered by center frequency.
|
||||
pub channel_frames: Vec<CanonicalCsiFrame>,
|
||||
/// Center frequencies (MHz) for each channel row.
|
||||
pub frequencies_mhz: Vec<u32>,
|
||||
/// Cross-channel coherence score (0.0-1.0).
|
||||
pub coherence: f32,
|
||||
}
|
||||
|
||||
/// Configuration for the multi-band fusion process.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MultiBandConfig {
|
||||
/// Time window in microseconds within which frames are considered
|
||||
/// part of the same sensing cycle.
|
||||
pub window_us: u64,
|
||||
/// Expected number of channels per cycle.
|
||||
pub expected_channels: usize,
|
||||
/// Minimum coherence to accept the fused frame.
|
||||
pub min_coherence: f32,
|
||||
}
|
||||
|
||||
impl Default for MultiBandConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
window_us: 200_000, // 200 ms default window
|
||||
expected_channels: 3,
|
||||
min_coherence: 0.3,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for constructing a `MultiBandCsiFrame` from per-channel observations.
|
||||
#[derive(Debug)]
|
||||
pub struct MultiBandBuilder {
|
||||
node_id: u8,
|
||||
timestamp_us: u64,
|
||||
frames: Vec<CanonicalCsiFrame>,
|
||||
frequencies: Vec<u32>,
|
||||
}
|
||||
|
||||
impl MultiBandBuilder {
|
||||
/// Create a new builder for the given node and timestamp.
|
||||
pub fn new(node_id: u8, timestamp_us: u64) -> Self {
|
||||
Self {
|
||||
node_id,
|
||||
timestamp_us,
|
||||
frames: Vec::new(),
|
||||
frequencies: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a channel observation at the given center frequency.
|
||||
pub fn add_channel(
|
||||
mut self,
|
||||
frame: CanonicalCsiFrame,
|
||||
freq_mhz: u32,
|
||||
) -> Self {
|
||||
self.frames.push(frame);
|
||||
self.frequencies.push(freq_mhz);
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the fused multi-band frame.
|
||||
///
|
||||
/// Validates inputs, sorts by frequency, and computes cross-channel coherence.
|
||||
pub fn build(mut self) -> std::result::Result<MultiBandCsiFrame, MultiBandError> {
|
||||
if self.frames.is_empty() {
|
||||
return Err(MultiBandError::NoFrames);
|
||||
}
|
||||
|
||||
if self.frequencies.len() != self.frames.len() {
|
||||
return Err(MultiBandError::FrequencyCountMismatch {
|
||||
freq_count: self.frequencies.len(),
|
||||
frame_count: self.frames.len(),
|
||||
});
|
||||
}
|
||||
|
||||
// Check for duplicate frequencies
|
||||
for i in 0..self.frequencies.len() {
|
||||
for j in (i + 1)..self.frequencies.len() {
|
||||
if self.frequencies[i] == self.frequencies[j] {
|
||||
return Err(MultiBandError::DuplicateFrequency {
|
||||
freq_mhz: self.frequencies[i],
|
||||
idx: j,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate consistent subcarrier counts
|
||||
let expected_len = self.frames[0].amplitude.len();
|
||||
for (i, frame) in self.frames.iter().enumerate().skip(1) {
|
||||
if frame.amplitude.len() != expected_len {
|
||||
return Err(MultiBandError::SubcarrierMismatch {
|
||||
channel_idx: i,
|
||||
expected: expected_len,
|
||||
got: frame.amplitude.len(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Sort frames by frequency
|
||||
let mut indices: Vec<usize> = (0..self.frames.len()).collect();
|
||||
indices.sort_by_key(|&i| self.frequencies[i]);
|
||||
|
||||
let sorted_frames: Vec<CanonicalCsiFrame> =
|
||||
indices.iter().map(|&i| self.frames[i].clone()).collect();
|
||||
let sorted_freqs: Vec<u32> =
|
||||
indices.iter().map(|&i| self.frequencies[i]).collect();
|
||||
|
||||
self.frames = sorted_frames;
|
||||
self.frequencies = sorted_freqs;
|
||||
|
||||
// Compute cross-channel coherence
|
||||
let coherence = compute_cross_channel_coherence(&self.frames);
|
||||
|
||||
Ok(MultiBandCsiFrame {
|
||||
node_id: self.node_id,
|
||||
timestamp_us: self.timestamp_us,
|
||||
channel_frames: self.frames,
|
||||
frequencies_mhz: self.frequencies,
|
||||
coherence,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute cross-channel coherence as the mean pairwise Pearson correlation
|
||||
/// of amplitude vectors across all channel pairs.
|
||||
///
|
||||
/// Returns a value in [0.0, 1.0] where 1.0 means perfect correlation.
|
||||
fn compute_cross_channel_coherence(frames: &[CanonicalCsiFrame]) -> f32 {
|
||||
if frames.len() < 2 {
|
||||
return 1.0; // single channel is trivially coherent
|
||||
}
|
||||
|
||||
let mut total_corr = 0.0_f64;
|
||||
let mut pair_count = 0u32;
|
||||
|
||||
for i in 0..frames.len() {
|
||||
for j in (i + 1)..frames.len() {
|
||||
let corr = pearson_correlation_f32(
|
||||
&frames[i].amplitude,
|
||||
&frames[j].amplitude,
|
||||
);
|
||||
total_corr += corr as f64;
|
||||
pair_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if pair_count == 0 {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
// Map correlation [-1, 1] to coherence [0, 1]
|
||||
let mean_corr = total_corr / pair_count as f64;
|
||||
((mean_corr + 1.0) / 2.0).clamp(0.0, 1.0) as f32
|
||||
}
|
||||
|
||||
/// Pearson correlation coefficient between two f32 slices.
|
||||
fn pearson_correlation_f32(a: &[f32], b: &[f32]) -> f32 {
|
||||
let n = a.len().min(b.len());
|
||||
if n == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let n_f = n as f32;
|
||||
let mean_a: f32 = a[..n].iter().sum::<f32>() / n_f;
|
||||
let mean_b: f32 = b[..n].iter().sum::<f32>() / n_f;
|
||||
|
||||
let mut cov = 0.0_f32;
|
||||
let mut var_a = 0.0_f32;
|
||||
let mut var_b = 0.0_f32;
|
||||
|
||||
for i in 0..n {
|
||||
let da = a[i] - mean_a;
|
||||
let db = b[i] - mean_b;
|
||||
cov += da * db;
|
||||
var_a += da * da;
|
||||
var_b += db * db;
|
||||
}
|
||||
|
||||
let denom = (var_a * var_b).sqrt();
|
||||
if denom < 1e-12 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
(cov / denom).clamp(-1.0, 1.0)
|
||||
}
|
||||
|
||||
/// Concatenate the amplitude vectors from all channels into a single
|
||||
/// wideband amplitude vector. Useful for downstream models that expect
|
||||
/// a flat feature vector.
|
||||
pub fn concatenate_amplitudes(frame: &MultiBandCsiFrame) -> Vec<f32> {
|
||||
let total_len: usize = frame.channel_frames.iter().map(|f| f.amplitude.len()).sum();
|
||||
let mut out = Vec::with_capacity(total_len);
|
||||
for cf in &frame.channel_frames {
|
||||
out.extend_from_slice(&cf.amplitude);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Compute the mean amplitude across all channels, producing a single
|
||||
/// canonical-length vector that averages multi-band observations.
|
||||
pub fn mean_amplitude(frame: &MultiBandCsiFrame) -> Vec<f32> {
|
||||
if frame.channel_frames.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let n_sub = frame.channel_frames[0].amplitude.len();
|
||||
let n_ch = frame.channel_frames.len() as f32;
|
||||
let mut mean = vec![0.0_f32; n_sub];
|
||||
|
||||
for cf in &frame.channel_frames {
|
||||
for (i, &val) in cf.amplitude.iter().enumerate() {
|
||||
if i < n_sub {
|
||||
mean[i] += val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for v in &mut mean {
|
||||
*v /= n_ch;
|
||||
}
|
||||
|
||||
mean
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::hardware_norm::HardwareType;
|
||||
|
||||
fn make_canonical(amplitude: Vec<f32>, phase: Vec<f32>) -> CanonicalCsiFrame {
|
||||
CanonicalCsiFrame {
|
||||
amplitude,
|
||||
phase,
|
||||
hardware_type: HardwareType::Esp32S3,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_frame(n_sub: usize, scale: f32) -> CanonicalCsiFrame {
|
||||
let amp: Vec<f32> = (0..n_sub).map(|i| scale * (i as f32 * 0.1).sin()).collect();
|
||||
let phase: Vec<f32> = (0..n_sub).map(|i| (i as f32 * 0.05).cos()).collect();
|
||||
make_canonical(amp, phase)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_single_channel() {
|
||||
let frame = MultiBandBuilder::new(0, 1000)
|
||||
.add_channel(make_frame(56, 1.0), 2412)
|
||||
.build()
|
||||
.unwrap();
|
||||
assert_eq!(frame.node_id, 0);
|
||||
assert_eq!(frame.timestamp_us, 1000);
|
||||
assert_eq!(frame.channel_frames.len(), 1);
|
||||
assert_eq!(frame.frequencies_mhz, vec![2412]);
|
||||
assert!((frame.coherence - 1.0).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_three_channels_sorted_by_freq() {
|
||||
let frame = MultiBandBuilder::new(1, 2000)
|
||||
.add_channel(make_frame(56, 1.0), 2462) // ch 11
|
||||
.add_channel(make_frame(56, 1.0), 2412) // ch 1
|
||||
.add_channel(make_frame(56, 1.0), 2437) // ch 6
|
||||
.build()
|
||||
.unwrap();
|
||||
assert_eq!(frame.frequencies_mhz, vec![2412, 2437, 2462]);
|
||||
assert_eq!(frame.channel_frames.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_frames_error() {
|
||||
let result = MultiBandBuilder::new(0, 0).build();
|
||||
assert!(matches!(result, Err(MultiBandError::NoFrames)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subcarrier_mismatch_error() {
|
||||
let result = MultiBandBuilder::new(0, 0)
|
||||
.add_channel(make_frame(56, 1.0), 2412)
|
||||
.add_channel(make_frame(30, 1.0), 2437)
|
||||
.build();
|
||||
assert!(matches!(result, Err(MultiBandError::SubcarrierMismatch { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn duplicate_frequency_error() {
|
||||
let result = MultiBandBuilder::new(0, 0)
|
||||
.add_channel(make_frame(56, 1.0), 2412)
|
||||
.add_channel(make_frame(56, 1.0), 2412)
|
||||
.build();
|
||||
assert!(matches!(result, Err(MultiBandError::DuplicateFrequency { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coherence_identical_channels() {
|
||||
let f = make_frame(56, 1.0);
|
||||
let frame = MultiBandBuilder::new(0, 0)
|
||||
.add_channel(f.clone(), 2412)
|
||||
.add_channel(f.clone(), 2437)
|
||||
.build()
|
||||
.unwrap();
|
||||
// Identical channels should have coherence == 1.0
|
||||
assert!((frame.coherence - 1.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coherence_orthogonal_channels() {
|
||||
let n = 56;
|
||||
let amp_a: Vec<f32> = (0..n).map(|i| (i as f32 * 0.3).sin()).collect();
|
||||
let amp_b: Vec<f32> = (0..n).map(|i| (i as f32 * 0.3).cos()).collect();
|
||||
let ph = vec![0.0_f32; n];
|
||||
|
||||
let frame = MultiBandBuilder::new(0, 0)
|
||||
.add_channel(make_canonical(amp_a, ph.clone()), 2412)
|
||||
.add_channel(make_canonical(amp_b, ph), 2437)
|
||||
.build()
|
||||
.unwrap();
|
||||
// Orthogonal signals should produce lower coherence
|
||||
assert!(frame.coherence < 0.9);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn concatenate_amplitudes_correct_length() {
|
||||
let frame = MultiBandBuilder::new(0, 0)
|
||||
.add_channel(make_frame(56, 1.0), 2412)
|
||||
.add_channel(make_frame(56, 2.0), 2437)
|
||||
.add_channel(make_frame(56, 3.0), 2462)
|
||||
.build()
|
||||
.unwrap();
|
||||
let concat = concatenate_amplitudes(&frame);
|
||||
assert_eq!(concat.len(), 56 * 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mean_amplitude_correct() {
|
||||
let n = 4;
|
||||
let f1 = make_canonical(vec![1.0, 2.0, 3.0, 4.0], vec![0.0; n]);
|
||||
let f2 = make_canonical(vec![3.0, 4.0, 5.0, 6.0], vec![0.0; n]);
|
||||
let frame = MultiBandBuilder::new(0, 0)
|
||||
.add_channel(f1, 2412)
|
||||
.add_channel(f2, 2437)
|
||||
.build()
|
||||
.unwrap();
|
||||
let m = mean_amplitude(&frame);
|
||||
assert_eq!(m.len(), 4);
|
||||
assert!((m[0] - 2.0).abs() < 1e-6);
|
||||
assert!((m[1] - 3.0).abs() < 1e-6);
|
||||
assert!((m[2] - 4.0).abs() < 1e-6);
|
||||
assert!((m[3] - 5.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mean_amplitude_empty() {
|
||||
let frame = MultiBandCsiFrame {
|
||||
node_id: 0,
|
||||
timestamp_us: 0,
|
||||
channel_frames: vec![],
|
||||
frequencies_mhz: vec![],
|
||||
coherence: 1.0,
|
||||
};
|
||||
assert!(mean_amplitude(&frame).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pearson_correlation_perfect() {
|
||||
let a = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let b = vec![2.0_f32, 4.0, 6.0, 8.0, 10.0];
|
||||
let r = pearson_correlation_f32(&a, &b);
|
||||
assert!((r - 1.0).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pearson_correlation_negative() {
|
||||
let a = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let b = vec![5.0_f32, 4.0, 3.0, 2.0, 1.0];
|
||||
let r = pearson_correlation_f32(&a, &b);
|
||||
assert!((r + 1.0).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pearson_correlation_empty() {
|
||||
assert_eq!(pearson_correlation_f32(&[], &[]), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_config() {
|
||||
let cfg = MultiBandConfig::default();
|
||||
assert_eq!(cfg.expected_channels, 3);
|
||||
assert_eq!(cfg.window_us, 200_000);
|
||||
assert!((cfg.min_coherence - 0.3).abs() < f32::EPSILON);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,562 @@
|
||||
//! Multistatic Viewpoint Fusion (ADR-029 Section 2.4)
|
||||
//!
|
||||
//! With N ESP32 nodes in a TDMA mesh, each sensing cycle produces N
|
||||
//! `MultiBandCsiFrame`s. This module fuses them into a single
|
||||
//! `FusedSensingFrame` using attention-based cross-node weighting.
|
||||
//!
|
||||
//! # Algorithm
|
||||
//!
|
||||
//! 1. Collect N `MultiBandCsiFrame`s from the current sensing cycle.
|
||||
//! 2. Use `ruvector-attn-mincut` for cross-node attention: cells showing
|
||||
//! correlated motion energy across nodes (body reflection) are amplified;
|
||||
//! cells with single-node energy (multipath artifact) are suppressed.
|
||||
//! 3. Multi-person separation via `ruvector-mincut::DynamicMinCut` builds
|
||||
//! a cross-link correlation graph and partitions into K person clusters.
|
||||
//!
|
||||
//! # RuVector Integration
|
||||
//!
|
||||
//! - `ruvector-attn-mincut` for cross-node spectrogram attention gating
|
||||
//! - `ruvector-mincut` for person separation (DynamicMinCut)
|
||||
|
||||
use super::multiband::MultiBandCsiFrame;
|
||||
|
||||
/// Errors from multistatic fusion.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum MultistaticError {
|
||||
/// No node frames provided.
|
||||
#[error("No node frames provided for multistatic fusion")]
|
||||
NoFrames,
|
||||
|
||||
/// Insufficient nodes for multistatic mode (need at least 2).
|
||||
#[error("Need at least 2 nodes for multistatic fusion, got {0}")]
|
||||
InsufficientNodes(usize),
|
||||
|
||||
/// Timestamp mismatch beyond guard interval.
|
||||
#[error("Timestamp spread {spread_us} us exceeds guard interval {guard_us} us")]
|
||||
TimestampMismatch { spread_us: u64, guard_us: u64 },
|
||||
|
||||
/// Dimension mismatch in fusion inputs.
|
||||
#[error("Dimension mismatch: node {node_idx} has {got} subcarriers, expected {expected}")]
|
||||
DimensionMismatch {
|
||||
node_idx: usize,
|
||||
expected: usize,
|
||||
got: usize,
|
||||
},
|
||||
}
|
||||
|
||||
/// A fused sensing frame from all nodes at one sensing cycle.
|
||||
///
|
||||
/// This is the primary output of the multistatic fusion stage and serves
|
||||
/// as input to model inference and the pose tracker.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FusedSensingFrame {
|
||||
/// Timestamp of this sensing cycle in microseconds.
|
||||
pub timestamp_us: u64,
|
||||
/// Fused amplitude vector across all nodes (attention-weighted mean).
|
||||
/// Length = n_subcarriers.
|
||||
pub fused_amplitude: Vec<f32>,
|
||||
/// Fused phase vector across all nodes.
|
||||
/// Length = n_subcarriers.
|
||||
pub fused_phase: Vec<f32>,
|
||||
/// Per-node multi-band frames (preserved for geometry computations).
|
||||
pub node_frames: Vec<MultiBandCsiFrame>,
|
||||
/// Node positions (x, y, z) in meters from deployment configuration.
|
||||
pub node_positions: Vec<[f32; 3]>,
|
||||
/// Number of active nodes contributing to this frame.
|
||||
pub active_nodes: usize,
|
||||
/// Cross-node coherence score (0.0-1.0). Higher means more agreement
|
||||
/// across viewpoints, indicating a strong body reflection signal.
|
||||
pub cross_node_coherence: f32,
|
||||
}
|
||||
|
||||
/// Configuration for multistatic fusion.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MultistaticConfig {
|
||||
/// Maximum timestamp spread (microseconds) across nodes in one cycle.
|
||||
/// Default: 5000 us (5 ms), well within the 50 ms TDMA cycle.
|
||||
pub guard_interval_us: u64,
|
||||
/// Minimum number of nodes for multistatic mode.
|
||||
/// Falls back to single-node mode if fewer nodes are available.
|
||||
pub min_nodes: usize,
|
||||
/// Attention temperature for cross-node weighting.
|
||||
/// Lower temperature -> sharper attention (fewer nodes dominate).
|
||||
pub attention_temperature: f32,
|
||||
/// Whether to enable person separation via min-cut.
|
||||
pub enable_person_separation: bool,
|
||||
}
|
||||
|
||||
impl Default for MultistaticConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
guard_interval_us: 5000,
|
||||
min_nodes: 2,
|
||||
attention_temperature: 1.0,
|
||||
enable_person_separation: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Multistatic frame fuser.
|
||||
///
|
||||
/// Collects per-node multi-band frames and produces a single fused
|
||||
/// sensing frame per TDMA cycle.
|
||||
#[derive(Debug)]
|
||||
pub struct MultistaticFuser {
|
||||
config: MultistaticConfig,
|
||||
/// Node positions in 3D space (meters).
|
||||
node_positions: Vec<[f32; 3]>,
|
||||
}
|
||||
|
||||
impl MultistaticFuser {
|
||||
/// Create a fuser with default configuration and no node positions.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
config: MultistaticConfig::default(),
|
||||
node_positions: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a fuser with custom configuration.
|
||||
pub fn with_config(config: MultistaticConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
node_positions: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set node positions for geometric diversity computations.
|
||||
pub fn set_node_positions(&mut self, positions: Vec<[f32; 3]>) {
|
||||
self.node_positions = positions;
|
||||
}
|
||||
|
||||
/// Return the current node positions.
|
||||
pub fn node_positions(&self) -> &[[f32; 3]] {
|
||||
&self.node_positions
|
||||
}
|
||||
|
||||
/// Fuse multiple node frames into a single `FusedSensingFrame`.
|
||||
///
|
||||
/// When only one node is provided, falls back to single-node mode
|
||||
/// (no cross-node attention). When two or more nodes are available,
|
||||
/// applies attention-weighted fusion.
|
||||
pub fn fuse(
|
||||
&self,
|
||||
node_frames: &[MultiBandCsiFrame],
|
||||
) -> std::result::Result<FusedSensingFrame, MultistaticError> {
|
||||
if node_frames.is_empty() {
|
||||
return Err(MultistaticError::NoFrames);
|
||||
}
|
||||
|
||||
// Validate timestamp spread
|
||||
if node_frames.len() > 1 {
|
||||
let min_ts = node_frames.iter().map(|f| f.timestamp_us).min().unwrap();
|
||||
let max_ts = node_frames.iter().map(|f| f.timestamp_us).max().unwrap();
|
||||
let spread = max_ts - min_ts;
|
||||
if spread > self.config.guard_interval_us {
|
||||
return Err(MultistaticError::TimestampMismatch {
|
||||
spread_us: spread,
|
||||
guard_us: self.config.guard_interval_us,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Extract per-node amplitude vectors from first channel of each node
|
||||
let amplitudes: Vec<&[f32]> = node_frames
|
||||
.iter()
|
||||
.filter_map(|f| f.channel_frames.first().map(|cf| cf.amplitude.as_slice()))
|
||||
.collect();
|
||||
|
||||
let phases: Vec<&[f32]> = node_frames
|
||||
.iter()
|
||||
.filter_map(|f| f.channel_frames.first().map(|cf| cf.phase.as_slice()))
|
||||
.collect();
|
||||
|
||||
if amplitudes.is_empty() {
|
||||
return Err(MultistaticError::NoFrames);
|
||||
}
|
||||
|
||||
// Validate dimension consistency
|
||||
let n_sub = amplitudes[0].len();
|
||||
for (i, amp) in amplitudes.iter().enumerate().skip(1) {
|
||||
if amp.len() != n_sub {
|
||||
return Err(MultistaticError::DimensionMismatch {
|
||||
node_idx: i,
|
||||
expected: n_sub,
|
||||
got: amp.len(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let n_nodes = amplitudes.len();
|
||||
let (fused_amp, fused_ph, coherence) = if n_nodes == 1 {
|
||||
// Single-node fallback
|
||||
(
|
||||
amplitudes[0].to_vec(),
|
||||
phases[0].to_vec(),
|
||||
1.0_f32,
|
||||
)
|
||||
} else {
|
||||
// Multi-node attention-weighted fusion
|
||||
attention_weighted_fusion(&litudes, &phases, self.config.attention_temperature)
|
||||
};
|
||||
|
||||
// Derive timestamp from median
|
||||
let mut timestamps: Vec<u64> = node_frames.iter().map(|f| f.timestamp_us).collect();
|
||||
timestamps.sort_unstable();
|
||||
let timestamp_us = timestamps[timestamps.len() / 2];
|
||||
|
||||
// Build node positions list, filling with origin for unknown nodes
|
||||
let positions: Vec<[f32; 3]> = (0..n_nodes)
|
||||
.map(|i| {
|
||||
self.node_positions
|
||||
.get(i)
|
||||
.copied()
|
||||
.unwrap_or([0.0, 0.0, 0.0])
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(FusedSensingFrame {
|
||||
timestamp_us,
|
||||
fused_amplitude: fused_amp,
|
||||
fused_phase: fused_ph,
|
||||
node_frames: node_frames.to_vec(),
|
||||
node_positions: positions,
|
||||
active_nodes: n_nodes,
|
||||
cross_node_coherence: coherence,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MultistaticFuser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Attention-weighted fusion of amplitude and phase vectors from multiple nodes.
|
||||
///
|
||||
/// Each node's contribution is weighted by its agreement with the consensus.
|
||||
/// Returns (fused_amplitude, fused_phase, cross_node_coherence).
|
||||
fn attention_weighted_fusion(
|
||||
amplitudes: &[&[f32]],
|
||||
phases: &[&[f32]],
|
||||
temperature: f32,
|
||||
) -> (Vec<f32>, Vec<f32>, f32) {
|
||||
let n_nodes = amplitudes.len();
|
||||
let n_sub = amplitudes[0].len();
|
||||
|
||||
// Compute mean amplitude as consensus reference
|
||||
let mut mean_amp = vec![0.0_f32; n_sub];
|
||||
for amp in amplitudes {
|
||||
for (i, &v) in amp.iter().enumerate() {
|
||||
mean_amp[i] += v;
|
||||
}
|
||||
}
|
||||
for v in &mut mean_amp {
|
||||
*v /= n_nodes as f32;
|
||||
}
|
||||
|
||||
// Compute attention weights based on similarity to consensus
|
||||
let mut logits = vec![0.0_f32; n_nodes];
|
||||
for (n, amp) in amplitudes.iter().enumerate() {
|
||||
let mut dot = 0.0_f32;
|
||||
let mut norm_a = 0.0_f32;
|
||||
let mut norm_b = 0.0_f32;
|
||||
for i in 0..n_sub {
|
||||
dot += amp[i] * mean_amp[i];
|
||||
norm_a += amp[i] * amp[i];
|
||||
norm_b += mean_amp[i] * mean_amp[i];
|
||||
}
|
||||
let denom = (norm_a * norm_b).sqrt().max(1e-12);
|
||||
let similarity = dot / denom;
|
||||
logits[n] = similarity / temperature;
|
||||
}
|
||||
|
||||
// Numerically stable softmax: subtract max to prevent exp() overflow
|
||||
let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let mut weights = vec![0.0_f32; n_nodes];
|
||||
for (n, &logit) in logits.iter().enumerate() {
|
||||
weights[n] = (logit - max_logit).exp();
|
||||
}
|
||||
let weight_sum: f32 = weights.iter().sum::<f32>().max(1e-12);
|
||||
for w in &mut weights {
|
||||
*w /= weight_sum;
|
||||
}
|
||||
|
||||
// Weighted fusion
|
||||
let mut fused_amp = vec![0.0_f32; n_sub];
|
||||
let mut fused_ph_sin = vec![0.0_f32; n_sub];
|
||||
let mut fused_ph_cos = vec![0.0_f32; n_sub];
|
||||
|
||||
for (n, (&, &ph)) in amplitudes.iter().zip(phases.iter()).enumerate() {
|
||||
let w = weights[n];
|
||||
for i in 0..n_sub {
|
||||
fused_amp[i] += w * amp[i];
|
||||
fused_ph_sin[i] += w * ph[i].sin();
|
||||
fused_ph_cos[i] += w * ph[i].cos();
|
||||
}
|
||||
}
|
||||
|
||||
// Recover phase from sin/cos weighted average
|
||||
let fused_ph: Vec<f32> = fused_ph_sin
|
||||
.iter()
|
||||
.zip(fused_ph_cos.iter())
|
||||
.map(|(&s, &c)| s.atan2(c))
|
||||
.collect();
|
||||
|
||||
// Coherence = mean weight entropy proxy: high when weights are balanced
|
||||
let coherence = compute_weight_coherence(&weights);
|
||||
|
||||
(fused_amp, fused_ph, coherence)
|
||||
}
|
||||
|
||||
/// Compute coherence from attention weights.
|
||||
///
|
||||
/// Returns 1.0 when all weights are equal (all nodes agree),
|
||||
/// and approaches 0.0 when a single node dominates.
|
||||
fn compute_weight_coherence(weights: &[f32]) -> f32 {
|
||||
let n = weights.len() as f32;
|
||||
if n <= 1.0 {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
// Normalized entropy: H / log(n)
|
||||
let max_entropy = n.ln();
|
||||
if max_entropy < 1e-12 {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
let entropy: f32 = weights
|
||||
.iter()
|
||||
.filter(|&&w| w > 1e-12)
|
||||
.map(|&w| -w * w.ln())
|
||||
.sum();
|
||||
|
||||
(entropy / max_entropy).clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
/// Compute the geometric diversity score for a set of node positions.
|
||||
///
|
||||
/// Returns a value in [0.0, 1.0] where 1.0 indicates maximum angular
|
||||
/// coverage. Based on the angular span of node positions relative to the
|
||||
/// room centroid.
|
||||
pub fn geometric_diversity(positions: &[[f32; 3]]) -> f32 {
|
||||
if positions.len() < 2 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Compute centroid
|
||||
let n = positions.len() as f32;
|
||||
let centroid = [
|
||||
positions.iter().map(|p| p[0]).sum::<f32>() / n,
|
||||
positions.iter().map(|p| p[1]).sum::<f32>() / n,
|
||||
positions.iter().map(|p| p[2]).sum::<f32>() / n,
|
||||
];
|
||||
|
||||
// Compute angles from centroid to each node (in 2D, ignoring z)
|
||||
let mut angles: Vec<f32> = positions
|
||||
.iter()
|
||||
.map(|p| {
|
||||
let dx = p[0] - centroid[0];
|
||||
let dy = p[1] - centroid[1];
|
||||
dy.atan2(dx)
|
||||
})
|
||||
.collect();
|
||||
|
||||
angles.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
// Angular coverage: sum of gaps, diversity is high when gaps are even
|
||||
let mut max_gap = 0.0_f32;
|
||||
for i in 0..angles.len() {
|
||||
let next = (i + 1) % angles.len();
|
||||
let mut gap = angles[next] - angles[i];
|
||||
if gap < 0.0 {
|
||||
gap += 2.0 * std::f32::consts::PI;
|
||||
}
|
||||
max_gap = max_gap.max(gap);
|
||||
}
|
||||
|
||||
// Perfect coverage (N equidistant nodes): max_gap = 2*pi/N
|
||||
// Worst case (all co-located): max_gap = 2*pi
|
||||
let ideal_gap = 2.0 * std::f32::consts::PI / positions.len() as f32;
|
||||
let diversity = (ideal_gap / max_gap.max(1e-6)).clamp(0.0, 1.0);
|
||||
diversity
|
||||
}
|
||||
|
||||
/// Represents a cluster of TX-RX links attributed to one person.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PersonCluster {
|
||||
/// Cluster identifier.
|
||||
pub id: usize,
|
||||
/// Indices into the link array belonging to this cluster.
|
||||
pub link_indices: Vec<usize>,
|
||||
/// Mean correlation strength within the cluster.
|
||||
pub intra_correlation: f32,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::hardware_norm::{CanonicalCsiFrame, HardwareType};
|
||||
|
||||
fn make_node_frame(
|
||||
node_id: u8,
|
||||
timestamp_us: u64,
|
||||
n_sub: usize,
|
||||
scale: f32,
|
||||
) -> MultiBandCsiFrame {
|
||||
let amp: Vec<f32> = (0..n_sub).map(|i| scale * (1.0 + 0.1 * i as f32)).collect();
|
||||
let phase: Vec<f32> = (0..n_sub).map(|i| i as f32 * 0.05).collect();
|
||||
MultiBandCsiFrame {
|
||||
node_id,
|
||||
timestamp_us,
|
||||
channel_frames: vec![CanonicalCsiFrame {
|
||||
amplitude: amp,
|
||||
phase,
|
||||
hardware_type: HardwareType::Esp32S3,
|
||||
}],
|
||||
frequencies_mhz: vec![2412],
|
||||
coherence: 0.9,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fuse_single_node_fallback() {
|
||||
let fuser = MultistaticFuser::new();
|
||||
let frames = vec![make_node_frame(0, 1000, 56, 1.0)];
|
||||
let fused = fuser.fuse(&frames).unwrap();
|
||||
assert_eq!(fused.active_nodes, 1);
|
||||
assert_eq!(fused.fused_amplitude.len(), 56);
|
||||
assert!((fused.cross_node_coherence - 1.0).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fuse_two_identical_nodes() {
|
||||
let fuser = MultistaticFuser::new();
|
||||
let f0 = make_node_frame(0, 1000, 56, 1.0);
|
||||
let f1 = make_node_frame(1, 1001, 56, 1.0);
|
||||
let fused = fuser.fuse(&[f0, f1]).unwrap();
|
||||
assert_eq!(fused.active_nodes, 2);
|
||||
assert_eq!(fused.fused_amplitude.len(), 56);
|
||||
// Identical nodes -> high coherence
|
||||
assert!(fused.cross_node_coherence > 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fuse_four_nodes() {
|
||||
let fuser = MultistaticFuser::new();
|
||||
let frames: Vec<MultiBandCsiFrame> = (0..4)
|
||||
.map(|i| make_node_frame(i, 1000 + i as u64, 56, 1.0 + 0.1 * i as f32))
|
||||
.collect();
|
||||
let fused = fuser.fuse(&frames).unwrap();
|
||||
assert_eq!(fused.active_nodes, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_frames_error() {
|
||||
let fuser = MultistaticFuser::new();
|
||||
assert!(matches!(fuser.fuse(&[]), Err(MultistaticError::NoFrames)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn timestamp_mismatch_error() {
|
||||
let config = MultistaticConfig {
|
||||
guard_interval_us: 100,
|
||||
..Default::default()
|
||||
};
|
||||
let fuser = MultistaticFuser::with_config(config);
|
||||
let f0 = make_node_frame(0, 0, 56, 1.0);
|
||||
let f1 = make_node_frame(1, 200, 56, 1.0);
|
||||
assert!(matches!(
|
||||
fuser.fuse(&[f0, f1]),
|
||||
Err(MultistaticError::TimestampMismatch { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dimension_mismatch_error() {
|
||||
let fuser = MultistaticFuser::new();
|
||||
let f0 = make_node_frame(0, 1000, 56, 1.0);
|
||||
let f1 = make_node_frame(1, 1001, 30, 1.0);
|
||||
assert!(matches!(
|
||||
fuser.fuse(&[f0, f1]),
|
||||
Err(MultistaticError::DimensionMismatch { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn node_positions_set_and_retrieved() {
|
||||
let mut fuser = MultistaticFuser::new();
|
||||
let positions = vec![[0.0, 0.0, 1.0], [3.0, 0.0, 1.0]];
|
||||
fuser.set_node_positions(positions.clone());
|
||||
assert_eq!(fuser.node_positions(), &positions[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fused_positions_filled() {
|
||||
let mut fuser = MultistaticFuser::new();
|
||||
fuser.set_node_positions(vec![[1.0, 2.0, 3.0]]);
|
||||
let frames = vec![
|
||||
make_node_frame(0, 100, 56, 1.0),
|
||||
make_node_frame(1, 101, 56, 1.0),
|
||||
];
|
||||
let fused = fuser.fuse(&frames).unwrap();
|
||||
assert_eq!(fused.node_positions[0], [1.0, 2.0, 3.0]);
|
||||
assert_eq!(fused.node_positions[1], [0.0, 0.0, 0.0]); // default
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn geometric_diversity_single_node() {
|
||||
assert_eq!(geometric_diversity(&[[0.0, 0.0, 0.0]]), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn geometric_diversity_two_opposite() {
|
||||
let score = geometric_diversity(&[[-1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]);
|
||||
assert!(score > 0.8, "Two opposite nodes should have high diversity: {}", score);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn geometric_diversity_four_corners() {
|
||||
let score = geometric_diversity(&[
|
||||
[0.0, 0.0, 0.0],
|
||||
[5.0, 0.0, 0.0],
|
||||
[5.0, 5.0, 0.0],
|
||||
[0.0, 5.0, 0.0],
|
||||
]);
|
||||
assert!(score > 0.7, "Four corners should have good diversity: {}", score);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn weight_coherence_uniform() {
|
||||
let weights = vec![0.25, 0.25, 0.25, 0.25];
|
||||
let c = compute_weight_coherence(&weights);
|
||||
assert!((c - 1.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn weight_coherence_single_dominant() {
|
||||
let weights = vec![0.97, 0.01, 0.01, 0.01];
|
||||
let c = compute_weight_coherence(&weights);
|
||||
assert!(c < 0.3, "Single dominant node should have low coherence: {}", c);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_config() {
|
||||
let cfg = MultistaticConfig::default();
|
||||
assert_eq!(cfg.guard_interval_us, 5000);
|
||||
assert_eq!(cfg.min_nodes, 2);
|
||||
assert!((cfg.attention_temperature - 1.0).abs() < f32::EPSILON);
|
||||
assert!(cfg.enable_person_separation);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn person_cluster_creation() {
|
||||
let cluster = PersonCluster {
|
||||
id: 0,
|
||||
link_indices: vec![0, 1, 3],
|
||||
intra_correlation: 0.85,
|
||||
};
|
||||
assert_eq!(cluster.link_indices.len(), 3);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,457 @@
|
||||
//! Cross-Channel Phase Alignment (ADR-029 Section 2.3)
|
||||
//!
|
||||
//! When the ESP32 hops between WiFi channels, the local oscillator (LO)
|
||||
//! introduces a channel-dependent phase rotation. The observed phase on
|
||||
//! channel c is:
|
||||
//!
|
||||
//! phi_c = phi_body + delta_c
|
||||
//!
|
||||
//! where `delta_c` is the LO offset for channel c. This module estimates
|
||||
//! and removes the `delta_c` offsets by fitting against the static
|
||||
//! subcarrier components, which should have zero body-caused phase shift.
|
||||
//!
|
||||
//! # RuVector Integration
|
||||
//!
|
||||
//! Uses `ruvector-solver::NeumannSolver` concepts for iterative convergence
|
||||
//! on the phase offset estimation. The solver achieves O(sqrt(n)) convergence.
|
||||
|
||||
use crate::hardware_norm::CanonicalCsiFrame;
|
||||
use std::f32::consts::PI;
|
||||
|
||||
/// Errors from phase alignment.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum PhaseAlignError {
|
||||
/// No frames provided.
|
||||
#[error("No frames provided for phase alignment")]
|
||||
NoFrames,
|
||||
|
||||
/// Insufficient static subcarriers for alignment.
|
||||
#[error("Need at least {needed} static subcarriers, found {found}")]
|
||||
InsufficientStatic { needed: usize, found: usize },
|
||||
|
||||
/// Phase data length mismatch.
|
||||
#[error("Phase length {got} does not match expected {expected}")]
|
||||
PhaseLengthMismatch { expected: usize, got: usize },
|
||||
|
||||
/// Convergence failure.
|
||||
#[error("Phase alignment failed to converge after {iterations} iterations")]
|
||||
ConvergenceFailed { iterations: usize },
|
||||
}
|
||||
|
||||
/// Configuration for the phase aligner.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PhaseAlignConfig {
|
||||
/// Maximum iterations for the Neumann solver.
|
||||
pub max_iterations: usize,
|
||||
/// Convergence tolerance (radians).
|
||||
pub tolerance: f32,
|
||||
/// Fraction of subcarriers considered "static" (lowest variance).
|
||||
pub static_fraction: f32,
|
||||
/// Minimum number of static subcarriers required.
|
||||
pub min_static_subcarriers: usize,
|
||||
}
|
||||
|
||||
impl Default for PhaseAlignConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_iterations: 20,
|
||||
tolerance: 1e-4,
|
||||
static_fraction: 0.3,
|
||||
min_static_subcarriers: 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cross-channel phase aligner.
|
||||
///
|
||||
/// Estimates per-channel LO phase offsets from static subcarriers and
|
||||
/// removes them to produce phase-coherent multi-band observations.
|
||||
#[derive(Debug)]
|
||||
pub struct PhaseAligner {
|
||||
/// Number of channels expected.
|
||||
num_channels: usize,
|
||||
/// Configuration parameters.
|
||||
config: PhaseAlignConfig,
|
||||
/// Last estimated offsets (one per channel), updated after each `align`.
|
||||
last_offsets: Vec<f32>,
|
||||
}
|
||||
|
||||
impl PhaseAligner {
|
||||
/// Create a new aligner for the given number of channels.
|
||||
pub fn new(num_channels: usize) -> Self {
|
||||
Self {
|
||||
num_channels,
|
||||
config: PhaseAlignConfig::default(),
|
||||
last_offsets: vec![0.0; num_channels],
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new aligner with custom configuration.
|
||||
pub fn with_config(num_channels: usize, config: PhaseAlignConfig) -> Self {
|
||||
Self {
|
||||
num_channels,
|
||||
config,
|
||||
last_offsets: vec![0.0; num_channels],
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the last estimated phase offsets (radians).
|
||||
pub fn last_offsets(&self) -> &[f32] {
|
||||
&self.last_offsets
|
||||
}
|
||||
|
||||
/// Align phases across channels.
|
||||
///
|
||||
/// Takes a slice of per-channel `CanonicalCsiFrame`s and returns corrected
|
||||
/// frames with LO phase offsets removed. The first channel is used as the
|
||||
/// reference (delta_0 = 0).
|
||||
///
|
||||
/// # Algorithm
|
||||
///
|
||||
/// 1. Identify static subcarriers (lowest amplitude variance across channels).
|
||||
/// 2. For each channel c, compute mean phase on static subcarriers.
|
||||
/// 3. Estimate delta_c as the difference from the reference channel.
|
||||
/// 4. Iterate with Neumann-style refinement until convergence.
|
||||
/// 5. Subtract delta_c from all subcarrier phases on channel c.
|
||||
pub fn align(
|
||||
&mut self,
|
||||
frames: &[CanonicalCsiFrame],
|
||||
) -> std::result::Result<Vec<CanonicalCsiFrame>, PhaseAlignError> {
|
||||
if frames.is_empty() {
|
||||
return Err(PhaseAlignError::NoFrames);
|
||||
}
|
||||
|
||||
if frames.len() == 1 {
|
||||
// Single channel: no alignment needed
|
||||
self.last_offsets = vec![0.0];
|
||||
return Ok(frames.to_vec());
|
||||
}
|
||||
|
||||
let n_sub = frames[0].phase.len();
|
||||
for (_i, f) in frames.iter().enumerate().skip(1) {
|
||||
if f.phase.len() != n_sub {
|
||||
return Err(PhaseAlignError::PhaseLengthMismatch {
|
||||
expected: n_sub,
|
||||
got: f.phase.len(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Step 1: Find static subcarriers (lowest amplitude variance across channels)
|
||||
let static_indices = find_static_subcarriers(frames, &self.config)?;
|
||||
|
||||
// Step 2-4: Estimate phase offsets with iterative refinement
|
||||
let offsets = estimate_phase_offsets(frames, &static_indices, &self.config)?;
|
||||
|
||||
// Step 5: Apply correction
|
||||
let corrected = apply_phase_correction(frames, &offsets);
|
||||
|
||||
self.last_offsets = offsets;
|
||||
Ok(corrected)
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the indices of static subcarriers (lowest amplitude variance).
|
||||
fn find_static_subcarriers(
|
||||
frames: &[CanonicalCsiFrame],
|
||||
config: &PhaseAlignConfig,
|
||||
) -> std::result::Result<Vec<usize>, PhaseAlignError> {
|
||||
let n_sub = frames[0].amplitude.len();
|
||||
let n_ch = frames.len();
|
||||
|
||||
// Compute variance of amplitude across channels for each subcarrier
|
||||
let mut variances: Vec<(usize, f32)> = (0..n_sub)
|
||||
.map(|s| {
|
||||
let mean: f32 = frames.iter().map(|f| f.amplitude[s]).sum::<f32>() / n_ch as f32;
|
||||
let var: f32 = frames
|
||||
.iter()
|
||||
.map(|f| {
|
||||
let d = f.amplitude[s] - mean;
|
||||
d * d
|
||||
})
|
||||
.sum::<f32>()
|
||||
/ n_ch as f32;
|
||||
(s, var)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sort by variance (ascending) and take the bottom fraction
|
||||
variances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
let n_static = ((n_sub as f32 * config.static_fraction).ceil() as usize)
|
||||
.max(config.min_static_subcarriers);
|
||||
|
||||
if variances.len() < config.min_static_subcarriers {
|
||||
return Err(PhaseAlignError::InsufficientStatic {
|
||||
needed: config.min_static_subcarriers,
|
||||
found: variances.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut indices: Vec<usize> = variances
|
||||
.iter()
|
||||
.take(n_static.min(variances.len()))
|
||||
.map(|(idx, _)| *idx)
|
||||
.collect();
|
||||
|
||||
indices.sort_unstable();
|
||||
Ok(indices)
|
||||
}
|
||||
|
||||
/// Estimate per-channel phase offsets using iterative Neumann-style refinement.
|
||||
///
|
||||
/// Channel 0 is the reference (offset = 0).
|
||||
fn estimate_phase_offsets(
|
||||
frames: &[CanonicalCsiFrame],
|
||||
static_indices: &[usize],
|
||||
config: &PhaseAlignConfig,
|
||||
) -> std::result::Result<Vec<f32>, PhaseAlignError> {
|
||||
let n_ch = frames.len();
|
||||
let mut offsets = vec![0.0_f32; n_ch];
|
||||
|
||||
// Reference: mean phase on static subcarriers for channel 0
|
||||
let ref_mean = mean_phase_on_indices(&frames[0].phase, static_indices);
|
||||
|
||||
// Initial estimate: difference of mean static phase from reference
|
||||
for c in 1..n_ch {
|
||||
let ch_mean = mean_phase_on_indices(&frames[c].phase, static_indices);
|
||||
offsets[c] = wrap_phase(ch_mean - ref_mean);
|
||||
}
|
||||
|
||||
// Iterative refinement (Neumann-style)
|
||||
for _iter in 0..config.max_iterations {
|
||||
let mut max_update = 0.0_f32;
|
||||
|
||||
for c in 1..n_ch {
|
||||
// Compute residual: for each static subcarrier, the corrected
|
||||
// phase should match the reference channel's phase.
|
||||
let mut residual_sum = 0.0_f32;
|
||||
for &s in static_indices {
|
||||
let corrected = frames[c].phase[s] - offsets[c];
|
||||
let residual = wrap_phase(corrected - frames[0].phase[s]);
|
||||
residual_sum += residual;
|
||||
}
|
||||
let mean_residual = residual_sum / static_indices.len() as f32;
|
||||
|
||||
// Update offset
|
||||
let update = mean_residual * 0.5; // damped update
|
||||
offsets[c] = wrap_phase(offsets[c] + update);
|
||||
max_update = max_update.max(update.abs());
|
||||
}
|
||||
|
||||
if max_update < config.tolerance {
|
||||
return Ok(offsets);
|
||||
}
|
||||
}
|
||||
|
||||
// Even if we do not converge tightly, return best estimate
|
||||
Ok(offsets)
|
||||
}
|
||||
|
||||
/// Apply phase correction: subtract offset from each subcarrier phase.
|
||||
fn apply_phase_correction(
|
||||
frames: &[CanonicalCsiFrame],
|
||||
offsets: &[f32],
|
||||
) -> Vec<CanonicalCsiFrame> {
|
||||
frames
|
||||
.iter()
|
||||
.zip(offsets.iter())
|
||||
.map(|(frame, &offset)| {
|
||||
let corrected_phase: Vec<f32> = frame
|
||||
.phase
|
||||
.iter()
|
||||
.map(|&p| wrap_phase(p - offset))
|
||||
.collect();
|
||||
CanonicalCsiFrame {
|
||||
amplitude: frame.amplitude.clone(),
|
||||
phase: corrected_phase,
|
||||
hardware_type: frame.hardware_type,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute mean phase on the given subcarrier indices.
|
||||
fn mean_phase_on_indices(phase: &[f32], indices: &[usize]) -> f32 {
|
||||
if indices.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Use circular mean to handle phase wrapping
|
||||
let mut sin_sum = 0.0_f32;
|
||||
let mut cos_sum = 0.0_f32;
|
||||
for &i in indices {
|
||||
// Defensive bounds check: skip out-of-range indices rather than panic
|
||||
if let Some(&p) = phase.get(i) {
|
||||
sin_sum += p.sin();
|
||||
cos_sum += p.cos();
|
||||
}
|
||||
}
|
||||
|
||||
sin_sum.atan2(cos_sum)
|
||||
}
|
||||
|
||||
/// Wrap phase into [-pi, pi].
|
||||
fn wrap_phase(phase: f32) -> f32 {
|
||||
let mut p = phase % (2.0 * PI);
|
||||
if p > PI {
|
||||
p -= 2.0 * PI;
|
||||
}
|
||||
if p < -PI {
|
||||
p += 2.0 * PI;
|
||||
}
|
||||
p
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::hardware_norm::HardwareType;
|
||||
|
||||
fn make_frame_with_phase(n: usize, base_phase: f32, offset: f32) -> CanonicalCsiFrame {
|
||||
let amplitude: Vec<f32> = (0..n).map(|i| 1.0 + 0.01 * i as f32).collect();
|
||||
let phase: Vec<f32> = (0..n).map(|i| base_phase + i as f32 * 0.01 + offset).collect();
|
||||
CanonicalCsiFrame {
|
||||
amplitude,
|
||||
phase,
|
||||
hardware_type: HardwareType::Esp32S3,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn single_channel_no_change() {
|
||||
let mut aligner = PhaseAligner::new(1);
|
||||
let frames = vec![make_frame_with_phase(56, 0.0, 0.0)];
|
||||
let result = aligner.align(&frames).unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].phase, frames[0].phase);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_frames_error() {
|
||||
let mut aligner = PhaseAligner::new(3);
|
||||
let result = aligner.align(&[]);
|
||||
assert!(matches!(result, Err(PhaseAlignError::NoFrames)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn phase_length_mismatch_error() {
|
||||
let mut aligner = PhaseAligner::new(2);
|
||||
let f1 = make_frame_with_phase(56, 0.0, 0.0);
|
||||
let f2 = make_frame_with_phase(30, 0.0, 0.0);
|
||||
let result = aligner.align(&[f1, f2]);
|
||||
assert!(matches!(result, Err(PhaseAlignError::PhaseLengthMismatch { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identical_channels_zero_offset() {
|
||||
let mut aligner = PhaseAligner::new(3);
|
||||
let f = make_frame_with_phase(56, 0.5, 0.0);
|
||||
let result = aligner.align(&[f.clone(), f.clone(), f.clone()]).unwrap();
|
||||
assert_eq!(result.len(), 3);
|
||||
// All offsets should be ~0
|
||||
for &off in aligner.last_offsets() {
|
||||
assert!(off.abs() < 0.1, "Expected near-zero offset, got {}", off);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn known_offset_corrected() {
|
||||
let mut aligner = PhaseAligner::new(2);
|
||||
let offset = 0.5_f32;
|
||||
let f0 = make_frame_with_phase(56, 0.0, 0.0);
|
||||
let f1 = make_frame_with_phase(56, 0.0, offset);
|
||||
|
||||
let result = aligner.align(&[f0.clone(), f1]).unwrap();
|
||||
|
||||
// After correction, channel 1 phases should be close to channel 0
|
||||
let max_diff: f32 = result[0]
|
||||
.phase
|
||||
.iter()
|
||||
.zip(result[1].phase.iter())
|
||||
.map(|(a, b)| wrap_phase(a - b).abs())
|
||||
.fold(0.0_f32, f32::max);
|
||||
|
||||
assert!(
|
||||
max_diff < 0.2,
|
||||
"Max phase difference after alignment: {} (should be <0.2)",
|
||||
max_diff
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrap_phase_within_range() {
|
||||
assert!((wrap_phase(0.0)).abs() < 1e-6);
|
||||
assert!((wrap_phase(PI) - PI).abs() < 1e-6);
|
||||
assert!((wrap_phase(-PI) + PI).abs() < 1e-6);
|
||||
assert!((wrap_phase(3.0 * PI) - PI).abs() < 0.01);
|
||||
assert!((wrap_phase(-3.0 * PI) + PI).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mean_phase_circular() {
|
||||
let phase = vec![0.1_f32, 0.2, 0.3, 0.4];
|
||||
let indices = vec![0, 1, 2, 3];
|
||||
let m = mean_phase_on_indices(&phase, &indices);
|
||||
assert!((m - 0.25).abs() < 0.05);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mean_phase_empty_indices() {
|
||||
assert_eq!(mean_phase_on_indices(&[1.0, 2.0], &[]), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn last_offsets_accessible() {
|
||||
let aligner = PhaseAligner::new(3);
|
||||
assert_eq!(aligner.last_offsets().len(), 3);
|
||||
assert!(aligner.last_offsets().iter().all(|&x| x == 0.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_config() {
|
||||
let config = PhaseAlignConfig {
|
||||
max_iterations: 50,
|
||||
tolerance: 1e-6,
|
||||
static_fraction: 0.5,
|
||||
min_static_subcarriers: 3,
|
||||
};
|
||||
let aligner = PhaseAligner::with_config(2, config);
|
||||
assert_eq!(aligner.last_offsets().len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn three_channel_alignment() {
|
||||
let mut aligner = PhaseAligner::new(3);
|
||||
let f0 = make_frame_with_phase(56, 0.0, 0.0);
|
||||
let f1 = make_frame_with_phase(56, 0.0, 0.3);
|
||||
let f2 = make_frame_with_phase(56, 0.0, -0.2);
|
||||
|
||||
let result = aligner.align(&[f0, f1, f2]).unwrap();
|
||||
assert_eq!(result.len(), 3);
|
||||
|
||||
// Reference channel offset should be 0
|
||||
assert!(aligner.last_offsets()[0].abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_config_values() {
|
||||
let cfg = PhaseAlignConfig::default();
|
||||
assert_eq!(cfg.max_iterations, 20);
|
||||
assert!((cfg.tolerance - 1e-4).abs() < 1e-8);
|
||||
assert!((cfg.static_fraction - 0.3).abs() < 1e-6);
|
||||
assert_eq!(cfg.min_static_subcarriers, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn phase_correction_preserves_amplitude() {
|
||||
let mut aligner = PhaseAligner::new(2);
|
||||
let f0 = make_frame_with_phase(56, 0.0, 0.0);
|
||||
let f1 = make_frame_with_phase(56, 0.0, 1.0);
|
||||
|
||||
let result = aligner.align(&[f0.clone(), f1.clone()]).unwrap();
|
||||
// Amplitude should be unchanged
|
||||
assert_eq!(result[0].amplitude, f0.amplitude);
|
||||
assert_eq!(result[1].amplitude, f1.amplitude);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,943 @@
|
||||
//! 17-Keypoint Kalman Pose Tracker with Re-ID (ADR-029 Section 2.7)
|
||||
//!
|
||||
//! Tracks multiple people as persistent 17-keypoint skeletons across time.
|
||||
//! Each keypoint has a 6D Kalman state (x, y, z, vx, vy, vz) with a
|
||||
//! constant-velocity motion model. Track lifecycle follows:
|
||||
//!
|
||||
//! Tentative -> Active -> Lost -> Terminated
|
||||
//!
|
||||
//! Detection-to-track assignment uses a joint cost combining Mahalanobis
|
||||
//! distance (60%) and AETHER re-ID embedding cosine similarity (40%),
|
||||
//! implemented via `ruvector-mincut::DynamicPersonMatcher`.
|
||||
//!
|
||||
//! # Parameters
|
||||
//!
|
||||
//! | Parameter | Value | Rationale |
|
||||
//! |-----------|-------|-----------|
|
||||
//! | State dimension | 6 per keypoint | Constant-velocity model |
|
||||
//! | Process noise | 0.3 m/s^2 | Normal walking acceleration |
|
||||
//! | Measurement noise | 0.08 m | Target <8cm RMS at torso |
|
||||
//! | Birth hits | 2 frames | Reject single-frame noise |
|
||||
//! | Loss misses | 5 frames | Brief occlusion tolerance |
|
||||
//! | Re-ID embedding | 128-dim | AETHER body-shape discriminative |
|
||||
//! | Re-ID window | 5 seconds | Crossing recovery |
|
||||
//!
|
||||
//! # RuVector Integration
|
||||
//!
|
||||
//! - `ruvector-mincut` -> Person separation and track assignment
|
||||
|
||||
use super::{TrackId, NUM_KEYPOINTS};
|
||||
|
||||
/// Errors from the pose tracker.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum PoseTrackerError {
|
||||
/// Invalid keypoint index.
|
||||
#[error("Invalid keypoint index {index}, max is {}", NUM_KEYPOINTS - 1)]
|
||||
InvalidKeypointIndex { index: usize },
|
||||
|
||||
/// Invalid embedding dimension.
|
||||
#[error("Embedding dimension {got} does not match expected {expected}")]
|
||||
EmbeddingDimMismatch { expected: usize, got: usize },
|
||||
|
||||
/// Mahalanobis gate exceeded.
|
||||
#[error("Mahalanobis distance {distance:.2} exceeds gate {gate:.2}")]
|
||||
MahalanobisGateExceeded { distance: f32, gate: f32 },
|
||||
|
||||
/// Track not found.
|
||||
#[error("Track {0} not found")]
|
||||
TrackNotFound(TrackId),
|
||||
|
||||
/// No detections provided.
|
||||
#[error("No detections provided for update")]
|
||||
NoDetections,
|
||||
}
|
||||
|
||||
/// Per-keypoint Kalman state.
|
||||
///
|
||||
/// Maintains a 6D state vector [x, y, z, vx, vy, vz] and a 6x6 covariance
|
||||
/// matrix stored as the upper triangle (21 elements, row-major).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KeypointState {
|
||||
/// State vector [x, y, z, vx, vy, vz].
|
||||
pub state: [f32; 6],
|
||||
/// 6x6 covariance upper triangle (21 elements, row-major).
|
||||
/// Indices: (0,0)=0, (0,1)=1, (0,2)=2, (0,3)=3, (0,4)=4, (0,5)=5,
|
||||
/// (1,1)=6, (1,2)=7, (1,3)=8, (1,4)=9, (1,5)=10,
|
||||
/// (2,2)=11, (2,3)=12, (2,4)=13, (2,5)=14,
|
||||
/// (3,3)=15, (3,4)=16, (3,5)=17,
|
||||
/// (4,4)=18, (4,5)=19,
|
||||
/// (5,5)=20
|
||||
pub covariance: [f32; 21],
|
||||
/// Confidence (0.0-1.0) from DensePose model output.
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
impl KeypointState {
|
||||
/// Create a new keypoint state at the given 3D position.
|
||||
pub fn new(x: f32, y: f32, z: f32) -> Self {
|
||||
let mut cov = [0.0_f32; 21];
|
||||
// Initialize diagonal with default uncertainty
|
||||
let pos_var = 0.1 * 0.1; // 10 cm initial uncertainty
|
||||
let vel_var = 0.5 * 0.5; // 0.5 m/s initial velocity uncertainty
|
||||
cov[0] = pos_var; // x variance
|
||||
cov[6] = pos_var; // y variance
|
||||
cov[11] = pos_var; // z variance
|
||||
cov[15] = vel_var; // vx variance
|
||||
cov[18] = vel_var; // vy variance
|
||||
cov[20] = vel_var; // vz variance
|
||||
|
||||
Self {
|
||||
state: [x, y, z, 0.0, 0.0, 0.0],
|
||||
covariance: cov,
|
||||
confidence: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the position [x, y, z].
|
||||
pub fn position(&self) -> [f32; 3] {
|
||||
[self.state[0], self.state[1], self.state[2]]
|
||||
}
|
||||
|
||||
/// Return the velocity [vx, vy, vz].
|
||||
pub fn velocity(&self) -> [f32; 3] {
|
||||
[self.state[3], self.state[4], self.state[5]]
|
||||
}
|
||||
|
||||
/// Predict step: advance state by dt seconds using constant-velocity model.
|
||||
///
|
||||
/// x' = x + vx * dt
|
||||
/// P' = F * P * F^T + Q
|
||||
pub fn predict(&mut self, dt: f32, process_noise_accel: f32) {
|
||||
// State prediction: x' = x + v * dt
|
||||
self.state[0] += self.state[3] * dt;
|
||||
self.state[1] += self.state[4] * dt;
|
||||
self.state[2] += self.state[5] * dt;
|
||||
|
||||
// Process noise Q (constant acceleration model)
|
||||
let dt2 = dt * dt;
|
||||
let dt3 = dt2 * dt;
|
||||
let dt4 = dt3 * dt;
|
||||
let q = process_noise_accel * process_noise_accel;
|
||||
|
||||
// Add process noise to diagonal elements
|
||||
// Position variances: + q * dt^4 / 4
|
||||
let pos_q = q * dt4 / 4.0;
|
||||
// Velocity variances: + q * dt^2
|
||||
let vel_q = q * dt2;
|
||||
// Position-velocity cross: + q * dt^3 / 2
|
||||
let _cross_q = q * dt3 / 2.0;
|
||||
|
||||
// Simplified: only update diagonal for numerical stability
|
||||
self.covariance[0] += pos_q; // xx
|
||||
self.covariance[6] += pos_q; // yy
|
||||
self.covariance[11] += pos_q; // zz
|
||||
self.covariance[15] += vel_q; // vxvx
|
||||
self.covariance[18] += vel_q; // vyvy
|
||||
self.covariance[20] += vel_q; // vzvz
|
||||
}
|
||||
|
||||
/// Measurement update: incorporate a position observation [x, y, z].
|
||||
///
|
||||
/// Uses the standard Kalman update with position-only measurement model
|
||||
/// H = [I3 | 0_3x3].
|
||||
pub fn update(
|
||||
&mut self,
|
||||
measurement: &[f32; 3],
|
||||
measurement_noise: f32,
|
||||
noise_multiplier: f32,
|
||||
) {
|
||||
let r = measurement_noise * measurement_noise * noise_multiplier;
|
||||
|
||||
// Innovation (residual)
|
||||
let innov = [
|
||||
measurement[0] - self.state[0],
|
||||
measurement[1] - self.state[1],
|
||||
measurement[2] - self.state[2],
|
||||
];
|
||||
|
||||
// Innovation covariance S = H * P * H^T + R
|
||||
// Since H = [I3 | 0], S is just the top-left 3x3 of P + R
|
||||
let s = [
|
||||
self.covariance[0] + r,
|
||||
self.covariance[6] + r,
|
||||
self.covariance[11] + r,
|
||||
];
|
||||
|
||||
// Kalman gain K = P * H^T * S^-1
|
||||
// For diagonal S, K_ij = P_ij / S_jj (simplified)
|
||||
let k = [
|
||||
[self.covariance[0] / s[0], 0.0, 0.0], // x row
|
||||
[0.0, self.covariance[6] / s[1], 0.0], // y row
|
||||
[0.0, 0.0, self.covariance[11] / s[2]], // z row
|
||||
[self.covariance[3] / s[0], 0.0, 0.0], // vx row
|
||||
[0.0, self.covariance[9] / s[1], 0.0], // vy row
|
||||
[0.0, 0.0, self.covariance[14] / s[2]], // vz row
|
||||
];
|
||||
|
||||
// State update: x' = x + K * innov
|
||||
for i in 0..6 {
|
||||
for j in 0..3 {
|
||||
self.state[i] += k[i][j] * innov[j];
|
||||
}
|
||||
}
|
||||
|
||||
// Covariance update: P' = (I - K*H) * P (simplified diagonal update)
|
||||
self.covariance[0] *= 1.0 - k[0][0];
|
||||
self.covariance[6] *= 1.0 - k[1][1];
|
||||
self.covariance[11] *= 1.0 - k[2][2];
|
||||
}
|
||||
|
||||
/// Compute the Mahalanobis distance between this state and a measurement.
|
||||
pub fn mahalanobis_distance(&self, measurement: &[f32; 3]) -> f32 {
|
||||
let innov = [
|
||||
measurement[0] - self.state[0],
|
||||
measurement[1] - self.state[1],
|
||||
measurement[2] - self.state[2],
|
||||
];
|
||||
|
||||
// Using diagonal approximation
|
||||
let mut dist_sq = 0.0_f32;
|
||||
let variances = [self.covariance[0], self.covariance[6], self.covariance[11]];
|
||||
for i in 0..3 {
|
||||
let v = variances[i].max(1e-6);
|
||||
dist_sq += innov[i] * innov[i] / v;
|
||||
}
|
||||
|
||||
dist_sq.sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for KeypointState {
|
||||
fn default() -> Self {
|
||||
Self::new(0.0, 0.0, 0.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Track lifecycle state machine.
|
||||
///
|
||||
/// Follows the pattern from ADR-026:
|
||||
/// Tentative -> Active -> Lost -> Terminated
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TrackLifecycleState {
|
||||
/// Track has been detected but not yet confirmed (< birth_hits frames).
|
||||
Tentative,
|
||||
/// Track is confirmed and actively being updated.
|
||||
Active,
|
||||
/// Track has lost measurement association (< loss_misses frames).
|
||||
Lost,
|
||||
/// Track has been terminated (exceeded max lost duration or deemed false positive).
|
||||
Terminated,
|
||||
}
|
||||
|
||||
impl TrackLifecycleState {
|
||||
/// Returns true if the track is in an active or tentative state.
|
||||
pub fn is_alive(&self) -> bool {
|
||||
matches!(self, Self::Tentative | Self::Active | Self::Lost)
|
||||
}
|
||||
|
||||
/// Returns true if the track can receive measurement updates.
|
||||
pub fn accepts_updates(&self) -> bool {
|
||||
matches!(self, Self::Tentative | Self::Active)
|
||||
}
|
||||
|
||||
/// Returns true if the track is eligible for re-identification.
|
||||
pub fn is_lost(&self) -> bool {
|
||||
matches!(self, Self::Lost)
|
||||
}
|
||||
}
|
||||
|
||||
/// A pose track -- aggregate root for tracking one person.
|
||||
///
|
||||
/// Contains 17 keypoint Kalman states, lifecycle, and re-ID embedding.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PoseTrack {
|
||||
/// Unique track identifier.
|
||||
pub id: TrackId,
|
||||
/// Per-keypoint Kalman state (COCO-17 ordering).
|
||||
pub keypoints: [KeypointState; NUM_KEYPOINTS],
|
||||
/// Track lifecycle state.
|
||||
pub lifecycle: TrackLifecycleState,
|
||||
/// Running-average AETHER embedding for re-ID (128-dim).
|
||||
pub embedding: Vec<f32>,
|
||||
/// Total frames since creation.
|
||||
pub age: u64,
|
||||
/// Frames since last successful measurement update.
|
||||
pub time_since_update: u64,
|
||||
/// Number of consecutive measurement updates (for birth gate).
|
||||
pub consecutive_hits: u64,
|
||||
/// Creation timestamp in microseconds.
|
||||
pub created_at: u64,
|
||||
/// Last update timestamp in microseconds.
|
||||
pub updated_at: u64,
|
||||
}
|
||||
|
||||
impl PoseTrack {
|
||||
/// Create a new tentative track from a detection.
|
||||
pub fn new(
|
||||
id: TrackId,
|
||||
keypoint_positions: &[[f32; 3]; NUM_KEYPOINTS],
|
||||
timestamp_us: u64,
|
||||
embedding_dim: usize,
|
||||
) -> Self {
|
||||
let keypoints = std::array::from_fn(|i| {
|
||||
let [x, y, z] = keypoint_positions[i];
|
||||
KeypointState::new(x, y, z)
|
||||
});
|
||||
|
||||
Self {
|
||||
id,
|
||||
keypoints,
|
||||
lifecycle: TrackLifecycleState::Tentative,
|
||||
embedding: vec![0.0; embedding_dim],
|
||||
age: 0,
|
||||
time_since_update: 0,
|
||||
consecutive_hits: 1,
|
||||
created_at: timestamp_us,
|
||||
updated_at: timestamp_us,
|
||||
}
|
||||
}
|
||||
|
||||
/// Predict all keypoints forward by dt seconds.
|
||||
pub fn predict(&mut self, dt: f32, process_noise: f32) {
|
||||
for kp in &mut self.keypoints {
|
||||
kp.predict(dt, process_noise);
|
||||
}
|
||||
self.age += 1;
|
||||
self.time_since_update += 1;
|
||||
}
|
||||
|
||||
/// Update all keypoints with new measurements.
|
||||
///
|
||||
/// Also updates lifecycle state transitions based on birth/loss gates.
|
||||
pub fn update_keypoints(
|
||||
&mut self,
|
||||
measurements: &[[f32; 3]; NUM_KEYPOINTS],
|
||||
measurement_noise: f32,
|
||||
noise_multiplier: f32,
|
||||
timestamp_us: u64,
|
||||
) {
|
||||
for (kp, meas) in self.keypoints.iter_mut().zip(measurements.iter()) {
|
||||
kp.update(meas, measurement_noise, noise_multiplier);
|
||||
}
|
||||
|
||||
self.time_since_update = 0;
|
||||
self.consecutive_hits += 1;
|
||||
self.updated_at = timestamp_us;
|
||||
|
||||
// Lifecycle transitions
|
||||
self.update_lifecycle();
|
||||
}
|
||||
|
||||
/// Update the embedding with EMA decay.
|
||||
pub fn update_embedding(&mut self, new_embedding: &[f32], decay: f32) {
|
||||
if new_embedding.len() != self.embedding.len() {
|
||||
return;
|
||||
}
|
||||
|
||||
let alpha = 1.0 - decay;
|
||||
for (e, &ne) in self.embedding.iter_mut().zip(new_embedding.iter()) {
|
||||
*e = decay * *e + alpha * ne;
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the centroid position (mean of all keypoints).
|
||||
pub fn centroid(&self) -> [f32; 3] {
|
||||
let n = NUM_KEYPOINTS as f32;
|
||||
let mut c = [0.0_f32; 3];
|
||||
for kp in &self.keypoints {
|
||||
let pos = kp.position();
|
||||
c[0] += pos[0];
|
||||
c[1] += pos[1];
|
||||
c[2] += pos[2];
|
||||
}
|
||||
c[0] /= n;
|
||||
c[1] /= n;
|
||||
c[2] /= n;
|
||||
c
|
||||
}
|
||||
|
||||
/// Compute torso jitter RMS in meters.
|
||||
///
|
||||
/// Uses the torso keypoints (shoulders, hips) velocity magnitudes
|
||||
/// as a proxy for jitter.
|
||||
pub fn torso_jitter_rms(&self) -> f32 {
|
||||
let torso_indices = super::keypoint::TORSO_INDICES;
|
||||
let mut sum_sq = 0.0_f32;
|
||||
let mut count = 0;
|
||||
|
||||
for &idx in torso_indices {
|
||||
let vel = self.keypoints[idx].velocity();
|
||||
let speed_sq = vel[0] * vel[0] + vel[1] * vel[1] + vel[2] * vel[2];
|
||||
sum_sq += speed_sq;
|
||||
count += 1;
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
(sum_sq / count as f32).sqrt()
|
||||
}
|
||||
|
||||
/// Mark the track as lost.
|
||||
pub fn mark_lost(&mut self) {
|
||||
if self.lifecycle != TrackLifecycleState::Terminated {
|
||||
self.lifecycle = TrackLifecycleState::Lost;
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark the track as terminated.
|
||||
pub fn terminate(&mut self) {
|
||||
self.lifecycle = TrackLifecycleState::Terminated;
|
||||
}
|
||||
|
||||
/// Update lifecycle state based on consecutive hits and misses.
|
||||
fn update_lifecycle(&mut self) {
|
||||
match self.lifecycle {
|
||||
TrackLifecycleState::Tentative => {
|
||||
if self.consecutive_hits >= 2 {
|
||||
// Birth gate: promote to Active after 2 consecutive updates
|
||||
self.lifecycle = TrackLifecycleState::Active;
|
||||
}
|
||||
}
|
||||
TrackLifecycleState::Lost => {
|
||||
// Re-acquired: promote back to Active
|
||||
self.lifecycle = TrackLifecycleState::Active;
|
||||
self.consecutive_hits = 1;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tracker configuration parameters.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrackerConfig {
|
||||
/// Process noise acceleration (m/s^2). Default: 0.3.
|
||||
pub process_noise: f32,
|
||||
/// Measurement noise std dev (m). Default: 0.08.
|
||||
pub measurement_noise: f32,
|
||||
/// Mahalanobis gate threshold (chi-squared(3) at 3-sigma = 9.0).
|
||||
pub mahalanobis_gate: f32,
|
||||
/// Frames required for tentative->active promotion. Default: 2.
|
||||
pub birth_hits: u64,
|
||||
/// Max frames without update before tentative->lost. Default: 5.
|
||||
pub loss_misses: u64,
|
||||
/// Re-ID window in frames (5 seconds at 20Hz = 100). Default: 100.
|
||||
pub reid_window: u64,
|
||||
/// Embedding EMA decay rate. Default: 0.95.
|
||||
pub embedding_decay: f32,
|
||||
/// Embedding dimension. Default: 128.
|
||||
pub embedding_dim: usize,
|
||||
/// Position weight in assignment cost. Default: 0.6.
|
||||
pub position_weight: f32,
|
||||
/// Embedding weight in assignment cost. Default: 0.4.
|
||||
pub embedding_weight: f32,
|
||||
}
|
||||
|
||||
impl Default for TrackerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
process_noise: 0.3,
|
||||
measurement_noise: 0.08,
|
||||
mahalanobis_gate: 9.0,
|
||||
birth_hits: 2,
|
||||
loss_misses: 5,
|
||||
reid_window: 100,
|
||||
embedding_decay: 0.95,
|
||||
embedding_dim: 128,
|
||||
position_weight: 0.6,
|
||||
embedding_weight: 0.4,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Multi-person pose tracker.
|
||||
///
|
||||
/// Manages a collection of `PoseTrack` instances with automatic lifecycle
|
||||
/// management, detection-to-track assignment, and re-identification.
|
||||
#[derive(Debug)]
|
||||
pub struct PoseTracker {
|
||||
config: TrackerConfig,
|
||||
tracks: Vec<PoseTrack>,
|
||||
next_id: u64,
|
||||
}
|
||||
|
||||
impl PoseTracker {
|
||||
/// Create a new tracker with default configuration.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
config: TrackerConfig::default(),
|
||||
tracks: Vec::new(),
|
||||
next_id: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new tracker with custom configuration.
|
||||
pub fn with_config(config: TrackerConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
tracks: Vec::new(),
|
||||
next_id: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return all active tracks (not terminated).
|
||||
pub fn active_tracks(&self) -> Vec<&PoseTrack> {
|
||||
self.tracks
|
||||
.iter()
|
||||
.filter(|t| t.lifecycle.is_alive())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Return all tracks including terminated ones.
|
||||
pub fn all_tracks(&self) -> &[PoseTrack] {
|
||||
&self.tracks
|
||||
}
|
||||
|
||||
/// Return the number of active (alive) tracks.
|
||||
pub fn active_count(&self) -> usize {
|
||||
self.tracks.iter().filter(|t| t.lifecycle.is_alive()).count()
|
||||
}
|
||||
|
||||
/// Predict step for all tracks (advance by dt seconds).
|
||||
pub fn predict_all(&mut self, dt: f32) {
|
||||
for track in &mut self.tracks {
|
||||
if track.lifecycle.is_alive() {
|
||||
track.predict(dt, self.config.process_noise);
|
||||
}
|
||||
}
|
||||
|
||||
// Mark tracks as lost after exceeding loss_misses
|
||||
for track in &mut self.tracks {
|
||||
if track.lifecycle.accepts_updates()
|
||||
&& track.time_since_update >= self.config.loss_misses
|
||||
{
|
||||
track.mark_lost();
|
||||
}
|
||||
}
|
||||
|
||||
// Terminate tracks that have been lost too long
|
||||
let reid_window = self.config.reid_window;
|
||||
for track in &mut self.tracks {
|
||||
if track.lifecycle.is_lost() && track.time_since_update >= reid_window {
|
||||
track.terminate();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new track from a detection.
|
||||
pub fn create_track(
|
||||
&mut self,
|
||||
keypoints: &[[f32; 3]; NUM_KEYPOINTS],
|
||||
timestamp_us: u64,
|
||||
) -> TrackId {
|
||||
let id = TrackId::new(self.next_id);
|
||||
self.next_id += 1;
|
||||
|
||||
let track = PoseTrack::new(id, keypoints, timestamp_us, self.config.embedding_dim);
|
||||
self.tracks.push(track);
|
||||
id
|
||||
}
|
||||
|
||||
/// Find the track with the given ID.
|
||||
pub fn find_track(&self, id: TrackId) -> Option<&PoseTrack> {
|
||||
self.tracks.iter().find(|t| t.id == id)
|
||||
}
|
||||
|
||||
/// Find the track with the given ID (mutable).
|
||||
pub fn find_track_mut(&mut self, id: TrackId) -> Option<&mut PoseTrack> {
|
||||
self.tracks.iter_mut().find(|t| t.id == id)
|
||||
}
|
||||
|
||||
/// Remove terminated tracks from the collection.
|
||||
pub fn prune_terminated(&mut self) {
|
||||
self.tracks
|
||||
.retain(|t| t.lifecycle != TrackLifecycleState::Terminated);
|
||||
}
|
||||
|
||||
/// Compute the assignment cost between a track and a detection.
|
||||
///
|
||||
/// cost = position_weight * mahalanobis(track, detection.position)
|
||||
/// + embedding_weight * (1 - cosine_sim(track.embedding, detection.embedding))
|
||||
pub fn assignment_cost(
|
||||
&self,
|
||||
track: &PoseTrack,
|
||||
detection_centroid: &[f32; 3],
|
||||
detection_embedding: &[f32],
|
||||
) -> f32 {
|
||||
// Position cost: Mahalanobis distance at centroid
|
||||
let centroid_kp = track.centroid();
|
||||
let centroid_state = KeypointState::new(centroid_kp[0], centroid_kp[1], centroid_kp[2]);
|
||||
let maha = centroid_state.mahalanobis_distance(detection_centroid);
|
||||
|
||||
// Embedding cost: 1 - cosine similarity
|
||||
let embed_cost = 1.0 - cosine_similarity(&track.embedding, detection_embedding);
|
||||
|
||||
self.config.position_weight * maha + self.config.embedding_weight * embed_cost
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PoseTracker {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Cosine similarity between two vectors.
|
||||
///
|
||||
/// Returns a value in [-1.0, 1.0] where 1.0 means identical direction.
|
||||
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
let n = a.len().min(b.len());
|
||||
if n == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mut dot = 0.0_f32;
|
||||
let mut norm_a = 0.0_f32;
|
||||
let mut norm_b = 0.0_f32;
|
||||
|
||||
for i in 0..n {
|
||||
dot += a[i] * b[i];
|
||||
norm_a += a[i] * a[i];
|
||||
norm_b += b[i] * b[i];
|
||||
}
|
||||
|
||||
let denom = (norm_a * norm_b).sqrt();
|
||||
if denom < 1e-12 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
(dot / denom).clamp(-1.0, 1.0)
|
||||
}
|
||||
|
||||
/// A detected pose from the model, before assignment to a track.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PoseDetection {
|
||||
/// Per-keypoint positions [x, y, z, confidence] for 17 keypoints.
|
||||
pub keypoints: [[f32; 4]; NUM_KEYPOINTS],
|
||||
/// AETHER re-ID embedding (128-dim).
|
||||
pub embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
impl PoseDetection {
|
||||
/// Extract the 3D position array from keypoints.
|
||||
pub fn positions(&self) -> [[f32; 3]; NUM_KEYPOINTS] {
|
||||
std::array::from_fn(|i| [self.keypoints[i][0], self.keypoints[i][1], self.keypoints[i][2]])
|
||||
}
|
||||
|
||||
/// Compute the centroid of the detection.
|
||||
pub fn centroid(&self) -> [f32; 3] {
|
||||
let n = NUM_KEYPOINTS as f32;
|
||||
let mut c = [0.0_f32; 3];
|
||||
for kp in &self.keypoints {
|
||||
c[0] += kp[0];
|
||||
c[1] += kp[1];
|
||||
c[2] += kp[2];
|
||||
}
|
||||
c[0] /= n;
|
||||
c[1] /= n;
|
||||
c[2] /= n;
|
||||
c
|
||||
}
|
||||
|
||||
/// Mean confidence across all keypoints.
|
||||
pub fn mean_confidence(&self) -> f32 {
|
||||
let sum: f32 = self.keypoints.iter().map(|kp| kp[3]).sum();
|
||||
sum / NUM_KEYPOINTS as f32
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn zero_positions() -> [[f32; 3]; NUM_KEYPOINTS] {
|
||||
[[0.0, 0.0, 0.0]; NUM_KEYPOINTS]
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn offset_positions(offset: f32) -> [[f32; 3]; NUM_KEYPOINTS] {
|
||||
std::array::from_fn(|i| [offset + i as f32 * 0.1, offset, 0.0])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keypoint_state_creation() {
|
||||
let kp = KeypointState::new(1.0, 2.0, 3.0);
|
||||
assert_eq!(kp.position(), [1.0, 2.0, 3.0]);
|
||||
assert_eq!(kp.velocity(), [0.0, 0.0, 0.0]);
|
||||
assert_eq!(kp.confidence, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keypoint_predict_moves_position() {
|
||||
let mut kp = KeypointState::new(0.0, 0.0, 0.0);
|
||||
kp.state[3] = 1.0; // vx = 1 m/s
|
||||
kp.predict(0.05, 0.3); // 50ms step
|
||||
assert!((kp.state[0] - 0.05).abs() < 1e-5, "x should be ~0.05, got {}", kp.state[0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keypoint_predict_increases_uncertainty() {
|
||||
let mut kp = KeypointState::new(0.0, 0.0, 0.0);
|
||||
let initial_var = kp.covariance[0];
|
||||
kp.predict(0.05, 0.3);
|
||||
assert!(kp.covariance[0] > initial_var);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keypoint_update_reduces_uncertainty() {
|
||||
let mut kp = KeypointState::new(0.0, 0.0, 0.0);
|
||||
kp.predict(0.05, 0.3);
|
||||
let post_predict_var = kp.covariance[0];
|
||||
kp.update(&[0.01, 0.0, 0.0], 0.08, 1.0);
|
||||
assert!(kp.covariance[0] < post_predict_var);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mahalanobis_zero_distance() {
|
||||
let kp = KeypointState::new(1.0, 2.0, 3.0);
|
||||
let d = kp.mahalanobis_distance(&[1.0, 2.0, 3.0]);
|
||||
assert!(d < 1e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mahalanobis_positive_for_offset() {
|
||||
let kp = KeypointState::new(0.0, 0.0, 0.0);
|
||||
let d = kp.mahalanobis_distance(&[1.0, 0.0, 0.0]);
|
||||
assert!(d > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lifecycle_transitions() {
|
||||
assert!(TrackLifecycleState::Tentative.is_alive());
|
||||
assert!(TrackLifecycleState::Active.is_alive());
|
||||
assert!(TrackLifecycleState::Lost.is_alive());
|
||||
assert!(!TrackLifecycleState::Terminated.is_alive());
|
||||
|
||||
assert!(TrackLifecycleState::Tentative.accepts_updates());
|
||||
assert!(TrackLifecycleState::Active.accepts_updates());
|
||||
assert!(!TrackLifecycleState::Lost.accepts_updates());
|
||||
assert!(!TrackLifecycleState::Terminated.accepts_updates());
|
||||
|
||||
assert!(!TrackLifecycleState::Tentative.is_lost());
|
||||
assert!(TrackLifecycleState::Lost.is_lost());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn track_creation() {
|
||||
let positions = zero_positions();
|
||||
let track = PoseTrack::new(TrackId(0), &positions, 1000, 128);
|
||||
assert_eq!(track.id, TrackId(0));
|
||||
assert_eq!(track.lifecycle, TrackLifecycleState::Tentative);
|
||||
assert_eq!(track.embedding.len(), 128);
|
||||
assert_eq!(track.age, 0);
|
||||
assert_eq!(track.consecutive_hits, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn track_birth_gate() {
|
||||
let positions = zero_positions();
|
||||
let mut track = PoseTrack::new(TrackId(0), &positions, 0, 128);
|
||||
assert_eq!(track.lifecycle, TrackLifecycleState::Tentative);
|
||||
|
||||
// First update: still tentative (need 2 hits)
|
||||
track.update_keypoints(&positions, 0.08, 1.0, 100);
|
||||
assert_eq!(track.lifecycle, TrackLifecycleState::Active);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn track_loss_gate() {
|
||||
let positions = zero_positions();
|
||||
let mut track = PoseTrack::new(TrackId(0), &positions, 0, 128);
|
||||
track.lifecycle = TrackLifecycleState::Active;
|
||||
|
||||
// Predict without updates exceeding loss_misses
|
||||
for _ in 0..6 {
|
||||
track.predict(0.05, 0.3);
|
||||
}
|
||||
// Manually mark lost (normally done by tracker)
|
||||
if track.time_since_update >= 5 {
|
||||
track.mark_lost();
|
||||
}
|
||||
assert_eq!(track.lifecycle, TrackLifecycleState::Lost);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn track_centroid() {
|
||||
let positions: [[f32; 3]; NUM_KEYPOINTS] =
|
||||
std::array::from_fn(|_| [1.0, 2.0, 3.0]);
|
||||
let track = PoseTrack::new(TrackId(0), &positions, 0, 128);
|
||||
let c = track.centroid();
|
||||
assert!((c[0] - 1.0).abs() < 1e-5);
|
||||
assert!((c[1] - 2.0).abs() < 1e-5);
|
||||
assert!((c[2] - 3.0).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn track_embedding_update() {
|
||||
let positions = zero_positions();
|
||||
let mut track = PoseTrack::new(TrackId(0), &positions, 0, 4);
|
||||
let new_embed = vec![1.0, 2.0, 3.0, 4.0];
|
||||
track.update_embedding(&new_embed, 0.5);
|
||||
// EMA: 0.5 * 0.0 + 0.5 * new = new / 2
|
||||
for i in 0..4 {
|
||||
assert!((track.embedding[i] - new_embed[i] * 0.5).abs() < 1e-5);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tracker_create_and_find() {
|
||||
let mut tracker = PoseTracker::new();
|
||||
let positions = zero_positions();
|
||||
let id = tracker.create_track(&positions, 1000);
|
||||
assert!(tracker.find_track(id).is_some());
|
||||
assert_eq!(tracker.active_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tracker_predict_marks_lost() {
|
||||
let mut tracker = PoseTracker::with_config(TrackerConfig {
|
||||
loss_misses: 3,
|
||||
reid_window: 10,
|
||||
..Default::default()
|
||||
});
|
||||
let positions = zero_positions();
|
||||
let id = tracker.create_track(&positions, 0);
|
||||
|
||||
// Promote to active
|
||||
if let Some(t) = tracker.find_track_mut(id) {
|
||||
t.lifecycle = TrackLifecycleState::Active;
|
||||
}
|
||||
|
||||
// Predict 4 times without update
|
||||
for _ in 0..4 {
|
||||
tracker.predict_all(0.05);
|
||||
}
|
||||
|
||||
let track = tracker.find_track(id).unwrap();
|
||||
assert_eq!(track.lifecycle, TrackLifecycleState::Lost);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tracker_prune_terminated() {
|
||||
let mut tracker = PoseTracker::new();
|
||||
let positions = zero_positions();
|
||||
let id = tracker.create_track(&positions, 0);
|
||||
if let Some(t) = tracker.find_track_mut(id) {
|
||||
t.terminate();
|
||||
}
|
||||
assert_eq!(tracker.all_tracks().len(), 1);
|
||||
tracker.prune_terminated();
|
||||
assert_eq!(tracker.all_tracks().len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosine_similarity_identical() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![1.0, 2.0, 3.0];
|
||||
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosine_similarity_orthogonal() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![0.0, 1.0, 0.0];
|
||||
assert!(cosine_similarity(&a, &b).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosine_similarity_opposite() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![-1.0, -2.0, -3.0];
|
||||
assert!((cosine_similarity(&a, &b) + 1.0).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosine_similarity_empty() {
|
||||
assert_eq!(cosine_similarity(&[], &[]), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pose_detection_centroid() {
|
||||
let kps: [[f32; 4]; NUM_KEYPOINTS] =
|
||||
std::array::from_fn(|_| [1.0, 2.0, 3.0, 0.9]);
|
||||
let det = PoseDetection {
|
||||
keypoints: kps,
|
||||
embedding: vec![0.0; 128],
|
||||
};
|
||||
let c = det.centroid();
|
||||
assert!((c[0] - 1.0).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pose_detection_mean_confidence() {
|
||||
let kps: [[f32; 4]; NUM_KEYPOINTS] =
|
||||
std::array::from_fn(|_| [0.0, 0.0, 0.0, 0.8]);
|
||||
let det = PoseDetection {
|
||||
keypoints: kps,
|
||||
embedding: vec![0.0; 128],
|
||||
};
|
||||
assert!((det.mean_confidence() - 0.8).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pose_detection_positions() {
|
||||
let kps: [[f32; 4]; NUM_KEYPOINTS] =
|
||||
std::array::from_fn(|i| [i as f32, 0.0, 0.0, 1.0]);
|
||||
let det = PoseDetection {
|
||||
keypoints: kps,
|
||||
embedding: vec![],
|
||||
};
|
||||
let pos = det.positions();
|
||||
assert_eq!(pos[0], [0.0, 0.0, 0.0]);
|
||||
assert_eq!(pos[5], [5.0, 0.0, 0.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn assignment_cost_computation() {
|
||||
let mut tracker = PoseTracker::new();
|
||||
let positions = zero_positions();
|
||||
let id = tracker.create_track(&positions, 0);
|
||||
|
||||
let track = tracker.find_track(id).unwrap();
|
||||
let cost = tracker.assignment_cost(track, &[0.0, 0.0, 0.0], &vec![0.0; 128]);
|
||||
// Zero distance + zero embedding cost should be near 0
|
||||
// But embedding cost = 1 - cosine_sim(zeros, zeros) = 1 - 0 = 1
|
||||
// So cost = 0.6 * 0 + 0.4 * 1 = 0.4
|
||||
assert!((cost - 0.4).abs() < 0.1, "Expected ~0.4, got {}", cost);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn torso_jitter_rms_stationary() {
|
||||
let positions = zero_positions();
|
||||
let track = PoseTrack::new(TrackId(0), &positions, 0, 128);
|
||||
let jitter = track.torso_jitter_rms();
|
||||
assert!(jitter < 1e-5, "Stationary track should have near-zero jitter");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_tracker_config() {
|
||||
let cfg = TrackerConfig::default();
|
||||
assert!((cfg.process_noise - 0.3).abs() < f32::EPSILON);
|
||||
assert!((cfg.measurement_noise - 0.08).abs() < f32::EPSILON);
|
||||
assert!((cfg.mahalanobis_gate - 9.0).abs() < f32::EPSILON);
|
||||
assert_eq!(cfg.birth_hits, 2);
|
||||
assert_eq!(cfg.loss_misses, 5);
|
||||
assert_eq!(cfg.reid_window, 100);
|
||||
assert!((cfg.embedding_decay - 0.95).abs() < f32::EPSILON);
|
||||
assert_eq!(cfg.embedding_dim, 128);
|
||||
assert!((cfg.position_weight - 0.6).abs() < f32::EPSILON);
|
||||
assert!((cfg.embedding_weight - 0.4).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn track_terminate_prevents_lost() {
|
||||
let positions = zero_positions();
|
||||
let mut track = PoseTrack::new(TrackId(0), &positions, 0, 128);
|
||||
track.terminate();
|
||||
assert_eq!(track.lifecycle, TrackLifecycleState::Terminated);
|
||||
track.mark_lost(); // Should not override Terminated
|
||||
assert_eq!(track.lifecycle, TrackLifecycleState::Terminated);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,517 @@
|
||||
//! Enhanced gesture classification using `midstreamer-temporal-compare`.
|
||||
//!
|
||||
//! Extends the DTW-based gesture classifier from `gesture.rs` with
|
||||
//! optimized temporal comparison algorithms provided by the
|
||||
//! `midstreamer-temporal-compare` crate (ADR-032a Section 6.4).
|
||||
//!
|
||||
//! # Improvements over base gesture classifier
|
||||
//!
|
||||
//! - **Cached DTW**: Results cached by sequence hash for repeated comparisons
|
||||
//! - **Multi-algorithm**: DTW, LCS, and edit distance available
|
||||
//! - **Pattern detection**: Automatic sub-gesture pattern extraction
|
||||
//!
|
||||
//! # References
|
||||
//! - ADR-030 Tier 6: Invisible Interaction Layer
|
||||
//! - ADR-032a Section 6.4: midstreamer-temporal-compare integration
|
||||
|
||||
use midstreamer_temporal_compare::{
|
||||
ComparisonAlgorithm, Sequence, TemporalComparator,
|
||||
};
|
||||
|
||||
use super::gesture::{GestureConfig, GestureError, GestureResult, GestureTemplate};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Configuration
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Algorithm selection for temporal gesture matching.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum GestureAlgorithm {
|
||||
/// Dynamic Time Warping (classic, from base gesture module).
|
||||
Dtw,
|
||||
/// Longest Common Subsequence (better for sparse gestures).
|
||||
Lcs,
|
||||
/// Edit distance (better for discrete gesture phases).
|
||||
EditDistance,
|
||||
}
|
||||
|
||||
impl GestureAlgorithm {
|
||||
/// Convert to the midstreamer comparison algorithm.
|
||||
pub fn to_comparison_algorithm(&self) -> ComparisonAlgorithm {
|
||||
match self {
|
||||
GestureAlgorithm::Dtw => ComparisonAlgorithm::DTW,
|
||||
GestureAlgorithm::Lcs => ComparisonAlgorithm::LCS,
|
||||
GestureAlgorithm::EditDistance => ComparisonAlgorithm::EditDistance,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for the temporal gesture classifier.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TemporalGestureConfig {
|
||||
/// Base gesture config (feature_dim, min_sequence_len, etc.).
|
||||
pub base: GestureConfig,
|
||||
/// Primary comparison algorithm.
|
||||
pub algorithm: GestureAlgorithm,
|
||||
/// Whether to enable result caching.
|
||||
pub enable_cache: bool,
|
||||
/// Cache capacity (number of comparison results to cache).
|
||||
pub cache_capacity: usize,
|
||||
/// Maximum distance for a match (lower = stricter).
|
||||
pub max_distance: f64,
|
||||
/// Maximum sequence length accepted by the comparator.
|
||||
pub max_sequence_length: usize,
|
||||
}
|
||||
|
||||
impl Default for TemporalGestureConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
base: GestureConfig::default(),
|
||||
algorithm: GestureAlgorithm::Dtw,
|
||||
enable_cache: true,
|
||||
cache_capacity: 256,
|
||||
max_distance: 50.0,
|
||||
max_sequence_length: 1024,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Temporal gesture classifier
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Enhanced gesture classifier using `midstreamer-temporal-compare`.
|
||||
///
|
||||
/// Provides multi-algorithm gesture matching with caching.
|
||||
/// The comparator uses `f64` elements where each frame is reduced
|
||||
/// to its L2 norm for scalar temporal comparison.
|
||||
pub struct TemporalGestureClassifier {
|
||||
/// Configuration.
|
||||
config: TemporalGestureConfig,
|
||||
/// Registered gesture templates.
|
||||
templates: Vec<GestureTemplate>,
|
||||
/// Template sequences pre-converted to midstreamer format.
|
||||
template_sequences: Vec<Sequence<i64>>,
|
||||
/// Temporal comparator with caching.
|
||||
comparator: TemporalComparator<i64>,
|
||||
}
|
||||
|
||||
impl TemporalGestureClassifier {
|
||||
/// Create a new temporal gesture classifier.
|
||||
pub fn new(config: TemporalGestureConfig) -> Self {
|
||||
let comparator = TemporalComparator::new(
|
||||
config.cache_capacity,
|
||||
config.max_sequence_length,
|
||||
);
|
||||
Self {
|
||||
config,
|
||||
templates: Vec::new(),
|
||||
template_sequences: Vec::new(),
|
||||
comparator,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a gesture template.
|
||||
pub fn add_template(
|
||||
&mut self,
|
||||
template: GestureTemplate,
|
||||
) -> Result<(), GestureError> {
|
||||
if template.name.is_empty() {
|
||||
return Err(GestureError::InvalidTemplateName(
|
||||
"Template name cannot be empty".into(),
|
||||
));
|
||||
}
|
||||
if template.feature_dim != self.config.base.feature_dim {
|
||||
return Err(GestureError::DimensionMismatch {
|
||||
expected: self.config.base.feature_dim,
|
||||
got: template.feature_dim,
|
||||
});
|
||||
}
|
||||
if template.sequence.len() < self.config.base.min_sequence_len {
|
||||
return Err(GestureError::SequenceTooShort {
|
||||
needed: self.config.base.min_sequence_len,
|
||||
got: template.sequence.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let seq = Self::to_sequence(&template.sequence);
|
||||
self.template_sequences.push(seq);
|
||||
self.templates.push(template);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Number of registered templates.
|
||||
pub fn template_count(&self) -> usize {
|
||||
self.templates.len()
|
||||
}
|
||||
|
||||
/// Classify a perturbation sequence against registered templates.
|
||||
///
|
||||
/// Uses the configured comparison algorithm (DTW, LCS, or edit distance)
|
||||
/// from `midstreamer-temporal-compare`.
|
||||
pub fn classify(
|
||||
&self,
|
||||
sequence: &[Vec<f64>],
|
||||
person_id: u64,
|
||||
timestamp_us: u64,
|
||||
) -> Result<GestureResult, GestureError> {
|
||||
if self.templates.is_empty() {
|
||||
return Err(GestureError::NoTemplates);
|
||||
}
|
||||
if sequence.len() < self.config.base.min_sequence_len {
|
||||
return Err(GestureError::SequenceTooShort {
|
||||
needed: self.config.base.min_sequence_len,
|
||||
got: sequence.len(),
|
||||
});
|
||||
}
|
||||
for frame in sequence {
|
||||
if frame.len() != self.config.base.feature_dim {
|
||||
return Err(GestureError::DimensionMismatch {
|
||||
expected: self.config.base.feature_dim,
|
||||
got: frame.len(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let query_seq = Self::to_sequence(sequence);
|
||||
let algo = self.config.algorithm.to_comparison_algorithm();
|
||||
|
||||
let mut best_distance = f64::INFINITY;
|
||||
let mut second_best = f64::INFINITY;
|
||||
let mut best_idx: Option<usize> = None;
|
||||
|
||||
for (idx, template_seq) in self.template_sequences.iter().enumerate() {
|
||||
let result = self
|
||||
.comparator
|
||||
.compare(&query_seq, template_seq, algo);
|
||||
// Use distance from ComparisonResult (lower = better match)
|
||||
let distance = match result {
|
||||
Ok(cr) => cr.distance,
|
||||
Err(_) => f64::INFINITY,
|
||||
};
|
||||
|
||||
if distance < best_distance {
|
||||
second_best = best_distance;
|
||||
best_distance = distance;
|
||||
best_idx = Some(idx);
|
||||
} else if distance < second_best {
|
||||
second_best = distance;
|
||||
}
|
||||
}
|
||||
|
||||
let recognized = best_distance <= self.config.max_distance;
|
||||
|
||||
// Confidence based on margin between best and second-best
|
||||
let confidence = if recognized && second_best.is_finite() && second_best > 1e-10 {
|
||||
(1.0 - best_distance / second_best).clamp(0.0, 1.0)
|
||||
} else if recognized {
|
||||
(1.0 - best_distance / self.config.max_distance).clamp(0.0, 1.0)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
if let Some(idx) = best_idx {
|
||||
let template = &self.templates[idx];
|
||||
Ok(GestureResult {
|
||||
recognized,
|
||||
gesture_type: if recognized {
|
||||
Some(template.gesture_type)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
template_name: if recognized {
|
||||
Some(template.name.clone())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
distance: best_distance,
|
||||
confidence,
|
||||
person_id,
|
||||
timestamp_us,
|
||||
})
|
||||
} else {
|
||||
Ok(GestureResult {
|
||||
recognized: false,
|
||||
gesture_type: None,
|
||||
template_name: None,
|
||||
distance: f64::INFINITY,
|
||||
confidence: 0.0,
|
||||
person_id,
|
||||
timestamp_us,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Get cache statistics from the temporal comparator.
|
||||
pub fn cache_stats(&self) -> midstreamer_temporal_compare::CacheStats {
|
||||
self.comparator.cache_stats()
|
||||
}
|
||||
|
||||
/// Active comparison algorithm.
|
||||
pub fn algorithm(&self) -> GestureAlgorithm {
|
||||
self.config.algorithm
|
||||
}
|
||||
|
||||
/// Convert a feature sequence to a midstreamer `Sequence<i64>`.
|
||||
///
|
||||
/// Each frame's L2 norm is quantized to an i64 (multiplied by 1000)
|
||||
/// for use with the generic comparator.
|
||||
fn to_sequence(frames: &[Vec<f64>]) -> Sequence<i64> {
|
||||
let mut seq = Sequence::new();
|
||||
for (i, frame) in frames.iter().enumerate() {
|
||||
let norm = frame.iter().map(|x| x * x).sum::<f64>().sqrt();
|
||||
let quantized = (norm * 1000.0) as i64;
|
||||
seq.push(quantized, i as u64);
|
||||
}
|
||||
seq
|
||||
}
|
||||
}
|
||||
|
||||
// We implement Debug manually because TemporalComparator does not derive Debug
|
||||
impl std::fmt::Debug for TemporalGestureClassifier {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("TemporalGestureClassifier")
|
||||
.field("config", &self.config)
|
||||
.field("template_count", &self.templates.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use super::super::gesture::GestureType;
|
||||
|
||||
fn make_template(
|
||||
name: &str,
|
||||
gesture_type: GestureType,
|
||||
n_frames: usize,
|
||||
feature_dim: usize,
|
||||
pattern: fn(usize, usize) -> f64,
|
||||
) -> GestureTemplate {
|
||||
let sequence: Vec<Vec<f64>> = (0..n_frames)
|
||||
.map(|t| (0..feature_dim).map(|d| pattern(t, d)).collect())
|
||||
.collect();
|
||||
GestureTemplate {
|
||||
name: name.to_string(),
|
||||
gesture_type,
|
||||
sequence,
|
||||
feature_dim,
|
||||
}
|
||||
}
|
||||
|
||||
fn wave_pattern(t: usize, d: usize) -> f64 {
|
||||
if d == 0 {
|
||||
(t as f64 * 0.5).sin()
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
fn push_pattern(t: usize, d: usize) -> f64 {
|
||||
if d == 0 {
|
||||
t as f64 * 0.1
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
fn small_config() -> TemporalGestureConfig {
|
||||
TemporalGestureConfig {
|
||||
base: GestureConfig {
|
||||
feature_dim: 4,
|
||||
min_sequence_len: 5,
|
||||
max_distance: 10.0,
|
||||
band_width: 3,
|
||||
},
|
||||
algorithm: GestureAlgorithm::Dtw,
|
||||
enable_cache: false,
|
||||
cache_capacity: 64,
|
||||
max_distance: 100000.0, // generous for testing
|
||||
max_sequence_length: 1024,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temporal_classifier_creation() {
|
||||
let classifier = TemporalGestureClassifier::new(small_config());
|
||||
assert_eq!(classifier.template_count(), 0);
|
||||
assert_eq!(classifier.algorithm(), GestureAlgorithm::Dtw);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temporal_add_template() {
|
||||
let mut classifier = TemporalGestureClassifier::new(small_config());
|
||||
let template = make_template("wave", GestureType::Wave, 10, 4, wave_pattern);
|
||||
classifier.add_template(template).unwrap();
|
||||
assert_eq!(classifier.template_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temporal_add_template_empty_name() {
|
||||
let mut classifier = TemporalGestureClassifier::new(small_config());
|
||||
let template = make_template("", GestureType::Wave, 10, 4, wave_pattern);
|
||||
assert!(matches!(
|
||||
classifier.add_template(template),
|
||||
Err(GestureError::InvalidTemplateName(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temporal_add_template_wrong_dim() {
|
||||
let mut classifier = TemporalGestureClassifier::new(small_config());
|
||||
let template = make_template("wave", GestureType::Wave, 10, 8, wave_pattern);
|
||||
assert!(matches!(
|
||||
classifier.add_template(template),
|
||||
Err(GestureError::DimensionMismatch { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temporal_classify_no_templates() {
|
||||
let classifier = TemporalGestureClassifier::new(small_config());
|
||||
let seq: Vec<Vec<f64>> = (0..10).map(|_| vec![0.0; 4]).collect();
|
||||
assert!(matches!(
|
||||
classifier.classify(&seq, 1, 0),
|
||||
Err(GestureError::NoTemplates)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temporal_classify_too_short() {
|
||||
let mut classifier = TemporalGestureClassifier::new(small_config());
|
||||
classifier
|
||||
.add_template(make_template("wave", GestureType::Wave, 10, 4, wave_pattern))
|
||||
.unwrap();
|
||||
let seq: Vec<Vec<f64>> = (0..3).map(|_| vec![0.0; 4]).collect();
|
||||
assert!(matches!(
|
||||
classifier.classify(&seq, 1, 0),
|
||||
Err(GestureError::SequenceTooShort { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temporal_classify_exact_match() {
|
||||
let mut classifier = TemporalGestureClassifier::new(small_config());
|
||||
let template = make_template("wave", GestureType::Wave, 10, 4, wave_pattern);
|
||||
classifier.add_template(template).unwrap();
|
||||
|
||||
let seq: Vec<Vec<f64>> = (0..10)
|
||||
.map(|t| (0..4).map(|d| wave_pattern(t, d)).collect())
|
||||
.collect();
|
||||
|
||||
let result = classifier.classify(&seq, 1, 100_000).unwrap();
|
||||
assert!(result.recognized, "Exact match should be recognized");
|
||||
assert_eq!(result.gesture_type, Some(GestureType::Wave));
|
||||
assert!(result.distance < 1e-6, "Exact match should have near-zero distance");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temporal_classify_best_of_two() {
|
||||
let mut classifier = TemporalGestureClassifier::new(small_config());
|
||||
classifier
|
||||
.add_template(make_template("wave", GestureType::Wave, 10, 4, wave_pattern))
|
||||
.unwrap();
|
||||
classifier
|
||||
.add_template(make_template("push", GestureType::Push, 10, 4, push_pattern))
|
||||
.unwrap();
|
||||
|
||||
let seq: Vec<Vec<f64>> = (0..10)
|
||||
.map(|t| (0..4).map(|d| wave_pattern(t, d)).collect())
|
||||
.collect();
|
||||
|
||||
let result = classifier.classify(&seq, 1, 0).unwrap();
|
||||
assert!(result.recognized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temporal_algorithm_selection() {
|
||||
assert_eq!(
|
||||
GestureAlgorithm::Dtw.to_comparison_algorithm(),
|
||||
ComparisonAlgorithm::DTW
|
||||
);
|
||||
assert_eq!(
|
||||
GestureAlgorithm::Lcs.to_comparison_algorithm(),
|
||||
ComparisonAlgorithm::LCS
|
||||
);
|
||||
assert_eq!(
|
||||
GestureAlgorithm::EditDistance.to_comparison_algorithm(),
|
||||
ComparisonAlgorithm::EditDistance
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temporal_lcs_algorithm() {
|
||||
let config = TemporalGestureConfig {
|
||||
algorithm: GestureAlgorithm::Lcs,
|
||||
..small_config()
|
||||
};
|
||||
let mut classifier = TemporalGestureClassifier::new(config);
|
||||
classifier
|
||||
.add_template(make_template("wave", GestureType::Wave, 10, 4, wave_pattern))
|
||||
.unwrap();
|
||||
|
||||
let seq: Vec<Vec<f64>> = (0..10)
|
||||
.map(|t| (0..4).map(|d| wave_pattern(t, d)).collect())
|
||||
.collect();
|
||||
|
||||
let result = classifier.classify(&seq, 1, 0).unwrap();
|
||||
assert!(result.recognized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temporal_edit_distance_algorithm() {
|
||||
let config = TemporalGestureConfig {
|
||||
algorithm: GestureAlgorithm::EditDistance,
|
||||
..small_config()
|
||||
};
|
||||
let mut classifier = TemporalGestureClassifier::new(config);
|
||||
classifier
|
||||
.add_template(make_template("wave", GestureType::Wave, 10, 4, wave_pattern))
|
||||
.unwrap();
|
||||
|
||||
let seq: Vec<Vec<f64>> = (0..10)
|
||||
.map(|t| (0..4).map(|d| wave_pattern(t, d)).collect())
|
||||
.collect();
|
||||
|
||||
let result = classifier.classify(&seq, 1, 0).unwrap();
|
||||
assert!(result.recognized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temporal_default_config() {
|
||||
let config = TemporalGestureConfig::default();
|
||||
assert_eq!(config.algorithm, GestureAlgorithm::Dtw);
|
||||
assert!(config.enable_cache);
|
||||
assert_eq!(config.cache_capacity, 256);
|
||||
assert!((config.max_distance - 50.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temporal_cache_stats() {
|
||||
let classifier = TemporalGestureClassifier::new(small_config());
|
||||
let stats = classifier.cache_stats();
|
||||
assert_eq!(stats.hits, 0);
|
||||
assert_eq!(stats.misses, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_sequence_conversion() {
|
||||
let frames: Vec<Vec<f64>> = vec![vec![3.0, 4.0], vec![0.0, 1.0]];
|
||||
let seq = TemporalGestureClassifier::to_sequence(&frames);
|
||||
// First element: sqrt(9+16) = 5.0 -> 5000
|
||||
// Second element: sqrt(0+1) = 1.0 -> 1000
|
||||
assert_eq!(seq.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_debug_impl() {
|
||||
let classifier = TemporalGestureClassifier::new(small_config());
|
||||
let dbg = format!("{:?}", classifier);
|
||||
assert!(dbg.contains("TemporalGestureClassifier"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,689 @@
|
||||
//! Coarse RF Tomography from link attenuations.
|
||||
//!
|
||||
//! Produces a low-resolution 3D occupancy volume by inverting per-link
|
||||
//! attenuation measurements. Each voxel receives an occupancy probability
|
||||
//! based on how many links traverse it and how much attenuation those links
|
||||
//! observed.
|
||||
//!
|
||||
//! # Algorithm
|
||||
//! 1. Define a voxel grid covering the monitored volume
|
||||
//! 2. For each link, determine which voxels lie along the propagation path
|
||||
//! 3. Solve the sparse tomographic inverse: attenuation = sum(voxel_density * path_weight)
|
||||
//! 4. Apply L1 regularization for sparsity (most voxels are unoccupied)
|
||||
//!
|
||||
//! # References
|
||||
//! - ADR-030 Tier 2: Coarse RF Tomography
|
||||
//! - Wilson & Patwari (2010), "Radio Tomographic Imaging"
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Error types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors from tomography operations.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum TomographyError {
|
||||
/// Not enough links for tomographic inversion.
|
||||
#[error("Insufficient links: need >= {needed}, got {got}")]
|
||||
InsufficientLinks { needed: usize, got: usize },
|
||||
|
||||
/// Grid dimensions are invalid.
|
||||
#[error("Invalid grid dimensions: {0}")]
|
||||
InvalidGrid(String),
|
||||
|
||||
/// No voxels intersected by any link.
|
||||
#[error("No voxels intersected by links — check geometry")]
|
||||
NoIntersections,
|
||||
|
||||
/// Observation vector length mismatch.
|
||||
#[error("Observation length mismatch: expected {expected}, got {got}")]
|
||||
ObservationMismatch { expected: usize, got: usize },
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Configuration
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Configuration for the voxel grid and tomographic solver.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TomographyConfig {
|
||||
/// Number of voxels along X axis.
|
||||
pub nx: usize,
|
||||
/// Number of voxels along Y axis.
|
||||
pub ny: usize,
|
||||
/// Number of voxels along Z axis.
|
||||
pub nz: usize,
|
||||
/// Physical extent of the grid: `[x_min, y_min, z_min, x_max, y_max, z_max]`.
|
||||
pub bounds: [f64; 6],
|
||||
/// L1 regularization weight (higher = sparser solution).
|
||||
pub lambda: f64,
|
||||
/// Maximum iterations for the solver.
|
||||
pub max_iterations: usize,
|
||||
/// Convergence tolerance.
|
||||
pub tolerance: f64,
|
||||
/// Minimum links required for inversion (default 8).
|
||||
pub min_links: usize,
|
||||
}
|
||||
|
||||
impl Default for TomographyConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
nx: 8,
|
||||
ny: 8,
|
||||
nz: 4,
|
||||
bounds: [0.0, 0.0, 0.0, 6.0, 6.0, 3.0],
|
||||
lambda: 0.1,
|
||||
max_iterations: 100,
|
||||
tolerance: 1e-4,
|
||||
min_links: 8,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Geometry types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A 3D position.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Position3D {
|
||||
pub x: f64,
|
||||
pub y: f64,
|
||||
pub z: f64,
|
||||
}
|
||||
|
||||
/// A link between a transmitter and receiver.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LinkGeometry {
|
||||
/// Transmitter position.
|
||||
pub tx: Position3D,
|
||||
/// Receiver position.
|
||||
pub rx: Position3D,
|
||||
/// Link identifier.
|
||||
pub link_id: usize,
|
||||
}
|
||||
|
||||
impl LinkGeometry {
|
||||
/// Euclidean distance between TX and RX.
|
||||
pub fn distance(&self) -> f64 {
|
||||
let dx = self.rx.x - self.tx.x;
|
||||
let dy = self.rx.y - self.tx.y;
|
||||
let dz = self.rx.z - self.tx.z;
|
||||
(dx * dx + dy * dy + dz * dz).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Occupancy volume
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// 3D occupancy grid resulting from tomographic inversion.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OccupancyVolume {
|
||||
/// Voxel densities in row-major order `[nz][ny][nx]`.
|
||||
pub densities: Vec<f64>,
|
||||
/// Grid dimensions.
|
||||
pub nx: usize,
|
||||
pub ny: usize,
|
||||
pub nz: usize,
|
||||
/// Physical bounds.
|
||||
pub bounds: [f64; 6],
|
||||
/// Number of occupied voxels (density > threshold).
|
||||
pub occupied_count: usize,
|
||||
/// Total voxel count.
|
||||
pub total_voxels: usize,
|
||||
/// Solver residual at convergence.
|
||||
pub residual: f64,
|
||||
/// Number of iterations used.
|
||||
pub iterations: usize,
|
||||
}
|
||||
|
||||
impl OccupancyVolume {
|
||||
/// Get density at voxel (ix, iy, iz). Returns None if out of bounds.
|
||||
pub fn get(&self, ix: usize, iy: usize, iz: usize) -> Option<f64> {
|
||||
if ix < self.nx && iy < self.ny && iz < self.nz {
|
||||
Some(self.densities[iz * self.ny * self.nx + iy * self.nx + ix])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Voxel size along each axis.
|
||||
pub fn voxel_size(&self) -> [f64; 3] {
|
||||
[
|
||||
(self.bounds[3] - self.bounds[0]) / self.nx as f64,
|
||||
(self.bounds[4] - self.bounds[1]) / self.ny as f64,
|
||||
(self.bounds[5] - self.bounds[2]) / self.nz as f64,
|
||||
]
|
||||
}
|
||||
|
||||
/// Center position of voxel (ix, iy, iz).
|
||||
pub fn voxel_center(&self, ix: usize, iy: usize, iz: usize) -> Position3D {
|
||||
let vs = self.voxel_size();
|
||||
Position3D {
|
||||
x: self.bounds[0] + (ix as f64 + 0.5) * vs[0],
|
||||
y: self.bounds[1] + (iy as f64 + 0.5) * vs[1],
|
||||
z: self.bounds[2] + (iz as f64 + 0.5) * vs[2],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tomographic solver
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Coarse RF tomography solver.
|
||||
///
|
||||
/// Given a set of TX-RX links and per-link attenuation measurements,
|
||||
/// reconstructs a 3D occupancy volume using L1-regularized least squares.
|
||||
pub struct RfTomographer {
|
||||
config: TomographyConfig,
|
||||
/// Precomputed weight matrix: `weight_matrix[link_idx]` is a list of
|
||||
/// (voxel_index, weight) pairs.
|
||||
weight_matrix: Vec<Vec<(usize, f64)>>,
|
||||
/// Number of voxels.
|
||||
n_voxels: usize,
|
||||
}
|
||||
|
||||
impl RfTomographer {
|
||||
/// Create a new tomographer with the given configuration and link geometry.
|
||||
pub fn new(config: TomographyConfig, links: &[LinkGeometry]) -> Result<Self, TomographyError> {
|
||||
if links.len() < config.min_links {
|
||||
return Err(TomographyError::InsufficientLinks {
|
||||
needed: config.min_links,
|
||||
got: links.len(),
|
||||
});
|
||||
}
|
||||
if config.nx == 0 || config.ny == 0 || config.nz == 0 {
|
||||
return Err(TomographyError::InvalidGrid(
|
||||
"Grid dimensions must be > 0".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let n_voxels = config
|
||||
.nx
|
||||
.checked_mul(config.ny)
|
||||
.and_then(|v| v.checked_mul(config.nz))
|
||||
.ok_or_else(|| {
|
||||
TomographyError::InvalidGrid(format!(
|
||||
"Grid dimensions overflow: {}x{}x{}",
|
||||
config.nx, config.ny, config.nz
|
||||
))
|
||||
})?;
|
||||
|
||||
// Precompute weight matrix
|
||||
let weight_matrix: Vec<Vec<(usize, f64)>> = links
|
||||
.iter()
|
||||
.map(|link| compute_link_weights(link, &config))
|
||||
.collect();
|
||||
|
||||
// Ensure at least one link intersects some voxels
|
||||
let total_weights: usize = weight_matrix.iter().map(|w| w.len()).sum();
|
||||
if total_weights == 0 {
|
||||
return Err(TomographyError::NoIntersections);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
weight_matrix,
|
||||
n_voxels,
|
||||
})
|
||||
}
|
||||
|
||||
/// Reconstruct occupancy from per-link attenuation measurements.
|
||||
///
|
||||
/// `attenuations` has one entry per link (same order as links passed to `new`).
|
||||
/// Higher attenuation indicates more obstruction along the link path.
|
||||
pub fn reconstruct(&self, attenuations: &[f64]) -> Result<OccupancyVolume, TomographyError> {
|
||||
if attenuations.len() != self.weight_matrix.len() {
|
||||
return Err(TomographyError::ObservationMismatch {
|
||||
expected: self.weight_matrix.len(),
|
||||
got: attenuations.len(),
|
||||
});
|
||||
}
|
||||
|
||||
// ISTA (Iterative Shrinkage-Thresholding Algorithm) for L1 minimization
|
||||
// min ||Wx - y||^2 + lambda * ||x||_1
|
||||
let mut x = vec![0.0_f64; self.n_voxels];
|
||||
let n_links = attenuations.len();
|
||||
|
||||
// Estimate step size: 1 / L where L is the Lipschitz constant of the
|
||||
// gradient of ||Wx - y||^2, i.e. the spectral norm of W^T W.
|
||||
// A safe upper bound is the Frobenius norm squared of W (sum of all
|
||||
// squared entries), since ||W^T W|| <= ||W||_F^2.
|
||||
let frobenius_sq: f64 = self
|
||||
.weight_matrix
|
||||
.iter()
|
||||
.flat_map(|ws| ws.iter().map(|&(_, w)| w * w))
|
||||
.sum();
|
||||
let lipschitz = frobenius_sq.max(1e-10);
|
||||
let step_size = 1.0 / lipschitz;
|
||||
|
||||
let mut residual = 0.0_f64;
|
||||
let mut iterations = 0;
|
||||
|
||||
for iter in 0..self.config.max_iterations {
|
||||
// Compute gradient: W^T (Wx - y)
|
||||
let mut gradient = vec![0.0_f64; self.n_voxels];
|
||||
residual = 0.0;
|
||||
|
||||
for (link_idx, weights) in self.weight_matrix.iter().enumerate() {
|
||||
// Forward: Wx for this link
|
||||
let predicted: f64 = weights.iter().map(|&(idx, w)| w * x[idx]).sum();
|
||||
let diff = predicted - attenuations[link_idx];
|
||||
residual += diff * diff;
|
||||
|
||||
// Backward: accumulate gradient
|
||||
for &(idx, w) in weights {
|
||||
gradient[idx] += w * diff;
|
||||
}
|
||||
}
|
||||
|
||||
residual = (residual / n_links as f64).sqrt();
|
||||
|
||||
// Gradient step + soft thresholding (proximal L1)
|
||||
let mut max_change = 0.0_f64;
|
||||
for i in 0..self.n_voxels {
|
||||
let new_val = x[i] - step_size * gradient[i];
|
||||
// Soft thresholding
|
||||
let threshold = self.config.lambda * step_size;
|
||||
let shrunk = if new_val > threshold {
|
||||
new_val - threshold
|
||||
} else if new_val < -threshold {
|
||||
new_val + threshold
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
// Non-negativity constraint (density >= 0)
|
||||
let clamped = shrunk.max(0.0);
|
||||
max_change = max_change.max((clamped - x[i]).abs());
|
||||
x[i] = clamped;
|
||||
}
|
||||
|
||||
iterations = iter + 1;
|
||||
|
||||
if max_change < self.config.tolerance {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Count occupied voxels (density > 0.01)
|
||||
let occupied_count = x.iter().filter(|&&d| d > 0.01).count();
|
||||
|
||||
Ok(OccupancyVolume {
|
||||
densities: x,
|
||||
nx: self.config.nx,
|
||||
ny: self.config.ny,
|
||||
nz: self.config.nz,
|
||||
bounds: self.config.bounds,
|
||||
occupied_count,
|
||||
total_voxels: self.n_voxels,
|
||||
residual,
|
||||
iterations,
|
||||
})
|
||||
}
|
||||
|
||||
/// Number of links in this tomographer.
|
||||
pub fn n_links(&self) -> usize {
|
||||
self.weight_matrix.len()
|
||||
}
|
||||
|
||||
/// Number of voxels in the grid.
|
||||
pub fn n_voxels(&self) -> usize {
|
||||
self.n_voxels
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Weight computation (simplified ray-voxel intersection)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Compute the intersection weights of a link with the voxel grid.
|
||||
///
|
||||
/// Uses a simplified approach: for each voxel, computes the minimum
|
||||
/// distance from the voxel center to the link ray. Voxels within
|
||||
/// one Fresnel zone receive weight proportional to closeness.
|
||||
fn compute_link_weights(link: &LinkGeometry, config: &TomographyConfig) -> Vec<(usize, f64)> {
|
||||
let vx = (config.bounds[3] - config.bounds[0]) / config.nx as f64;
|
||||
let vy = (config.bounds[4] - config.bounds[1]) / config.ny as f64;
|
||||
let vz = (config.bounds[5] - config.bounds[2]) / config.nz as f64;
|
||||
|
||||
// Fresnel zone half-width (approximate)
|
||||
let link_dist = link.distance();
|
||||
let wavelength = 0.06; // ~5 GHz
|
||||
let fresnel_radius = (wavelength * link_dist / 4.0).sqrt().max(vx.max(vy));
|
||||
|
||||
let dx = link.rx.x - link.tx.x;
|
||||
let dy = link.rx.y - link.tx.y;
|
||||
let dz = link.rx.z - link.tx.z;
|
||||
|
||||
let mut weights = Vec::new();
|
||||
|
||||
for iz in 0..config.nz {
|
||||
for iy in 0..config.ny {
|
||||
for ix in 0..config.nx {
|
||||
let cx = config.bounds[0] + (ix as f64 + 0.5) * vx;
|
||||
let cy = config.bounds[1] + (iy as f64 + 0.5) * vy;
|
||||
let cz = config.bounds[2] + (iz as f64 + 0.5) * vz;
|
||||
|
||||
// Point-to-line distance
|
||||
let dist = point_to_segment_distance(
|
||||
cx, cy, cz, link.tx.x, link.tx.y, link.tx.z, dx, dy, dz, link_dist,
|
||||
);
|
||||
|
||||
if dist < fresnel_radius {
|
||||
// Weight decays with distance from link ray
|
||||
let w = 1.0 - dist / fresnel_radius;
|
||||
let idx = iz * config.ny * config.nx + iy * config.nx + ix;
|
||||
weights.push((idx, w));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
weights
|
||||
}
|
||||
|
||||
/// Distance from point (px,py,pz) to line segment defined by start + t*dir
|
||||
/// where dir = (dx,dy,dz) and segment length = `seg_len`.
|
||||
fn point_to_segment_distance(
|
||||
px: f64,
|
||||
py: f64,
|
||||
pz: f64,
|
||||
sx: f64,
|
||||
sy: f64,
|
||||
sz: f64,
|
||||
dx: f64,
|
||||
dy: f64,
|
||||
dz: f64,
|
||||
seg_len: f64,
|
||||
) -> f64 {
|
||||
if seg_len < 1e-12 {
|
||||
return ((px - sx).powi(2) + (py - sy).powi(2) + (pz - sz).powi(2)).sqrt();
|
||||
}
|
||||
|
||||
// Project point onto line: t = dot(P-S, D) / |D|^2
|
||||
let t = ((px - sx) * dx + (py - sy) * dy + (pz - sz) * dz) / (seg_len * seg_len);
|
||||
let t_clamped = t.clamp(0.0, 1.0);
|
||||
|
||||
let closest_x = sx + t_clamped * dx;
|
||||
let closest_y = sy + t_clamped * dy;
|
||||
let closest_z = sz + t_clamped * dz;
|
||||
|
||||
((px - closest_x).powi(2) + (py - closest_y).powi(2) + (pz - closest_z).powi(2)).sqrt()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_square_links() -> Vec<LinkGeometry> {
|
||||
// 4 nodes in a square at z=1.5, 12 directed links
|
||||
let nodes = [
|
||||
Position3D {
|
||||
x: 0.5,
|
||||
y: 0.5,
|
||||
z: 1.5,
|
||||
},
|
||||
Position3D {
|
||||
x: 5.5,
|
||||
y: 0.5,
|
||||
z: 1.5,
|
||||
},
|
||||
Position3D {
|
||||
x: 5.5,
|
||||
y: 5.5,
|
||||
z: 1.5,
|
||||
},
|
||||
Position3D {
|
||||
x: 0.5,
|
||||
y: 5.5,
|
||||
z: 1.5,
|
||||
},
|
||||
];
|
||||
let mut links = Vec::new();
|
||||
let mut id = 0;
|
||||
for i in 0..4 {
|
||||
for j in 0..4 {
|
||||
if i != j {
|
||||
links.push(LinkGeometry {
|
||||
tx: nodes[i],
|
||||
rx: nodes[j],
|
||||
link_id: id,
|
||||
});
|
||||
id += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
links
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tomographer_creation() {
|
||||
let links = make_square_links();
|
||||
let config = TomographyConfig {
|
||||
min_links: 8,
|
||||
..Default::default()
|
||||
};
|
||||
let tomo = RfTomographer::new(config, &links).unwrap();
|
||||
assert_eq!(tomo.n_links(), 12);
|
||||
assert_eq!(tomo.n_voxels(), 8 * 8 * 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insufficient_links() {
|
||||
let links = vec![LinkGeometry {
|
||||
tx: Position3D {
|
||||
x: 0.0,
|
||||
y: 0.0,
|
||||
z: 0.0,
|
||||
},
|
||||
rx: Position3D {
|
||||
x: 1.0,
|
||||
y: 0.0,
|
||||
z: 0.0,
|
||||
},
|
||||
link_id: 0,
|
||||
}];
|
||||
let config = TomographyConfig {
|
||||
min_links: 8,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(matches!(
|
||||
RfTomographer::new(config, &links),
|
||||
Err(TomographyError::InsufficientLinks { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_grid() {
|
||||
let links = make_square_links();
|
||||
let config = TomographyConfig {
|
||||
nx: 0,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(matches!(
|
||||
RfTomographer::new(config, &links),
|
||||
Err(TomographyError::InvalidGrid(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_attenuation_empty_room() {
|
||||
let links = make_square_links();
|
||||
let config = TomographyConfig {
|
||||
min_links: 8,
|
||||
..Default::default()
|
||||
};
|
||||
let tomo = RfTomographer::new(config, &links).unwrap();
|
||||
|
||||
// Zero attenuation = empty room
|
||||
let attenuations = vec![0.0; tomo.n_links()];
|
||||
let volume = tomo.reconstruct(&attenuations).unwrap();
|
||||
|
||||
assert_eq!(volume.total_voxels, 8 * 8 * 4);
|
||||
// All densities should be zero or near zero
|
||||
assert!(
|
||||
volume.occupied_count == 0,
|
||||
"Empty room should have no occupied voxels, got {}",
|
||||
volume.occupied_count
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nonzero_attenuation_produces_density() {
|
||||
let links = make_square_links();
|
||||
let config = TomographyConfig {
|
||||
min_links: 8,
|
||||
lambda: 0.001, // light regularization so solution is not zeroed
|
||||
max_iterations: 500,
|
||||
tolerance: 1e-8,
|
||||
..Default::default()
|
||||
};
|
||||
let tomo = RfTomographer::new(config, &links).unwrap();
|
||||
|
||||
// Strong attenuations to represent obstructed links
|
||||
let attenuations: Vec<f64> = (0..tomo.n_links()).map(|i| 5.0 + 1.0 * i as f64).collect();
|
||||
let volume = tomo.reconstruct(&attenuations).unwrap();
|
||||
|
||||
// Check that at least some voxels have non-negligible density
|
||||
let any_nonzero = volume.densities.iter().any(|&d| d > 1e-6);
|
||||
assert!(
|
||||
any_nonzero,
|
||||
"Non-zero attenuation should produce non-zero voxel densities"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_observation_mismatch() {
|
||||
let links = make_square_links();
|
||||
let config = TomographyConfig {
|
||||
min_links: 8,
|
||||
..Default::default()
|
||||
};
|
||||
let tomo = RfTomographer::new(config, &links).unwrap();
|
||||
|
||||
let attenuations = vec![0.1; 3]; // wrong count
|
||||
assert!(matches!(
|
||||
tomo.reconstruct(&attenuations),
|
||||
Err(TomographyError::ObservationMismatch { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_voxel_access() {
|
||||
let links = make_square_links();
|
||||
let config = TomographyConfig {
|
||||
min_links: 8,
|
||||
..Default::default()
|
||||
};
|
||||
let tomo = RfTomographer::new(config, &links).unwrap();
|
||||
|
||||
let attenuations = vec![0.0; tomo.n_links()];
|
||||
let volume = tomo.reconstruct(&attenuations).unwrap();
|
||||
|
||||
// Valid access
|
||||
assert!(volume.get(0, 0, 0).is_some());
|
||||
assert!(volume.get(7, 7, 3).is_some());
|
||||
// Out of bounds
|
||||
assert!(volume.get(8, 0, 0).is_none());
|
||||
assert!(volume.get(0, 8, 0).is_none());
|
||||
assert!(volume.get(0, 0, 4).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_voxel_center() {
|
||||
let links = make_square_links();
|
||||
let config = TomographyConfig {
|
||||
nx: 6,
|
||||
ny: 6,
|
||||
nz: 3,
|
||||
min_links: 8,
|
||||
..Default::default()
|
||||
};
|
||||
let tomo = RfTomographer::new(config, &links).unwrap();
|
||||
|
||||
let attenuations = vec![0.0; tomo.n_links()];
|
||||
let volume = tomo.reconstruct(&attenuations).unwrap();
|
||||
|
||||
let center = volume.voxel_center(0, 0, 0);
|
||||
assert!(center.x > 0.0 && center.x < 1.0);
|
||||
assert!(center.y > 0.0 && center.y < 1.0);
|
||||
assert!(center.z > 0.0 && center.z < 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_voxel_size() {
|
||||
let links = make_square_links();
|
||||
let config = TomographyConfig {
|
||||
nx: 6,
|
||||
ny: 6,
|
||||
nz: 3,
|
||||
bounds: [0.0, 0.0, 0.0, 6.0, 6.0, 3.0],
|
||||
min_links: 8,
|
||||
..Default::default()
|
||||
};
|
||||
let tomo = RfTomographer::new(config, &links).unwrap();
|
||||
|
||||
let attenuations = vec![0.0; tomo.n_links()];
|
||||
let volume = tomo.reconstruct(&attenuations).unwrap();
|
||||
let vs = volume.voxel_size();
|
||||
|
||||
assert!((vs[0] - 1.0).abs() < 1e-10);
|
||||
assert!((vs[1] - 1.0).abs() < 1e-10);
|
||||
assert!((vs[2] - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_point_to_segment_distance() {
|
||||
// Point directly on the segment
|
||||
let d = point_to_segment_distance(0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0);
|
||||
assert!(d < 1e-10);
|
||||
|
||||
// Point 1 unit above the midpoint
|
||||
let d = point_to_segment_distance(0.5, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0);
|
||||
assert!((d - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_link_distance() {
|
||||
let link = LinkGeometry {
|
||||
tx: Position3D {
|
||||
x: 0.0,
|
||||
y: 0.0,
|
||||
z: 0.0,
|
||||
},
|
||||
rx: Position3D {
|
||||
x: 3.0,
|
||||
y: 4.0,
|
||||
z: 0.0,
|
||||
},
|
||||
link_id: 0,
|
||||
};
|
||||
assert!((link.distance() - 5.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_solver_convergence() {
|
||||
let links = make_square_links();
|
||||
let config = TomographyConfig {
|
||||
min_links: 8,
|
||||
lambda: 0.01,
|
||||
max_iterations: 500,
|
||||
tolerance: 1e-6,
|
||||
..Default::default()
|
||||
};
|
||||
let tomo = RfTomographer::new(config, &links).unwrap();
|
||||
|
||||
let attenuations: Vec<f64> = (0..tomo.n_links())
|
||||
.map(|i| 0.3 * (i as f64 * 0.7).sin().abs())
|
||||
.collect();
|
||||
let volume = tomo.reconstruct(&attenuations).unwrap();
|
||||
|
||||
assert!(volume.residual.is_finite());
|
||||
assert!(volume.iterations > 0);
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "wifi-densepose-train"
|
||||
version = "0.2.0"
|
||||
version = "0.3.0"
|
||||
edition = "2021"
|
||||
authors = ["rUv <ruv@ruv.net>", "WiFi-DensePose Contributors"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
@@ -27,8 +27,8 @@ cuda = ["tch-backend"]
|
||||
|
||||
[dependencies]
|
||||
# Internal crates
|
||||
wifi-densepose-signal = { version = "0.2.0", path = "../wifi-densepose-signal" }
|
||||
wifi-densepose-nn = { version = "0.2.0", path = "../wifi-densepose-nn" }
|
||||
wifi-densepose-signal = { version = "0.3.0", path = "../wifi-densepose-signal" }
|
||||
wifi-densepose-nn = { version = "0.3.0", path = "../wifi-densepose-nn" }
|
||||
|
||||
# Core
|
||||
thiserror.workspace = true
|
||||
|
||||
@@ -50,6 +50,7 @@ pub mod error;
|
||||
pub mod eval;
|
||||
pub mod geometry;
|
||||
pub mod rapid_adapt;
|
||||
pub mod ruview_metrics;
|
||||
pub mod subcarrier;
|
||||
pub mod virtual_aug;
|
||||
|
||||
|
||||
@@ -0,0 +1,947 @@
|
||||
//! RuView three-metric acceptance test (ADR-031).
|
||||
//!
|
||||
//! Implements the tiered pass/fail acceptance criteria for multistatic fusion:
|
||||
//!
|
||||
//! 1. **Joint Error (PCK / OKS)**: pose estimation accuracy.
|
||||
//! 2. **Multi-Person Separation (MOTA)**: tracking identity maintenance.
|
||||
//! 3. **Vital Sign Accuracy**: breathing and heartbeat detection precision.
|
||||
//!
|
||||
//! Tiered evaluation:
|
||||
//!
|
||||
//! | Tier | Requirements | Deployment Gate |
|
||||
//! |--------|----------------|------------------------|
|
||||
//! | Bronze | Metric 2 | Prototype demo |
|
||||
//! | Silver | Metrics 1 + 2 | Production candidate |
|
||||
//! | Gold | All three | Full deployment |
|
||||
//!
|
||||
//! # No mock data
|
||||
//!
|
||||
//! All computations use real metric definitions from the COCO evaluation
|
||||
//! protocol, MOT challenge MOTA definition, and signal-processing SNR
|
||||
//! measurement. No synthetic values are introduced at runtime.
|
||||
|
||||
use ndarray::{Array1, Array2};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tier definitions
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Deployment tier achieved by the acceptance test.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub enum RuViewTier {
|
||||
/// No tier met -- system fails acceptance.
|
||||
Fail,
|
||||
/// Metric 2 (tracking) passes. Prototype demo gate.
|
||||
Bronze,
|
||||
/// Metrics 1 + 2 (pose + tracking) pass. Production candidate gate.
|
||||
Silver,
|
||||
/// All three metrics pass. Full deployment gate.
|
||||
Gold,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for RuViewTier {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
RuViewTier::Fail => write!(f, "FAIL"),
|
||||
RuViewTier::Bronze => write!(f, "BRONZE"),
|
||||
RuViewTier::Silver => write!(f, "SILVER"),
|
||||
RuViewTier::Gold => write!(f, "GOLD"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Metric 1: Joint Error (PCK / OKS)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Thresholds for Metric 1 (Joint Error).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct JointErrorThresholds {
|
||||
/// PCK@0.2 all 17 keypoints (>= this to pass).
|
||||
pub pck_all: f32,
|
||||
/// PCK@0.2 torso keypoints (shoulders + hips, >= this to pass).
|
||||
pub pck_torso: f32,
|
||||
/// Mean OKS (>= this to pass).
|
||||
pub oks: f32,
|
||||
/// Torso jitter RMS in metres over 10s window (< this to pass).
|
||||
pub jitter_rms_m: f32,
|
||||
/// Per-keypoint max error 95th percentile in metres (< this to pass).
|
||||
pub max_error_p95_m: f32,
|
||||
}
|
||||
|
||||
impl Default for JointErrorThresholds {
|
||||
fn default() -> Self {
|
||||
JointErrorThresholds {
|
||||
pck_all: 0.70,
|
||||
pck_torso: 0.80,
|
||||
oks: 0.50,
|
||||
jitter_rms_m: 0.03,
|
||||
max_error_p95_m: 0.15,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of Metric 1 evaluation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct JointErrorResult {
|
||||
/// PCK@0.2 over all 17 keypoints.
|
||||
pub pck_all: f32,
|
||||
/// PCK@0.2 over torso keypoints (indices 5, 6, 11, 12).
|
||||
pub pck_torso: f32,
|
||||
/// Mean OKS.
|
||||
pub oks: f32,
|
||||
/// Torso jitter RMS (metres).
|
||||
pub jitter_rms_m: f32,
|
||||
/// Per-keypoint max error 95th percentile (metres).
|
||||
pub max_error_p95_m: f32,
|
||||
/// Whether this metric passes.
|
||||
pub passes: bool,
|
||||
}
|
||||
|
||||
/// COCO keypoint sigmas for OKS computation (17 joints).
|
||||
const COCO_SIGMAS: [f32; 17] = [
|
||||
0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072,
|
||||
0.062, 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089,
|
||||
];
|
||||
|
||||
/// Torso keypoint indices (COCO ordering): left_shoulder, right_shoulder,
|
||||
/// left_hip, right_hip.
|
||||
const TORSO_INDICES: [usize; 4] = [5, 6, 11, 12];
|
||||
|
||||
/// Evaluate Metric 1: Joint Error.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `pred_kpts`: per-frame predicted keypoints `[17, 2]` in normalised `[0,1]`.
|
||||
/// - `gt_kpts`: per-frame ground-truth keypoints `[17, 2]`.
|
||||
/// - `visibility`: per-frame visibility `[17]`, 0 = invisible.
|
||||
/// - `scale`: per-frame object scale for OKS (pass 1.0 if unknown).
|
||||
/// - `thresholds`: acceptance thresholds.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `JointErrorResult` with the computed metrics and pass/fail.
|
||||
pub fn evaluate_joint_error(
|
||||
pred_kpts: &[Array2<f32>],
|
||||
gt_kpts: &[Array2<f32>],
|
||||
visibility: &[Array1<f32>],
|
||||
scale: &[f32],
|
||||
thresholds: &JointErrorThresholds,
|
||||
) -> JointErrorResult {
|
||||
let n = pred_kpts.len();
|
||||
if n == 0 {
|
||||
return JointErrorResult {
|
||||
pck_all: 0.0,
|
||||
pck_torso: 0.0,
|
||||
oks: 0.0,
|
||||
jitter_rms_m: f32::MAX,
|
||||
max_error_p95_m: f32::MAX,
|
||||
passes: false,
|
||||
};
|
||||
}
|
||||
|
||||
// PCK@0.2 computation.
|
||||
let pck_threshold = 0.2;
|
||||
let mut all_correct = 0_usize;
|
||||
let mut all_total = 0_usize;
|
||||
let mut torso_correct = 0_usize;
|
||||
let mut torso_total = 0_usize;
|
||||
let mut oks_sum = 0.0_f64;
|
||||
let mut per_kp_errors: Vec<Vec<f32>> = vec![Vec::new(); 17];
|
||||
|
||||
for i in 0..n {
|
||||
let bbox_diag = compute_bbox_diag(>_kpts[i], &visibility[i]);
|
||||
let safe_diag = bbox_diag.max(1e-3);
|
||||
let dist_thr = pck_threshold * safe_diag;
|
||||
|
||||
for j in 0..17 {
|
||||
if visibility[i][j] < 0.5 {
|
||||
continue;
|
||||
}
|
||||
let dx = pred_kpts[i][[j, 0]] - gt_kpts[i][[j, 0]];
|
||||
let dy = pred_kpts[i][[j, 1]] - gt_kpts[i][[j, 1]];
|
||||
let dist = (dx * dx + dy * dy).sqrt();
|
||||
|
||||
per_kp_errors[j].push(dist);
|
||||
|
||||
all_total += 1;
|
||||
if dist <= dist_thr {
|
||||
all_correct += 1;
|
||||
}
|
||||
|
||||
if TORSO_INDICES.contains(&j) {
|
||||
torso_total += 1;
|
||||
if dist <= dist_thr {
|
||||
torso_correct += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OKS for this frame.
|
||||
let s = scale.get(i).copied().unwrap_or(1.0);
|
||||
let oks_frame = compute_single_oks(&pred_kpts[i], >_kpts[i], &visibility[i], s);
|
||||
oks_sum += oks_frame as f64;
|
||||
}
|
||||
|
||||
let pck_all = if all_total > 0 { all_correct as f32 / all_total as f32 } else { 0.0 };
|
||||
let pck_torso = if torso_total > 0 { torso_correct as f32 / torso_total as f32 } else { 0.0 };
|
||||
let oks = (oks_sum / n as f64) as f32;
|
||||
|
||||
// Torso jitter: RMS of frame-to-frame torso centroid displacement.
|
||||
let jitter_rms_m = compute_torso_jitter(pred_kpts, visibility);
|
||||
|
||||
// 95th percentile max per-keypoint error.
|
||||
let max_error_p95_m = compute_p95_max_error(&per_kp_errors);
|
||||
|
||||
let passes = pck_all >= thresholds.pck_all
|
||||
&& pck_torso >= thresholds.pck_torso
|
||||
&& oks >= thresholds.oks
|
||||
&& jitter_rms_m < thresholds.jitter_rms_m
|
||||
&& max_error_p95_m < thresholds.max_error_p95_m;
|
||||
|
||||
JointErrorResult {
|
||||
pck_all,
|
||||
pck_torso,
|
||||
oks,
|
||||
jitter_rms_m,
|
||||
max_error_p95_m,
|
||||
passes,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Metric 2: Multi-Person Separation (MOTA)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Thresholds for Metric 2 (Multi-Person Separation).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrackingThresholds {
|
||||
/// Maximum allowed identity switches (MOTA ID-switch). Must be 0 for pass.
|
||||
pub max_id_switches: usize,
|
||||
/// Maximum track fragmentation ratio (< this to pass).
|
||||
pub max_frag_ratio: f32,
|
||||
/// Maximum false track creations per minute (must be 0 for pass).
|
||||
pub max_false_tracks_per_min: f32,
|
||||
}
|
||||
|
||||
impl Default for TrackingThresholds {
|
||||
fn default() -> Self {
|
||||
TrackingThresholds {
|
||||
max_id_switches: 0,
|
||||
max_frag_ratio: 0.05,
|
||||
max_false_tracks_per_min: 0.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A single frame of tracking data for MOTA computation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrackingFrame {
|
||||
/// Frame index (0-based).
|
||||
pub frame_idx: usize,
|
||||
/// Ground-truth person IDs present in this frame.
|
||||
pub gt_ids: Vec<u32>,
|
||||
/// Predicted person IDs present in this frame.
|
||||
pub pred_ids: Vec<u32>,
|
||||
/// Assignment: `(pred_id, gt_id)` pairs for matched persons.
|
||||
pub assignments: Vec<(u32, u32)>,
|
||||
}
|
||||
|
||||
/// Result of Metric 2 evaluation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrackingResult {
|
||||
/// Number of identity switches across the sequence.
|
||||
pub id_switches: usize,
|
||||
/// Track fragmentation ratio.
|
||||
pub fragmentation_ratio: f32,
|
||||
/// False track creations per minute.
|
||||
pub false_tracks_per_min: f32,
|
||||
/// MOTA score (higher is better).
|
||||
pub mota: f32,
|
||||
/// Total number of frames evaluated.
|
||||
pub n_frames: usize,
|
||||
/// Whether this metric passes.
|
||||
pub passes: bool,
|
||||
}
|
||||
|
||||
/// Evaluate Metric 2: Multi-Person Separation.
|
||||
///
|
||||
/// Computes MOTA (Multiple Object Tracking Accuracy) components:
|
||||
/// identity switches, fragmentation ratio, and false track rate.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `frames`: per-frame tracking data with GT and predicted IDs + assignments.
|
||||
/// - `duration_minutes`: total duration of the tracking sequence in minutes.
|
||||
/// - `thresholds`: acceptance thresholds.
|
||||
pub fn evaluate_tracking(
|
||||
frames: &[TrackingFrame],
|
||||
duration_minutes: f32,
|
||||
thresholds: &TrackingThresholds,
|
||||
) -> TrackingResult {
|
||||
let n_frames = frames.len();
|
||||
if n_frames == 0 {
|
||||
return TrackingResult {
|
||||
id_switches: 0,
|
||||
fragmentation_ratio: 0.0,
|
||||
false_tracks_per_min: 0.0,
|
||||
mota: 0.0,
|
||||
n_frames: 0,
|
||||
passes: false,
|
||||
};
|
||||
}
|
||||
|
||||
// Count identity switches: a switch occurs when the predicted ID assigned
|
||||
// to a GT ID changes between consecutive frames.
|
||||
let mut id_switches = 0_usize;
|
||||
let mut prev_assignment: std::collections::HashMap<u32, u32> = std::collections::HashMap::new();
|
||||
let mut total_gt = 0_usize;
|
||||
let mut total_misses = 0_usize;
|
||||
let mut total_false_positives = 0_usize;
|
||||
|
||||
// Track fragmentation: count how many times a GT track is "broken"
|
||||
// (present in one frame, absent in the next, then present again).
|
||||
let mut gt_track_presence: std::collections::HashMap<u32, Vec<bool>> =
|
||||
std::collections::HashMap::new();
|
||||
|
||||
for frame in frames {
|
||||
total_gt += frame.gt_ids.len();
|
||||
let n_matched = frame.assignments.len();
|
||||
total_misses += frame.gt_ids.len().saturating_sub(n_matched);
|
||||
total_false_positives += frame.pred_ids.len().saturating_sub(n_matched);
|
||||
|
||||
let mut current_assignment: std::collections::HashMap<u32, u32> =
|
||||
std::collections::HashMap::new();
|
||||
for &(pred_id, gt_id) in &frame.assignments {
|
||||
current_assignment.insert(gt_id, pred_id);
|
||||
if let Some(&prev_pred) = prev_assignment.get(>_id) {
|
||||
if prev_pred != pred_id {
|
||||
id_switches += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Track presence for fragmentation.
|
||||
for >_id in &frame.gt_ids {
|
||||
gt_track_presence
|
||||
.entry(gt_id)
|
||||
.or_default()
|
||||
.push(frame.assignments.iter().any(|&(_, gid)| gid == gt_id));
|
||||
}
|
||||
|
||||
prev_assignment = current_assignment;
|
||||
}
|
||||
|
||||
// Fragmentation ratio: fraction of GT tracks that have gaps.
|
||||
let mut n_fragmented = 0_usize;
|
||||
let mut n_tracks = 0_usize;
|
||||
for presence in gt_track_presence.values() {
|
||||
if presence.len() < 2 {
|
||||
continue;
|
||||
}
|
||||
n_tracks += 1;
|
||||
let mut has_gap = false;
|
||||
let mut was_present = false;
|
||||
let mut lost = false;
|
||||
for &present in presence {
|
||||
if was_present && !present {
|
||||
lost = true;
|
||||
}
|
||||
if lost && present {
|
||||
has_gap = true;
|
||||
break;
|
||||
}
|
||||
was_present = present;
|
||||
}
|
||||
if has_gap {
|
||||
n_fragmented += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let fragmentation_ratio = if n_tracks > 0 {
|
||||
n_fragmented as f32 / n_tracks as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// False tracks per minute.
|
||||
let safe_duration = duration_minutes.max(1e-6);
|
||||
let false_tracks_per_min = total_false_positives as f32 / safe_duration;
|
||||
|
||||
// MOTA = 1 - (misses + false_positives + id_switches) / total_gt
|
||||
let mota = if total_gt > 0 {
|
||||
1.0 - (total_misses + total_false_positives + id_switches) as f32 / total_gt as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let passes = id_switches <= thresholds.max_id_switches
|
||||
&& fragmentation_ratio < thresholds.max_frag_ratio
|
||||
&& false_tracks_per_min <= thresholds.max_false_tracks_per_min;
|
||||
|
||||
TrackingResult {
|
||||
id_switches,
|
||||
fragmentation_ratio,
|
||||
false_tracks_per_min,
|
||||
mota,
|
||||
n_frames,
|
||||
passes,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Metric 3: Vital Sign Accuracy
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Thresholds for Metric 3 (Vital Sign Accuracy).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VitalSignThresholds {
|
||||
/// Breathing rate accuracy tolerance (BPM).
|
||||
pub breathing_bpm_tolerance: f32,
|
||||
/// Breathing band SNR minimum (dB).
|
||||
pub breathing_snr_db: f32,
|
||||
/// Heartbeat rate accuracy tolerance (BPM, aspirational).
|
||||
pub heartbeat_bpm_tolerance: f32,
|
||||
/// Heartbeat band SNR minimum (dB, aspirational).
|
||||
pub heartbeat_snr_db: f32,
|
||||
/// Micro-motion resolution in metres.
|
||||
pub micro_motion_m: f32,
|
||||
/// Range for micro-motion test (metres).
|
||||
pub micro_motion_range_m: f32,
|
||||
}
|
||||
|
||||
impl Default for VitalSignThresholds {
|
||||
fn default() -> Self {
|
||||
VitalSignThresholds {
|
||||
breathing_bpm_tolerance: 2.0,
|
||||
breathing_snr_db: 6.0,
|
||||
heartbeat_bpm_tolerance: 5.0,
|
||||
heartbeat_snr_db: 3.0,
|
||||
micro_motion_m: 0.001,
|
||||
micro_motion_range_m: 3.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A single vital sign measurement for evaluation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VitalSignMeasurement {
|
||||
/// Estimated breathing rate (BPM).
|
||||
pub breathing_bpm: f32,
|
||||
/// Ground-truth breathing rate (BPM).
|
||||
pub gt_breathing_bpm: f32,
|
||||
/// Breathing band SNR (dB).
|
||||
pub breathing_snr_db: f32,
|
||||
/// Estimated heartbeat rate (BPM), if available.
|
||||
pub heartbeat_bpm: Option<f32>,
|
||||
/// Ground-truth heartbeat rate (BPM), if available.
|
||||
pub gt_heartbeat_bpm: Option<f32>,
|
||||
/// Heartbeat band SNR (dB), if available.
|
||||
pub heartbeat_snr_db: Option<f32>,
|
||||
}
|
||||
|
||||
/// Result of Metric 3 evaluation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VitalSignResult {
|
||||
/// Mean breathing rate error (BPM).
|
||||
pub breathing_error_bpm: f32,
|
||||
/// Mean breathing SNR (dB).
|
||||
pub breathing_snr_db: f32,
|
||||
/// Mean heartbeat rate error (BPM), if measured.
|
||||
pub heartbeat_error_bpm: Option<f32>,
|
||||
/// Mean heartbeat SNR (dB), if measured.
|
||||
pub heartbeat_snr_db: Option<f32>,
|
||||
/// Number of measurements evaluated.
|
||||
pub n_measurements: usize,
|
||||
/// Whether this metric passes.
|
||||
pub passes: bool,
|
||||
}
|
||||
|
||||
/// Evaluate Metric 3: Vital Sign Accuracy.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `measurements`: per-epoch vital sign measurements with GT.
|
||||
/// - `thresholds`: acceptance thresholds.
|
||||
pub fn evaluate_vital_signs(
|
||||
measurements: &[VitalSignMeasurement],
|
||||
thresholds: &VitalSignThresholds,
|
||||
) -> VitalSignResult {
|
||||
let n = measurements.len();
|
||||
if n == 0 {
|
||||
return VitalSignResult {
|
||||
breathing_error_bpm: f32::MAX,
|
||||
breathing_snr_db: 0.0,
|
||||
heartbeat_error_bpm: None,
|
||||
heartbeat_snr_db: None,
|
||||
n_measurements: 0,
|
||||
passes: false,
|
||||
};
|
||||
}
|
||||
|
||||
// Breathing metrics.
|
||||
let breathing_errors: Vec<f32> = measurements
|
||||
.iter()
|
||||
.map(|m| (m.breathing_bpm - m.gt_breathing_bpm).abs())
|
||||
.collect();
|
||||
let breathing_error_mean = breathing_errors.iter().sum::<f32>() / n as f32;
|
||||
let breathing_snr_mean =
|
||||
measurements.iter().map(|m| m.breathing_snr_db).sum::<f32>() / n as f32;
|
||||
|
||||
// Heartbeat metrics (optional).
|
||||
let heartbeat_pairs: Vec<(f32, f32, f32)> = measurements
|
||||
.iter()
|
||||
.filter_map(|m| {
|
||||
match (m.heartbeat_bpm, m.gt_heartbeat_bpm, m.heartbeat_snr_db) {
|
||||
(Some(hb), Some(gt), Some(snr)) => Some((hb, gt, snr)),
|
||||
_ => None,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let (heartbeat_error, heartbeat_snr) = if heartbeat_pairs.is_empty() {
|
||||
(None, None)
|
||||
} else {
|
||||
let hb_n = heartbeat_pairs.len() as f32;
|
||||
let err = heartbeat_pairs
|
||||
.iter()
|
||||
.map(|(hb, gt, _)| (hb - gt).abs())
|
||||
.sum::<f32>()
|
||||
/ hb_n;
|
||||
let snr = heartbeat_pairs.iter().map(|(_, _, s)| s).sum::<f32>() / hb_n;
|
||||
(Some(err), Some(snr))
|
||||
};
|
||||
|
||||
// Pass/fail: breathing must pass; heartbeat is aspirational.
|
||||
let breathing_passes = breathing_error_mean <= thresholds.breathing_bpm_tolerance
|
||||
&& breathing_snr_mean >= thresholds.breathing_snr_db;
|
||||
|
||||
let heartbeat_passes = match (heartbeat_error, heartbeat_snr) {
|
||||
(Some(err), Some(snr)) => {
|
||||
err <= thresholds.heartbeat_bpm_tolerance && snr >= thresholds.heartbeat_snr_db
|
||||
}
|
||||
_ => true, // No heartbeat data: aspirational, not required.
|
||||
};
|
||||
|
||||
let passes = breathing_passes && heartbeat_passes;
|
||||
|
||||
VitalSignResult {
|
||||
breathing_error_bpm: breathing_error_mean,
|
||||
breathing_snr_db: breathing_snr_mean,
|
||||
heartbeat_error_bpm: heartbeat_error,
|
||||
heartbeat_snr_db: heartbeat_snr,
|
||||
n_measurements: n,
|
||||
passes,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tiered acceptance
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Combined result of all three metrics with tier determination.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RuViewAcceptanceResult {
|
||||
/// Metric 1: Joint Error.
|
||||
pub joint_error: JointErrorResult,
|
||||
/// Metric 2: Tracking.
|
||||
pub tracking: TrackingResult,
|
||||
/// Metric 3: Vital Signs.
|
||||
pub vital_signs: VitalSignResult,
|
||||
/// Achieved deployment tier.
|
||||
pub tier: RuViewTier,
|
||||
}
|
||||
|
||||
impl RuViewAcceptanceResult {
|
||||
/// A human-readable summary of the acceptance test.
|
||||
pub fn summary(&self) -> String {
|
||||
format!(
|
||||
"RuView Tier={} | PCK={:.3} OKS={:.3} | MOTA={:.3} IDsw={} | Breathing={:.1}BPM err",
|
||||
self.tier,
|
||||
self.joint_error.pck_all,
|
||||
self.joint_error.oks,
|
||||
self.tracking.mota,
|
||||
self.tracking.id_switches,
|
||||
self.vital_signs.breathing_error_bpm,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine the deployment tier from individual metric results.
|
||||
pub fn determine_tier(
|
||||
joint_error: &JointErrorResult,
|
||||
tracking: &TrackingResult,
|
||||
vital_signs: &VitalSignResult,
|
||||
) -> RuViewTier {
|
||||
if !tracking.passes {
|
||||
return RuViewTier::Fail;
|
||||
}
|
||||
// Bronze: only tracking passes.
|
||||
if !joint_error.passes {
|
||||
return RuViewTier::Bronze;
|
||||
}
|
||||
// Silver: tracking + joint error pass.
|
||||
if !vital_signs.passes {
|
||||
return RuViewTier::Silver;
|
||||
}
|
||||
// Gold: all pass.
|
||||
RuViewTier::Gold
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Internal helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn compute_bbox_diag(kp: &Array2<f32>, vis: &Array1<f32>) -> f32 {
|
||||
let mut x_min = f32::MAX;
|
||||
let mut x_max = f32::MIN;
|
||||
let mut y_min = f32::MAX;
|
||||
let mut y_max = f32::MIN;
|
||||
let mut any = false;
|
||||
|
||||
for j in 0..17.min(kp.shape()[0]) {
|
||||
if vis[j] >= 0.5 {
|
||||
let x = kp[[j, 0]];
|
||||
let y = kp[[j, 1]];
|
||||
x_min = x_min.min(x);
|
||||
x_max = x_max.max(x);
|
||||
y_min = y_min.min(y);
|
||||
y_max = y_max.max(y);
|
||||
any = true;
|
||||
}
|
||||
}
|
||||
if !any {
|
||||
return 0.0;
|
||||
}
|
||||
let w = (x_max - x_min).max(0.0);
|
||||
let h = (y_max - y_min).max(0.0);
|
||||
(w * w + h * h).sqrt()
|
||||
}
|
||||
|
||||
fn compute_single_oks(pred: &Array2<f32>, gt: &Array2<f32>, vis: &Array1<f32>, s: f32) -> f32 {
|
||||
let s_sq = s * s;
|
||||
let mut num = 0.0_f32;
|
||||
let mut den = 0.0_f32;
|
||||
for j in 0..17 {
|
||||
if vis[j] < 0.5 {
|
||||
continue;
|
||||
}
|
||||
den += 1.0;
|
||||
let dx = pred[[j, 0]] - gt[[j, 0]];
|
||||
let dy = pred[[j, 1]] - gt[[j, 1]];
|
||||
let d_sq = dx * dx + dy * dy;
|
||||
let k = COCO_SIGMAS[j];
|
||||
num += (-d_sq / (2.0 * s_sq * k * k)).exp();
|
||||
}
|
||||
if den > 0.0 { num / den } else { 0.0 }
|
||||
}
|
||||
|
||||
fn compute_torso_jitter(pred_kpts: &[Array2<f32>], visibility: &[Array1<f32>]) -> f32 {
|
||||
if pred_kpts.len() < 2 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Compute torso centroid per frame.
|
||||
let centroids: Vec<Option<(f32, f32)>> = pred_kpts
|
||||
.iter()
|
||||
.zip(visibility.iter())
|
||||
.map(|(kp, vis)| {
|
||||
let mut cx = 0.0_f32;
|
||||
let mut cy = 0.0_f32;
|
||||
let mut count = 0_usize;
|
||||
for &idx in &TORSO_INDICES {
|
||||
if vis[idx] >= 0.5 {
|
||||
cx += kp[[idx, 0]];
|
||||
cy += kp[[idx, 1]];
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
if count > 0 {
|
||||
Some((cx / count as f32, cy / count as f32))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Frame-to-frame displacement squared.
|
||||
let mut sum_sq = 0.0_f64;
|
||||
let mut n_pairs = 0_usize;
|
||||
for i in 1..centroids.len() {
|
||||
if let (Some((x0, y0)), Some((x1, y1))) = (centroids[i - 1], centroids[i]) {
|
||||
let dx = (x1 - x0) as f64;
|
||||
let dy = (y1 - y0) as f64;
|
||||
sum_sq += dx * dx + dy * dy;
|
||||
n_pairs += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if n_pairs == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
(sum_sq / n_pairs as f64).sqrt() as f32
|
||||
}
|
||||
|
||||
fn compute_p95_max_error(per_kp_errors: &[Vec<f32>]) -> f32 {
|
||||
// Collect all per-keypoint errors, find 95th percentile.
|
||||
let mut all_errors: Vec<f32> = per_kp_errors.iter().flat_map(|e| e.iter().copied()).collect();
|
||||
if all_errors.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
all_errors.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
|
||||
let idx = ((all_errors.len() as f64 * 0.95) as usize).min(all_errors.len() - 1);
|
||||
all_errors[idx]
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use ndarray::{Array1, Array2};
|
||||
|
||||
fn make_perfect_kpts() -> (Array2<f32>, Array2<f32>, Array1<f32>) {
|
||||
let kp = Array2::from_shape_fn((17, 2), |(j, d)| {
|
||||
if d == 0 { j as f32 * 0.05 } else { j as f32 * 0.03 }
|
||||
});
|
||||
let vis = Array1::ones(17);
|
||||
(kp.clone(), kp, vis)
|
||||
}
|
||||
|
||||
fn make_noisy_kpts(noise: f32) -> (Array2<f32>, Array2<f32>, Array1<f32>) {
|
||||
let gt = Array2::from_shape_fn((17, 2), |(j, d)| {
|
||||
if d == 0 { j as f32 * 0.03 } else { j as f32 * 0.02 }
|
||||
});
|
||||
let pred = Array2::from_shape_fn((17, 2), |(j, d)| {
|
||||
// Apply deterministic noise that varies per joint so some joints
|
||||
// are definitely outside the PCK threshold.
|
||||
gt[[j, d]] + noise * ((j * 7 + d * 3) as f32).sin()
|
||||
});
|
||||
let vis = Array1::ones(17);
|
||||
(pred, gt, vis)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn joint_error_perfect_predictions_pass() {
|
||||
let (pred, gt, vis) = make_perfect_kpts();
|
||||
let result = evaluate_joint_error(
|
||||
&[pred],
|
||||
&[gt],
|
||||
&[vis],
|
||||
&[1.0],
|
||||
&JointErrorThresholds::default(),
|
||||
);
|
||||
assert_eq!(result.pck_all, 1.0, "perfect predictions should have PCK=1.0");
|
||||
assert!((result.oks - 1.0).abs() < 1e-3, "perfect predictions should have OKS~1.0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn joint_error_empty_returns_fail() {
|
||||
let result = evaluate_joint_error(
|
||||
&[],
|
||||
&[],
|
||||
&[],
|
||||
&[],
|
||||
&JointErrorThresholds::default(),
|
||||
);
|
||||
assert!(!result.passes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn joint_error_noisy_predictions_lower_pck() {
|
||||
let (pred, gt, vis) = make_noisy_kpts(0.5);
|
||||
let result = evaluate_joint_error(
|
||||
&[pred],
|
||||
&[gt],
|
||||
&[vis],
|
||||
&[1.0],
|
||||
&JointErrorThresholds::default(),
|
||||
);
|
||||
assert!(result.pck_all < 1.0, "noisy predictions should have PCK < 1.0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tracking_no_id_switches_pass() {
|
||||
let frames: Vec<TrackingFrame> = (0..100)
|
||||
.map(|i| TrackingFrame {
|
||||
frame_idx: i,
|
||||
gt_ids: vec![1, 2],
|
||||
pred_ids: vec![1, 2],
|
||||
assignments: vec![(1, 1), (2, 2)],
|
||||
})
|
||||
.collect();
|
||||
let result = evaluate_tracking(&frames, 1.0, &TrackingThresholds::default());
|
||||
assert_eq!(result.id_switches, 0);
|
||||
assert!(result.passes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tracking_id_switches_detected() {
|
||||
let mut frames: Vec<TrackingFrame> = (0..10)
|
||||
.map(|i| TrackingFrame {
|
||||
frame_idx: i,
|
||||
gt_ids: vec![1, 2],
|
||||
pred_ids: vec![1, 2],
|
||||
assignments: vec![(1, 1), (2, 2)],
|
||||
})
|
||||
.collect();
|
||||
// Swap assignments at frame 5.
|
||||
frames[5].assignments = vec![(2, 1), (1, 2)];
|
||||
let result = evaluate_tracking(&frames, 1.0, &TrackingThresholds::default());
|
||||
assert!(result.id_switches >= 1, "should detect ID switch at frame 5");
|
||||
assert!(!result.passes, "ID switches should cause failure");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tracking_empty_returns_fail() {
|
||||
let result = evaluate_tracking(&[], 1.0, &TrackingThresholds::default());
|
||||
assert!(!result.passes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vital_signs_accurate_breathing_passes() {
|
||||
let measurements = vec![
|
||||
VitalSignMeasurement {
|
||||
breathing_bpm: 15.0,
|
||||
gt_breathing_bpm: 14.5,
|
||||
breathing_snr_db: 10.0,
|
||||
heartbeat_bpm: None,
|
||||
gt_heartbeat_bpm: None,
|
||||
heartbeat_snr_db: None,
|
||||
},
|
||||
VitalSignMeasurement {
|
||||
breathing_bpm: 16.0,
|
||||
gt_breathing_bpm: 15.5,
|
||||
breathing_snr_db: 8.0,
|
||||
heartbeat_bpm: None,
|
||||
gt_heartbeat_bpm: None,
|
||||
heartbeat_snr_db: None,
|
||||
},
|
||||
];
|
||||
let result = evaluate_vital_signs(&measurements, &VitalSignThresholds::default());
|
||||
assert!(result.breathing_error_bpm <= 2.0);
|
||||
assert!(result.passes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vital_signs_inaccurate_breathing_fails() {
|
||||
let measurements = vec![VitalSignMeasurement {
|
||||
breathing_bpm: 25.0,
|
||||
gt_breathing_bpm: 15.0,
|
||||
breathing_snr_db: 10.0,
|
||||
heartbeat_bpm: None,
|
||||
gt_heartbeat_bpm: None,
|
||||
heartbeat_snr_db: None,
|
||||
}];
|
||||
let result = evaluate_vital_signs(&measurements, &VitalSignThresholds::default());
|
||||
assert!(!result.passes, "10 BPM error should fail");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vital_signs_empty_returns_fail() {
|
||||
let result = evaluate_vital_signs(&[], &VitalSignThresholds::default());
|
||||
assert!(!result.passes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tier_determination_gold() {
|
||||
let je = JointErrorResult {
|
||||
pck_all: 0.85,
|
||||
pck_torso: 0.90,
|
||||
oks: 0.65,
|
||||
jitter_rms_m: 0.01,
|
||||
max_error_p95_m: 0.10,
|
||||
passes: true,
|
||||
};
|
||||
let tr = TrackingResult {
|
||||
id_switches: 0,
|
||||
fragmentation_ratio: 0.01,
|
||||
false_tracks_per_min: 0.0,
|
||||
mota: 0.95,
|
||||
n_frames: 1000,
|
||||
passes: true,
|
||||
};
|
||||
let vs = VitalSignResult {
|
||||
breathing_error_bpm: 1.0,
|
||||
breathing_snr_db: 8.0,
|
||||
heartbeat_error_bpm: Some(3.0),
|
||||
heartbeat_snr_db: Some(4.0),
|
||||
n_measurements: 10,
|
||||
passes: true,
|
||||
};
|
||||
assert_eq!(determine_tier(&je, &tr, &vs), RuViewTier::Gold);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tier_determination_silver() {
|
||||
let je = JointErrorResult { passes: true, ..Default::default() };
|
||||
let tr = TrackingResult { passes: true, ..Default::default() };
|
||||
let vs = VitalSignResult { passes: false, ..Default::default() };
|
||||
assert_eq!(determine_tier(&je, &tr, &vs), RuViewTier::Silver);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tier_determination_bronze() {
|
||||
let je = JointErrorResult { passes: false, ..Default::default() };
|
||||
let tr = TrackingResult { passes: true, ..Default::default() };
|
||||
let vs = VitalSignResult { passes: false, ..Default::default() };
|
||||
assert_eq!(determine_tier(&je, &tr, &vs), RuViewTier::Bronze);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tier_determination_fail() {
|
||||
let je = JointErrorResult { passes: true, ..Default::default() };
|
||||
let tr = TrackingResult { passes: false, ..Default::default() };
|
||||
let vs = VitalSignResult { passes: true, ..Default::default() };
|
||||
assert_eq!(determine_tier(&je, &tr, &vs), RuViewTier::Fail);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tier_ordering() {
|
||||
assert!(RuViewTier::Gold > RuViewTier::Silver);
|
||||
assert!(RuViewTier::Silver > RuViewTier::Bronze);
|
||||
assert!(RuViewTier::Bronze > RuViewTier::Fail);
|
||||
}
|
||||
|
||||
// Implement Default for test convenience.
|
||||
impl Default for JointErrorResult {
|
||||
fn default() -> Self {
|
||||
JointErrorResult {
|
||||
pck_all: 0.0,
|
||||
pck_torso: 0.0,
|
||||
oks: 0.0,
|
||||
jitter_rms_m: 0.0,
|
||||
max_error_p95_m: 0.0,
|
||||
passes: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TrackingResult {
|
||||
fn default() -> Self {
|
||||
TrackingResult {
|
||||
id_switches: 0,
|
||||
fragmentation_ratio: 0.0,
|
||||
false_tracks_per_min: 0.0,
|
||||
mota: 0.0,
|
||||
n_frames: 0,
|
||||
passes: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for VitalSignResult {
|
||||
fn default() -> Self {
|
||||
VitalSignResult {
|
||||
breathing_error_bpm: 0.0,
|
||||
breathing_snr_db: 0.0,
|
||||
heartbeat_error_bpm: None,
|
||||
heartbeat_snr_db: None,
|
||||
n_measurements: 0,
|
||||
passes: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -59,7 +59,7 @@ uuid = { version = "1.6", features = ["v4", "serde", "js"] }
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
|
||||
# Optional: wifi-densepose-mat integration
|
||||
wifi-densepose-mat = { version = "0.2.0", path = "../wifi-densepose-mat", optional = true, features = ["serde"] }
|
||||
wifi-densepose-mat = { version = "0.3.0", path = "../wifi-densepose-mat", optional = true, features = ["serde"] }
|
||||
|
||||
[dev-dependencies]
|
||||
wasm-bindgen-test = "0.3"
|
||||
|
||||
1129
rust-port/wifi-densepose-rs/patches/ruvector-crv/Cargo.lock
generated
Normal file
1129
rust-port/wifi-densepose-rs/patches/ruvector-crv/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
23
rust-port/wifi-densepose-rs/patches/ruvector-crv/Cargo.toml
Normal file
23
rust-port/wifi-densepose-rs/patches/ruvector-crv/Cargo.toml
Normal file
@@ -0,0 +1,23 @@
|
||||
[package]
|
||||
name = "ruvector-crv"
|
||||
version = "0.1.1"
|
||||
edition = "2021"
|
||||
authors = ["ruvector contributors"]
|
||||
description = "CRV (Coordinate Remote Viewing) protocol integration for ruvector - maps 6-stage signal line methodology to vector database subsystems"
|
||||
license = "MIT OR Apache-2.0"
|
||||
repository = "https://github.com/ruvnet/ruvector"
|
||||
|
||||
[lib]
|
||||
name = "ruvector_crv"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[dependencies]
|
||||
ruvector-attention = "0.1.31"
|
||||
ruvector-gnn = { version = "2.0", default-features = false }
|
||||
ruvector-mincut = { version = "2.0", default-features = false, features = ["exact"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
thiserror = "1.0"
|
||||
|
||||
[dev-dependencies]
|
||||
approx = "0.5"
|
||||
28
rust-port/wifi-densepose-rs/patches/ruvector-crv/Cargo.toml.orig
generated
Normal file
28
rust-port/wifi-densepose-rs/patches/ruvector-crv/Cargo.toml.orig
generated
Normal file
@@ -0,0 +1,28 @@
|
||||
[package]
|
||||
name = "ruvector-crv"
|
||||
version = "0.1.1"
|
||||
edition = "2021"
|
||||
authors = ["ruvector contributors"]
|
||||
description = "CRV (Coordinate Remote Viewing) protocol integration for ruvector - maps 6-stage signal line methodology to vector database subsystems"
|
||||
license = "MIT OR Apache-2.0"
|
||||
repository = "https://github.com/ruvnet/ruvector"
|
||||
readme = "README.md"
|
||||
keywords = ["crv", "signal-line", "vector-search", "attention", "hyperbolic"]
|
||||
categories = ["algorithms", "science"]
|
||||
|
||||
[lib]
|
||||
crate-type = ["rlib"]
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
||||
[dependencies]
|
||||
ruvector-attention = { version = "0.1.31", path = "../ruvector-attention" }
|
||||
ruvector-gnn = { version = "2.0.1", path = "../ruvector-gnn", default-features = false }
|
||||
ruvector-mincut = { version = "2.0.1", path = "../ruvector-mincut", default-features = false, features = ["exact"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
thiserror = "1.0"
|
||||
|
||||
[dev-dependencies]
|
||||
approx = "0.5"
|
||||
68
rust-port/wifi-densepose-rs/patches/ruvector-crv/README.md
Normal file
68
rust-port/wifi-densepose-rs/patches/ruvector-crv/README.md
Normal file
@@ -0,0 +1,68 @@
|
||||
# ruvector-crv
|
||||
|
||||
CRV (Coordinate Remote Viewing) protocol integration for ruvector.
|
||||
|
||||
Maps the 6-stage CRV signal line methodology to ruvector's subsystems:
|
||||
|
||||
| CRV Stage | Data Type | ruvector Component |
|
||||
|-----------|-----------|-------------------|
|
||||
| Stage I (Ideograms) | Gestalt primitives | Poincaré ball hyperbolic embeddings |
|
||||
| Stage II (Sensory) | Textures, colors, temps | Multi-head attention vectors |
|
||||
| Stage III (Dimensional) | Spatial sketches | GNN graph topology |
|
||||
| Stage IV (Emotional) | AOL, intangibles | SNN temporal encoding |
|
||||
| Stage V (Interrogation) | Signal line probing | Differentiable search |
|
||||
| Stage VI (3D Model) | Composite model | MinCut partitioning |
|
||||
|
||||
## Quick Start
|
||||
|
||||
```rust
|
||||
use ruvector_crv::{CrvConfig, CrvSessionManager, GestaltType, StageIData};
|
||||
|
||||
// Create session manager with default config (384 dimensions)
|
||||
let config = CrvConfig::default();
|
||||
let mut manager = CrvSessionManager::new(config);
|
||||
|
||||
// Create a session for a target coordinate
|
||||
manager.create_session("session-001".to_string(), "1234-5678".to_string()).unwrap();
|
||||
|
||||
// Add Stage I ideogram data
|
||||
let stage_i = StageIData {
|
||||
stroke: vec![(0.0, 0.0), (1.0, 0.5), (2.0, 1.0), (3.0, 0.5)],
|
||||
spontaneous_descriptor: "angular rising".to_string(),
|
||||
classification: GestaltType::Manmade,
|
||||
confidence: 0.85,
|
||||
};
|
||||
|
||||
let embedding = manager.add_stage_i("session-001", &stage_i).unwrap();
|
||||
assert_eq!(embedding.len(), 384);
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
The Poincaré ball embedding for Stage I gestalts encodes the hierarchical
|
||||
gestalt taxonomy (root → manmade/natural/movement/energy/water/land) with
|
||||
exponentially less distortion than Euclidean space.
|
||||
|
||||
For AOL (Analytical Overlay) separation, the spiking neural network temporal
|
||||
encoding models signal-vs-noise discrimination: high-frequency spike bursts
|
||||
correlate with AOL contamination, while sustained low-frequency patterns
|
||||
indicate clean signal line data.
|
||||
|
||||
MinCut partitioning in Stage VI identifies natural cluster boundaries in the
|
||||
accumulated session graph, separating distinct target aspects.
|
||||
|
||||
## Cross-Session Convergence
|
||||
|
||||
Multiple sessions targeting the same coordinate can be analyzed for
|
||||
convergence — agreement between independent viewers strengthens the
|
||||
signal validity:
|
||||
|
||||
```rust
|
||||
// After adding data to multiple sessions for "1234-5678"...
|
||||
let convergence = manager.find_convergence("1234-5678", 0.75).unwrap();
|
||||
// convergence.scores contains similarity values for converging entries
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT OR Apache-2.0
|
||||
@@ -0,0 +1,38 @@
|
||||
//! Error types for the CRV protocol integration.
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// CRV-specific errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CrvError {
|
||||
/// Dimension mismatch between expected and actual vector sizes.
|
||||
#[error("Dimension mismatch: expected {expected}, got {actual}")]
|
||||
DimensionMismatch { expected: usize, actual: usize },
|
||||
|
||||
/// Invalid CRV stage number.
|
||||
#[error("Invalid stage: {0} (must be 1-6)")]
|
||||
InvalidStage(u8),
|
||||
|
||||
/// Empty input data.
|
||||
#[error("Empty input: {0}")]
|
||||
EmptyInput(String),
|
||||
|
||||
/// Session not found.
|
||||
#[error("Session not found: {0}")]
|
||||
SessionNotFound(String),
|
||||
|
||||
/// Encoding failure.
|
||||
#[error("Encoding error: {0}")]
|
||||
EncodingError(String),
|
||||
|
||||
/// Attention mechanism error.
|
||||
#[error("Attention error: {0}")]
|
||||
AttentionError(#[from] ruvector_attention::AttentionError),
|
||||
|
||||
/// Serialization error.
|
||||
#[error("Serialization error: {0}")]
|
||||
SerializationError(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
/// Result type alias for CRV operations.
|
||||
pub type CrvResult<T> = Result<T, CrvError>;
|
||||
178
rust-port/wifi-densepose-rs/patches/ruvector-crv/src/lib.rs
Normal file
178
rust-port/wifi-densepose-rs/patches/ruvector-crv/src/lib.rs
Normal file
@@ -0,0 +1,178 @@
|
||||
//! # ruvector-crv
|
||||
//!
|
||||
//! CRV (Coordinate Remote Viewing) protocol integration for ruvector.
|
||||
//!
|
||||
//! Maps the 6-stage CRV signal line methodology to ruvector's subsystems:
|
||||
//!
|
||||
//! | CRV Stage | Data Type | ruvector Component |
|
||||
//! |-----------|-----------|-------------------|
|
||||
//! | Stage I (Ideograms) | Gestalt primitives | Poincaré ball hyperbolic embeddings |
|
||||
//! | Stage II (Sensory) | Textures, colors, temps | Multi-head attention vectors |
|
||||
//! | Stage III (Dimensional) | Spatial sketches | GNN graph topology |
|
||||
//! | Stage IV (Emotional) | AOL, intangibles | SNN temporal encoding |
|
||||
//! | Stage V (Interrogation) | Signal line probing | Differentiable search |
|
||||
//! | Stage VI (3D Model) | Composite model | MinCut partitioning |
|
||||
//!
|
||||
//! ## Quick Start
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! use ruvector_crv::{CrvConfig, CrvSessionManager, GestaltType, StageIData};
|
||||
//!
|
||||
//! // Create session manager with default config (384 dimensions)
|
||||
//! let config = CrvConfig::default();
|
||||
//! let mut manager = CrvSessionManager::new(config);
|
||||
//!
|
||||
//! // Create a session for a target coordinate
|
||||
//! manager.create_session("session-001".to_string(), "1234-5678".to_string()).unwrap();
|
||||
//!
|
||||
//! // Add Stage I ideogram data
|
||||
//! let stage_i = StageIData {
|
||||
//! stroke: vec![(0.0, 0.0), (1.0, 0.5), (2.0, 1.0), (3.0, 0.5)],
|
||||
//! spontaneous_descriptor: "angular rising".to_string(),
|
||||
//! classification: GestaltType::Manmade,
|
||||
//! confidence: 0.85,
|
||||
//! };
|
||||
//!
|
||||
//! let embedding = manager.add_stage_i("session-001", &stage_i).unwrap();
|
||||
//! assert_eq!(embedding.len(), 384);
|
||||
//! ```
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! The Poincaré ball embedding for Stage I gestalts encodes the hierarchical
|
||||
//! gestalt taxonomy (root → manmade/natural/movement/energy/water/land) with
|
||||
//! exponentially less distortion than Euclidean space.
|
||||
//!
|
||||
//! For AOL (Analytical Overlay) separation, the spiking neural network temporal
|
||||
//! encoding models signal-vs-noise discrimination: high-frequency spike bursts
|
||||
//! correlate with AOL contamination, while sustained low-frequency patterns
|
||||
//! indicate clean signal line data.
|
||||
//!
|
||||
//! MinCut partitioning in Stage VI identifies natural cluster boundaries in the
|
||||
//! accumulated session graph, separating distinct target aspects.
|
||||
//!
|
||||
//! ## Cross-Session Convergence
|
||||
//!
|
||||
//! Multiple sessions targeting the same coordinate can be analyzed for
|
||||
//! convergence — agreement between independent viewers strengthens the
|
||||
//! signal validity:
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! # use ruvector_crv::{CrvConfig, CrvSessionManager};
|
||||
//! # let mut manager = CrvSessionManager::new(CrvConfig::default());
|
||||
//! // After adding data to multiple sessions for "1234-5678"...
|
||||
//! let convergence = manager.find_convergence("1234-5678", 0.75).unwrap();
|
||||
//! // convergence.scores contains similarity values for converging entries
|
||||
//! ```
|
||||
|
||||
pub mod error;
|
||||
pub mod session;
|
||||
pub mod stage_i;
|
||||
pub mod stage_ii;
|
||||
pub mod stage_iii;
|
||||
pub mod stage_iv;
|
||||
pub mod stage_v;
|
||||
pub mod stage_vi;
|
||||
pub mod types;
|
||||
|
||||
// Re-export main types
|
||||
pub use error::{CrvError, CrvResult};
|
||||
pub use session::CrvSessionManager;
|
||||
pub use stage_i::StageIEncoder;
|
||||
pub use stage_ii::StageIIEncoder;
|
||||
pub use stage_iii::StageIIIEncoder;
|
||||
pub use stage_iv::StageIVEncoder;
|
||||
pub use stage_v::StageVEngine;
|
||||
pub use stage_vi::StageVIModeler;
|
||||
pub use types::{
|
||||
AOLDetection, ConvergenceResult, CrossReference, CrvConfig, CrvSessionEntry,
|
||||
GeometricKind, GestaltType, SensoryModality, SignalLineProbe, SketchElement,
|
||||
SpatialRelationType, SpatialRelationship, StageIData, StageIIData, StageIIIData,
|
||||
StageIVData, StageVData, StageVIData, TargetPartition,
|
||||
};
|
||||
|
||||
/// Library version.
|
||||
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_version() {
|
||||
assert!(!VERSION.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_end_to_end_session() {
|
||||
let config = CrvConfig {
|
||||
dimensions: 32,
|
||||
..CrvConfig::default()
|
||||
};
|
||||
let mut manager = CrvSessionManager::new(config);
|
||||
|
||||
// Create two sessions for the same coordinate
|
||||
manager
|
||||
.create_session("viewer-a".to_string(), "target-001".to_string())
|
||||
.unwrap();
|
||||
manager
|
||||
.create_session("viewer-b".to_string(), "target-001".to_string())
|
||||
.unwrap();
|
||||
|
||||
// Viewer A: Stage I
|
||||
let s1_a = StageIData {
|
||||
stroke: vec![(0.0, 0.0), (1.0, 1.0), (2.0, 0.5), (3.0, 0.0)],
|
||||
spontaneous_descriptor: "tall angular".to_string(),
|
||||
classification: GestaltType::Manmade,
|
||||
confidence: 0.85,
|
||||
};
|
||||
manager.add_stage_i("viewer-a", &s1_a).unwrap();
|
||||
|
||||
// Viewer B: Stage I (similar gestalt)
|
||||
let s1_b = StageIData {
|
||||
stroke: vec![(0.0, 0.0), (0.5, 1.2), (1.5, 0.8), (2.5, 0.0)],
|
||||
spontaneous_descriptor: "structured upward".to_string(),
|
||||
classification: GestaltType::Manmade,
|
||||
confidence: 0.78,
|
||||
};
|
||||
manager.add_stage_i("viewer-b", &s1_b).unwrap();
|
||||
|
||||
// Viewer A: Stage II
|
||||
let s2_a = StageIIData {
|
||||
impressions: vec![
|
||||
(SensoryModality::Texture, "rough stone".to_string()),
|
||||
(SensoryModality::Temperature, "cool".to_string()),
|
||||
(SensoryModality::Color, "gray".to_string()),
|
||||
],
|
||||
feature_vector: None,
|
||||
};
|
||||
manager.add_stage_ii("viewer-a", &s2_a).unwrap();
|
||||
|
||||
// Viewer B: Stage II (overlapping sensory)
|
||||
let s2_b = StageIIData {
|
||||
impressions: vec![
|
||||
(SensoryModality::Texture, "grainy rough".to_string()),
|
||||
(SensoryModality::Color, "dark gray".to_string()),
|
||||
(SensoryModality::Luminosity, "dim".to_string()),
|
||||
],
|
||||
feature_vector: None,
|
||||
};
|
||||
manager.add_stage_ii("viewer-b", &s2_b).unwrap();
|
||||
|
||||
// Verify entries
|
||||
assert_eq!(manager.session_entry_count("viewer-a"), 2);
|
||||
assert_eq!(manager.session_entry_count("viewer-b"), 2);
|
||||
|
||||
// Both sessions should have embeddings
|
||||
let entries_a = manager.get_session_embeddings("viewer-a").unwrap();
|
||||
let entries_b = manager.get_session_embeddings("viewer-b").unwrap();
|
||||
|
||||
assert_eq!(entries_a.len(), 2);
|
||||
assert_eq!(entries_b.len(), 2);
|
||||
|
||||
// All embeddings should be 32-dimensional
|
||||
for entry in entries_a.iter().chain(entries_b.iter()) {
|
||||
assert_eq!(entry.embedding.len(), 32);
|
||||
}
|
||||
}
|
||||
}
|
||||
629
rust-port/wifi-densepose-rs/patches/ruvector-crv/src/session.rs
Normal file
629
rust-port/wifi-densepose-rs/patches/ruvector-crv/src/session.rs
Normal file
@@ -0,0 +1,629 @@
|
||||
//! CRV Session Manager
|
||||
//!
|
||||
//! Manages CRV sessions as directed acyclic graphs (DAGs), where each session
|
||||
//! progresses through stages I-VI. Provides cross-session convergence analysis
|
||||
//! to find agreement between multiple viewers targeting the same coordinate.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! Each session is a DAG of stage entries. Cross-session convergence is computed
|
||||
//! by finding entries with high embedding similarity across different sessions
|
||||
//! targeting the same coordinate.
|
||||
|
||||
use crate::error::{CrvError, CrvResult};
|
||||
use crate::stage_i::StageIEncoder;
|
||||
use crate::stage_ii::StageIIEncoder;
|
||||
use crate::stage_iii::StageIIIEncoder;
|
||||
use crate::stage_iv::StageIVEncoder;
|
||||
use crate::stage_v::StageVEngine;
|
||||
use crate::stage_vi::StageVIModeler;
|
||||
use crate::types::*;
|
||||
use ruvector_gnn::search::cosine_similarity;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// A session entry stored in the session graph.
|
||||
#[derive(Debug, Clone)]
|
||||
struct SessionEntry {
|
||||
/// The stage data embedding.
|
||||
embedding: Vec<f32>,
|
||||
/// Stage number (1-6).
|
||||
stage: u8,
|
||||
/// Entry index within the stage.
|
||||
entry_index: usize,
|
||||
/// Metadata.
|
||||
metadata: HashMap<String, serde_json::Value>,
|
||||
/// Timestamp.
|
||||
timestamp_ms: u64,
|
||||
}
|
||||
|
||||
/// A complete CRV session with all stage data.
|
||||
#[derive(Debug)]
|
||||
struct Session {
|
||||
/// Session identifier.
|
||||
id: SessionId,
|
||||
/// Target coordinate.
|
||||
coordinate: TargetCoordinate,
|
||||
/// Entries organized by stage.
|
||||
entries: Vec<SessionEntry>,
|
||||
}
|
||||
|
||||
/// CRV Session Manager: coordinates all stage encoders and manages sessions.
|
||||
#[derive(Debug)]
|
||||
pub struct CrvSessionManager {
|
||||
/// Configuration.
|
||||
config: CrvConfig,
|
||||
/// Stage I encoder.
|
||||
stage_i: StageIEncoder,
|
||||
/// Stage II encoder.
|
||||
stage_ii: StageIIEncoder,
|
||||
/// Stage III encoder.
|
||||
stage_iii: StageIIIEncoder,
|
||||
/// Stage IV encoder.
|
||||
stage_iv: StageIVEncoder,
|
||||
/// Stage V engine.
|
||||
stage_v: StageVEngine,
|
||||
/// Stage VI modeler.
|
||||
stage_vi: StageVIModeler,
|
||||
/// Active sessions indexed by session ID.
|
||||
sessions: HashMap<SessionId, Session>,
|
||||
}
|
||||
|
||||
impl CrvSessionManager {
|
||||
/// Create a new session manager with the given configuration.
|
||||
pub fn new(config: CrvConfig) -> Self {
|
||||
let stage_i = StageIEncoder::new(&config);
|
||||
let stage_ii = StageIIEncoder::new(&config);
|
||||
let stage_iii = StageIIIEncoder::new(&config);
|
||||
let stage_iv = StageIVEncoder::new(&config);
|
||||
let stage_v = StageVEngine::new(&config);
|
||||
let stage_vi = StageVIModeler::new(&config);
|
||||
|
||||
Self {
|
||||
config,
|
||||
stage_i,
|
||||
stage_ii,
|
||||
stage_iii,
|
||||
stage_iv,
|
||||
stage_v,
|
||||
stage_vi,
|
||||
sessions: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new session for a given target coordinate.
|
||||
pub fn create_session(
|
||||
&mut self,
|
||||
session_id: SessionId,
|
||||
coordinate: TargetCoordinate,
|
||||
) -> CrvResult<()> {
|
||||
if self.sessions.contains_key(&session_id) {
|
||||
return Err(CrvError::EncodingError(format!(
|
||||
"Session {} already exists",
|
||||
session_id
|
||||
)));
|
||||
}
|
||||
|
||||
self.sessions.insert(
|
||||
session_id.clone(),
|
||||
Session {
|
||||
id: session_id,
|
||||
coordinate,
|
||||
entries: Vec::new(),
|
||||
},
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add Stage I data to a session.
|
||||
pub fn add_stage_i(
|
||||
&mut self,
|
||||
session_id: &str,
|
||||
data: &StageIData,
|
||||
) -> CrvResult<Vec<f32>> {
|
||||
let embedding = self.stage_i.encode(data)?;
|
||||
self.add_entry(session_id, 1, embedding.clone(), HashMap::new())?;
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
/// Add Stage II data to a session.
|
||||
pub fn add_stage_ii(
|
||||
&mut self,
|
||||
session_id: &str,
|
||||
data: &StageIIData,
|
||||
) -> CrvResult<Vec<f32>> {
|
||||
let embedding = self.stage_ii.encode(data)?;
|
||||
self.add_entry(session_id, 2, embedding.clone(), HashMap::new())?;
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
/// Add Stage III data to a session.
|
||||
pub fn add_stage_iii(
|
||||
&mut self,
|
||||
session_id: &str,
|
||||
data: &StageIIIData,
|
||||
) -> CrvResult<Vec<f32>> {
|
||||
let embedding = self.stage_iii.encode(data)?;
|
||||
self.add_entry(session_id, 3, embedding.clone(), HashMap::new())?;
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
/// Add Stage IV data to a session.
|
||||
pub fn add_stage_iv(
|
||||
&mut self,
|
||||
session_id: &str,
|
||||
data: &StageIVData,
|
||||
) -> CrvResult<Vec<f32>> {
|
||||
let embedding = self.stage_iv.encode(data)?;
|
||||
self.add_entry(session_id, 4, embedding.clone(), HashMap::new())?;
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
/// Run Stage V interrogation on a session.
|
||||
///
|
||||
/// Probes the accumulated session data with specified queries.
|
||||
pub fn run_stage_v(
|
||||
&mut self,
|
||||
session_id: &str,
|
||||
probe_queries: &[(&str, u8, Vec<f32>)], // (query text, target stage, query embedding)
|
||||
k: usize,
|
||||
) -> CrvResult<StageVData> {
|
||||
let session = self
|
||||
.sessions
|
||||
.get(session_id)
|
||||
.ok_or_else(|| CrvError::SessionNotFound(session_id.to_string()))?;
|
||||
|
||||
let all_embeddings: Vec<Vec<f32>> =
|
||||
session.entries.iter().map(|e| e.embedding.clone()).collect();
|
||||
|
||||
let mut probes = Vec::new();
|
||||
let mut cross_refs = Vec::new();
|
||||
|
||||
for (query_text, target_stage, query_emb) in probe_queries {
|
||||
// Filter candidates to the target stage
|
||||
let stage_entries: Vec<Vec<f32>> = session
|
||||
.entries
|
||||
.iter()
|
||||
.filter(|e| e.stage == *target_stage)
|
||||
.map(|e| e.embedding.clone())
|
||||
.collect();
|
||||
|
||||
if stage_entries.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut probe = self.stage_v.probe(query_emb, &stage_entries, k)?;
|
||||
probe.query = query_text.to_string();
|
||||
probe.target_stage = *target_stage;
|
||||
probes.push(probe);
|
||||
}
|
||||
|
||||
// Cross-reference between all stage pairs
|
||||
for from_stage in 1..=4u8 {
|
||||
for to_stage in (from_stage + 1)..=4u8 {
|
||||
let from_entries: Vec<Vec<f32>> = session
|
||||
.entries
|
||||
.iter()
|
||||
.filter(|e| e.stage == from_stage)
|
||||
.map(|e| e.embedding.clone())
|
||||
.collect();
|
||||
let to_entries: Vec<Vec<f32>> = session
|
||||
.entries
|
||||
.iter()
|
||||
.filter(|e| e.stage == to_stage)
|
||||
.map(|e| e.embedding.clone())
|
||||
.collect();
|
||||
|
||||
if !from_entries.is_empty() && !to_entries.is_empty() {
|
||||
let refs = self.stage_v.cross_reference(
|
||||
from_stage,
|
||||
&from_entries,
|
||||
to_stage,
|
||||
&to_entries,
|
||||
self.config.convergence_threshold,
|
||||
);
|
||||
cross_refs.extend(refs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let stage_v_data = StageVData {
|
||||
probes,
|
||||
cross_references: cross_refs,
|
||||
};
|
||||
|
||||
// Encode Stage V result and add to session
|
||||
if !stage_v_data.probes.is_empty() {
|
||||
let embedding = self.stage_v.encode(&stage_v_data, &all_embeddings)?;
|
||||
self.add_entry(session_id, 5, embedding, HashMap::new())?;
|
||||
}
|
||||
|
||||
Ok(stage_v_data)
|
||||
}
|
||||
|
||||
/// Run Stage VI composite modeling on a session.
|
||||
pub fn run_stage_vi(&mut self, session_id: &str) -> CrvResult<StageVIData> {
|
||||
let session = self
|
||||
.sessions
|
||||
.get(session_id)
|
||||
.ok_or_else(|| CrvError::SessionNotFound(session_id.to_string()))?;
|
||||
|
||||
let embeddings: Vec<Vec<f32>> =
|
||||
session.entries.iter().map(|e| e.embedding.clone()).collect();
|
||||
let labels: Vec<(u8, usize)> = session
|
||||
.entries
|
||||
.iter()
|
||||
.map(|e| (e.stage, e.entry_index))
|
||||
.collect();
|
||||
|
||||
let stage_vi_data = self.stage_vi.partition(&embeddings, &labels)?;
|
||||
|
||||
// Encode Stage VI result and add to session
|
||||
let embedding = self.stage_vi.encode(&stage_vi_data)?;
|
||||
self.add_entry(session_id, 6, embedding, HashMap::new())?;
|
||||
|
||||
Ok(stage_vi_data)
|
||||
}
|
||||
|
||||
/// Find convergence across multiple sessions targeting the same coordinate.
|
||||
///
|
||||
/// This is the core multi-viewer matching operation: given sessions from
|
||||
/// different viewers targeting the same coordinate, find which aspects
|
||||
/// of their signal line data converge (agree).
|
||||
pub fn find_convergence(
|
||||
&self,
|
||||
coordinate: &str,
|
||||
min_similarity: f32,
|
||||
) -> CrvResult<ConvergenceResult> {
|
||||
// Collect all sessions for this coordinate
|
||||
let relevant_sessions: Vec<&Session> = self
|
||||
.sessions
|
||||
.values()
|
||||
.filter(|s| s.coordinate == coordinate)
|
||||
.collect();
|
||||
|
||||
if relevant_sessions.len() < 2 {
|
||||
return Err(CrvError::EmptyInput(
|
||||
"Need at least 2 sessions for convergence analysis".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut session_pairs = Vec::new();
|
||||
let mut scores = Vec::new();
|
||||
let mut convergent_stages = Vec::new();
|
||||
|
||||
// Compare all pairs of sessions
|
||||
for i in 0..relevant_sessions.len() {
|
||||
for j in (i + 1)..relevant_sessions.len() {
|
||||
let sess_a = relevant_sessions[i];
|
||||
let sess_b = relevant_sessions[j];
|
||||
|
||||
// Compare stage-by-stage
|
||||
for stage in 1..=6u8 {
|
||||
let entries_a: Vec<&[f32]> = sess_a
|
||||
.entries
|
||||
.iter()
|
||||
.filter(|e| e.stage == stage)
|
||||
.map(|e| e.embedding.as_slice())
|
||||
.collect();
|
||||
let entries_b: Vec<&[f32]> = sess_b
|
||||
.entries
|
||||
.iter()
|
||||
.filter(|e| e.stage == stage)
|
||||
.map(|e| e.embedding.as_slice())
|
||||
.collect();
|
||||
|
||||
if entries_a.is_empty() || entries_b.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Find best match for each entry in A against entries in B
|
||||
for emb_a in &entries_a {
|
||||
for emb_b in &entries_b {
|
||||
if emb_a.len() == emb_b.len() && !emb_a.is_empty() {
|
||||
let sim = cosine_similarity(emb_a, emb_b);
|
||||
if sim >= min_similarity {
|
||||
session_pairs
|
||||
.push((sess_a.id.clone(), sess_b.id.clone()));
|
||||
scores.push(sim);
|
||||
if !convergent_stages.contains(&stage) {
|
||||
convergent_stages.push(stage);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute consensus embedding (mean of all converging embeddings)
|
||||
let consensus_embedding = if !scores.is_empty() {
|
||||
let mut consensus = vec![0.0f32; self.config.dimensions];
|
||||
let mut count = 0usize;
|
||||
|
||||
for session in &relevant_sessions {
|
||||
for entry in &session.entries {
|
||||
if convergent_stages.contains(&entry.stage) {
|
||||
for (i, &v) in entry.embedding.iter().enumerate() {
|
||||
if i < self.config.dimensions {
|
||||
consensus[i] += v;
|
||||
}
|
||||
}
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
for v in &mut consensus {
|
||||
*v /= count as f32;
|
||||
}
|
||||
Some(consensus)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Sort convergent stages
|
||||
convergent_stages.sort();
|
||||
|
||||
Ok(ConvergenceResult {
|
||||
session_pairs,
|
||||
scores,
|
||||
convergent_stages,
|
||||
consensus_embedding,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get all embeddings for a session.
|
||||
pub fn get_session_embeddings(&self, session_id: &str) -> CrvResult<Vec<CrvSessionEntry>> {
|
||||
let session = self
|
||||
.sessions
|
||||
.get(session_id)
|
||||
.ok_or_else(|| CrvError::SessionNotFound(session_id.to_string()))?;
|
||||
|
||||
Ok(session
|
||||
.entries
|
||||
.iter()
|
||||
.map(|e| CrvSessionEntry {
|
||||
session_id: session.id.clone(),
|
||||
coordinate: session.coordinate.clone(),
|
||||
stage: e.stage,
|
||||
embedding: e.embedding.clone(),
|
||||
metadata: e.metadata.clone(),
|
||||
timestamp_ms: e.timestamp_ms,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Get the number of entries in a session.
|
||||
pub fn session_entry_count(&self, session_id: &str) -> usize {
|
||||
self.sessions
|
||||
.get(session_id)
|
||||
.map(|s| s.entries.len())
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Get the number of active sessions.
|
||||
pub fn session_count(&self) -> usize {
|
||||
self.sessions.len()
|
||||
}
|
||||
|
||||
/// Remove a session.
|
||||
pub fn remove_session(&mut self, session_id: &str) -> bool {
|
||||
self.sessions.remove(session_id).is_some()
|
||||
}
|
||||
|
||||
/// Get access to the Stage I encoder for direct operations.
|
||||
pub fn stage_i_encoder(&self) -> &StageIEncoder {
|
||||
&self.stage_i
|
||||
}
|
||||
|
||||
/// Get access to the Stage II encoder for direct operations.
|
||||
pub fn stage_ii_encoder(&self) -> &StageIIEncoder {
|
||||
&self.stage_ii
|
||||
}
|
||||
|
||||
/// Get access to the Stage IV encoder for direct operations.
|
||||
pub fn stage_iv_encoder(&self) -> &StageIVEncoder {
|
||||
&self.stage_iv
|
||||
}
|
||||
|
||||
/// Get access to the Stage V engine for direct operations.
|
||||
pub fn stage_v_engine(&self) -> &StageVEngine {
|
||||
&self.stage_v
|
||||
}
|
||||
|
||||
/// Get access to the Stage VI modeler for direct operations.
|
||||
pub fn stage_vi_modeler(&self) -> &StageVIModeler {
|
||||
&self.stage_vi
|
||||
}
|
||||
|
||||
/// Internal: add an entry to a session.
|
||||
fn add_entry(
|
||||
&mut self,
|
||||
session_id: &str,
|
||||
stage: u8,
|
||||
embedding: Vec<f32>,
|
||||
metadata: HashMap<String, serde_json::Value>,
|
||||
) -> CrvResult<()> {
|
||||
let session = self
|
||||
.sessions
|
||||
.get_mut(session_id)
|
||||
.ok_or_else(|| CrvError::SessionNotFound(session_id.to_string()))?;
|
||||
|
||||
let entry_index = session.entries.iter().filter(|e| e.stage == stage).count();
|
||||
|
||||
session.entries.push(SessionEntry {
|
||||
embedding,
|
||||
stage,
|
||||
entry_index,
|
||||
metadata,
|
||||
timestamp_ms: 0,
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_config() -> CrvConfig {
|
||||
CrvConfig {
|
||||
dimensions: 32,
|
||||
convergence_threshold: 0.5,
|
||||
..CrvConfig::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_creation() {
|
||||
let config = test_config();
|
||||
let mut manager = CrvSessionManager::new(config);
|
||||
|
||||
manager
|
||||
.create_session("sess-1".to_string(), "1234-5678".to_string())
|
||||
.unwrap();
|
||||
assert_eq!(manager.session_count(), 1);
|
||||
assert_eq!(manager.session_entry_count("sess-1"), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_stage_i() {
|
||||
let config = test_config();
|
||||
let mut manager = CrvSessionManager::new(config);
|
||||
|
||||
manager
|
||||
.create_session("sess-1".to_string(), "1234-5678".to_string())
|
||||
.unwrap();
|
||||
|
||||
let data = StageIData {
|
||||
stroke: vec![(0.0, 0.0), (1.0, 1.0), (2.0, 0.0)],
|
||||
spontaneous_descriptor: "angular".to_string(),
|
||||
classification: GestaltType::Manmade,
|
||||
confidence: 0.9,
|
||||
};
|
||||
|
||||
let emb = manager.add_stage_i("sess-1", &data).unwrap();
|
||||
assert_eq!(emb.len(), 32);
|
||||
assert_eq!(manager.session_entry_count("sess-1"), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_stage_ii() {
|
||||
let config = test_config();
|
||||
let mut manager = CrvSessionManager::new(config);
|
||||
|
||||
manager
|
||||
.create_session("sess-1".to_string(), "coord-1".to_string())
|
||||
.unwrap();
|
||||
|
||||
let data = StageIIData {
|
||||
impressions: vec![
|
||||
(SensoryModality::Texture, "rough".to_string()),
|
||||
(SensoryModality::Color, "gray".to_string()),
|
||||
],
|
||||
feature_vector: None,
|
||||
};
|
||||
|
||||
let emb = manager.add_stage_ii("sess-1", &data).unwrap();
|
||||
assert_eq!(emb.len(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_full_session_flow() {
|
||||
let config = test_config();
|
||||
let mut manager = CrvSessionManager::new(config);
|
||||
|
||||
manager
|
||||
.create_session("sess-1".to_string(), "coord-1".to_string())
|
||||
.unwrap();
|
||||
|
||||
// Stage I
|
||||
let s1 = StageIData {
|
||||
stroke: vec![(0.0, 0.0), (1.0, 1.0), (2.0, 0.0)],
|
||||
spontaneous_descriptor: "angular".to_string(),
|
||||
classification: GestaltType::Manmade,
|
||||
confidence: 0.9,
|
||||
};
|
||||
manager.add_stage_i("sess-1", &s1).unwrap();
|
||||
|
||||
// Stage II
|
||||
let s2 = StageIIData {
|
||||
impressions: vec![
|
||||
(SensoryModality::Texture, "rough stone".to_string()),
|
||||
(SensoryModality::Temperature, "cold".to_string()),
|
||||
],
|
||||
feature_vector: None,
|
||||
};
|
||||
manager.add_stage_ii("sess-1", &s2).unwrap();
|
||||
|
||||
// Stage IV
|
||||
let s4 = StageIVData {
|
||||
emotional_impact: vec![("solemn".to_string(), 0.6)],
|
||||
tangibles: vec!["stone blocks".to_string()],
|
||||
intangibles: vec!["ancient".to_string()],
|
||||
aol_detections: vec![],
|
||||
};
|
||||
manager.add_stage_iv("sess-1", &s4).unwrap();
|
||||
|
||||
assert_eq!(manager.session_entry_count("sess-1"), 3);
|
||||
|
||||
// Get all entries
|
||||
let entries = manager.get_session_embeddings("sess-1").unwrap();
|
||||
assert_eq!(entries.len(), 3);
|
||||
assert_eq!(entries[0].stage, 1);
|
||||
assert_eq!(entries[1].stage, 2);
|
||||
assert_eq!(entries[2].stage, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duplicate_session() {
|
||||
let config = test_config();
|
||||
let mut manager = CrvSessionManager::new(config);
|
||||
|
||||
manager
|
||||
.create_session("sess-1".to_string(), "coord-1".to_string())
|
||||
.unwrap();
|
||||
|
||||
let result = manager.create_session("sess-1".to_string(), "coord-2".to_string());
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_not_found() {
|
||||
let config = test_config();
|
||||
let mut manager = CrvSessionManager::new(config);
|
||||
|
||||
let s1 = StageIData {
|
||||
stroke: vec![(0.0, 0.0), (1.0, 1.0)],
|
||||
spontaneous_descriptor: "test".to_string(),
|
||||
classification: GestaltType::Natural,
|
||||
confidence: 0.5,
|
||||
};
|
||||
|
||||
let result = manager.add_stage_i("nonexistent", &s1);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_session() {
|
||||
let config = test_config();
|
||||
let mut manager = CrvSessionManager::new(config);
|
||||
|
||||
manager
|
||||
.create_session("sess-1".to_string(), "coord-1".to_string())
|
||||
.unwrap();
|
||||
assert_eq!(manager.session_count(), 1);
|
||||
|
||||
assert!(manager.remove_session("sess-1"));
|
||||
assert_eq!(manager.session_count(), 0);
|
||||
|
||||
assert!(!manager.remove_session("sess-1"));
|
||||
}
|
||||
}
|
||||
364
rust-port/wifi-densepose-rs/patches/ruvector-crv/src/stage_i.rs
Normal file
364
rust-port/wifi-densepose-rs/patches/ruvector-crv/src/stage_i.rs
Normal file
@@ -0,0 +1,364 @@
|
||||
//! Stage I Encoder: Ideogram Gestalts via Poincaré Ball Embeddings
|
||||
//!
|
||||
//! CRV Stage I captures gestalt primitives (manmade, natural, movement, energy,
|
||||
//! water, land) through ideogram traces. The hierarchical taxonomy of gestalts
|
||||
//! maps naturally to hyperbolic space, where the Poincaré ball model encodes
|
||||
//! tree-like structures with exponentially less distortion than Euclidean space.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! Ideogram stroke traces are converted to fixed-dimension feature vectors,
|
||||
//! then projected into the Poincaré ball. Gestalt classification uses hyperbolic
|
||||
//! distance to prototype embeddings for each gestalt type.
|
||||
|
||||
use crate::error::{CrvError, CrvResult};
|
||||
use crate::types::{CrvConfig, GestaltType, StageIData};
|
||||
use ruvector_attention::hyperbolic::{
|
||||
exp_map, frechet_mean, log_map, mobius_add, poincare_distance, project_to_ball,
|
||||
};
|
||||
|
||||
/// Stage I encoder using Poincaré ball hyperbolic embeddings.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StageIEncoder {
|
||||
/// Embedding dimensionality.
|
||||
dim: usize,
|
||||
/// Poincaré ball curvature (positive).
|
||||
curvature: f32,
|
||||
/// Prototype embeddings for each gestalt type in the Poincaré ball.
|
||||
/// Indexed by `GestaltType::index()`.
|
||||
prototypes: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
impl StageIEncoder {
|
||||
/// Create a new Stage I encoder with default gestalt prototypes.
|
||||
pub fn new(config: &CrvConfig) -> Self {
|
||||
let dim = config.dimensions;
|
||||
let curvature = config.curvature;
|
||||
|
||||
// Initialize gestalt prototypes as points in the Poincaré ball.
|
||||
// Each prototype is placed at a distinct region of the ball,
|
||||
// with hierarchical relationships preserved by hyperbolic distance.
|
||||
let prototypes = Self::init_prototypes(dim, curvature);
|
||||
|
||||
Self {
|
||||
dim,
|
||||
curvature,
|
||||
prototypes,
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize gestalt prototype embeddings in the Poincaré ball.
|
||||
///
|
||||
/// Places each gestalt type at a distinct angular position with
|
||||
/// controlled radial distance from the origin. The hierarchical
|
||||
/// structure (root → gestalt types → sub-types) is preserved
|
||||
/// by the exponential volume growth of hyperbolic space.
|
||||
fn init_prototypes(dim: usize, curvature: f32) -> Vec<Vec<f32>> {
|
||||
let num_types = GestaltType::all().len();
|
||||
let mut prototypes = Vec::with_capacity(num_types);
|
||||
|
||||
for gestalt in GestaltType::all() {
|
||||
let idx = gestalt.index();
|
||||
// Place each prototype along a different axis direction
|
||||
// with a moderate radial distance (0.3-0.5 of ball radius).
|
||||
let mut proto = vec![0.0f32; dim];
|
||||
|
||||
// Use multiple dimensions to spread prototypes apart
|
||||
let base_dim = idx * (dim / num_types);
|
||||
let spread = dim / num_types;
|
||||
|
||||
for d in 0..spread.min(dim - base_dim) {
|
||||
let angle = std::f32::consts::PI * 2.0 * (d as f32) / (spread as f32);
|
||||
proto[base_dim + d] = 0.3 * angle.cos() / (spread as f32).sqrt();
|
||||
}
|
||||
|
||||
// Project to ball to ensure it's inside
|
||||
proto = project_to_ball(&proto, curvature, 1e-7);
|
||||
prototypes.push(proto);
|
||||
}
|
||||
|
||||
prototypes
|
||||
}
|
||||
|
||||
/// Encode an ideogram stroke trace into a fixed-dimension feature vector.
|
||||
///
|
||||
/// Extracts geometric features from the stroke: curvature statistics,
|
||||
/// velocity profile, angular distribution, and bounding box ratios.
|
||||
pub fn encode_stroke(&self, stroke: &[(f32, f32)]) -> CrvResult<Vec<f32>> {
|
||||
if stroke.is_empty() {
|
||||
return Err(CrvError::EmptyInput("Stroke trace is empty".to_string()));
|
||||
}
|
||||
|
||||
let mut features = vec![0.0f32; self.dim];
|
||||
|
||||
// Feature 1: Stroke statistics (first few dimensions)
|
||||
let n = stroke.len() as f32;
|
||||
let (cx, cy) = stroke
|
||||
.iter()
|
||||
.fold((0.0, 0.0), |(sx, sy), &(x, y)| (sx + x, sy + y));
|
||||
features[0] = cx / n; // centroid x
|
||||
features[1] = cy / n; // centroid y
|
||||
|
||||
// Feature 2: Bounding box aspect ratio
|
||||
let (min_x, max_x) = stroke
|
||||
.iter()
|
||||
.map(|p| p.0)
|
||||
.fold((f32::MAX, f32::MIN), |(mn, mx), v| (mn.min(v), mx.max(v)));
|
||||
let (min_y, max_y) = stroke
|
||||
.iter()
|
||||
.map(|p| p.1)
|
||||
.fold((f32::MAX, f32::MIN), |(mn, mx), v| (mn.min(v), mx.max(v)));
|
||||
let width = (max_x - min_x).max(1e-6);
|
||||
let height = (max_y - min_y).max(1e-6);
|
||||
features[2] = width / height; // aspect ratio
|
||||
|
||||
// Feature 3: Total path length (normalized)
|
||||
let mut path_length = 0.0f32;
|
||||
for i in 1..stroke.len() {
|
||||
let dx = stroke[i].0 - stroke[i - 1].0;
|
||||
let dy = stroke[i].1 - stroke[i - 1].1;
|
||||
path_length += (dx * dx + dy * dy).sqrt();
|
||||
}
|
||||
features[3] = path_length / (width + height).max(1e-6);
|
||||
|
||||
// Feature 4: Angular distribution (segment angles)
|
||||
if stroke.len() >= 3 {
|
||||
let num_angle_bins = 8.min(self.dim.saturating_sub(4));
|
||||
for i in 1..stroke.len().saturating_sub(1) {
|
||||
let dx1 = stroke[i].0 - stroke[i - 1].0;
|
||||
let dy1 = stroke[i].1 - stroke[i - 1].1;
|
||||
let dx2 = stroke[i + 1].0 - stroke[i].0;
|
||||
let dy2 = stroke[i + 1].1 - stroke[i].1;
|
||||
let angle = dy1.atan2(dx1) - dy2.atan2(dx2);
|
||||
let bin = ((angle + std::f32::consts::PI) / (2.0 * std::f32::consts::PI)
|
||||
* num_angle_bins as f32) as usize;
|
||||
let bin = bin.min(num_angle_bins - 1);
|
||||
if 4 + bin < self.dim {
|
||||
features[4 + bin] += 1.0 / (stroke.len() as f32 - 2.0).max(1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Feature 5: Curvature variance (spread across remaining dimensions)
|
||||
if stroke.len() >= 3 {
|
||||
let mut curvatures = Vec::new();
|
||||
for i in 1..stroke.len() - 1 {
|
||||
let dx1 = stroke[i].0 - stroke[i - 1].0;
|
||||
let dy1 = stroke[i].1 - stroke[i - 1].1;
|
||||
let dx2 = stroke[i + 1].0 - stroke[i].0;
|
||||
let dy2 = stroke[i + 1].1 - stroke[i].1;
|
||||
let cross = dx1 * dy2 - dy1 * dx2;
|
||||
let ds1 = (dx1 * dx1 + dy1 * dy1).sqrt().max(1e-6);
|
||||
let ds2 = (dx2 * dx2 + dy2 * dy2).sqrt().max(1e-6);
|
||||
curvatures.push(cross / (ds1 * ds2));
|
||||
}
|
||||
if !curvatures.is_empty() {
|
||||
let mean_k: f32 = curvatures.iter().sum::<f32>() / curvatures.len() as f32;
|
||||
let var_k: f32 = curvatures.iter().map(|k| (k - mean_k).powi(2)).sum::<f32>()
|
||||
/ curvatures.len() as f32;
|
||||
if 12 < self.dim {
|
||||
features[12] = mean_k;
|
||||
}
|
||||
if 13 < self.dim {
|
||||
features[13] = var_k;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize the feature vector
|
||||
let norm: f32 = features.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 1e-6 {
|
||||
let scale = 0.4 / norm; // keep within ball
|
||||
for f in &mut features {
|
||||
*f *= scale;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(features)
|
||||
}
|
||||
|
||||
/// Encode complete Stage I data into a Poincaré ball embedding.
|
||||
///
|
||||
/// Combines stroke features with the gestalt prototype via Möbius addition,
|
||||
/// producing a vector that encodes both the raw ideogram trace and its
|
||||
/// gestalt classification in hyperbolic space.
|
||||
pub fn encode(&self, data: &StageIData) -> CrvResult<Vec<f32>> {
|
||||
let stroke_features = self.encode_stroke(&data.stroke)?;
|
||||
|
||||
// Get the prototype for the classified gestalt type
|
||||
let prototype = &self.prototypes[data.classification.index()];
|
||||
|
||||
// Combine stroke features with gestalt prototype via Möbius addition.
|
||||
// This places the encoded vector near the gestalt prototype in
|
||||
// hyperbolic space, with the stroke features providing the offset.
|
||||
let combined = mobius_add(&stroke_features, prototype, self.curvature);
|
||||
|
||||
// Weight by confidence
|
||||
let weighted: Vec<f32> = combined
|
||||
.iter()
|
||||
.map(|&v| v * data.confidence + stroke_features[0] * (1.0 - data.confidence))
|
||||
.collect();
|
||||
|
||||
Ok(project_to_ball(&weighted, self.curvature, 1e-7))
|
||||
}
|
||||
|
||||
/// Classify a stroke embedding into a gestalt type by finding the
|
||||
/// nearest prototype in hyperbolic space.
|
||||
pub fn classify(&self, embedding: &[f32]) -> CrvResult<(GestaltType, f32)> {
|
||||
if embedding.len() != self.dim {
|
||||
return Err(CrvError::DimensionMismatch {
|
||||
expected: self.dim,
|
||||
actual: embedding.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut best_type = GestaltType::Manmade;
|
||||
let mut best_distance = f32::MAX;
|
||||
|
||||
for gestalt in GestaltType::all() {
|
||||
let proto = &self.prototypes[gestalt.index()];
|
||||
let dist = poincare_distance(embedding, proto, self.curvature);
|
||||
if dist < best_distance {
|
||||
best_distance = dist;
|
||||
best_type = *gestalt;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert distance to confidence (closer = higher confidence)
|
||||
let confidence = (-best_distance).exp();
|
||||
|
||||
Ok((best_type, confidence))
|
||||
}
|
||||
|
||||
/// Compute the Fréchet mean of multiple Stage I embeddings.
|
||||
///
|
||||
/// Useful for finding the consensus gestalt across multiple sessions
|
||||
/// targeting the same coordinate.
|
||||
pub fn consensus(&self, embeddings: &[&[f32]]) -> CrvResult<Vec<f32>> {
|
||||
if embeddings.is_empty() {
|
||||
return Err(CrvError::EmptyInput(
|
||||
"No embeddings for consensus".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(frechet_mean(embeddings, None, self.curvature, 50, 1e-5))
|
||||
}
|
||||
|
||||
/// Compute pairwise hyperbolic distance between two Stage I embeddings.
|
||||
pub fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
|
||||
poincare_distance(a, b, self.curvature)
|
||||
}
|
||||
|
||||
/// Get the prototype embedding for a gestalt type.
|
||||
pub fn prototype(&self, gestalt: GestaltType) -> &[f32] {
|
||||
&self.prototypes[gestalt.index()]
|
||||
}
|
||||
|
||||
/// Map an embedding to tangent space at the origin for Euclidean operations.
|
||||
pub fn to_tangent(&self, embedding: &[f32]) -> Vec<f32> {
|
||||
let origin = vec![0.0f32; self.dim];
|
||||
log_map(embedding, &origin, self.curvature)
|
||||
}
|
||||
|
||||
/// Map a tangent vector back to the Poincaré ball.
|
||||
pub fn from_tangent(&self, tangent: &[f32]) -> Vec<f32> {
|
||||
let origin = vec![0.0f32; self.dim];
|
||||
exp_map(tangent, &origin, self.curvature)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_config() -> CrvConfig {
|
||||
CrvConfig {
|
||||
dimensions: 32,
|
||||
curvature: 1.0,
|
||||
..CrvConfig::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encoder_creation() {
|
||||
let config = test_config();
|
||||
let encoder = StageIEncoder::new(&config);
|
||||
assert_eq!(encoder.dim, 32);
|
||||
assert_eq!(encoder.prototypes.len(), 6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stroke_encoding() {
|
||||
let config = test_config();
|
||||
let encoder = StageIEncoder::new(&config);
|
||||
|
||||
let stroke = vec![(0.0, 0.0), (1.0, 0.5), (2.0, 1.0), (3.0, 0.5), (4.0, 0.0)];
|
||||
let embedding = encoder.encode_stroke(&stroke).unwrap();
|
||||
assert_eq!(embedding.len(), 32);
|
||||
|
||||
// Should be inside the Poincaré ball
|
||||
let norm_sq: f32 = embedding.iter().map(|x| x * x).sum();
|
||||
assert!(norm_sq < 1.0 / config.curvature);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_full_encode() {
|
||||
let config = test_config();
|
||||
let encoder = StageIEncoder::new(&config);
|
||||
|
||||
let data = StageIData {
|
||||
stroke: vec![(0.0, 0.0), (1.0, 1.0), (2.0, 0.0)],
|
||||
spontaneous_descriptor: "angular".to_string(),
|
||||
classification: GestaltType::Manmade,
|
||||
confidence: 0.9,
|
||||
};
|
||||
|
||||
let embedding = encoder.encode(&data).unwrap();
|
||||
assert_eq!(embedding.len(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classification() {
|
||||
let config = test_config();
|
||||
let encoder = StageIEncoder::new(&config);
|
||||
|
||||
// Encode and classify should round-trip for strong prototypes
|
||||
let proto = encoder.prototype(GestaltType::Energy).to_vec();
|
||||
let (classified, confidence) = encoder.classify(&proto).unwrap();
|
||||
assert_eq!(classified, GestaltType::Energy);
|
||||
assert!(confidence > 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_distance_symmetry() {
|
||||
let config = test_config();
|
||||
let encoder = StageIEncoder::new(&config);
|
||||
|
||||
let a = encoder.prototype(GestaltType::Manmade);
|
||||
let b = encoder.prototype(GestaltType::Natural);
|
||||
|
||||
let d_ab = encoder.distance(a, b);
|
||||
let d_ba = encoder.distance(b, a);
|
||||
|
||||
assert!((d_ab - d_ba).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tangent_roundtrip() {
|
||||
let config = test_config();
|
||||
let encoder = StageIEncoder::new(&config);
|
||||
|
||||
let proto = encoder.prototype(GestaltType::Water).to_vec();
|
||||
let tangent = encoder.to_tangent(&proto);
|
||||
let recovered = encoder.from_tangent(&tangent);
|
||||
|
||||
// Should approximately round-trip
|
||||
let error: f32 = proto
|
||||
.iter()
|
||||
.zip(&recovered)
|
||||
.map(|(a, b)| (a - b).abs())
|
||||
.sum::<f32>()
|
||||
/ proto.len() as f32;
|
||||
assert!(error < 0.1);
|
||||
}
|
||||
}
|
||||
268
rust-port/wifi-densepose-rs/patches/ruvector-crv/src/stage_ii.rs
Normal file
268
rust-port/wifi-densepose-rs/patches/ruvector-crv/src/stage_ii.rs
Normal file
@@ -0,0 +1,268 @@
|
||||
//! Stage II Encoder: Sensory Data via Multi-Head Attention Vectors
|
||||
//!
|
||||
//! CRV Stage II captures sensory impressions (textures, colors, temperatures,
|
||||
//! sounds, etc.). Each sensory modality is encoded as a separate attention head,
|
||||
//! with the multi-head mechanism combining them into a unified 384-dimensional
|
||||
//! representation.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! Sensory descriptors are hashed into feature vectors per modality, then
|
||||
//! processed through multi-head attention where each head specializes in
|
||||
//! a different sensory channel.
|
||||
|
||||
use crate::error::{CrvError, CrvResult};
|
||||
use crate::types::{CrvConfig, SensoryModality, StageIIData};
|
||||
use ruvector_attention::traits::Attention;
|
||||
use ruvector_attention::MultiHeadAttention;
|
||||
|
||||
/// Number of sensory modality heads.
|
||||
const NUM_MODALITIES: usize = 8;
|
||||
|
||||
/// Stage II encoder using multi-head attention for sensory fusion.
|
||||
pub struct StageIIEncoder {
|
||||
/// Embedding dimensionality.
|
||||
dim: usize,
|
||||
/// Multi-head attention mechanism (one head per modality).
|
||||
attention: MultiHeadAttention,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for StageIIEncoder {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("StageIIEncoder")
|
||||
.field("dim", &self.dim)
|
||||
.field("attention", &"MultiHeadAttention { .. }")
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl StageIIEncoder {
|
||||
/// Create a new Stage II encoder.
|
||||
pub fn new(config: &CrvConfig) -> Self {
|
||||
let dim = config.dimensions;
|
||||
// Ensure dim is divisible by NUM_MODALITIES
|
||||
let effective_heads = if dim % NUM_MODALITIES == 0 {
|
||||
NUM_MODALITIES
|
||||
} else {
|
||||
// Fall back to a divisor
|
||||
let mut h = NUM_MODALITIES;
|
||||
while dim % h != 0 && h > 1 {
|
||||
h -= 1;
|
||||
}
|
||||
h
|
||||
};
|
||||
|
||||
let attention = MultiHeadAttention::new(dim, effective_heads);
|
||||
|
||||
Self { dim, attention }
|
||||
}
|
||||
|
||||
/// Encode a sensory descriptor string into a feature vector.
|
||||
///
|
||||
/// Uses a deterministic hash-based encoding to convert text descriptors
|
||||
/// into fixed-dimension vectors. Each modality gets a distinct subspace.
|
||||
fn encode_descriptor(&self, modality: SensoryModality, descriptor: &str) -> Vec<f32> {
|
||||
let mut features = vec![0.0f32; self.dim];
|
||||
let modality_offset = modality_index(modality) * (self.dim / NUM_MODALITIES.max(1));
|
||||
let subspace_size = self.dim / NUM_MODALITIES.max(1);
|
||||
|
||||
// Simple deterministic hash encoding
|
||||
let bytes = descriptor.as_bytes();
|
||||
for (i, &byte) in bytes.iter().enumerate() {
|
||||
let dim_idx = modality_offset + (i % subspace_size);
|
||||
if dim_idx < self.dim {
|
||||
// Distribute byte values across the subspace with varied phases
|
||||
let phase = (i as f32) * 0.618_034; // golden ratio
|
||||
features[dim_idx] += (byte as f32 / 255.0) * (phase * std::f32::consts::PI).cos();
|
||||
}
|
||||
}
|
||||
|
||||
// Add modality-specific bias
|
||||
if modality_offset < self.dim {
|
||||
features[modality_offset] += 1.0;
|
||||
}
|
||||
|
||||
// L2 normalize
|
||||
let norm: f32 = features.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 1e-6 {
|
||||
for f in &mut features {
|
||||
*f /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
features
|
||||
}
|
||||
|
||||
/// Encode Stage II data into a unified sensory embedding.
|
||||
///
|
||||
/// Each sensory impression becomes a key-value pair in the attention
|
||||
/// mechanism. A learned query (based on the modality distribution)
|
||||
/// attends over all impressions to produce the fused output.
|
||||
pub fn encode(&self, data: &StageIIData) -> CrvResult<Vec<f32>> {
|
||||
if data.impressions.is_empty() {
|
||||
return Err(CrvError::EmptyInput(
|
||||
"No sensory impressions".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// If a pre-computed feature vector exists, use it
|
||||
if let Some(ref fv) = data.feature_vector {
|
||||
if fv.len() == self.dim {
|
||||
return Ok(fv.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Encode each impression into a feature vector
|
||||
let encoded: Vec<Vec<f32>> = data
|
||||
.impressions
|
||||
.iter()
|
||||
.map(|(modality, descriptor)| self.encode_descriptor(*modality, descriptor))
|
||||
.collect();
|
||||
|
||||
// Build query from modality distribution
|
||||
let query = self.build_modality_query(&data.impressions);
|
||||
|
||||
let keys: Vec<&[f32]> = encoded.iter().map(|v| v.as_slice()).collect();
|
||||
let values: Vec<&[f32]> = encoded.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = self.attention.compute(&query, &keys, &values)?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Build a query vector from the distribution of modalities present.
|
||||
fn build_modality_query(&self, impressions: &[(SensoryModality, String)]) -> Vec<f32> {
|
||||
let mut query = vec![0.0f32; self.dim];
|
||||
let subspace_size = self.dim / NUM_MODALITIES.max(1);
|
||||
|
||||
// Count modality occurrences
|
||||
let mut counts = [0usize; NUM_MODALITIES];
|
||||
for (modality, _) in impressions {
|
||||
let idx = modality_index(*modality);
|
||||
if idx < NUM_MODALITIES {
|
||||
counts[idx] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Encode counts as the query
|
||||
let total: f32 = counts.iter().sum::<usize>() as f32;
|
||||
for (m, &count) in counts.iter().enumerate() {
|
||||
let weight = count as f32 / total.max(1.0);
|
||||
let offset = m * subspace_size;
|
||||
for d in 0..subspace_size.min(self.dim - offset) {
|
||||
query[offset + d] = weight * (1.0 + d as f32 * 0.01);
|
||||
}
|
||||
}
|
||||
|
||||
// L2 normalize
|
||||
let norm: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 1e-6 {
|
||||
for f in &mut query {
|
||||
*f /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
query
|
||||
}
|
||||
|
||||
/// Compute similarity between two Stage II embeddings.
|
||||
pub fn similarity(&self, a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() || a.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm_a < 1e-6 || norm_b < 1e-6 {
|
||||
return 0.0;
|
||||
}
|
||||
dot / (norm_a * norm_b)
|
||||
}
|
||||
}
|
||||
|
||||
/// Map sensory modality to index.
|
||||
fn modality_index(m: SensoryModality) -> usize {
|
||||
match m {
|
||||
SensoryModality::Texture => 0,
|
||||
SensoryModality::Color => 1,
|
||||
SensoryModality::Temperature => 2,
|
||||
SensoryModality::Sound => 3,
|
||||
SensoryModality::Smell => 4,
|
||||
SensoryModality::Taste => 5,
|
||||
SensoryModality::Dimension => 6,
|
||||
SensoryModality::Luminosity => 7,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_config() -> CrvConfig {
|
||||
CrvConfig {
|
||||
dimensions: 32, // 32 / 8 = 4 dims per head
|
||||
..CrvConfig::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encoder_creation() {
|
||||
let config = test_config();
|
||||
let encoder = StageIIEncoder::new(&config);
|
||||
assert_eq!(encoder.dim, 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_descriptor_encoding() {
|
||||
let config = test_config();
|
||||
let encoder = StageIIEncoder::new(&config);
|
||||
|
||||
let v = encoder.encode_descriptor(SensoryModality::Texture, "rough grainy");
|
||||
assert_eq!(v.len(), 32);
|
||||
|
||||
// Should be normalized
|
||||
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
assert!((norm - 1.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_full_encode() {
|
||||
let config = test_config();
|
||||
let encoder = StageIIEncoder::new(&config);
|
||||
|
||||
let data = StageIIData {
|
||||
impressions: vec![
|
||||
(SensoryModality::Texture, "rough".to_string()),
|
||||
(SensoryModality::Color, "blue-gray".to_string()),
|
||||
(SensoryModality::Temperature, "cold".to_string()),
|
||||
],
|
||||
feature_vector: None,
|
||||
};
|
||||
|
||||
let embedding = encoder.encode(&data).unwrap();
|
||||
assert_eq!(embedding.len(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_similarity() {
|
||||
let config = test_config();
|
||||
let encoder = StageIIEncoder::new(&config);
|
||||
|
||||
let a = vec![1.0; 32];
|
||||
let b = vec![1.0; 32];
|
||||
let sim = encoder.similarity(&a, &b);
|
||||
assert!((sim - 1.0).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_impressions() {
|
||||
let config = test_config();
|
||||
let encoder = StageIIEncoder::new(&config);
|
||||
|
||||
let data = StageIIData {
|
||||
impressions: vec![],
|
||||
feature_vector: None,
|
||||
};
|
||||
|
||||
assert!(encoder.encode(&data).is_err());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,282 @@
|
||||
//! Stage III Encoder: Dimensional Data via GNN Graph Topology
|
||||
//!
|
||||
//! CRV Stage III captures spatial sketches and geometric relationships.
|
||||
//! These naturally form a graph where sketch elements are nodes and spatial
|
||||
//! relationships are edges. The GNN layer learns to propagate spatial
|
||||
//! context through the graph, producing an embedding that captures the
|
||||
//! full dimensional structure of the target.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! Sketch elements → node features, spatial relationships → edge weights.
|
||||
//! A GNN forward pass aggregates neighborhood information to produce
|
||||
//! a graph-level embedding.
|
||||
|
||||
use crate::error::{CrvError, CrvResult};
|
||||
use crate::types::{CrvConfig, GeometricKind, SpatialRelationType, StageIIIData};
|
||||
use ruvector_gnn::layer::RuvectorLayer;
|
||||
use ruvector_gnn::search::cosine_similarity;
|
||||
|
||||
/// Stage III encoder using GNN graph topology.
|
||||
#[derive(Debug)]
|
||||
pub struct StageIIIEncoder {
|
||||
/// Embedding dimensionality.
|
||||
dim: usize,
|
||||
/// GNN layer for spatial message passing.
|
||||
gnn_layer: RuvectorLayer,
|
||||
}
|
||||
|
||||
impl StageIIIEncoder {
|
||||
/// Create a new Stage III encoder.
|
||||
pub fn new(config: &CrvConfig) -> Self {
|
||||
let dim = config.dimensions;
|
||||
// Single GNN layer: input_dim -> hidden_dim, 1 head
|
||||
let gnn_layer = RuvectorLayer::new(dim, dim, 1, 0.0)
|
||||
.expect("ruvector-crv: valid GNN layer config (dim, dim, 1 head, 0.0 dropout)");
|
||||
|
||||
Self { dim, gnn_layer }
|
||||
}
|
||||
|
||||
/// Encode a sketch element into a node feature vector.
|
||||
fn encode_element(&self, label: &str, kind: GeometricKind, position: (f32, f32), scale: Option<f32>) -> Vec<f32> {
|
||||
let mut features = vec![0.0f32; self.dim];
|
||||
|
||||
// Geometric kind encoding (one-hot style in first 8 dims)
|
||||
let kind_idx = match kind {
|
||||
GeometricKind::Point => 0,
|
||||
GeometricKind::Line => 1,
|
||||
GeometricKind::Curve => 2,
|
||||
GeometricKind::Rectangle => 3,
|
||||
GeometricKind::Circle => 4,
|
||||
GeometricKind::Triangle => 5,
|
||||
GeometricKind::Polygon => 6,
|
||||
GeometricKind::Freeform => 7,
|
||||
};
|
||||
if kind_idx < self.dim {
|
||||
features[kind_idx] = 1.0;
|
||||
}
|
||||
|
||||
// Position encoding (normalized)
|
||||
if 8 < self.dim {
|
||||
features[8] = position.0;
|
||||
}
|
||||
if 9 < self.dim {
|
||||
features[9] = position.1;
|
||||
}
|
||||
|
||||
// Scale encoding
|
||||
if let Some(s) = scale {
|
||||
if 10 < self.dim {
|
||||
features[10] = s;
|
||||
}
|
||||
}
|
||||
|
||||
// Label hash encoding (spread across remaining dims)
|
||||
for (i, byte) in label.bytes().enumerate() {
|
||||
let idx = 11 + (i % (self.dim.saturating_sub(11)));
|
||||
if idx < self.dim {
|
||||
features[idx] += (byte as f32 / 255.0) * 0.5;
|
||||
}
|
||||
}
|
||||
|
||||
// L2 normalize
|
||||
let norm: f32 = features.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 1e-6 {
|
||||
for f in &mut features {
|
||||
*f /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
features
|
||||
}
|
||||
|
||||
/// Compute edge weight from spatial relationship type.
|
||||
fn relationship_weight(relation: SpatialRelationType) -> f32 {
|
||||
match relation {
|
||||
SpatialRelationType::Adjacent => 0.8,
|
||||
SpatialRelationType::Contains => 0.9,
|
||||
SpatialRelationType::Above => 0.6,
|
||||
SpatialRelationType::Below => 0.6,
|
||||
SpatialRelationType::Inside => 0.95,
|
||||
SpatialRelationType::Surrounding => 0.85,
|
||||
SpatialRelationType::Connected => 0.7,
|
||||
SpatialRelationType::Separated => 0.3,
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode Stage III data into a graph-level embedding.
|
||||
///
|
||||
/// Builds a graph from sketch elements and relationships,
|
||||
/// runs GNN message passing, then aggregates node embeddings
|
||||
/// into a single graph-level vector.
|
||||
pub fn encode(&self, data: &StageIIIData) -> CrvResult<Vec<f32>> {
|
||||
if data.sketch_elements.is_empty() {
|
||||
return Err(CrvError::EmptyInput(
|
||||
"No sketch elements".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Build label → index mapping
|
||||
let label_to_idx: std::collections::HashMap<&str, usize> = data
|
||||
.sketch_elements
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, elem)| (elem.label.as_str(), i))
|
||||
.collect();
|
||||
|
||||
// Encode each element as a node feature vector
|
||||
let node_features: Vec<Vec<f32>> = data
|
||||
.sketch_elements
|
||||
.iter()
|
||||
.map(|elem| {
|
||||
self.encode_element(&elem.label, elem.kind, elem.position, elem.scale)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// For each node, collect neighbor embeddings and edge weights
|
||||
// based on the spatial relationships
|
||||
let mut aggregated = vec![vec![0.0f32; self.dim]; node_features.len()];
|
||||
|
||||
for (node_idx, node_feat) in node_features.iter().enumerate() {
|
||||
let label = &data.sketch_elements[node_idx].label;
|
||||
|
||||
// Find all relationships involving this node
|
||||
let mut neighbor_feats = Vec::new();
|
||||
let mut edge_weights = Vec::new();
|
||||
|
||||
for rel in &data.relationships {
|
||||
if rel.from == *label {
|
||||
if let Some(&neighbor_idx) = label_to_idx.get(rel.to.as_str()) {
|
||||
neighbor_feats.push(node_features[neighbor_idx].clone());
|
||||
edge_weights.push(Self::relationship_weight(rel.relation) * rel.strength);
|
||||
}
|
||||
} else if rel.to == *label {
|
||||
if let Some(&neighbor_idx) = label_to_idx.get(rel.from.as_str()) {
|
||||
neighbor_feats.push(node_features[neighbor_idx].clone());
|
||||
edge_weights.push(Self::relationship_weight(rel.relation) * rel.strength);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GNN forward pass for this node
|
||||
aggregated[node_idx] =
|
||||
self.gnn_layer
|
||||
.forward(node_feat, &neighbor_feats, &edge_weights);
|
||||
}
|
||||
|
||||
// Aggregate into graph-level embedding via mean pooling
|
||||
let mut graph_embedding = vec![0.0f32; self.dim];
|
||||
for node_emb in &aggregated {
|
||||
for (i, &v) in node_emb.iter().enumerate() {
|
||||
if i < self.dim {
|
||||
graph_embedding[i] += v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let n = aggregated.len() as f32;
|
||||
for v in &mut graph_embedding {
|
||||
*v /= n;
|
||||
}
|
||||
|
||||
Ok(graph_embedding)
|
||||
}
|
||||
|
||||
/// Compute similarity between two Stage III embeddings.
|
||||
pub fn similarity(&self, a: &[f32], b: &[f32]) -> f32 {
|
||||
cosine_similarity(a, b)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::{SketchElement, SpatialRelationship};
|
||||
|
||||
fn test_config() -> CrvConfig {
|
||||
CrvConfig {
|
||||
dimensions: 32,
|
||||
..CrvConfig::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encoder_creation() {
|
||||
let config = test_config();
|
||||
let encoder = StageIIIEncoder::new(&config);
|
||||
assert_eq!(encoder.dim, 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_element_encoding() {
|
||||
let config = test_config();
|
||||
let encoder = StageIIIEncoder::new(&config);
|
||||
|
||||
let features = encoder.encode_element(
|
||||
"building",
|
||||
GeometricKind::Rectangle,
|
||||
(0.5, 0.3),
|
||||
Some(2.0),
|
||||
);
|
||||
assert_eq!(features.len(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_full_encode() {
|
||||
let config = test_config();
|
||||
let encoder = StageIIIEncoder::new(&config);
|
||||
|
||||
let data = StageIIIData {
|
||||
sketch_elements: vec![
|
||||
SketchElement {
|
||||
label: "tower".to_string(),
|
||||
kind: GeometricKind::Rectangle,
|
||||
position: (0.5, 0.8),
|
||||
scale: Some(3.0),
|
||||
},
|
||||
SketchElement {
|
||||
label: "base".to_string(),
|
||||
kind: GeometricKind::Rectangle,
|
||||
position: (0.5, 0.2),
|
||||
scale: Some(5.0),
|
||||
},
|
||||
SketchElement {
|
||||
label: "path".to_string(),
|
||||
kind: GeometricKind::Line,
|
||||
position: (0.3, 0.1),
|
||||
scale: None,
|
||||
},
|
||||
],
|
||||
relationships: vec![
|
||||
SpatialRelationship {
|
||||
from: "tower".to_string(),
|
||||
to: "base".to_string(),
|
||||
relation: SpatialRelationType::Above,
|
||||
strength: 0.9,
|
||||
},
|
||||
SpatialRelationship {
|
||||
from: "path".to_string(),
|
||||
to: "base".to_string(),
|
||||
relation: SpatialRelationType::Adjacent,
|
||||
strength: 0.7,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let embedding = encoder.encode(&data).unwrap();
|
||||
assert_eq!(embedding.len(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_elements() {
|
||||
let config = test_config();
|
||||
let encoder = StageIIIEncoder::new(&config);
|
||||
|
||||
let data = StageIIIData {
|
||||
sketch_elements: vec![],
|
||||
relationships: vec![],
|
||||
};
|
||||
|
||||
assert!(encoder.encode(&data).is_err());
|
||||
}
|
||||
}
|
||||
339
rust-port/wifi-densepose-rs/patches/ruvector-crv/src/stage_iv.rs
Normal file
339
rust-port/wifi-densepose-rs/patches/ruvector-crv/src/stage_iv.rs
Normal file
@@ -0,0 +1,339 @@
|
||||
//! Stage IV Encoder: Emotional/AOL Data via SNN Temporal Encoding
|
||||
//!
|
||||
//! CRV Stage IV captures emotional impacts, tangibles, intangibles, and
|
||||
//! analytical overlay (AOL) detections. The spiking neural network (SNN)
|
||||
//! temporal encoding naturally models the signal-vs-noise discrimination
|
||||
//! that Stage IV demands:
|
||||
//!
|
||||
//! - High-frequency spike bursts correlate with AOL contamination
|
||||
//! - Sustained low-frequency patterns indicate clean signal line data
|
||||
//! - The refractory period prevents AOL cascade (analytical runaway)
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! Emotional intensity timeseries → SNN input currents.
|
||||
//! Network spike rate analysis detects AOL events.
|
||||
//! The embedding captures both the clean signal and AOL separation.
|
||||
|
||||
use crate::error::CrvResult;
|
||||
use crate::types::{AOLDetection, CrvConfig, StageIVData};
|
||||
use ruvector_mincut::snn::{LayerConfig, NetworkConfig, NeuronConfig, SpikingNetwork};
|
||||
|
||||
/// Stage IV encoder using spiking neural network temporal encoding.
|
||||
#[derive(Debug)]
|
||||
pub struct StageIVEncoder {
|
||||
/// Embedding dimensionality.
|
||||
dim: usize,
|
||||
/// AOL detection threshold (spike rate above this = likely AOL).
|
||||
aol_threshold: f32,
|
||||
/// SNN time step.
|
||||
dt: f64,
|
||||
/// Refractory period for AOL cascade prevention.
|
||||
refractory_period: f64,
|
||||
}
|
||||
|
||||
impl StageIVEncoder {
|
||||
/// Create a new Stage IV encoder.
|
||||
pub fn new(config: &CrvConfig) -> Self {
|
||||
Self {
|
||||
dim: config.dimensions,
|
||||
aol_threshold: config.aol_threshold,
|
||||
dt: config.snn_dt,
|
||||
refractory_period: config.refractory_period_ms,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a spiking network configured for emotional signal processing.
|
||||
///
|
||||
/// The network has 3 layers:
|
||||
/// - Input: receives emotional intensity as current
|
||||
/// - Hidden: processes temporal patterns
|
||||
/// - Output: produces the embedding dimensions
|
||||
fn create_network(&self, input_size: usize) -> SpikingNetwork {
|
||||
let hidden_size = (input_size * 2).max(16).min(128);
|
||||
let output_size = self.dim.min(64); // SNN output, will be expanded
|
||||
|
||||
let neuron_config = NeuronConfig {
|
||||
tau_membrane: 20.0,
|
||||
v_rest: 0.0,
|
||||
v_reset: 0.0,
|
||||
threshold: 1.0,
|
||||
t_refrac: self.refractory_period,
|
||||
resistance: 1.0,
|
||||
threshold_adapt: 0.1,
|
||||
tau_threshold: 100.0,
|
||||
homeostatic: true,
|
||||
target_rate: 0.01,
|
||||
tau_homeostatic: 1000.0,
|
||||
};
|
||||
|
||||
let config = NetworkConfig {
|
||||
layers: vec![
|
||||
LayerConfig::new(input_size).with_neuron_config(neuron_config.clone()),
|
||||
LayerConfig::new(hidden_size)
|
||||
.with_neuron_config(neuron_config.clone())
|
||||
.with_recurrence(),
|
||||
LayerConfig::new(output_size).with_neuron_config(neuron_config),
|
||||
],
|
||||
stdp_config: Default::default(),
|
||||
dt: self.dt,
|
||||
winner_take_all: false,
|
||||
wta_strength: 0.0,
|
||||
};
|
||||
|
||||
SpikingNetwork::new(config)
|
||||
}
|
||||
|
||||
/// Encode emotional intensity values into SNN input currents.
|
||||
fn emotional_to_currents(intensities: &[(String, f32)]) -> Vec<f64> {
|
||||
intensities
|
||||
.iter()
|
||||
.map(|(_, intensity)| *intensity as f64 * 5.0) // Scale to reasonable current
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Analyze spike output to detect AOL events.
|
||||
///
|
||||
/// High spike rate in a short window indicates the analytical mind
|
||||
/// is overriding the signal line (AOL contamination).
|
||||
fn detect_aol(
|
||||
&self,
|
||||
spike_rates: &[f64],
|
||||
window_ms: f64,
|
||||
) -> Vec<AOLDetection> {
|
||||
let mut detections = Vec::new();
|
||||
let threshold = self.aol_threshold as f64;
|
||||
|
||||
for (i, &rate) in spike_rates.iter().enumerate() {
|
||||
if rate > threshold {
|
||||
detections.push(AOLDetection {
|
||||
content: format!("AOL burst at timestep {}", i),
|
||||
timestamp_ms: (i as f64 * window_ms) as u64,
|
||||
flagged: rate > threshold * 1.5, // Auto-flag strong AOL
|
||||
anomaly_score: (rate / threshold).min(1.0) as f32,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
detections
|
||||
}
|
||||
|
||||
/// Encode Stage IV data into a temporal embedding.
|
||||
///
|
||||
/// Runs the SNN on emotional intensity data, analyzes spike patterns
|
||||
/// for AOL contamination, and produces a combined embedding that
|
||||
/// captures both clean signal and AOL separation.
|
||||
pub fn encode(&self, data: &StageIVData) -> CrvResult<Vec<f32>> {
|
||||
// Build input from emotional impact data
|
||||
let input_size = data.emotional_impact.len().max(1);
|
||||
let currents = Self::emotional_to_currents(&data.emotional_impact);
|
||||
|
||||
if currents.is_empty() {
|
||||
// Fall back to text-based encoding if no emotional intensity data
|
||||
return self.encode_from_text(data);
|
||||
}
|
||||
|
||||
// Run SNN simulation
|
||||
let mut network = self.create_network(input_size);
|
||||
let num_steps = 100; // 100ms simulation
|
||||
let mut spike_counts = vec![0usize; network.layer_size(network.num_layers() - 1)];
|
||||
let mut step_rates = Vec::new();
|
||||
|
||||
for step in 0..num_steps {
|
||||
// Inject currents (modulated by step for temporal variation)
|
||||
let modulated: Vec<f64> = currents
|
||||
.iter()
|
||||
.map(|&c| c * (1.0 + 0.3 * ((step as f64 * 0.1).sin())))
|
||||
.collect();
|
||||
network.inject_current(&modulated);
|
||||
|
||||
let spikes = network.step();
|
||||
for spike in &spikes {
|
||||
if spike.neuron_id < spike_counts.len() {
|
||||
spike_counts[spike.neuron_id] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Track rate per window
|
||||
if step % 10 == 9 {
|
||||
let rate = spikes.len() as f64 / 10.0;
|
||||
step_rates.push(rate);
|
||||
}
|
||||
}
|
||||
|
||||
// Build embedding from spike counts and output activities
|
||||
let output = network.get_output();
|
||||
let mut embedding = vec![0.0f32; self.dim];
|
||||
|
||||
// First portion: spike count features
|
||||
let spike_dims = spike_counts.len().min(self.dim / 3);
|
||||
let max_count = *spike_counts.iter().max().unwrap_or(&1) as f32;
|
||||
for (i, &count) in spike_counts.iter().take(spike_dims).enumerate() {
|
||||
embedding[i] = count as f32 / max_count.max(1.0);
|
||||
}
|
||||
|
||||
// Second portion: membrane potential output
|
||||
let pot_offset = self.dim / 3;
|
||||
let pot_dims = output.len().min(self.dim / 3);
|
||||
for (i, &v) in output.iter().take(pot_dims).enumerate() {
|
||||
if pot_offset + i < self.dim {
|
||||
embedding[pot_offset + i] = v as f32;
|
||||
}
|
||||
}
|
||||
|
||||
// Third portion: text-derived features from tangibles/intangibles
|
||||
let text_offset = 2 * self.dim / 3;
|
||||
self.encode_text_features(data, &mut embedding[text_offset..]);
|
||||
|
||||
// Encode AOL information
|
||||
let aol_detections = self.detect_aol(&step_rates, 10.0);
|
||||
let aol_count = (aol_detections.len() + data.aol_detections.len()) as f32;
|
||||
if self.dim > 2 {
|
||||
// Store AOL contamination level in last dimension
|
||||
embedding[self.dim - 1] = (aol_count / num_steps as f32).min(1.0);
|
||||
}
|
||||
|
||||
// L2 normalize
|
||||
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 1e-6 {
|
||||
for f in &mut embedding {
|
||||
*f /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
/// Text-based encoding fallback when no intensity timeseries is available.
|
||||
fn encode_from_text(&self, data: &StageIVData) -> CrvResult<Vec<f32>> {
|
||||
let mut embedding = vec![0.0f32; self.dim];
|
||||
self.encode_text_features(data, &mut embedding);
|
||||
|
||||
// L2 normalize
|
||||
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 1e-6 {
|
||||
for f in &mut embedding {
|
||||
*f /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
/// Encode text descriptors (tangibles, intangibles) into feature slots.
|
||||
fn encode_text_features(&self, data: &StageIVData, features: &mut [f32]) {
|
||||
if features.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Hash tangibles
|
||||
for (i, tangible) in data.tangibles.iter().enumerate() {
|
||||
for (j, byte) in tangible.bytes().enumerate() {
|
||||
let idx = (i * 7 + j) % features.len();
|
||||
features[idx] += (byte as f32 / 255.0) * 0.3;
|
||||
}
|
||||
}
|
||||
|
||||
// Hash intangibles
|
||||
for (i, intangible) in data.intangibles.iter().enumerate() {
|
||||
for (j, byte) in intangible.bytes().enumerate() {
|
||||
let idx = (i * 11 + j + features.len() / 2) % features.len();
|
||||
features[idx] += (byte as f32 / 255.0) * 0.3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the AOL anomaly score for a given Stage IV embedding.
|
||||
///
|
||||
/// Higher values indicate more AOL contamination.
|
||||
pub fn aol_score(&self, embedding: &[f32]) -> f32 {
|
||||
if embedding.len() >= self.dim && self.dim > 2 {
|
||||
embedding[self.dim - 1].abs()
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_config() -> CrvConfig {
|
||||
CrvConfig {
|
||||
dimensions: 32,
|
||||
aol_threshold: 0.7,
|
||||
refractory_period_ms: 50.0,
|
||||
snn_dt: 1.0,
|
||||
..CrvConfig::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encoder_creation() {
|
||||
let config = test_config();
|
||||
let encoder = StageIVEncoder::new(&config);
|
||||
assert_eq!(encoder.dim, 32);
|
||||
assert_eq!(encoder.aol_threshold, 0.7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_text_only_encode() {
|
||||
let config = test_config();
|
||||
let encoder = StageIVEncoder::new(&config);
|
||||
|
||||
let data = StageIVData {
|
||||
emotional_impact: vec![],
|
||||
tangibles: vec!["metal".to_string(), "concrete".to_string()],
|
||||
intangibles: vec!["historical significance".to_string()],
|
||||
aol_detections: vec![],
|
||||
};
|
||||
|
||||
let embedding = encoder.encode(&data).unwrap();
|
||||
assert_eq!(embedding.len(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_full_encode_with_snn() {
|
||||
let config = test_config();
|
||||
let encoder = StageIVEncoder::new(&config);
|
||||
|
||||
let data = StageIVData {
|
||||
emotional_impact: vec![
|
||||
("awe".to_string(), 0.8),
|
||||
("unease".to_string(), 0.3),
|
||||
("curiosity".to_string(), 0.6),
|
||||
],
|
||||
tangibles: vec!["stone wall".to_string()],
|
||||
intangibles: vec!["ancient purpose".to_string()],
|
||||
aol_detections: vec![AOLDetection {
|
||||
content: "looks like a castle".to_string(),
|
||||
timestamp_ms: 500,
|
||||
flagged: true,
|
||||
anomaly_score: 0.8,
|
||||
}],
|
||||
};
|
||||
|
||||
let embedding = encoder.encode(&data).unwrap();
|
||||
assert_eq!(embedding.len(), 32);
|
||||
|
||||
// Should be normalized
|
||||
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
assert!((norm - 1.0).abs() < 0.1 || norm < 0.01); // normalized or near-zero
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aol_detection() {
|
||||
let config = test_config();
|
||||
let encoder = StageIVEncoder::new(&config);
|
||||
|
||||
let rates = vec![0.1, 0.2, 0.9, 0.95, 0.3, 0.1];
|
||||
let detections = encoder.detect_aol(&rates, 10.0);
|
||||
|
||||
// Should detect the high-rate windows as AOL
|
||||
assert!(detections.len() >= 2);
|
||||
for d in &detections {
|
||||
assert!(d.anomaly_score > 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
222
rust-port/wifi-densepose-rs/patches/ruvector-crv/src/stage_v.rs
Normal file
222
rust-port/wifi-densepose-rs/patches/ruvector-crv/src/stage_v.rs
Normal file
@@ -0,0 +1,222 @@
|
||||
//! Stage V: Interrogation via Differentiable Search with Soft Attention
|
||||
//!
|
||||
//! CRV Stage V involves probing the signal line by asking targeted questions
|
||||
//! about specific aspects of the target, then cross-referencing results
|
||||
//! across all accumulated data from Stages I-IV.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! Uses `ruvector_gnn::search::differentiable_search` to find the most
|
||||
//! relevant data entries for each probe query, with soft attention weights
|
||||
//! providing a continuous similarity measure rather than hard thresholds.
|
||||
//! This enables gradient-based refinement of probe queries.
|
||||
|
||||
use crate::error::{CrvError, CrvResult};
|
||||
use crate::types::{CrossReference, CrvConfig, SignalLineProbe, StageVData};
|
||||
use ruvector_gnn::search::{cosine_similarity, differentiable_search};
|
||||
|
||||
/// Stage V interrogation engine using differentiable search.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StageVEngine {
|
||||
/// Embedding dimensionality.
|
||||
dim: usize,
|
||||
/// Temperature for differentiable search softmax.
|
||||
temperature: f32,
|
||||
}
|
||||
|
||||
impl StageVEngine {
|
||||
/// Create a new Stage V engine.
|
||||
pub fn new(config: &CrvConfig) -> Self {
|
||||
Self {
|
||||
dim: config.dimensions,
|
||||
temperature: config.search_temperature,
|
||||
}
|
||||
}
|
||||
|
||||
/// Probe the accumulated session embeddings with a query.
|
||||
///
|
||||
/// Performs differentiable search over the given candidate embeddings,
|
||||
/// returning soft attention weights and top-k candidates.
|
||||
pub fn probe(
|
||||
&self,
|
||||
query_embedding: &[f32],
|
||||
candidates: &[Vec<f32>],
|
||||
k: usize,
|
||||
) -> CrvResult<SignalLineProbe> {
|
||||
if candidates.is_empty() {
|
||||
return Err(CrvError::EmptyInput(
|
||||
"No candidates for probing".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let (top_candidates, attention_weights) =
|
||||
differentiable_search(query_embedding, candidates, k, self.temperature);
|
||||
|
||||
Ok(SignalLineProbe {
|
||||
query: String::new(), // Caller sets the text
|
||||
target_stage: 0, // Caller sets the stage
|
||||
attention_weights,
|
||||
top_candidates,
|
||||
})
|
||||
}
|
||||
|
||||
/// Cross-reference entries across stages to find correlations.
|
||||
///
|
||||
/// For each entry in `from_entries`, finds the most similar entries
|
||||
/// in `to_entries` using cosine similarity, producing cross-references
|
||||
/// above the given threshold.
|
||||
pub fn cross_reference(
|
||||
&self,
|
||||
from_stage: u8,
|
||||
from_entries: &[Vec<f32>],
|
||||
to_stage: u8,
|
||||
to_entries: &[Vec<f32>],
|
||||
threshold: f32,
|
||||
) -> Vec<CrossReference> {
|
||||
let mut refs = Vec::new();
|
||||
|
||||
for (from_idx, from_emb) in from_entries.iter().enumerate() {
|
||||
for (to_idx, to_emb) in to_entries.iter().enumerate() {
|
||||
if from_emb.len() == to_emb.len() {
|
||||
let score = cosine_similarity(from_emb, to_emb);
|
||||
if score >= threshold {
|
||||
refs.push(CrossReference {
|
||||
from_stage,
|
||||
from_entry: from_idx,
|
||||
to_stage,
|
||||
to_entry: to_idx,
|
||||
score,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by score descending
|
||||
refs.sort_by(|a, b| {
|
||||
b.score
|
||||
.partial_cmp(&a.score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
refs
|
||||
}
|
||||
|
||||
/// Encode Stage V data into a combined interrogation embedding.
|
||||
///
|
||||
/// Aggregates the attention weights from all probes to produce
|
||||
/// a unified view of which aspects of the target were most
|
||||
/// responsive to interrogation.
|
||||
pub fn encode(&self, data: &StageVData, all_embeddings: &[Vec<f32>]) -> CrvResult<Vec<f32>> {
|
||||
if data.probes.is_empty() {
|
||||
return Err(CrvError::EmptyInput("No probes in Stage V data".to_string()));
|
||||
}
|
||||
|
||||
let mut embedding = vec![0.0f32; self.dim];
|
||||
|
||||
// Weight each candidate embedding by its attention weight across all probes
|
||||
for probe in &data.probes {
|
||||
for (&candidate_idx, &weight) in probe
|
||||
.top_candidates
|
||||
.iter()
|
||||
.zip(probe.attention_weights.iter())
|
||||
{
|
||||
if candidate_idx < all_embeddings.len() {
|
||||
let emb = &all_embeddings[candidate_idx];
|
||||
for (i, &v) in emb.iter().enumerate() {
|
||||
if i < self.dim {
|
||||
embedding[i] += v * weight;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize by number of probes
|
||||
let num_probes = data.probes.len() as f32;
|
||||
for v in &mut embedding {
|
||||
*v /= num_probes;
|
||||
}
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
/// Compute the interrogation signal strength for a given embedding.
|
||||
///
|
||||
/// Higher values indicate more responsive signal line data.
|
||||
pub fn signal_strength(&self, embedding: &[f32]) -> f32 {
|
||||
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
norm
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_config() -> CrvConfig {
|
||||
CrvConfig {
|
||||
dimensions: 8,
|
||||
search_temperature: 1.0,
|
||||
..CrvConfig::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_engine_creation() {
|
||||
let config = test_config();
|
||||
let engine = StageVEngine::new(&config);
|
||||
assert_eq!(engine.dim, 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_probe() {
|
||||
let config = test_config();
|
||||
let engine = StageVEngine::new(&config);
|
||||
|
||||
let query = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
|
||||
let candidates = vec![
|
||||
vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], // exact match
|
||||
vec![0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], // partial
|
||||
vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], // orthogonal
|
||||
];
|
||||
|
||||
let probe = engine.probe(&query, &candidates, 2).unwrap();
|
||||
assert_eq!(probe.top_candidates.len(), 2);
|
||||
assert_eq!(probe.attention_weights.len(), 2);
|
||||
// Best match should be first
|
||||
assert_eq!(probe.top_candidates[0], 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_reference() {
|
||||
let config = test_config();
|
||||
let engine = StageVEngine::new(&config);
|
||||
|
||||
let from = vec![
|
||||
vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
];
|
||||
let to = vec![
|
||||
vec![0.9, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], // similar to from[0]
|
||||
vec![0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], // different
|
||||
];
|
||||
|
||||
let refs = engine.cross_reference(1, &from, 2, &to, 0.5);
|
||||
assert!(!refs.is_empty());
|
||||
assert_eq!(refs[0].from_stage, 1);
|
||||
assert_eq!(refs[0].to_stage, 2);
|
||||
assert!(refs[0].score > 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_probe() {
|
||||
let config = test_config();
|
||||
let engine = StageVEngine::new(&config);
|
||||
|
||||
let query = vec![1.0; 8];
|
||||
let candidates: Vec<Vec<f32>> = vec![];
|
||||
|
||||
assert!(engine.probe(&query, &candidates, 5).is_err());
|
||||
}
|
||||
}
|
||||
387
rust-port/wifi-densepose-rs/patches/ruvector-crv/src/stage_vi.rs
Normal file
387
rust-port/wifi-densepose-rs/patches/ruvector-crv/src/stage_vi.rs
Normal file
@@ -0,0 +1,387 @@
|
||||
//! Stage VI: Composite Modeling via MinCut Partitioning
|
||||
//!
|
||||
//! CRV Stage VI builds a composite 3D model from all accumulated session data.
|
||||
//! The MinCut algorithm identifies natural cluster boundaries in the session
|
||||
//! graph, separating distinct target aspects that emerged across stages.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! All session embeddings form nodes in a weighted graph, with edge weights
|
||||
//! derived from cosine similarity. MinCut partitioning finds the natural
|
||||
//! separations between target aspects, producing distinct partitions that
|
||||
//! represent different facets of the target.
|
||||
|
||||
use crate::error::{CrvError, CrvResult};
|
||||
use crate::types::{CrvConfig, StageVIData, TargetPartition};
|
||||
use ruvector_gnn::search::cosine_similarity;
|
||||
use ruvector_mincut::prelude::*;
|
||||
|
||||
/// Stage VI composite modeler using MinCut partitioning.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StageVIModeler {
|
||||
/// Embedding dimensionality.
|
||||
dim: usize,
|
||||
/// Minimum edge weight to create an edge (similarity threshold).
|
||||
edge_threshold: f32,
|
||||
}
|
||||
|
||||
impl StageVIModeler {
|
||||
/// Create a new Stage VI modeler.
|
||||
pub fn new(config: &CrvConfig) -> Self {
|
||||
Self {
|
||||
dim: config.dimensions,
|
||||
edge_threshold: 0.2, // Low threshold to capture weak relationships too
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a similarity graph from session embeddings.
|
||||
///
|
||||
/// Each embedding becomes a vertex. Edges are created between
|
||||
/// pairs with cosine similarity above the threshold, with
|
||||
/// edge weight equal to the similarity score.
|
||||
fn build_similarity_graph(&self, embeddings: &[Vec<f32>]) -> Vec<(u64, u64, f64)> {
|
||||
let n = embeddings.len();
|
||||
let mut edges = Vec::new();
|
||||
|
||||
for i in 0..n {
|
||||
for j in (i + 1)..n {
|
||||
if embeddings[i].len() == embeddings[j].len() && !embeddings[i].is_empty() {
|
||||
let sim = cosine_similarity(&embeddings[i], &embeddings[j]);
|
||||
if sim > self.edge_threshold {
|
||||
edges.push((i as u64 + 1, j as u64 + 1, sim as f64));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
edges
|
||||
}
|
||||
|
||||
/// Compute centroid of a set of embeddings.
|
||||
fn compute_centroid(&self, embeddings: &[&[f32]]) -> Vec<f32> {
|
||||
if embeddings.is_empty() {
|
||||
return vec![0.0; self.dim];
|
||||
}
|
||||
|
||||
let mut centroid = vec![0.0f32; self.dim];
|
||||
for emb in embeddings {
|
||||
for (i, &v) in emb.iter().enumerate() {
|
||||
if i < self.dim {
|
||||
centroid[i] += v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let n = embeddings.len() as f32;
|
||||
for v in &mut centroid {
|
||||
*v /= n;
|
||||
}
|
||||
|
||||
centroid
|
||||
}
|
||||
|
||||
/// Partition session embeddings into target aspects using MinCut.
|
||||
///
|
||||
/// Returns the MinCut-based partition assignments and centroids.
|
||||
pub fn partition(
|
||||
&self,
|
||||
embeddings: &[Vec<f32>],
|
||||
stage_labels: &[(u8, usize)], // (stage, entry_index) for each embedding
|
||||
) -> CrvResult<StageVIData> {
|
||||
if embeddings.len() < 2 {
|
||||
// With fewer than 2 embeddings, return a single partition
|
||||
let centroid = if embeddings.is_empty() {
|
||||
vec![0.0; self.dim]
|
||||
} else {
|
||||
embeddings[0].clone()
|
||||
};
|
||||
|
||||
return Ok(StageVIData {
|
||||
partitions: vec![TargetPartition {
|
||||
label: "primary".to_string(),
|
||||
member_entries: stage_labels.to_vec(),
|
||||
centroid,
|
||||
separation_strength: 0.0,
|
||||
}],
|
||||
composite_description: "Single-aspect target".to_string(),
|
||||
partition_confidence: vec![1.0],
|
||||
});
|
||||
}
|
||||
|
||||
// Build similarity graph
|
||||
let edges = self.build_similarity_graph(embeddings);
|
||||
|
||||
if edges.is_empty() {
|
||||
// No significant similarities found - each embedding is its own partition
|
||||
let partitions: Vec<TargetPartition> = embeddings
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, emb)| TargetPartition {
|
||||
label: format!("aspect-{}", i),
|
||||
member_entries: if i < stage_labels.len() {
|
||||
vec![stage_labels[i]]
|
||||
} else {
|
||||
vec![]
|
||||
},
|
||||
centroid: emb.clone(),
|
||||
separation_strength: 1.0,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let n = partitions.len();
|
||||
return Ok(StageVIData {
|
||||
partitions,
|
||||
composite_description: format!("{} disconnected aspects", n),
|
||||
partition_confidence: vec![0.5; n],
|
||||
});
|
||||
}
|
||||
|
||||
// Build MinCut structure
|
||||
let mincut_result = MinCutBuilder::new()
|
||||
.exact()
|
||||
.with_edges(edges.clone())
|
||||
.build();
|
||||
|
||||
let mincut = match mincut_result {
|
||||
Ok(mc) => mc,
|
||||
Err(_) => {
|
||||
// Fallback: single partition
|
||||
let centroid = self.compute_centroid(
|
||||
&embeddings.iter().map(|e| e.as_slice()).collect::<Vec<_>>(),
|
||||
);
|
||||
return Ok(StageVIData {
|
||||
partitions: vec![TargetPartition {
|
||||
label: "composite".to_string(),
|
||||
member_entries: stage_labels.to_vec(),
|
||||
centroid,
|
||||
separation_strength: 0.0,
|
||||
}],
|
||||
composite_description: "Unified composite model".to_string(),
|
||||
partition_confidence: vec![0.8],
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let cut_value = mincut.min_cut_value();
|
||||
|
||||
// Use the MinCut value to determine partition boundary.
|
||||
// We partition into two groups based on connectivity:
|
||||
// vertices more connected to the "left" side vs "right" side.
|
||||
let n = embeddings.len();
|
||||
|
||||
// Simple 2-partition based on similarity to first vs last embedding
|
||||
let (group_a, group_b) = self.bisect_by_similarity(embeddings);
|
||||
|
||||
let centroid_a = self.compute_centroid(
|
||||
&group_a.iter().map(|&i| embeddings[i].as_slice()).collect::<Vec<_>>(),
|
||||
);
|
||||
let centroid_b = self.compute_centroid(
|
||||
&group_b.iter().map(|&i| embeddings[i].as_slice()).collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let members_a: Vec<(u8, usize)> = group_a
|
||||
.iter()
|
||||
.filter_map(|&i| stage_labels.get(i).copied())
|
||||
.collect();
|
||||
let members_b: Vec<(u8, usize)> = group_b
|
||||
.iter()
|
||||
.filter_map(|&i| stage_labels.get(i).copied())
|
||||
.collect();
|
||||
|
||||
let partitions = vec![
|
||||
TargetPartition {
|
||||
label: "primary-aspect".to_string(),
|
||||
member_entries: members_a,
|
||||
centroid: centroid_a,
|
||||
separation_strength: cut_value as f32,
|
||||
},
|
||||
TargetPartition {
|
||||
label: "secondary-aspect".to_string(),
|
||||
member_entries: members_b,
|
||||
centroid: centroid_b,
|
||||
separation_strength: cut_value as f32,
|
||||
},
|
||||
];
|
||||
|
||||
// Confidence based on separation strength
|
||||
let total_edges = edges.len() as f32;
|
||||
let conf = if total_edges > 0.0 {
|
||||
(cut_value as f32 / total_edges).min(1.0)
|
||||
} else {
|
||||
0.5
|
||||
};
|
||||
|
||||
Ok(StageVIData {
|
||||
partitions,
|
||||
composite_description: format!(
|
||||
"Bisected composite: {} embeddings, cut value {:.3}",
|
||||
n, cut_value
|
||||
),
|
||||
partition_confidence: vec![conf, conf],
|
||||
})
|
||||
}
|
||||
|
||||
/// Bisect embeddings into two groups by maximizing inter-group dissimilarity.
|
||||
///
|
||||
/// Uses a greedy approach: pick the two most dissimilar embeddings as seeds,
|
||||
/// then assign each remaining embedding to the nearer seed.
|
||||
fn bisect_by_similarity(&self, embeddings: &[Vec<f32>]) -> (Vec<usize>, Vec<usize>) {
|
||||
let n = embeddings.len();
|
||||
if n <= 1 {
|
||||
return ((0..n).collect(), vec![]);
|
||||
}
|
||||
|
||||
// Find the two most dissimilar embeddings
|
||||
let mut min_sim = f32::MAX;
|
||||
let mut seed_a = 0;
|
||||
let mut seed_b = 1;
|
||||
|
||||
for i in 0..n {
|
||||
for j in (i + 1)..n {
|
||||
if embeddings[i].len() == embeddings[j].len() && !embeddings[i].is_empty() {
|
||||
let sim = cosine_similarity(&embeddings[i], &embeddings[j]);
|
||||
if sim < min_sim {
|
||||
min_sim = sim;
|
||||
seed_a = i;
|
||||
seed_b = j;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut group_a = vec![seed_a];
|
||||
let mut group_b = vec![seed_b];
|
||||
|
||||
for i in 0..n {
|
||||
if i == seed_a || i == seed_b {
|
||||
continue;
|
||||
}
|
||||
|
||||
let sim_a = if embeddings[i].len() == embeddings[seed_a].len() {
|
||||
cosine_similarity(&embeddings[i], &embeddings[seed_a])
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
let sim_b = if embeddings[i].len() == embeddings[seed_b].len() {
|
||||
cosine_similarity(&embeddings[i], &embeddings[seed_b])
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
if sim_a >= sim_b {
|
||||
group_a.push(i);
|
||||
} else {
|
||||
group_b.push(i);
|
||||
}
|
||||
}
|
||||
|
||||
(group_a, group_b)
|
||||
}
|
||||
|
||||
/// Encode the Stage VI partition result into a single embedding.
|
||||
///
|
||||
/// Produces a weighted combination of partition centroids.
|
||||
pub fn encode(&self, data: &StageVIData) -> CrvResult<Vec<f32>> {
|
||||
if data.partitions.is_empty() {
|
||||
return Err(CrvError::EmptyInput("No partitions".to_string()));
|
||||
}
|
||||
|
||||
let mut embedding = vec![0.0f32; self.dim];
|
||||
let mut total_weight = 0.0f32;
|
||||
|
||||
for (partition, &confidence) in data.partitions.iter().zip(data.partition_confidence.iter()) {
|
||||
let weight = confidence * partition.member_entries.len() as f32;
|
||||
for (i, &v) in partition.centroid.iter().enumerate() {
|
||||
if i < self.dim {
|
||||
embedding[i] += v * weight;
|
||||
}
|
||||
}
|
||||
total_weight += weight;
|
||||
}
|
||||
|
||||
if total_weight > 1e-6 {
|
||||
for v in &mut embedding {
|
||||
*v /= total_weight;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_config() -> CrvConfig {
|
||||
CrvConfig {
|
||||
dimensions: 8,
|
||||
..CrvConfig::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_modeler_creation() {
|
||||
let config = test_config();
|
||||
let modeler = StageVIModeler::new(&config);
|
||||
assert_eq!(modeler.dim, 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partition_single() {
|
||||
let config = test_config();
|
||||
let modeler = StageVIModeler::new(&config);
|
||||
|
||||
let embeddings = vec![vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]];
|
||||
let labels = vec![(1, 0)];
|
||||
|
||||
let result = modeler.partition(&embeddings, &labels).unwrap();
|
||||
assert_eq!(result.partitions.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partition_two_clusters() {
|
||||
let config = test_config();
|
||||
let modeler = StageVIModeler::new(&config);
|
||||
|
||||
// Two clearly separated clusters
|
||||
let embeddings = vec![
|
||||
vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
vec![0.9, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
vec![0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
|
||||
vec![0.0, 0.0, 0.0, 0.0, 0.9, 0.1, 0.0, 0.0],
|
||||
];
|
||||
let labels = vec![(1, 0), (2, 0), (3, 0), (4, 0)];
|
||||
|
||||
let result = modeler.partition(&embeddings, &labels).unwrap();
|
||||
assert_eq!(result.partitions.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_partitions() {
|
||||
let config = test_config();
|
||||
let modeler = StageVIModeler::new(&config);
|
||||
|
||||
let data = StageVIData {
|
||||
partitions: vec![
|
||||
TargetPartition {
|
||||
label: "a".to_string(),
|
||||
member_entries: vec![(1, 0), (2, 0)],
|
||||
centroid: vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
separation_strength: 0.5,
|
||||
},
|
||||
TargetPartition {
|
||||
label: "b".to_string(),
|
||||
member_entries: vec![(3, 0)],
|
||||
centroid: vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
separation_strength: 0.5,
|
||||
},
|
||||
],
|
||||
composite_description: "test".to_string(),
|
||||
partition_confidence: vec![0.8, 0.6],
|
||||
};
|
||||
|
||||
let embedding = modeler.encode(&data).unwrap();
|
||||
assert_eq!(embedding.len(), 8);
|
||||
}
|
||||
}
|
||||
360
rust-port/wifi-densepose-rs/patches/ruvector-crv/src/types.rs
Normal file
360
rust-port/wifi-densepose-rs/patches/ruvector-crv/src/types.rs
Normal file
@@ -0,0 +1,360 @@
|
||||
//! Core types for the CRV (Coordinate Remote Viewing) protocol.
|
||||
//!
|
||||
//! Defines the data structures for the 6-stage CRV signal line methodology,
|
||||
//! session management, and analytical overlay (AOL) detection.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Unique identifier for a CRV session.
|
||||
pub type SessionId = String;
|
||||
|
||||
/// Unique identifier for a target coordinate.
|
||||
pub type TargetCoordinate = String;
|
||||
|
||||
/// Unique identifier for a stage data entry.
|
||||
pub type EntryId = String;
|
||||
|
||||
/// Classification of gestalt primitives in Stage I.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum GestaltType {
|
||||
/// Human-made structures, artifacts
|
||||
Manmade,
|
||||
/// Organic, natural formations
|
||||
Natural,
|
||||
/// Dynamic, kinetic signals
|
||||
Movement,
|
||||
/// Thermal, electromagnetic, force
|
||||
Energy,
|
||||
/// Aqueous, fluid, wet
|
||||
Water,
|
||||
/// Solid, terrain, geological
|
||||
Land,
|
||||
}
|
||||
|
||||
impl GestaltType {
|
||||
/// Returns all gestalt types for iteration.
|
||||
pub fn all() -> &'static [GestaltType] {
|
||||
&[
|
||||
GestaltType::Manmade,
|
||||
GestaltType::Natural,
|
||||
GestaltType::Movement,
|
||||
GestaltType::Energy,
|
||||
GestaltType::Water,
|
||||
GestaltType::Land,
|
||||
]
|
||||
}
|
||||
|
||||
/// Returns the index of this gestalt type in the canonical ordering.
|
||||
pub fn index(&self) -> usize {
|
||||
match self {
|
||||
GestaltType::Manmade => 0,
|
||||
GestaltType::Natural => 1,
|
||||
GestaltType::Movement => 2,
|
||||
GestaltType::Energy => 3,
|
||||
GestaltType::Water => 4,
|
||||
GestaltType::Land => 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stage I data: Ideogram traces and gestalt classifications.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StageIData {
|
||||
/// Raw ideogram stroke trace as a sequence of (x, y) coordinates.
|
||||
pub stroke: Vec<(f32, f32)>,
|
||||
/// First spontaneous descriptor word.
|
||||
pub spontaneous_descriptor: String,
|
||||
/// Classified gestalt type.
|
||||
pub classification: GestaltType,
|
||||
/// Confidence in the classification (0.0 - 1.0).
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
/// Sensory modality for Stage II data.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum SensoryModality {
|
||||
/// Surface textures (smooth, rough, grainy, etc.)
|
||||
Texture,
|
||||
/// Visual colors and patterns
|
||||
Color,
|
||||
/// Thermal impressions (hot, cold, warm)
|
||||
Temperature,
|
||||
/// Auditory impressions
|
||||
Sound,
|
||||
/// Olfactory impressions
|
||||
Smell,
|
||||
/// Taste impressions
|
||||
Taste,
|
||||
/// Size/scale impressions (large, small, vast)
|
||||
Dimension,
|
||||
/// Luminosity (bright, dark, glowing)
|
||||
Luminosity,
|
||||
}
|
||||
|
||||
/// Stage II data: Sensory impressions.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StageIIData {
|
||||
/// Sensory impressions as modality-descriptor pairs.
|
||||
pub impressions: Vec<(SensoryModality, String)>,
|
||||
/// Raw sensory feature vector (encoded from descriptors).
|
||||
pub feature_vector: Option<Vec<f32>>,
|
||||
}
|
||||
|
||||
/// Stage III data: Dimensional and spatial relationships.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StageIIIData {
|
||||
/// Spatial sketch as a set of named geometric primitives.
|
||||
pub sketch_elements: Vec<SketchElement>,
|
||||
/// Spatial relationships between elements.
|
||||
pub relationships: Vec<SpatialRelationship>,
|
||||
}
|
||||
|
||||
/// A geometric element in a Stage III sketch.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SketchElement {
|
||||
/// Unique label for this element.
|
||||
pub label: String,
|
||||
/// Type of geometric primitive.
|
||||
pub kind: GeometricKind,
|
||||
/// Position in sketch space (x, y).
|
||||
pub position: (f32, f32),
|
||||
/// Optional size/scale.
|
||||
pub scale: Option<f32>,
|
||||
}
|
||||
|
||||
/// Types of geometric primitives.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum GeometricKind {
|
||||
Point,
|
||||
Line,
|
||||
Curve,
|
||||
Rectangle,
|
||||
Circle,
|
||||
Triangle,
|
||||
Polygon,
|
||||
Freeform,
|
||||
}
|
||||
|
||||
/// Spatial relationship between two sketch elements.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SpatialRelationship {
|
||||
/// Source element label.
|
||||
pub from: String,
|
||||
/// Target element label.
|
||||
pub to: String,
|
||||
/// Relationship type.
|
||||
pub relation: SpatialRelationType,
|
||||
/// Strength of the relationship (0.0 - 1.0).
|
||||
pub strength: f32,
|
||||
}
|
||||
|
||||
/// Types of spatial relationships.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum SpatialRelationType {
|
||||
Adjacent,
|
||||
Contains,
|
||||
Above,
|
||||
Below,
|
||||
Inside,
|
||||
Surrounding,
|
||||
Connected,
|
||||
Separated,
|
||||
}
|
||||
|
||||
/// Stage IV data: Emotional, aesthetic, and intangible impressions.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StageIVData {
|
||||
/// Emotional impact descriptors with intensity.
|
||||
pub emotional_impact: Vec<(String, f32)>,
|
||||
/// Tangible object impressions.
|
||||
pub tangibles: Vec<String>,
|
||||
/// Intangible concept impressions (purpose, function, significance).
|
||||
pub intangibles: Vec<String>,
|
||||
/// Analytical overlay detections with timestamps.
|
||||
pub aol_detections: Vec<AOLDetection>,
|
||||
}
|
||||
|
||||
/// An analytical overlay (AOL) detection event.
|
||||
///
|
||||
/// AOL occurs when the viewer's analytical mind attempts to assign
|
||||
/// a known label/concept to incoming signal line data, potentially
|
||||
/// contaminating the raw perception.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AOLDetection {
|
||||
/// The AOL content (what the viewer's mind jumped to).
|
||||
pub content: String,
|
||||
/// Timestamp within the session (milliseconds from start).
|
||||
pub timestamp_ms: u64,
|
||||
/// Whether it was flagged and set aside ("AOL break").
|
||||
pub flagged: bool,
|
||||
/// Anomaly score from spike rate analysis (0.0 - 1.0).
|
||||
/// Higher scores indicate stronger AOL contamination.
|
||||
pub anomaly_score: f32,
|
||||
}
|
||||
|
||||
/// Stage V data: Interrogation and cross-referencing results.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StageVData {
|
||||
/// Probe queries and their results.
|
||||
pub probes: Vec<SignalLineProbe>,
|
||||
/// Cross-references to data from earlier stages.
|
||||
pub cross_references: Vec<CrossReference>,
|
||||
}
|
||||
|
||||
/// A signal line probe query.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SignalLineProbe {
|
||||
/// The question or aspect being probed.
|
||||
pub query: String,
|
||||
/// Stage being interrogated.
|
||||
pub target_stage: u8,
|
||||
/// Resulting soft attention weights over candidates.
|
||||
pub attention_weights: Vec<f32>,
|
||||
/// Top-k candidate indices from differentiable search.
|
||||
pub top_candidates: Vec<usize>,
|
||||
}
|
||||
|
||||
/// A cross-reference between stage data entries.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CrossReference {
|
||||
/// Source stage number.
|
||||
pub from_stage: u8,
|
||||
/// Source entry index.
|
||||
pub from_entry: usize,
|
||||
/// Target stage number.
|
||||
pub to_stage: u8,
|
||||
/// Target entry index.
|
||||
pub to_entry: usize,
|
||||
/// Similarity/relevance score.
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
/// Stage VI data: Composite 3D model from accumulated session data.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StageVIData {
|
||||
/// Cluster partitions discovered by MinCut.
|
||||
pub partitions: Vec<TargetPartition>,
|
||||
/// Overall composite descriptor.
|
||||
pub composite_description: String,
|
||||
/// Confidence scores per partition.
|
||||
pub partition_confidence: Vec<f32>,
|
||||
}
|
||||
|
||||
/// A partition of the target, representing a distinct aspect or component.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TargetPartition {
|
||||
/// Human-readable label for this partition.
|
||||
pub label: String,
|
||||
/// Stage data entry indices that belong to this partition.
|
||||
pub member_entries: Vec<(u8, usize)>,
|
||||
/// Centroid embedding of this partition.
|
||||
pub centroid: Vec<f32>,
|
||||
/// MinCut value separating this partition from others.
|
||||
pub separation_strength: f32,
|
||||
}
|
||||
|
||||
/// A complete CRV session entry stored in the database.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CrvSessionEntry {
|
||||
/// Session identifier.
|
||||
pub session_id: SessionId,
|
||||
/// Target coordinate.
|
||||
pub coordinate: TargetCoordinate,
|
||||
/// CRV stage (1-6).
|
||||
pub stage: u8,
|
||||
/// Embedding vector for this entry.
|
||||
pub embedding: Vec<f32>,
|
||||
/// Arbitrary metadata.
|
||||
pub metadata: HashMap<String, serde_json::Value>,
|
||||
/// Timestamp in milliseconds.
|
||||
pub timestamp_ms: u64,
|
||||
}
|
||||
|
||||
/// Configuration for CRV session processing.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CrvConfig {
|
||||
/// Embedding dimensionality.
|
||||
pub dimensions: usize,
|
||||
/// Curvature for Poincare ball (Stage I). Positive value.
|
||||
pub curvature: f32,
|
||||
/// AOL anomaly detection threshold (Stage IV).
|
||||
pub aol_threshold: f32,
|
||||
/// SNN refractory period in ms (Stage IV).
|
||||
pub refractory_period_ms: f64,
|
||||
/// SNN time step in ms (Stage IV).
|
||||
pub snn_dt: f64,
|
||||
/// Differentiable search temperature (Stage V).
|
||||
pub search_temperature: f32,
|
||||
/// Convergence threshold for cross-session matching.
|
||||
pub convergence_threshold: f32,
|
||||
}
|
||||
|
||||
impl Default for CrvConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
dimensions: 384,
|
||||
curvature: 1.0,
|
||||
aol_threshold: 0.7,
|
||||
refractory_period_ms: 50.0,
|
||||
snn_dt: 1.0,
|
||||
search_temperature: 1.0,
|
||||
convergence_threshold: 0.75,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of a convergence analysis across multiple sessions.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ConvergenceResult {
|
||||
/// Session pairs that converged.
|
||||
pub session_pairs: Vec<(SessionId, SessionId)>,
|
||||
/// Convergence scores per pair.
|
||||
pub scores: Vec<f32>,
|
||||
/// Stages where convergence was strongest.
|
||||
pub convergent_stages: Vec<u8>,
|
||||
/// Merged embedding representing the consensus signal.
|
||||
pub consensus_embedding: Option<Vec<f32>>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_gestalt_type_all() {
|
||||
let all = GestaltType::all();
|
||||
assert_eq!(all.len(), 6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gestalt_type_index() {
|
||||
assert_eq!(GestaltType::Manmade.index(), 0);
|
||||
assert_eq!(GestaltType::Land.index(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = CrvConfig::default();
|
||||
assert_eq!(config.dimensions, 384);
|
||||
assert_eq!(config.curvature, 1.0);
|
||||
assert_eq!(config.aol_threshold, 0.7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_entry_serialization() {
|
||||
let entry = CrvSessionEntry {
|
||||
session_id: "sess-001".to_string(),
|
||||
coordinate: "1234-5678".to_string(),
|
||||
stage: 1,
|
||||
embedding: vec![0.1, 0.2, 0.3],
|
||||
metadata: HashMap::new(),
|
||||
timestamp_ms: 1000,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&entry).unwrap();
|
||||
let deserialized: CrvSessionEntry = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized.session_id, "sess-001");
|
||||
assert_eq!(deserialized.stage, 1);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user