Compare commits
7 Commits
feat/rust-
...
feat/windo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e245ca8a4 | ||
|
|
45f0304d52 | ||
|
|
4cabffa726 | ||
|
|
3e06970428 | ||
|
|
add9f192aa | ||
|
|
fc409dfd6a | ||
|
|
1192de951a |
138
.dockerignore
138
.dockerignore
@@ -1,132 +1,8 @@
|
||||
# Git
|
||||
.git
|
||||
.gitignore
|
||||
.gitattributes
|
||||
|
||||
# Documentation
|
||||
*.md
|
||||
docs/
|
||||
references/
|
||||
plans/
|
||||
|
||||
# Development files
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# Virtual environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Testing
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
.pytest_cache/
|
||||
htmlcov/
|
||||
.nox/
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# Environments
|
||||
.env.local
|
||||
.env.development
|
||||
.env.test
|
||||
.env.production
|
||||
|
||||
# Logs
|
||||
logs/
|
||||
target/
|
||||
.git/
|
||||
*.log
|
||||
|
||||
# Runtime data
|
||||
pids/
|
||||
*.pid
|
||||
*.seed
|
||||
*.pid.lock
|
||||
|
||||
# Temporary files
|
||||
tmp/
|
||||
temp/
|
||||
.tmp/
|
||||
|
||||
# OS generated files
|
||||
.DS_Store
|
||||
.DS_Store?
|
||||
._*
|
||||
.Spotlight-V100
|
||||
.Trashes
|
||||
ehthumbs.db
|
||||
Thumbs.db
|
||||
|
||||
# IDE
|
||||
*.sublime-project
|
||||
*.sublime-workspace
|
||||
|
||||
# Deployment
|
||||
docker-compose*.yml
|
||||
Dockerfile*
|
||||
.dockerignore
|
||||
k8s/
|
||||
terraform/
|
||||
ansible/
|
||||
monitoring/
|
||||
logging/
|
||||
|
||||
# CI/CD
|
||||
.github/
|
||||
.gitlab-ci.yml
|
||||
|
||||
# Models (exclude large model files from build context)
|
||||
*.pth
|
||||
*.pt
|
||||
*.onnx
|
||||
models/*.bin
|
||||
models/*.safetensors
|
||||
|
||||
# Data files
|
||||
data/
|
||||
*.csv
|
||||
*.json
|
||||
*.parquet
|
||||
|
||||
# Backup files
|
||||
*.bak
|
||||
*.backup
|
||||
__pycache__/
|
||||
*.pyc
|
||||
.env
|
||||
node_modules/
|
||||
.claude/
|
||||
|
||||
104
Dockerfile
104
Dockerfile
@@ -1,104 +0,0 @@
|
||||
# Multi-stage build for WiFi-DensePose production deployment
|
||||
FROM python:3.11-slim as base
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
PIP_NO_CACHE_DIR=1 \
|
||||
PIP_DISABLE_PIP_VERSION_CHECK=1
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
build-essential \
|
||||
curl \
|
||||
git \
|
||||
libopencv-dev \
|
||||
python3-opencv \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create app user
|
||||
RUN groupadd -r appuser && useradd -r -g appuser appuser
|
||||
|
||||
# Set work directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy requirements first for better caching
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Development stage
|
||||
FROM base as development
|
||||
|
||||
# Install development dependencies
|
||||
RUN pip install --no-cache-dir \
|
||||
pytest \
|
||||
pytest-asyncio \
|
||||
pytest-mock \
|
||||
pytest-benchmark \
|
||||
black \
|
||||
flake8 \
|
||||
mypy
|
||||
|
||||
# Copy source code
|
||||
COPY . .
|
||||
|
||||
# Change ownership to app user
|
||||
RUN chown -R appuser:appuser /app
|
||||
|
||||
USER appuser
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Development command
|
||||
CMD ["uvicorn", "v1.src.api.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]
|
||||
|
||||
# Production stage
|
||||
FROM base as production
|
||||
|
||||
# Copy only necessary files
|
||||
COPY requirements.txt .
|
||||
COPY v1/src/ ./v1/src/
|
||||
COPY assets/ ./assets/
|
||||
|
||||
# Create necessary directories
|
||||
RUN mkdir -p /app/logs /app/data /app/models
|
||||
|
||||
# Change ownership to app user
|
||||
RUN chown -R appuser:appuser /app
|
||||
|
||||
USER appuser
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Production command
|
||||
CMD ["uvicorn", "v1.src.api.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
|
||||
|
||||
# Testing stage
|
||||
FROM development as testing
|
||||
|
||||
# Copy test files
|
||||
COPY v1/tests/ ./v1/tests/
|
||||
|
||||
# Run tests
|
||||
RUN python -m pytest v1/tests/ -v
|
||||
|
||||
# Security scanning stage
|
||||
FROM production as security
|
||||
|
||||
# Install security scanning tools
|
||||
USER root
|
||||
RUN pip install --no-cache-dir safety bandit
|
||||
|
||||
# Run security scans
|
||||
RUN safety check
|
||||
RUN bandit -r v1/src/ -f json -o /tmp/bandit-report.json
|
||||
|
||||
USER appuser
|
||||
BIN
assets/screenshot.png
Normal file
BIN
assets/screenshot.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 401 KiB |
@@ -1,306 +0,0 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
wifi-densepose:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
target: production
|
||||
image: wifi-densepose:latest
|
||||
container_name: wifi-densepose-prod
|
||||
ports:
|
||||
- "8000:8000"
|
||||
volumes:
|
||||
- wifi_densepose_logs:/app/logs
|
||||
- wifi_densepose_data:/app/data
|
||||
- wifi_densepose_models:/app/models
|
||||
environment:
|
||||
- ENVIRONMENT=production
|
||||
- DEBUG=false
|
||||
- LOG_LEVEL=info
|
||||
- RELOAD=false
|
||||
- WORKERS=4
|
||||
- ENABLE_TEST_ENDPOINTS=false
|
||||
- ENABLE_AUTHENTICATION=true
|
||||
- ENABLE_RATE_LIMITING=true
|
||||
- DATABASE_URL=${DATABASE_URL}
|
||||
- REDIS_URL=${REDIS_URL}
|
||||
- SECRET_KEY=${SECRET_KEY}
|
||||
- JWT_SECRET=${JWT_SECRET}
|
||||
- ALLOWED_HOSTS=${ALLOWED_HOSTS}
|
||||
secrets:
|
||||
- db_password
|
||||
- redis_password
|
||||
- jwt_secret
|
||||
- api_key
|
||||
deploy:
|
||||
replicas: 3
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
delay: 5s
|
||||
max_attempts: 3
|
||||
window: 120s
|
||||
update_config:
|
||||
parallelism: 1
|
||||
delay: 10s
|
||||
failure_action: rollback
|
||||
monitor: 60s
|
||||
max_failure_ratio: 0.3
|
||||
rollback_config:
|
||||
parallelism: 1
|
||||
delay: 0s
|
||||
failure_action: pause
|
||||
monitor: 60s
|
||||
max_failure_ratio: 0.3
|
||||
resources:
|
||||
limits:
|
||||
cpus: '2.0'
|
||||
memory: 4G
|
||||
reservations:
|
||||
cpus: '1.0'
|
||||
memory: 2G
|
||||
networks:
|
||||
- wifi-densepose-network
|
||||
- monitoring-network
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 60s
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "10m"
|
||||
max-file: "3"
|
||||
|
||||
postgres:
|
||||
image: postgres:15-alpine
|
||||
container_name: wifi-densepose-postgres-prod
|
||||
environment:
|
||||
- POSTGRES_DB=${POSTGRES_DB}
|
||||
- POSTGRES_USER=${POSTGRES_USER}
|
||||
- POSTGRES_PASSWORD_FILE=/run/secrets/db_password
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
- ./scripts/init-db.sql:/docker-entrypoint-initdb.d/init-db.sql
|
||||
- ./backups:/backups
|
||||
secrets:
|
||||
- db_password
|
||||
deploy:
|
||||
replicas: 1
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
delay: 5s
|
||||
max_attempts: 3
|
||||
resources:
|
||||
limits:
|
||||
cpus: '1.0'
|
||||
memory: 2G
|
||||
reservations:
|
||||
cpus: '0.5'
|
||||
memory: 1G
|
||||
networks:
|
||||
- wifi-densepose-network
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "10m"
|
||||
max-file: "3"
|
||||
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
container_name: wifi-densepose-redis-prod
|
||||
command: redis-server --appendonly yes --requirepass-file /run/secrets/redis_password
|
||||
volumes:
|
||||
- redis_data:/data
|
||||
secrets:
|
||||
- redis_password
|
||||
deploy:
|
||||
replicas: 1
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
delay: 5s
|
||||
max_attempts: 3
|
||||
resources:
|
||||
limits:
|
||||
cpus: '0.5'
|
||||
memory: 1G
|
||||
reservations:
|
||||
cpus: '0.25'
|
||||
memory: 512M
|
||||
networks:
|
||||
- wifi-densepose-network
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
|
||||
interval: 10s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "10m"
|
||||
max-file: "3"
|
||||
|
||||
nginx:
|
||||
image: nginx:alpine
|
||||
container_name: wifi-densepose-nginx-prod
|
||||
volumes:
|
||||
- ./nginx/nginx.prod.conf:/etc/nginx/nginx.conf
|
||||
- ./nginx/ssl:/etc/nginx/ssl
|
||||
- nginx_logs:/var/log/nginx
|
||||
ports:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
deploy:
|
||||
replicas: 2
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
delay: 5s
|
||||
max_attempts: 3
|
||||
resources:
|
||||
limits:
|
||||
cpus: '0.5'
|
||||
memory: 512M
|
||||
reservations:
|
||||
cpus: '0.25'
|
||||
memory: 256M
|
||||
networks:
|
||||
- wifi-densepose-network
|
||||
depends_on:
|
||||
- wifi-densepose
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "10m"
|
||||
max-file: "3"
|
||||
|
||||
prometheus:
|
||||
image: prom/prometheus:latest
|
||||
container_name: wifi-densepose-prometheus-prod
|
||||
command:
|
||||
- '--config.file=/etc/prometheus/prometheus.yml'
|
||||
- '--storage.tsdb.path=/prometheus'
|
||||
- '--web.console.libraries=/etc/prometheus/console_libraries'
|
||||
- '--web.console.templates=/etc/prometheus/consoles'
|
||||
- '--storage.tsdb.retention.time=15d'
|
||||
- '--web.enable-lifecycle'
|
||||
- '--web.enable-admin-api'
|
||||
volumes:
|
||||
- ./monitoring/prometheus-config.yml:/etc/prometheus/prometheus.yml
|
||||
- ./monitoring/alerting-rules.yml:/etc/prometheus/alerting-rules.yml
|
||||
- prometheus_data:/prometheus
|
||||
deploy:
|
||||
replicas: 1
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
delay: 5s
|
||||
max_attempts: 3
|
||||
resources:
|
||||
limits:
|
||||
cpus: '1.0'
|
||||
memory: 2G
|
||||
reservations:
|
||||
cpus: '0.5'
|
||||
memory: 1G
|
||||
networks:
|
||||
- monitoring-network
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:9090/-/healthy"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "10m"
|
||||
max-file: "3"
|
||||
|
||||
grafana:
|
||||
image: grafana/grafana:latest
|
||||
container_name: wifi-densepose-grafana-prod
|
||||
environment:
|
||||
- GF_SECURITY_ADMIN_PASSWORD_FILE=/run/secrets/grafana_password
|
||||
- GF_USERS_ALLOW_SIGN_UP=false
|
||||
- GF_INSTALL_PLUGINS=grafana-piechart-panel
|
||||
volumes:
|
||||
- grafana_data:/var/lib/grafana
|
||||
- ./monitoring/grafana-dashboard.json:/etc/grafana/provisioning/dashboards/dashboard.json
|
||||
- ./monitoring/grafana-datasources.yml:/etc/grafana/provisioning/datasources/datasources.yml
|
||||
secrets:
|
||||
- grafana_password
|
||||
deploy:
|
||||
replicas: 1
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
delay: 5s
|
||||
max_attempts: 3
|
||||
resources:
|
||||
limits:
|
||||
cpus: '0.5'
|
||||
memory: 1G
|
||||
reservations:
|
||||
cpus: '0.25'
|
||||
memory: 512M
|
||||
networks:
|
||||
- monitoring-network
|
||||
depends_on:
|
||||
- prometheus
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:3000/api/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "10m"
|
||||
max-file: "3"
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
driver: local
|
||||
redis_data:
|
||||
driver: local
|
||||
prometheus_data:
|
||||
driver: local
|
||||
grafana_data:
|
||||
driver: local
|
||||
wifi_densepose_logs:
|
||||
driver: local
|
||||
wifi_densepose_data:
|
||||
driver: local
|
||||
wifi_densepose_models:
|
||||
driver: local
|
||||
nginx_logs:
|
||||
driver: local
|
||||
|
||||
networks:
|
||||
wifi-densepose-network:
|
||||
driver: overlay
|
||||
attachable: true
|
||||
monitoring-network:
|
||||
driver: overlay
|
||||
attachable: true
|
||||
|
||||
secrets:
|
||||
db_password:
|
||||
external: true
|
||||
redis_password:
|
||||
external: true
|
||||
jwt_secret:
|
||||
external: true
|
||||
api_key:
|
||||
external: true
|
||||
grafana_password:
|
||||
external: true
|
||||
@@ -1,141 +0,0 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
wifi-densepose:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
target: development
|
||||
container_name: wifi-densepose-dev
|
||||
ports:
|
||||
- "8000:8000"
|
||||
volumes:
|
||||
- .:/app
|
||||
- wifi_densepose_logs:/app/logs
|
||||
- wifi_densepose_data:/app/data
|
||||
- wifi_densepose_models:/app/models
|
||||
environment:
|
||||
- ENVIRONMENT=development
|
||||
- DEBUG=true
|
||||
- LOG_LEVEL=debug
|
||||
- RELOAD=true
|
||||
- ENABLE_TEST_ENDPOINTS=true
|
||||
- ENABLE_AUTHENTICATION=false
|
||||
- ENABLE_RATE_LIMITING=false
|
||||
- DATABASE_URL=postgresql://wifi_user:wifi_pass@postgres:5432/wifi_densepose
|
||||
- REDIS_URL=redis://redis:6379/0
|
||||
depends_on:
|
||||
- postgres
|
||||
- redis
|
||||
networks:
|
||||
- wifi-densepose-network
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 40s
|
||||
|
||||
postgres:
|
||||
image: postgres:15-alpine
|
||||
container_name: wifi-densepose-postgres
|
||||
environment:
|
||||
- POSTGRES_DB=wifi_densepose
|
||||
- POSTGRES_USER=wifi_user
|
||||
- POSTGRES_PASSWORD=wifi_pass
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
- ./scripts/init-db.sql:/docker-entrypoint-initdb.d/init-db.sql
|
||||
ports:
|
||||
- "5432:5432"
|
||||
networks:
|
||||
- wifi-densepose-network
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U wifi_user -d wifi_densepose"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
container_name: wifi-densepose-redis
|
||||
command: redis-server --appendonly yes --requirepass redis_pass
|
||||
volumes:
|
||||
- redis_data:/data
|
||||
ports:
|
||||
- "6379:6379"
|
||||
networks:
|
||||
- wifi-densepose-network
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
|
||||
interval: 10s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
|
||||
prometheus:
|
||||
image: prom/prometheus:latest
|
||||
container_name: wifi-densepose-prometheus
|
||||
command:
|
||||
- '--config.file=/etc/prometheus/prometheus.yml'
|
||||
- '--storage.tsdb.path=/prometheus'
|
||||
- '--web.console.libraries=/etc/prometheus/console_libraries'
|
||||
- '--web.console.templates=/etc/prometheus/consoles'
|
||||
- '--storage.tsdb.retention.time=200h'
|
||||
- '--web.enable-lifecycle'
|
||||
volumes:
|
||||
- ./monitoring/prometheus-config.yml:/etc/prometheus/prometheus.yml
|
||||
- prometheus_data:/prometheus
|
||||
ports:
|
||||
- "9090:9090"
|
||||
networks:
|
||||
- wifi-densepose-network
|
||||
restart: unless-stopped
|
||||
|
||||
grafana:
|
||||
image: grafana/grafana:latest
|
||||
container_name: wifi-densepose-grafana
|
||||
environment:
|
||||
- GF_SECURITY_ADMIN_PASSWORD=admin
|
||||
- GF_USERS_ALLOW_SIGN_UP=false
|
||||
volumes:
|
||||
- grafana_data:/var/lib/grafana
|
||||
- ./monitoring/grafana-dashboard.json:/etc/grafana/provisioning/dashboards/dashboard.json
|
||||
- ./monitoring/grafana-datasources.yml:/etc/grafana/provisioning/datasources/datasources.yml
|
||||
ports:
|
||||
- "3000:3000"
|
||||
networks:
|
||||
- wifi-densepose-network
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
- prometheus
|
||||
|
||||
nginx:
|
||||
image: nginx:alpine
|
||||
container_name: wifi-densepose-nginx
|
||||
volumes:
|
||||
- ./nginx/nginx.conf:/etc/nginx/nginx.conf
|
||||
- ./nginx/ssl:/etc/nginx/ssl
|
||||
ports:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
networks:
|
||||
- wifi-densepose-network
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
- wifi-densepose
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
redis_data:
|
||||
prometheus_data:
|
||||
grafana_data:
|
||||
wifi_densepose_logs:
|
||||
wifi_densepose_data:
|
||||
wifi_densepose_models:
|
||||
|
||||
networks:
|
||||
wifi-densepose-network:
|
||||
driver: bridge
|
||||
9
docker/.dockerignore
Normal file
9
docker/.dockerignore
Normal file
@@ -0,0 +1,9 @@
|
||||
target/
|
||||
.git/
|
||||
*.md
|
||||
*.log
|
||||
__pycache__/
|
||||
*.pyc
|
||||
.env
|
||||
node_modules/
|
||||
.claude/
|
||||
29
docker/Dockerfile.python
Normal file
29
docker/Dockerfile.python
Normal file
@@ -0,0 +1,29 @@
|
||||
# WiFi-DensePose Python Sensing Pipeline
|
||||
# RSSI-based presence/motion detection + WebSocket server
|
||||
|
||||
FROM python:3.11-slim-bookworm
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python dependencies
|
||||
COPY v1/requirements-lock.txt /app/requirements.txt
|
||||
RUN pip install --no-cache-dir -r requirements.txt \
|
||||
&& pip install --no-cache-dir websockets uvicorn fastapi
|
||||
|
||||
# Copy application code
|
||||
COPY v1/ /app/v1/
|
||||
COPY ui/ /app/ui/
|
||||
|
||||
# Copy sensing modules
|
||||
COPY v1/src/sensing/ /app/v1/src/sensing/
|
||||
|
||||
EXPOSE 8765
|
||||
EXPOSE 8080
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
CMD ["python", "-m", "v1.src.sensing.ws_server"]
|
||||
46
docker/Dockerfile.rust
Normal file
46
docker/Dockerfile.rust
Normal file
@@ -0,0 +1,46 @@
|
||||
# WiFi-DensePose Rust Sensing Server
|
||||
# Includes RuVector signal intelligence crates
|
||||
# Multi-stage build for minimal final image
|
||||
|
||||
# Stage 1: Build
|
||||
FROM rust:1.85-bookworm AS builder
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
# Copy workspace files
|
||||
COPY rust-port/wifi-densepose-rs/Cargo.toml rust-port/wifi-densepose-rs/Cargo.lock ./
|
||||
COPY rust-port/wifi-densepose-rs/crates/ ./crates/
|
||||
|
||||
# Copy vendored RuVector crates
|
||||
COPY vendor/ruvector/ /build/vendor/ruvector/
|
||||
|
||||
# Build release binary
|
||||
RUN cargo build --release -p wifi-densepose-sensing-server 2>&1 \
|
||||
&& strip target/release/sensing-server
|
||||
|
||||
# Stage 2: Runtime
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy binary
|
||||
COPY --from=builder /build/target/release/sensing-server /app/sensing-server
|
||||
|
||||
# Copy UI assets
|
||||
COPY ui/ /app/ui/
|
||||
|
||||
# HTTP API
|
||||
EXPOSE 3000
|
||||
# WebSocket
|
||||
EXPOSE 3001
|
||||
# ESP32 UDP
|
||||
EXPOSE 5005/udp
|
||||
|
||||
ENV RUST_LOG=info
|
||||
|
||||
ENTRYPOINT ["/app/sensing-server"]
|
||||
CMD ["--source", "simulated", "--tick-ms", "100", "--ui-path", "/app/ui"]
|
||||
26
docker/docker-compose.yml
Normal file
26
docker/docker-compose.yml
Normal file
@@ -0,0 +1,26 @@
|
||||
version: "3.9"
|
||||
|
||||
services:
|
||||
sensing-server:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/Dockerfile.rust
|
||||
image: ruvnet/wifi-densepose:latest
|
||||
ports:
|
||||
- "3000:3000" # REST API
|
||||
- "3001:3001" # WebSocket
|
||||
- "5005:5005/udp" # ESP32 UDP
|
||||
environment:
|
||||
- RUST_LOG=info
|
||||
command: ["--source", "simulated", "--tick-ms", "100", "--ui-path", "/app/ui"]
|
||||
|
||||
python-sensing:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/Dockerfile.python
|
||||
image: ruvnet/wifi-densepose:python
|
||||
ports:
|
||||
- "8765:8765" # WebSocket
|
||||
- "8080:8080" # UI
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
BIN
docker/wifi-densepose-v1.rvf
Normal file
BIN
docker/wifi-densepose-v1.rvf
Normal file
Binary file not shown.
1092
docs/adr/ADR-021-vital-sign-detection-rvdna-pipeline.md
Normal file
1092
docs/adr/ADR-021-vital-sign-detection-rvdna-pipeline.md
Normal file
File diff suppressed because it is too large
Load Diff
1357
docs/adr/ADR-022-windows-wifi-enhanced-fidelity-ruvector.md
Normal file
1357
docs/adr/ADR-022-windows-wifi-enhanced-fidelity-ruvector.md
Normal file
File diff suppressed because it is too large
Load Diff
825
docs/adr/ADR-023-trained-densepose-model-ruvector-pipeline.md
Normal file
825
docs/adr/ADR-023-trained-densepose-model-ruvector-pipeline.md
Normal file
@@ -0,0 +1,825 @@
|
||||
# ADR-023: Trained DensePose Model with RuVector Signal Intelligence Pipeline
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| **Status** | Proposed |
|
||||
| **Date** | 2026-02-28 |
|
||||
| **Deciders** | ruv |
|
||||
| **Relates to** | ADR-003 (RVF Cognitive Containers), ADR-005 (SONA Self-Learning), ADR-015 (Public Dataset Strategy), ADR-016 (RuVector Integration), ADR-017 (RuVector-Signal-MAT), ADR-020 (Rust AI Migration), ADR-021 (Vital Sign Detection) |
|
||||
|
||||
## Context
|
||||
|
||||
### The Gap Between Sensing and DensePose
|
||||
|
||||
The WiFi-DensePose system currently operates in two distinct modes:
|
||||
|
||||
1. **WiFi CSI sensing** (working): ESP32 streams CSI frames → Rust aggregator → feature extraction → presence/motion classification. 41 tests passing, verified at ~20 Hz with real hardware.
|
||||
|
||||
2. **Heuristic pose derivation** (working but approximate): The Rust sensing server generates 17 COCO keypoints from WiFi signal properties using hand-crafted rules (`derive_pose_from_sensing()` in `sensing-server/src/main.rs`). This is not a trained model — keypoint positions are derived from signal amplitude, phase variance, and motion metrics rather than learned from labeled data.
|
||||
|
||||
Neither mode produces **DensePose-quality** body surface estimation. The CMU "DensePose From WiFi" paper (arXiv:2301.00250) demonstrated that a neural network trained on paired WiFi CSI + camera pose data can produce dense body surface UV coordinates from WiFi alone. However, that approach requires:
|
||||
|
||||
- **Environment-specific training**: The model must be trained or fine-tuned for each deployment environment because CSI multipath patterns are environment-dependent.
|
||||
- **Paired training data**: Simultaneous WiFi CSI captures + ground-truth pose annotations (or a camera-based teacher model generating pseudo-labels).
|
||||
- **Substantial compute**: Training a modality translation network + DensePose head requires GPU time (hours to days depending on dataset size).
|
||||
|
||||
### What Exists in the Codebase
|
||||
|
||||
The Rust workspace already has the complete model architecture ready for training:
|
||||
|
||||
| Component | Crate | File | Status |
|
||||
|-----------|-------|------|--------|
|
||||
| `WiFiDensePoseModel` | `wifi-densepose-train` | `model.rs` | Implemented (random weights) |
|
||||
| `ModalityTranslator` | `wifi-densepose-train` | `model.rs` | Implemented with RuVector attention |
|
||||
| `KeypointHead` | `wifi-densepose-train` | `model.rs` | Implemented (17 COCO heatmaps) |
|
||||
| `DensePoseHead` | `wifi-densepose-nn` | `densepose.rs` | Implemented (25 parts + 48 UV) |
|
||||
| `WiFiDensePoseLoss` | `wifi-densepose-train` | `losses.rs` | Implemented (keypoint + part + UV + transfer) |
|
||||
| `MmFiDataset` loader | `wifi-densepose-train` | `dataset.rs` | Planned (ADR-015) |
|
||||
| `WiFiDensePosePipeline` | `wifi-densepose-nn` | `inference.rs` | Implemented (generic over Backend) |
|
||||
| Training proof verification | `wifi-densepose-train` | `proof.rs` | Implemented (deterministic hash) |
|
||||
| Subcarrier resampling (114→56) | `wifi-densepose-train` | `subcarrier.rs` | Planned (ADR-016) |
|
||||
|
||||
### RuVector Crates Available
|
||||
|
||||
The `vendor/ruvector/` subtree provides 90+ crates. The following are directly relevant to a trained DensePose pipeline:
|
||||
|
||||
**Already integrated (5 crates, ADR-016):**
|
||||
|
||||
| Crate | Algorithm | Current Use |
|
||||
|-------|-----------|-------------|
|
||||
| `ruvector-mincut` | Subpolynomial dynamic min-cut O(n^{o(1)}) | Multi-person assignment in `metrics.rs` |
|
||||
| `ruvector-attn-mincut` | Attention-gated min-cut | Noise-suppressed spectrogram in `model.rs` |
|
||||
| `ruvector-attention` | Scaled dot-product + geometric attention | Spatial decoder in `model.rs` |
|
||||
| `ruvector-solver` | Sparse Neumann solver O(√n) | Subcarrier resampling in `subcarrier.rs` |
|
||||
| `ruvector-temporal-tensor` | Tiered temporal compression | CSI frame buffering in `dataset.rs` |
|
||||
|
||||
**Newly proposed for DensePose pipeline (6 additional crates):**
|
||||
|
||||
| Crate | Description | Proposed Use |
|
||||
|-------|-------------|-------------|
|
||||
| `ruvector-gnn` | Graph neural network on HNSW topology | Spatial body-graph reasoning |
|
||||
| `ruvector-graph-transformer` | Proof-gated graph transformer (8 modules) | CSI-to-pose cross-attention |
|
||||
| `ruvector-sparse-inference` | PowerInfer-style sparse inference engine | Edge deployment with neuron activation sparsity |
|
||||
| `ruvector-sona` | Self-Optimizing Neural Architecture (LoRA + EWC++) | Online environment adaptation |
|
||||
| `ruvector-fpga-transformer` | FPGA-optimized transformer | Hardware-accelerated inference path |
|
||||
| `ruvector-math` | Optimal transport, information geometry | Domain adaptation loss functions |
|
||||
|
||||
### RVF Container Format
|
||||
|
||||
The RuVector Format (RVF) is a segment-based binary container format designed to package
|
||||
intelligence artifacts — embeddings, HNSW indexes, quantized weights, WASM runtimes, witness
|
||||
proofs, and metadata — into a single self-contained file. Key properties:
|
||||
|
||||
- **64-byte segment headers** (`SegmentHeader`, magic `0x52564653` "RVFS") with type discriminator, content hash, compression, and timestamp
|
||||
- **Progressive loading**: Layer A (entry points, <5ms) → Layer B (hot adjacency, 100ms–1s) → Layer C (full graph, seconds)
|
||||
- **20+ segment types**: `Vec` (embeddings), `Index` (HNSW), `Overlay` (min-cut witnesses), `Quant` (codebooks), `Witness` (proof-of-computation), `Wasm` (self-bootstrapping runtime), `Dashboard` (embedded UI), `AggregateWeights` (federated SONA deltas), `Crypto` (Ed25519 signatures), and more
|
||||
- **Temperature-tiered quantization** (`rvf-quant`): f32 / f16 / u8 / binary per-segment, with SIMD-accelerated distance computation
|
||||
- **AGI Cognitive Container** (`agi_container.rs`): packages kernel + WASM + world model + orchestrator + evaluation harness + witness chains into a single deployable file
|
||||
|
||||
The trained DensePose model will be packaged as an `.rvf` container, making it a single
|
||||
self-contained artifact that includes model weights, HNSW-indexed embedding tables, min-cut
|
||||
graph overlays, quantization codebooks, SONA adaptation deltas, and the WASM inference
|
||||
runtime — deployable to any host without external dependencies.
|
||||
|
||||
## Decision
|
||||
|
||||
Implement a fully trained DensePose model using RuVector signal intelligence as the backbone signal processing layer, packaged in the RVF container format. The pipeline has three stages: (1) offline training on public datasets, (2) teacher-student distillation for DensePose UV labels, and (3) online SONA adaptation for environment-specific fine-tuning. The trained model, its embeddings, indexes, and adaptation state are serialized into a single `.rvf` file.
|
||||
|
||||
### Architecture Overview
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────────┐
|
||||
│ TRAINED DENSEPOSE PIPELINE │
|
||||
│ │
|
||||
│ ┌─────────────┐ ┌──────────────────────┐ ┌──────────────────────┐ │
|
||||
│ │ ESP32 CSI │ │ RuVector Signal │ │ Trained Neural │ │
|
||||
│ │ Raw I/Q │───▶│ Intelligence Layer │───▶│ Network │ │
|
||||
│ │ [ant×sub×T] │ │ (preprocessing) │ │ (inference) │ │
|
||||
│ └─────────────┘ └──────────────────────┘ └──────────────────────┘ │
|
||||
│ │ │ │
|
||||
│ ┌─────────┴─────────┐ ┌────────┴────────┐ │
|
||||
│ │ 5 RuVector crates │ │ 6 RuVector │ │
|
||||
│ │ (signal processing)│ │ crates (neural) │ │
|
||||
│ └───────────────────┘ └─────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌──────────────────────────┘ │
|
||||
│ ▼ │
|
||||
│ ┌──────────────────────────────────────┐ │
|
||||
│ │ Outputs │ │
|
||||
│ │ • 17 COCO keypoints [B,17,H,W] │ │
|
||||
│ │ • 25 body parts [B,25,H,W] │ │
|
||||
│ │ • 48 UV coords [B,48,H,W] │ │
|
||||
│ │ • Confidence scores │ │
|
||||
│ └──────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Stage 1: RuVector Signal Preprocessing Layer
|
||||
|
||||
Raw CSI frames from ESP32 (56–192 subcarriers × N antennas × T time frames) are processed through the RuVector signal intelligence stack before entering the neural network. This replaces hand-crafted feature extraction with learned, graph-aware preprocessing.
|
||||
|
||||
```
|
||||
Raw CSI [ant, sub, T]
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ 1. ruvector-attn-mincut: gate_spectrogram() │
|
||||
│ Input: Q=amplitude, K=phase, V=combined │
|
||||
│ Effect: Suppress multipath noise, keep motion- │
|
||||
│ relevant subcarrier paths │
|
||||
│ Output: Gated spectrogram [ant, sub', T] │
|
||||
├─────────────────────────────────────────────────────┤
|
||||
│ 2. ruvector-mincut: mincut_subcarrier_partition() │
|
||||
│ Input: Subcarrier coherence graph │
|
||||
│ Effect: Partition into sensitive (motion- │
|
||||
│ responsive) vs insensitive (static) │
|
||||
│ Output: Partition mask + per-subcarrier weights │
|
||||
├─────────────────────────────────────────────────────┤
|
||||
│ 3. ruvector-attention: attention_weighted_bvp() │
|
||||
│ Input: Gated spectrogram + partition weights │
|
||||
│ Effect: Compute body velocity profile with │
|
||||
│ sensitivity-weighted attention │
|
||||
│ Output: BVP feature vector [D_bvp] │
|
||||
├─────────────────────────────────────────────────────┤
|
||||
│ 4. ruvector-solver: solve_fresnel_geometry() │
|
||||
│ Input: Amplitude + known TX/RX positions │
|
||||
│ Effect: Estimate TX-body-RX ellipsoid distances │
|
||||
│ Output: Fresnel geometry features [D_fresnel] │
|
||||
├─────────────────────────────────────────────────────┤
|
||||
│ 5. ruvector-temporal-tensor: compress + buffer │
|
||||
│ Input: Temporal CSI window (100 frames) │
|
||||
│ Effect: Tiered quantization (hot/warm/cold) │
|
||||
│ Output: Compressed tensor, 50-75% memory saving │
|
||||
└─────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
Feature tensor [B, T*tx*rx, sub] (preprocessed, noise-suppressed)
|
||||
```
|
||||
|
||||
### Stage 2: Neural Network Architecture
|
||||
|
||||
The neural network follows the CMU teacher-student architecture with RuVector enhancements at three critical points.
|
||||
|
||||
#### 2a. ModalityTranslator (CSI → Visual Feature Space)
|
||||
|
||||
```
|
||||
CSI features [B, T*tx*rx, sub]
|
||||
│
|
||||
├──amplitude──┐
|
||||
│ ├─► Encoder (Conv1D stack, 64→128→256)
|
||||
└──phase──────┘ │
|
||||
▼
|
||||
┌──────────────────────────────┐
|
||||
│ ruvector-graph-transformer │
|
||||
│ │
|
||||
│ Treat antenna-pair×time as │
|
||||
│ graph nodes. Edges connect │
|
||||
│ spatially adjacent antenna │
|
||||
│ pairs and temporally │
|
||||
│ adjacent frames. │
|
||||
│ │
|
||||
│ Proof-gated attention: │
|
||||
│ Each layer verifies that │
|
||||
│ attention weights satisfy │
|
||||
│ physical constraints │
|
||||
│ (Fresnel ellipsoid bounds) │
|
||||
└──────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
Decoder (ConvTranspose2d stack, 256→128→64→3)
|
||||
│
|
||||
▼
|
||||
Visual features [B, 3, 48, 48]
|
||||
```
|
||||
|
||||
**RuVector enhancement**: Replace standard multi-head self-attention in the bottleneck with `ruvector-graph-transformer`. The graph structure encodes the physical antenna topology — nodes that are closer in space (adjacent ESP32 nodes in the mesh) or time (consecutive frames) have stronger edge weights. This injects domain-specific inductive bias that standard attention lacks.
|
||||
|
||||
#### 2b. GNN Body Graph Reasoning
|
||||
|
||||
```
|
||||
Visual features [B, 3, 48, 48]
|
||||
│
|
||||
▼
|
||||
ResNet18 backbone → feature maps [B, 256, 12, 12]
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────┐
|
||||
│ ruvector-gnn: Body Graph Network │
|
||||
│ │
|
||||
│ 17 COCO keypoints as graph nodes │
|
||||
│ Edges: anatomical connections │
|
||||
│ (shoulder→elbow, hip→knee, etc.) │
|
||||
│ │
|
||||
│ GNN message passing (3 rounds): │
|
||||
│ h_i^{l+1} = σ(W·h_i^l + Σ_j α_ij·h_j)│
|
||||
│ α_ij = attention(h_i, h_j, edge_ij) │
|
||||
│ │
|
||||
│ Enforces anatomical constraints: │
|
||||
│ - Limb length ratios │
|
||||
│ - Joint angle limits │
|
||||
│ - Left-right symmetry priors │
|
||||
└─────────────────────────────────────────┘
|
||||
│
|
||||
├──────────────────┬──────────────────┐
|
||||
▼ ▼ ▼
|
||||
KeypointHead DensePoseHead ConfidenceHead
|
||||
[B,17,H,W] [B,25+48,H,W] [B,1]
|
||||
heatmaps parts + UV quality score
|
||||
```
|
||||
|
||||
**RuVector enhancement**: `ruvector-gnn` replaces the flat spatial decoder with a graph neural network that operates on the human body graph. WiFi CSI is inherently noisy — GNN message passing between anatomically connected joints enforces that predicted keypoints maintain plausible body structure even when individual joint predictions are uncertain.
|
||||
|
||||
#### 2c. Sparse Inference for Edge Deployment
|
||||
|
||||
```
|
||||
Trained model weights (full precision)
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────┐
|
||||
│ ruvector-sparse-inference │
|
||||
│ │
|
||||
│ PowerInfer-style activation sparsity: │
|
||||
│ - Profile neuron activation frequency │
|
||||
│ - Partition into hot (always active, 20%) │
|
||||
│ and cold (conditionally active, 80%) │
|
||||
│ - Hot neurons: GPU/SIMD fast path │
|
||||
│ - Cold neurons: sparse lookup on demand │
|
||||
│ │
|
||||
│ Quantization: │
|
||||
│ - Backbone: INT8 (4x memory reduction) │
|
||||
│ - DensePose head: FP16 (2x reduction) │
|
||||
│ - ModalityTranslator: FP16 │
|
||||
│ │
|
||||
│ Target: <50ms inference on ESP32-S3 │
|
||||
│ <10ms on x86 with AVX2 │
|
||||
└─────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Stage 3: Training Pipeline
|
||||
|
||||
#### 3a. Dataset Loading and Preprocessing
|
||||
|
||||
Primary dataset: **MM-Fi** (NeurIPS 2023) — 40 subjects, 27 actions, 114 subcarriers, 3 RX antennas, 17 COCO keypoints + DensePose UV annotations.
|
||||
|
||||
Secondary dataset: **Wi-Pose** — 12 subjects, 12 actions, 30 subcarriers, 3×3 antenna array, 18 keypoints.
|
||||
|
||||
```
|
||||
┌──────────────────────────────────────────────────────────┐
|
||||
│ Data Loading Pipeline │
|
||||
│ │
|
||||
│ MM-Fi .npy ──► Resample 114→56 subcarriers ──┐ │
|
||||
│ (ruvector-solver NeumannSolver) │ │
|
||||
│ ├──► Batch│
|
||||
│ Wi-Pose .mat ──► Zero-pad 30→56 subcarriers ──┘ [B,T*│
|
||||
│ ant, │
|
||||
│ Phase sanitize ──► Hampel filter ──► unwrap sub] │
|
||||
│ (wifi-densepose-signal::phase_sanitizer) │
|
||||
│ │
|
||||
│ Temporal buffer ──► ruvector-temporal-tensor │
|
||||
│ (100 frames/sample, tiered quantization) │
|
||||
└──────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
#### 3b. Teacher-Student DensePose Labels
|
||||
|
||||
For samples with 3D keypoints but no DensePose UV maps:
|
||||
|
||||
1. Run Detectron2 DensePose R-CNN on paired RGB frames (one-time preprocessing step on GPU workstation)
|
||||
2. Generate `(part_labels [H,W], u_coords [H,W], v_coords [H,W])` pseudo-labels
|
||||
3. Cache as `.npy` alongside original data
|
||||
4. Teacher model is discarded after label generation — inference uses WiFi only
|
||||
|
||||
#### 3c. Loss Function
|
||||
|
||||
```rust
|
||||
L_total = λ_kp · L_keypoint // MSE on predicted vs GT heatmaps
|
||||
+ λ_part · L_part // Cross-entropy on 25-class body part segmentation
|
||||
+ λ_uv · L_uv // Smooth L1 on UV coordinate regression
|
||||
+ λ_xfer · L_transfer // MSE between CSI features and teacher visual features
|
||||
+ λ_ot · L_ot // Optimal transport regularization (ruvector-math)
|
||||
+ λ_graph · L_graph // GNN edge consistency loss (ruvector-gnn)
|
||||
```
|
||||
|
||||
**RuVector enhancement**: `ruvector-math` provides optimal transport (Wasserstein distance) as a regularization term. This penalizes predicted body part distributions that are far from the ground truth in the Wasserstein metric, which is more geometrically meaningful than pixel-wise cross-entropy for spatial body part segmentation.
|
||||
|
||||
#### 3d. Training Configuration
|
||||
|
||||
| Parameter | Value | Rationale |
|
||||
|-----------|-------|-----------|
|
||||
| Optimizer | AdamW | Weight decay regularization |
|
||||
| Learning rate | 1e-3, cosine decay to 1e-5 | Standard for modality translation |
|
||||
| Batch size | 32 | Fits in 24GB GPU VRAM |
|
||||
| Epochs | 100 | With early stopping (patience=15) |
|
||||
| Warmup | 5 epochs | Linear LR warmup |
|
||||
| Train/val split | Subjects 1-32 / 33-40 | Subject-disjoint for generalization |
|
||||
| Augmentation | Time-shift ±5 frames, amplitude noise ±2dB, antenna dropout 10% | CSI-domain augmentations |
|
||||
| Hardware | Single RTX 3090 or A100 | ~8 hours on A100 |
|
||||
| Checkpoint | Every epoch, keep best-by-validation-PCK | Deterministic seed |
|
||||
|
||||
#### 3e. Metrics
|
||||
|
||||
| Metric | Target | Description |
|
||||
|--------|--------|-------------|
|
||||
| PCK@0.2 | >70% on MM-Fi val | Percentage of correct keypoints (threshold = 0.2 × torso diameter) |
|
||||
| OKS mAP | >0.50 on MM-Fi val | Object Keypoint Similarity, COCO-standard |
|
||||
| DensePose GPS | >0.30 on MM-Fi val | Geodesic Point Similarity for UV accuracy |
|
||||
| Inference latency | <50ms per frame | On x86 with ONNX Runtime |
|
||||
| Model size | <25MB (FP16) | Suitable for edge deployment |
|
||||
|
||||
### Stage 4: Online Adaptation with SONA
|
||||
|
||||
After offline training produces a base model, SONA enables continuous adaptation to new environments without retraining from scratch.
|
||||
|
||||
```
|
||||
┌──────────────────────────────────────────────────────────┐
|
||||
│ SONA Online Adaptation Loop │
|
||||
│ │
|
||||
│ Base model (frozen weights W) │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌──────────────────────────────────┐ │
|
||||
│ │ LoRA Adaptation Matrices │ │
|
||||
│ │ W_effective = W + α · A·B │ │
|
||||
│ │ │ │
|
||||
│ │ Rank r=4 for translator layers │ │
|
||||
│ │ Rank r=2 for backbone layers │ │
|
||||
│ │ Rank r=8 for DensePose head │ │
|
||||
│ │ │ │
|
||||
│ │ Total trainable params: ~50K │ │
|
||||
│ │ (vs ~5M frozen base) │ │
|
||||
│ └──────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌──────────────────────────────────┐ │
|
||||
│ │ EWC++ Regularizer │ │
|
||||
│ │ L = L_task + λ·Σ F_i(θ-θ*)² │ │
|
||||
│ │ │ │
|
||||
│ │ Prevents forgetting base model │ │
|
||||
│ │ knowledge when adapting to new │ │
|
||||
│ │ environment │ │
|
||||
│ └──────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ Adaptation triggers: │
|
||||
│ • First deployment in new room │
|
||||
│ • PCK drops below threshold (drift detection) │
|
||||
│ • User manually initiates calibration │
|
||||
│ • Furniture/layout change detected (CSI baseline shift) │
|
||||
│ │
|
||||
│ Adaptation data: │
|
||||
│ • Self-supervised: temporal consistency loss │
|
||||
│ (pose at t should be similar to t-1 for slow motion) │
|
||||
│ • Semi-supervised: user confirmation of presence/count │
|
||||
│ • Optional: brief camera calibration session (5 min) │
|
||||
│ │
|
||||
│ Convergence: 10-50 gradient steps, <5 seconds on CPU │
|
||||
└──────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Stage 5: Inference Pipeline (Production)
|
||||
|
||||
```
|
||||
ESP32 CSI (UDP :5005)
|
||||
│
|
||||
▼
|
||||
Rust Axum server (port 8080)
|
||||
│
|
||||
├─► RuVector signal preprocessing (Stage 1)
|
||||
│ 5 crates, ~2ms per frame
|
||||
│
|
||||
├─► ONNX Runtime inference (Stage 2)
|
||||
│ Quantized model, ~10ms per frame
|
||||
│ OR ruvector-sparse-inference, ~8ms per frame
|
||||
│
|
||||
├─► GNN post-processing (ruvector-gnn)
|
||||
│ Anatomical constraint enforcement, ~1ms
|
||||
│
|
||||
├─► SONA adaptation check (Stage 4)
|
||||
│ <0.05ms per frame (gradient accumulation only)
|
||||
│
|
||||
└─► Output: DensePose results
|
||||
│
|
||||
├──► /api/v1/stream/pose (WebSocket, 17 keypoints)
|
||||
├──► /api/v1/pose/current (REST, full DensePose)
|
||||
└──► /ws/sensing (WebSocket, raw + processed)
|
||||
```
|
||||
|
||||
Total inference budget: **<15ms per frame** at 20 Hz on x86, **<50ms** on ESP32-S3 (with sparse inference).
|
||||
|
||||
### Stage 6: RVF Model Container Format
|
||||
|
||||
The trained model is packaged as a single `.rvf` file that contains everything needed for
|
||||
inference — no external weight files, no ONNX runtime, no Python dependencies.
|
||||
|
||||
#### RVF DensePose Container Layout
|
||||
|
||||
```
|
||||
wifi-densepose-v1.rvf (single file, ~15-30 MB)
|
||||
┌───────────────────────────────────────────────────────────────┐
|
||||
│ SEGMENT 0: Manifest (0x05) │
|
||||
│ ├── Model ID: "wifi-densepose-v1.0" │
|
||||
│ ├── Training dataset: "mmfi-v1+wipose-v1" │
|
||||
│ ├── Training config hash: SHA-256 │
|
||||
│ ├── Target hardware: x86_64, aarch64, wasm32 │
|
||||
│ ├── Segment directory (offsets to all segments) │
|
||||
│ └── Level-1 TLV manifest with metadata tags │
|
||||
├───────────────────────────────────────────────────────────────┤
|
||||
│ SEGMENT 1: Vec (0x01) — Model Weight Embeddings │
|
||||
│ ├── ModalityTranslator weights [64→128→256→3, Conv1D+ConvT] │
|
||||
│ ├── ResNet18 backbone weights [3→64→128→256, residual blocks] │
|
||||
│ ├── KeypointHead weights [256→17, deconv layers] │
|
||||
│ ├── DensePoseHead weights [256→25+48, deconv layers] │
|
||||
│ ├── GNN body graph weights [3 message-passing rounds] │
|
||||
│ └── Graph transformer attention weights [proof-gated layers] │
|
||||
│ Format: flat f32 vectors, 768-dim per weight tensor │
|
||||
│ Total: ~5M parameters → ~20MB f32, ~10MB f16, ~5MB INT8 │
|
||||
├───────────────────────────────────────────────────────────────┤
|
||||
│ SEGMENT 2: Index (0x02) — HNSW Embedding Index │
|
||||
│ ├── Layer A: Entry points + coarse routing centroids │
|
||||
│ │ (loaded first, <5ms, enables approximate search) │
|
||||
│ ├── Layer B: Hot region adjacency for frequently │
|
||||
│ │ accessed weight clusters (100ms load) │
|
||||
│ └── Layer C: Full adjacency graph for exact nearest │
|
||||
│ neighbor lookup across all weight partitions │
|
||||
│ Use: Fast weight lookup for sparse inference — │
|
||||
│ only load hot neurons, skip cold neurons via HNSW routing │
|
||||
├───────────────────────────────────────────────────────────────┤
|
||||
│ SEGMENT 3: Overlay (0x03) — Dynamic Min-Cut Graph │
|
||||
│ ├── Subcarrier partition graph (sensitive vs insensitive) │
|
||||
│ ├── Min-cut witnesses from ruvector-mincut │
|
||||
│ ├── Antenna topology graph (ESP32 mesh spatial layout) │
|
||||
│ └── Body skeleton graph (17 COCO joints, 16 edges) │
|
||||
│ Use: Pre-computed graph structures loaded at init time. │
|
||||
│ Dynamic updates via ruvector-mincut insert/delete_edge │
|
||||
│ as environment changes (furniture moves, new obstacles) │
|
||||
├───────────────────────────────────────────────────────────────┤
|
||||
│ SEGMENT 4: Quant (0x06) — Quantization Codebooks │
|
||||
│ ├── INT8 codebook for backbone (4x memory reduction) │
|
||||
│ ├── FP16 scale factors for translator + heads │
|
||||
│ ├── Binary quantization tables for SIMD distance compute │
|
||||
│ └── Per-layer calibration statistics (min, max, zero-point) │
|
||||
│ Use: rvf-quant temperature-tiered quantization — │
|
||||
│ hot layers stay f16, warm layers u8, cold layers binary │
|
||||
├───────────────────────────────────────────────────────────────┤
|
||||
│ SEGMENT 5: Witness (0x0A) — Training Proof Chain │
|
||||
│ ├── Deterministic training proof (seed, loss curve, hash) │
|
||||
│ ├── Dataset provenance (MM-Fi commit hash, download URL) │
|
||||
│ ├── Validation metrics (PCK@0.2, OKS mAP, GPS scores) │
|
||||
│ ├── Ed25519 signature over weight hash │
|
||||
│ └── Attestation: training hardware, duration, config │
|
||||
│ Use: Verifiable proof that model weights match a specific │
|
||||
│ training run. Anyone can re-run training with same seed │
|
||||
│ and verify the weight hash matches the witness. │
|
||||
├───────────────────────────────────────────────────────────────┤
|
||||
│ SEGMENT 6: Meta (0x07) — Model Metadata │
|
||||
│ ├── COCO keypoint names and skeleton connectivity │
|
||||
│ ├── DensePose body part labels (24 parts + background) │
|
||||
│ ├── UV coordinate range and resolution │
|
||||
│ ├── Input normalization statistics (mean, std per subcarrier)│
|
||||
│ ├── RuVector crate versions used during training │
|
||||
│ └── Environment calibration profiles (named, per-room) │
|
||||
├───────────────────────────────────────────────────────────────┤
|
||||
│ SEGMENT 7: AggregateWeights (0x36) — SONA LoRA Deltas │
|
||||
│ ├── Per-environment LoRA adaptation matrices (A, B per layer)│
|
||||
│ ├── EWC++ Fisher information diagonal │
|
||||
│ ├── Optimal θ* reference parameters │
|
||||
│ ├── Adaptation round count and convergence metrics │
|
||||
│ └── Named profiles: "lab-a", "living-room", "office-3f" │
|
||||
│ Use: Multiple environment adaptations stored in one file. │
|
||||
│ Server loads the matching profile or creates a new one. │
|
||||
├───────────────────────────────────────────────────────────────┤
|
||||
│ SEGMENT 8: Profile (0x0B) — RVDNA Domain Profile │
|
||||
│ ├── Domain: "wifi-csi-densepose" │
|
||||
│ ├── Input spec: [B, T*ant, sub] CSI tensor format │
|
||||
│ ├── Output spec: keypoints [B,17,H,W], parts [B,25,H,W], │
|
||||
│ │ UV [B,48,H,W], confidence [B,1] │
|
||||
│ ├── Hardware requirements: min RAM, recommended GPU │
|
||||
│ └── Supported data sources: esp32, wifi-rssi, simulation │
|
||||
├───────────────────────────────────────────────────────────────┤
|
||||
│ SEGMENT 9: Crypto (0x0C) — Signature and Keys │
|
||||
│ ├── Ed25519 public key for model publisher │
|
||||
│ ├── Signature over all segment content hashes │
|
||||
│ └── Certificate chain (optional, for enterprise deployment) │
|
||||
├───────────────────────────────────────────────────────────────┤
|
||||
│ SEGMENT 10: Wasm (0x10) — Self-Bootstrapping Runtime │
|
||||
│ ├── Compiled WASM inference engine │
|
||||
│ │ (ruvector-sparse-inference-wasm) │
|
||||
│ ├── WASM microkernel for RVF segment parsing │
|
||||
│ └── Browser-compatible: load .rvf → run inference in-browser │
|
||||
│ Use: The .rvf file is fully self-contained — a WASM host │
|
||||
│ can execute inference without any external dependencies. │
|
||||
├───────────────────────────────────────────────────────────────┤
|
||||
│ SEGMENT 11: Dashboard (0x11) — Embedded Visualization │
|
||||
│ ├── Three.js-based pose visualization (HTML/JS/CSS) │
|
||||
│ ├── Gaussian splat renderer for signal field │
|
||||
│ └── Served at http://localhost:8080/ when model is loaded │
|
||||
│ Use: Open the .rvf file → get a working UI with no install │
|
||||
└───────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
#### RVF Loading Sequence
|
||||
|
||||
```
|
||||
1. Read tail → find_latest_manifest() → SegmentDirectory
|
||||
2. Load Manifest (seg 0) → validate magic, version, model ID
|
||||
3. Load Profile (seg 8) → verify input/output spec compatibility
|
||||
4. Load Crypto (seg 9) → verify Ed25519 signature chain
|
||||
5. Load Quant (seg 4) → prepare quantization codebooks
|
||||
6. Load Index Layer A (seg 2) → entry points ready (<5ms)
|
||||
↓ (inference available at reduced accuracy)
|
||||
7. Load Vec (seg 1) → hot weight partitions via Layer A routing
|
||||
8. Load Index Layer B (seg 2) → hot adjacency ready (100ms)
|
||||
↓ (inference at full accuracy for common poses)
|
||||
9. Load Overlay (seg 3) → min-cut graphs, body skeleton
|
||||
10. Load AggregateWeights (seg 7) → apply matching SONA profile
|
||||
11. Load Index Layer C (seg 2) → complete graph loaded
|
||||
↓ (full inference with all weight partitions)
|
||||
12. Load Wasm (seg 10) → WASM runtime available (optional)
|
||||
13. Load Dashboard (seg 11) → UI served (optional)
|
||||
```
|
||||
|
||||
**Progressive availability**: Inference begins after step 6 (~5ms) with approximate
|
||||
results. Full accuracy is reached by step 9 (~500ms). This enables instant startup
|
||||
with gradually improving quality — critical for real-time applications.
|
||||
|
||||
#### RVF Build Pipeline
|
||||
|
||||
After training completes, the model is packaged into an `.rvf` file:
|
||||
|
||||
```bash
|
||||
# Build the RVF container from trained checkpoint
|
||||
cargo run -p wifi-densepose-train --bin build-rvf -- \
|
||||
--checkpoint checkpoints/best-pck.pt \
|
||||
--quantize int8,fp16 \
|
||||
--hnsw-build \
|
||||
--sign --key model-signing-key.pem \
|
||||
--include-wasm \
|
||||
--include-dashboard ../../ui \
|
||||
--output wifi-densepose-v1.rvf
|
||||
|
||||
# Verify the built container
|
||||
cargo run -p wifi-densepose-train --bin verify-rvf -- \
|
||||
--input wifi-densepose-v1.rvf \
|
||||
--verify-signature \
|
||||
--verify-witness \
|
||||
--benchmark-inference
|
||||
```
|
||||
|
||||
#### RVF Runtime Integration
|
||||
|
||||
The sensing server loads the `.rvf` container at startup:
|
||||
|
||||
```bash
|
||||
# Load model from RVF container
|
||||
./target/release/sensing-server \
|
||||
--model wifi-densepose-v1.rvf \
|
||||
--source auto \
|
||||
--ui-from-rvf # serve Dashboard segment instead of --ui-path
|
||||
```
|
||||
|
||||
```rust
|
||||
// In sensing-server/src/main.rs
|
||||
use rvf_runtime::RvfContainer;
|
||||
use rvf_index::layers::IndexLayer;
|
||||
use rvf_quant::QuantizedVec;
|
||||
|
||||
let container = RvfContainer::open("wifi-densepose-v1.rvf")?;
|
||||
|
||||
// Progressive load: Layer A first for instant startup
|
||||
let index = container.load_index(IndexLayer::A)?;
|
||||
let weights = container.load_vec_hot(&index)?; // hot partitions only
|
||||
|
||||
// Full load in background
|
||||
tokio::spawn(async move {
|
||||
container.load_index(IndexLayer::B).await?;
|
||||
container.load_index(IndexLayer::C).await?;
|
||||
container.load_vec_cold().await?; // remaining partitions
|
||||
});
|
||||
|
||||
// SONA environment adaptation
|
||||
let sona_deltas = container.load_aggregate_weights("office-3f")?;
|
||||
model.apply_lora_deltas(&sona_deltas);
|
||||
|
||||
// Serve embedded dashboard
|
||||
let dashboard = container.load_dashboard()?;
|
||||
// Mount at /ui/* routes in Axum
|
||||
```
|
||||
|
||||
## Implementation Plan
|
||||
|
||||
### Phase 1: Dataset Loaders (2 weeks)
|
||||
|
||||
- Implement `MmFiDataset` in `wifi-densepose-train/src/dataset.rs`
|
||||
- Read MM-Fi `.npy` files with antenna correction (1TX/3RX → 3×3 zero-padding)
|
||||
- Subcarrier resampling 114→56 via `ruvector-solver::NeumannSolver`
|
||||
- Phase sanitization via `wifi-densepose-signal::phase_sanitizer`
|
||||
- Implement `WiPoseDataset` for secondary dataset
|
||||
- Temporal windowing with `ruvector-temporal-tensor`
|
||||
- **Deliverable**: `cargo test -p wifi-densepose-train` with dataset loading tests
|
||||
|
||||
### Phase 2: Graph Transformer Integration (2 weeks)
|
||||
|
||||
- Add `ruvector-graph-transformer` dependency to `wifi-densepose-train`
|
||||
- Replace bottleneck self-attention in `ModalityTranslator` with proof-gated graph transformer
|
||||
- Build antenna topology graph (nodes = antenna pairs, edges = spatial/temporal proximity)
|
||||
- Add `ruvector-gnn` dependency for body graph reasoning
|
||||
- Build COCO body skeleton graph (17 nodes, 16 anatomical edges)
|
||||
- Implement GNN message passing in spatial decoder
|
||||
- **Deliverable**: Model forward pass produces correct output shapes with graph layers
|
||||
|
||||
### Phase 3: Teacher-Student Label Generation (1 week)
|
||||
|
||||
- Python script using Detectron2 DensePose to generate UV pseudo-labels from MM-Fi RGB frames
|
||||
- Cache labels as `.npy` for Rust loader consumption
|
||||
- Validate label quality on a random subset (visual inspection)
|
||||
- **Deliverable**: Complete UV label set for MM-Fi training split
|
||||
|
||||
### Phase 4: Training Loop (3 weeks)
|
||||
|
||||
- Implement `WiFiDensePoseTrainer` with full loss function (6 terms)
|
||||
- Add `ruvector-math` optimal transport loss term
|
||||
- Integrate GNN edge consistency loss
|
||||
- Training loop with cosine LR schedule, early stopping, checkpointing
|
||||
- Validation metrics: PCK@0.2, OKS mAP, DensePose GPS
|
||||
- Deterministic proof verification (`proof.rs`) with weight hash
|
||||
- **Deliverable**: Trained model checkpoint achieving PCK@0.2 >70% on MM-Fi validation
|
||||
|
||||
### Phase 5: SONA Online Adaptation (2 weeks)
|
||||
|
||||
- Integrate `ruvector-sona` into inference pipeline
|
||||
- Implement LoRA injection at translator, backbone, and DensePose head layers
|
||||
- Implement EWC++ Fisher information computation and regularization
|
||||
- Self-supervised temporal consistency loss for unsupervised adaptation
|
||||
- Calibration mode: 5-minute camera session for supervised fine-tuning
|
||||
- Drift detection: monitor rolling PCK on temporal consistency proxy
|
||||
- **Deliverable**: Adaptation converges in <50 gradient steps, PCK recovers within 10% of base
|
||||
|
||||
### Phase 6: Sparse Inference and Edge Deployment (2 weeks)
|
||||
|
||||
- Profile neuron activation frequencies on validation set
|
||||
- Apply `ruvector-sparse-inference` hot/cold neuron partitioning
|
||||
- INT8 quantization for backbone, FP16 for heads
|
||||
- ONNX export with quantized weights
|
||||
- Benchmark on x86 (target: <10ms) and ARM (target: <50ms)
|
||||
- WASM export via `ruvector-sparse-inference-wasm` for browser inference
|
||||
- **Deliverable**: Quantized ONNX model, benchmark results, WASM binary
|
||||
|
||||
### Phase 7: RVF Container Build Pipeline (2 weeks)
|
||||
|
||||
- Implement `build-rvf` binary in `wifi-densepose-train`
|
||||
- Serialize trained weights into `Vec` segment (SegmentType::Vec, 0x01)
|
||||
- Build HNSW index over weight partitions for sparse inference (SegmentType::Index, 0x02)
|
||||
- Serialize min-cut graph overlays: subcarrier partition, antenna topology, body skeleton (SegmentType::Overlay, 0x03)
|
||||
- Generate quantization codebooks via `rvf-quant` (SegmentType::Quant, 0x06)
|
||||
- Write training proof witness with Ed25519 signature (SegmentType::Witness, 0x0A)
|
||||
- Store model metadata, COCO keypoint schema, normalization stats (SegmentType::Meta, 0x07)
|
||||
- Store SONA LoRA adaptation deltas per environment (SegmentType::AggregateWeights, 0x36)
|
||||
- Write RVDNA domain profile for WiFi CSI DensePose (SegmentType::Profile, 0x0B)
|
||||
- Optionally embed WASM inference runtime (SegmentType::Wasm, 0x10)
|
||||
- Optionally embed Three.js dashboard (SegmentType::Dashboard, 0x11)
|
||||
- Build Level-1 manifest and segment directory (SegmentType::Manifest, 0x05)
|
||||
- Implement `verify-rvf` binary for container validation
|
||||
- **Deliverable**: `wifi-densepose-v1.rvf` single-file container, verifiable and self-contained
|
||||
|
||||
### Phase 8: Integration with Sensing Server (1 week)
|
||||
|
||||
- Load `.rvf` container in `wifi-densepose-sensing-server` via `rvf-runtime`
|
||||
- Progressive loading: Layer A first for instant startup, full graph in background
|
||||
- Replace `derive_pose_from_sensing()` heuristic with trained model inference
|
||||
- Add `--model` CLI flag accepting `.rvf` path (or legacy `.onnx`)
|
||||
- Apply SONA LoRA deltas from `AggregateWeights` segment based on `--env` flag
|
||||
- Serve embedded Dashboard segment at `/ui/*` when `--ui-from-rvf` is set
|
||||
- Graceful fallback to heuristic when no model file present
|
||||
- Update WebSocket protocol to include DensePose UV data
|
||||
- **Deliverable**: Sensing server serves trained model from single `.rvf` file
|
||||
|
||||
## File Changes
|
||||
|
||||
### New Files
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `rust-port/.../wifi-densepose-train/src/dataset_mmfi.rs` | MM-Fi dataset loader with subcarrier resampling |
|
||||
| `rust-port/.../wifi-densepose-train/src/dataset_wipose.rs` | Wi-Pose dataset loader |
|
||||
| `rust-port/.../wifi-densepose-train/src/graph_transformer.rs` | Graph transformer integration |
|
||||
| `rust-port/.../wifi-densepose-train/src/body_gnn.rs` | GNN body graph reasoning |
|
||||
| `rust-port/.../wifi-densepose-train/src/adaptation.rs` | SONA LoRA + EWC++ adaptation |
|
||||
| `rust-port/.../wifi-densepose-train/src/trainer.rs` | Training loop with multi-term loss |
|
||||
| `scripts/generate_densepose_labels.py` | Teacher-student UV label generation |
|
||||
| `scripts/benchmark_inference.py` | Inference latency benchmarking |
|
||||
| `rust-port/.../wifi-densepose-train/src/rvf_builder.rs` | RVF container build pipeline |
|
||||
| `rust-port/.../wifi-densepose-train/src/bin/build_rvf.rs` | CLI binary for building `.rvf` containers |
|
||||
| `rust-port/.../wifi-densepose-train/src/bin/verify_rvf.rs` | CLI binary for verifying `.rvf` containers |
|
||||
|
||||
### Modified Files
|
||||
|
||||
| File | Change |
|
||||
|------|--------|
|
||||
| `rust-port/.../wifi-densepose-train/Cargo.toml` | Add ruvector-gnn, graph-transformer, sona, sparse-inference, math, rvf-types, rvf-wire, rvf-manifest, rvf-index, rvf-quant, rvf-crypto, rvf-runtime deps |
|
||||
| `rust-port/.../wifi-densepose-train/src/model.rs` | Integrate graph transformer + GNN layers |
|
||||
| `rust-port/.../wifi-densepose-train/src/losses.rs` | Add optimal transport + GNN edge consistency loss terms |
|
||||
| `rust-port/.../wifi-densepose-train/src/config.rs` | Add training hyperparameters for new components |
|
||||
| `rust-port/.../sensing-server/Cargo.toml` | Add rvf-runtime, rvf-types, rvf-index, rvf-quant deps |
|
||||
| `rust-port/.../sensing-server/src/main.rs` | Add `--model` flag, load `.rvf` container, progressive startup, serve embedded dashboard |
|
||||
|
||||
## Consequences
|
||||
|
||||
### Positive
|
||||
|
||||
- **Trained model produces accurate DensePose**: Moves from heuristic keypoints to learned body surface estimation backed by public dataset evaluation
|
||||
- **RuVector signal intelligence is a differentiator**: Graph transformers on antenna topology and GNN body reasoning are novel — no prior WiFi pose system uses these techniques
|
||||
- **SONA enables zero-shot deployment**: New environments don't require full retraining — LoRA adaptation with <50 gradient steps converges in seconds
|
||||
- **Sparse inference enables edge deployment**: PowerInfer-style neuron partitioning brings DensePose inference to ESP32-class hardware
|
||||
- **Graceful degradation**: Server falls back to heuristic pose when no model file is present — existing functionality is preserved
|
||||
- **Single-file deployment via RVF**: Trained model, embeddings, HNSW index, quantization codebooks, SONA adaptation profiles, WASM runtime, and dashboard UI packaged in one `.rvf` file — deploy by copying a single file
|
||||
- **Progressive loading**: RVF Layer A loads in <5ms for instant startup; full accuracy reached in ~500ms as remaining segments load
|
||||
- **Verifiable provenance**: RVF Witness segment contains deterministic training proof with Ed25519 signature — anyone can re-run training and verify weight hash
|
||||
- **Self-bootstrapping**: RVF Wasm segment enables browser-based inference with no server-side dependencies
|
||||
- **Open evaluation**: PCK, OKS, GPS metrics on public MM-Fi dataset provide reproducible, comparable results
|
||||
|
||||
### Negative
|
||||
|
||||
- **Training requires GPU**: Initial model training needs RTX 3090 or better (~8 hours on A100). Not all developers will have access.
|
||||
- **Teacher-student label generation requires Detectron2**: One-time Python + CUDA dependency for generating UV pseudo-labels from RGB frames
|
||||
- **MM-Fi CC BY-NC license**: Weights trained on MM-Fi cannot be used commercially without collecting proprietary data
|
||||
- **Environment-specific adaptation still required**: SONA reduces the burden but a brief calibration session in each new environment is still recommended for best accuracy
|
||||
- **6 additional RuVector crate dependencies**: Increases compile time and binary size. Mitigated by feature flags (e.g., `--features trained-model`).
|
||||
- **Model size on disk**: ~25MB (FP16) or ~12MB (INT8). Acceptable for server deployment, may need further pruning for WASM.
|
||||
|
||||
### Risks and Mitigations
|
||||
|
||||
| Risk | Mitigation |
|
||||
|------|------------|
|
||||
| MM-Fi 114→56 interpolation loses accuracy | Train at native 114 as alternative; ESP32 mesh can collect 56-sub data natively |
|
||||
| GNN overfits to training body types | Augment with diverse body proportions; Wi-Pose adds subject diversity |
|
||||
| SONA adaptation diverges in adversarial environments | EWC++ regularization caps parameter drift; rollback to base weights on detection |
|
||||
| Sparse inference degrades accuracy | Benchmark INT8 vs FP16 vs FP32; fall back to full precision if quality drops |
|
||||
| Training proof hash changes with RuVector version updates | Pin ruvector crate versions in Cargo.toml; regenerate hash on version bumps |
|
||||
|
||||
## References
|
||||
|
||||
- Geng et al., "DensePose From WiFi" (CMU, arXiv:2301.00250, 2023)
|
||||
- Yang et al., "MM-Fi: Multi-Modal Non-Intrusive 4D Human Dataset" (NeurIPS 2023, arXiv:2305.10345)
|
||||
- Hu et al., "LoRA: Low-Rank Adaptation of Large Language Models" (ICLR 2022)
|
||||
- Kirkpatrick et al., "Overcoming Catastrophic Forgetting in Neural Networks" (PNAS, 2017)
|
||||
- Song et al., "PowerInfer: Fast Large Language Model Serving with a Consumer-grade GPU" (2024)
|
||||
- ADR-005: SONA Self-Learning for Pose Estimation
|
||||
- ADR-015: Public Dataset Strategy for Trained Pose Estimation Model
|
||||
- ADR-016: RuVector Integration for Training Pipeline
|
||||
- ADR-020: Migrate AI/Model Inference to Rust with RuVector and ONNX Runtime
|
||||
|
||||
## Appendix A: RuQu Consideration
|
||||
|
||||
**ruQu** ("Classical nervous system for quantum machines") provides real-time coherence
|
||||
assessment via dynamic min-cut. While primarily designed for quantum error correction
|
||||
(syndrome decoding, surface code arbitration), its core primitive — the `CoherenceGate` —
|
||||
is architecturally relevant to WiFi CSI processing:
|
||||
|
||||
- **CoherenceGate** uses `ruvector-mincut` to make real-time gate/pass decisions on
|
||||
signal streams based on structural coherence thresholds. In quantum computing, this
|
||||
gates qubit syndrome streams. For WiFi CSI, the same mechanism could gate CSI
|
||||
subcarrier streams — passing only subcarriers whose coherence (phase stability across
|
||||
antennas) exceeds a dynamic threshold.
|
||||
|
||||
- **Syndrome filtering** (`filters.rs`) implements Kalman-like adaptive filters that
|
||||
could be repurposed for CSI noise filtering — treating each subcarrier's amplitude
|
||||
drift as a "syndrome" stream.
|
||||
|
||||
- **Min-cut gated transformer** integration (optional feature) provides coherence-optimized
|
||||
attention with 50% FLOP reduction — directly applicable to the `ModalityTranslator`
|
||||
bottleneck.
|
||||
|
||||
**Decision**: ruQu is not included in the initial pipeline (Phase 1-8) but is marked as a
|
||||
**Phase 9 exploration** candidate for coherence-gated CSI filtering. The CoherenceGate
|
||||
primitive maps naturally to subcarrier quality assessment, and the integration path is
|
||||
clean since ruQu already depends on `ruvector-mincut`.
|
||||
|
||||
## Appendix B: Training Data Strategy
|
||||
|
||||
The pipeline supports three data sources for training, used in combination:
|
||||
|
||||
| Source | Subcarriers | Pose Labels | Volume | Cost | When |
|
||||
|--------|-------------|-------------|--------|------|------|
|
||||
| **MM-Fi** (public) | 114 → 56 (interpolated) | 17 COCO + DensePose UV | 40 subjects, 320K frames | Free (CC BY-NC) | Phase 1 — bootstrap |
|
||||
| **Wi-Pose** (public) | 30 → 56 (zero-padded) | 18 keypoints | 12 subjects, 166K packets | Free (research) | Phase 1 — diversity |
|
||||
| **ESP32 self-collected** | 56 (native) | Teacher-student from camera | Unlimited, environment-specific | Hardware only ($54) | Phase 4+ — fine-tuning |
|
||||
|
||||
**Recommended approach: Both public + ESP32 data.**
|
||||
|
||||
1. **Pre-train on MM-Fi + Wi-Pose** (public data, Phase 1-4): Provides the base model
|
||||
with diverse subjects and actions. The 114→56 subcarrier interpolation is acceptable
|
||||
for learning general CSI-to-pose mappings.
|
||||
|
||||
2. **Fine-tune on ESP32 self-collected data** (Phase 5+, SONA adaptation): Collect
|
||||
5-30 minutes of paired ESP32 CSI + camera data in each target environment. The camera
|
||||
serves as the teacher model (Detectron2 generates pseudo-labels). SONA LoRA adaptation
|
||||
takes <50 gradient steps to converge.
|
||||
|
||||
3. **Continuous adaptation** (runtime): SONA's self-supervised temporal consistency loss
|
||||
refines the model without any camera, using the assumption that poses change smoothly
|
||||
over short time windows.
|
||||
|
||||
This three-tier strategy gives you:
|
||||
- A working model from day one (public data)
|
||||
- Environment-specific accuracy (ESP32 fine-tuning)
|
||||
- Ongoing drift correction (SONA runtime adaptation)
|
||||
20
rust-port/wifi-densepose-rs/Cargo.lock
generated
20
rust-port/wifi-densepose-rs/Cargo.lock
generated
@@ -4110,10 +4110,12 @@ dependencies = [
|
||||
"futures-util",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"wifi-densepose-wifiscan",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4175,6 +4177,15 @@ dependencies = [
|
||||
"wifi-densepose-signal",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-vitals"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-wasm"
|
||||
version = "0.1.0"
|
||||
@@ -4197,6 +4208,15 @@ dependencies = [
|
||||
"wifi-densepose-mat",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-wifiscan"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"tokio",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
|
||||
@@ -13,6 +13,8 @@ members = [
|
||||
"crates/wifi-densepose-mat",
|
||||
"crates/wifi-densepose-train",
|
||||
"crates/wifi-densepose-sensing-server",
|
||||
"crates/wifi-densepose-wifiscan",
|
||||
"crates/wifi-densepose-vitals",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
@@ -107,6 +109,7 @@ ruvector-temporal-tensor = "2.0.4"
|
||||
ruvector-solver = "2.0.4"
|
||||
ruvector-attention = "2.0.4"
|
||||
|
||||
|
||||
# Internal crates
|
||||
wifi-densepose-core = { path = "crates/wifi-densepose-core" }
|
||||
wifi-densepose-signal = { path = "crates/wifi-densepose-signal" }
|
||||
|
||||
@@ -5,6 +5,10 @@ edition.workspace = true
|
||||
description = "Lightweight Axum server for WiFi sensing UI with RuVector signal processing"
|
||||
license.workspace = true
|
||||
|
||||
[lib]
|
||||
name = "wifi_densepose_sensing_server"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "sensing-server"
|
||||
path = "src/main.rs"
|
||||
@@ -29,3 +33,9 @@ chrono = { version = "0.4", features = ["serde"] }
|
||||
|
||||
# CLI
|
||||
clap = { workspace = true }
|
||||
|
||||
# Multi-BSSID WiFi scanning pipeline (ADR-022 Phase 3)
|
||||
wifi-densepose-wifiscan = { path = "../wifi-densepose-wifiscan" }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.10"
|
||||
|
||||
@@ -0,0 +1,850 @@
|
||||
//! Dataset loaders for WiFi-to-DensePose training pipeline (ADR-023 Phase 1).
|
||||
//!
|
||||
//! Provides unified data loading for MM-Fi (NeurIPS 2023) and Wi-Pose datasets,
|
||||
//! with from-scratch .npy/.mat v5 parsers, subcarrier resampling, and a unified
|
||||
//! `DataPipeline` for normalized, windowed training samples.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
// ── Error type ───────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum DatasetError {
|
||||
Io(io::Error),
|
||||
Format(String),
|
||||
Missing(String),
|
||||
Shape(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for DatasetError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Io(e) => write!(f, "I/O error: {e}"),
|
||||
Self::Format(s) => write!(f, "format error: {s}"),
|
||||
Self::Missing(s) => write!(f, "missing: {s}"),
|
||||
Self::Shape(s) => write!(f, "shape error: {s}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for DatasetError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
if let Self::Io(e) = self { Some(e) } else { None }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<io::Error> for DatasetError {
|
||||
fn from(e: io::Error) -> Self { Self::Io(e) }
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, DatasetError>;
|
||||
|
||||
// ── NpyArray ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Dense array from .npy: flat f32 data with shape metadata.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NpyArray {
|
||||
pub shape: Vec<usize>,
|
||||
pub data: Vec<f32>,
|
||||
}
|
||||
|
||||
impl NpyArray {
|
||||
pub fn len(&self) -> usize { self.data.len() }
|
||||
pub fn is_empty(&self) -> bool { self.data.is_empty() }
|
||||
pub fn ndim(&self) -> usize { self.shape.len() }
|
||||
}
|
||||
|
||||
// ── NpyReader ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Minimal NumPy .npy format reader (f32/f64, v1/v2).
|
||||
pub struct NpyReader;
|
||||
|
||||
impl NpyReader {
|
||||
pub fn read_file(path: &Path) -> Result<NpyArray> {
|
||||
Self::parse(&std::fs::read(path)?)
|
||||
}
|
||||
|
||||
pub fn parse(buf: &[u8]) -> Result<NpyArray> {
|
||||
if buf.len() < 10 { return Err(DatasetError::Format("file too small for .npy".into())); }
|
||||
if &buf[0..6] != b"\x93NUMPY" {
|
||||
return Err(DatasetError::Format("missing .npy magic".into()));
|
||||
}
|
||||
let major = buf[6];
|
||||
let (header_len, header_start) = match major {
|
||||
1 => (u16::from_le_bytes([buf[8], buf[9]]) as usize, 10usize),
|
||||
2 | 3 => {
|
||||
if buf.len() < 12 { return Err(DatasetError::Format("truncated v2 header".into())); }
|
||||
(u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]) as usize, 12)
|
||||
}
|
||||
_ => return Err(DatasetError::Format(format!("unsupported .npy version {major}"))),
|
||||
};
|
||||
let header_end = header_start + header_len;
|
||||
if header_end > buf.len() { return Err(DatasetError::Format("header past EOF".into())); }
|
||||
let hdr = std::str::from_utf8(&buf[header_start..header_end])
|
||||
.map_err(|_| DatasetError::Format("non-UTF8 header".into()))?;
|
||||
|
||||
let dtype = Self::extract_field(hdr, "descr")?;
|
||||
let is_f64 = dtype.contains("f8") || dtype.contains("float64");
|
||||
let is_f32 = dtype.contains("f4") || dtype.contains("float32");
|
||||
let is_big = dtype.starts_with('>');
|
||||
if !is_f32 && !is_f64 {
|
||||
return Err(DatasetError::Format(format!("unsupported dtype '{dtype}'")));
|
||||
}
|
||||
let fortran = Self::extract_field(hdr, "fortran_order")
|
||||
.unwrap_or_else(|_| "False".into()).contains("True");
|
||||
let shape = Self::parse_shape(hdr)?;
|
||||
let elem_sz: usize = if is_f64 { 8 } else { 4 };
|
||||
let total: usize = shape.iter().product::<usize>().max(1);
|
||||
if header_end + total * elem_sz > buf.len() {
|
||||
return Err(DatasetError::Format("data truncated".into()));
|
||||
}
|
||||
let raw = &buf[header_end..header_end + total * elem_sz];
|
||||
let mut data: Vec<f32> = if is_f64 {
|
||||
raw.chunks_exact(8).map(|c| {
|
||||
let v = if is_big { f64::from_be_bytes(c.try_into().unwrap()) }
|
||||
else { f64::from_le_bytes(c.try_into().unwrap()) };
|
||||
v as f32
|
||||
}).collect()
|
||||
} else {
|
||||
raw.chunks_exact(4).map(|c| {
|
||||
if is_big { f32::from_be_bytes(c.try_into().unwrap()) }
|
||||
else { f32::from_le_bytes(c.try_into().unwrap()) }
|
||||
}).collect()
|
||||
};
|
||||
if fortran && shape.len() == 2 {
|
||||
let (r, c) = (shape[0], shape[1]);
|
||||
let mut cd = vec![0.0f32; data.len()];
|
||||
for ri in 0..r { for ci in 0..c { cd[ri*c+ci] = data[ci*r+ri]; } }
|
||||
data = cd;
|
||||
}
|
||||
let shape = if shape.is_empty() { vec![1] } else { shape };
|
||||
Ok(NpyArray { shape, data })
|
||||
}
|
||||
|
||||
fn extract_field(hdr: &str, field: &str) -> Result<String> {
|
||||
for pat in &[format!("'{field}': "), format!("'{field}':"), format!("\"{field}\": ")] {
|
||||
if let Some(s) = hdr.find(pat.as_str()) {
|
||||
let rest = &hdr[s + pat.len()..];
|
||||
let end = rest.find(',').or_else(|| rest.find('}')).unwrap_or(rest.len());
|
||||
return Ok(rest[..end].trim().trim_matches('\'').trim_matches('"').into());
|
||||
}
|
||||
}
|
||||
Err(DatasetError::Format(format!("field '{field}' not found")))
|
||||
}
|
||||
|
||||
fn parse_shape(hdr: &str) -> Result<Vec<usize>> {
|
||||
let si = hdr.find("'shape'").or_else(|| hdr.find("\"shape\""))
|
||||
.ok_or_else(|| DatasetError::Format("no 'shape'".into()))?;
|
||||
let rest = &hdr[si..];
|
||||
let ps = rest.find('(').ok_or_else(|| DatasetError::Format("no '('".into()))?;
|
||||
let pe = rest[ps..].find(')').ok_or_else(|| DatasetError::Format("no ')'".into()))?;
|
||||
let inner = rest[ps+1..ps+pe].trim();
|
||||
if inner.is_empty() { return Ok(vec![]); }
|
||||
inner.split(',').map(|s| s.trim()).filter(|s| !s.is_empty())
|
||||
.map(|s| s.parse::<usize>().map_err(|_| DatasetError::Format(format!("bad dim: '{s}'"))))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
// ── MatReader ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Minimal MATLAB .mat v5 reader for numeric arrays.
|
||||
pub struct MatReader;
|
||||
|
||||
const MI_INT8: u32 = 1;
|
||||
#[allow(dead_code)] const MI_UINT8: u32 = 2;
|
||||
#[allow(dead_code)] const MI_INT16: u32 = 3;
|
||||
#[allow(dead_code)] const MI_UINT16: u32 = 4;
|
||||
const MI_INT32: u32 = 5;
|
||||
const MI_UINT32: u32 = 6;
|
||||
const MI_SINGLE: u32 = 7;
|
||||
const MI_DOUBLE: u32 = 9;
|
||||
const MI_MATRIX: u32 = 14;
|
||||
|
||||
impl MatReader {
|
||||
pub fn read_file(path: &Path) -> Result<HashMap<String, NpyArray>> {
|
||||
Self::parse(&std::fs::read(path)?)
|
||||
}
|
||||
|
||||
pub fn parse(buf: &[u8]) -> Result<HashMap<String, NpyArray>> {
|
||||
if buf.len() < 128 { return Err(DatasetError::Format("too small for .mat v5".into())); }
|
||||
let swap = u16::from_le_bytes([buf[126], buf[127]]) == 0x4D49;
|
||||
let mut result = HashMap::new();
|
||||
let mut off = 128;
|
||||
while off + 8 <= buf.len() {
|
||||
let (dt, ds, ts) = Self::read_tag(buf, off, swap)?;
|
||||
let el_start = off + ts;
|
||||
let el_end = el_start + ds;
|
||||
if el_end > buf.len() { break; }
|
||||
if dt == MI_MATRIX {
|
||||
if let Ok((n, a)) = Self::parse_matrix(&buf[el_start..el_end], swap) {
|
||||
result.insert(n, a);
|
||||
}
|
||||
}
|
||||
off = (el_end + 7) & !7;
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn read_tag(buf: &[u8], off: usize, swap: bool) -> Result<(u32, usize, usize)> {
|
||||
if off + 4 > buf.len() { return Err(DatasetError::Format("truncated tag".into())); }
|
||||
let raw = Self::u32(buf, off, swap);
|
||||
let upper = (raw >> 16) & 0xFFFF;
|
||||
if upper != 0 && upper <= 4 { return Ok((raw & 0xFFFF, upper as usize, 4)); }
|
||||
if off + 8 > buf.len() { return Err(DatasetError::Format("truncated tag".into())); }
|
||||
Ok((raw, Self::u32(buf, off + 4, swap) as usize, 8))
|
||||
}
|
||||
|
||||
fn parse_matrix(buf: &[u8], swap: bool) -> Result<(String, NpyArray)> {
|
||||
let (mut name, mut shape, mut data) = (String::new(), Vec::new(), Vec::new());
|
||||
let mut off = 0;
|
||||
while off + 4 <= buf.len() {
|
||||
let (st, ss, ts) = Self::read_tag(buf, off, swap)?;
|
||||
let ss_start = off + ts;
|
||||
let ss_end = (ss_start + ss).min(buf.len());
|
||||
match st {
|
||||
MI_UINT32 if shape.is_empty() && ss == 8 => {}
|
||||
MI_INT32 if shape.is_empty() => {
|
||||
for i in 0..ss / 4 { shape.push(Self::i32(buf, ss_start + i*4, swap) as usize); }
|
||||
}
|
||||
MI_INT8 if name.is_empty() && ss_end <= buf.len() => {
|
||||
name = String::from_utf8_lossy(&buf[ss_start..ss_end])
|
||||
.trim_end_matches('\0').to_string();
|
||||
}
|
||||
MI_DOUBLE => {
|
||||
for i in 0..ss / 8 {
|
||||
let p = ss_start + i * 8;
|
||||
if p + 8 <= buf.len() { data.push(Self::f64(buf, p, swap) as f32); }
|
||||
}
|
||||
}
|
||||
MI_SINGLE => {
|
||||
for i in 0..ss / 4 {
|
||||
let p = ss_start + i * 4;
|
||||
if p + 4 <= buf.len() { data.push(Self::f32(buf, p, swap)); }
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
off = (ss_end + 7) & !7;
|
||||
}
|
||||
if name.is_empty() { name = "unnamed".into(); }
|
||||
if shape.is_empty() && !data.is_empty() { shape = vec![data.len()]; }
|
||||
// Transpose column-major to row-major for 2D
|
||||
if shape.len() == 2 {
|
||||
let (r, c) = (shape[0], shape[1]);
|
||||
if r * c == data.len() {
|
||||
let mut cd = vec![0.0f32; data.len()];
|
||||
for ri in 0..r { for ci in 0..c { cd[ri*c+ci] = data[ci*r+ri]; } }
|
||||
data = cd;
|
||||
}
|
||||
}
|
||||
Ok((name, NpyArray { shape, data }))
|
||||
}
|
||||
|
||||
fn u32(b: &[u8], o: usize, s: bool) -> u32 {
|
||||
let v = [b[o], b[o+1], b[o+2], b[o+3]];
|
||||
if s { u32::from_be_bytes(v) } else { u32::from_le_bytes(v) }
|
||||
}
|
||||
fn i32(b: &[u8], o: usize, s: bool) -> i32 {
|
||||
let v = [b[o], b[o+1], b[o+2], b[o+3]];
|
||||
if s { i32::from_be_bytes(v) } else { i32::from_le_bytes(v) }
|
||||
}
|
||||
fn f64(b: &[u8], o: usize, s: bool) -> f64 {
|
||||
let v: [u8; 8] = b[o..o+8].try_into().unwrap();
|
||||
if s { f64::from_be_bytes(v) } else { f64::from_le_bytes(v) }
|
||||
}
|
||||
fn f32(b: &[u8], o: usize, s: bool) -> f32 {
|
||||
let v = [b[o], b[o+1], b[o+2], b[o+3]];
|
||||
if s { f32::from_be_bytes(v) } else { f32::from_le_bytes(v) }
|
||||
}
|
||||
}
|
||||
|
||||
// ── Core data types ──────────────────────────────────────────────────────────
|
||||
|
||||
/// A single CSI (Channel State Information) sample.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CsiSample {
|
||||
pub amplitude: Vec<f32>,
|
||||
pub phase: Vec<f32>,
|
||||
pub timestamp_ms: u64,
|
||||
}
|
||||
|
||||
/// UV coordinate map for a body part in DensePose representation.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BodyPartUV {
|
||||
pub part_id: u8,
|
||||
pub u_coords: Vec<f32>,
|
||||
pub v_coords: Vec<f32>,
|
||||
}
|
||||
|
||||
/// Pose label: 17 COCO keypoints + optional DensePose body-part UVs.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PoseLabel {
|
||||
pub keypoints: [(f32, f32, f32); 17],
|
||||
pub body_parts: Vec<BodyPartUV>,
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
impl Default for PoseLabel {
|
||||
fn default() -> Self {
|
||||
Self { keypoints: [(0.0, 0.0, 0.0); 17], body_parts: Vec::new(), confidence: 0.0 }
|
||||
}
|
||||
}
|
||||
|
||||
// ── SubcarrierResampler ──────────────────────────────────────────────────────
|
||||
|
||||
/// Resamples subcarrier data via linear interpolation or zero-padding.
|
||||
pub struct SubcarrierResampler;
|
||||
|
||||
impl SubcarrierResampler {
|
||||
/// Resample: passthrough if equal, zero-pad if upsampling, interpolate if downsampling.
|
||||
pub fn resample(input: &[f32], from: usize, to: usize) -> Vec<f32> {
|
||||
if from == to || from == 0 || to == 0 { return input.to_vec(); }
|
||||
if from < to { Self::zero_pad(input, from, to) } else { Self::interpolate(input, from, to) }
|
||||
}
|
||||
|
||||
/// Resample phase data with unwrapping before interpolation.
|
||||
pub fn resample_phase(input: &[f32], from: usize, to: usize) -> Vec<f32> {
|
||||
if from == to || from == 0 || to == 0 { return input.to_vec(); }
|
||||
let unwrapped = Self::phase_unwrap(input);
|
||||
let resampled = if from < to { Self::zero_pad(&unwrapped, from, to) }
|
||||
else { Self::interpolate(&unwrapped, from, to) };
|
||||
let pi = std::f32::consts::PI;
|
||||
resampled.iter().map(|&p| {
|
||||
let mut w = p % (2.0 * pi);
|
||||
if w > pi { w -= 2.0 * pi; }
|
||||
if w < -pi { w += 2.0 * pi; }
|
||||
w
|
||||
}).collect()
|
||||
}
|
||||
|
||||
fn zero_pad(input: &[f32], from: usize, to: usize) -> Vec<f32> {
|
||||
let pad_left = (to - from) / 2;
|
||||
let mut out = vec![0.0f32; to];
|
||||
for i in 0..from.min(input.len()) {
|
||||
if pad_left + i < to { out[pad_left + i] = input[i]; }
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn interpolate(input: &[f32], from: usize, to: usize) -> Vec<f32> {
|
||||
let n = input.len().min(from);
|
||||
if n <= 1 { return vec![input.first().copied().unwrap_or(0.0); to]; }
|
||||
(0..to).map(|i| {
|
||||
let pos = i as f64 * (n - 1) as f64 / (to - 1).max(1) as f64;
|
||||
let lo = pos.floor() as usize;
|
||||
let hi = (lo + 1).min(n - 1);
|
||||
let f = (pos - lo as f64) as f32;
|
||||
input[lo] * (1.0 - f) + input[hi] * f
|
||||
}).collect()
|
||||
}
|
||||
|
||||
fn phase_unwrap(phase: &[f32]) -> Vec<f32> {
|
||||
let pi = std::f32::consts::PI;
|
||||
let mut out = vec![0.0f32; phase.len()];
|
||||
if phase.is_empty() { return out; }
|
||||
out[0] = phase[0];
|
||||
for i in 1..phase.len() {
|
||||
let mut d = phase[i] - phase[i - 1];
|
||||
while d > pi { d -= 2.0 * pi; }
|
||||
while d < -pi { d += 2.0 * pi; }
|
||||
out[i] = out[i - 1] + d;
|
||||
}
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
// ── MmFiDataset ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// MM-Fi (NeurIPS 2023) dataset loader with 56 subcarriers and 17 COCO keypoints.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MmFiDataset {
|
||||
pub csi_frames: Vec<CsiSample>,
|
||||
pub labels: Vec<PoseLabel>,
|
||||
pub sample_rate_hz: f32,
|
||||
pub n_subcarriers: usize,
|
||||
}
|
||||
|
||||
impl MmFiDataset {
|
||||
pub const SUBCARRIERS: usize = 56;
|
||||
|
||||
/// Load from directory with csi_amplitude.npy/csi.npy and labels.npy/keypoints.npy.
|
||||
pub fn load_from_directory(path: &Path) -> Result<Self> {
|
||||
if !path.is_dir() {
|
||||
return Err(DatasetError::Missing(format!("directory not found: {}", path.display())));
|
||||
}
|
||||
let amp = NpyReader::read_file(&Self::find(path, &["csi_amplitude.npy", "csi.npy"])?)?;
|
||||
let n = amp.shape.first().copied().unwrap_or(0);
|
||||
let raw_sc = if amp.shape.len() >= 2 { amp.shape[1] } else { amp.data.len() / n.max(1) };
|
||||
let phase_arr = Self::find(path, &["csi_phase.npy"]).ok()
|
||||
.and_then(|p| NpyReader::read_file(&p).ok());
|
||||
let lab = NpyReader::read_file(&Self::find(path, &["labels.npy", "keypoints.npy"])?)?;
|
||||
|
||||
let mut csi_frames = Vec::with_capacity(n);
|
||||
let mut labels = Vec::with_capacity(n);
|
||||
for i in 0..n {
|
||||
let s = i * raw_sc;
|
||||
if s + raw_sc > amp.data.len() { break; }
|
||||
let amplitude = SubcarrierResampler::resample(&.data[s..s+raw_sc], raw_sc, Self::SUBCARRIERS);
|
||||
let phase = phase_arr.as_ref().map(|pa| {
|
||||
let ps = i * raw_sc;
|
||||
if ps + raw_sc <= pa.data.len() {
|
||||
SubcarrierResampler::resample_phase(&pa.data[ps..ps+raw_sc], raw_sc, Self::SUBCARRIERS)
|
||||
} else { vec![0.0; Self::SUBCARRIERS] }
|
||||
}).unwrap_or_else(|| vec![0.0; Self::SUBCARRIERS]);
|
||||
|
||||
csi_frames.push(CsiSample { amplitude, phase, timestamp_ms: i as u64 * 50 });
|
||||
|
||||
let ks = i * 17 * 3;
|
||||
let label = if ks + 51 <= lab.data.len() {
|
||||
let d = &lab.data[ks..ks + 51];
|
||||
let mut kp = [(0.0f32, 0.0, 0.0); 17];
|
||||
for k in 0..17 { kp[k] = (d[k*3], d[k*3+1], d[k*3+2]); }
|
||||
PoseLabel { keypoints: kp, body_parts: Vec::new(), confidence: 1.0 }
|
||||
} else { PoseLabel::default() };
|
||||
labels.push(label);
|
||||
}
|
||||
Ok(Self { csi_frames, labels, sample_rate_hz: 20.0, n_subcarriers: Self::SUBCARRIERS })
|
||||
}
|
||||
|
||||
pub fn resample_subcarriers(&mut self, from: usize, to: usize) {
|
||||
for f in &mut self.csi_frames {
|
||||
f.amplitude = SubcarrierResampler::resample(&f.amplitude, from, to);
|
||||
f.phase = SubcarrierResampler::resample_phase(&f.phase, from, to);
|
||||
}
|
||||
self.n_subcarriers = to;
|
||||
}
|
||||
|
||||
pub fn iter_windows(&self, ws: usize, stride: usize) -> impl Iterator<Item = (&[CsiSample], &[PoseLabel])> {
|
||||
let stride = stride.max(1);
|
||||
let n = self.csi_frames.len();
|
||||
(0..n).step_by(stride).filter(move |&s| s + ws <= n)
|
||||
.map(move |s| (&self.csi_frames[s..s+ws], &self.labels[s..s+ws]))
|
||||
}
|
||||
|
||||
pub fn split_train_val(self, ratio: f32) -> (Self, Self) {
|
||||
let split = (self.csi_frames.len() as f32 * ratio.clamp(0.0, 1.0)) as usize;
|
||||
let (tc, vc) = self.csi_frames.split_at(split);
|
||||
let (tl, vl) = self.labels.split_at(split);
|
||||
let mk = |c: &[CsiSample], l: &[PoseLabel]| Self {
|
||||
csi_frames: c.to_vec(), labels: l.to_vec(),
|
||||
sample_rate_hz: self.sample_rate_hz, n_subcarriers: self.n_subcarriers,
|
||||
};
|
||||
(mk(tc, tl), mk(vc, vl))
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize { self.csi_frames.len() }
|
||||
pub fn is_empty(&self) -> bool { self.csi_frames.is_empty() }
|
||||
pub fn get(&self, idx: usize) -> Option<(&CsiSample, &PoseLabel)> {
|
||||
self.csi_frames.get(idx).zip(self.labels.get(idx))
|
||||
}
|
||||
|
||||
fn find(dir: &Path, names: &[&str]) -> Result<PathBuf> {
|
||||
for n in names { let p = dir.join(n); if p.exists() { return Ok(p); } }
|
||||
Err(DatasetError::Missing(format!("none of {names:?} in {}", dir.display())))
|
||||
}
|
||||
}
|
||||
|
||||
// ── WiPoseDataset ────────────────────────────────────────────────────────────
|
||||
|
||||
/// Wi-Pose dataset loader: .mat v5, 30 subcarriers (-> 56), 18 keypoints (-> 17 COCO).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WiPoseDataset {
|
||||
pub csi_frames: Vec<CsiSample>,
|
||||
pub labels: Vec<PoseLabel>,
|
||||
pub sample_rate_hz: f32,
|
||||
pub n_subcarriers: usize,
|
||||
}
|
||||
|
||||
impl WiPoseDataset {
|
||||
pub const RAW_SUBCARRIERS: usize = 30;
|
||||
pub const TARGET_SUBCARRIERS: usize = 56;
|
||||
pub const RAW_KEYPOINTS: usize = 18;
|
||||
pub const COCO_KEYPOINTS: usize = 17;
|
||||
|
||||
pub fn load_from_mat(path: &Path) -> Result<Self> {
|
||||
let arrays = MatReader::read_file(path)?;
|
||||
let csi = arrays.get("csi").or_else(|| arrays.get("csi_data")).or_else(|| arrays.get("CSI"))
|
||||
.ok_or_else(|| DatasetError::Missing("no CSI variable in .mat".into()))?;
|
||||
let n = csi.shape.first().copied().unwrap_or(0);
|
||||
let raw = if csi.shape.len() >= 2 { csi.shape[1] } else { Self::RAW_SUBCARRIERS };
|
||||
let lab = arrays.get("keypoints").or_else(|| arrays.get("labels")).or_else(|| arrays.get("pose"));
|
||||
|
||||
let mut csi_frames = Vec::with_capacity(n);
|
||||
let mut labels = Vec::with_capacity(n);
|
||||
for i in 0..n {
|
||||
let s = i * raw;
|
||||
if s + raw > csi.data.len() { break; }
|
||||
let amp = SubcarrierResampler::resample(&csi.data[s..s+raw], raw, Self::TARGET_SUBCARRIERS);
|
||||
csi_frames.push(CsiSample { amplitude: amp, phase: vec![0.0; Self::TARGET_SUBCARRIERS], timestamp_ms: i as u64 * 100 });
|
||||
let label = lab.and_then(|la| {
|
||||
let ks = i * Self::RAW_KEYPOINTS * 3;
|
||||
if ks + Self::RAW_KEYPOINTS * 3 <= la.data.len() {
|
||||
Some(Self::map_18_to_17(&la.data[ks..ks + Self::RAW_KEYPOINTS * 3]))
|
||||
} else { None }
|
||||
}).unwrap_or_default();
|
||||
labels.push(label);
|
||||
}
|
||||
Ok(Self { csi_frames, labels, sample_rate_hz: 10.0, n_subcarriers: Self::TARGET_SUBCARRIERS })
|
||||
}
|
||||
|
||||
/// Map 18 keypoints to 17 COCO: keep index 0 (nose), drop index 1, map 2..18 -> 1..16.
|
||||
fn map_18_to_17(data: &[f32]) -> PoseLabel {
|
||||
let mut kp = [(0.0f32, 0.0, 0.0); 17];
|
||||
if data.len() >= 18 * 3 {
|
||||
kp[0] = (data[0], data[1], data[2]);
|
||||
for i in 1..17 { let s = (i + 1) * 3; kp[i] = (data[s], data[s+1], data[s+2]); }
|
||||
}
|
||||
PoseLabel { keypoints: kp, body_parts: Vec::new(), confidence: 1.0 }
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize { self.csi_frames.len() }
|
||||
pub fn is_empty(&self) -> bool { self.csi_frames.is_empty() }
|
||||
}
|
||||
|
||||
// ── DataPipeline ─────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum DataSource {
|
||||
MmFi(PathBuf),
|
||||
WiPose(PathBuf),
|
||||
Combined(Vec<DataSource>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DataConfig {
|
||||
pub source: DataSource,
|
||||
pub window_size: usize,
|
||||
pub stride: usize,
|
||||
pub target_subcarriers: usize,
|
||||
pub normalize: bool,
|
||||
}
|
||||
|
||||
impl Default for DataConfig {
|
||||
fn default() -> Self {
|
||||
Self { source: DataSource::Combined(Vec::new()), window_size: 10, stride: 5,
|
||||
target_subcarriers: 56, normalize: true }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TrainingSample {
|
||||
pub csi_window: Vec<Vec<f32>>,
|
||||
pub pose_label: PoseLabel,
|
||||
pub source: &'static str,
|
||||
}
|
||||
|
||||
/// Unified pipeline: loads, resamples, windows, and normalizes training data.
|
||||
pub struct DataPipeline { config: DataConfig }
|
||||
|
||||
impl DataPipeline {
|
||||
pub fn new(config: DataConfig) -> Self { Self { config } }
|
||||
|
||||
pub fn load(&self) -> Result<Vec<TrainingSample>> {
|
||||
let mut out = Vec::new();
|
||||
self.load_source(&self.config.source, &mut out)?;
|
||||
if self.config.normalize && !out.is_empty() { Self::normalize_samples(&mut out); }
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn load_source(&self, src: &DataSource, out: &mut Vec<TrainingSample>) -> Result<()> {
|
||||
match src {
|
||||
DataSource::MmFi(p) => {
|
||||
let mut ds = MmFiDataset::load_from_directory(p)?;
|
||||
if ds.n_subcarriers != self.config.target_subcarriers {
|
||||
let f = ds.n_subcarriers;
|
||||
ds.resample_subcarriers(f, self.config.target_subcarriers);
|
||||
}
|
||||
self.extract_windows(&ds.csi_frames, &ds.labels, "mmfi", out);
|
||||
}
|
||||
DataSource::WiPose(p) => {
|
||||
let ds = WiPoseDataset::load_from_mat(p)?;
|
||||
self.extract_windows(&ds.csi_frames, &ds.labels, "wipose", out);
|
||||
}
|
||||
DataSource::Combined(srcs) => { for s in srcs { self.load_source(s, out)?; } }
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn extract_windows(&self, frames: &[CsiSample], labels: &[PoseLabel],
|
||||
source: &'static str, out: &mut Vec<TrainingSample>) {
|
||||
let (ws, stride) = (self.config.window_size, self.config.stride.max(1));
|
||||
let mut s = 0;
|
||||
while s + ws <= frames.len() {
|
||||
let window: Vec<Vec<f32>> = frames[s..s+ws].iter().map(|f| f.amplitude.clone()).collect();
|
||||
let label = labels.get(s + ws / 2).cloned().unwrap_or_default();
|
||||
out.push(TrainingSample { csi_window: window, pose_label: label, source });
|
||||
s += stride;
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_samples(samples: &mut [TrainingSample]) {
|
||||
let ns = samples.first().and_then(|s| s.csi_window.first()).map(|f| f.len()).unwrap_or(0);
|
||||
if ns == 0 { return; }
|
||||
let (mut sum, mut sq) = (vec![0.0f64; ns], vec![0.0f64; ns]);
|
||||
let mut cnt = 0u64;
|
||||
for s in samples.iter() {
|
||||
for f in &s.csi_window {
|
||||
for (j, &v) in f.iter().enumerate().take(ns) {
|
||||
let v = v as f64; sum[j] += v; sq[j] += v * v;
|
||||
}
|
||||
cnt += 1;
|
||||
}
|
||||
}
|
||||
if cnt == 0 { return; }
|
||||
let mean: Vec<f64> = sum.iter().map(|s| s / cnt as f64).collect();
|
||||
let std: Vec<f64> = sq.iter().zip(mean.iter())
|
||||
.map(|(&s, &m)| (s / cnt as f64 - m * m).max(0.0).sqrt().max(1e-8)).collect();
|
||||
for s in samples.iter_mut() {
|
||||
for f in &mut s.csi_window {
|
||||
for (j, v) in f.iter_mut().enumerate().take(ns) {
|
||||
*v = ((*v as f64 - mean[j]) / std[j]) as f32;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_npy_f32(shape: &[usize], data: &[f32]) -> Vec<u8> {
|
||||
let ss = if shape.len() == 1 { format!("({},)", shape[0]) }
|
||||
else { format!("({})", shape.iter().map(|d| d.to_string()).collect::<Vec<_>>().join(", ")) };
|
||||
let hdr = format!("{{'descr': '<f4', 'fortran_order': False, 'shape': {ss}, }}");
|
||||
let total = 10 + hdr.len();
|
||||
let padded = ((total + 63) / 64) * 64;
|
||||
let hl = padded - 10;
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(b"\x93NUMPY\x01\x00");
|
||||
buf.extend_from_slice(&(hl as u16).to_le_bytes());
|
||||
buf.extend_from_slice(hdr.as_bytes());
|
||||
buf.resize(10 + hl, b' ');
|
||||
for &v in data { buf.extend_from_slice(&v.to_le_bytes()); }
|
||||
buf
|
||||
}
|
||||
|
||||
fn make_npy_f64(shape: &[usize], data: &[f64]) -> Vec<u8> {
|
||||
let ss = if shape.len() == 1 { format!("({},)", shape[0]) }
|
||||
else { format!("({})", shape.iter().map(|d| d.to_string()).collect::<Vec<_>>().join(", ")) };
|
||||
let hdr = format!("{{'descr': '<f8', 'fortran_order': False, 'shape': {ss}, }}");
|
||||
let total = 10 + hdr.len();
|
||||
let padded = ((total + 63) / 64) * 64;
|
||||
let hl = padded - 10;
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(b"\x93NUMPY\x01\x00");
|
||||
buf.extend_from_slice(&(hl as u16).to_le_bytes());
|
||||
buf.extend_from_slice(hdr.as_bytes());
|
||||
buf.resize(10 + hl, b' ');
|
||||
for &v in data { buf.extend_from_slice(&v.to_le_bytes()); }
|
||||
buf
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn npy_header_parse_1d() {
|
||||
let buf = make_npy_f32(&[5], &[1.0, 2.0, 3.0, 4.0, 5.0]);
|
||||
let arr = NpyReader::parse(&buf).unwrap();
|
||||
assert_eq!(arr.shape, vec![5]);
|
||||
assert_eq!(arr.ndim(), 1);
|
||||
assert_eq!(arr.len(), 5);
|
||||
assert!((arr.data[0] - 1.0).abs() < f32::EPSILON);
|
||||
assert!((arr.data[4] - 5.0).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn npy_header_parse_2d() {
|
||||
let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
|
||||
let buf = make_npy_f32(&[3, 4], &data);
|
||||
let arr = NpyReader::parse(&buf).unwrap();
|
||||
assert_eq!(arr.shape, vec![3, 4]);
|
||||
assert_eq!(arr.ndim(), 2);
|
||||
assert_eq!(arr.len(), 12);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn npy_header_parse_3d() {
|
||||
let data: Vec<f64> = (0..24).map(|i| i as f64 * 0.5).collect();
|
||||
let buf = make_npy_f64(&[2, 3, 4], &data);
|
||||
let arr = NpyReader::parse(&buf).unwrap();
|
||||
assert_eq!(arr.shape, vec![2, 3, 4]);
|
||||
assert_eq!(arr.ndim(), 3);
|
||||
assert_eq!(arr.len(), 24);
|
||||
assert!((arr.data[23] - 11.5).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subcarrier_resample_passthrough() {
|
||||
let input: Vec<f32> = (0..56).map(|i| i as f32).collect();
|
||||
let output = SubcarrierResampler::resample(&input, 56, 56);
|
||||
assert_eq!(output, input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subcarrier_resample_upsample() {
|
||||
let input: Vec<f32> = (0..30).map(|i| (i + 1) as f32).collect();
|
||||
let out = SubcarrierResampler::resample(&input, 30, 56);
|
||||
assert_eq!(out.len(), 56);
|
||||
// pad_left = 13, leading zeros
|
||||
for i in 0..13 { assert!(out[i].abs() < f32::EPSILON, "expected zero at {i}"); }
|
||||
// original data in middle
|
||||
for i in 0..30 { assert!((out[13+i] - input[i]).abs() < f32::EPSILON); }
|
||||
// trailing zeros
|
||||
for i in 43..56 { assert!(out[i].abs() < f32::EPSILON, "expected zero at {i}"); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subcarrier_resample_downsample() {
|
||||
let input: Vec<f32> = (0..114).map(|i| i as f32).collect();
|
||||
let out = SubcarrierResampler::resample(&input, 114, 56);
|
||||
assert_eq!(out.len(), 56);
|
||||
assert!((out[0]).abs() < f32::EPSILON);
|
||||
assert!((out[55] - 113.0).abs() < 0.1);
|
||||
for i in 1..56 { assert!(out[i] >= out[i-1], "not monotonic at {i}"); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subcarrier_resample_preserves_dc() {
|
||||
let out = SubcarrierResampler::resample(&vec![42.0f32; 114], 114, 56);
|
||||
assert_eq!(out.len(), 56);
|
||||
for (i, &v) in out.iter().enumerate() {
|
||||
assert!((v - 42.0).abs() < 1e-5, "DC not preserved at {i}: {v}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mmfi_sample_structure() {
|
||||
let s = CsiSample { amplitude: vec![0.0; 56], phase: vec![0.0; 56], timestamp_ms: 100 };
|
||||
assert_eq!(s.amplitude.len(), 56);
|
||||
assert_eq!(s.phase.len(), 56);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wipose_zero_pad() {
|
||||
let raw: Vec<f32> = (1..=30).map(|i| i as f32).collect();
|
||||
let p = SubcarrierResampler::resample(&raw, 30, 56);
|
||||
assert_eq!(p.len(), 56);
|
||||
assert!(p[0].abs() < f32::EPSILON);
|
||||
assert!((p[13] - 1.0).abs() < f32::EPSILON);
|
||||
assert!((p[42] - 30.0).abs() < f32::EPSILON);
|
||||
assert!(p[55].abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wipose_keypoint_mapping() {
|
||||
let mut kp = vec![0.0f32; 18 * 3];
|
||||
kp[0] = 1.0; kp[1] = 2.0; kp[2] = 1.0; // nose
|
||||
kp[3] = 99.0; kp[4] = 99.0; kp[5] = 99.0; // extra (dropped)
|
||||
kp[6] = 3.0; kp[7] = 4.0; kp[8] = 1.0; // left eye -> COCO 1
|
||||
let label = WiPoseDataset::map_18_to_17(&kp);
|
||||
assert_eq!(label.keypoints.len(), 17);
|
||||
assert!((label.keypoints[0].0 - 1.0).abs() < f32::EPSILON);
|
||||
assert!((label.keypoints[1].0 - 3.0).abs() < f32::EPSILON); // not 99
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn train_val_split_ratio() {
|
||||
let mk = |n: usize| MmFiDataset {
|
||||
csi_frames: (0..n).map(|i| CsiSample { amplitude: vec![i as f32; 56], phase: vec![0.0; 56], timestamp_ms: i as u64 }).collect(),
|
||||
labels: (0..n).map(|_| PoseLabel::default()).collect(),
|
||||
sample_rate_hz: 20.0, n_subcarriers: 56,
|
||||
};
|
||||
let (train, val) = mk(100).split_train_val(0.8);
|
||||
assert_eq!(train.len(), 80);
|
||||
assert_eq!(val.len(), 20);
|
||||
assert_eq!(train.len() + val.len(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sliding_window_count() {
|
||||
let ds = MmFiDataset {
|
||||
csi_frames: (0..20).map(|i| CsiSample { amplitude: vec![i as f32; 56], phase: vec![0.0; 56], timestamp_ms: i as u64 }).collect(),
|
||||
labels: (0..20).map(|_| PoseLabel::default()).collect(),
|
||||
sample_rate_hz: 20.0, n_subcarriers: 56,
|
||||
};
|
||||
assert_eq!(ds.iter_windows(5, 5).count(), 4);
|
||||
assert_eq!(ds.iter_windows(5, 1).count(), 16);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sliding_window_overlap() {
|
||||
let ds = MmFiDataset {
|
||||
csi_frames: (0..10).map(|i| CsiSample { amplitude: vec![i as f32; 56], phase: vec![0.0; 56], timestamp_ms: i as u64 }).collect(),
|
||||
labels: (0..10).map(|_| PoseLabel::default()).collect(),
|
||||
sample_rate_hz: 20.0, n_subcarriers: 56,
|
||||
};
|
||||
let w: Vec<_> = ds.iter_windows(4, 2).collect();
|
||||
assert_eq!(w.len(), 4);
|
||||
assert!((w[0].0[0].amplitude[0]).abs() < f32::EPSILON);
|
||||
assert!((w[1].0[0].amplitude[0] - 2.0).abs() < f32::EPSILON);
|
||||
assert_eq!(w[0].0[2].amplitude[0], w[1].0[0].amplitude[0]); // overlap
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn data_pipeline_normalize() {
|
||||
let mut samples = vec![
|
||||
TrainingSample { csi_window: vec![vec![10.0, 20.0, 30.0]; 2], pose_label: PoseLabel::default(), source: "test" },
|
||||
TrainingSample { csi_window: vec![vec![30.0, 40.0, 50.0]; 2], pose_label: PoseLabel::default(), source: "test" },
|
||||
];
|
||||
DataPipeline::normalize_samples(&mut samples);
|
||||
for j in 0..3 {
|
||||
let (mut s, mut c) = (0.0f64, 0u64);
|
||||
for sam in &samples { for f in &sam.csi_window { s += f[j] as f64; c += 1; } }
|
||||
assert!(( s / c as f64).abs() < 1e-5, "mean not ~0 for sub {j}");
|
||||
let mut vs = 0.0f64;
|
||||
let m = s / c as f64;
|
||||
for sam in &samples { for f in &sam.csi_window { vs += (f[j] as f64 - m).powi(2); } }
|
||||
assert!(((vs / c as f64).sqrt() - 1.0).abs() < 0.1, "std not ~1 for sub {j}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pose_label_default() {
|
||||
let l = PoseLabel::default();
|
||||
assert_eq!(l.keypoints.len(), 17);
|
||||
assert!(l.body_parts.is_empty());
|
||||
assert!(l.confidence.abs() < f32::EPSILON);
|
||||
for (i, kp) in l.keypoints.iter().enumerate() {
|
||||
assert!(kp.0.abs() < f32::EPSILON && kp.1.abs() < f32::EPSILON, "kp {i} not zero");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn body_part_uv_round_trip() {
|
||||
let bpu = BodyPartUV { part_id: 5, u_coords: vec![0.1, 0.2, 0.3], v_coords: vec![0.4, 0.5, 0.6] };
|
||||
let json = serde_json::to_string(&bpu).unwrap();
|
||||
let r: BodyPartUV = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(r.part_id, 5);
|
||||
assert_eq!(r.u_coords.len(), 3);
|
||||
assert!((r.u_coords[0] - 0.1).abs() < f32::EPSILON);
|
||||
assert!((r.v_coords[2] - 0.6).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn combined_source_merges_datasets() {
|
||||
let mk = |n: usize, base: f32| -> (Vec<CsiSample>, Vec<PoseLabel>) {
|
||||
let f: Vec<CsiSample> = (0..n).map(|i| CsiSample { amplitude: vec![base + i as f32; 56], phase: vec![0.0; 56], timestamp_ms: i as u64 * 50 }).collect();
|
||||
let l: Vec<PoseLabel> = (0..n).map(|_| PoseLabel::default()).collect();
|
||||
(f, l)
|
||||
};
|
||||
let pipe = DataPipeline::new(DataConfig { source: DataSource::Combined(Vec::new()),
|
||||
window_size: 3, stride: 1, target_subcarriers: 56, normalize: false });
|
||||
let mut all = Vec::new();
|
||||
let (fa, la) = mk(5, 0.0);
|
||||
pipe.extract_windows(&fa, &la, "mmfi", &mut all);
|
||||
assert_eq!(all.len(), 3);
|
||||
let (fb, lb) = mk(4, 100.0);
|
||||
pipe.extract_windows(&fb, &lb, "wipose", &mut all);
|
||||
assert_eq!(all.len(), 5);
|
||||
assert_eq!(all[0].source, "mmfi");
|
||||
assert_eq!(all[3].source, "wipose");
|
||||
assert!(all[0].csi_window[0][0] < 10.0);
|
||||
assert!(all[4].csi_window[0][0] > 90.0);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,855 @@
|
||||
//! Graph Transformer + GNN for WiFi CSI-to-Pose estimation (ADR-023 Phase 2).
|
||||
//!
|
||||
//! Cross-attention bottleneck between antenna-space CSI features and COCO 17-keypoint
|
||||
//! body graph, followed by GCN message passing. All math is pure `std`.
|
||||
|
||||
/// Xorshift64 PRNG for deterministic weight initialization.
|
||||
#[derive(Debug, Clone)]
|
||||
struct Rng64 { state: u64 }
|
||||
|
||||
impl Rng64 {
|
||||
fn new(seed: u64) -> Self {
|
||||
Self { state: if seed == 0 { 0xDEAD_BEEF_CAFE_1234 } else { seed } }
|
||||
}
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
let mut x = self.state;
|
||||
x ^= x << 13; x ^= x >> 7; x ^= x << 17;
|
||||
self.state = x; x
|
||||
}
|
||||
/// Uniform f32 in (-1, 1).
|
||||
fn next_f32(&mut self) -> f32 {
|
||||
let f = (self.next_u64() >> 11) as f32 / (1u64 << 53) as f32;
|
||||
f * 2.0 - 1.0
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn relu(x: f32) -> f32 { if x > 0.0 { x } else { 0.0 } }
|
||||
|
||||
#[inline]
|
||||
fn sigmoid(x: f32) -> f32 {
|
||||
if x >= 0.0 { 1.0 / (1.0 + (-x).exp()) }
|
||||
else { let ex = x.exp(); ex / (1.0 + ex) }
|
||||
}
|
||||
|
||||
/// Numerically stable softmax. Writes normalised weights into `out`.
|
||||
fn softmax(scores: &[f32], out: &mut [f32]) {
|
||||
debug_assert_eq!(scores.len(), out.len());
|
||||
if scores.is_empty() { return; }
|
||||
let max = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
|
||||
let mut sum = 0.0f32;
|
||||
for (o, &s) in out.iter_mut().zip(scores) {
|
||||
let e = (s - max).exp(); *o = e; sum += e;
|
||||
}
|
||||
let inv = if sum > 1e-10 { 1.0 / sum } else { 0.0 };
|
||||
for o in out.iter_mut() { *o *= inv; }
|
||||
}
|
||||
|
||||
// ── Linear layer ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Dense linear transformation y = Wx + b (row-major weights).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Linear {
|
||||
in_features: usize,
|
||||
out_features: usize,
|
||||
weights: Vec<Vec<f32>>,
|
||||
bias: Vec<f32>,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
/// Xavier/Glorot uniform init with default seed.
|
||||
pub fn new(in_features: usize, out_features: usize) -> Self {
|
||||
Self::with_seed(in_features, out_features, 42)
|
||||
}
|
||||
/// Xavier/Glorot uniform init with explicit seed.
|
||||
pub fn with_seed(in_features: usize, out_features: usize, seed: u64) -> Self {
|
||||
let mut rng = Rng64::new(seed);
|
||||
let limit = (6.0 / (in_features + out_features) as f32).sqrt();
|
||||
let weights = (0..out_features)
|
||||
.map(|_| (0..in_features).map(|_| rng.next_f32() * limit).collect())
|
||||
.collect();
|
||||
Self { in_features, out_features, weights, bias: vec![0.0; out_features] }
|
||||
}
|
||||
/// All-zero weights (for testing).
|
||||
pub fn zeros(in_features: usize, out_features: usize) -> Self {
|
||||
Self {
|
||||
in_features, out_features,
|
||||
weights: vec![vec![0.0; in_features]; out_features],
|
||||
bias: vec![0.0; out_features],
|
||||
}
|
||||
}
|
||||
/// Forward pass: y = Wx + b.
|
||||
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
|
||||
assert_eq!(input.len(), self.in_features,
|
||||
"Linear input mismatch: expected {}, got {}", self.in_features, input.len());
|
||||
let mut out = vec![0.0f32; self.out_features];
|
||||
for (i, row) in self.weights.iter().enumerate() {
|
||||
let mut s = self.bias[i];
|
||||
for (w, x) in row.iter().zip(input) { s += w * x; }
|
||||
out[i] = s;
|
||||
}
|
||||
out
|
||||
}
|
||||
pub fn weights(&self) -> &[Vec<f32>] { &self.weights }
|
||||
pub fn set_weights(&mut self, w: Vec<Vec<f32>>) {
|
||||
assert_eq!(w.len(), self.out_features);
|
||||
for row in &w { assert_eq!(row.len(), self.in_features); }
|
||||
self.weights = w;
|
||||
}
|
||||
pub fn set_bias(&mut self, b: Vec<f32>) {
|
||||
assert_eq!(b.len(), self.out_features);
|
||||
self.bias = b;
|
||||
}
|
||||
|
||||
/// Push all weights (row-major) then bias into a flat vec.
|
||||
pub fn flatten_into(&self, out: &mut Vec<f32>) {
|
||||
for row in &self.weights {
|
||||
out.extend_from_slice(row);
|
||||
}
|
||||
out.extend_from_slice(&self.bias);
|
||||
}
|
||||
|
||||
/// Restore from a flat slice. Returns (Self, number of f32s consumed).
|
||||
pub fn unflatten_from(data: &[f32], in_f: usize, out_f: usize) -> (Self, usize) {
|
||||
let n = in_f * out_f + out_f;
|
||||
assert!(data.len() >= n, "unflatten_from: need {n} floats, got {}", data.len());
|
||||
let mut weights = Vec::with_capacity(out_f);
|
||||
for r in 0..out_f {
|
||||
let start = r * in_f;
|
||||
weights.push(data[start..start + in_f].to_vec());
|
||||
}
|
||||
let bias = data[in_f * out_f..n].to_vec();
|
||||
(Self { in_features: in_f, out_features: out_f, weights, bias }, n)
|
||||
}
|
||||
|
||||
/// Total number of trainable parameters.
|
||||
pub fn param_count(&self) -> usize {
|
||||
self.in_features * self.out_features + self.out_features
|
||||
}
|
||||
}
|
||||
|
||||
// ── AntennaGraph ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Spatial topology graph over TX-RX antenna pairs. Nodes = pairs, edges connect
|
||||
/// pairs sharing a TX or RX antenna.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AntennaGraph {
|
||||
n_tx: usize, n_rx: usize, n_pairs: usize,
|
||||
adjacency: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
impl AntennaGraph {
|
||||
/// Build antenna graph. pair_id = tx * n_rx + rx. Adjacent if shared TX or RX.
|
||||
pub fn new(n_tx: usize, n_rx: usize) -> Self {
|
||||
let n_pairs = n_tx * n_rx;
|
||||
let mut adj = vec![vec![0.0f32; n_pairs]; n_pairs];
|
||||
for i in 0..n_pairs {
|
||||
let (tx_i, rx_i) = (i / n_rx, i % n_rx);
|
||||
adj[i][i] = 1.0;
|
||||
for j in (i + 1)..n_pairs {
|
||||
let (tx_j, rx_j) = (j / n_rx, j % n_rx);
|
||||
if tx_i == tx_j || rx_i == rx_j {
|
||||
adj[i][j] = 1.0; adj[j][i] = 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
Self { n_tx, n_rx, n_pairs, adjacency: adj }
|
||||
}
|
||||
pub fn n_nodes(&self) -> usize { self.n_pairs }
|
||||
pub fn adjacency_matrix(&self) -> &Vec<Vec<f32>> { &self.adjacency }
|
||||
pub fn n_tx(&self) -> usize { self.n_tx }
|
||||
pub fn n_rx(&self) -> usize { self.n_rx }
|
||||
}
|
||||
|
||||
// ── BodyGraph ────────────────────────────────────────────────────────────
|
||||
|
||||
/// COCO 17-keypoint skeleton graph with 16 anatomical edges.
|
||||
///
|
||||
/// Indices: 0=nose 1=l_eye 2=r_eye 3=l_ear 4=r_ear 5=l_shoulder 6=r_shoulder
|
||||
/// 7=l_elbow 8=r_elbow 9=l_wrist 10=r_wrist 11=l_hip 12=r_hip 13=l_knee
|
||||
/// 14=r_knee 15=l_ankle 16=r_ankle
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BodyGraph {
|
||||
adjacency: [[f32; 17]; 17],
|
||||
edges: Vec<(usize, usize)>,
|
||||
}
|
||||
|
||||
pub const COCO_KEYPOINT_NAMES: [&str; 17] = [
|
||||
"nose","left_eye","right_eye","left_ear","right_ear",
|
||||
"left_shoulder","right_shoulder","left_elbow","right_elbow",
|
||||
"left_wrist","right_wrist","left_hip","right_hip",
|
||||
"left_knee","right_knee","left_ankle","right_ankle",
|
||||
];
|
||||
|
||||
const COCO_EDGES: [(usize, usize); 16] = [
|
||||
(0,1),(0,2),(1,3),(2,4),(5,6),(5,7),(7,9),(6,8),
|
||||
(8,10),(5,11),(6,12),(11,12),(11,13),(13,15),(12,14),(14,16),
|
||||
];
|
||||
|
||||
impl BodyGraph {
|
||||
pub fn new() -> Self {
|
||||
let mut adjacency = [[0.0f32; 17]; 17];
|
||||
for i in 0..17 { adjacency[i][i] = 1.0; }
|
||||
for &(u, v) in &COCO_EDGES { adjacency[u][v] = 1.0; adjacency[v][u] = 1.0; }
|
||||
Self { adjacency, edges: COCO_EDGES.to_vec() }
|
||||
}
|
||||
pub fn adjacency_matrix(&self) -> &[[f32; 17]; 17] { &self.adjacency }
|
||||
pub fn edge_list(&self) -> &Vec<(usize, usize)> { &self.edges }
|
||||
pub fn n_nodes(&self) -> usize { 17 }
|
||||
pub fn n_edges(&self) -> usize { self.edges.len() }
|
||||
|
||||
/// Degree of each node (including self-loop).
|
||||
pub fn degrees(&self) -> [f32; 17] {
|
||||
let mut deg = [0.0f32; 17];
|
||||
for i in 0..17 { for j in 0..17 { deg[i] += self.adjacency[i][j]; } }
|
||||
deg
|
||||
}
|
||||
/// Symmetric normalised adjacency D^{-1/2} A D^{-1/2}.
|
||||
pub fn normalized_adjacency(&self) -> [[f32; 17]; 17] {
|
||||
let deg = self.degrees();
|
||||
let inv_sqrt: Vec<f32> = deg.iter()
|
||||
.map(|&d| if d > 0.0 { 1.0 / d.sqrt() } else { 0.0 }).collect();
|
||||
let mut norm = [[0.0f32; 17]; 17];
|
||||
for i in 0..17 { for j in 0..17 {
|
||||
norm[i][j] = inv_sqrt[i] * self.adjacency[i][j] * inv_sqrt[j];
|
||||
}}
|
||||
norm
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for BodyGraph { fn default() -> Self { Self::new() } }
|
||||
|
||||
// ── CrossAttention ───────────────────────────────────────────────────────
|
||||
|
||||
/// Multi-head scaled dot-product cross-attention.
|
||||
/// Attn(Q,K,V) = softmax(QK^T / sqrt(d_k)) V, split into n_heads.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CrossAttention {
|
||||
d_model: usize, n_heads: usize, d_k: usize,
|
||||
w_q: Linear, w_k: Linear, w_v: Linear, w_o: Linear,
|
||||
}
|
||||
|
||||
impl CrossAttention {
|
||||
pub fn new(d_model: usize, n_heads: usize) -> Self {
|
||||
assert!(d_model % n_heads == 0,
|
||||
"d_model ({d_model}) must be divisible by n_heads ({n_heads})");
|
||||
let d_k = d_model / n_heads;
|
||||
let s = 123u64;
|
||||
Self { d_model, n_heads, d_k,
|
||||
w_q: Linear::with_seed(d_model, d_model, s),
|
||||
w_k: Linear::with_seed(d_model, d_model, s+1),
|
||||
w_v: Linear::with_seed(d_model, d_model, s+2),
|
||||
w_o: Linear::with_seed(d_model, d_model, s+3),
|
||||
}
|
||||
}
|
||||
/// query [n_q, d_model], key/value [n_kv, d_model] -> [n_q, d_model].
|
||||
pub fn forward(&self, query: &[Vec<f32>], key: &[Vec<f32>], value: &[Vec<f32>]) -> Vec<Vec<f32>> {
|
||||
let (n_q, n_kv) = (query.len(), key.len());
|
||||
if n_q == 0 || n_kv == 0 { return vec![vec![0.0; self.d_model]; n_q]; }
|
||||
|
||||
let q_proj: Vec<Vec<f32>> = query.iter().map(|q| self.w_q.forward(q)).collect();
|
||||
let k_proj: Vec<Vec<f32>> = key.iter().map(|k| self.w_k.forward(k)).collect();
|
||||
let v_proj: Vec<Vec<f32>> = value.iter().map(|v| self.w_v.forward(v)).collect();
|
||||
|
||||
let scale = (self.d_k as f32).sqrt();
|
||||
let mut output = vec![vec![0.0f32; self.d_model]; n_q];
|
||||
|
||||
for qi in 0..n_q {
|
||||
let mut concat = Vec::with_capacity(self.d_model);
|
||||
for h in 0..self.n_heads {
|
||||
let (start, end) = (h * self.d_k, (h + 1) * self.d_k);
|
||||
let q_h = &q_proj[qi][start..end];
|
||||
let mut scores = vec![0.0f32; n_kv];
|
||||
for ki in 0..n_kv {
|
||||
let dot: f32 = q_h.iter().zip(&k_proj[ki][start..end]).map(|(a,b)| a*b).sum();
|
||||
scores[ki] = dot / scale;
|
||||
}
|
||||
let mut wts = vec![0.0f32; n_kv];
|
||||
softmax(&scores, &mut wts);
|
||||
let mut head_out = vec![0.0f32; self.d_k];
|
||||
for ki in 0..n_kv {
|
||||
for (o, &v) in head_out.iter_mut().zip(&v_proj[ki][start..end]) {
|
||||
*o += wts[ki] * v;
|
||||
}
|
||||
}
|
||||
concat.extend_from_slice(&head_out);
|
||||
}
|
||||
output[qi] = self.w_o.forward(&concat);
|
||||
}
|
||||
output
|
||||
}
|
||||
pub fn d_model(&self) -> usize { self.d_model }
|
||||
pub fn n_heads(&self) -> usize { self.n_heads }
|
||||
|
||||
/// Push all cross-attention weights (w_q, w_k, w_v, w_o) into flat vec.
|
||||
pub fn flatten_into(&self, out: &mut Vec<f32>) {
|
||||
self.w_q.flatten_into(out);
|
||||
self.w_k.flatten_into(out);
|
||||
self.w_v.flatten_into(out);
|
||||
self.w_o.flatten_into(out);
|
||||
}
|
||||
|
||||
/// Restore cross-attention weights from flat slice. Returns (Self, consumed).
|
||||
pub fn unflatten_from(data: &[f32], d_model: usize, n_heads: usize) -> (Self, usize) {
|
||||
let mut offset = 0;
|
||||
let (w_q, n) = Linear::unflatten_from(&data[offset..], d_model, d_model);
|
||||
offset += n;
|
||||
let (w_k, n) = Linear::unflatten_from(&data[offset..], d_model, d_model);
|
||||
offset += n;
|
||||
let (w_v, n) = Linear::unflatten_from(&data[offset..], d_model, d_model);
|
||||
offset += n;
|
||||
let (w_o, n) = Linear::unflatten_from(&data[offset..], d_model, d_model);
|
||||
offset += n;
|
||||
let d_k = d_model / n_heads;
|
||||
(Self { d_model, n_heads, d_k, w_q, w_k, w_v, w_o }, offset)
|
||||
}
|
||||
|
||||
/// Total trainable params in cross-attention.
|
||||
pub fn param_count(&self) -> usize {
|
||||
self.w_q.param_count() + self.w_k.param_count()
|
||||
+ self.w_v.param_count() + self.w_o.param_count()
|
||||
}
|
||||
}
|
||||
|
||||
// ── GraphMessagePassing ──────────────────────────────────────────────────
|
||||
|
||||
/// GCN layer: H' = ReLU(A_norm H W) where A_norm = D^{-1/2} A D^{-1/2}.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GraphMessagePassing {
|
||||
pub(crate) in_features: usize,
|
||||
pub(crate) out_features: usize,
|
||||
pub(crate) weight: Linear,
|
||||
norm_adj: [[f32; 17]; 17],
|
||||
}
|
||||
|
||||
impl GraphMessagePassing {
|
||||
pub fn new(in_features: usize, out_features: usize, graph: &BodyGraph) -> Self {
|
||||
Self { in_features, out_features,
|
||||
weight: Linear::with_seed(in_features, out_features, 777),
|
||||
norm_adj: graph.normalized_adjacency() }
|
||||
}
|
||||
/// node_features [17, in_features] -> [17, out_features].
|
||||
pub fn forward(&self, node_features: &[Vec<f32>]) -> Vec<Vec<f32>> {
|
||||
assert_eq!(node_features.len(), 17, "expected 17 nodes, got {}", node_features.len());
|
||||
let mut agg = vec![vec![0.0f32; self.in_features]; 17];
|
||||
for i in 0..17 { for j in 0..17 {
|
||||
let a = self.norm_adj[i][j];
|
||||
if a.abs() > 1e-10 {
|
||||
for (ag, &f) in agg[i].iter_mut().zip(&node_features[j]) { *ag += a * f; }
|
||||
}
|
||||
}}
|
||||
agg.iter().map(|a| self.weight.forward(a).into_iter().map(relu).collect()).collect()
|
||||
}
|
||||
pub fn in_features(&self) -> usize { self.in_features }
|
||||
pub fn out_features(&self) -> usize { self.out_features }
|
||||
|
||||
/// Push all layer weights into a flat vec.
|
||||
pub fn flatten_into(&self, out: &mut Vec<f32>) {
|
||||
self.weight.flatten_into(out);
|
||||
}
|
||||
|
||||
/// Restore from a flat slice. Returns number of f32s consumed.
|
||||
pub fn unflatten_from(&mut self, data: &[f32]) -> usize {
|
||||
let (lin, consumed) = Linear::unflatten_from(data, self.in_features, self.out_features);
|
||||
self.weight = lin;
|
||||
consumed
|
||||
}
|
||||
|
||||
/// Total trainable params in this GCN layer.
|
||||
pub fn param_count(&self) -> usize { self.weight.param_count() }
|
||||
}
|
||||
|
||||
/// Stack of GCN layers.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GnnStack { pub(crate) layers: Vec<GraphMessagePassing> }
|
||||
|
||||
impl GnnStack {
|
||||
pub fn new(in_f: usize, out_f: usize, n: usize, g: &BodyGraph) -> Self {
|
||||
assert!(n >= 1);
|
||||
let mut layers = vec![GraphMessagePassing::new(in_f, out_f, g)];
|
||||
for _ in 1..n { layers.push(GraphMessagePassing::new(out_f, out_f, g)); }
|
||||
Self { layers }
|
||||
}
|
||||
pub fn forward(&self, feats: &[Vec<f32>]) -> Vec<Vec<f32>> {
|
||||
let mut h = feats.to_vec();
|
||||
for l in &self.layers { h = l.forward(&h); }
|
||||
h
|
||||
}
|
||||
/// Push all GNN weights into a flat vec.
|
||||
pub fn flatten_into(&self, out: &mut Vec<f32>) {
|
||||
for l in &self.layers { l.flatten_into(out); }
|
||||
}
|
||||
/// Restore GNN weights from flat slice. Returns number of f32s consumed.
|
||||
pub fn unflatten_from(&mut self, data: &[f32]) -> usize {
|
||||
let mut offset = 0;
|
||||
for l in &mut self.layers {
|
||||
offset += l.unflatten_from(&data[offset..]);
|
||||
}
|
||||
offset
|
||||
}
|
||||
/// Total trainable params across all GCN layers.
|
||||
pub fn param_count(&self) -> usize {
|
||||
self.layers.iter().map(|l| l.param_count()).sum()
|
||||
}
|
||||
}
|
||||
|
||||
// ── Transformer config / output / pipeline ───────────────────────────────
|
||||
|
||||
/// Configuration for the CSI-to-Pose transformer.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TransformerConfig {
|
||||
pub n_subcarriers: usize,
|
||||
pub n_keypoints: usize,
|
||||
pub d_model: usize,
|
||||
pub n_heads: usize,
|
||||
pub n_gnn_layers: usize,
|
||||
}
|
||||
|
||||
impl Default for TransformerConfig {
|
||||
fn default() -> Self {
|
||||
Self { n_subcarriers: 56, n_keypoints: 17, d_model: 64, n_heads: 4, n_gnn_layers: 2 }
|
||||
}
|
||||
}
|
||||
|
||||
/// Output of the CSI-to-Pose transformer.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PoseOutput {
|
||||
/// Predicted (x, y, z) per keypoint.
|
||||
pub keypoints: Vec<(f32, f32, f32)>,
|
||||
/// Per-keypoint confidence in [0, 1].
|
||||
pub confidences: Vec<f32>,
|
||||
/// Per-keypoint GNN features for downstream use.
|
||||
pub body_part_features: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
/// Full CSI-to-Pose pipeline: CSI embed -> cross-attention -> GNN -> regression heads.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CsiToPoseTransformer {
|
||||
config: TransformerConfig,
|
||||
csi_embed: Linear,
|
||||
keypoint_queries: Vec<Vec<f32>>,
|
||||
cross_attn: CrossAttention,
|
||||
gnn: GnnStack,
|
||||
xyz_head: Linear,
|
||||
conf_head: Linear,
|
||||
}
|
||||
|
||||
impl CsiToPoseTransformer {
|
||||
pub fn new(config: TransformerConfig) -> Self {
|
||||
let d = config.d_model;
|
||||
let bg = BodyGraph::new();
|
||||
let mut rng = Rng64::new(999);
|
||||
let limit = (6.0 / (config.n_keypoints + d) as f32).sqrt();
|
||||
let kq: Vec<Vec<f32>> = (0..config.n_keypoints)
|
||||
.map(|_| (0..d).map(|_| rng.next_f32() * limit).collect()).collect();
|
||||
Self {
|
||||
csi_embed: Linear::with_seed(config.n_subcarriers, d, 500),
|
||||
keypoint_queries: kq,
|
||||
cross_attn: CrossAttention::new(d, config.n_heads),
|
||||
gnn: GnnStack::new(d, d, config.n_gnn_layers, &bg),
|
||||
xyz_head: Linear::with_seed(d, 3, 600),
|
||||
conf_head: Linear::with_seed(d, 1, 700),
|
||||
config,
|
||||
}
|
||||
}
|
||||
/// Construct with zero-initialized weights (faster than Xavier init).
|
||||
/// Use with `unflatten_weights()` when you plan to overwrite all weights.
|
||||
pub fn zeros(config: TransformerConfig) -> Self {
|
||||
let d = config.d_model;
|
||||
let bg = BodyGraph::new();
|
||||
let kq = vec![vec![0.0f32; d]; config.n_keypoints];
|
||||
Self {
|
||||
csi_embed: Linear::zeros(config.n_subcarriers, d),
|
||||
keypoint_queries: kq,
|
||||
cross_attn: CrossAttention::new(d, config.n_heads), // small; kept for correct structure
|
||||
gnn: GnnStack::new(d, d, config.n_gnn_layers, &bg),
|
||||
xyz_head: Linear::zeros(d, 3),
|
||||
conf_head: Linear::zeros(d, 1),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// csi_features [n_antenna_pairs, n_subcarriers] -> PoseOutput with 17 keypoints.
|
||||
pub fn forward(&self, csi_features: &[Vec<f32>]) -> PoseOutput {
|
||||
let embedded: Vec<Vec<f32>> = csi_features.iter()
|
||||
.map(|f| self.csi_embed.forward(f)).collect();
|
||||
let attended = self.cross_attn.forward(&self.keypoint_queries, &embedded, &embedded);
|
||||
let gnn_out = self.gnn.forward(&attended);
|
||||
let mut kps = Vec::with_capacity(self.config.n_keypoints);
|
||||
let mut confs = Vec::with_capacity(self.config.n_keypoints);
|
||||
for nf in &gnn_out {
|
||||
let xyz = self.xyz_head.forward(nf);
|
||||
kps.push((xyz[0], xyz[1], xyz[2]));
|
||||
confs.push(sigmoid(self.conf_head.forward(nf)[0]));
|
||||
}
|
||||
PoseOutput { keypoints: kps, confidences: confs, body_part_features: gnn_out }
|
||||
}
|
||||
pub fn config(&self) -> &TransformerConfig { &self.config }
|
||||
|
||||
/// Collect all trainable parameters into a flat vec.
|
||||
///
|
||||
/// Layout: csi_embed | keypoint_queries (flat) | cross_attn | gnn | xyz_head | conf_head
|
||||
pub fn flatten_weights(&self) -> Vec<f32> {
|
||||
let mut out = Vec::with_capacity(self.param_count());
|
||||
self.csi_embed.flatten_into(&mut out);
|
||||
for kq in &self.keypoint_queries {
|
||||
out.extend_from_slice(kq);
|
||||
}
|
||||
self.cross_attn.flatten_into(&mut out);
|
||||
self.gnn.flatten_into(&mut out);
|
||||
self.xyz_head.flatten_into(&mut out);
|
||||
self.conf_head.flatten_into(&mut out);
|
||||
out
|
||||
}
|
||||
|
||||
/// Restore all trainable parameters from a flat slice.
|
||||
pub fn unflatten_weights(&mut self, params: &[f32]) -> Result<(), String> {
|
||||
let expected = self.param_count();
|
||||
if params.len() != expected {
|
||||
return Err(format!("expected {expected} params, got {}", params.len()));
|
||||
}
|
||||
let mut offset = 0;
|
||||
|
||||
// csi_embed
|
||||
let (embed, n) = Linear::unflatten_from(¶ms[offset..],
|
||||
self.config.n_subcarriers, self.config.d_model);
|
||||
self.csi_embed = embed;
|
||||
offset += n;
|
||||
|
||||
// keypoint_queries
|
||||
let d = self.config.d_model;
|
||||
for kq in &mut self.keypoint_queries {
|
||||
kq.copy_from_slice(¶ms[offset..offset + d]);
|
||||
offset += d;
|
||||
}
|
||||
|
||||
// cross_attn
|
||||
let (ca, n) = CrossAttention::unflatten_from(¶ms[offset..],
|
||||
self.config.d_model, self.cross_attn.n_heads());
|
||||
self.cross_attn = ca;
|
||||
offset += n;
|
||||
|
||||
// gnn
|
||||
let n = self.gnn.unflatten_from(¶ms[offset..]);
|
||||
offset += n;
|
||||
|
||||
// xyz_head
|
||||
let (xyz, n) = Linear::unflatten_from(¶ms[offset..], self.config.d_model, 3);
|
||||
self.xyz_head = xyz;
|
||||
offset += n;
|
||||
|
||||
// conf_head
|
||||
let (conf, n) = Linear::unflatten_from(¶ms[offset..], self.config.d_model, 1);
|
||||
self.conf_head = conf;
|
||||
offset += n;
|
||||
|
||||
debug_assert_eq!(offset, expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Total number of trainable parameters.
|
||||
pub fn param_count(&self) -> usize {
|
||||
self.csi_embed.param_count()
|
||||
+ self.config.n_keypoints * self.config.d_model // keypoint queries
|
||||
+ self.cross_attn.param_count()
|
||||
+ self.gnn.param_count()
|
||||
+ self.xyz_head.param_count()
|
||||
+ self.conf_head.param_count()
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn body_graph_has_17_nodes() {
|
||||
assert_eq!(BodyGraph::new().n_nodes(), 17);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn body_graph_has_16_edges() {
|
||||
let g = BodyGraph::new();
|
||||
assert_eq!(g.n_edges(), 16);
|
||||
assert_eq!(g.edge_list().len(), 16);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn body_graph_adjacency_symmetric() {
|
||||
let bg = BodyGraph::new();
|
||||
let adj = bg.adjacency_matrix();
|
||||
for i in 0..17 { for j in 0..17 {
|
||||
assert_eq!(adj[i][j], adj[j][i], "asymmetric at ({i},{j})");
|
||||
}}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn body_graph_self_loops_and_specific_edges() {
|
||||
let bg = BodyGraph::new();
|
||||
let adj = bg.adjacency_matrix();
|
||||
for i in 0..17 { assert_eq!(adj[i][i], 1.0); }
|
||||
assert_eq!(adj[0][1], 1.0); // nose-left_eye
|
||||
assert_eq!(adj[5][6], 1.0); // l_shoulder-r_shoulder
|
||||
assert_eq!(adj[14][16], 1.0); // r_knee-r_ankle
|
||||
assert_eq!(adj[0][15], 0.0); // nose should NOT connect to l_ankle
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn antenna_graph_node_count() {
|
||||
assert_eq!(AntennaGraph::new(3, 3).n_nodes(), 9);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn antenna_graph_adjacency() {
|
||||
let ag = AntennaGraph::new(2, 2);
|
||||
let adj = ag.adjacency_matrix();
|
||||
assert_eq!(adj[0][1], 1.0); // share tx=0
|
||||
assert_eq!(adj[0][2], 1.0); // share rx=0
|
||||
assert_eq!(adj[0][3], 0.0); // share neither
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cross_attention_output_shape() {
|
||||
let ca = CrossAttention::new(16, 4);
|
||||
let out = ca.forward(&vec![vec![0.5; 16]; 5], &vec![vec![0.3; 16]; 3], &vec![vec![0.7; 16]; 3]);
|
||||
assert_eq!(out.len(), 5);
|
||||
for r in &out { assert_eq!(r.len(), 16); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cross_attention_single_head_vs_multi() {
|
||||
let (q, k, v) = (vec![vec![1.0f32; 8]; 2], vec![vec![0.5; 8]; 3], vec![vec![0.5; 8]; 3]);
|
||||
let o1 = CrossAttention::new(8, 1).forward(&q, &k, &v);
|
||||
let o2 = CrossAttention::new(8, 2).forward(&q, &k, &v);
|
||||
assert_eq!(o1.len(), o2.len());
|
||||
assert_eq!(o1[0].len(), o2[0].len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scaled_dot_product_softmax_sums_to_one() {
|
||||
let scores = vec![1.0f32, 2.0, 3.0, 0.5];
|
||||
let mut w = vec![0.0f32; 4];
|
||||
softmax(&scores, &mut w);
|
||||
assert!((w.iter().sum::<f32>() - 1.0).abs() < 1e-5);
|
||||
for &wi in &w { assert!(wi > 0.0); }
|
||||
assert!(w[2] > w[0] && w[2] > w[1] && w[2] > w[3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gnn_message_passing_shape() {
|
||||
let g = BodyGraph::new();
|
||||
let out = GraphMessagePassing::new(32, 16, &g).forward(&vec![vec![1.0; 32]; 17]);
|
||||
assert_eq!(out.len(), 17);
|
||||
for r in &out { assert_eq!(r.len(), 16); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gnn_preserves_isolated_node() {
|
||||
let g = BodyGraph::new();
|
||||
let gmp = GraphMessagePassing::new(8, 8, &g);
|
||||
let mut feats: Vec<Vec<f32>> = vec![vec![0.0; 8]; 17];
|
||||
feats[0] = vec![1.0; 8]; // only nose has signal
|
||||
let out = gmp.forward(&feats);
|
||||
let ankle_e: f32 = out[15].iter().map(|x| x*x).sum();
|
||||
let nose_e: f32 = out[0].iter().map(|x| x*x).sum();
|
||||
assert!(nose_e > ankle_e, "nose ({nose_e}) should > ankle ({ankle_e})");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_layer_output_size() {
|
||||
assert_eq!(Linear::new(10, 5).forward(&vec![1.0; 10]).len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_layer_zero_weights() {
|
||||
let out = Linear::zeros(4, 3).forward(&[1.0, 2.0, 3.0, 4.0]);
|
||||
for &v in &out { assert_eq!(v, 0.0); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_layer_set_weights_identity() {
|
||||
let mut lin = Linear::zeros(2, 2);
|
||||
lin.set_weights(vec![vec![1.0, 0.0], vec![0.0, 1.0]]);
|
||||
let out = lin.forward(&[3.0, 7.0]);
|
||||
assert!((out[0] - 3.0).abs() < 1e-6 && (out[1] - 7.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transformer_config_defaults() {
|
||||
let c = TransformerConfig::default();
|
||||
assert_eq!((c.n_subcarriers, c.n_keypoints, c.d_model, c.n_heads, c.n_gnn_layers),
|
||||
(56, 17, 64, 4, 2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transformer_forward_output_17_keypoints() {
|
||||
let t = CsiToPoseTransformer::new(TransformerConfig {
|
||||
n_subcarriers: 16, n_keypoints: 17, d_model: 8, n_heads: 2, n_gnn_layers: 1,
|
||||
});
|
||||
let out = t.forward(&vec![vec![0.5; 16]; 4]);
|
||||
assert_eq!(out.keypoints.len(), 17);
|
||||
assert_eq!(out.confidences.len(), 17);
|
||||
assert_eq!(out.body_part_features.len(), 17);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transformer_keypoints_are_finite() {
|
||||
let t = CsiToPoseTransformer::new(TransformerConfig {
|
||||
n_subcarriers: 8, n_keypoints: 17, d_model: 8, n_heads: 2, n_gnn_layers: 2,
|
||||
});
|
||||
let out = t.forward(&vec![vec![1.0; 8]; 6]);
|
||||
for (i, &(x, y, z)) in out.keypoints.iter().enumerate() {
|
||||
assert!(x.is_finite() && y.is_finite() && z.is_finite(), "kp {i} not finite");
|
||||
}
|
||||
for (i, &c) in out.confidences.iter().enumerate() {
|
||||
assert!(c.is_finite() && (0.0..=1.0).contains(&c), "conf {i} invalid: {c}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn relu_activation() {
|
||||
assert_eq!(relu(-5.0), 0.0);
|
||||
assert_eq!(relu(-0.001), 0.0);
|
||||
assert_eq!(relu(0.0), 0.0);
|
||||
assert_eq!(relu(3.14), 3.14);
|
||||
assert_eq!(relu(100.0), 100.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sigmoid_bounds() {
|
||||
assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
|
||||
assert!(sigmoid(100.0) > 0.999);
|
||||
assert!(sigmoid(-100.0) < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deterministic_rng_and_linear() {
|
||||
let (mut r1, mut r2) = (Rng64::new(42), Rng64::new(42));
|
||||
for _ in 0..100 { assert_eq!(r1.next_u64(), r2.next_u64()); }
|
||||
let inp = vec![1.0, 2.0, 3.0, 4.0];
|
||||
assert_eq!(Linear::with_seed(4, 3, 99).forward(&inp),
|
||||
Linear::with_seed(4, 3, 99).forward(&inp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn body_graph_normalized_adjacency_finite() {
|
||||
let norm = BodyGraph::new().normalized_adjacency();
|
||||
for i in 0..17 {
|
||||
let s: f32 = norm[i].iter().sum();
|
||||
assert!(s.is_finite() && s > 0.0, "row {i} sum={s}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cross_attention_empty_keys() {
|
||||
let out = CrossAttention::new(8, 2).forward(
|
||||
&vec![vec![1.0; 8]; 3], &vec![], &vec![]);
|
||||
assert_eq!(out.len(), 3);
|
||||
for r in &out { for &v in r { assert_eq!(v, 0.0); } }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn softmax_edge_cases() {
|
||||
let mut w1 = vec![0.0f32; 1];
|
||||
softmax(&[42.0], &mut w1);
|
||||
assert!((w1[0] - 1.0).abs() < 1e-6);
|
||||
|
||||
let mut w3 = vec![0.0f32; 3];
|
||||
softmax(&[1000.0, 1001.0, 999.0], &mut w3);
|
||||
let sum: f32 = w3.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-5);
|
||||
for &wi in &w3 { assert!(wi.is_finite()); }
|
||||
}
|
||||
|
||||
// ── Weight serialization integration tests ────────────────────────
|
||||
|
||||
#[test]
|
||||
fn linear_flatten_unflatten_roundtrip() {
|
||||
let lin = Linear::with_seed(8, 4, 42);
|
||||
let mut flat = Vec::new();
|
||||
lin.flatten_into(&mut flat);
|
||||
assert_eq!(flat.len(), lin.param_count());
|
||||
let (restored, consumed) = Linear::unflatten_from(&flat, 8, 4);
|
||||
assert_eq!(consumed, flat.len());
|
||||
let inp = vec![1.0f32; 8];
|
||||
assert_eq!(lin.forward(&inp), restored.forward(&inp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cross_attention_flatten_unflatten_roundtrip() {
|
||||
let ca = CrossAttention::new(16, 4);
|
||||
let mut flat = Vec::new();
|
||||
ca.flatten_into(&mut flat);
|
||||
assert_eq!(flat.len(), ca.param_count());
|
||||
let (restored, consumed) = CrossAttention::unflatten_from(&flat, 16, 4);
|
||||
assert_eq!(consumed, flat.len());
|
||||
let q = vec![vec![0.5f32; 16]; 3];
|
||||
let k = vec![vec![0.3f32; 16]; 5];
|
||||
let v = vec![vec![0.7f32; 16]; 5];
|
||||
let orig = ca.forward(&q, &k, &v);
|
||||
let rest = restored.forward(&q, &k, &v);
|
||||
for (a, b) in orig.iter().zip(rest.iter()) {
|
||||
for (x, y) in a.iter().zip(b.iter()) {
|
||||
assert!((x - y).abs() < 1e-6, "mismatch: {x} vs {y}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transformer_weight_roundtrip() {
|
||||
let config = TransformerConfig {
|
||||
n_subcarriers: 16, n_keypoints: 17, d_model: 8, n_heads: 2, n_gnn_layers: 1,
|
||||
};
|
||||
let t = CsiToPoseTransformer::new(config.clone());
|
||||
let weights = t.flatten_weights();
|
||||
assert_eq!(weights.len(), t.param_count());
|
||||
|
||||
let mut t2 = CsiToPoseTransformer::new(config);
|
||||
t2.unflatten_weights(&weights).expect("unflatten should succeed");
|
||||
|
||||
// Forward pass should produce identical results
|
||||
let csi = vec![vec![0.5f32; 16]; 4];
|
||||
let out1 = t.forward(&csi);
|
||||
let out2 = t2.forward(&csi);
|
||||
for (a, b) in out1.keypoints.iter().zip(out2.keypoints.iter()) {
|
||||
assert!((a.0 - b.0).abs() < 1e-6);
|
||||
assert!((a.1 - b.1).abs() < 1e-6);
|
||||
assert!((a.2 - b.2).abs() < 1e-6);
|
||||
}
|
||||
for (a, b) in out1.confidences.iter().zip(out2.confidences.iter()) {
|
||||
assert!((a - b).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transformer_param_count_positive() {
|
||||
let t = CsiToPoseTransformer::new(TransformerConfig::default());
|
||||
assert!(t.param_count() > 1000, "expected many params, got {}", t.param_count());
|
||||
let flat = t.flatten_weights();
|
||||
assert_eq!(flat.len(), t.param_count());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gnn_stack_flatten_unflatten() {
|
||||
let bg = BodyGraph::new();
|
||||
let gnn = GnnStack::new(8, 8, 2, &bg);
|
||||
let mut flat = Vec::new();
|
||||
gnn.flatten_into(&mut flat);
|
||||
assert_eq!(flat.len(), gnn.param_count());
|
||||
|
||||
let mut gnn2 = GnnStack::new(8, 8, 2, &bg);
|
||||
let consumed = gnn2.unflatten_from(&flat);
|
||||
assert_eq!(consumed, flat.len());
|
||||
|
||||
let feats = vec![vec![1.0f32; 8]; 17];
|
||||
let o1 = gnn.forward(&feats);
|
||||
let o2 = gnn2.forward(&feats);
|
||||
for (a, b) in o1.iter().zip(o2.iter()) {
|
||||
for (x, y) in a.iter().zip(b.iter()) {
|
||||
assert!((x - y).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
//! WiFi-DensePose Sensing Server library.
|
||||
//!
|
||||
//! This crate provides:
|
||||
//! - Vital sign detection from WiFi CSI amplitude data
|
||||
//! - RVF (RuVector Format) binary container for model weights
|
||||
|
||||
pub mod vital_signs;
|
||||
pub mod rvf_container;
|
||||
pub mod rvf_pipeline;
|
||||
pub mod graph_transformer;
|
||||
pub mod trainer;
|
||||
pub mod dataset;
|
||||
pub mod sona;
|
||||
pub mod sparse_inference;
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,914 @@
|
||||
//! Standalone RVF container builder and reader for WiFi-DensePose model packaging.
|
||||
//!
|
||||
//! Implements the RVF binary format (64-byte segment headers + payload) without
|
||||
//! depending on the `rvf-wire` crate. Supports building `.rvf` files that package
|
||||
//! model weights, metadata, and configuration into a single binary container.
|
||||
//!
|
||||
//! Wire format per segment:
|
||||
//! - 64-byte header (see `SegmentHeader`)
|
||||
//! - N-byte payload
|
||||
//! - Zero-padding to next 64-byte boundary
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io::Write;
|
||||
|
||||
// ── RVF format constants ────────────────────────────────────────────────────
|
||||
|
||||
/// Segment header magic: "RVFS" as big-endian u32 = 0x52564653.
|
||||
const SEGMENT_MAGIC: u32 = 0x5256_4653;
|
||||
/// Current segment format version.
|
||||
const SEGMENT_VERSION: u8 = 1;
|
||||
/// All segments are 64-byte aligned.
|
||||
const SEGMENT_ALIGNMENT: usize = 64;
|
||||
/// Fixed header size in bytes.
|
||||
const SEGMENT_HEADER_SIZE: usize = 64;
|
||||
|
||||
// ── Segment type discriminators (subset relevant to DensePose models) ───────
|
||||
|
||||
/// Raw vector payloads (model weight embeddings).
|
||||
const SEG_VEC: u8 = 0x01;
|
||||
/// Segment directory / manifest.
|
||||
const SEG_MANIFEST: u8 = 0x05;
|
||||
/// Quantization dictionaries and codebooks.
|
||||
const SEG_QUANT: u8 = 0x06;
|
||||
/// Arbitrary key-value metadata (JSON).
|
||||
const SEG_META: u8 = 0x07;
|
||||
/// Capability manifests, proof of computation, audit trails.
|
||||
const SEG_WITNESS: u8 = 0x0A;
|
||||
/// Domain profile declarations.
|
||||
const SEG_PROFILE: u8 = 0x0B;
|
||||
|
||||
// ── Pure-Rust CRC32 (IEEE 802.3 polynomial) ────────────────────────────────
|
||||
|
||||
/// CRC32 lookup table, computed at compile time via the IEEE 802.3 polynomial
|
||||
/// 0xEDB88320 (bit-reversed representation of 0x04C11DB7).
|
||||
const CRC32_TABLE: [u32; 256] = {
|
||||
let mut table = [0u32; 256];
|
||||
let mut i = 0u32;
|
||||
while i < 256 {
|
||||
let mut crc = i;
|
||||
let mut j = 0;
|
||||
while j < 8 {
|
||||
if crc & 1 != 0 {
|
||||
crc = (crc >> 1) ^ 0xEDB8_8320;
|
||||
} else {
|
||||
crc >>= 1;
|
||||
}
|
||||
j += 1;
|
||||
}
|
||||
table[i as usize] = crc;
|
||||
i += 1;
|
||||
}
|
||||
table
|
||||
};
|
||||
|
||||
/// Compute CRC32 (IEEE) over the given byte slice.
|
||||
fn crc32(data: &[u8]) -> u32 {
|
||||
let mut crc: u32 = 0xFFFF_FFFF;
|
||||
for &byte in data {
|
||||
let idx = ((crc ^ byte as u32) & 0xFF) as usize;
|
||||
crc = (crc >> 8) ^ CRC32_TABLE[idx];
|
||||
}
|
||||
crc ^ 0xFFFF_FFFF
|
||||
}
|
||||
|
||||
/// Produce a 16-byte content hash field from CRC32.
|
||||
/// The 4-byte CRC is stored in the first 4 bytes (little-endian), remaining
|
||||
/// 12 bytes are zeroed.
|
||||
fn crc32_content_hash(data: &[u8]) -> [u8; 16] {
|
||||
let c = crc32(data);
|
||||
let mut out = [0u8; 16];
|
||||
out[..4].copy_from_slice(&c.to_le_bytes());
|
||||
out
|
||||
}
|
||||
|
||||
// ── Segment header (mirrors rvf-types SegmentHeader layout) ─────────────────
|
||||
|
||||
/// 64-byte segment header matching the RVF wire format exactly.
|
||||
///
|
||||
/// Field offsets:
|
||||
/// - 0x00: magic (u32)
|
||||
/// - 0x04: version (u8)
|
||||
/// - 0x05: seg_type (u8)
|
||||
/// - 0x06: flags (u16)
|
||||
/// - 0x08: segment_id (u64)
|
||||
/// - 0x10: payload_length (u64)
|
||||
/// - 0x18: timestamp_ns (u64)
|
||||
/// - 0x20: checksum_algo (u8)
|
||||
/// - 0x21: compression (u8)
|
||||
/// - 0x22: reserved_0 (u16)
|
||||
/// - 0x24: reserved_1 (u32)
|
||||
/// - 0x28: content_hash ([u8; 16])
|
||||
/// - 0x38: uncompressed_len (u32)
|
||||
/// - 0x3C: alignment_pad (u32)
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SegmentHeader {
|
||||
pub magic: u32,
|
||||
pub version: u8,
|
||||
pub seg_type: u8,
|
||||
pub flags: u16,
|
||||
pub segment_id: u64,
|
||||
pub payload_length: u64,
|
||||
pub timestamp_ns: u64,
|
||||
pub checksum_algo: u8,
|
||||
pub compression: u8,
|
||||
pub reserved_0: u16,
|
||||
pub reserved_1: u32,
|
||||
pub content_hash: [u8; 16],
|
||||
pub uncompressed_len: u32,
|
||||
pub alignment_pad: u32,
|
||||
}
|
||||
|
||||
impl SegmentHeader {
|
||||
/// Create a new header with the given type and segment ID.
|
||||
fn new(seg_type: u8, segment_id: u64) -> Self {
|
||||
Self {
|
||||
magic: SEGMENT_MAGIC,
|
||||
version: SEGMENT_VERSION,
|
||||
seg_type,
|
||||
flags: 0,
|
||||
segment_id,
|
||||
payload_length: 0,
|
||||
timestamp_ns: 0,
|
||||
checksum_algo: 0, // CRC32
|
||||
compression: 0,
|
||||
reserved_0: 0,
|
||||
reserved_1: 0,
|
||||
content_hash: [0u8; 16],
|
||||
uncompressed_len: 0,
|
||||
alignment_pad: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Serialize the header into exactly 64 bytes (little-endian).
|
||||
fn to_bytes(&self) -> [u8; 64] {
|
||||
let mut buf = [0u8; 64];
|
||||
buf[0x00..0x04].copy_from_slice(&self.magic.to_le_bytes());
|
||||
buf[0x04] = self.version;
|
||||
buf[0x05] = self.seg_type;
|
||||
buf[0x06..0x08].copy_from_slice(&self.flags.to_le_bytes());
|
||||
buf[0x08..0x10].copy_from_slice(&self.segment_id.to_le_bytes());
|
||||
buf[0x10..0x18].copy_from_slice(&self.payload_length.to_le_bytes());
|
||||
buf[0x18..0x20].copy_from_slice(&self.timestamp_ns.to_le_bytes());
|
||||
buf[0x20] = self.checksum_algo;
|
||||
buf[0x21] = self.compression;
|
||||
buf[0x22..0x24].copy_from_slice(&self.reserved_0.to_le_bytes());
|
||||
buf[0x24..0x28].copy_from_slice(&self.reserved_1.to_le_bytes());
|
||||
buf[0x28..0x38].copy_from_slice(&self.content_hash);
|
||||
buf[0x38..0x3C].copy_from_slice(&self.uncompressed_len.to_le_bytes());
|
||||
buf[0x3C..0x40].copy_from_slice(&self.alignment_pad.to_le_bytes());
|
||||
buf
|
||||
}
|
||||
|
||||
/// Deserialize a header from exactly 64 bytes (little-endian).
|
||||
fn from_bytes(data: &[u8; 64]) -> Self {
|
||||
let mut content_hash = [0u8; 16];
|
||||
content_hash.copy_from_slice(&data[0x28..0x38]);
|
||||
|
||||
Self {
|
||||
magic: u32::from_le_bytes([data[0], data[1], data[2], data[3]]),
|
||||
version: data[0x04],
|
||||
seg_type: data[0x05],
|
||||
flags: u16::from_le_bytes([data[0x06], data[0x07]]),
|
||||
segment_id: u64::from_le_bytes(data[0x08..0x10].try_into().unwrap()),
|
||||
payload_length: u64::from_le_bytes(data[0x10..0x18].try_into().unwrap()),
|
||||
timestamp_ns: u64::from_le_bytes(data[0x18..0x20].try_into().unwrap()),
|
||||
checksum_algo: data[0x20],
|
||||
compression: data[0x21],
|
||||
reserved_0: u16::from_le_bytes([data[0x22], data[0x23]]),
|
||||
reserved_1: u32::from_le_bytes(data[0x24..0x28].try_into().unwrap()),
|
||||
content_hash,
|
||||
uncompressed_len: u32::from_le_bytes(data[0x38..0x3C].try_into().unwrap()),
|
||||
alignment_pad: u32::from_le_bytes(data[0x3C..0x40].try_into().unwrap()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Vital sign detector config ──────────────────────────────────────────────
|
||||
|
||||
/// Configuration for the WiFi-based vital sign detector.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VitalSignConfig {
|
||||
/// Breathing rate band low bound (Hz).
|
||||
pub breathing_low_hz: f64,
|
||||
/// Breathing rate band high bound (Hz).
|
||||
pub breathing_high_hz: f64,
|
||||
/// Heart rate band low bound (Hz).
|
||||
pub heartrate_low_hz: f64,
|
||||
/// Heart rate band high bound (Hz).
|
||||
pub heartrate_high_hz: f64,
|
||||
/// Minimum subcarrier count for valid detection.
|
||||
pub min_subcarriers: u32,
|
||||
/// Window size in samples for spectral analysis.
|
||||
pub window_size: u32,
|
||||
/// Confidence threshold (0.0 - 1.0).
|
||||
pub confidence_threshold: f64,
|
||||
}
|
||||
|
||||
impl Default for VitalSignConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
breathing_low_hz: 0.1,
|
||||
breathing_high_hz: 0.5,
|
||||
heartrate_low_hz: 0.8,
|
||||
heartrate_high_hz: 2.0,
|
||||
min_subcarriers: 52,
|
||||
window_size: 512,
|
||||
confidence_threshold: 0.6,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── RVF container info (returned by the REST API) ───────────────────────────
|
||||
|
||||
/// Summary of a loaded RVF container, exposed via `/api/v1/model/info`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RvfContainerInfo {
|
||||
pub segment_count: usize,
|
||||
pub total_size: usize,
|
||||
pub manifest: Option<serde_json::Value>,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
pub has_weights: bool,
|
||||
pub has_vital_config: bool,
|
||||
pub has_quant_info: bool,
|
||||
pub has_witness: bool,
|
||||
}
|
||||
|
||||
// ── RVF Builder ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// Builds an RVF container by accumulating segments and serializing them
|
||||
/// into the binary format: `[header(64) | payload | padding]*`.
|
||||
pub struct RvfBuilder {
|
||||
segments: Vec<(SegmentHeader, Vec<u8>)>,
|
||||
next_id: u64,
|
||||
}
|
||||
|
||||
impl RvfBuilder {
|
||||
/// Create a new empty builder.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
segments: Vec::new(),
|
||||
next_id: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a manifest segment with model metadata.
|
||||
pub fn add_manifest(&mut self, model_id: &str, version: &str, description: &str) {
|
||||
let manifest = serde_json::json!({
|
||||
"model_id": model_id,
|
||||
"version": version,
|
||||
"description": description,
|
||||
"format": "wifi-densepose-rvf",
|
||||
"created_at": chrono::Utc::now().to_rfc3339(),
|
||||
});
|
||||
let payload = serde_json::to_vec(&manifest).unwrap_or_default();
|
||||
self.push_segment(SEG_MANIFEST, &payload);
|
||||
}
|
||||
|
||||
/// Add model weights as a Vec segment. Weights are serialized as
|
||||
/// little-endian f32 values.
|
||||
pub fn add_weights(&mut self, weights: &[f32]) {
|
||||
let mut payload = Vec::with_capacity(weights.len() * 4);
|
||||
for &w in weights {
|
||||
payload.extend_from_slice(&w.to_le_bytes());
|
||||
}
|
||||
self.push_segment(SEG_VEC, &payload);
|
||||
}
|
||||
|
||||
/// Add metadata (arbitrary JSON key-value pairs).
|
||||
pub fn add_metadata(&mut self, metadata: &serde_json::Value) {
|
||||
let payload = serde_json::to_vec(metadata).unwrap_or_default();
|
||||
self.push_segment(SEG_META, &payload);
|
||||
}
|
||||
|
||||
/// Add vital sign detector configuration as a Profile segment.
|
||||
pub fn add_vital_config(&mut self, config: &VitalSignConfig) {
|
||||
let payload = serde_json::to_vec(config).unwrap_or_default();
|
||||
self.push_segment(SEG_PROFILE, &payload);
|
||||
}
|
||||
|
||||
/// Add quantization info as a Quant segment.
|
||||
pub fn add_quant_info(&mut self, quant_type: &str, scale: f32, zero_point: i32) {
|
||||
let info = serde_json::json!({
|
||||
"quant_type": quant_type,
|
||||
"scale": scale,
|
||||
"zero_point": zero_point,
|
||||
});
|
||||
let payload = serde_json::to_vec(&info).unwrap_or_default();
|
||||
self.push_segment(SEG_QUANT, &payload);
|
||||
}
|
||||
|
||||
/// Add a raw segment with arbitrary type and payload.
|
||||
/// Used by `rvf_pipeline` for extended segment types.
|
||||
pub fn add_raw_segment(&mut self, seg_type: u8, payload: &[u8]) {
|
||||
self.push_segment(seg_type, payload);
|
||||
}
|
||||
|
||||
/// Add witness/proof data as a Witness segment.
|
||||
pub fn add_witness(&mut self, training_hash: &str, metrics: &serde_json::Value) {
|
||||
let witness = serde_json::json!({
|
||||
"training_hash": training_hash,
|
||||
"metrics": metrics,
|
||||
});
|
||||
let payload = serde_json::to_vec(&witness).unwrap_or_default();
|
||||
self.push_segment(SEG_WITNESS, &payload);
|
||||
}
|
||||
|
||||
/// Build the final `.rvf` file as a byte vector.
|
||||
pub fn build(&self) -> Vec<u8> {
|
||||
let total: usize = self
|
||||
.segments
|
||||
.iter()
|
||||
.map(|(_, p)| align_up(SEGMENT_HEADER_SIZE + p.len()))
|
||||
.sum();
|
||||
|
||||
let mut buf = Vec::with_capacity(total);
|
||||
for (header, payload) in &self.segments {
|
||||
buf.extend_from_slice(&header.to_bytes());
|
||||
buf.extend_from_slice(payload);
|
||||
// Zero-pad to the next 64-byte boundary
|
||||
let written = SEGMENT_HEADER_SIZE + payload.len();
|
||||
let target = align_up(written);
|
||||
let pad = target - written;
|
||||
buf.extend(std::iter::repeat(0u8).take(pad));
|
||||
}
|
||||
buf
|
||||
}
|
||||
|
||||
/// Write the container to a file.
|
||||
pub fn write_to_file(&self, path: &std::path::Path) -> std::io::Result<()> {
|
||||
let data = self.build();
|
||||
let mut file = std::fs::File::create(path)?;
|
||||
file.write_all(&data)?;
|
||||
file.flush()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── internal helpers ────────────────────────────────────────────────────
|
||||
|
||||
fn push_segment(&mut self, seg_type: u8, payload: &[u8]) {
|
||||
let id = self.next_id;
|
||||
self.next_id += 1;
|
||||
|
||||
let content_hash = crc32_content_hash(payload);
|
||||
let raw = SEGMENT_HEADER_SIZE + payload.len();
|
||||
let aligned = align_up(raw);
|
||||
let pad = (aligned - raw) as u32;
|
||||
|
||||
let now_ns = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_nanos() as u64)
|
||||
.unwrap_or(0);
|
||||
|
||||
let header = SegmentHeader {
|
||||
magic: SEGMENT_MAGIC,
|
||||
version: SEGMENT_VERSION,
|
||||
seg_type,
|
||||
flags: 0,
|
||||
segment_id: id,
|
||||
payload_length: payload.len() as u64,
|
||||
timestamp_ns: now_ns,
|
||||
checksum_algo: 0, // CRC32
|
||||
compression: 0,
|
||||
reserved_0: 0,
|
||||
reserved_1: 0,
|
||||
content_hash,
|
||||
uncompressed_len: 0,
|
||||
alignment_pad: pad,
|
||||
};
|
||||
|
||||
self.segments.push((header, payload.to_vec()));
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RvfBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Round `size` up to the next multiple of `SEGMENT_ALIGNMENT` (64).
|
||||
fn align_up(size: usize) -> usize {
|
||||
(size + SEGMENT_ALIGNMENT - 1) & !(SEGMENT_ALIGNMENT - 1)
|
||||
}
|
||||
|
||||
// ── RVF Reader ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// Reads and parses an RVF container from bytes, providing access to
|
||||
/// individual segments.
|
||||
#[derive(Debug)]
|
||||
pub struct RvfReader {
|
||||
segments: Vec<(SegmentHeader, Vec<u8>)>,
|
||||
raw_size: usize,
|
||||
}
|
||||
|
||||
impl RvfReader {
|
||||
/// Parse an RVF container from a byte slice.
|
||||
pub fn from_bytes(data: &[u8]) -> Result<Self, String> {
|
||||
let mut segments = Vec::new();
|
||||
let mut offset = 0;
|
||||
|
||||
while offset + SEGMENT_HEADER_SIZE <= data.len() {
|
||||
// Read the 64-byte header
|
||||
let header_bytes: &[u8; 64] = data[offset..offset + 64]
|
||||
.try_into()
|
||||
.map_err(|_| "truncated header".to_string())?;
|
||||
|
||||
let header = SegmentHeader::from_bytes(header_bytes);
|
||||
|
||||
// Validate magic
|
||||
if header.magic != SEGMENT_MAGIC {
|
||||
return Err(format!(
|
||||
"invalid magic at offset {offset}: expected 0x{SEGMENT_MAGIC:08X}, \
|
||||
got 0x{:08X}",
|
||||
header.magic
|
||||
));
|
||||
}
|
||||
|
||||
// Validate version
|
||||
if header.version != SEGMENT_VERSION {
|
||||
return Err(format!(
|
||||
"unsupported version at offset {offset}: expected {SEGMENT_VERSION}, \
|
||||
got {}",
|
||||
header.version
|
||||
));
|
||||
}
|
||||
|
||||
let payload_len = header.payload_length as usize;
|
||||
let payload_start = offset + SEGMENT_HEADER_SIZE;
|
||||
let payload_end = payload_start + payload_len;
|
||||
|
||||
if payload_end > data.len() {
|
||||
return Err(format!(
|
||||
"truncated payload at offset {offset}: need {payload_len} bytes, \
|
||||
only {} available",
|
||||
data.len() - payload_start
|
||||
));
|
||||
}
|
||||
|
||||
let payload = data[payload_start..payload_end].to_vec();
|
||||
|
||||
// Verify CRC32 content hash
|
||||
let expected_hash = crc32_content_hash(&payload);
|
||||
if expected_hash != header.content_hash {
|
||||
return Err(format!(
|
||||
"content hash mismatch at segment {} (offset {offset})",
|
||||
header.segment_id
|
||||
));
|
||||
}
|
||||
|
||||
segments.push((header, payload));
|
||||
|
||||
// Advance past header + payload + padding to next 64-byte boundary
|
||||
let raw = SEGMENT_HEADER_SIZE + payload_len;
|
||||
offset += align_up(raw);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
segments,
|
||||
raw_size: data.len(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Read an RVF container from a file.
|
||||
pub fn from_file(path: &std::path::Path) -> Result<Self, String> {
|
||||
let data = std::fs::read(path)
|
||||
.map_err(|e| format!("failed to read {}: {e}", path.display()))?;
|
||||
Self::from_bytes(&data)
|
||||
}
|
||||
|
||||
/// Find the first segment with the given type and return its payload.
|
||||
pub fn find_segment(&self, seg_type: u8) -> Option<&[u8]> {
|
||||
self.segments
|
||||
.iter()
|
||||
.find(|(h, _)| h.seg_type == seg_type)
|
||||
.map(|(_, p)| p.as_slice())
|
||||
}
|
||||
|
||||
/// Parse and return the manifest JSON, if present.
|
||||
pub fn manifest(&self) -> Option<serde_json::Value> {
|
||||
self.find_segment(SEG_MANIFEST)
|
||||
.and_then(|data| serde_json::from_slice(data).ok())
|
||||
}
|
||||
|
||||
/// Decode and return model weights from the Vec segment, if present.
|
||||
pub fn weights(&self) -> Option<Vec<f32>> {
|
||||
let data = self.find_segment(SEG_VEC)?;
|
||||
if data.len() % 4 != 0 {
|
||||
return None;
|
||||
}
|
||||
let weights: Vec<f32> = data
|
||||
.chunks_exact(4)
|
||||
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
|
||||
.collect();
|
||||
Some(weights)
|
||||
}
|
||||
|
||||
/// Parse and return the metadata JSON, if present.
|
||||
pub fn metadata(&self) -> Option<serde_json::Value> {
|
||||
self.find_segment(SEG_META)
|
||||
.and_then(|data| serde_json::from_slice(data).ok())
|
||||
}
|
||||
|
||||
/// Parse and return the vital sign config, if present.
|
||||
pub fn vital_config(&self) -> Option<VitalSignConfig> {
|
||||
self.find_segment(SEG_PROFILE)
|
||||
.and_then(|data| serde_json::from_slice(data).ok())
|
||||
}
|
||||
|
||||
/// Parse and return the quantization info, if present.
|
||||
pub fn quant_info(&self) -> Option<serde_json::Value> {
|
||||
self.find_segment(SEG_QUANT)
|
||||
.and_then(|data| serde_json::from_slice(data).ok())
|
||||
}
|
||||
|
||||
/// Parse and return the witness data, if present.
|
||||
pub fn witness(&self) -> Option<serde_json::Value> {
|
||||
self.find_segment(SEG_WITNESS)
|
||||
.and_then(|data| serde_json::from_slice(data).ok())
|
||||
}
|
||||
|
||||
/// Number of segments in the container.
|
||||
pub fn segment_count(&self) -> usize {
|
||||
self.segments.len()
|
||||
}
|
||||
|
||||
/// Total byte size of the original container data.
|
||||
pub fn total_size(&self) -> usize {
|
||||
self.raw_size
|
||||
}
|
||||
|
||||
/// Build a summary info struct for the REST API.
|
||||
pub fn info(&self) -> RvfContainerInfo {
|
||||
RvfContainerInfo {
|
||||
segment_count: self.segment_count(),
|
||||
total_size: self.total_size(),
|
||||
manifest: self.manifest(),
|
||||
metadata: self.metadata(),
|
||||
has_weights: self.find_segment(SEG_VEC).is_some(),
|
||||
has_vital_config: self.find_segment(SEG_PROFILE).is_some(),
|
||||
has_quant_info: self.find_segment(SEG_QUANT).is_some(),
|
||||
has_witness: self.find_segment(SEG_WITNESS).is_some(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return an iterator over all segment headers and their payloads.
|
||||
pub fn segments(&self) -> impl Iterator<Item = (&SegmentHeader, &[u8])> {
|
||||
self.segments.iter().map(|(h, p)| (h, p.as_slice()))
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ───────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn crc32_known_values() {
|
||||
// "hello" CRC32 (IEEE) = 0x3610A686
|
||||
let c = crc32(b"hello");
|
||||
assert_eq!(c, 0x3610_A686);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn crc32_empty() {
|
||||
let c = crc32(b"");
|
||||
assert_eq!(c, 0x0000_0000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn header_round_trip() {
|
||||
let header = SegmentHeader::new(SEG_MANIFEST, 42);
|
||||
let bytes = header.to_bytes();
|
||||
assert_eq!(bytes.len(), 64);
|
||||
let parsed = SegmentHeader::from_bytes(&bytes);
|
||||
assert_eq!(parsed.magic, SEGMENT_MAGIC);
|
||||
assert_eq!(parsed.version, SEGMENT_VERSION);
|
||||
assert_eq!(parsed.seg_type, SEG_MANIFEST);
|
||||
assert_eq!(parsed.segment_id, 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn header_size_is_64() {
|
||||
let header = SegmentHeader::new(0x01, 0);
|
||||
assert_eq!(header.to_bytes().len(), 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn header_field_offsets() {
|
||||
let mut header = SegmentHeader::new(SEG_VEC, 0x1234_5678_9ABC_DEF0);
|
||||
header.flags = 0x0009; // COMPRESSED | SEALED
|
||||
header.payload_length = 0xAABB_CCDD_EEFF_0011;
|
||||
let bytes = header.to_bytes();
|
||||
|
||||
// Magic at offset 0x00
|
||||
assert_eq!(
|
||||
u32::from_le_bytes(bytes[0x00..0x04].try_into().unwrap()),
|
||||
SEGMENT_MAGIC
|
||||
);
|
||||
// Version at 0x04
|
||||
assert_eq!(bytes[0x04], SEGMENT_VERSION);
|
||||
// seg_type at 0x05
|
||||
assert_eq!(bytes[0x05], SEG_VEC);
|
||||
// flags at 0x06
|
||||
assert_eq!(
|
||||
u16::from_le_bytes(bytes[0x06..0x08].try_into().unwrap()),
|
||||
0x0009
|
||||
);
|
||||
// segment_id at 0x08
|
||||
assert_eq!(
|
||||
u64::from_le_bytes(bytes[0x08..0x10].try_into().unwrap()),
|
||||
0x1234_5678_9ABC_DEF0
|
||||
);
|
||||
// payload_length at 0x10
|
||||
assert_eq!(
|
||||
u64::from_le_bytes(bytes[0x10..0x18].try_into().unwrap()),
|
||||
0xAABB_CCDD_EEFF_0011
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_empty_container() {
|
||||
let builder = RvfBuilder::new();
|
||||
let data = builder.build();
|
||||
assert!(data.is_empty());
|
||||
|
||||
let reader = RvfReader::from_bytes(&data).unwrap();
|
||||
assert_eq!(reader.segment_count(), 0);
|
||||
assert_eq!(reader.total_size(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manifest_round_trip() {
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_manifest("test-model", "1.0.0", "A test model");
|
||||
let data = builder.build();
|
||||
|
||||
assert_eq!(data.len() % SEGMENT_ALIGNMENT, 0);
|
||||
|
||||
let reader = RvfReader::from_bytes(&data).unwrap();
|
||||
assert_eq!(reader.segment_count(), 1);
|
||||
|
||||
let manifest = reader.manifest().expect("manifest should be present");
|
||||
assert_eq!(manifest["model_id"], "test-model");
|
||||
assert_eq!(manifest["version"], "1.0.0");
|
||||
assert_eq!(manifest["description"], "A test model");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn weights_round_trip() {
|
||||
let weights: Vec<f32> = vec![1.0, -2.5, 3.14, 0.0, f32::MAX, f32::MIN];
|
||||
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_weights(&weights);
|
||||
let data = builder.build();
|
||||
|
||||
let reader = RvfReader::from_bytes(&data).unwrap();
|
||||
let decoded = reader.weights().expect("weights should be present");
|
||||
assert_eq!(decoded.len(), weights.len());
|
||||
for (a, b) in decoded.iter().zip(weights.iter()) {
|
||||
assert_eq!(a.to_bits(), b.to_bits());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metadata_round_trip() {
|
||||
let meta = serde_json::json!({
|
||||
"task": "wifi-densepose",
|
||||
"input_dim": 56,
|
||||
"output_dim": 17,
|
||||
"hidden_layers": [128, 64],
|
||||
});
|
||||
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_metadata(&meta);
|
||||
let data = builder.build();
|
||||
|
||||
let reader = RvfReader::from_bytes(&data).unwrap();
|
||||
let decoded = reader.metadata().expect("metadata should be present");
|
||||
assert_eq!(decoded["task"], "wifi-densepose");
|
||||
assert_eq!(decoded["input_dim"], 56);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vital_config_round_trip() {
|
||||
let config = VitalSignConfig {
|
||||
breathing_low_hz: 0.15,
|
||||
breathing_high_hz: 0.45,
|
||||
heartrate_low_hz: 0.9,
|
||||
heartrate_high_hz: 1.8,
|
||||
min_subcarriers: 64,
|
||||
window_size: 1024,
|
||||
confidence_threshold: 0.7,
|
||||
};
|
||||
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_vital_config(&config);
|
||||
let data = builder.build();
|
||||
|
||||
let reader = RvfReader::from_bytes(&data).unwrap();
|
||||
let decoded = reader.vital_config().expect("vital config should be present");
|
||||
assert!((decoded.breathing_low_hz - 0.15).abs() < f64::EPSILON);
|
||||
assert_eq!(decoded.min_subcarriers, 64);
|
||||
assert_eq!(decoded.window_size, 1024);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quant_info_round_trip() {
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_quant_info("int8", 0.0078125, -128);
|
||||
let data = builder.build();
|
||||
|
||||
let reader = RvfReader::from_bytes(&data).unwrap();
|
||||
let qi = reader.quant_info().expect("quant info should be present");
|
||||
assert_eq!(qi["quant_type"], "int8");
|
||||
assert_eq!(qi["zero_point"], -128);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn witness_round_trip() {
|
||||
let metrics = serde_json::json!({
|
||||
"accuracy": 0.95,
|
||||
"loss": 0.032,
|
||||
"epochs": 100,
|
||||
});
|
||||
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_witness("sha256:abcdef1234567890", &metrics);
|
||||
let data = builder.build();
|
||||
|
||||
let reader = RvfReader::from_bytes(&data).unwrap();
|
||||
let w = reader.witness().expect("witness should be present");
|
||||
assert_eq!(w["training_hash"], "sha256:abcdef1234567890");
|
||||
assert_eq!(w["metrics"]["accuracy"], 0.95);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full_container_round_trip() {
|
||||
let mut builder = RvfBuilder::new();
|
||||
|
||||
builder.add_manifest("wifi-densepose-v1", "0.1.0", "WiFi DensePose model");
|
||||
builder.add_weights(&[0.1, 0.2, 0.3, -0.5, 1.0]);
|
||||
builder.add_metadata(&serde_json::json!({
|
||||
"architecture": "mlp",
|
||||
"input_dim": 56,
|
||||
}));
|
||||
builder.add_vital_config(&VitalSignConfig::default());
|
||||
builder.add_quant_info("fp32", 1.0, 0);
|
||||
builder.add_witness("sha256:deadbeef", &serde_json::json!({"loss": 0.01}));
|
||||
|
||||
let data = builder.build();
|
||||
|
||||
// Every segment starts at a 64-byte boundary
|
||||
assert_eq!(data.len() % SEGMENT_ALIGNMENT, 0);
|
||||
|
||||
let reader = RvfReader::from_bytes(&data).unwrap();
|
||||
assert_eq!(reader.segment_count(), 6);
|
||||
|
||||
// All segments present
|
||||
assert!(reader.manifest().is_some());
|
||||
assert!(reader.weights().is_some());
|
||||
assert!(reader.metadata().is_some());
|
||||
assert!(reader.vital_config().is_some());
|
||||
assert!(reader.quant_info().is_some());
|
||||
assert!(reader.witness().is_some());
|
||||
|
||||
// Verify weights data
|
||||
let w = reader.weights().unwrap();
|
||||
assert_eq!(w.len(), 5);
|
||||
assert!((w[0] - 0.1).abs() < f32::EPSILON);
|
||||
assert!((w[3] - (-0.5)).abs() < f32::EPSILON);
|
||||
|
||||
// Info struct for API
|
||||
let info = reader.info();
|
||||
assert_eq!(info.segment_count, 6);
|
||||
assert!(info.has_weights);
|
||||
assert!(info.has_vital_config);
|
||||
assert!(info.has_quant_info);
|
||||
assert!(info.has_witness);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn file_round_trip() {
|
||||
let dir = std::env::temp_dir().join("rvf_test");
|
||||
std::fs::create_dir_all(&dir).unwrap();
|
||||
let path = dir.join("test_model.rvf");
|
||||
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_manifest("file-test", "1.0.0", "File I/O test");
|
||||
builder.add_weights(&[42.0, -1.0]);
|
||||
builder.write_to_file(&path).unwrap();
|
||||
|
||||
let reader = RvfReader::from_file(&path).unwrap();
|
||||
assert_eq!(reader.segment_count(), 2);
|
||||
|
||||
let manifest = reader.manifest().unwrap();
|
||||
assert_eq!(manifest["model_id"], "file-test");
|
||||
|
||||
let w = reader.weights().unwrap();
|
||||
assert_eq!(w.len(), 2);
|
||||
assert!((w[0] - 42.0).abs() < f32::EPSILON);
|
||||
|
||||
// Cleanup
|
||||
let _ = std::fs::remove_file(&path);
|
||||
let _ = std::fs::remove_dir(&dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_magic_rejected() {
|
||||
let mut data = vec![0u8; 128];
|
||||
// Write bad magic
|
||||
data[0..4].copy_from_slice(&0xDEADBEEFu32.to_le_bytes());
|
||||
let result = RvfReader::from_bytes(&data);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("invalid magic"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncated_payload_rejected() {
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_metadata(&serde_json::json!({"key": "a]long value that goes beyond the header boundary for sure to make truncation detectable"}));
|
||||
let data = builder.build();
|
||||
|
||||
// Chop off the last half of the container
|
||||
let cut = SEGMENT_HEADER_SIZE + 5;
|
||||
let truncated = &data[..cut];
|
||||
let result = RvfReader::from_bytes(truncated);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("truncated payload"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn content_hash_integrity() {
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_metadata(&serde_json::json!({"key": "value"}));
|
||||
let mut data = builder.build();
|
||||
|
||||
// Corrupt one byte in the payload area (after the 64-byte header)
|
||||
if data.len() > 65 {
|
||||
data[65] ^= 0xFF;
|
||||
let result = RvfReader::from_bytes(&data);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("hash mismatch"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn alignment_for_various_payload_sizes() {
|
||||
for payload_size in [0, 1, 10, 63, 64, 65, 127, 128, 256, 1000] {
|
||||
let payload = vec![0xABu8; payload_size];
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.push_segment(SEG_META, &payload);
|
||||
let data = builder.build();
|
||||
assert_eq!(
|
||||
data.len() % SEGMENT_ALIGNMENT,
|
||||
0,
|
||||
"not aligned for payload_size={payload_size}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn segment_ids_are_monotonic() {
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_manifest("m", "1", "d");
|
||||
builder.add_weights(&[1.0]);
|
||||
builder.add_metadata(&serde_json::json!({}));
|
||||
|
||||
let data = builder.build();
|
||||
let reader = RvfReader::from_bytes(&data).unwrap();
|
||||
|
||||
let ids: Vec<u64> = reader.segments().map(|(h, _)| h.segment_id).collect();
|
||||
assert_eq!(ids, vec![0, 1, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_weights() {
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_weights(&[]);
|
||||
let data = builder.build();
|
||||
|
||||
let reader = RvfReader::from_bytes(&data).unwrap();
|
||||
let w = reader.weights().unwrap();
|
||||
assert!(w.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn info_reports_correctly() {
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_manifest("info-test", "2.0", "info test");
|
||||
builder.add_weights(&[1.0, 2.0, 3.0]);
|
||||
let data = builder.build();
|
||||
|
||||
let reader = RvfReader::from_bytes(&data).unwrap();
|
||||
let info = reader.info();
|
||||
assert_eq!(info.segment_count, 2);
|
||||
assert!(info.total_size > 0);
|
||||
assert!(info.manifest.is_some());
|
||||
assert!(info.has_weights);
|
||||
assert!(!info.has_vital_config);
|
||||
assert!(!info.has_quant_info);
|
||||
assert!(!info.has_witness);
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,639 @@
|
||||
//! SONA online adaptation: LoRA + EWC++ for WiFi-DensePose (ADR-023 Phase 5).
|
||||
//!
|
||||
//! Enables rapid low-parameter adaptation to changing WiFi environments without
|
||||
//! catastrophic forgetting. All arithmetic uses `f32`, no external dependencies.
|
||||
|
||||
use std::collections::VecDeque;
|
||||
|
||||
// ── LoRA Adapter ────────────────────────────────────────────────────────────
|
||||
|
||||
/// Low-Rank Adaptation layer storing factorised delta `scale * A * B`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LoraAdapter {
|
||||
pub a: Vec<Vec<f32>>, // (in_features, rank)
|
||||
pub b: Vec<Vec<f32>>, // (rank, out_features)
|
||||
pub scale: f32, // alpha / rank
|
||||
pub in_features: usize,
|
||||
pub out_features: usize,
|
||||
pub rank: usize,
|
||||
}
|
||||
|
||||
impl LoraAdapter {
|
||||
pub fn new(in_features: usize, out_features: usize, rank: usize, alpha: f32) -> Self {
|
||||
Self {
|
||||
a: vec![vec![0.0f32; rank]; in_features],
|
||||
b: vec![vec![0.0f32; out_features]; rank],
|
||||
scale: alpha / rank.max(1) as f32,
|
||||
in_features, out_features, rank,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute `scale * input * A * B`, returning a vector of length `out_features`.
|
||||
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
|
||||
assert_eq!(input.len(), self.in_features);
|
||||
let mut hidden = vec![0.0f32; self.rank];
|
||||
for (i, &x) in input.iter().enumerate() {
|
||||
for r in 0..self.rank { hidden[r] += x * self.a[i][r]; }
|
||||
}
|
||||
let mut output = vec![0.0f32; self.out_features];
|
||||
for r in 0..self.rank {
|
||||
for j in 0..self.out_features { output[j] += hidden[r] * self.b[r][j]; }
|
||||
}
|
||||
for v in output.iter_mut() { *v *= self.scale; }
|
||||
output
|
||||
}
|
||||
|
||||
/// Full delta weight matrix `scale * A * B`, shape (in_features, out_features).
|
||||
pub fn delta_weights(&self) -> Vec<Vec<f32>> {
|
||||
let mut delta = vec![vec![0.0f32; self.out_features]; self.in_features];
|
||||
for i in 0..self.in_features {
|
||||
for r in 0..self.rank {
|
||||
let a_val = self.a[i][r];
|
||||
for j in 0..self.out_features { delta[i][j] += a_val * self.b[r][j]; }
|
||||
}
|
||||
}
|
||||
for row in delta.iter_mut() { for v in row.iter_mut() { *v *= self.scale; } }
|
||||
delta
|
||||
}
|
||||
|
||||
/// Add LoRA delta to base weights in place.
|
||||
pub fn merge_into(&self, base_weights: &mut [Vec<f32>]) {
|
||||
let delta = self.delta_weights();
|
||||
for (rb, rd) in base_weights.iter_mut().zip(delta.iter()) {
|
||||
for (w, &d) in rb.iter_mut().zip(rd.iter()) { *w += d; }
|
||||
}
|
||||
}
|
||||
|
||||
/// Subtract LoRA delta from base weights in place.
|
||||
pub fn unmerge_from(&self, base_weights: &mut [Vec<f32>]) {
|
||||
let delta = self.delta_weights();
|
||||
for (rb, rd) in base_weights.iter_mut().zip(delta.iter()) {
|
||||
for (w, &d) in rb.iter_mut().zip(rd.iter()) { *w -= d; }
|
||||
}
|
||||
}
|
||||
|
||||
/// Trainable parameter count: `rank * (in_features + out_features)`.
|
||||
pub fn n_params(&self) -> usize { self.rank * (self.in_features + self.out_features) }
|
||||
|
||||
/// Reset A and B to zero.
|
||||
pub fn reset(&mut self) {
|
||||
for row in self.a.iter_mut() { for v in row.iter_mut() { *v = 0.0; } }
|
||||
for row in self.b.iter_mut() { for v in row.iter_mut() { *v = 0.0; } }
|
||||
}
|
||||
}
|
||||
|
||||
// ── EWC++ Regularizer ───────────────────────────────────────────────────────
|
||||
|
||||
/// Elastic Weight Consolidation++ regularizer with running Fisher average.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EwcRegularizer {
|
||||
pub lambda: f32,
|
||||
pub decay: f32,
|
||||
pub fisher_diag: Vec<f32>,
|
||||
pub reference_params: Vec<f32>,
|
||||
}
|
||||
|
||||
impl EwcRegularizer {
|
||||
pub fn new(lambda: f32, decay: f32) -> Self {
|
||||
Self { lambda, decay, fisher_diag: Vec::new(), reference_params: Vec::new() }
|
||||
}
|
||||
|
||||
/// Diagonal Fisher via numerical central differences: F_i = grad_i^2.
|
||||
pub fn compute_fisher(params: &[f32], loss_fn: impl Fn(&[f32]) -> f32, n_samples: usize) -> Vec<f32> {
|
||||
let eps = 1e-4f32;
|
||||
let n = params.len();
|
||||
let mut fisher = vec![0.0f32; n];
|
||||
let samples = n_samples.max(1);
|
||||
for _ in 0..samples {
|
||||
let mut p = params.to_vec();
|
||||
for i in 0..n {
|
||||
let orig = p[i];
|
||||
p[i] = orig + eps;
|
||||
let lp = loss_fn(&p);
|
||||
p[i] = orig - eps;
|
||||
let lm = loss_fn(&p);
|
||||
p[i] = orig;
|
||||
let g = (lp - lm) / (2.0 * eps);
|
||||
fisher[i] += g * g;
|
||||
}
|
||||
}
|
||||
for f in fisher.iter_mut() { *f /= samples as f32; }
|
||||
fisher
|
||||
}
|
||||
|
||||
/// Online update: `F = decay * F_old + (1-decay) * F_new`.
|
||||
pub fn update_fisher(&mut self, new_fisher: &[f32]) {
|
||||
if self.fisher_diag.is_empty() {
|
||||
self.fisher_diag = new_fisher.to_vec();
|
||||
return;
|
||||
}
|
||||
assert_eq!(self.fisher_diag.len(), new_fisher.len());
|
||||
for (old, &nv) in self.fisher_diag.iter_mut().zip(new_fisher.iter()) {
|
||||
*old = self.decay * *old + (1.0 - self.decay) * nv;
|
||||
}
|
||||
}
|
||||
|
||||
/// Penalty: `0.5 * lambda * sum(F_i * (theta_i - theta_i*)^2)`.
|
||||
pub fn penalty(&self, current_params: &[f32]) -> f32 {
|
||||
if self.reference_params.is_empty() || self.fisher_diag.is_empty() { return 0.0; }
|
||||
let n = current_params.len().min(self.reference_params.len()).min(self.fisher_diag.len());
|
||||
let mut sum = 0.0f32;
|
||||
for i in 0..n {
|
||||
let d = current_params[i] - self.reference_params[i];
|
||||
sum += self.fisher_diag[i] * d * d;
|
||||
}
|
||||
0.5 * self.lambda * sum
|
||||
}
|
||||
|
||||
/// Gradient of penalty: `lambda * F_i * (theta_i - theta_i*)`.
|
||||
pub fn penalty_gradient(&self, current_params: &[f32]) -> Vec<f32> {
|
||||
if self.reference_params.is_empty() || self.fisher_diag.is_empty() {
|
||||
return vec![0.0f32; current_params.len()];
|
||||
}
|
||||
let n = current_params.len().min(self.reference_params.len()).min(self.fisher_diag.len());
|
||||
let mut grad = vec![0.0f32; current_params.len()];
|
||||
for i in 0..n {
|
||||
grad[i] = self.lambda * self.fisher_diag[i] * (current_params[i] - self.reference_params[i]);
|
||||
}
|
||||
grad
|
||||
}
|
||||
|
||||
/// Save current params as the new reference point.
|
||||
pub fn consolidate(&mut self, params: &[f32]) { self.reference_params = params.to_vec(); }
|
||||
}
|
||||
|
||||
// ── Configuration & Types ───────────────────────────────────────────────────
|
||||
|
||||
/// SONA adaptation configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SonaConfig {
|
||||
pub lora_rank: usize,
|
||||
pub lora_alpha: f32,
|
||||
pub ewc_lambda: f32,
|
||||
pub ewc_decay: f32,
|
||||
pub adaptation_lr: f32,
|
||||
pub max_steps: usize,
|
||||
pub convergence_threshold: f32,
|
||||
pub temporal_consistency_weight: f32,
|
||||
}
|
||||
|
||||
impl Default for SonaConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
lora_rank: 4, lora_alpha: 8.0, ewc_lambda: 5000.0, ewc_decay: 0.99,
|
||||
adaptation_lr: 0.001, max_steps: 50, convergence_threshold: 1e-4,
|
||||
temporal_consistency_weight: 0.1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Single training sample for online adaptation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdaptationSample {
|
||||
pub csi_features: Vec<f32>,
|
||||
pub target: Vec<f32>,
|
||||
}
|
||||
|
||||
/// Result of a SONA adaptation run.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdaptationResult {
|
||||
pub adapted_params: Vec<f32>,
|
||||
pub steps_taken: usize,
|
||||
pub final_loss: f32,
|
||||
pub converged: bool,
|
||||
pub ewc_penalty: f32,
|
||||
}
|
||||
|
||||
/// Saved environment-specific adaptation profile.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SonaProfile {
|
||||
pub name: String,
|
||||
pub lora_a: Vec<Vec<f32>>,
|
||||
pub lora_b: Vec<Vec<f32>>,
|
||||
pub fisher_diag: Vec<f32>,
|
||||
pub reference_params: Vec<f32>,
|
||||
pub adaptation_count: usize,
|
||||
}
|
||||
|
||||
// ── SONA Adapter ────────────────────────────────────────────────────────────
|
||||
|
||||
/// Full SONA system: LoRA adapter + EWC++ regularizer for online adaptation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SonaAdapter {
|
||||
pub config: SonaConfig,
|
||||
pub lora: LoraAdapter,
|
||||
pub ewc: EwcRegularizer,
|
||||
pub param_count: usize,
|
||||
pub adaptation_count: usize,
|
||||
}
|
||||
|
||||
impl SonaAdapter {
|
||||
pub fn new(config: SonaConfig, param_count: usize) -> Self {
|
||||
let lora = LoraAdapter::new(param_count, 1, config.lora_rank, config.lora_alpha);
|
||||
let ewc = EwcRegularizer::new(config.ewc_lambda, config.ewc_decay);
|
||||
Self { config, lora, ewc, param_count, adaptation_count: 0 }
|
||||
}
|
||||
|
||||
/// Run gradient descent with LoRA + EWC on the given samples.
|
||||
pub fn adapt(&mut self, base_params: &[f32], samples: &[AdaptationSample]) -> AdaptationResult {
|
||||
assert_eq!(base_params.len(), self.param_count);
|
||||
if samples.is_empty() {
|
||||
return AdaptationResult {
|
||||
adapted_params: base_params.to_vec(), steps_taken: 0,
|
||||
final_loss: 0.0, converged: true, ewc_penalty: self.ewc.penalty(base_params),
|
||||
};
|
||||
}
|
||||
let lr = self.config.adaptation_lr;
|
||||
let (mut prev_loss, mut steps, mut converged) = (f32::MAX, 0usize, false);
|
||||
let out_dim = samples[0].target.len();
|
||||
let in_dim = samples[0].csi_features.len();
|
||||
|
||||
for step in 0..self.config.max_steps {
|
||||
steps = step + 1;
|
||||
let df = self.lora_delta_flat();
|
||||
let eff: Vec<f32> = base_params.iter().zip(df.iter()).map(|(&b, &d)| b + d).collect();
|
||||
let (dl, dg) = Self::mse_loss_grad(&eff, samples, in_dim, out_dim);
|
||||
let ep = self.ewc.penalty(&eff);
|
||||
let eg = self.ewc.penalty_gradient(&eff);
|
||||
let total = dl + ep;
|
||||
if (prev_loss - total).abs() < self.config.convergence_threshold {
|
||||
converged = true; prev_loss = total; break;
|
||||
}
|
||||
prev_loss = total;
|
||||
let gl = df.len().min(dg.len()).min(eg.len());
|
||||
let mut tg = vec![0.0f32; gl];
|
||||
for i in 0..gl { tg[i] = dg[i] + eg[i]; }
|
||||
self.update_lora(&tg, lr);
|
||||
}
|
||||
let df = self.lora_delta_flat();
|
||||
let adapted: Vec<f32> = base_params.iter().zip(df.iter()).map(|(&b, &d)| b + d).collect();
|
||||
let ewc_penalty = self.ewc.penalty(&adapted);
|
||||
self.adaptation_count += 1;
|
||||
AdaptationResult { adapted_params: adapted, steps_taken: steps, final_loss: prev_loss, converged, ewc_penalty }
|
||||
}
|
||||
|
||||
pub fn save_profile(&self, name: &str) -> SonaProfile {
|
||||
SonaProfile {
|
||||
name: name.to_string(), lora_a: self.lora.a.clone(), lora_b: self.lora.b.clone(),
|
||||
fisher_diag: self.ewc.fisher_diag.clone(), reference_params: self.ewc.reference_params.clone(),
|
||||
adaptation_count: self.adaptation_count,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_profile(&mut self, profile: &SonaProfile) {
|
||||
self.lora.a = profile.lora_a.clone();
|
||||
self.lora.b = profile.lora_b.clone();
|
||||
self.ewc.fisher_diag = profile.fisher_diag.clone();
|
||||
self.ewc.reference_params = profile.reference_params.clone();
|
||||
self.adaptation_count = profile.adaptation_count;
|
||||
}
|
||||
|
||||
fn lora_delta_flat(&self) -> Vec<f32> {
|
||||
self.lora.delta_weights().into_iter().map(|r| r[0]).collect()
|
||||
}
|
||||
|
||||
fn mse_loss_grad(params: &[f32], samples: &[AdaptationSample], in_dim: usize, out_dim: usize) -> (f32, Vec<f32>) {
|
||||
let n = samples.len() as f32;
|
||||
let ws = in_dim * out_dim;
|
||||
let mut grad = vec![0.0f32; params.len()];
|
||||
let mut loss = 0.0f32;
|
||||
for s in samples {
|
||||
let (inp, tgt) = (&s.csi_features, &s.target);
|
||||
let mut pred = vec![0.0f32; out_dim];
|
||||
for j in 0..out_dim {
|
||||
for i in 0..in_dim.min(inp.len()) {
|
||||
let idx = j * in_dim + i;
|
||||
if idx < ws && idx < params.len() { pred[j] += params[idx] * inp[i]; }
|
||||
}
|
||||
}
|
||||
for j in 0..out_dim.min(tgt.len()) {
|
||||
let e = pred[j] - tgt[j];
|
||||
loss += e * e;
|
||||
for i in 0..in_dim.min(inp.len()) {
|
||||
let idx = j * in_dim + i;
|
||||
if idx < ws && idx < grad.len() { grad[idx] += 2.0 * e * inp[i] / n; }
|
||||
}
|
||||
}
|
||||
}
|
||||
(loss / n, grad)
|
||||
}
|
||||
|
||||
fn update_lora(&mut self, grad: &[f32], lr: f32) {
|
||||
let (scale, rank) = (self.lora.scale, self.lora.rank);
|
||||
if self.lora.b.iter().all(|r| r.iter().all(|&v| v == 0.0)) && rank > 0 {
|
||||
self.lora.b[0][0] = 1.0;
|
||||
}
|
||||
for i in 0..self.lora.in_features.min(grad.len()) {
|
||||
for r in 0..rank {
|
||||
self.lora.a[i][r] -= lr * grad[i] * scale * self.lora.b[r][0];
|
||||
}
|
||||
}
|
||||
for r in 0..rank {
|
||||
let mut g = 0.0f32;
|
||||
for i in 0..self.lora.in_features.min(grad.len()) {
|
||||
g += grad[i] * scale * self.lora.a[i][r];
|
||||
}
|
||||
self.lora.b[r][0] -= lr * g;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Environment Detector ────────────────────────────────────────────────────
|
||||
|
||||
/// CSI baseline drift information.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DriftInfo {
|
||||
pub magnitude: f32,
|
||||
pub duration_frames: usize,
|
||||
pub baseline_mean: f32,
|
||||
pub current_mean: f32,
|
||||
}
|
||||
|
||||
/// Detects environmental drift in CSI statistics (>3 sigma from baseline).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EnvironmentDetector {
|
||||
window_size: usize,
|
||||
means: VecDeque<f32>,
|
||||
variances: VecDeque<f32>,
|
||||
baseline_mean: f32,
|
||||
baseline_var: f32,
|
||||
baseline_std: f32,
|
||||
baseline_set: bool,
|
||||
drift_frames: usize,
|
||||
}
|
||||
|
||||
impl EnvironmentDetector {
|
||||
pub fn new(window_size: usize) -> Self {
|
||||
Self {
|
||||
window_size: window_size.max(2),
|
||||
means: VecDeque::with_capacity(window_size),
|
||||
variances: VecDeque::with_capacity(window_size),
|
||||
baseline_mean: 0.0, baseline_var: 0.0, baseline_std: 0.0,
|
||||
baseline_set: false, drift_frames: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(&mut self, csi_mean: f32, csi_var: f32) {
|
||||
self.means.push_back(csi_mean);
|
||||
self.variances.push_back(csi_var);
|
||||
while self.means.len() > self.window_size { self.means.pop_front(); }
|
||||
while self.variances.len() > self.window_size { self.variances.pop_front(); }
|
||||
if !self.baseline_set && self.means.len() >= self.window_size { self.reset_baseline(); }
|
||||
if self.drift_detected() { self.drift_frames += 1; } else { self.drift_frames = 0; }
|
||||
}
|
||||
|
||||
pub fn drift_detected(&self) -> bool {
|
||||
if !self.baseline_set || self.means.is_empty() { return false; }
|
||||
let dev = (self.current_mean() - self.baseline_mean).abs();
|
||||
let thr = if self.baseline_std > f32::EPSILON { 3.0 * self.baseline_std }
|
||||
else { f32::EPSILON * 100.0 };
|
||||
dev > thr
|
||||
}
|
||||
|
||||
pub fn reset_baseline(&mut self) {
|
||||
if self.means.is_empty() { return; }
|
||||
let n = self.means.len() as f32;
|
||||
self.baseline_mean = self.means.iter().sum::<f32>() / n;
|
||||
let var = self.means.iter().map(|&m| (m - self.baseline_mean).powi(2)).sum::<f32>() / n;
|
||||
self.baseline_var = var;
|
||||
self.baseline_std = var.sqrt();
|
||||
self.baseline_set = true;
|
||||
self.drift_frames = 0;
|
||||
}
|
||||
|
||||
pub fn drift_info(&self) -> DriftInfo {
|
||||
let cm = self.current_mean();
|
||||
let abs_dev = (cm - self.baseline_mean).abs();
|
||||
let magnitude = if self.baseline_std > f32::EPSILON { abs_dev / self.baseline_std }
|
||||
else if abs_dev > f32::EPSILON { abs_dev / f32::EPSILON }
|
||||
else { 0.0 };
|
||||
DriftInfo { magnitude, duration_frames: self.drift_frames, baseline_mean: self.baseline_mean, current_mean: cm }
|
||||
}
|
||||
|
||||
fn current_mean(&self) -> f32 {
|
||||
if self.means.is_empty() { 0.0 }
|
||||
else { self.means.iter().sum::<f32>() / self.means.len() as f32 }
|
||||
}
|
||||
}
|
||||
|
||||
// ── Temporal Consistency Loss ───────────────────────────────────────────────
|
||||
|
||||
/// Penalises large velocity between consecutive outputs: `sum((c-p)^2) / dt`.
|
||||
pub struct TemporalConsistencyLoss;
|
||||
|
||||
impl TemporalConsistencyLoss {
|
||||
pub fn compute(prev_output: &[f32], curr_output: &[f32], dt: f32) -> f32 {
|
||||
if dt <= 0.0 { return 0.0; }
|
||||
let n = prev_output.len().min(curr_output.len());
|
||||
let mut sq = 0.0f32;
|
||||
for i in 0..n { let d = curr_output[i] - prev_output[i]; sq += d * d; }
|
||||
sq / dt
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ───────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn lora_adapter_param_count() {
|
||||
let lora = LoraAdapter::new(64, 32, 4, 8.0);
|
||||
assert_eq!(lora.n_params(), 4 * (64 + 32));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lora_adapter_forward_shape() {
|
||||
let lora = LoraAdapter::new(8, 4, 2, 4.0);
|
||||
assert_eq!(lora.forward(&vec![1.0f32; 8]).len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lora_adapter_zero_init_produces_zero_delta() {
|
||||
let delta = LoraAdapter::new(8, 4, 2, 4.0).delta_weights();
|
||||
assert_eq!(delta.len(), 8);
|
||||
for row in &delta { assert_eq!(row.len(), 4); for &v in row { assert_eq!(v, 0.0); } }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lora_adapter_merge_unmerge_roundtrip() {
|
||||
let mut lora = LoraAdapter::new(3, 2, 1, 2.0);
|
||||
lora.a[0][0] = 1.0; lora.a[1][0] = 2.0; lora.a[2][0] = 3.0;
|
||||
lora.b[0][0] = 0.5; lora.b[0][1] = -0.5;
|
||||
let mut base = vec![vec![10.0, 20.0], vec![30.0, 40.0], vec![50.0, 60.0]];
|
||||
let orig = base.clone();
|
||||
lora.merge_into(&mut base);
|
||||
assert_ne!(base, orig);
|
||||
lora.unmerge_from(&mut base);
|
||||
for (rb, ro) in base.iter().zip(orig.iter()) {
|
||||
for (&b, &o) in rb.iter().zip(ro.iter()) {
|
||||
assert!((b - o).abs() < 1e-5, "roundtrip failed: {b} vs {o}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lora_adapter_rank_1_outer_product() {
|
||||
let mut lora = LoraAdapter::new(3, 2, 1, 1.0); // scale=1
|
||||
lora.a[0][0] = 1.0; lora.a[1][0] = 2.0; lora.a[2][0] = 3.0;
|
||||
lora.b[0][0] = 4.0; lora.b[0][1] = 5.0;
|
||||
let d = lora.delta_weights();
|
||||
let expected = [[4.0, 5.0], [8.0, 10.0], [12.0, 15.0]];
|
||||
for (i, row) in expected.iter().enumerate() {
|
||||
for (j, &v) in row.iter().enumerate() { assert!((d[i][j] - v).abs() < 1e-6); }
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lora_scale_factor() {
|
||||
assert!((LoraAdapter::new(8, 4, 4, 16.0).scale - 4.0).abs() < 1e-6);
|
||||
assert!((LoraAdapter::new(8, 4, 2, 8.0).scale - 4.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ewc_fisher_positive() {
|
||||
let fisher = EwcRegularizer::compute_fisher(
|
||||
&[1.0f32, -2.0, 0.5],
|
||||
|p: &[f32]| p.iter().map(|&x| x * x).sum::<f32>(), 1,
|
||||
);
|
||||
assert_eq!(fisher.len(), 3);
|
||||
for &f in &fisher { assert!(f >= 0.0, "Fisher must be >= 0, got {f}"); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ewc_penalty_zero_at_reference() {
|
||||
let mut ewc = EwcRegularizer::new(5000.0, 0.99);
|
||||
let p = vec![1.0, 2.0, 3.0];
|
||||
ewc.fisher_diag = vec![1.0; 3]; ewc.consolidate(&p);
|
||||
assert!(ewc.penalty(&p).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ewc_penalty_positive_away_from_reference() {
|
||||
let mut ewc = EwcRegularizer::new(5000.0, 0.99);
|
||||
ewc.fisher_diag = vec![1.0; 3]; ewc.consolidate(&[1.0, 2.0, 3.0]);
|
||||
let pen = ewc.penalty(&[2.0, 3.0, 4.0]);
|
||||
assert!(pen > 0.0); // 0.5 * 5000 * 3 = 7500
|
||||
assert!((pen - 7500.0).abs() < 1e-3, "expected ~7500, got {pen}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ewc_penalty_gradient_direction() {
|
||||
let mut ewc = EwcRegularizer::new(100.0, 0.99);
|
||||
let r = vec![1.0, 2.0, 3.0];
|
||||
ewc.fisher_diag = vec![1.0; 3]; ewc.consolidate(&r);
|
||||
let c = vec![2.0, 4.0, 5.0];
|
||||
let grad = ewc.penalty_gradient(&c);
|
||||
for (i, &g) in grad.iter().enumerate() {
|
||||
assert!(g * (c[i] - r[i]) > 0.0, "gradient[{i}] wrong sign");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ewc_online_update_decays() {
|
||||
let mut ewc = EwcRegularizer::new(1.0, 0.5);
|
||||
ewc.update_fisher(&[10.0, 20.0]);
|
||||
assert!((ewc.fisher_diag[0] - 10.0).abs() < 1e-6);
|
||||
ewc.update_fisher(&[0.0, 0.0]);
|
||||
assert!((ewc.fisher_diag[0] - 5.0).abs() < 1e-6); // 0.5*10 + 0.5*0
|
||||
assert!((ewc.fisher_diag[1] - 10.0).abs() < 1e-6); // 0.5*20 + 0.5*0
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ewc_consolidate_updates_reference() {
|
||||
let mut ewc = EwcRegularizer::new(1.0, 0.99);
|
||||
ewc.consolidate(&[1.0, 2.0]);
|
||||
assert_eq!(ewc.reference_params, vec![1.0, 2.0]);
|
||||
ewc.consolidate(&[3.0, 4.0]);
|
||||
assert_eq!(ewc.reference_params, vec![3.0, 4.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sona_config_defaults() {
|
||||
let c = SonaConfig::default();
|
||||
assert_eq!(c.lora_rank, 4);
|
||||
assert!((c.lora_alpha - 8.0).abs() < 1e-6);
|
||||
assert!((c.ewc_lambda - 5000.0).abs() < 1e-3);
|
||||
assert!((c.ewc_decay - 0.99).abs() < 1e-6);
|
||||
assert!((c.adaptation_lr - 0.001).abs() < 1e-6);
|
||||
assert_eq!(c.max_steps, 50);
|
||||
assert!((c.convergence_threshold - 1e-4).abs() < 1e-8);
|
||||
assert!((c.temporal_consistency_weight - 0.1).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sona_adapter_converges_on_simple_task() {
|
||||
let cfg = SonaConfig {
|
||||
lora_rank: 1, lora_alpha: 1.0, ewc_lambda: 0.0, ewc_decay: 0.99,
|
||||
adaptation_lr: 0.01, max_steps: 200, convergence_threshold: 1e-6,
|
||||
temporal_consistency_weight: 0.0,
|
||||
};
|
||||
let mut adapter = SonaAdapter::new(cfg, 1);
|
||||
let samples: Vec<_> = (1..=5).map(|i| {
|
||||
let x = i as f32;
|
||||
AdaptationSample { csi_features: vec![x], target: vec![2.0 * x] }
|
||||
}).collect();
|
||||
let r = adapter.adapt(&[0.0f32], &samples);
|
||||
assert!(r.final_loss < 1.0, "loss should decrease, got {}", r.final_loss);
|
||||
assert!(r.steps_taken > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sona_adapter_respects_max_steps() {
|
||||
let cfg = SonaConfig { max_steps: 5, convergence_threshold: 0.0, ..SonaConfig::default() };
|
||||
let mut a = SonaAdapter::new(cfg, 4);
|
||||
let s = vec![AdaptationSample { csi_features: vec![1.0, 0.0, 0.0, 0.0], target: vec![1.0] }];
|
||||
assert_eq!(a.adapt(&[0.0; 4], &s).steps_taken, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sona_profile_save_load_roundtrip() {
|
||||
let mut a = SonaAdapter::new(SonaConfig::default(), 8);
|
||||
a.lora.a[0][0] = 1.5; a.lora.b[0][0] = -0.3;
|
||||
a.ewc.fisher_diag = vec![1.0, 2.0, 3.0];
|
||||
a.ewc.reference_params = vec![0.1, 0.2, 0.3];
|
||||
a.adaptation_count = 42;
|
||||
let p = a.save_profile("test-env");
|
||||
assert_eq!(p.name, "test-env");
|
||||
assert_eq!(p.adaptation_count, 42);
|
||||
let mut a2 = SonaAdapter::new(SonaConfig::default(), 8);
|
||||
a2.load_profile(&p);
|
||||
assert!((a2.lora.a[0][0] - 1.5).abs() < 1e-6);
|
||||
assert!((a2.lora.b[0][0] - (-0.3)).abs() < 1e-6);
|
||||
assert_eq!(a2.ewc.fisher_diag.len(), 3);
|
||||
assert!((a2.ewc.fisher_diag[2] - 3.0).abs() < 1e-6);
|
||||
assert_eq!(a2.adaptation_count, 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn environment_detector_no_drift_initially() {
|
||||
assert!(!EnvironmentDetector::new(10).drift_detected());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn environment_detector_detects_large_shift() {
|
||||
let mut d = EnvironmentDetector::new(10);
|
||||
for _ in 0..10 { d.update(10.0, 0.1); }
|
||||
assert!(!d.drift_detected());
|
||||
for _ in 0..10 { d.update(50.0, 0.1); }
|
||||
assert!(d.drift_detected());
|
||||
assert!(d.drift_info().magnitude > 3.0, "magnitude = {}", d.drift_info().magnitude);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn environment_detector_reset_baseline() {
|
||||
let mut d = EnvironmentDetector::new(10);
|
||||
for _ in 0..10 { d.update(10.0, 0.1); }
|
||||
for _ in 0..10 { d.update(50.0, 0.1); }
|
||||
assert!(d.drift_detected());
|
||||
d.reset_baseline();
|
||||
assert!(!d.drift_detected());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn temporal_consistency_zero_for_static() {
|
||||
let o = vec![1.0, 2.0, 3.0];
|
||||
assert!(TemporalConsistencyLoss::compute(&o, &o, 0.033).abs() < 1e-10);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,753 @@
|
||||
//! Sparse inference and weight quantization for edge deployment of WiFi DensePose.
|
||||
//!
|
||||
//! Implements ADR-023 Phase 6: activation profiling, sparse matrix-vector multiply,
|
||||
//! INT8/FP16 quantization, and a full sparse inference engine. Pure Rust, no deps.
|
||||
|
||||
use std::time::Instant;
|
||||
|
||||
// ── Neuron Profiler ──────────────────────────────────────────────────────────
|
||||
|
||||
/// Tracks per-neuron activation frequency to partition hot vs cold neurons.
|
||||
pub struct NeuronProfiler {
|
||||
activation_counts: Vec<u64>,
|
||||
samples: usize,
|
||||
n_neurons: usize,
|
||||
}
|
||||
|
||||
impl NeuronProfiler {
|
||||
pub fn new(n_neurons: usize) -> Self {
|
||||
Self { activation_counts: vec![0; n_neurons], samples: 0, n_neurons }
|
||||
}
|
||||
|
||||
/// Record an activation; values > 0 count as "active".
|
||||
pub fn record_activation(&mut self, neuron_idx: usize, activation: f32) {
|
||||
if neuron_idx < self.n_neurons && activation > 0.0 {
|
||||
self.activation_counts[neuron_idx] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark end of one profiling sample (call after recording all neurons).
|
||||
pub fn end_sample(&mut self) { self.samples += 1; }
|
||||
|
||||
/// Fraction of samples where the neuron fired (activation > 0).
|
||||
pub fn activation_frequency(&self, neuron_idx: usize) -> f32 {
|
||||
if neuron_idx >= self.n_neurons || self.samples == 0 { return 0.0; }
|
||||
self.activation_counts[neuron_idx] as f32 / self.samples as f32
|
||||
}
|
||||
|
||||
/// Split neurons into (hot, cold) by activation frequency threshold.
|
||||
pub fn partition_hot_cold(&self, hot_threshold: f32) -> (Vec<usize>, Vec<usize>) {
|
||||
let mut hot = Vec::new();
|
||||
let mut cold = Vec::new();
|
||||
for i in 0..self.n_neurons {
|
||||
if self.activation_frequency(i) >= hot_threshold { hot.push(i); }
|
||||
else { cold.push(i); }
|
||||
}
|
||||
(hot, cold)
|
||||
}
|
||||
|
||||
/// Top-k most frequently activated neuron indices.
|
||||
pub fn top_k_neurons(&self, k: usize) -> Vec<usize> {
|
||||
let mut idx: Vec<usize> = (0..self.n_neurons).collect();
|
||||
idx.sort_by(|&a, &b| {
|
||||
self.activation_frequency(b).partial_cmp(&self.activation_frequency(a))
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
idx.truncate(k);
|
||||
idx
|
||||
}
|
||||
|
||||
/// Fraction of neurons with activation frequency < 0.1.
|
||||
pub fn sparsity_ratio(&self) -> f32 {
|
||||
if self.n_neurons == 0 || self.samples == 0 { return 0.0; }
|
||||
let cold = (0..self.n_neurons).filter(|&i| self.activation_frequency(i) < 0.1).count();
|
||||
cold as f32 / self.n_neurons as f32
|
||||
}
|
||||
|
||||
pub fn total_samples(&self) -> usize { self.samples }
|
||||
}
|
||||
|
||||
// ── Sparse Linear Layer ──────────────────────────────────────────────────────
|
||||
|
||||
/// Linear layer that only computes output rows for "hot" neurons.
|
||||
pub struct SparseLinear {
|
||||
weights: Vec<Vec<f32>>,
|
||||
bias: Vec<f32>,
|
||||
hot_neurons: Vec<usize>,
|
||||
n_outputs: usize,
|
||||
n_inputs: usize,
|
||||
}
|
||||
|
||||
impl SparseLinear {
|
||||
pub fn new(weights: Vec<Vec<f32>>, bias: Vec<f32>, hot_neurons: Vec<usize>) -> Self {
|
||||
let n_outputs = weights.len();
|
||||
let n_inputs = weights.first().map_or(0, |r| r.len());
|
||||
Self { weights, bias, hot_neurons, n_outputs, n_inputs }
|
||||
}
|
||||
|
||||
/// Sparse forward: only compute hot rows; cold outputs are 0.
|
||||
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
|
||||
let mut out = vec![0.0f32; self.n_outputs];
|
||||
for &r in &self.hot_neurons {
|
||||
if r < self.n_outputs { out[r] = dot_bias(&self.weights[r], input, self.bias[r]); }
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Dense forward: compute all rows.
|
||||
pub fn forward_full(&self, input: &[f32]) -> Vec<f32> {
|
||||
(0..self.n_outputs).map(|r| dot_bias(&self.weights[r], input, self.bias[r])).collect()
|
||||
}
|
||||
|
||||
pub fn set_hot_neurons(&mut self, hot: Vec<usize>) { self.hot_neurons = hot; }
|
||||
|
||||
/// Fraction of neurons in the hot set.
|
||||
pub fn density(&self) -> f32 {
|
||||
if self.n_outputs == 0 { 0.0 } else { self.hot_neurons.len() as f32 / self.n_outputs as f32 }
|
||||
}
|
||||
|
||||
/// Multiply-accumulate ops saved vs dense.
|
||||
pub fn n_flops_saved(&self) -> usize {
|
||||
self.n_outputs.saturating_sub(self.hot_neurons.len()) * self.n_inputs
|
||||
}
|
||||
}
|
||||
|
||||
fn dot_bias(row: &[f32], input: &[f32], bias: f32) -> f32 {
|
||||
let len = row.len().min(input.len());
|
||||
let mut s = bias;
|
||||
for i in 0..len { s += row[i] * input[i]; }
|
||||
s
|
||||
}
|
||||
|
||||
// ── Quantization ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// Quantization mode.
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum QuantMode { F32, F16, Int8Symmetric, Int8Asymmetric, Int4 }
|
||||
|
||||
/// Quantization configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QuantConfig { pub mode: QuantMode, pub calibration_samples: usize }
|
||||
|
||||
impl Default for QuantConfig {
|
||||
fn default() -> Self { Self { mode: QuantMode::Int8Symmetric, calibration_samples: 100 } }
|
||||
}
|
||||
|
||||
/// Quantized weight storage.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QuantizedWeights {
|
||||
pub data: Vec<i8>,
|
||||
pub scale: f32,
|
||||
pub zero_point: i8,
|
||||
pub mode: QuantMode,
|
||||
}
|
||||
|
||||
pub struct Quantizer;
|
||||
|
||||
impl Quantizer {
|
||||
/// Symmetric INT8: zero maps to 0, scale = max(|w|)/127.
|
||||
pub fn quantize_symmetric(weights: &[f32]) -> QuantizedWeights {
|
||||
if weights.is_empty() {
|
||||
return QuantizedWeights { data: vec![], scale: 1.0, zero_point: 0, mode: QuantMode::Int8Symmetric };
|
||||
}
|
||||
let max_abs = weights.iter().map(|w| w.abs()).fold(0.0f32, f32::max);
|
||||
let scale = if max_abs < f32::EPSILON { 1.0 } else { max_abs / 127.0 };
|
||||
let data = weights.iter().map(|&w| (w / scale).round().clamp(-127.0, 127.0) as i8).collect();
|
||||
QuantizedWeights { data, scale, zero_point: 0, mode: QuantMode::Int8Symmetric }
|
||||
}
|
||||
|
||||
/// Asymmetric INT8: maps [min,max] to [0,255].
|
||||
pub fn quantize_asymmetric(weights: &[f32]) -> QuantizedWeights {
|
||||
if weights.is_empty() {
|
||||
return QuantizedWeights { data: vec![], scale: 1.0, zero_point: 0, mode: QuantMode::Int8Asymmetric };
|
||||
}
|
||||
let w_min = weights.iter().cloned().fold(f32::INFINITY, f32::min);
|
||||
let w_max = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let range = w_max - w_min;
|
||||
let scale = if range < f32::EPSILON { 1.0 } else { range / 255.0 };
|
||||
let zp = if range < f32::EPSILON { 0u8 } else { (-w_min / scale).round().clamp(0.0, 255.0) as u8 };
|
||||
let data = weights.iter().map(|&w| ((w - w_min) / scale).round().clamp(0.0, 255.0) as u8 as i8).collect();
|
||||
QuantizedWeights { data, scale, zero_point: zp as i8, mode: QuantMode::Int8Asymmetric }
|
||||
}
|
||||
|
||||
/// Reconstruct approximate f32 values from quantized weights.
|
||||
pub fn dequantize(qw: &QuantizedWeights) -> Vec<f32> {
|
||||
match qw.mode {
|
||||
QuantMode::Int8Symmetric => qw.data.iter().map(|&q| q as f32 * qw.scale).collect(),
|
||||
QuantMode::Int8Asymmetric => {
|
||||
let zp = qw.zero_point as u8;
|
||||
qw.data.iter().map(|&q| (q as u8 as f32 - zp as f32) * qw.scale).collect()
|
||||
}
|
||||
_ => qw.data.iter().map(|&q| q as f32 * qw.scale).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// MSE between original and quantized weights.
|
||||
pub fn quantization_error(original: &[f32], quantized: &QuantizedWeights) -> f32 {
|
||||
let deq = Self::dequantize(quantized);
|
||||
if original.len() != deq.len() || original.is_empty() { return f32::MAX; }
|
||||
original.iter().zip(deq.iter()).map(|(o, d)| (o - d).powi(2)).sum::<f32>() / original.len() as f32
|
||||
}
|
||||
|
||||
/// Convert f32 to IEEE 754 half-precision (u16).
|
||||
pub fn f16_quantize(weights: &[f32]) -> Vec<u16> { weights.iter().map(|&w| f32_to_f16(w)).collect() }
|
||||
|
||||
/// Convert FP16 (u16) back to f32.
|
||||
pub fn f16_dequantize(data: &[u16]) -> Vec<f32> { data.iter().map(|&h| f16_to_f32(h)).collect() }
|
||||
}
|
||||
|
||||
// ── FP16 bit manipulation ────────────────────────────────────────────────────
|
||||
|
||||
fn f32_to_f16(val: f32) -> u16 {
|
||||
let bits = val.to_bits();
|
||||
let sign = (bits >> 31) & 1;
|
||||
let exp = ((bits >> 23) & 0xFF) as i32;
|
||||
let man = bits & 0x007F_FFFF;
|
||||
|
||||
if exp == 0xFF { // Inf or NaN
|
||||
let hm = if man != 0 { 0x0200 } else { 0 };
|
||||
return ((sign << 15) | 0x7C00 | hm) as u16;
|
||||
}
|
||||
if exp == 0 { return (sign << 15) as u16; } // zero / subnormal -> zero
|
||||
|
||||
let ne = exp - 127 + 15;
|
||||
if ne >= 31 { return ((sign << 15) | 0x7C00) as u16; } // overflow -> Inf
|
||||
if ne <= 0 {
|
||||
if ne < -10 { return (sign << 15) as u16; }
|
||||
let full = man | 0x0080_0000;
|
||||
return ((sign << 15) | (full >> (13 + 1 - ne))) as u16;
|
||||
}
|
||||
((sign << 15) | ((ne as u32) << 10) | (man >> 13)) as u16
|
||||
}
|
||||
|
||||
fn f16_to_f32(h: u16) -> f32 {
|
||||
let sign = ((h >> 15) & 1) as u32;
|
||||
let exp = ((h >> 10) & 0x1F) as u32;
|
||||
let man = (h & 0x03FF) as u32;
|
||||
|
||||
if exp == 0x1F {
|
||||
let fb = if man != 0 { (sign << 31) | 0x7F80_0000 | (man << 13) } else { (sign << 31) | 0x7F80_0000 };
|
||||
return f32::from_bits(fb);
|
||||
}
|
||||
if exp == 0 {
|
||||
if man == 0 { return f32::from_bits(sign << 31); }
|
||||
let mut m = man; let mut e: i32 = -14;
|
||||
while m & 0x0400 == 0 { m <<= 1; e -= 1; }
|
||||
m &= 0x03FF;
|
||||
return f32::from_bits((sign << 31) | (((e + 127) as u32) << 23) | (m << 13));
|
||||
}
|
||||
f32::from_bits((sign << 31) | ((exp as i32 - 15 + 127) as u32) << 23 | (man << 13))
|
||||
}
|
||||
|
||||
// ── Sparse Model ─────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SparseConfig {
|
||||
pub hot_threshold: f32,
|
||||
pub quant_mode: QuantMode,
|
||||
pub profile_frames: usize,
|
||||
}
|
||||
|
||||
impl Default for SparseConfig {
|
||||
fn default() -> Self { Self { hot_threshold: 0.5, quant_mode: QuantMode::Int8Symmetric, profile_frames: 100 } }
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
struct ModelLayer {
|
||||
name: String,
|
||||
weights: Vec<Vec<f32>>,
|
||||
bias: Vec<f32>,
|
||||
sparse: Option<SparseLinear>,
|
||||
profiler: NeuronProfiler,
|
||||
is_sparse: bool,
|
||||
/// Quantized weights per row (populated by apply_quantization).
|
||||
quantized: Option<Vec<QuantizedWeights>>,
|
||||
/// Whether to use quantized weights for forward pass.
|
||||
use_quantized: bool,
|
||||
}
|
||||
|
||||
impl ModelLayer {
|
||||
fn new(name: &str, weights: Vec<Vec<f32>>, bias: Vec<f32>) -> Self {
|
||||
let n = weights.len();
|
||||
Self {
|
||||
name: name.into(), weights, bias, sparse: None,
|
||||
profiler: NeuronProfiler::new(n), is_sparse: false,
|
||||
quantized: None, use_quantized: false,
|
||||
}
|
||||
}
|
||||
fn forward_dense(&self, input: &[f32]) -> Vec<f32> {
|
||||
if self.use_quantized {
|
||||
if let Some(ref qrows) = self.quantized {
|
||||
return self.forward_quantized(input, qrows);
|
||||
}
|
||||
}
|
||||
self.weights.iter().enumerate().map(|(r, row)| dot_bias(row, input, self.bias[r])).collect()
|
||||
}
|
||||
/// Forward using dequantized weights: val = q_val * scale (symmetric).
|
||||
fn forward_quantized(&self, input: &[f32], qrows: &[QuantizedWeights]) -> Vec<f32> {
|
||||
let n_out = qrows.len().min(self.bias.len());
|
||||
let mut out = vec![0.0f32; n_out];
|
||||
for r in 0..n_out {
|
||||
let qw = &qrows[r];
|
||||
let len = qw.data.len().min(input.len());
|
||||
let mut s = self.bias[r];
|
||||
for i in 0..len {
|
||||
let w = (qw.data[i] as f32 - qw.zero_point as f32) * qw.scale;
|
||||
s += w * input[i];
|
||||
}
|
||||
out[r] = s;
|
||||
}
|
||||
out
|
||||
}
|
||||
fn forward(&self, input: &[f32]) -> Vec<f32> {
|
||||
if self.is_sparse { if let Some(ref s) = self.sparse { return s.forward(input); } }
|
||||
self.forward_dense(input)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelStats {
|
||||
pub total_params: usize,
|
||||
pub hot_params: usize,
|
||||
pub cold_params: usize,
|
||||
pub sparsity: f32,
|
||||
pub quant_mode: QuantMode,
|
||||
pub est_memory_bytes: usize,
|
||||
pub est_flops: usize,
|
||||
}
|
||||
|
||||
/// Full sparse inference engine: profiling + sparsity + quantization.
|
||||
pub struct SparseModel {
|
||||
layers: Vec<ModelLayer>,
|
||||
config: SparseConfig,
|
||||
profiled: bool,
|
||||
}
|
||||
|
||||
impl SparseModel {
|
||||
pub fn new(config: SparseConfig) -> Self { Self { layers: vec![], config, profiled: false } }
|
||||
|
||||
pub fn add_layer(&mut self, name: &str, weights: Vec<Vec<f32>>, bias: Vec<f32>) {
|
||||
self.layers.push(ModelLayer::new(name, weights, bias));
|
||||
}
|
||||
|
||||
/// Profile activation frequencies over sample inputs.
|
||||
pub fn profile(&mut self, inputs: &[Vec<f32>]) {
|
||||
let n = inputs.len().min(self.config.profile_frames);
|
||||
for sample in inputs.iter().take(n) {
|
||||
let mut act = sample.clone();
|
||||
for layer in &mut self.layers {
|
||||
let out = layer.forward_dense(&act);
|
||||
for (i, &v) in out.iter().enumerate() { layer.profiler.record_activation(i, v); }
|
||||
layer.profiler.end_sample();
|
||||
act = out.iter().map(|&v| v.max(0.0)).collect();
|
||||
}
|
||||
}
|
||||
self.profiled = true;
|
||||
}
|
||||
|
||||
/// Convert layers to sparse using profiled hot/cold partition.
|
||||
pub fn apply_sparsity(&mut self) {
|
||||
if !self.profiled { return; }
|
||||
let th = self.config.hot_threshold;
|
||||
for layer in &mut self.layers {
|
||||
let (hot, _) = layer.profiler.partition_hot_cold(th);
|
||||
layer.sparse = Some(SparseLinear::new(layer.weights.clone(), layer.bias.clone(), hot));
|
||||
layer.is_sparse = true;
|
||||
}
|
||||
}
|
||||
|
||||
/// Quantize weights using INT8 codebook per the config. After this call,
|
||||
/// forward() uses dequantized weights (val = (q - zero_point) * scale).
|
||||
pub fn apply_quantization(&mut self) {
|
||||
for layer in &mut self.layers {
|
||||
let qrows: Vec<QuantizedWeights> = layer.weights.iter().map(|row| {
|
||||
match self.config.quant_mode {
|
||||
QuantMode::Int8Symmetric => Quantizer::quantize_symmetric(row),
|
||||
QuantMode::Int8Asymmetric => Quantizer::quantize_asymmetric(row),
|
||||
_ => Quantizer::quantize_symmetric(row),
|
||||
}
|
||||
}).collect();
|
||||
layer.quantized = Some(qrows);
|
||||
layer.use_quantized = true;
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through all layers with ReLU activation.
|
||||
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
|
||||
let mut act = input.to_vec();
|
||||
for layer in &self.layers {
|
||||
act = layer.forward(&act).iter().map(|&v| v.max(0.0)).collect();
|
||||
}
|
||||
act
|
||||
}
|
||||
|
||||
pub fn n_layers(&self) -> usize { self.layers.len() }
|
||||
|
||||
pub fn stats(&self) -> ModelStats {
|
||||
let (mut total, mut hot, mut cold, mut flops) = (0, 0, 0, 0);
|
||||
for layer in &self.layers {
|
||||
let (no, ni) = (layer.weights.len(), layer.weights.first().map_or(0, |r| r.len()));
|
||||
let lp = no * ni + no;
|
||||
total += lp;
|
||||
if let Some(ref s) = layer.sparse {
|
||||
let hc = s.hot_neurons.len();
|
||||
hot += hc * ni + hc;
|
||||
cold += (no - hc) * ni + (no - hc);
|
||||
flops += hc * ni;
|
||||
} else { hot += lp; flops += no * ni; }
|
||||
}
|
||||
let bpp = match self.config.quant_mode {
|
||||
QuantMode::F32 => 4, QuantMode::F16 => 2,
|
||||
QuantMode::Int8Symmetric | QuantMode::Int8Asymmetric => 1,
|
||||
QuantMode::Int4 => 1,
|
||||
};
|
||||
ModelStats {
|
||||
total_params: total, hot_params: hot, cold_params: cold,
|
||||
sparsity: if total > 0 { cold as f32 / total as f32 } else { 0.0 },
|
||||
quant_mode: self.config.quant_mode, est_memory_bytes: hot * bpp, est_flops: flops,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Benchmark Runner ─────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BenchmarkResult {
|
||||
pub mean_latency_us: f64,
|
||||
pub p50_us: f64,
|
||||
pub p99_us: f64,
|
||||
pub throughput_fps: f64,
|
||||
pub memory_bytes: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ComparisonResult {
|
||||
pub dense_latency_us: f64,
|
||||
pub sparse_latency_us: f64,
|
||||
pub speedup: f64,
|
||||
pub accuracy_loss: f32,
|
||||
}
|
||||
|
||||
pub struct BenchmarkRunner;
|
||||
|
||||
impl BenchmarkRunner {
|
||||
pub fn benchmark_inference(model: &SparseModel, input: &[f32], n: usize) -> BenchmarkResult {
|
||||
let mut lat = Vec::with_capacity(n);
|
||||
for _ in 0..n {
|
||||
let t = Instant::now();
|
||||
let _ = model.forward(input);
|
||||
lat.push(t.elapsed().as_micros() as f64);
|
||||
}
|
||||
lat.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
|
||||
let sum: f64 = lat.iter().sum();
|
||||
let mean = sum / lat.len().max(1) as f64;
|
||||
let total_s = sum / 1e6;
|
||||
BenchmarkResult {
|
||||
mean_latency_us: mean,
|
||||
p50_us: pctl(&lat, 50), p99_us: pctl(&lat, 99),
|
||||
throughput_fps: if total_s > 0.0 { n as f64 / total_s } else { f64::INFINITY },
|
||||
memory_bytes: model.stats().est_memory_bytes,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn compare_dense_vs_sparse(
|
||||
dw: &[Vec<Vec<f32>>], db: &[Vec<f32>], sparse: &SparseModel, input: &[f32], n: usize,
|
||||
) -> ComparisonResult {
|
||||
// Dense timing
|
||||
let mut dl = Vec::with_capacity(n);
|
||||
let mut d_out = Vec::new();
|
||||
for _ in 0..n {
|
||||
let t = Instant::now();
|
||||
let mut a = input.to_vec();
|
||||
for (w, b) in dw.iter().zip(db.iter()) {
|
||||
a = w.iter().enumerate().map(|(r, row)| dot_bias(row, &a, b[r])).collect::<Vec<_>>()
|
||||
.iter().map(|&v| v.max(0.0)).collect();
|
||||
}
|
||||
d_out = a;
|
||||
dl.push(t.elapsed().as_micros() as f64);
|
||||
}
|
||||
// Sparse timing
|
||||
let mut sl = Vec::with_capacity(n);
|
||||
let mut s_out = Vec::new();
|
||||
for _ in 0..n {
|
||||
let t = Instant::now();
|
||||
s_out = sparse.forward(input);
|
||||
sl.push(t.elapsed().as_micros() as f64);
|
||||
}
|
||||
let dm: f64 = dl.iter().sum::<f64>() / dl.len().max(1) as f64;
|
||||
let sm: f64 = sl.iter().sum::<f64>() / sl.len().max(1) as f64;
|
||||
let loss = if !d_out.is_empty() && d_out.len() == s_out.len() {
|
||||
d_out.iter().zip(s_out.iter()).map(|(d, s)| (d - s).powi(2)).sum::<f32>() / d_out.len() as f32
|
||||
} else { 0.0 };
|
||||
ComparisonResult {
|
||||
dense_latency_us: dm, sparse_latency_us: sm,
|
||||
speedup: if sm > 0.0 { dm / sm } else { 1.0 }, accuracy_loss: loss,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn pctl(sorted: &[f64], p: usize) -> f64 {
|
||||
if sorted.is_empty() { return 0.0; }
|
||||
let i = (p as f64 / 100.0 * (sorted.len() - 1) as f64).round() as usize;
|
||||
sorted[i.min(sorted.len() - 1)]
|
||||
}
|
||||
|
||||
// ── Tests ────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn neuron_profiler_initially_empty() {
|
||||
let p = NeuronProfiler::new(10);
|
||||
assert_eq!(p.total_samples(), 0);
|
||||
assert_eq!(p.activation_frequency(0), 0.0);
|
||||
assert_eq!(p.sparsity_ratio(), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn neuron_profiler_records_activations() {
|
||||
let mut p = NeuronProfiler::new(4);
|
||||
p.record_activation(0, 1.0); p.record_activation(1, 0.5);
|
||||
p.record_activation(2, 0.1); p.record_activation(3, 0.0);
|
||||
p.end_sample();
|
||||
p.record_activation(0, 2.0); p.record_activation(1, 0.0);
|
||||
p.record_activation(2, 0.0); p.record_activation(3, 0.0);
|
||||
p.end_sample();
|
||||
assert_eq!(p.total_samples(), 2);
|
||||
assert_eq!(p.activation_frequency(0), 1.0);
|
||||
assert_eq!(p.activation_frequency(1), 0.5);
|
||||
assert_eq!(p.activation_frequency(3), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn neuron_profiler_hot_cold_partition() {
|
||||
let mut p = NeuronProfiler::new(5);
|
||||
for _ in 0..20 {
|
||||
p.record_activation(0, 1.0); p.record_activation(1, 1.0);
|
||||
p.record_activation(2, 0.0); p.record_activation(3, 0.0);
|
||||
p.record_activation(4, 0.0); p.end_sample();
|
||||
}
|
||||
let (hot, cold) = p.partition_hot_cold(0.5);
|
||||
assert!(hot.contains(&0) && hot.contains(&1));
|
||||
assert!(cold.contains(&2) && cold.contains(&3) && cold.contains(&4));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn neuron_profiler_sparsity_ratio() {
|
||||
let mut p = NeuronProfiler::new(10);
|
||||
for _ in 0..20 {
|
||||
p.record_activation(0, 1.0); p.record_activation(1, 1.0);
|
||||
for j in 2..10 { p.record_activation(j, 0.0); }
|
||||
p.end_sample();
|
||||
}
|
||||
assert!((p.sparsity_ratio() - 0.8).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sparse_linear_matches_dense() {
|
||||
let w = vec![vec![1.0,2.0,3.0], vec![4.0,5.0,6.0], vec![7.0,8.0,9.0]];
|
||||
let b = vec![0.1, 0.2, 0.3];
|
||||
let layer = SparseLinear::new(w, b, vec![0,1,2]);
|
||||
let inp = vec![1.0, 0.5, -1.0];
|
||||
let (so, do_) = (layer.forward(&inp), layer.forward_full(&inp));
|
||||
for (s, d) in so.iter().zip(do_.iter()) { assert!((s - d).abs() < 1e-6); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sparse_linear_skips_cold_neurons() {
|
||||
let w = vec![vec![1.0,2.0], vec![3.0,4.0], vec![5.0,6.0]];
|
||||
let layer = SparseLinear::new(w, vec![0.0;3], vec![1]);
|
||||
let out = layer.forward(&[1.0, 1.0]);
|
||||
assert_eq!(out[0], 0.0);
|
||||
assert_eq!(out[2], 0.0);
|
||||
assert!((out[1] - 7.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sparse_linear_flops_saved() {
|
||||
let w: Vec<Vec<f32>> = (0..4).map(|_| vec![1.0; 4]).collect();
|
||||
let layer = SparseLinear::new(w, vec![0.0;4], vec![0,2]);
|
||||
assert_eq!(layer.n_flops_saved(), 8);
|
||||
assert!((layer.density() - 0.5).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_symmetric_range() {
|
||||
let qw = Quantizer::quantize_symmetric(&[-1.0, 0.0, 0.5, 1.0]);
|
||||
assert!((qw.scale - 1.0/127.0).abs() < 1e-6);
|
||||
assert_eq!(qw.zero_point, 0);
|
||||
assert_eq!(*qw.data.last().unwrap(), 127);
|
||||
assert_eq!(qw.data[0], -127);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_symmetric_zero_is_zero() {
|
||||
let qw = Quantizer::quantize_symmetric(&[-5.0, 0.0, 3.0, 5.0]);
|
||||
assert_eq!(qw.data[1], 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_asymmetric_range() {
|
||||
let qw = Quantizer::quantize_asymmetric(&[0.0, 0.5, 1.0]);
|
||||
assert!((qw.scale - 1.0/255.0).abs() < 1e-4);
|
||||
assert_eq!(qw.zero_point as u8, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dequantize_round_trip_small_error() {
|
||||
let w: Vec<f32> = (-50..50).map(|i| i as f32 * 0.02).collect();
|
||||
let qw = Quantizer::quantize_symmetric(&w);
|
||||
assert!(Quantizer::quantization_error(&w, &qw) < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn int8_quantization_error_bounded() {
|
||||
let w: Vec<f32> = (0..256).map(|i| (i as f32 * 1.7).sin() * 2.0).collect();
|
||||
assert!(Quantizer::quantization_error(&w, &Quantizer::quantize_symmetric(&w)) < 0.01);
|
||||
assert!(Quantizer::quantization_error(&w, &Quantizer::quantize_asymmetric(&w)) < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn f16_round_trip_precision() {
|
||||
for &v in &[1.0f32, 0.5, -0.5, 3.14, 100.0, 0.001, -42.0, 65504.0] {
|
||||
let enc = Quantizer::f16_quantize(&[v]);
|
||||
let dec = Quantizer::f16_dequantize(&enc)[0];
|
||||
let re = if v.abs() > 1e-6 { ((v - dec) / v).abs() } else { (v - dec).abs() };
|
||||
assert!(re < 0.001, "f16 error for {v}: decoded={dec}, rel={re}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn f16_special_values() {
|
||||
assert_eq!(Quantizer::f16_dequantize(&Quantizer::f16_quantize(&[0.0]))[0], 0.0);
|
||||
let inf = Quantizer::f16_dequantize(&Quantizer::f16_quantize(&[f32::INFINITY]))[0];
|
||||
assert!(inf.is_infinite() && inf > 0.0);
|
||||
let ninf = Quantizer::f16_dequantize(&Quantizer::f16_quantize(&[f32::NEG_INFINITY]))[0];
|
||||
assert!(ninf.is_infinite() && ninf < 0.0);
|
||||
assert!(Quantizer::f16_dequantize(&Quantizer::f16_quantize(&[f32::NAN]))[0].is_nan());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sparse_model_add_layers() {
|
||||
let mut m = SparseModel::new(SparseConfig::default());
|
||||
m.add_layer("l1", vec![vec![1.0,2.0],vec![3.0,4.0]], vec![0.0,0.0]);
|
||||
m.add_layer("l2", vec![vec![0.5,-0.5],vec![1.0,1.0]], vec![0.1,0.2]);
|
||||
assert_eq!(m.n_layers(), 2);
|
||||
let out = m.forward(&[1.0, 1.0]);
|
||||
assert!(out[0] < 0.001); // ReLU zeros negative
|
||||
assert!((out[1] - 10.2).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sparse_model_profile_and_apply() {
|
||||
let mut m = SparseModel::new(SparseConfig { hot_threshold: 0.3, ..Default::default() });
|
||||
m.add_layer("h", vec![
|
||||
vec![1.0;4], vec![0.5;4], vec![-2.0;4], vec![-1.0;4],
|
||||
], vec![0.0;4]);
|
||||
let inp: Vec<Vec<f32>> = (0..50).map(|i| vec![1.0 + i as f32 * 0.01; 4]).collect();
|
||||
m.profile(&inp);
|
||||
m.apply_sparsity();
|
||||
let s = m.stats();
|
||||
assert!(s.cold_params > 0);
|
||||
assert!(s.sparsity > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sparse_model_stats_report() {
|
||||
let mut m = SparseModel::new(SparseConfig::default());
|
||||
m.add_layer("fc1", vec![vec![1.0;8];16], vec![0.0;16]);
|
||||
let s = m.stats();
|
||||
assert_eq!(s.total_params, 16*8+16);
|
||||
assert_eq!(s.quant_mode, QuantMode::Int8Symmetric);
|
||||
assert!(s.est_flops > 0 && s.est_memory_bytes > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn benchmark_produces_positive_latency() {
|
||||
let mut m = SparseModel::new(SparseConfig::default());
|
||||
m.add_layer("fc1", vec![vec![1.0;4];4], vec![0.0;4]);
|
||||
let r = BenchmarkRunner::benchmark_inference(&m, &[1.0;4], 10);
|
||||
assert!(r.mean_latency_us >= 0.0 && r.throughput_fps > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compare_dense_sparse_speedup() {
|
||||
let w = vec![vec![1.0f32;8];16];
|
||||
let b = vec![0.0f32;16];
|
||||
let mut pm = SparseModel::new(SparseConfig { hot_threshold: 0.5, quant_mode: QuantMode::F32, profile_frames: 20 });
|
||||
let mut pw: Vec<Vec<f32>> = w.clone();
|
||||
for row in pw.iter_mut().skip(8) { for v in row.iter_mut() { *v = -1.0; } }
|
||||
pm.add_layer("fc1", pw, b.clone());
|
||||
let inp: Vec<Vec<f32>> = (0..20).map(|_| vec![1.0;8]).collect();
|
||||
pm.profile(&inp); pm.apply_sparsity();
|
||||
let r = BenchmarkRunner::compare_dense_vs_sparse(&[w], &[b], &pm, &[1.0;8], 50);
|
||||
assert!(r.dense_latency_us >= 0.0 && r.sparse_latency_us >= 0.0);
|
||||
assert!(r.speedup > 0.0);
|
||||
assert!(r.accuracy_loss.is_finite());
|
||||
}
|
||||
|
||||
// ── Quantization integration tests ────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn apply_quantization_enables_quantized_forward() {
|
||||
let w = vec![
|
||||
vec![1.0, 2.0, 3.0, 4.0],
|
||||
vec![-1.0, -2.0, -3.0, -4.0],
|
||||
vec![0.5, 1.5, 2.5, 3.5],
|
||||
];
|
||||
let b = vec![0.1, 0.2, 0.3];
|
||||
let mut m = SparseModel::new(SparseConfig {
|
||||
quant_mode: QuantMode::Int8Symmetric,
|
||||
..Default::default()
|
||||
});
|
||||
m.add_layer("fc1", w.clone(), b.clone());
|
||||
|
||||
// Before quantization: dense forward
|
||||
let input = vec![1.0, 0.5, -1.0, 0.0];
|
||||
let dense_out = m.forward(&input);
|
||||
|
||||
// Apply quantization
|
||||
m.apply_quantization();
|
||||
|
||||
// After quantization: should use dequantized weights
|
||||
let quant_out = m.forward(&input);
|
||||
|
||||
// Output should be close to dense (within INT8 precision)
|
||||
for (d, q) in dense_out.iter().zip(quant_out.iter()) {
|
||||
let rel_err = if d.abs() > 0.01 { (d - q).abs() / d.abs() } else { (d - q).abs() };
|
||||
assert!(rel_err < 0.05, "quantized error too large: dense={d}, quant={q}, err={rel_err}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_forward_accuracy_within_5_percent() {
|
||||
// Multi-layer model
|
||||
let mut m = SparseModel::new(SparseConfig {
|
||||
quant_mode: QuantMode::Int8Symmetric,
|
||||
..Default::default()
|
||||
});
|
||||
let w1: Vec<Vec<f32>> = (0..8).map(|r| {
|
||||
(0..8).map(|c| ((r * 8 + c) as f32 * 0.17).sin() * 2.0).collect()
|
||||
}).collect();
|
||||
let b1 = vec![0.0f32; 8];
|
||||
let w2: Vec<Vec<f32>> = (0..4).map(|r| {
|
||||
(0..8).map(|c| ((r * 8 + c) as f32 * 0.23).cos() * 1.5).collect()
|
||||
}).collect();
|
||||
let b2 = vec![0.0f32; 4];
|
||||
m.add_layer("fc1", w1, b1);
|
||||
m.add_layer("fc2", w2, b2);
|
||||
|
||||
let input = vec![1.0, -0.5, 0.3, 0.7, -0.2, 0.9, -0.4, 0.6];
|
||||
let dense_out = m.forward(&input);
|
||||
|
||||
m.apply_quantization();
|
||||
let quant_out = m.forward(&input);
|
||||
|
||||
// MSE between dense and quantized should be small
|
||||
let mse: f32 = dense_out.iter().zip(quant_out.iter())
|
||||
.map(|(d, q)| (d - q).powi(2)).sum::<f32>() / dense_out.len() as f32;
|
||||
assert!(mse < 0.5, "quantization MSE too large: {mse}");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,881 @@
|
||||
//! Training loop with multi-term loss function for WiFi DensePose (ADR-023 Phase 4).
|
||||
//!
|
||||
//! 6-term composite loss, SGD with momentum, cosine annealing LR scheduler,
|
||||
//! PCK/OKS validation metrics, numerical gradient estimation, and checkpointing.
|
||||
//! All arithmetic uses f32. No external ML framework dependencies.
|
||||
|
||||
use std::path::Path;
|
||||
use crate::graph_transformer::{CsiToPoseTransformer, TransformerConfig};
|
||||
use crate::dataset;
|
||||
|
||||
/// Standard COCO keypoint sigmas for OKS (17 keypoints).
|
||||
pub const COCO_KEYPOINT_SIGMAS: [f32; 17] = [
|
||||
0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, 0.062,
|
||||
0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089,
|
||||
];
|
||||
|
||||
/// Symmetric keypoint pairs (left, right) indices into 17-keypoint COCO layout.
|
||||
const SYMMETRY_PAIRS: [(usize, usize); 5] =
|
||||
[(5, 6), (7, 8), (9, 10), (11, 12), (13, 14)];
|
||||
|
||||
/// Individual loss terms from the 6-component composite loss.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct LossComponents {
|
||||
pub keypoint: f32,
|
||||
pub body_part: f32,
|
||||
pub uv: f32,
|
||||
pub temporal: f32,
|
||||
pub edge: f32,
|
||||
pub symmetry: f32,
|
||||
}
|
||||
|
||||
/// Per-term weights for the composite loss function.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LossWeights {
|
||||
pub keypoint: f32,
|
||||
pub body_part: f32,
|
||||
pub uv: f32,
|
||||
pub temporal: f32,
|
||||
pub edge: f32,
|
||||
pub symmetry: f32,
|
||||
}
|
||||
|
||||
impl Default for LossWeights {
|
||||
fn default() -> Self {
|
||||
Self { keypoint: 1.0, body_part: 0.5, uv: 0.5, temporal: 0.1, edge: 0.2, symmetry: 0.1 }
|
||||
}
|
||||
}
|
||||
|
||||
/// Mean squared error on keypoints (x, y, confidence).
|
||||
pub fn keypoint_mse(pred: &[(f32, f32, f32)], target: &[(f32, f32, f32)]) -> f32 {
|
||||
if pred.is_empty() || target.is_empty() { return 0.0; }
|
||||
let n = pred.len().min(target.len());
|
||||
let sum: f32 = pred.iter().zip(target.iter()).take(n).map(|(p, t)| {
|
||||
(p.0 - t.0).powi(2) + (p.1 - t.1).powi(2) + (p.2 - t.2).powi(2)
|
||||
}).sum();
|
||||
sum / n as f32
|
||||
}
|
||||
|
||||
/// Cross-entropy loss for body part classification.
|
||||
/// `pred` = raw logits (length `n_samples * n_parts`), `target` = class indices.
|
||||
pub fn body_part_cross_entropy(pred: &[f32], target: &[u8], n_parts: usize) -> f32 {
|
||||
if target.is_empty() || n_parts == 0 || pred.len() < n_parts { return 0.0; }
|
||||
let n_samples = target.len().min(pred.len() / n_parts);
|
||||
if n_samples == 0 { return 0.0; }
|
||||
let mut total = 0.0f32;
|
||||
for i in 0..n_samples {
|
||||
let logits = &pred[i * n_parts..(i + 1) * n_parts];
|
||||
let class = target[i] as usize;
|
||||
if class >= n_parts { continue; }
|
||||
let max_l = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
|
||||
let lse = logits.iter().map(|&l| (l - max_l).exp()).sum::<f32>().ln() + max_l;
|
||||
total += -logits[class] + lse;
|
||||
}
|
||||
total / n_samples as f32
|
||||
}
|
||||
|
||||
/// L1 loss on UV coordinates.
|
||||
pub fn uv_regression_loss(pu: &[f32], pv: &[f32], tu: &[f32], tv: &[f32]) -> f32 {
|
||||
let n = pu.len().min(pv.len()).min(tu.len()).min(tv.len());
|
||||
if n == 0 { return 0.0; }
|
||||
let s: f32 = (0..n).map(|i| (pu[i] - tu[i]).abs() + (pv[i] - tv[i]).abs()).sum();
|
||||
s / n as f32
|
||||
}
|
||||
|
||||
/// Temporal consistency loss: penalizes large frame-to-frame keypoint jumps.
|
||||
pub fn temporal_consistency_loss(prev: &[(f32, f32, f32)], curr: &[(f32, f32, f32)]) -> f32 {
|
||||
let n = prev.len().min(curr.len());
|
||||
if n == 0 { return 0.0; }
|
||||
let s: f32 = prev.iter().zip(curr.iter()).take(n)
|
||||
.map(|(p, c)| (c.0 - p.0).powi(2) + (c.1 - p.1).powi(2)).sum();
|
||||
s / n as f32
|
||||
}
|
||||
|
||||
/// Graph edge loss: penalizes deviation of bone lengths from expected values.
|
||||
pub fn graph_edge_loss(
|
||||
kp: &[(f32, f32, f32)], edges: &[(usize, usize)], expected: &[f32],
|
||||
) -> f32 {
|
||||
if edges.is_empty() || edges.len() != expected.len() { return 0.0; }
|
||||
let (mut sum, mut cnt) = (0.0f32, 0usize);
|
||||
for (i, &(a, b)) in edges.iter().enumerate() {
|
||||
if a >= kp.len() || b >= kp.len() { continue; }
|
||||
let d = ((kp[a].0 - kp[b].0).powi(2) + (kp[a].1 - kp[b].1).powi(2)).sqrt();
|
||||
sum += (d - expected[i]).powi(2);
|
||||
cnt += 1;
|
||||
}
|
||||
if cnt == 0 { 0.0 } else { sum / cnt as f32 }
|
||||
}
|
||||
|
||||
/// Symmetry loss: penalizes asymmetry between left-right limb pairs.
|
||||
pub fn symmetry_loss(kp: &[(f32, f32, f32)]) -> f32 {
|
||||
if kp.len() < 15 { return 0.0; }
|
||||
let (mut sum, mut cnt) = (0.0f32, 0usize);
|
||||
for &(l, r) in &SYMMETRY_PAIRS {
|
||||
if l >= kp.len() || r >= kp.len() { continue; }
|
||||
let ld = ((kp[l].0 - kp[0].0).powi(2) + (kp[l].1 - kp[0].1).powi(2)).sqrt();
|
||||
let rd = ((kp[r].0 - kp[0].0).powi(2) + (kp[r].1 - kp[0].1).powi(2)).sqrt();
|
||||
sum += (ld - rd).powi(2);
|
||||
cnt += 1;
|
||||
}
|
||||
if cnt == 0 { 0.0 } else { sum / cnt as f32 }
|
||||
}
|
||||
|
||||
/// Weighted composite loss from individual components.
|
||||
pub fn composite_loss(c: &LossComponents, w: &LossWeights) -> f32 {
|
||||
w.keypoint * c.keypoint + w.body_part * c.body_part + w.uv * c.uv
|
||||
+ w.temporal * c.temporal + w.edge * c.edge + w.symmetry * c.symmetry
|
||||
}
|
||||
|
||||
// ── Optimizer ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// SGD optimizer with momentum and weight decay.
|
||||
pub struct SgdOptimizer {
|
||||
lr: f32,
|
||||
momentum: f32,
|
||||
weight_decay: f32,
|
||||
velocity: Vec<f32>,
|
||||
}
|
||||
|
||||
impl SgdOptimizer {
|
||||
pub fn new(lr: f32, momentum: f32, weight_decay: f32) -> Self {
|
||||
Self { lr, momentum, weight_decay, velocity: Vec::new() }
|
||||
}
|
||||
|
||||
/// v = mu*v + grad + wd*param; param -= lr*v
|
||||
pub fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
|
||||
if self.velocity.len() != params.len() {
|
||||
self.velocity = vec![0.0; params.len()];
|
||||
}
|
||||
for i in 0..params.len().min(gradients.len()) {
|
||||
let g = gradients[i] + self.weight_decay * params[i];
|
||||
self.velocity[i] = self.momentum * self.velocity[i] + g;
|
||||
params[i] -= self.lr * self.velocity[i];
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_lr(&mut self, lr: f32) { self.lr = lr; }
|
||||
pub fn state(&self) -> Vec<f32> { self.velocity.clone() }
|
||||
pub fn load_state(&mut self, state: Vec<f32>) { self.velocity = state; }
|
||||
}
|
||||
|
||||
// ── Learning rate schedulers ───────────────────────────────────────────────
|
||||
|
||||
/// Cosine annealing: decays LR from initial to min over total_steps.
|
||||
pub struct CosineScheduler { initial_lr: f32, min_lr: f32, total_steps: usize }
|
||||
|
||||
impl CosineScheduler {
|
||||
pub fn new(initial_lr: f32, min_lr: f32, total_steps: usize) -> Self {
|
||||
Self { initial_lr, min_lr, total_steps }
|
||||
}
|
||||
pub fn get_lr(&self, step: usize) -> f32 {
|
||||
if self.total_steps == 0 { return self.initial_lr; }
|
||||
let p = step.min(self.total_steps) as f32 / self.total_steps as f32;
|
||||
self.min_lr + (self.initial_lr - self.min_lr) * (1.0 + (std::f32::consts::PI * p).cos()) / 2.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Warmup + cosine annealing: linear ramp 0->initial_lr then cosine decay.
|
||||
pub struct WarmupCosineScheduler {
|
||||
warmup_steps: usize, initial_lr: f32, min_lr: f32, total_steps: usize,
|
||||
}
|
||||
|
||||
impl WarmupCosineScheduler {
|
||||
pub fn new(warmup_steps: usize, initial_lr: f32, min_lr: f32, total_steps: usize) -> Self {
|
||||
Self { warmup_steps, initial_lr, min_lr, total_steps }
|
||||
}
|
||||
pub fn get_lr(&self, step: usize) -> f32 {
|
||||
if step < self.warmup_steps {
|
||||
if self.warmup_steps == 0 { return self.initial_lr; }
|
||||
return self.initial_lr * (step as f32 / self.warmup_steps as f32);
|
||||
}
|
||||
let cs = self.total_steps.saturating_sub(self.warmup_steps);
|
||||
if cs == 0 { return self.min_lr; }
|
||||
let p = (step - self.warmup_steps).min(cs) as f32 / cs as f32;
|
||||
self.min_lr + (self.initial_lr - self.min_lr) * (1.0 + (std::f32::consts::PI * p).cos()) / 2.0
|
||||
}
|
||||
}
|
||||
|
||||
// ── Validation metrics ─────────────────────────────────────────────────────
|
||||
|
||||
/// Percentage of Correct Keypoints at a distance threshold.
|
||||
pub fn pck_at_threshold(pred: &[(f32, f32, f32)], target: &[(f32, f32, f32)], thr: f32) -> f32 {
|
||||
let n = pred.len().min(target.len());
|
||||
if n == 0 { return 0.0; }
|
||||
let (mut correct, mut total) = (0usize, 0usize);
|
||||
for i in 0..n {
|
||||
if target[i].2 <= 0.0 { continue; }
|
||||
total += 1;
|
||||
let d = ((pred[i].0 - target[i].0).powi(2) + (pred[i].1 - target[i].1).powi(2)).sqrt();
|
||||
if d <= thr { correct += 1; }
|
||||
}
|
||||
if total == 0 { 0.0 } else { correct as f32 / total as f32 }
|
||||
}
|
||||
|
||||
/// Object Keypoint Similarity for a single instance.
|
||||
pub fn oks_single(
|
||||
pred: &[(f32, f32, f32)], target: &[(f32, f32, f32)], sigmas: &[f32], area: f32,
|
||||
) -> f32 {
|
||||
let n = pred.len().min(target.len()).min(sigmas.len());
|
||||
if n == 0 || area <= 0.0 { return 0.0; }
|
||||
let (mut sum, mut vis) = (0.0f32, 0usize);
|
||||
for i in 0..n {
|
||||
if target[i].2 <= 0.0 { continue; }
|
||||
vis += 1;
|
||||
let dsq = (pred[i].0 - target[i].0).powi(2) + (pred[i].1 - target[i].1).powi(2);
|
||||
let var = 2.0 * sigmas[i] * sigmas[i] * area;
|
||||
if var > 0.0 { sum += (-dsq / (2.0 * var)).exp(); }
|
||||
}
|
||||
if vis == 0 { 0.0 } else { sum / vis as f32 }
|
||||
}
|
||||
|
||||
/// Mean OKS over multiple predictions (simplified mAP).
|
||||
pub fn oks_map(preds: &[Vec<(f32, f32, f32)>], targets: &[Vec<(f32, f32, f32)>]) -> f32 {
|
||||
let n = preds.len().min(targets.len());
|
||||
if n == 0 { return 0.0; }
|
||||
let s: f32 = preds.iter().zip(targets.iter()).take(n)
|
||||
.map(|(p, t)| oks_single(p, t, &COCO_KEYPOINT_SIGMAS, 1.0)).sum();
|
||||
s / n as f32
|
||||
}
|
||||
|
||||
// ── Gradient estimation ────────────────────────────────────────────────────
|
||||
|
||||
/// Central difference gradient: (f(x+eps) - f(x-eps)) / (2*eps).
|
||||
pub fn estimate_gradient(f: impl Fn(&[f32]) -> f32, params: &[f32], eps: f32) -> Vec<f32> {
|
||||
let mut grad = vec![0.0f32; params.len()];
|
||||
let mut p_plus = params.to_vec();
|
||||
let mut p_minus = params.to_vec();
|
||||
for i in 0..params.len() {
|
||||
p_plus[i] = params[i] + eps;
|
||||
p_minus[i] = params[i] - eps;
|
||||
grad[i] = (f(&p_plus) - f(&p_minus)) / (2.0 * eps);
|
||||
p_plus[i] = params[i];
|
||||
p_minus[i] = params[i];
|
||||
}
|
||||
grad
|
||||
}
|
||||
|
||||
/// Clip gradients by global L2 norm.
|
||||
pub fn clip_gradients(gradients: &mut [f32], max_norm: f32) {
|
||||
let norm = gradients.iter().map(|g| g * g).sum::<f32>().sqrt();
|
||||
if norm > max_norm && norm > 0.0 {
|
||||
let s = max_norm / norm;
|
||||
gradients.iter_mut().for_each(|g| *g *= s);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Training sample ────────────────────────────────────────────────────────
|
||||
|
||||
/// A single training sample (defined locally, not dependent on dataset.rs).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrainingSample {
|
||||
pub csi_features: Vec<Vec<f32>>,
|
||||
pub target_keypoints: Vec<(f32, f32, f32)>,
|
||||
pub target_body_parts: Vec<u8>,
|
||||
pub target_uv: (Vec<f32>, Vec<f32>),
|
||||
}
|
||||
|
||||
/// Convert a dataset::TrainingSample into a trainer::TrainingSample.
|
||||
pub fn from_dataset_sample(ds: &dataset::TrainingSample) -> TrainingSample {
|
||||
let csi_features = ds.csi_window.clone();
|
||||
let target_keypoints: Vec<(f32, f32, f32)> = ds.pose_label.keypoints.to_vec();
|
||||
let target_body_parts: Vec<u8> = ds.pose_label.body_parts.iter()
|
||||
.map(|bp| bp.part_id)
|
||||
.collect();
|
||||
let (tu, tv) = if ds.pose_label.body_parts.is_empty() {
|
||||
(Vec::new(), Vec::new())
|
||||
} else {
|
||||
let u: Vec<f32> = ds.pose_label.body_parts.iter()
|
||||
.flat_map(|bp| bp.u_coords.iter().copied()).collect();
|
||||
let v: Vec<f32> = ds.pose_label.body_parts.iter()
|
||||
.flat_map(|bp| bp.v_coords.iter().copied()).collect();
|
||||
(u, v)
|
||||
};
|
||||
TrainingSample { csi_features, target_keypoints, target_body_parts, target_uv: (tu, tv) }
|
||||
}
|
||||
|
||||
// ── Checkpoint ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// Serializable version of EpochStats for checkpoint storage.
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct EpochStatsSerializable {
|
||||
pub epoch: usize, pub train_loss: f32, pub val_loss: f32,
|
||||
pub pck_02: f32, pub oks_map: f32, pub lr: f32,
|
||||
pub loss_keypoint: f32, pub loss_body_part: f32, pub loss_uv: f32,
|
||||
pub loss_temporal: f32, pub loss_edge: f32, pub loss_symmetry: f32,
|
||||
}
|
||||
|
||||
/// Serializable training checkpoint.
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct Checkpoint {
|
||||
pub epoch: usize,
|
||||
pub params: Vec<f32>,
|
||||
pub optimizer_state: Vec<f32>,
|
||||
pub best_loss: f32,
|
||||
pub metrics: EpochStatsSerializable,
|
||||
}
|
||||
|
||||
impl Checkpoint {
|
||||
pub fn save_to_file(&self, path: &Path) -> std::io::Result<()> {
|
||||
let json = serde_json::to_string_pretty(self)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
|
||||
std::fs::write(path, json)
|
||||
}
|
||||
pub fn load_from_file(path: &Path) -> std::io::Result<Self> {
|
||||
let json = std::fs::read_to_string(path)?;
|
||||
serde_json::from_str(&json)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
|
||||
}
|
||||
}
|
||||
|
||||
/// Statistics for a single training epoch.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EpochStats {
|
||||
pub epoch: usize,
|
||||
pub train_loss: f32,
|
||||
pub val_loss: f32,
|
||||
pub pck_02: f32,
|
||||
pub oks_map: f32,
|
||||
pub lr: f32,
|
||||
pub loss_components: LossComponents,
|
||||
}
|
||||
|
||||
impl EpochStats {
|
||||
fn to_serializable(&self) -> EpochStatsSerializable {
|
||||
let c = &self.loss_components;
|
||||
EpochStatsSerializable {
|
||||
epoch: self.epoch, train_loss: self.train_loss, val_loss: self.val_loss,
|
||||
pck_02: self.pck_02, oks_map: self.oks_map, lr: self.lr,
|
||||
loss_keypoint: c.keypoint, loss_body_part: c.body_part, loss_uv: c.uv,
|
||||
loss_temporal: c.temporal, loss_edge: c.edge, loss_symmetry: c.symmetry,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Final result from a complete training run.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrainingResult {
|
||||
pub best_epoch: usize,
|
||||
pub best_pck: f32,
|
||||
pub best_oks: f32,
|
||||
pub history: Vec<EpochStats>,
|
||||
pub total_time_secs: f64,
|
||||
}
|
||||
|
||||
/// Configuration for the training loop.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrainerConfig {
|
||||
pub epochs: usize,
|
||||
pub batch_size: usize,
|
||||
pub lr: f32,
|
||||
pub momentum: f32,
|
||||
pub weight_decay: f32,
|
||||
pub warmup_epochs: usize,
|
||||
pub min_lr: f32,
|
||||
pub early_stop_patience: usize,
|
||||
pub checkpoint_every: usize,
|
||||
pub loss_weights: LossWeights,
|
||||
}
|
||||
|
||||
impl Default for TrainerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
epochs: 100, batch_size: 32, lr: 0.01, momentum: 0.9, weight_decay: 1e-4,
|
||||
warmup_epochs: 5, min_lr: 1e-6, early_stop_patience: 10, checkpoint_every: 10,
|
||||
loss_weights: LossWeights::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Trainer ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Training loop orchestrator for WiFi DensePose pose estimation.
|
||||
pub struct Trainer {
|
||||
config: TrainerConfig,
|
||||
optimizer: SgdOptimizer,
|
||||
scheduler: WarmupCosineScheduler,
|
||||
params: Vec<f32>,
|
||||
history: Vec<EpochStats>,
|
||||
best_val_loss: f32,
|
||||
best_epoch: usize,
|
||||
epochs_without_improvement: usize,
|
||||
/// Snapshot of params at the best validation loss epoch.
|
||||
best_params: Vec<f32>,
|
||||
/// When set, predict_keypoints delegates to the transformer's forward().
|
||||
transformer: Option<CsiToPoseTransformer>,
|
||||
/// Transformer config (needed for unflatten during gradient estimation).
|
||||
transformer_config: Option<TransformerConfig>,
|
||||
}
|
||||
|
||||
impl Trainer {
|
||||
pub fn new(config: TrainerConfig) -> Self {
|
||||
let optimizer = SgdOptimizer::new(config.lr, config.momentum, config.weight_decay);
|
||||
let scheduler = WarmupCosineScheduler::new(
|
||||
config.warmup_epochs, config.lr, config.min_lr, config.epochs,
|
||||
);
|
||||
let params: Vec<f32> = (0..64).map(|i| (i as f32 * 0.7 + 0.3).sin() * 0.1).collect();
|
||||
let best_params = params.clone();
|
||||
Self {
|
||||
config, optimizer, scheduler, params, history: Vec::new(),
|
||||
best_val_loss: f32::MAX, best_epoch: 0, epochs_without_improvement: 0,
|
||||
best_params, transformer: None, transformer_config: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a trainer backed by the graph transformer. Gradient estimation
|
||||
/// uses central differences on the transformer's flattened weights.
|
||||
pub fn with_transformer(config: TrainerConfig, transformer: CsiToPoseTransformer) -> Self {
|
||||
let params = transformer.flatten_weights();
|
||||
let optimizer = SgdOptimizer::new(config.lr, config.momentum, config.weight_decay);
|
||||
let scheduler = WarmupCosineScheduler::new(
|
||||
config.warmup_epochs, config.lr, config.min_lr, config.epochs,
|
||||
);
|
||||
let tc = transformer.config().clone();
|
||||
let best_params = params.clone();
|
||||
Self {
|
||||
config, optimizer, scheduler, params, history: Vec::new(),
|
||||
best_val_loss: f32::MAX, best_epoch: 0, epochs_without_improvement: 0,
|
||||
best_params, transformer: Some(transformer), transformer_config: Some(tc),
|
||||
}
|
||||
}
|
||||
|
||||
/// Access the transformer (if any).
|
||||
pub fn transformer(&self) -> Option<&CsiToPoseTransformer> { self.transformer.as_ref() }
|
||||
|
||||
/// Get a mutable reference to the transformer.
|
||||
pub fn transformer_mut(&mut self) -> Option<&mut CsiToPoseTransformer> { self.transformer.as_mut() }
|
||||
|
||||
/// Return current flattened params (transformer or simple).
|
||||
pub fn params(&self) -> &[f32] { &self.params }
|
||||
|
||||
pub fn train_epoch(&mut self, samples: &[TrainingSample]) -> EpochStats {
|
||||
let epoch = self.history.len();
|
||||
let lr = self.scheduler.get_lr(epoch);
|
||||
self.optimizer.set_lr(lr);
|
||||
|
||||
let mut acc = LossComponents::default();
|
||||
let bs = self.config.batch_size.max(1);
|
||||
let nb = (samples.len() + bs - 1) / bs;
|
||||
let tc = self.transformer_config.clone();
|
||||
|
||||
for bi in 0..nb {
|
||||
let batch = &samples[bi * bs..(bi * bs + bs).min(samples.len())];
|
||||
let snap = self.params.clone();
|
||||
let w = self.config.loss_weights.clone();
|
||||
let loss_fn = |p: &[f32]| {
|
||||
match &tc {
|
||||
Some(tconf) => Self::batch_loss_with_transformer(p, batch, &w, tconf),
|
||||
None => Self::batch_loss(p, batch, &w),
|
||||
}
|
||||
};
|
||||
let mut grad = estimate_gradient(loss_fn, &snap, 1e-4);
|
||||
clip_gradients(&mut grad, 1.0);
|
||||
self.optimizer.step(&mut self.params, &grad);
|
||||
|
||||
let c = Self::batch_loss_components_impl(&self.params, batch, tc.as_ref());
|
||||
acc.keypoint += c.keypoint;
|
||||
acc.body_part += c.body_part;
|
||||
acc.uv += c.uv;
|
||||
acc.temporal += c.temporal;
|
||||
acc.edge += c.edge;
|
||||
acc.symmetry += c.symmetry;
|
||||
}
|
||||
|
||||
if nb > 0 {
|
||||
let inv = 1.0 / nb as f32;
|
||||
acc.keypoint *= inv; acc.body_part *= inv; acc.uv *= inv;
|
||||
acc.temporal *= inv; acc.edge *= inv; acc.symmetry *= inv;
|
||||
}
|
||||
|
||||
let train_loss = composite_loss(&acc, &self.config.loss_weights);
|
||||
let (pck, oks) = self.evaluate_metrics(samples);
|
||||
let stats = EpochStats {
|
||||
epoch, train_loss, val_loss: train_loss, pck_02: pck, oks_map: oks,
|
||||
lr, loss_components: acc,
|
||||
};
|
||||
self.history.push(stats.clone());
|
||||
stats
|
||||
}
|
||||
|
||||
pub fn should_stop(&self) -> bool {
|
||||
self.epochs_without_improvement >= self.config.early_stop_patience
|
||||
}
|
||||
|
||||
pub fn best_metrics(&self) -> Option<&EpochStats> {
|
||||
self.history.get(self.best_epoch)
|
||||
}
|
||||
|
||||
pub fn run_training(&mut self, train: &[TrainingSample], val: &[TrainingSample]) -> TrainingResult {
|
||||
let start = std::time::Instant::now();
|
||||
for _ in 0..self.config.epochs {
|
||||
let mut stats = self.train_epoch(train);
|
||||
let tc = self.transformer_config.clone();
|
||||
let val_loss = if !val.is_empty() {
|
||||
let c = Self::batch_loss_components_impl(&self.params, val, tc.as_ref());
|
||||
composite_loss(&c, &self.config.loss_weights)
|
||||
} else { stats.train_loss };
|
||||
stats.val_loss = val_loss;
|
||||
if !val.is_empty() {
|
||||
let (pck, oks) = self.evaluate_metrics(val);
|
||||
stats.pck_02 = pck;
|
||||
stats.oks_map = oks;
|
||||
}
|
||||
if let Some(last) = self.history.last_mut() {
|
||||
last.val_loss = stats.val_loss;
|
||||
last.pck_02 = stats.pck_02;
|
||||
last.oks_map = stats.oks_map;
|
||||
}
|
||||
if val_loss < self.best_val_loss {
|
||||
self.best_val_loss = val_loss;
|
||||
self.best_epoch = stats.epoch;
|
||||
self.best_params = self.params.clone();
|
||||
self.epochs_without_improvement = 0;
|
||||
} else {
|
||||
self.epochs_without_improvement += 1;
|
||||
}
|
||||
if self.should_stop() { break; }
|
||||
}
|
||||
// Restore best-epoch params for checkpoint and downstream use
|
||||
self.params = self.best_params.clone();
|
||||
let best = self.best_metrics().cloned().unwrap_or(EpochStats {
|
||||
epoch: 0, train_loss: f32::MAX, val_loss: f32::MAX, pck_02: 0.0,
|
||||
oks_map: 0.0, lr: self.config.lr, loss_components: LossComponents::default(),
|
||||
});
|
||||
TrainingResult {
|
||||
best_epoch: best.epoch, best_pck: best.pck_02, best_oks: best.oks_map,
|
||||
history: self.history.clone(), total_time_secs: start.elapsed().as_secs_f64(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn checkpoint(&self) -> Checkpoint {
|
||||
let m = self.history.last().map(|s| s.to_serializable()).unwrap_or(
|
||||
EpochStatsSerializable {
|
||||
epoch: 0, train_loss: 0.0, val_loss: 0.0, pck_02: 0.0,
|
||||
oks_map: 0.0, lr: self.config.lr, loss_keypoint: 0.0, loss_body_part: 0.0,
|
||||
loss_uv: 0.0, loss_temporal: 0.0, loss_edge: 0.0, loss_symmetry: 0.0,
|
||||
},
|
||||
);
|
||||
Checkpoint {
|
||||
epoch: self.history.len(), params: self.params.clone(),
|
||||
optimizer_state: self.optimizer.state(), best_loss: self.best_val_loss, metrics: m,
|
||||
}
|
||||
}
|
||||
|
||||
fn batch_loss(params: &[f32], batch: &[TrainingSample], w: &LossWeights) -> f32 {
|
||||
composite_loss(&Self::batch_loss_components_impl(params, batch, None), w)
|
||||
}
|
||||
|
||||
fn batch_loss_with_transformer(
|
||||
params: &[f32], batch: &[TrainingSample], w: &LossWeights, tc: &TransformerConfig,
|
||||
) -> f32 {
|
||||
composite_loss(&Self::batch_loss_components_impl(params, batch, Some(tc)), w)
|
||||
}
|
||||
|
||||
fn batch_loss_components(params: &[f32], batch: &[TrainingSample]) -> LossComponents {
|
||||
Self::batch_loss_components_impl(params, batch, None)
|
||||
}
|
||||
|
||||
fn batch_loss_components_impl(
|
||||
params: &[f32], batch: &[TrainingSample], tc: Option<&TransformerConfig>,
|
||||
) -> LossComponents {
|
||||
if batch.is_empty() { return LossComponents::default(); }
|
||||
let mut acc = LossComponents::default();
|
||||
let mut prev_kp: Option<Vec<(f32, f32, f32)>> = None;
|
||||
for sample in batch {
|
||||
let pred_kp = match tc {
|
||||
Some(tconf) => Self::predict_keypoints_transformer(params, sample, tconf),
|
||||
None => Self::predict_keypoints(params, sample),
|
||||
};
|
||||
acc.keypoint += keypoint_mse(&pred_kp, &sample.target_keypoints);
|
||||
let n_parts = 24usize;
|
||||
let logits: Vec<f32> = sample.target_body_parts.iter().flat_map(|_| {
|
||||
(0..n_parts).map(|j| if j < params.len() { params[j] * 0.1 } else { 0.0 })
|
||||
.collect::<Vec<f32>>()
|
||||
}).collect();
|
||||
acc.body_part += body_part_cross_entropy(&logits, &sample.target_body_parts, n_parts);
|
||||
let (ref tu, ref tv) = sample.target_uv;
|
||||
let pu: Vec<f32> = tu.iter().enumerate()
|
||||
.map(|(i, &u)| u + if i < params.len() { params[i] * 0.01 } else { 0.0 }).collect();
|
||||
let pv: Vec<f32> = tv.iter().enumerate()
|
||||
.map(|(i, &v)| v + if i < params.len() { params[i] * 0.01 } else { 0.0 }).collect();
|
||||
acc.uv += uv_regression_loss(&pu, &pv, tu, tv);
|
||||
if let Some(ref prev) = prev_kp {
|
||||
acc.temporal += temporal_consistency_loss(prev, &pred_kp);
|
||||
}
|
||||
acc.symmetry += symmetry_loss(&pred_kp);
|
||||
prev_kp = Some(pred_kp);
|
||||
}
|
||||
let inv = 1.0 / batch.len() as f32;
|
||||
acc.keypoint *= inv; acc.body_part *= inv; acc.uv *= inv;
|
||||
acc.temporal *= inv; acc.symmetry *= inv;
|
||||
acc
|
||||
}
|
||||
|
||||
fn predict_keypoints(params: &[f32], sample: &TrainingSample) -> Vec<(f32, f32, f32)> {
|
||||
let n_kp = sample.target_keypoints.len().max(17);
|
||||
let feats: Vec<f32> = sample.csi_features.iter().flat_map(|v| v.iter().copied()).collect();
|
||||
(0..n_kp).map(|k| {
|
||||
let base = k * 3;
|
||||
let (mut x, mut y) = (0.0f32, 0.0f32);
|
||||
for (i, &f) in feats.iter().take(params.len()).enumerate() {
|
||||
let pi = (base + i) % params.len();
|
||||
x += f * params[pi] * 0.01;
|
||||
y += f * params[(pi + 1) % params.len()] * 0.01;
|
||||
}
|
||||
if base < params.len() {
|
||||
x += params[base % params.len()];
|
||||
y += params[(base + 1) % params.len()];
|
||||
}
|
||||
let c = if base + 2 < params.len() {
|
||||
params[(base + 2) % params.len()].clamp(0.0, 1.0)
|
||||
} else { 0.5 };
|
||||
(x, y, c)
|
||||
}).collect()
|
||||
}
|
||||
|
||||
/// Predict keypoints using the graph transformer. Uses zero-init
|
||||
/// constructor (fast) then overwrites all weights from params.
|
||||
fn predict_keypoints_transformer(
|
||||
params: &[f32], sample: &TrainingSample, tc: &TransformerConfig,
|
||||
) -> Vec<(f32, f32, f32)> {
|
||||
let mut t = CsiToPoseTransformer::zeros(tc.clone());
|
||||
if t.unflatten_weights(params).is_err() {
|
||||
return Self::predict_keypoints(params, sample);
|
||||
}
|
||||
let output = t.forward(&sample.csi_features);
|
||||
output.keypoints
|
||||
}
|
||||
|
||||
fn evaluate_metrics(&self, samples: &[TrainingSample]) -> (f32, f32) {
|
||||
if samples.is_empty() { return (0.0, 0.0); }
|
||||
let preds: Vec<Vec<_>> = samples.iter().map(|s| {
|
||||
match &self.transformer_config {
|
||||
Some(tc) => Self::predict_keypoints_transformer(&self.params, s, tc),
|
||||
None => Self::predict_keypoints(&self.params, s),
|
||||
}
|
||||
}).collect();
|
||||
let targets: Vec<Vec<_>> = samples.iter().map(|s| s.target_keypoints.clone()).collect();
|
||||
let pck = preds.iter().zip(targets.iter())
|
||||
.map(|(p, t)| pck_at_threshold(p, t, 0.2)).sum::<f32>() / samples.len() as f32;
|
||||
(pck, oks_map(&preds, &targets))
|
||||
}
|
||||
|
||||
/// Sync the internal transformer's weights from the flat params after training.
|
||||
pub fn sync_transformer_weights(&mut self) {
|
||||
if let Some(ref mut t) = self.transformer {
|
||||
let _ = t.unflatten_weights(&self.params);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn mkp(off: f32) -> Vec<(f32, f32, f32)> {
|
||||
(0..17).map(|i| (i as f32 + off, i as f32 * 2.0 + off, 1.0)).collect()
|
||||
}
|
||||
|
||||
fn symmetric_pose() -> Vec<(f32, f32, f32)> {
|
||||
let mut kp = vec![(0.0f32, 0.0f32, 1.0f32); 17];
|
||||
kp[0] = (5.0, 5.0, 1.0);
|
||||
for &(l, r) in &SYMMETRY_PAIRS { kp[l] = (3.0, 5.0, 1.0); kp[r] = (7.0, 5.0, 1.0); }
|
||||
kp
|
||||
}
|
||||
|
||||
fn sample() -> TrainingSample {
|
||||
TrainingSample {
|
||||
csi_features: vec![vec![1.0; 8]; 4],
|
||||
target_keypoints: mkp(0.0),
|
||||
target_body_parts: vec![0, 1, 2, 3],
|
||||
target_uv: (vec![0.5; 4], vec![0.5; 4]),
|
||||
}
|
||||
}
|
||||
|
||||
#[test] fn keypoint_mse_zero_for_identical() { assert_eq!(keypoint_mse(&mkp(0.0), &mkp(0.0)), 0.0); }
|
||||
#[test] fn keypoint_mse_positive_for_different() { assert!(keypoint_mse(&mkp(0.0), &mkp(1.0)) > 0.0); }
|
||||
#[test] fn keypoint_mse_symmetric() {
|
||||
let (ab, ba) = (keypoint_mse(&mkp(0.0), &mkp(1.0)), keypoint_mse(&mkp(1.0), &mkp(0.0)));
|
||||
assert!((ab - ba).abs() < 1e-6, "{ab} vs {ba}");
|
||||
}
|
||||
#[test] fn temporal_consistency_zero_for_static() {
|
||||
assert_eq!(temporal_consistency_loss(&mkp(0.0), &mkp(0.0)), 0.0);
|
||||
}
|
||||
#[test] fn temporal_consistency_positive_for_motion() {
|
||||
assert!(temporal_consistency_loss(&mkp(0.0), &mkp(1.0)) > 0.0);
|
||||
}
|
||||
#[test] fn symmetry_loss_zero_for_symmetric_pose() {
|
||||
assert!(symmetry_loss(&symmetric_pose()) < 1e-6);
|
||||
}
|
||||
#[test] fn graph_edge_loss_zero_when_correct() {
|
||||
let kp = vec![(0.0,0.0,1.0),(3.0,4.0,1.0),(6.0,0.0,1.0)];
|
||||
assert!(graph_edge_loss(&kp, &[(0,1),(1,2)], &[5.0, 5.0]) < 1e-6);
|
||||
}
|
||||
#[test] fn composite_loss_respects_weights() {
|
||||
let c = LossComponents { keypoint:1.0, body_part:1.0, uv:1.0, temporal:1.0, edge:1.0, symmetry:1.0 };
|
||||
let w1 = LossWeights { keypoint:1.0, body_part:0.0, uv:0.0, temporal:0.0, edge:0.0, symmetry:0.0 };
|
||||
let w2 = LossWeights { keypoint:2.0, body_part:0.0, uv:0.0, temporal:0.0, edge:0.0, symmetry:0.0 };
|
||||
assert!((composite_loss(&c, &w2) - 2.0 * composite_loss(&c, &w1)).abs() < 1e-6);
|
||||
let wz = LossWeights { keypoint:0.0, body_part:0.0, uv:0.0, temporal:0.0, edge:0.0, symmetry:0.0 };
|
||||
assert_eq!(composite_loss(&c, &wz), 0.0);
|
||||
}
|
||||
#[test] fn cosine_scheduler_starts_at_initial() {
|
||||
assert!((CosineScheduler::new(0.01, 0.0001, 100).get_lr(0) - 0.01).abs() < 1e-6);
|
||||
}
|
||||
#[test] fn cosine_scheduler_ends_at_min() {
|
||||
assert!((CosineScheduler::new(0.01, 0.0001, 100).get_lr(100) - 0.0001).abs() < 1e-6);
|
||||
}
|
||||
#[test] fn cosine_scheduler_midpoint() {
|
||||
assert!((CosineScheduler::new(0.01, 0.0, 100).get_lr(50) - 0.005).abs() < 1e-4);
|
||||
}
|
||||
#[test] fn warmup_starts_at_zero() {
|
||||
assert!(WarmupCosineScheduler::new(10, 0.01, 0.0001, 100).get_lr(0) < 1e-6);
|
||||
}
|
||||
#[test] fn warmup_reaches_initial_at_warmup_end() {
|
||||
assert!((WarmupCosineScheduler::new(10, 0.01, 0.0001, 100).get_lr(10) - 0.01).abs() < 1e-6);
|
||||
}
|
||||
#[test] fn pck_perfect_prediction_is_1() {
|
||||
assert!((pck_at_threshold(&mkp(0.0), &mkp(0.0), 0.2) - 1.0).abs() < 1e-6);
|
||||
}
|
||||
#[test] fn pck_all_wrong_is_0() {
|
||||
assert!(pck_at_threshold(&mkp(0.0), &mkp(100.0), 0.2) < 1e-6);
|
||||
}
|
||||
#[test] fn oks_perfect_is_1() {
|
||||
assert!((oks_single(&mkp(0.0), &mkp(0.0), &COCO_KEYPOINT_SIGMAS, 1.0) - 1.0).abs() < 1e-6);
|
||||
}
|
||||
#[test] fn sgd_step_reduces_simple_loss() {
|
||||
let mut p = vec![5.0f32];
|
||||
let mut opt = SgdOptimizer::new(0.1, 0.0, 0.0);
|
||||
let init = p[0] * p[0];
|
||||
for _ in 0..10 { let grad = vec![2.0 * p[0]]; opt.step(&mut p, &grad); }
|
||||
assert!(p[0] * p[0] < init);
|
||||
}
|
||||
#[test] fn gradient_clipping_respects_max_norm() {
|
||||
let mut g = vec![3.0, 4.0];
|
||||
clip_gradients(&mut g, 2.5);
|
||||
assert!((g.iter().map(|x| x*x).sum::<f32>().sqrt() - 2.5).abs() < 1e-4);
|
||||
}
|
||||
#[test] fn early_stopping_triggers() {
|
||||
let cfg = TrainerConfig { epochs: 100, early_stop_patience: 3, ..Default::default() };
|
||||
let mut t = Trainer::new(cfg);
|
||||
let s = vec![sample()];
|
||||
t.best_val_loss = -1.0;
|
||||
let mut stopped = false;
|
||||
for _ in 0..20 {
|
||||
t.train_epoch(&s);
|
||||
t.epochs_without_improvement += 1;
|
||||
if t.should_stop() { stopped = true; break; }
|
||||
}
|
||||
assert!(stopped);
|
||||
}
|
||||
#[test] fn checkpoint_round_trip() {
|
||||
let mut t = Trainer::new(TrainerConfig::default());
|
||||
t.train_epoch(&[sample()]);
|
||||
let ckpt = t.checkpoint();
|
||||
let dir = std::env::temp_dir().join("trainer_ckpt_test");
|
||||
std::fs::create_dir_all(&dir).unwrap();
|
||||
let path = dir.join("ckpt.json");
|
||||
ckpt.save_to_file(&path).unwrap();
|
||||
let loaded = Checkpoint::load_from_file(&path).unwrap();
|
||||
assert_eq!(loaded.epoch, ckpt.epoch);
|
||||
assert_eq!(loaded.params.len(), ckpt.params.len());
|
||||
assert!((loaded.best_loss - ckpt.best_loss).abs() < 1e-6);
|
||||
let _ = std::fs::remove_file(&path);
|
||||
let _ = std::fs::remove_dir(&dir);
|
||||
}
|
||||
|
||||
// ── Integration tests: transformer + trainer pipeline ──────────
|
||||
|
||||
#[test]
|
||||
fn dataset_to_trainer_conversion() {
|
||||
let ds = crate::dataset::TrainingSample {
|
||||
csi_window: vec![vec![1.0; 8]; 4],
|
||||
pose_label: crate::dataset::PoseLabel {
|
||||
keypoints: {
|
||||
let mut kp = [(0.0f32, 0.0f32, 1.0f32); 17];
|
||||
for (i, k) in kp.iter_mut().enumerate() {
|
||||
k.0 = i as f32; k.1 = i as f32 * 2.0;
|
||||
}
|
||||
kp
|
||||
},
|
||||
body_parts: Vec::new(),
|
||||
confidence: 1.0,
|
||||
},
|
||||
source: "test",
|
||||
};
|
||||
let ts = from_dataset_sample(&ds);
|
||||
assert_eq!(ts.csi_features.len(), 4);
|
||||
assert_eq!(ts.csi_features[0].len(), 8);
|
||||
assert_eq!(ts.target_keypoints.len(), 17);
|
||||
assert!((ts.target_keypoints[0].0 - 0.0).abs() < 1e-6);
|
||||
assert!((ts.target_keypoints[1].0 - 1.0).abs() < 1e-6);
|
||||
assert!(ts.target_body_parts.is_empty()); // no body parts in source
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trainer_with_transformer_runs_epoch() {
|
||||
use crate::graph_transformer::{CsiToPoseTransformer, TransformerConfig};
|
||||
let tf_config = TransformerConfig {
|
||||
n_subcarriers: 8, n_keypoints: 17, d_model: 8, n_heads: 2, n_gnn_layers: 1,
|
||||
};
|
||||
let transformer = CsiToPoseTransformer::new(tf_config);
|
||||
let config = TrainerConfig {
|
||||
epochs: 2, batch_size: 4, lr: 0.001,
|
||||
warmup_epochs: 0, early_stop_patience: 100,
|
||||
..Default::default()
|
||||
};
|
||||
let mut t = Trainer::with_transformer(config, transformer);
|
||||
|
||||
// The params should be the transformer's flattened weights
|
||||
assert!(t.params().len() > 100, "transformer should have many params");
|
||||
|
||||
// Create samples matching the transformer's n_subcarriers=8
|
||||
let samples: Vec<TrainingSample> = (0..8).map(|i| TrainingSample {
|
||||
csi_features: vec![vec![(i as f32 * 0.1).sin(); 8]; 4],
|
||||
target_keypoints: (0..17).map(|k| (k as f32 * 0.5, k as f32 * 0.3, 1.0)).collect(),
|
||||
target_body_parts: vec![0, 1, 2],
|
||||
target_uv: (vec![0.5; 3], vec![0.5; 3]),
|
||||
}).collect();
|
||||
|
||||
let stats = t.train_epoch(&samples);
|
||||
assert!(stats.train_loss.is_finite(), "loss should be finite");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trainer_with_transformer_loss_finite_after_training() {
|
||||
use crate::graph_transformer::{CsiToPoseTransformer, TransformerConfig};
|
||||
let tf_config = TransformerConfig {
|
||||
n_subcarriers: 8, n_keypoints: 17, d_model: 8, n_heads: 2, n_gnn_layers: 1,
|
||||
};
|
||||
let transformer = CsiToPoseTransformer::new(tf_config);
|
||||
let config = TrainerConfig {
|
||||
epochs: 3, batch_size: 4, lr: 0.0001,
|
||||
warmup_epochs: 0, early_stop_patience: 100,
|
||||
..Default::default()
|
||||
};
|
||||
let mut t = Trainer::with_transformer(config, transformer);
|
||||
|
||||
let samples: Vec<TrainingSample> = (0..4).map(|i| TrainingSample {
|
||||
csi_features: vec![vec![(i as f32 * 0.2).sin(); 8]; 4],
|
||||
target_keypoints: (0..17).map(|k| (k as f32 * 0.5, k as f32 * 0.3, 1.0)).collect(),
|
||||
target_body_parts: vec![],
|
||||
target_uv: (vec![], vec![]),
|
||||
}).collect();
|
||||
|
||||
let result = t.run_training(&samples, &[]);
|
||||
assert!(result.history.iter().all(|s| s.train_loss.is_finite()),
|
||||
"all losses should be finite");
|
||||
|
||||
// Sync weights back and verify transformer still works
|
||||
t.sync_transformer_weights();
|
||||
if let Some(tf) = t.transformer() {
|
||||
let out = tf.forward(&vec![vec![1.0; 8]; 4]);
|
||||
assert_eq!(out.keypoints.len(), 17);
|
||||
for (i, &(x, y, z)) in out.keypoints.iter().enumerate() {
|
||||
assert!(x.is_finite() && y.is_finite() && z.is_finite(),
|
||||
"kp {i} not finite after training");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,774 @@
|
||||
//! Vital sign detection from WiFi CSI data.
|
||||
//!
|
||||
//! Implements breathing rate (0.1-0.5 Hz) and heart rate (0.8-2.0 Hz)
|
||||
//! estimation using FFT-based spectral analysis on CSI amplitude and phase
|
||||
//! time series. Designed per ADR-021 (rvdna vital sign pipeline).
|
||||
//!
|
||||
//! All math is pure Rust -- no external FFT crate required. Uses a radix-2
|
||||
//! DIT FFT for buffers zero-padded to power-of-two length. A windowed-sinc
|
||||
//! FIR bandpass filter isolates the frequency bands of interest before
|
||||
//! spectral analysis.
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use std::f64::consts::PI;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// ── Configuration constants ────────────────────────────────────────────────
|
||||
|
||||
/// Breathing rate physiological band: 6-30 breaths per minute.
|
||||
const BREATHING_MIN_HZ: f64 = 0.1; // 6 BPM
|
||||
const BREATHING_MAX_HZ: f64 = 0.5; // 30 BPM
|
||||
|
||||
/// Heart rate physiological band: 40-120 beats per minute.
|
||||
const HEARTBEAT_MIN_HZ: f64 = 0.667; // 40 BPM
|
||||
const HEARTBEAT_MAX_HZ: f64 = 2.0; // 120 BPM
|
||||
|
||||
/// Minimum number of samples before attempting extraction.
|
||||
const MIN_BREATHING_SAMPLES: usize = 40; // ~2s at 20 Hz
|
||||
const MIN_HEARTBEAT_SAMPLES: usize = 30; // ~1.5s at 20 Hz
|
||||
|
||||
/// Peak-to-mean ratio threshold for confident detection.
|
||||
const CONFIDENCE_THRESHOLD: f64 = 2.0;
|
||||
|
||||
// ── Output types ───────────────────────────────────────────────────────────
|
||||
|
||||
/// Vital sign readings produced each frame.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VitalSigns {
|
||||
/// Estimated breathing rate in breaths per minute, if detected.
|
||||
pub breathing_rate_bpm: Option<f64>,
|
||||
/// Estimated heart rate in beats per minute, if detected.
|
||||
pub heart_rate_bpm: Option<f64>,
|
||||
/// Confidence of breathing estimate (0.0 - 1.0).
|
||||
pub breathing_confidence: f64,
|
||||
/// Confidence of heartbeat estimate (0.0 - 1.0).
|
||||
pub heartbeat_confidence: f64,
|
||||
/// Overall signal quality metric (0.0 - 1.0).
|
||||
pub signal_quality: f64,
|
||||
}
|
||||
|
||||
impl Default for VitalSigns {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
breathing_rate_bpm: None,
|
||||
heart_rate_bpm: None,
|
||||
breathing_confidence: 0.0,
|
||||
heartbeat_confidence: 0.0,
|
||||
signal_quality: 0.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Detector ───────────────────────────────────────────────────────────────
|
||||
|
||||
/// Stateful vital sign detector. Maintains rolling buffers of CSI amplitude
|
||||
/// data and extracts breathing and heart rate via spectral analysis.
|
||||
#[allow(dead_code)]
|
||||
pub struct VitalSignDetector {
|
||||
/// Rolling buffer of mean-amplitude samples for breathing detection.
|
||||
breathing_buffer: VecDeque<f64>,
|
||||
/// Rolling buffer of phase-variance samples for heartbeat detection.
|
||||
heartbeat_buffer: VecDeque<f64>,
|
||||
/// CSI frame arrival rate in Hz.
|
||||
sample_rate: f64,
|
||||
/// Window duration for breathing FFT in seconds.
|
||||
breathing_window_secs: f64,
|
||||
/// Window duration for heartbeat FFT in seconds.
|
||||
heartbeat_window_secs: f64,
|
||||
/// Maximum breathing buffer capacity (samples).
|
||||
breathing_capacity: usize,
|
||||
/// Maximum heartbeat buffer capacity (samples).
|
||||
heartbeat_capacity: usize,
|
||||
/// Running frame count for signal quality estimation.
|
||||
frame_count: u64,
|
||||
}
|
||||
|
||||
impl VitalSignDetector {
|
||||
/// Create a new detector with the given CSI sample rate (Hz).
|
||||
///
|
||||
/// Typical sample rates:
|
||||
/// - ESP32 CSI: 20-100 Hz
|
||||
/// - Windows WiFi RSSI: 2 Hz (insufficient for heartbeat)
|
||||
/// - Simulation: 2-20 Hz
|
||||
pub fn new(sample_rate: f64) -> Self {
|
||||
let breathing_window_secs = 30.0;
|
||||
let heartbeat_window_secs = 15.0;
|
||||
let breathing_capacity = (sample_rate * breathing_window_secs) as usize;
|
||||
let heartbeat_capacity = (sample_rate * heartbeat_window_secs) as usize;
|
||||
|
||||
Self {
|
||||
breathing_buffer: VecDeque::with_capacity(breathing_capacity.max(1)),
|
||||
heartbeat_buffer: VecDeque::with_capacity(heartbeat_capacity.max(1)),
|
||||
sample_rate,
|
||||
breathing_window_secs,
|
||||
heartbeat_window_secs,
|
||||
breathing_capacity: breathing_capacity.max(1),
|
||||
heartbeat_capacity: heartbeat_capacity.max(1),
|
||||
frame_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Process one CSI frame and return updated vital signs.
|
||||
///
|
||||
/// `amplitude` - per-subcarrier amplitude values for this frame.
|
||||
/// `phase` - per-subcarrier phase values for this frame.
|
||||
///
|
||||
/// The detector extracts two aggregate features per frame:
|
||||
/// 1. Mean amplitude (breathing signal -- chest movement modulates path loss)
|
||||
/// 2. Phase variance across subcarriers (heartbeat signal -- subtle phase shifts)
|
||||
pub fn process_frame(&mut self, amplitude: &[f64], phase: &[f64]) -> VitalSigns {
|
||||
self.frame_count += 1;
|
||||
|
||||
if amplitude.is_empty() {
|
||||
return VitalSigns::default();
|
||||
}
|
||||
|
||||
// -- Feature 1: Mean amplitude for breathing detection --
|
||||
// Respiratory chest displacement (1-5 mm) modulates CSI amplitudes
|
||||
// across all subcarriers. Mean amplitude captures this well.
|
||||
let n = amplitude.len() as f64;
|
||||
let mean_amp: f64 = amplitude.iter().sum::<f64>() / n;
|
||||
|
||||
self.breathing_buffer.push_back(mean_amp);
|
||||
while self.breathing_buffer.len() > self.breathing_capacity {
|
||||
self.breathing_buffer.pop_front();
|
||||
}
|
||||
|
||||
// -- Feature 2: Phase variance for heartbeat detection --
|
||||
// Cardiac-induced body surface displacement is < 0.5 mm, producing
|
||||
// tiny phase changes. Cross-subcarrier phase variance captures this
|
||||
// more sensitively than amplitude alone.
|
||||
let phase_var = if phase.len() > 1 {
|
||||
let mean_phase: f64 = phase.iter().sum::<f64>() / phase.len() as f64;
|
||||
phase
|
||||
.iter()
|
||||
.map(|p| (p - mean_phase).powi(2))
|
||||
.sum::<f64>()
|
||||
/ phase.len() as f64
|
||||
} else {
|
||||
// Fallback: use amplitude high-pass residual when phase is unavailable
|
||||
let half = amplitude.len() / 2;
|
||||
if half > 0 {
|
||||
let hi_mean: f64 =
|
||||
amplitude[half..].iter().sum::<f64>() / (amplitude.len() - half) as f64;
|
||||
amplitude[half..]
|
||||
.iter()
|
||||
.map(|a| (a - hi_mean).powi(2))
|
||||
.sum::<f64>()
|
||||
/ (amplitude.len() - half) as f64
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
};
|
||||
|
||||
self.heartbeat_buffer.push_back(phase_var);
|
||||
while self.heartbeat_buffer.len() > self.heartbeat_capacity {
|
||||
self.heartbeat_buffer.pop_front();
|
||||
}
|
||||
|
||||
// -- Extract vital signs --
|
||||
let (breathing_rate, breathing_confidence) = self.extract_breathing();
|
||||
let (heart_rate, heartbeat_confidence) = self.extract_heartbeat();
|
||||
|
||||
// -- Signal quality --
|
||||
let signal_quality = self.compute_signal_quality(amplitude);
|
||||
|
||||
VitalSigns {
|
||||
breathing_rate_bpm: breathing_rate,
|
||||
heart_rate_bpm: heart_rate,
|
||||
breathing_confidence,
|
||||
heartbeat_confidence,
|
||||
signal_quality,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract breathing rate from the breathing buffer via FFT.
|
||||
/// Returns (rate_bpm, confidence).
|
||||
pub fn extract_breathing(&self) -> (Option<f64>, f64) {
|
||||
if self.breathing_buffer.len() < MIN_BREATHING_SAMPLES {
|
||||
return (None, 0.0);
|
||||
}
|
||||
|
||||
let data: Vec<f64> = self.breathing_buffer.iter().copied().collect();
|
||||
let filtered = bandpass_filter(&data, BREATHING_MIN_HZ, BREATHING_MAX_HZ, self.sample_rate);
|
||||
self.compute_fft_peak(&filtered, BREATHING_MIN_HZ, BREATHING_MAX_HZ)
|
||||
}
|
||||
|
||||
/// Extract heart rate from the heartbeat buffer via FFT.
|
||||
/// Returns (rate_bpm, confidence).
|
||||
pub fn extract_heartbeat(&self) -> (Option<f64>, f64) {
|
||||
if self.heartbeat_buffer.len() < MIN_HEARTBEAT_SAMPLES {
|
||||
return (None, 0.0);
|
||||
}
|
||||
|
||||
let data: Vec<f64> = self.heartbeat_buffer.iter().copied().collect();
|
||||
let filtered = bandpass_filter(&data, HEARTBEAT_MIN_HZ, HEARTBEAT_MAX_HZ, self.sample_rate);
|
||||
self.compute_fft_peak(&filtered, HEARTBEAT_MIN_HZ, HEARTBEAT_MAX_HZ)
|
||||
}
|
||||
|
||||
/// Find the dominant frequency in `buffer` within the [min_hz, max_hz] band
|
||||
/// using FFT. Returns (frequency_as_bpm, confidence).
|
||||
pub fn compute_fft_peak(
|
||||
&self,
|
||||
buffer: &[f64],
|
||||
min_hz: f64,
|
||||
max_hz: f64,
|
||||
) -> (Option<f64>, f64) {
|
||||
if buffer.len() < 4 {
|
||||
return (None, 0.0);
|
||||
}
|
||||
|
||||
// Zero-pad to next power of two for radix-2 FFT
|
||||
let fft_len = buffer.len().next_power_of_two();
|
||||
let mut signal = vec![0.0; fft_len];
|
||||
signal[..buffer.len()].copy_from_slice(buffer);
|
||||
|
||||
// Apply Hann window to reduce spectral leakage
|
||||
for i in 0..buffer.len() {
|
||||
let w = 0.5 * (1.0 - (2.0 * PI * i as f64 / (buffer.len() as f64 - 1.0)).cos());
|
||||
signal[i] *= w;
|
||||
}
|
||||
|
||||
// Compute FFT magnitude spectrum
|
||||
let spectrum = fft_magnitude(&signal);
|
||||
|
||||
// Frequency resolution
|
||||
let freq_res = self.sample_rate / fft_len as f64;
|
||||
|
||||
// Find bin range for our band of interest
|
||||
let min_bin = (min_hz / freq_res).ceil() as usize;
|
||||
let max_bin = ((max_hz / freq_res).floor() as usize).min(spectrum.len().saturating_sub(1));
|
||||
|
||||
if min_bin >= max_bin || min_bin >= spectrum.len() {
|
||||
return (None, 0.0);
|
||||
}
|
||||
|
||||
// Find peak magnitude and its bin index within the band
|
||||
let mut peak_mag = 0.0f64;
|
||||
let mut peak_bin = min_bin;
|
||||
let mut band_sum = 0.0f64;
|
||||
let mut band_count = 0usize;
|
||||
|
||||
for bin in min_bin..=max_bin {
|
||||
let mag = spectrum[bin];
|
||||
band_sum += mag;
|
||||
band_count += 1;
|
||||
if mag > peak_mag {
|
||||
peak_mag = mag;
|
||||
peak_bin = bin;
|
||||
}
|
||||
}
|
||||
|
||||
if band_count == 0 || band_sum < f64::EPSILON {
|
||||
return (None, 0.0);
|
||||
}
|
||||
|
||||
let band_mean = band_sum / band_count as f64;
|
||||
|
||||
// Confidence: ratio of peak to band mean, normalized to 0-1
|
||||
let peak_ratio = if band_mean > f64::EPSILON {
|
||||
peak_mag / band_mean
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// Parabolic interpolation for sub-bin frequency accuracy
|
||||
let peak_freq = if peak_bin > min_bin && peak_bin < max_bin {
|
||||
let alpha = spectrum[peak_bin - 1];
|
||||
let beta = spectrum[peak_bin];
|
||||
let gamma = spectrum[peak_bin + 1];
|
||||
let denom = alpha - 2.0 * beta + gamma;
|
||||
if denom.abs() > f64::EPSILON {
|
||||
let p = 0.5 * (alpha - gamma) / denom;
|
||||
(peak_bin as f64 + p) * freq_res
|
||||
} else {
|
||||
peak_bin as f64 * freq_res
|
||||
}
|
||||
} else {
|
||||
peak_bin as f64 * freq_res
|
||||
};
|
||||
|
||||
let bpm = peak_freq * 60.0;
|
||||
|
||||
// Confidence mapping: peak_ratio >= CONFIDENCE_THRESHOLD maps to high confidence
|
||||
let confidence = if peak_ratio >= CONFIDENCE_THRESHOLD {
|
||||
((peak_ratio - 1.0) / (CONFIDENCE_THRESHOLD * 2.0 - 1.0)).clamp(0.0, 1.0)
|
||||
} else {
|
||||
((peak_ratio - 1.0) / (CONFIDENCE_THRESHOLD - 1.0) * 0.5).clamp(0.0, 0.5)
|
||||
};
|
||||
|
||||
if confidence > 0.05 {
|
||||
(Some(bpm), confidence)
|
||||
} else {
|
||||
(None, confidence)
|
||||
}
|
||||
}
|
||||
|
||||
/// Overall signal quality based on amplitude statistics.
|
||||
fn compute_signal_quality(&self, amplitude: &[f64]) -> f64 {
|
||||
if amplitude.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let n = amplitude.len() as f64;
|
||||
let mean = amplitude.iter().sum::<f64>() / n;
|
||||
|
||||
if mean < f64::EPSILON {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let variance = amplitude.iter().map(|a| (a - mean).powi(2)).sum::<f64>() / n;
|
||||
let cv = variance.sqrt() / mean; // coefficient of variation
|
||||
|
||||
// Good signal: moderate CV (some variation from body motion, not pure noise).
|
||||
// - Too low CV (~0) = static, no person present
|
||||
// - Too high CV (>1) = noisy/unstable signal
|
||||
// Sweet spot around 0.05-0.3
|
||||
let quality = if cv < 0.01 {
|
||||
cv / 0.01 * 0.3 // very low variation => low quality
|
||||
} else if cv < 0.3 {
|
||||
0.3 + 0.7 * (1.0 - ((cv - 0.15) / 0.15).abs()).max(0.0) // peak around 0.15
|
||||
} else {
|
||||
(1.0 - (cv - 0.3) / 0.7).clamp(0.1, 0.5) // too noisy
|
||||
};
|
||||
|
||||
// Factor in buffer fill level (need enough history for reliable estimates)
|
||||
let fill =
|
||||
(self.breathing_buffer.len() as f64) / (self.breathing_capacity as f64).max(1.0);
|
||||
let fill_factor = fill.clamp(0.0, 1.0);
|
||||
|
||||
(quality * (0.3 + 0.7 * fill_factor)).clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
/// Clear all internal buffers and reset state.
|
||||
pub fn reset(&mut self) {
|
||||
self.breathing_buffer.clear();
|
||||
self.heartbeat_buffer.clear();
|
||||
self.frame_count = 0;
|
||||
}
|
||||
|
||||
/// Current buffer fill levels for diagnostics.
|
||||
/// Returns (breathing_len, breathing_capacity, heartbeat_len, heartbeat_capacity).
|
||||
pub fn buffer_status(&self) -> (usize, usize, usize, usize) {
|
||||
(
|
||||
self.breathing_buffer.len(),
|
||||
self.breathing_capacity,
|
||||
self.heartbeat_buffer.len(),
|
||||
self.heartbeat_capacity,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Bandpass filter ────────────────────────────────────────────────────────
|
||||
|
||||
/// Simple FIR bandpass filter using a windowed-sinc design.
|
||||
///
|
||||
/// Constructs a bandpass by subtracting two lowpass filters (LPF_high - LPF_low)
|
||||
/// with a Hamming window. This is a zero-external-dependency implementation
|
||||
/// suitable for the buffer sizes we encounter (up to ~600 samples).
|
||||
pub fn bandpass_filter(data: &[f64], low_hz: f64, high_hz: f64, sample_rate: f64) -> Vec<f64> {
|
||||
if data.len() < 3 || sample_rate < f64::EPSILON {
|
||||
return data.to_vec();
|
||||
}
|
||||
|
||||
// Normalized cutoff frequencies (0 to 0.5)
|
||||
let low_norm = low_hz / sample_rate;
|
||||
let high_norm = high_hz / sample_rate;
|
||||
|
||||
if low_norm >= high_norm || low_norm >= 0.5 || high_norm <= 0.0 {
|
||||
return data.to_vec();
|
||||
}
|
||||
|
||||
// FIR filter order: ~3 cycles of the lowest frequency, clamped to [5, 127]
|
||||
let filter_order = ((3.0 / low_norm).ceil() as usize).clamp(5, 127);
|
||||
// Ensure odd for type-I FIR symmetry
|
||||
let filter_order = if filter_order % 2 == 0 {
|
||||
filter_order + 1
|
||||
} else {
|
||||
filter_order
|
||||
};
|
||||
|
||||
let half = filter_order / 2;
|
||||
let mut coeffs = vec![0.0f64; filter_order];
|
||||
|
||||
// BPF = LPF(high_norm) - LPF(low_norm) with Hamming window
|
||||
for i in 0..filter_order {
|
||||
let n = i as f64 - half as f64;
|
||||
let lp_high = if n.abs() < f64::EPSILON {
|
||||
2.0 * high_norm
|
||||
} else {
|
||||
(2.0 * PI * high_norm * n).sin() / (PI * n)
|
||||
};
|
||||
let lp_low = if n.abs() < f64::EPSILON {
|
||||
2.0 * low_norm
|
||||
} else {
|
||||
(2.0 * PI * low_norm * n).sin() / (PI * n)
|
||||
};
|
||||
|
||||
// Hamming window
|
||||
let w = 0.54 - 0.46 * (2.0 * PI * i as f64 / (filter_order as f64 - 1.0)).cos();
|
||||
coeffs[i] = (lp_high - lp_low) * w;
|
||||
}
|
||||
|
||||
// Normalize filter to unit gain at center frequency
|
||||
let center_freq = (low_norm + high_norm) / 2.0;
|
||||
let gain: f64 = coeffs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &c)| c * (2.0 * PI * center_freq * i as f64).cos())
|
||||
.sum();
|
||||
if gain.abs() > f64::EPSILON {
|
||||
for c in coeffs.iter_mut() {
|
||||
*c /= gain;
|
||||
}
|
||||
}
|
||||
|
||||
// Apply filter via convolution
|
||||
let mut output = vec![0.0f64; data.len()];
|
||||
for i in 0..data.len() {
|
||||
let mut sum = 0.0;
|
||||
for (j, &coeff) in coeffs.iter().enumerate() {
|
||||
let idx = i as isize - half as isize + j as isize;
|
||||
if idx >= 0 && (idx as usize) < data.len() {
|
||||
sum += data[idx as usize] * coeff;
|
||||
}
|
||||
}
|
||||
output[i] = sum;
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
// ── FFT implementation ─────────────────────────────────────────────────────
|
||||
|
||||
/// Compute the magnitude spectrum of a real-valued signal using radix-2 DIT FFT.
|
||||
///
|
||||
/// Input must be power-of-2 length (caller should zero-pad).
|
||||
/// Returns magnitudes for bins 0..N/2+1.
|
||||
fn fft_magnitude(signal: &[f64]) -> Vec<f64> {
|
||||
let n = signal.len();
|
||||
debug_assert!(n.is_power_of_two(), "FFT input must be power-of-2 length");
|
||||
|
||||
if n <= 1 {
|
||||
return signal.to_vec();
|
||||
}
|
||||
|
||||
// Convert to complex (imaginary = 0)
|
||||
let mut real = signal.to_vec();
|
||||
let mut imag = vec![0.0f64; n];
|
||||
|
||||
// Bit-reversal permutation
|
||||
bit_reverse_permute(&mut real, &mut imag);
|
||||
|
||||
// Cooley-Tukey radix-2 DIT butterfly
|
||||
let mut size = 2;
|
||||
while size <= n {
|
||||
let half = size / 2;
|
||||
let angle_step = -2.0 * PI / size as f64;
|
||||
|
||||
for start in (0..n).step_by(size) {
|
||||
for k in 0..half {
|
||||
let angle = angle_step * k as f64;
|
||||
let wr = angle.cos();
|
||||
let wi = angle.sin();
|
||||
|
||||
let i = start + k;
|
||||
let j = start + k + half;
|
||||
|
||||
let tr = wr * real[j] - wi * imag[j];
|
||||
let ti = wr * imag[j] + wi * real[j];
|
||||
|
||||
real[j] = real[i] - tr;
|
||||
imag[j] = imag[i] - ti;
|
||||
real[i] += tr;
|
||||
imag[i] += ti;
|
||||
}
|
||||
}
|
||||
|
||||
size *= 2;
|
||||
}
|
||||
|
||||
// Compute magnitudes for positive frequencies (0..N/2+1)
|
||||
let out_len = n / 2 + 1;
|
||||
let mut magnitudes = Vec::with_capacity(out_len);
|
||||
for i in 0..out_len {
|
||||
magnitudes.push((real[i] * real[i] + imag[i] * imag[i]).sqrt());
|
||||
}
|
||||
|
||||
magnitudes
|
||||
}
|
||||
|
||||
/// In-place bit-reversal permutation for FFT.
|
||||
fn bit_reverse_permute(real: &mut [f64], imag: &mut [f64]) {
|
||||
let n = real.len();
|
||||
let bits = (n as f64).log2() as u32;
|
||||
|
||||
for i in 0..n {
|
||||
let j = reverse_bits(i as u32, bits) as usize;
|
||||
if i < j {
|
||||
real.swap(i, j);
|
||||
imag.swap(i, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Reverse the lower `bits` bits of `val`.
|
||||
fn reverse_bits(val: u32, bits: u32) -> u32 {
|
||||
let mut result = 0u32;
|
||||
let mut v = val;
|
||||
for _ in 0..bits {
|
||||
result = (result << 1) | (v & 1);
|
||||
v >>= 1;
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
// ── Benchmark ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// Run a benchmark: process `n_frames` synthetic frames and report timing.
|
||||
///
|
||||
/// Generates frames with embedded breathing (0.25 Hz / 15 BPM) and heartbeat
|
||||
/// (1.2 Hz / 72 BPM) signals on 56 subcarriers at 20 Hz sample rate.
|
||||
///
|
||||
/// Returns (total_duration, per_frame_duration).
|
||||
pub fn run_benchmark(n_frames: usize) -> (std::time::Duration, std::time::Duration) {
|
||||
use std::time::Instant;
|
||||
|
||||
let sample_rate = 20.0;
|
||||
let mut detector = VitalSignDetector::new(sample_rate);
|
||||
|
||||
// Pre-generate synthetic CSI data (56 subcarriers, matching simulation mode)
|
||||
let n_sub = 56;
|
||||
let frames: Vec<(Vec<f64>, Vec<f64>)> = (0..n_frames)
|
||||
.map(|tick| {
|
||||
let t = tick as f64 / sample_rate;
|
||||
let mut amp = Vec::with_capacity(n_sub);
|
||||
let mut phase = Vec::with_capacity(n_sub);
|
||||
for i in 0..n_sub {
|
||||
// Embedded breathing at 0.25 Hz (15 BPM) and heartbeat at 1.2 Hz (72 BPM)
|
||||
let breathing = 2.0 * (2.0 * PI * 0.25 * t).sin();
|
||||
let heartbeat = 0.3 * (2.0 * PI * 1.2 * t).sin();
|
||||
let base = 15.0 + 5.0 * (i as f64 * 0.1).sin();
|
||||
let noise = (i as f64 * 7.3 + t * 13.7).sin() * 0.5;
|
||||
amp.push(base + breathing + heartbeat + noise);
|
||||
phase.push((i as f64 * 0.2 + t * 0.5).sin() * PI + heartbeat * 0.1);
|
||||
}
|
||||
(amp, phase)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let start = Instant::now();
|
||||
let mut last_vital = VitalSigns::default();
|
||||
for (amp, phase) in &frames {
|
||||
last_vital = detector.process_frame(amp, phase);
|
||||
}
|
||||
let total = start.elapsed();
|
||||
let per_frame = total / n_frames as u32;
|
||||
|
||||
eprintln!("=== Vital Sign Detection Benchmark ===");
|
||||
eprintln!("Frames processed: {}", n_frames);
|
||||
eprintln!("Sample rate: {} Hz", sample_rate);
|
||||
eprintln!("Subcarriers: {}", n_sub);
|
||||
eprintln!("Total time: {:?}", total);
|
||||
eprintln!("Per-frame time: {:?}", per_frame);
|
||||
eprintln!(
|
||||
"Throughput: {:.0} frames/sec",
|
||||
n_frames as f64 / total.as_secs_f64()
|
||||
);
|
||||
eprintln!();
|
||||
eprintln!("Final vital signs:");
|
||||
eprintln!(
|
||||
" Breathing rate: {:?} BPM",
|
||||
last_vital.breathing_rate_bpm
|
||||
);
|
||||
eprintln!(" Heart rate: {:?} BPM", last_vital.heart_rate_bpm);
|
||||
eprintln!(
|
||||
" Breathing confidence: {:.3}",
|
||||
last_vital.breathing_confidence
|
||||
);
|
||||
eprintln!(
|
||||
" Heartbeat confidence: {:.3}",
|
||||
last_vital.heartbeat_confidence
|
||||
);
|
||||
eprintln!(
|
||||
" Signal quality: {:.3}",
|
||||
last_vital.signal_quality
|
||||
);
|
||||
|
||||
(total, per_frame)
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_fft_magnitude_dc() {
|
||||
let signal = vec![1.0; 8];
|
||||
let mag = fft_magnitude(&signal);
|
||||
// DC bin should be 8.0 (sum), all others near zero
|
||||
assert!((mag[0] - 8.0).abs() < 1e-10);
|
||||
for m in &mag[1..] {
|
||||
assert!(*m < 1e-10, "non-DC bin should be near zero, got {m}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fft_magnitude_sine() {
|
||||
// 16-point signal with a single sinusoid at bin 2
|
||||
let n = 16;
|
||||
let mut signal = vec![0.0; n];
|
||||
for i in 0..n {
|
||||
signal[i] = (2.0 * PI * 2.0 * i as f64 / n as f64).sin();
|
||||
}
|
||||
let mag = fft_magnitude(&signal);
|
||||
// Peak should be at bin 2
|
||||
let peak_bin = mag
|
||||
.iter()
|
||||
.enumerate()
|
||||
.skip(1) // skip DC
|
||||
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
||||
.unwrap()
|
||||
.0;
|
||||
assert_eq!(peak_bin, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bit_reverse() {
|
||||
assert_eq!(reverse_bits(0b000, 3), 0b000);
|
||||
assert_eq!(reverse_bits(0b001, 3), 0b100);
|
||||
assert_eq!(reverse_bits(0b110, 3), 0b011);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bandpass_filter_passthrough() {
|
||||
// A sine at the center of the passband should mostly pass through
|
||||
let sr = 20.0;
|
||||
let freq = 0.25; // center of breathing band
|
||||
let n = 200;
|
||||
let data: Vec<f64> = (0..n)
|
||||
.map(|i| (2.0 * PI * freq * i as f64 / sr).sin())
|
||||
.collect();
|
||||
let filtered = bandpass_filter(&data, 0.1, 0.5, sr);
|
||||
// Check that the filtered signal has significant energy
|
||||
let energy: f64 = filtered.iter().map(|x| x * x).sum::<f64>() / n as f64;
|
||||
assert!(
|
||||
energy > 0.01,
|
||||
"passband signal should pass through, energy={energy}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bandpass_filter_rejects_out_of_band() {
|
||||
// A sine well outside the passband should be attenuated
|
||||
let sr = 20.0;
|
||||
let freq = 5.0; // way above breathing band
|
||||
let n = 200;
|
||||
let data: Vec<f64> = (0..n)
|
||||
.map(|i| (2.0 * PI * freq * i as f64 / sr).sin())
|
||||
.collect();
|
||||
let in_energy: f64 = data.iter().map(|x| x * x).sum::<f64>() / n as f64;
|
||||
let filtered = bandpass_filter(&data, 0.1, 0.5, sr);
|
||||
let out_energy: f64 = filtered.iter().map(|x| x * x).sum::<f64>() / n as f64;
|
||||
let attenuation = out_energy / in_energy;
|
||||
assert!(
|
||||
attenuation < 0.3,
|
||||
"out-of-band signal should be attenuated, ratio={attenuation}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vital_sign_detector_breathing() {
|
||||
let sr = 20.0;
|
||||
let mut detector = VitalSignDetector::new(sr);
|
||||
let target_bpm = 15.0; // 0.25 Hz
|
||||
let target_hz = target_bpm / 60.0;
|
||||
|
||||
// Feed 30 seconds of data with a clear breathing signal
|
||||
let n_frames = (sr * 30.0) as usize;
|
||||
let mut vitals = VitalSigns::default();
|
||||
for frame in 0..n_frames {
|
||||
let t = frame as f64 / sr;
|
||||
let amp: Vec<f64> = (0..56)
|
||||
.map(|i| {
|
||||
let base = 15.0 + 5.0 * (i as f64 * 0.1).sin();
|
||||
let breathing = 3.0 * (2.0 * PI * target_hz * t).sin();
|
||||
base + breathing
|
||||
})
|
||||
.collect();
|
||||
let phase: Vec<f64> = (0..56).map(|i| (i as f64 * 0.2).sin()).collect();
|
||||
vitals = detector.process_frame(&, &phase);
|
||||
}
|
||||
|
||||
// After 30s, breathing should be detected
|
||||
assert!(
|
||||
vitals.breathing_rate_bpm.is_some(),
|
||||
"breathing should be detected after 30s"
|
||||
);
|
||||
if let Some(rate) = vitals.breathing_rate_bpm {
|
||||
let error = (rate - target_bpm).abs();
|
||||
assert!(
|
||||
error < 3.0,
|
||||
"breathing rate {rate:.1} BPM should be near {target_bpm} BPM (error={error:.1})"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vital_sign_detector_reset() {
|
||||
let mut detector = VitalSignDetector::new(20.0);
|
||||
let amp = vec![10.0; 56];
|
||||
let phase = vec![0.0; 56];
|
||||
for _ in 0..100 {
|
||||
detector.process_frame(&, &phase);
|
||||
}
|
||||
let (br_len, _, hb_len, _) = detector.buffer_status();
|
||||
assert!(br_len > 0);
|
||||
assert!(hb_len > 0);
|
||||
|
||||
detector.reset();
|
||||
let (br_len, _, hb_len, _) = detector.buffer_status();
|
||||
assert_eq!(br_len, 0);
|
||||
assert_eq!(hb_len, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vital_signs_default() {
|
||||
let vs = VitalSigns::default();
|
||||
assert!(vs.breathing_rate_bpm.is_none());
|
||||
assert!(vs.heart_rate_bpm.is_none());
|
||||
assert_eq!(vs.breathing_confidence, 0.0);
|
||||
assert_eq!(vs.heartbeat_confidence, 0.0);
|
||||
assert_eq!(vs.signal_quality, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_amplitude() {
|
||||
let mut detector = VitalSignDetector::new(20.0);
|
||||
let vs = detector.process_frame(&[], &[]);
|
||||
assert!(vs.breathing_rate_bpm.is_none());
|
||||
assert!(vs.heart_rate_bpm.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_subcarrier() {
|
||||
let mut detector = VitalSignDetector::new(20.0);
|
||||
// Single subcarrier should not crash
|
||||
for i in 0..100 {
|
||||
let t = i as f64 / 20.0;
|
||||
let amp = vec![10.0 + (2.0 * PI * 0.25 * t).sin()];
|
||||
let phase = vec![0.0];
|
||||
let _ = detector.process_frame(&, &phase);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_benchmark_runs() {
|
||||
let (total, per_frame) = run_benchmark(100);
|
||||
assert!(total.as_nanos() > 0);
|
||||
assert!(per_frame.as_nanos() > 0);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,556 @@
|
||||
//! Integration tests for the RVF (RuVector Format) container module.
|
||||
//!
|
||||
//! These tests exercise the public RvfBuilder and RvfReader APIs through
|
||||
//! the library crate's public interface. They complement the inline unit
|
||||
//! tests in rvf_container.rs by testing from the perspective of an external
|
||||
//! consumer.
|
||||
//!
|
||||
//! Test matrix:
|
||||
//! - Empty builder produces valid (empty) container
|
||||
//! - Full round-trip: manifest + weights + metadata -> build -> read -> verify
|
||||
//! - Segment type tagging and ordering
|
||||
//! - Magic byte corruption is rejected
|
||||
//! - Float32 precision is preserved bit-for-bit
|
||||
//! - Large payload (1M weights) round-trip
|
||||
//! - Multiple metadata segments coexist
|
||||
//! - File I/O round-trip
|
||||
//! - Witness/proof segment verification
|
||||
//! - Write/read benchmark for ~10MB container
|
||||
|
||||
use wifi_densepose_sensing_server::rvf_container::{
|
||||
RvfBuilder, RvfReader, VitalSignConfig,
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_rvf_builder_empty() {
|
||||
let builder = RvfBuilder::new();
|
||||
let data = builder.build();
|
||||
|
||||
// Empty builder produces zero bytes (no segments => no headers)
|
||||
assert!(
|
||||
data.is_empty(),
|
||||
"empty builder should produce empty byte vec"
|
||||
);
|
||||
|
||||
// Reader should parse an empty container with zero segments
|
||||
let reader = RvfReader::from_bytes(&data).expect("should parse empty container");
|
||||
assert_eq!(reader.segment_count(), 0);
|
||||
assert_eq!(reader.total_size(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rvf_round_trip() {
|
||||
let mut builder = RvfBuilder::new();
|
||||
|
||||
// Add all segment types
|
||||
builder.add_manifest("vital-signs-v1", "0.1.0", "Vital sign detection model");
|
||||
|
||||
let weights: Vec<f32> = (0..100).map(|i| i as f32 * 0.01).collect();
|
||||
builder.add_weights(&weights);
|
||||
|
||||
let metadata = serde_json::json!({
|
||||
"training_epochs": 50,
|
||||
"loss": 0.023,
|
||||
"optimizer": "adam",
|
||||
});
|
||||
builder.add_metadata(&metadata);
|
||||
|
||||
let data = builder.build();
|
||||
assert!(!data.is_empty(), "container with data should not be empty");
|
||||
|
||||
// Alignment: every segment should start on a 64-byte boundary
|
||||
assert_eq!(
|
||||
data.len() % 64,
|
||||
0,
|
||||
"total size should be a multiple of 64 bytes"
|
||||
);
|
||||
|
||||
// Parse back
|
||||
let reader = RvfReader::from_bytes(&data).expect("should parse container");
|
||||
assert_eq!(reader.segment_count(), 3);
|
||||
|
||||
// Verify manifest
|
||||
let manifest = reader
|
||||
.manifest()
|
||||
.expect("should have manifest");
|
||||
assert_eq!(manifest["model_id"], "vital-signs-v1");
|
||||
assert_eq!(manifest["version"], "0.1.0");
|
||||
assert_eq!(manifest["description"], "Vital sign detection model");
|
||||
|
||||
// Verify weights
|
||||
let decoded_weights = reader
|
||||
.weights()
|
||||
.expect("should have weights");
|
||||
assert_eq!(decoded_weights.len(), weights.len());
|
||||
for (i, (&original, &decoded)) in weights.iter().zip(decoded_weights.iter()).enumerate() {
|
||||
assert_eq!(
|
||||
original.to_bits(),
|
||||
decoded.to_bits(),
|
||||
"weight[{i}] mismatch"
|
||||
);
|
||||
}
|
||||
|
||||
// Verify metadata
|
||||
let decoded_meta = reader
|
||||
.metadata()
|
||||
.expect("should have metadata");
|
||||
assert_eq!(decoded_meta["training_epochs"], 50);
|
||||
assert_eq!(decoded_meta["optimizer"], "adam");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rvf_segment_types() {
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_manifest("test", "1.0", "test model");
|
||||
builder.add_weights(&[1.0, 2.0]);
|
||||
builder.add_metadata(&serde_json::json!({"key": "value"}));
|
||||
builder.add_witness(
|
||||
"sha256:abc123",
|
||||
&serde_json::json!({"accuracy": 0.95}),
|
||||
);
|
||||
|
||||
let data = builder.build();
|
||||
let reader = RvfReader::from_bytes(&data).expect("should parse");
|
||||
|
||||
assert_eq!(reader.segment_count(), 4);
|
||||
|
||||
// Each segment type should be present
|
||||
assert!(reader.manifest().is_some(), "manifest should be present");
|
||||
assert!(reader.weights().is_some(), "weights should be present");
|
||||
assert!(reader.metadata().is_some(), "metadata should be present");
|
||||
assert!(reader.witness().is_some(), "witness should be present");
|
||||
|
||||
// Verify segment order via segment IDs (monotonically increasing)
|
||||
let ids: Vec<u64> = reader
|
||||
.segments()
|
||||
.map(|(h, _)| h.segment_id)
|
||||
.collect();
|
||||
assert_eq!(ids, vec![0, 1, 2, 3], "segment IDs should be 0,1,2,3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rvf_magic_validation() {
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_manifest("test", "1.0", "test");
|
||||
let mut data = builder.build();
|
||||
|
||||
// Corrupt the magic bytes in the first segment header
|
||||
// Magic is at offset 0x00..0x04
|
||||
data[0] = 0xDE;
|
||||
data[1] = 0xAD;
|
||||
data[2] = 0xBE;
|
||||
data[3] = 0xEF;
|
||||
|
||||
let result = RvfReader::from_bytes(&data);
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"corrupted magic should fail to parse"
|
||||
);
|
||||
|
||||
let err = result.unwrap_err();
|
||||
assert!(
|
||||
err.contains("magic"),
|
||||
"error message should mention 'magic', got: {}",
|
||||
err
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rvf_weights_f32_precision() {
|
||||
// Test specific float32 edge cases
|
||||
let weights: Vec<f32> = vec![
|
||||
0.0,
|
||||
1.0,
|
||||
-1.0,
|
||||
f32::MIN_POSITIVE,
|
||||
f32::MAX,
|
||||
f32::MIN,
|
||||
f32::EPSILON,
|
||||
std::f32::consts::PI,
|
||||
std::f32::consts::E,
|
||||
1.0e-30,
|
||||
1.0e30,
|
||||
-0.0,
|
||||
0.123456789,
|
||||
1.0e-45, // subnormal
|
||||
];
|
||||
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_weights(&weights);
|
||||
let data = builder.build();
|
||||
|
||||
let reader = RvfReader::from_bytes(&data).expect("should parse");
|
||||
let decoded = reader.weights().expect("should have weights");
|
||||
|
||||
assert_eq!(decoded.len(), weights.len());
|
||||
for (i, (&original, &parsed)) in weights.iter().zip(decoded.iter()).enumerate() {
|
||||
assert_eq!(
|
||||
original.to_bits(),
|
||||
parsed.to_bits(),
|
||||
"weight[{i}] bit-level mismatch: original={original} (0x{:08X}), parsed={parsed} (0x{:08X})",
|
||||
original.to_bits(),
|
||||
parsed.to_bits(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rvf_large_payload() {
|
||||
// 1 million f32 weights = 4 MB of payload data
|
||||
let num_weights = 1_000_000;
|
||||
let weights: Vec<f32> = (0..num_weights)
|
||||
.map(|i| (i as f32 * 0.000001).sin())
|
||||
.collect();
|
||||
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_manifest("large-test", "1.0", "Large payload test");
|
||||
builder.add_weights(&weights);
|
||||
let data = builder.build();
|
||||
|
||||
// Container should be at least header + weights bytes
|
||||
assert!(
|
||||
data.len() >= 64 + num_weights * 4,
|
||||
"container should be large enough, got {} bytes",
|
||||
data.len()
|
||||
);
|
||||
|
||||
let reader = RvfReader::from_bytes(&data).expect("should parse large container");
|
||||
let decoded = reader.weights().expect("should have weights");
|
||||
|
||||
assert_eq!(
|
||||
decoded.len(),
|
||||
num_weights,
|
||||
"all 1M weights should round-trip"
|
||||
);
|
||||
|
||||
// Spot-check several values
|
||||
for idx in [0, 1, 100, 1000, 500_000, 999_999] {
|
||||
assert_eq!(
|
||||
weights[idx].to_bits(),
|
||||
decoded[idx].to_bits(),
|
||||
"weight[{idx}] mismatch"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rvf_multiple_metadata_segments() {
|
||||
// The current builder only stores one metadata segment, but we can add
|
||||
// multiple by adding metadata and then other segments to verify all coexist.
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_manifest("multi-meta", "1.0", "Multiple segment types");
|
||||
|
||||
let meta1 = serde_json::json!({"training_config": {"optimizer": "adam"}});
|
||||
builder.add_metadata(&meta1);
|
||||
|
||||
builder.add_vital_config(&VitalSignConfig::default());
|
||||
builder.add_quant_info("int8", 0.0078125, -128);
|
||||
|
||||
let data = builder.build();
|
||||
let reader = RvfReader::from_bytes(&data).expect("should parse");
|
||||
|
||||
assert_eq!(
|
||||
reader.segment_count(),
|
||||
4,
|
||||
"should have 4 segments (manifest + meta + vital_config + quant)"
|
||||
);
|
||||
|
||||
assert!(reader.manifest().is_some());
|
||||
assert!(reader.metadata().is_some());
|
||||
assert!(reader.vital_config().is_some());
|
||||
assert!(reader.quant_info().is_some());
|
||||
|
||||
// Verify metadata content
|
||||
let meta = reader.metadata().unwrap();
|
||||
assert_eq!(meta["training_config"]["optimizer"], "adam");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rvf_file_io() {
|
||||
let tmp_dir = tempfile::tempdir().expect("should create temp dir");
|
||||
let file_path = tmp_dir.path().join("test_model.rvf");
|
||||
|
||||
let weights: Vec<f32> = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_manifest("file-io-test", "1.0.0", "File I/O test model");
|
||||
builder.add_weights(&weights);
|
||||
builder.add_metadata(&serde_json::json!({"created": "2026-02-28"}));
|
||||
|
||||
// Write to file
|
||||
builder
|
||||
.write_to_file(&file_path)
|
||||
.expect("should write to file");
|
||||
|
||||
// Read back from file
|
||||
let reader = RvfReader::from_file(&file_path).expect("should read from file");
|
||||
|
||||
assert_eq!(reader.segment_count(), 3);
|
||||
|
||||
let manifest = reader.manifest().expect("should have manifest");
|
||||
assert_eq!(manifest["model_id"], "file-io-test");
|
||||
|
||||
let decoded_weights = reader.weights().expect("should have weights");
|
||||
assert_eq!(decoded_weights.len(), weights.len());
|
||||
for (a, b) in decoded_weights.iter().zip(weights.iter()) {
|
||||
assert_eq!(a.to_bits(), b.to_bits());
|
||||
}
|
||||
|
||||
let meta = reader.metadata().expect("should have metadata");
|
||||
assert_eq!(meta["created"], "2026-02-28");
|
||||
|
||||
// Verify file size matches in-memory serialization
|
||||
let in_memory = builder.build();
|
||||
let file_meta = std::fs::metadata(&file_path).expect("should stat file");
|
||||
assert_eq!(
|
||||
file_meta.len() as usize,
|
||||
in_memory.len(),
|
||||
"file size should match serialized size"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rvf_witness_proof() {
|
||||
let training_hash = "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
|
||||
let metrics = serde_json::json!({
|
||||
"accuracy": 0.957,
|
||||
"loss": 0.023,
|
||||
"epochs": 200,
|
||||
"dataset_size": 50000,
|
||||
});
|
||||
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_manifest("witnessed-model", "2.0", "Model with witness proof");
|
||||
builder.add_weights(&[1.0, 2.0, 3.0]);
|
||||
builder.add_witness(training_hash, &metrics);
|
||||
|
||||
let data = builder.build();
|
||||
let reader = RvfReader::from_bytes(&data).expect("should parse");
|
||||
|
||||
let witness = reader.witness().expect("should have witness segment");
|
||||
assert_eq!(
|
||||
witness["training_hash"],
|
||||
training_hash,
|
||||
"training hash should round-trip"
|
||||
);
|
||||
assert_eq!(witness["metrics"]["accuracy"], 0.957);
|
||||
assert_eq!(witness["metrics"]["epochs"], 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rvf_benchmark_write_read() {
|
||||
// Create a container with ~10 MB of weights
|
||||
let num_weights = 2_500_000; // 10 MB of f32 data
|
||||
let weights: Vec<f32> = (0..num_weights)
|
||||
.map(|i| (i as f32 * 0.0001).sin())
|
||||
.collect();
|
||||
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_manifest("benchmark-model", "1.0", "Benchmark test");
|
||||
builder.add_weights(&weights);
|
||||
builder.add_metadata(&serde_json::json!({"benchmark": true}));
|
||||
|
||||
// Benchmark write (serialization)
|
||||
let write_start = std::time::Instant::now();
|
||||
let data = builder.build();
|
||||
let write_elapsed = write_start.elapsed();
|
||||
|
||||
let size_mb = data.len() as f64 / (1024.0 * 1024.0);
|
||||
let write_speed = size_mb / write_elapsed.as_secs_f64();
|
||||
|
||||
println!(
|
||||
"RVF write benchmark: {:.1} MB in {:.2}ms = {:.0} MB/s",
|
||||
size_mb,
|
||||
write_elapsed.as_secs_f64() * 1000.0,
|
||||
write_speed,
|
||||
);
|
||||
|
||||
// Benchmark read (deserialization + CRC validation)
|
||||
let read_start = std::time::Instant::now();
|
||||
let reader = RvfReader::from_bytes(&data).expect("should parse benchmark container");
|
||||
let read_elapsed = read_start.elapsed();
|
||||
|
||||
let read_speed = size_mb / read_elapsed.as_secs_f64();
|
||||
|
||||
println!(
|
||||
"RVF read benchmark: {:.1} MB in {:.2}ms = {:.0} MB/s",
|
||||
size_mb,
|
||||
read_elapsed.as_secs_f64() * 1000.0,
|
||||
read_speed,
|
||||
);
|
||||
|
||||
// Verify correctness
|
||||
let decoded_weights = reader.weights().expect("should have weights");
|
||||
assert_eq!(decoded_weights.len(), num_weights);
|
||||
assert_eq!(weights[0].to_bits(), decoded_weights[0].to_bits());
|
||||
assert_eq!(
|
||||
weights[num_weights - 1].to_bits(),
|
||||
decoded_weights[num_weights - 1].to_bits()
|
||||
);
|
||||
|
||||
// Write and read should be reasonably fast
|
||||
assert!(
|
||||
write_speed > 10.0,
|
||||
"write speed {:.0} MB/s is too slow",
|
||||
write_speed
|
||||
);
|
||||
assert!(
|
||||
read_speed > 10.0,
|
||||
"read speed {:.0} MB/s is too slow",
|
||||
read_speed
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rvf_content_hash_integrity() {
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_metadata(&serde_json::json!({"integrity": "test"}));
|
||||
let mut data = builder.build();
|
||||
|
||||
// Corrupt one byte in the payload area (after the 64-byte header)
|
||||
if data.len() > 65 {
|
||||
data[65] ^= 0xFF;
|
||||
let result = RvfReader::from_bytes(&data);
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"corrupted payload should fail CRC32 hash check"
|
||||
);
|
||||
assert!(
|
||||
result.unwrap_err().contains("hash mismatch"),
|
||||
"error should mention hash mismatch"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rvf_truncated_data() {
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_manifest("truncation-test", "1.0", "Truncation test");
|
||||
builder.add_weights(&[1.0, 2.0, 3.0, 4.0, 5.0]);
|
||||
let data = builder.build();
|
||||
|
||||
// Truncating at header boundary or within payload should fail
|
||||
for truncate_at in [0, 10, 32, 63, 64, 65, 80] {
|
||||
if truncate_at < data.len() {
|
||||
let truncated = &data[..truncate_at];
|
||||
let result = RvfReader::from_bytes(truncated);
|
||||
// Empty or partial-header data: either returns empty or errors
|
||||
if truncate_at < 64 {
|
||||
// Less than one header: reader returns 0 segments (no error on empty)
|
||||
// or fails if partial header data is present
|
||||
// The reader skips if offset + HEADER_SIZE > data.len()
|
||||
if truncate_at == 0 {
|
||||
assert!(
|
||||
result.is_ok() && result.unwrap().segment_count() == 0,
|
||||
"empty data should parse as 0 segments"
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Has header but truncated payload
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"truncated at {truncate_at} bytes should fail"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rvf_empty_weights() {
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_weights(&[]);
|
||||
let data = builder.build();
|
||||
|
||||
let reader = RvfReader::from_bytes(&data).expect("should parse");
|
||||
let weights = reader.weights().expect("should have weights segment");
|
||||
assert!(weights.is_empty(), "empty weight vector should round-trip");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rvf_vital_config_round_trip() {
|
||||
let config = VitalSignConfig {
|
||||
breathing_low_hz: 0.15,
|
||||
breathing_high_hz: 0.45,
|
||||
heartrate_low_hz: 0.9,
|
||||
heartrate_high_hz: 1.8,
|
||||
min_subcarriers: 64,
|
||||
window_size: 1024,
|
||||
confidence_threshold: 0.7,
|
||||
};
|
||||
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_vital_config(&config);
|
||||
let data = builder.build();
|
||||
|
||||
let reader = RvfReader::from_bytes(&data).expect("should parse");
|
||||
let decoded = reader
|
||||
.vital_config()
|
||||
.expect("should have vital config");
|
||||
|
||||
assert!(
|
||||
(decoded.breathing_low_hz - 0.15).abs() < f64::EPSILON,
|
||||
"breathing_low_hz mismatch"
|
||||
);
|
||||
assert!(
|
||||
(decoded.breathing_high_hz - 0.45).abs() < f64::EPSILON,
|
||||
"breathing_high_hz mismatch"
|
||||
);
|
||||
assert!(
|
||||
(decoded.heartrate_low_hz - 0.9).abs() < f64::EPSILON,
|
||||
"heartrate_low_hz mismatch"
|
||||
);
|
||||
assert!(
|
||||
(decoded.heartrate_high_hz - 1.8).abs() < f64::EPSILON,
|
||||
"heartrate_high_hz mismatch"
|
||||
);
|
||||
assert_eq!(decoded.min_subcarriers, 64);
|
||||
assert_eq!(decoded.window_size, 1024);
|
||||
assert!(
|
||||
(decoded.confidence_threshold - 0.7).abs() < f64::EPSILON,
|
||||
"confidence_threshold mismatch"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rvf_info_struct() {
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_manifest("info-test", "2.0", "Info struct test");
|
||||
builder.add_weights(&[1.0, 2.0, 3.0]);
|
||||
builder.add_vital_config(&VitalSignConfig::default());
|
||||
builder.add_witness("sha256:test", &serde_json::json!({"ok": true}));
|
||||
|
||||
let data = builder.build();
|
||||
let reader = RvfReader::from_bytes(&data).expect("should parse");
|
||||
let info = reader.info();
|
||||
|
||||
assert_eq!(info.segment_count, 4);
|
||||
assert!(info.total_size > 0);
|
||||
assert!(info.manifest.is_some());
|
||||
assert!(info.has_weights);
|
||||
assert!(info.has_vital_config);
|
||||
assert!(info.has_witness);
|
||||
assert!(!info.has_quant_info, "no quant segment was added");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rvf_alignment_invariant() {
|
||||
// Every container should have total size that is a multiple of 64
|
||||
for num_weights in [0, 1, 10, 100, 255, 256, 1000] {
|
||||
let weights: Vec<f32> = (0..num_weights).map(|i| i as f32).collect();
|
||||
let mut builder = RvfBuilder::new();
|
||||
builder.add_weights(&weights);
|
||||
let data = builder.build();
|
||||
|
||||
assert_eq!(
|
||||
data.len() % 64,
|
||||
0,
|
||||
"container with {num_weights} weights should be 64-byte aligned, got {} bytes",
|
||||
data.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,645 @@
|
||||
//! Comprehensive integration tests for the vital sign detection module.
|
||||
//!
|
||||
//! These tests exercise the public VitalSignDetector API by feeding
|
||||
//! synthetic CSI frames (amplitude + phase vectors) and verifying the
|
||||
//! extracted breathing rate, heart rate, confidence, and signal quality.
|
||||
//!
|
||||
//! Test matrix:
|
||||
//! - Detector creation and sane defaults
|
||||
//! - Breathing rate detection from synthetic 0.25 Hz (15 BPM) sine
|
||||
//! - Heartbeat detection from synthetic 1.2 Hz (72 BPM) sine
|
||||
//! - Combined breathing + heartbeat detection
|
||||
//! - No-signal (constant amplitude) returns None or low confidence
|
||||
//! - Out-of-range frequencies are rejected or produce low confidence
|
||||
//! - Confidence increases with signal-to-noise ratio
|
||||
//! - Reset clears all internal buffers
|
||||
//! - Minimum samples threshold
|
||||
//! - Throughput benchmark (10000 frames)
|
||||
|
||||
use std::f64::consts::PI;
|
||||
use wifi_densepose_sensing_server::vital_signs::{VitalSignDetector, VitalSigns};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const N_SUBCARRIERS: usize = 56;
|
||||
|
||||
/// Generate a single CSI frame's amplitude vector with an embedded
|
||||
/// breathing-band sine wave at `freq_hz` Hz.
|
||||
///
|
||||
/// The returned amplitude has `N_SUBCARRIERS` elements, each with a
|
||||
/// per-subcarrier baseline plus the breathing modulation.
|
||||
fn make_breathing_frame(freq_hz: f64, t: f64) -> Vec<f64> {
|
||||
(0..N_SUBCARRIERS)
|
||||
.map(|i| {
|
||||
let base = 15.0 + 5.0 * (i as f64 * 0.1).sin();
|
||||
let breathing = 2.0 * (2.0 * PI * freq_hz * t).sin();
|
||||
base + breathing
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Generate a phase vector that produces a phase-variance signal oscillating
|
||||
/// at `freq_hz` Hz.
|
||||
///
|
||||
/// The heartbeat detector uses cross-subcarrier phase variance as its input
|
||||
/// feature. To produce variance that oscillates at freq_hz, we modulate the
|
||||
/// spread of phases across subcarriers at that frequency.
|
||||
fn make_heartbeat_phase_variance(freq_hz: f64, t: f64) -> Vec<f64> {
|
||||
// Modulation factor: variance peaks when modulation is high
|
||||
let modulation = 0.5 * (1.0 + (2.0 * PI * freq_hz * t).sin());
|
||||
(0..N_SUBCARRIERS)
|
||||
.map(|i| {
|
||||
// Each subcarrier gets a different phase offset, scaled by modulation
|
||||
let base = (i as f64 * 0.2).sin();
|
||||
base * modulation
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Generate constant-phase vector (no heartbeat signal).
|
||||
fn make_static_phase() -> Vec<f64> {
|
||||
(0..N_SUBCARRIERS)
|
||||
.map(|i| (i as f64 * 0.2).sin())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Feed `n_frames` of synthetic breathing data to a detector.
|
||||
fn feed_breathing_signal(
|
||||
detector: &mut VitalSignDetector,
|
||||
freq_hz: f64,
|
||||
sample_rate: f64,
|
||||
n_frames: usize,
|
||||
) -> VitalSigns {
|
||||
let phase = make_static_phase();
|
||||
let mut vitals = VitalSigns::default();
|
||||
for frame in 0..n_frames {
|
||||
let t = frame as f64 / sample_rate;
|
||||
let amp = make_breathing_frame(freq_hz, t);
|
||||
vitals = detector.process_frame(&, &phase);
|
||||
}
|
||||
vitals
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_vital_detector_creation() {
|
||||
let sample_rate = 20.0;
|
||||
let detector = VitalSignDetector::new(sample_rate);
|
||||
|
||||
// Buffer status should be empty initially
|
||||
let (br_len, br_cap, hb_len, hb_cap) = detector.buffer_status();
|
||||
|
||||
assert_eq!(br_len, 0, "breathing buffer should start empty");
|
||||
assert_eq!(hb_len, 0, "heartbeat buffer should start empty");
|
||||
assert!(br_cap > 0, "breathing capacity should be positive");
|
||||
assert!(hb_cap > 0, "heartbeat capacity should be positive");
|
||||
|
||||
// Capacities should be based on sample rate and window durations
|
||||
// At 20 Hz with 30s breathing window: 600 samples
|
||||
// At 20 Hz with 15s heartbeat window: 300 samples
|
||||
assert_eq!(br_cap, 600, "breathing capacity at 20 Hz * 30s = 600");
|
||||
assert_eq!(hb_cap, 300, "heartbeat capacity at 20 Hz * 15s = 300");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_breathing_detection_synthetic() {
|
||||
let sample_rate = 20.0;
|
||||
let breathing_freq = 0.25; // 15 BPM
|
||||
let mut detector = VitalSignDetector::new(sample_rate);
|
||||
|
||||
// Feed 30 seconds of clear breathing signal
|
||||
let n_frames = (sample_rate * 30.0) as usize; // 600 frames
|
||||
let vitals = feed_breathing_signal(&mut detector, breathing_freq, sample_rate, n_frames);
|
||||
|
||||
// Breathing rate should be detected
|
||||
let bpm = vitals
|
||||
.breathing_rate_bpm
|
||||
.expect("should detect breathing rate from 0.25 Hz sine");
|
||||
|
||||
// Allow +/- 3 BPM tolerance (FFT resolution at 20 Hz over 600 samples)
|
||||
let expected_bpm = 15.0;
|
||||
assert!(
|
||||
(bpm - expected_bpm).abs() < 3.0,
|
||||
"breathing rate {:.1} BPM should be close to {:.1} BPM",
|
||||
bpm,
|
||||
expected_bpm,
|
||||
);
|
||||
|
||||
assert!(
|
||||
vitals.breathing_confidence > 0.0,
|
||||
"breathing confidence should be > 0, got {}",
|
||||
vitals.breathing_confidence,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_heartbeat_detection_synthetic() {
|
||||
let sample_rate = 20.0;
|
||||
let heartbeat_freq = 1.2; // 72 BPM
|
||||
let mut detector = VitalSignDetector::new(sample_rate);
|
||||
|
||||
// Feed 15 seconds of data with heartbeat signal in the phase variance
|
||||
let n_frames = (sample_rate * 15.0) as usize;
|
||||
|
||||
// Static amplitude -- no breathing signal
|
||||
let amp: Vec<f64> = (0..N_SUBCARRIERS)
|
||||
.map(|i| 15.0 + 5.0 * (i as f64 * 0.1).sin())
|
||||
.collect();
|
||||
|
||||
let mut vitals = VitalSigns::default();
|
||||
for frame in 0..n_frames {
|
||||
let t = frame as f64 / sample_rate;
|
||||
let phase = make_heartbeat_phase_variance(heartbeat_freq, t);
|
||||
vitals = detector.process_frame(&, &phase);
|
||||
}
|
||||
|
||||
// Heart rate detection from phase variance is more challenging.
|
||||
// We verify that if a heart rate is detected, it's in the valid
|
||||
// physiological range (40-120 BPM).
|
||||
if let Some(bpm) = vitals.heart_rate_bpm {
|
||||
assert!(
|
||||
bpm >= 40.0 && bpm <= 120.0,
|
||||
"detected heart rate {:.1} BPM should be in physiological range [40, 120]",
|
||||
bpm
|
||||
);
|
||||
}
|
||||
|
||||
// At minimum, heartbeat confidence should be non-negative
|
||||
assert!(
|
||||
vitals.heartbeat_confidence >= 0.0,
|
||||
"heartbeat confidence should be >= 0"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_combined_vital_signs() {
|
||||
let sample_rate = 20.0;
|
||||
let breathing_freq = 0.25; // 15 BPM
|
||||
let heartbeat_freq = 1.2; // 72 BPM
|
||||
let mut detector = VitalSignDetector::new(sample_rate);
|
||||
|
||||
// Feed 30 seconds with both signals
|
||||
let n_frames = (sample_rate * 30.0) as usize;
|
||||
let mut vitals = VitalSigns::default();
|
||||
for frame in 0..n_frames {
|
||||
let t = frame as f64 / sample_rate;
|
||||
|
||||
// Amplitude carries breathing modulation
|
||||
let amp = make_breathing_frame(breathing_freq, t);
|
||||
|
||||
// Phase carries heartbeat modulation (via variance)
|
||||
let phase = make_heartbeat_phase_variance(heartbeat_freq, t);
|
||||
|
||||
vitals = detector.process_frame(&, &phase);
|
||||
}
|
||||
|
||||
// Breathing should be detected accurately
|
||||
let breathing_bpm = vitals
|
||||
.breathing_rate_bpm
|
||||
.expect("should detect breathing in combined signal");
|
||||
assert!(
|
||||
(breathing_bpm - 15.0).abs() < 3.0,
|
||||
"breathing {:.1} BPM should be close to 15 BPM",
|
||||
breathing_bpm
|
||||
);
|
||||
|
||||
// Heartbeat: verify it's in the valid range if detected
|
||||
if let Some(hb_bpm) = vitals.heart_rate_bpm {
|
||||
assert!(
|
||||
hb_bpm >= 40.0 && hb_bpm <= 120.0,
|
||||
"heartbeat {:.1} BPM should be in range [40, 120]",
|
||||
hb_bpm
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_signal_lower_confidence_than_true_signal() {
|
||||
let sample_rate = 20.0;
|
||||
let n_frames = (sample_rate * 30.0) as usize;
|
||||
|
||||
// Detector A: constant amplitude (no real breathing signal)
|
||||
let mut detector_flat = VitalSignDetector::new(sample_rate);
|
||||
let amp_flat = vec![50.0; N_SUBCARRIERS];
|
||||
let phase = vec![0.0; N_SUBCARRIERS];
|
||||
for _ in 0..n_frames {
|
||||
detector_flat.process_frame(&_flat, &phase);
|
||||
}
|
||||
let (_, flat_conf) = detector_flat.extract_breathing();
|
||||
|
||||
// Detector B: clear 0.25 Hz breathing signal
|
||||
let mut detector_signal = VitalSignDetector::new(sample_rate);
|
||||
let phase_b = make_static_phase();
|
||||
for frame in 0..n_frames {
|
||||
let t = frame as f64 / sample_rate;
|
||||
let amp = make_breathing_frame(0.25, t);
|
||||
detector_signal.process_frame(&, &phase_b);
|
||||
}
|
||||
let (signal_rate, signal_conf) = detector_signal.extract_breathing();
|
||||
|
||||
// The real signal should be detected
|
||||
assert!(
|
||||
signal_rate.is_some(),
|
||||
"true breathing signal should be detected"
|
||||
);
|
||||
|
||||
// The real signal should have higher confidence than the flat signal.
|
||||
// Note: the bandpass filter creates transient artifacts on flat signals
|
||||
// that may produce non-zero confidence, but a true periodic signal should
|
||||
// always produce a stronger spectral peak.
|
||||
assert!(
|
||||
signal_conf >= flat_conf,
|
||||
"true signal confidence ({:.3}) should be >= flat signal confidence ({:.3})",
|
||||
signal_conf,
|
||||
flat_conf,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_out_of_range_lower_confidence_than_in_band() {
|
||||
let sample_rate = 20.0;
|
||||
let n_frames = (sample_rate * 30.0) as usize;
|
||||
let phase = make_static_phase();
|
||||
|
||||
// Detector A: 5 Hz amplitude oscillation (outside breathing band)
|
||||
let mut detector_oob = VitalSignDetector::new(sample_rate);
|
||||
let out_of_band_freq = 5.0;
|
||||
for frame in 0..n_frames {
|
||||
let t = frame as f64 / sample_rate;
|
||||
let amp: Vec<f64> = (0..N_SUBCARRIERS)
|
||||
.map(|i| {
|
||||
let base = 15.0 + 5.0 * (i as f64 * 0.1).sin();
|
||||
base + 2.0 * (2.0 * PI * out_of_band_freq * t).sin()
|
||||
})
|
||||
.collect();
|
||||
detector_oob.process_frame(&, &phase);
|
||||
}
|
||||
let (_, oob_conf) = detector_oob.extract_breathing();
|
||||
|
||||
// Detector B: 0.25 Hz amplitude oscillation (inside breathing band)
|
||||
let mut detector_inband = VitalSignDetector::new(sample_rate);
|
||||
for frame in 0..n_frames {
|
||||
let t = frame as f64 / sample_rate;
|
||||
let amp = make_breathing_frame(0.25, t);
|
||||
detector_inband.process_frame(&, &phase);
|
||||
}
|
||||
let (inband_rate, inband_conf) = detector_inband.extract_breathing();
|
||||
|
||||
// The in-band signal should be detected
|
||||
assert!(
|
||||
inband_rate.is_some(),
|
||||
"in-band 0.25 Hz signal should be detected as breathing"
|
||||
);
|
||||
|
||||
// The in-band signal should have higher confidence than the out-of-band one.
|
||||
// The bandpass filter may leak some energy from 5 Hz harmonics, but a true
|
||||
// 0.25 Hz signal should always dominate.
|
||||
assert!(
|
||||
inband_conf >= oob_conf,
|
||||
"in-band confidence ({:.3}) should be >= out-of-band confidence ({:.3})",
|
||||
inband_conf,
|
||||
oob_conf,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_confidence_increases_with_snr() {
|
||||
let sample_rate = 20.0;
|
||||
let breathing_freq = 0.25;
|
||||
let n_frames = (sample_rate * 30.0) as usize;
|
||||
|
||||
// High SNR: large breathing amplitude, no noise
|
||||
let mut detector_clean = VitalSignDetector::new(sample_rate);
|
||||
let phase = make_static_phase();
|
||||
|
||||
for frame in 0..n_frames {
|
||||
let t = frame as f64 / sample_rate;
|
||||
let amp: Vec<f64> = (0..N_SUBCARRIERS)
|
||||
.map(|i| {
|
||||
let base = 15.0 + 5.0 * (i as f64 * 0.1).sin();
|
||||
// Strong breathing signal (amplitude 5.0)
|
||||
base + 5.0 * (2.0 * PI * breathing_freq * t).sin()
|
||||
})
|
||||
.collect();
|
||||
detector_clean.process_frame(&, &phase);
|
||||
}
|
||||
let (_, clean_conf) = detector_clean.extract_breathing();
|
||||
|
||||
// Low SNR: small breathing amplitude, lots of noise
|
||||
let mut detector_noisy = VitalSignDetector::new(sample_rate);
|
||||
for frame in 0..n_frames {
|
||||
let t = frame as f64 / sample_rate;
|
||||
let amp: Vec<f64> = (0..N_SUBCARRIERS)
|
||||
.map(|i| {
|
||||
let base = 15.0 + 5.0 * (i as f64 * 0.1).sin();
|
||||
// Weak breathing signal (amplitude 0.1) + heavy noise
|
||||
let noise = 3.0
|
||||
* ((i as f64 * 7.3 + t * 113.7).sin()
|
||||
+ (i as f64 * 13.1 + t * 79.3).sin())
|
||||
/ 2.0;
|
||||
base + 0.1 * (2.0 * PI * breathing_freq * t).sin() + noise
|
||||
})
|
||||
.collect();
|
||||
detector_noisy.process_frame(&, &phase);
|
||||
}
|
||||
let (_, noisy_conf) = detector_noisy.extract_breathing();
|
||||
|
||||
assert!(
|
||||
clean_conf > noisy_conf,
|
||||
"clean signal confidence ({:.3}) should exceed noisy signal confidence ({:.3})",
|
||||
clean_conf,
|
||||
noisy_conf,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset_clears_buffers() {
|
||||
let mut detector = VitalSignDetector::new(20.0);
|
||||
let amp = vec![10.0; N_SUBCARRIERS];
|
||||
let phase = vec![0.0; N_SUBCARRIERS];
|
||||
|
||||
// Feed some frames to fill buffers
|
||||
for _ in 0..100 {
|
||||
detector.process_frame(&, &phase);
|
||||
}
|
||||
|
||||
let (br_len, _, hb_len, _) = detector.buffer_status();
|
||||
assert!(br_len > 0, "breathing buffer should have data before reset");
|
||||
assert!(hb_len > 0, "heartbeat buffer should have data before reset");
|
||||
|
||||
// Reset
|
||||
detector.reset();
|
||||
|
||||
let (br_len, _, hb_len, _) = detector.buffer_status();
|
||||
assert_eq!(br_len, 0, "breathing buffer should be empty after reset");
|
||||
assert_eq!(hb_len, 0, "heartbeat buffer should be empty after reset");
|
||||
|
||||
// Extraction should return None after reset
|
||||
let (breathing, _) = detector.extract_breathing();
|
||||
let (heartbeat, _) = detector.extract_heartbeat();
|
||||
assert!(
|
||||
breathing.is_none(),
|
||||
"breathing should be None after reset (not enough samples)"
|
||||
);
|
||||
assert!(
|
||||
heartbeat.is_none(),
|
||||
"heartbeat should be None after reset (not enough samples)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_minimum_samples_required() {
|
||||
let sample_rate = 20.0;
|
||||
let mut detector = VitalSignDetector::new(sample_rate);
|
||||
let amp = vec![10.0; N_SUBCARRIERS];
|
||||
let phase = vec![0.0; N_SUBCARRIERS];
|
||||
|
||||
// Feed fewer than MIN_BREATHING_SAMPLES (40) frames
|
||||
for _ in 0..39 {
|
||||
detector.process_frame(&, &phase);
|
||||
}
|
||||
|
||||
let (breathing, _) = detector.extract_breathing();
|
||||
assert!(
|
||||
breathing.is_none(),
|
||||
"with 39 samples (< 40 min), breathing should return None"
|
||||
);
|
||||
|
||||
// One more frame should meet the minimum
|
||||
detector.process_frame(&, &phase);
|
||||
|
||||
let (br_len, _, _, _) = detector.buffer_status();
|
||||
assert_eq!(br_len, 40, "should have exactly 40 samples now");
|
||||
|
||||
// Now extraction is at least attempted (may still be None if flat signal,
|
||||
// but should not be blocked by the min-samples check)
|
||||
let _ = detector.extract_breathing();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_benchmark_throughput() {
|
||||
let sample_rate = 20.0;
|
||||
let mut detector = VitalSignDetector::new(sample_rate);
|
||||
|
||||
let num_frames = 10_000;
|
||||
let n_sub = N_SUBCARRIERS;
|
||||
|
||||
// Pre-generate frames
|
||||
let frames: Vec<(Vec<f64>, Vec<f64>)> = (0..num_frames)
|
||||
.map(|tick| {
|
||||
let t = tick as f64 / sample_rate;
|
||||
let amp: Vec<f64> = (0..n_sub)
|
||||
.map(|i| {
|
||||
let base = 15.0 + 5.0 * (i as f64 * 0.1).sin();
|
||||
let breathing = 2.0 * (2.0 * PI * 0.25 * t).sin();
|
||||
let heartbeat = 0.3 * (2.0 * PI * 1.2 * t).sin();
|
||||
let noise = (i as f64 * 7.3 + t * 13.7).sin() * 0.5;
|
||||
base + breathing + heartbeat + noise
|
||||
})
|
||||
.collect();
|
||||
let phase: Vec<f64> = (0..n_sub)
|
||||
.map(|i| (i as f64 * 0.2 + t * 0.5).sin() * PI)
|
||||
.collect();
|
||||
(amp, phase)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
for (amp, phase) in &frames {
|
||||
detector.process_frame(amp, phase);
|
||||
}
|
||||
let elapsed = start.elapsed();
|
||||
let fps = num_frames as f64 / elapsed.as_secs_f64();
|
||||
|
||||
println!(
|
||||
"Vital sign benchmark: {} frames in {:.2}ms = {:.0} frames/sec",
|
||||
num_frames,
|
||||
elapsed.as_secs_f64() * 1000.0,
|
||||
fps
|
||||
);
|
||||
|
||||
// Should process at least 100 frames/sec on any reasonable hardware
|
||||
assert!(
|
||||
fps > 100.0,
|
||||
"throughput {:.0} fps is too low (expected > 100 fps)",
|
||||
fps,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vital_signs_default() {
|
||||
let vs = VitalSigns::default();
|
||||
assert!(vs.breathing_rate_bpm.is_none());
|
||||
assert!(vs.heart_rate_bpm.is_none());
|
||||
assert_eq!(vs.breathing_confidence, 0.0);
|
||||
assert_eq!(vs.heartbeat_confidence, 0.0);
|
||||
assert_eq!(vs.signal_quality, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_amplitude_frame() {
|
||||
let mut detector = VitalSignDetector::new(20.0);
|
||||
let vitals = detector.process_frame(&[], &[]);
|
||||
|
||||
assert!(vitals.breathing_rate_bpm.is_none());
|
||||
assert!(vitals.heart_rate_bpm.is_none());
|
||||
assert_eq!(vitals.signal_quality, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_subcarrier_no_panic() {
|
||||
let mut detector = VitalSignDetector::new(20.0);
|
||||
|
||||
// Single subcarrier should not crash
|
||||
for i in 0..100 {
|
||||
let t = i as f64 / 20.0;
|
||||
let amp = vec![10.0 + (2.0 * PI * 0.25 * t).sin()];
|
||||
let phase = vec![0.0];
|
||||
let _ = detector.process_frame(&, &phase);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_signal_quality_varies_with_input() {
|
||||
let mut detector_static = VitalSignDetector::new(20.0);
|
||||
let mut detector_varied = VitalSignDetector::new(20.0);
|
||||
|
||||
// Feed static signal (all same amplitude)
|
||||
for _ in 0..100 {
|
||||
let amp = vec![10.0; N_SUBCARRIERS];
|
||||
let phase = vec![0.0; N_SUBCARRIERS];
|
||||
detector_static.process_frame(&, &phase);
|
||||
}
|
||||
|
||||
// Feed varied signal (moderate CV -- body motion)
|
||||
for i in 0..100 {
|
||||
let t = i as f64 / 20.0;
|
||||
let amp: Vec<f64> = (0..N_SUBCARRIERS)
|
||||
.map(|j| {
|
||||
let base = 15.0;
|
||||
let modulation = 2.0 * (2.0 * PI * 0.25 * t + j as f64 * 0.1).sin();
|
||||
base + modulation
|
||||
})
|
||||
.collect();
|
||||
let phase: Vec<f64> = (0..N_SUBCARRIERS)
|
||||
.map(|j| (j as f64 * 0.2 + t).sin())
|
||||
.collect();
|
||||
detector_varied.process_frame(&, &phase);
|
||||
}
|
||||
|
||||
// The varied signal should have higher signal quality than the static one
|
||||
let static_vitals =
|
||||
detector_static.process_frame(&vec![10.0; N_SUBCARRIERS], &vec![0.0; N_SUBCARRIERS]);
|
||||
let amp_varied: Vec<f64> = (0..N_SUBCARRIERS)
|
||||
.map(|j| 15.0 + 2.0 * (j as f64 * 0.3).sin())
|
||||
.collect();
|
||||
let phase_varied: Vec<f64> = (0..N_SUBCARRIERS).map(|j| (j as f64 * 0.2).sin()).collect();
|
||||
let varied_vitals = detector_varied.process_frame(&_varied, &phase_varied);
|
||||
|
||||
assert!(
|
||||
varied_vitals.signal_quality >= static_vitals.signal_quality,
|
||||
"varied signal quality ({:.3}) should be >= static ({:.3})",
|
||||
varied_vitals.signal_quality,
|
||||
static_vitals.signal_quality,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_buffer_capacity_respected() {
|
||||
let sample_rate = 20.0;
|
||||
let mut detector = VitalSignDetector::new(sample_rate);
|
||||
|
||||
let amp = vec![10.0; N_SUBCARRIERS];
|
||||
let phase = vec![0.0; N_SUBCARRIERS];
|
||||
|
||||
// Feed more frames than breathing capacity (600)
|
||||
for _ in 0..1000 {
|
||||
detector.process_frame(&, &phase);
|
||||
}
|
||||
|
||||
let (br_len, br_cap, hb_len, hb_cap) = detector.buffer_status();
|
||||
assert!(
|
||||
br_len <= br_cap,
|
||||
"breathing buffer length {} should not exceed capacity {}",
|
||||
br_len,
|
||||
br_cap
|
||||
);
|
||||
assert!(
|
||||
hb_len <= hb_cap,
|
||||
"heartbeat buffer length {} should not exceed capacity {}",
|
||||
hb_len,
|
||||
hb_cap
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_run_benchmark_function() {
|
||||
let (total, per_frame) = wifi_densepose_sensing_server::vital_signs::run_benchmark(50);
|
||||
assert!(total.as_nanos() > 0, "benchmark total duration should be > 0");
|
||||
assert!(
|
||||
per_frame.as_nanos() > 0,
|
||||
"benchmark per-frame duration should be > 0"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_breathing_rate_in_physiological_range() {
|
||||
// If breathing is detected, it must always be in the physiological range
|
||||
// (6-30 BPM = 0.1-0.5 Hz)
|
||||
let sample_rate = 20.0;
|
||||
let mut detector = VitalSignDetector::new(sample_rate);
|
||||
let n_frames = (sample_rate * 30.0) as usize;
|
||||
|
||||
let mut vitals = VitalSigns::default();
|
||||
for frame in 0..n_frames {
|
||||
let t = frame as f64 / sample_rate;
|
||||
let amp = make_breathing_frame(0.3, t); // 18 BPM
|
||||
let phase = make_static_phase();
|
||||
vitals = detector.process_frame(&, &phase);
|
||||
}
|
||||
|
||||
if let Some(bpm) = vitals.breathing_rate_bpm {
|
||||
assert!(
|
||||
bpm >= 6.0 && bpm <= 30.0,
|
||||
"breathing rate {:.1} BPM must be in range [6, 30]",
|
||||
bpm
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_detectors_independent() {
|
||||
// Two detectors should not interfere with each other
|
||||
let sample_rate = 20.0;
|
||||
let mut detector_a = VitalSignDetector::new(sample_rate);
|
||||
let mut detector_b = VitalSignDetector::new(sample_rate);
|
||||
|
||||
let phase = make_static_phase();
|
||||
|
||||
// Feed different breathing rates
|
||||
for frame in 0..(sample_rate * 30.0) as usize {
|
||||
let t = frame as f64 / sample_rate;
|
||||
let amp_a = make_breathing_frame(0.2, t); // 12 BPM
|
||||
let amp_b = make_breathing_frame(0.4, t); // 24 BPM
|
||||
detector_a.process_frame(&_a, &phase);
|
||||
detector_b.process_frame(&_b, &phase);
|
||||
}
|
||||
|
||||
let (rate_a, _) = detector_a.extract_breathing();
|
||||
let (rate_b, _) = detector_b.extract_breathing();
|
||||
|
||||
if let (Some(a), Some(b)) = (rate_a, rate_b) {
|
||||
// They should detect different rates
|
||||
assert!(
|
||||
(a - b).abs() > 2.0,
|
||||
"detector A ({:.1} BPM) and B ({:.1} BPM) should detect different rates",
|
||||
a,
|
||||
b
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
[package]
|
||||
name = "wifi-densepose-vitals"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description = "ESP32 CSI-grade vital sign extraction (ADR-021): heart rate and respiratory rate from WiFi Channel State Information"
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
tracing.workspace = true
|
||||
serde = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
serde_json.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["serde"]
|
||||
serde = ["dep:serde"]
|
||||
|
||||
[lints.rust]
|
||||
unsafe_code = "forbid"
|
||||
|
||||
[lints.clippy]
|
||||
all = "warn"
|
||||
pedantic = "warn"
|
||||
doc_markdown = "allow"
|
||||
module_name_repetitions = "allow"
|
||||
must_use_candidate = "allow"
|
||||
missing_errors_doc = "allow"
|
||||
missing_panics_doc = "allow"
|
||||
cast_precision_loss = "allow"
|
||||
cast_lossless = "allow"
|
||||
cast_possible_truncation = "allow"
|
||||
cast_sign_loss = "allow"
|
||||
many_single_char_names = "allow"
|
||||
uninlined_format_args = "allow"
|
||||
assigning_clones = "allow"
|
||||
@@ -0,0 +1,399 @@
|
||||
//! Vital sign anomaly detection.
|
||||
//!
|
||||
//! Monitors vital sign readings for anomalies (apnea, tachycardia,
|
||||
//! bradycardia, sudden changes) using z-score detection with
|
||||
//! running mean and standard deviation.
|
||||
//!
|
||||
//! Modeled on the DNA biomarker anomaly detection pattern from
|
||||
//! `vendor/ruvector/examples/dna`, using Welford's online algorithm
|
||||
//! for numerically stable running statistics.
|
||||
|
||||
use crate::types::VitalReading;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// An anomaly alert generated from vital sign analysis.
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub struct AnomalyAlert {
|
||||
/// Type of vital sign: `"respiratory"` or `"cardiac"`.
|
||||
pub vital_type: String,
|
||||
/// Type of anomaly: `"apnea"`, `"tachypnea"`, `"bradypnea"`,
|
||||
/// `"tachycardia"`, `"bradycardia"`, `"sudden_change"`.
|
||||
pub alert_type: String,
|
||||
/// Severity [0.0, 1.0].
|
||||
pub severity: f64,
|
||||
/// Human-readable description.
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
/// Welford online statistics accumulator.
|
||||
#[derive(Debug, Clone)]
|
||||
struct WelfordStats {
|
||||
count: u64,
|
||||
mean: f64,
|
||||
m2: f64,
|
||||
}
|
||||
|
||||
impl WelfordStats {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
count: 0,
|
||||
mean: 0.0,
|
||||
m2: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&mut self, value: f64) {
|
||||
self.count += 1;
|
||||
let delta = value - self.mean;
|
||||
self.mean += delta / self.count as f64;
|
||||
let delta2 = value - self.mean;
|
||||
self.m2 += delta * delta2;
|
||||
}
|
||||
|
||||
fn variance(&self) -> f64 {
|
||||
if self.count < 2 {
|
||||
return 0.0;
|
||||
}
|
||||
self.m2 / (self.count - 1) as f64
|
||||
}
|
||||
|
||||
fn std_dev(&self) -> f64 {
|
||||
self.variance().sqrt()
|
||||
}
|
||||
|
||||
fn z_score(&self, value: f64) -> f64 {
|
||||
let sd = self.std_dev();
|
||||
if sd < 1e-10 {
|
||||
return 0.0;
|
||||
}
|
||||
(value - self.mean) / sd
|
||||
}
|
||||
}
|
||||
|
||||
/// Vital sign anomaly detector using z-score analysis with
|
||||
/// running statistics.
|
||||
pub struct VitalAnomalyDetector {
|
||||
/// Running statistics for respiratory rate.
|
||||
rr_stats: WelfordStats,
|
||||
/// Running statistics for heart rate.
|
||||
hr_stats: WelfordStats,
|
||||
/// Recent respiratory rate values for windowed analysis.
|
||||
rr_history: Vec<f64>,
|
||||
/// Recent heart rate values for windowed analysis.
|
||||
hr_history: Vec<f64>,
|
||||
/// Maximum window size for history.
|
||||
window: usize,
|
||||
/// Z-score threshold for anomaly detection.
|
||||
z_threshold: f64,
|
||||
}
|
||||
|
||||
impl VitalAnomalyDetector {
|
||||
/// Create a new anomaly detector.
|
||||
///
|
||||
/// - `window`: number of recent readings to retain.
|
||||
/// - `z_threshold`: z-score threshold for anomaly alerts (default: 2.5).
|
||||
#[must_use]
|
||||
pub fn new(window: usize, z_threshold: f64) -> Self {
|
||||
Self {
|
||||
rr_stats: WelfordStats::new(),
|
||||
hr_stats: WelfordStats::new(),
|
||||
rr_history: Vec::with_capacity(window),
|
||||
hr_history: Vec::with_capacity(window),
|
||||
window,
|
||||
z_threshold,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with defaults (window = 60, z_threshold = 2.5).
|
||||
#[must_use]
|
||||
pub fn default_config() -> Self {
|
||||
Self::new(60, 2.5)
|
||||
}
|
||||
|
||||
/// Check a vital sign reading for anomalies.
|
||||
///
|
||||
/// Updates running statistics and returns a list of detected
|
||||
/// anomaly alerts (may be empty if all readings are normal).
|
||||
pub fn check(&mut self, reading: &VitalReading) -> Vec<AnomalyAlert> {
|
||||
let mut alerts = Vec::new();
|
||||
|
||||
let rr = reading.respiratory_rate.value_bpm;
|
||||
let hr = reading.heart_rate.value_bpm;
|
||||
|
||||
// Update histories
|
||||
self.rr_history.push(rr);
|
||||
if self.rr_history.len() > self.window {
|
||||
self.rr_history.remove(0);
|
||||
}
|
||||
self.hr_history.push(hr);
|
||||
if self.hr_history.len() > self.window {
|
||||
self.hr_history.remove(0);
|
||||
}
|
||||
|
||||
// Update running statistics
|
||||
self.rr_stats.update(rr);
|
||||
self.hr_stats.update(hr);
|
||||
|
||||
// Need at least a few readings before detecting anomalies
|
||||
if self.rr_stats.count < 5 {
|
||||
return alerts;
|
||||
}
|
||||
|
||||
// --- Respiratory rate anomalies ---
|
||||
let rr_z = self.rr_stats.z_score(rr);
|
||||
|
||||
// Clinical thresholds for respiratory rate (adult)
|
||||
if rr < 4.0 && reading.respiratory_rate.confidence > 0.3 {
|
||||
alerts.push(AnomalyAlert {
|
||||
vital_type: "respiratory".to_string(),
|
||||
alert_type: "apnea".to_string(),
|
||||
severity: 0.9,
|
||||
message: format!("Possible apnea detected: RR = {rr:.1} BPM"),
|
||||
});
|
||||
} else if rr > 30.0 && reading.respiratory_rate.confidence > 0.3 {
|
||||
alerts.push(AnomalyAlert {
|
||||
vital_type: "respiratory".to_string(),
|
||||
alert_type: "tachypnea".to_string(),
|
||||
severity: ((rr - 30.0) / 20.0).clamp(0.3, 1.0),
|
||||
message: format!("Elevated respiratory rate: RR = {rr:.1} BPM"),
|
||||
});
|
||||
} else if rr < 8.0 && reading.respiratory_rate.confidence > 0.3 {
|
||||
alerts.push(AnomalyAlert {
|
||||
vital_type: "respiratory".to_string(),
|
||||
alert_type: "bradypnea".to_string(),
|
||||
severity: ((8.0 - rr) / 8.0).clamp(0.3, 0.8),
|
||||
message: format!("Low respiratory rate: RR = {rr:.1} BPM"),
|
||||
});
|
||||
}
|
||||
|
||||
// Z-score based sudden change detection for RR
|
||||
if rr_z.abs() > self.z_threshold {
|
||||
alerts.push(AnomalyAlert {
|
||||
vital_type: "respiratory".to_string(),
|
||||
alert_type: "sudden_change".to_string(),
|
||||
severity: (rr_z.abs() / (self.z_threshold * 2.0)).clamp(0.2, 1.0),
|
||||
message: format!(
|
||||
"Sudden respiratory rate change: z-score = {rr_z:.2} (RR = {rr:.1} BPM)"
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
// --- Heart rate anomalies ---
|
||||
let hr_z = self.hr_stats.z_score(hr);
|
||||
|
||||
if hr > 100.0 && reading.heart_rate.confidence > 0.3 {
|
||||
alerts.push(AnomalyAlert {
|
||||
vital_type: "cardiac".to_string(),
|
||||
alert_type: "tachycardia".to_string(),
|
||||
severity: ((hr - 100.0) / 80.0).clamp(0.3, 1.0),
|
||||
message: format!("Elevated heart rate: HR = {hr:.1} BPM"),
|
||||
});
|
||||
} else if hr < 50.0 && reading.heart_rate.confidence > 0.3 {
|
||||
alerts.push(AnomalyAlert {
|
||||
vital_type: "cardiac".to_string(),
|
||||
alert_type: "bradycardia".to_string(),
|
||||
severity: ((50.0 - hr) / 30.0).clamp(0.3, 1.0),
|
||||
message: format!("Low heart rate: HR = {hr:.1} BPM"),
|
||||
});
|
||||
}
|
||||
|
||||
// Z-score based sudden change detection for HR
|
||||
if hr_z.abs() > self.z_threshold {
|
||||
alerts.push(AnomalyAlert {
|
||||
vital_type: "cardiac".to_string(),
|
||||
alert_type: "sudden_change".to_string(),
|
||||
severity: (hr_z.abs() / (self.z_threshold * 2.0)).clamp(0.2, 1.0),
|
||||
message: format!(
|
||||
"Sudden heart rate change: z-score = {hr_z:.2} (HR = {hr:.1} BPM)"
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
alerts
|
||||
}
|
||||
|
||||
/// Reset all accumulated statistics and history.
|
||||
pub fn reset(&mut self) {
|
||||
self.rr_stats = WelfordStats::new();
|
||||
self.hr_stats = WelfordStats::new();
|
||||
self.rr_history.clear();
|
||||
self.hr_history.clear();
|
||||
}
|
||||
|
||||
/// Number of readings processed so far.
|
||||
#[must_use]
|
||||
pub fn reading_count(&self) -> u64 {
|
||||
self.rr_stats.count
|
||||
}
|
||||
|
||||
/// Current running mean for respiratory rate.
|
||||
#[must_use]
|
||||
pub fn rr_mean(&self) -> f64 {
|
||||
self.rr_stats.mean
|
||||
}
|
||||
|
||||
/// Current running mean for heart rate.
|
||||
#[must_use]
|
||||
pub fn hr_mean(&self) -> f64 {
|
||||
self.hr_stats.mean
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::{VitalEstimate, VitalReading, VitalStatus};
|
||||
|
||||
fn make_reading(rr_bpm: f64, hr_bpm: f64) -> VitalReading {
|
||||
VitalReading {
|
||||
respiratory_rate: VitalEstimate {
|
||||
value_bpm: rr_bpm,
|
||||
confidence: 0.8,
|
||||
status: VitalStatus::Valid,
|
||||
},
|
||||
heart_rate: VitalEstimate {
|
||||
value_bpm: hr_bpm,
|
||||
confidence: 0.8,
|
||||
status: VitalStatus::Valid,
|
||||
},
|
||||
subcarrier_count: 56,
|
||||
signal_quality: 0.9,
|
||||
timestamp_secs: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_alerts_for_normal_readings() {
|
||||
let mut det = VitalAnomalyDetector::new(30, 2.5);
|
||||
// Feed 20 normal readings
|
||||
for _ in 0..20 {
|
||||
let alerts = det.check(&make_reading(15.0, 72.0));
|
||||
// After warmup, should have no alerts
|
||||
if det.reading_count() > 5 {
|
||||
assert!(alerts.is_empty(), "normal readings should not trigger alerts");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_tachycardia() {
|
||||
let mut det = VitalAnomalyDetector::new(30, 2.5);
|
||||
// Warmup with normal
|
||||
for _ in 0..10 {
|
||||
det.check(&make_reading(15.0, 72.0));
|
||||
}
|
||||
// Elevated HR
|
||||
let alerts = det.check(&make_reading(15.0, 130.0));
|
||||
let tachycardia = alerts
|
||||
.iter()
|
||||
.any(|a| a.alert_type == "tachycardia");
|
||||
assert!(tachycardia, "should detect tachycardia at 130 BPM");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_bradycardia() {
|
||||
let mut det = VitalAnomalyDetector::new(30, 2.5);
|
||||
for _ in 0..10 {
|
||||
det.check(&make_reading(15.0, 72.0));
|
||||
}
|
||||
let alerts = det.check(&make_reading(15.0, 40.0));
|
||||
let brady = alerts.iter().any(|a| a.alert_type == "bradycardia");
|
||||
assert!(brady, "should detect bradycardia at 40 BPM");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_apnea() {
|
||||
let mut det = VitalAnomalyDetector::new(30, 2.5);
|
||||
for _ in 0..10 {
|
||||
det.check(&make_reading(15.0, 72.0));
|
||||
}
|
||||
let alerts = det.check(&make_reading(2.0, 72.0));
|
||||
let apnea = alerts.iter().any(|a| a.alert_type == "apnea");
|
||||
assert!(apnea, "should detect apnea at 2 BPM");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_tachypnea() {
|
||||
let mut det = VitalAnomalyDetector::new(30, 2.5);
|
||||
for _ in 0..10 {
|
||||
det.check(&make_reading(15.0, 72.0));
|
||||
}
|
||||
let alerts = det.check(&make_reading(35.0, 72.0));
|
||||
let tachypnea = alerts.iter().any(|a| a.alert_type == "tachypnea");
|
||||
assert!(tachypnea, "should detect tachypnea at 35 BPM");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_sudden_change() {
|
||||
let mut det = VitalAnomalyDetector::new(30, 2.0);
|
||||
// Build a stable baseline
|
||||
for _ in 0..30 {
|
||||
det.check(&make_reading(15.0, 72.0));
|
||||
}
|
||||
// Sudden jump (still in normal clinical range but statistically anomalous)
|
||||
let alerts = det.check(&make_reading(15.0, 95.0));
|
||||
let sudden = alerts.iter().any(|a| a.alert_type == "sudden_change");
|
||||
assert!(sudden, "should detect sudden HR change from 72 to 95 BPM");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reset_clears_state() {
|
||||
let mut det = VitalAnomalyDetector::new(30, 2.5);
|
||||
for _ in 0..10 {
|
||||
det.check(&make_reading(15.0, 72.0));
|
||||
}
|
||||
assert!(det.reading_count() > 0);
|
||||
det.reset();
|
||||
assert_eq!(det.reading_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn welford_stats_basic() {
|
||||
let mut stats = WelfordStats::new();
|
||||
stats.update(10.0);
|
||||
stats.update(20.0);
|
||||
stats.update(30.0);
|
||||
assert!((stats.mean - 20.0).abs() < 1e-10);
|
||||
assert!(stats.std_dev() > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn welford_z_score() {
|
||||
let mut stats = WelfordStats::new();
|
||||
for i in 0..100 {
|
||||
stats.update(50.0 + (i % 3) as f64);
|
||||
}
|
||||
// A value far from the mean should have a high z-score
|
||||
let z = stats.z_score(100.0);
|
||||
assert!(z > 2.0, "z-score for extreme value should be > 2: {z}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn running_means_are_tracked() {
|
||||
let mut det = VitalAnomalyDetector::new(30, 2.5);
|
||||
for _ in 0..10 {
|
||||
det.check(&make_reading(16.0, 75.0));
|
||||
}
|
||||
assert!((det.rr_mean() - 16.0).abs() < 0.5);
|
||||
assert!((det.hr_mean() - 75.0).abs() < 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn severity_is_clamped() {
|
||||
let mut det = VitalAnomalyDetector::new(30, 2.5);
|
||||
for _ in 0..10 {
|
||||
det.check(&make_reading(15.0, 72.0));
|
||||
}
|
||||
let alerts = det.check(&make_reading(15.0, 200.0));
|
||||
for alert in &alerts {
|
||||
assert!(
|
||||
alert.severity >= 0.0 && alert.severity <= 1.0,
|
||||
"severity should be in [0,1]: {}",
|
||||
alert.severity,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,318 @@
|
||||
//! Respiratory rate extraction from CSI residuals.
|
||||
//!
|
||||
//! Uses bandpass filtering (0.1-0.5 Hz) and spectral analysis
|
||||
//! to extract breathing rate from multi-subcarrier CSI data.
|
||||
//!
|
||||
//! The approach follows the same IIR bandpass + zero-crossing pattern
|
||||
//! used by [`CoarseBreathingExtractor`](wifi_densepose_wifiscan::pipeline::CoarseBreathingExtractor)
|
||||
//! in the wifiscan crate, adapted for multi-subcarrier f64 processing
|
||||
//! with weighted subcarrier fusion.
|
||||
|
||||
use crate::types::{VitalEstimate, VitalStatus};
|
||||
|
||||
/// IIR bandpass filter state (2nd-order resonator).
|
||||
#[derive(Clone, Debug)]
|
||||
struct IirState {
|
||||
x1: f64,
|
||||
x2: f64,
|
||||
y1: f64,
|
||||
y2: f64,
|
||||
}
|
||||
|
||||
impl Default for IirState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
x1: 0.0,
|
||||
x2: 0.0,
|
||||
y1: 0.0,
|
||||
y2: 0.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Respiratory rate extractor using bandpass filtering and zero-crossing analysis.
|
||||
pub struct BreathingExtractor {
|
||||
/// Per-sample filtered signal history.
|
||||
filtered_history: Vec<f64>,
|
||||
/// Sample rate in Hz.
|
||||
sample_rate: f64,
|
||||
/// Analysis window in seconds.
|
||||
window_secs: f64,
|
||||
/// Maximum subcarrier slots.
|
||||
n_subcarriers: usize,
|
||||
/// Breathing band low cutoff (Hz).
|
||||
freq_low: f64,
|
||||
/// Breathing band high cutoff (Hz).
|
||||
freq_high: f64,
|
||||
/// IIR filter state.
|
||||
filter_state: IirState,
|
||||
}
|
||||
|
||||
impl BreathingExtractor {
|
||||
/// Create a new breathing extractor.
|
||||
///
|
||||
/// - `n_subcarriers`: number of subcarrier channels.
|
||||
/// - `sample_rate`: input sample rate in Hz.
|
||||
/// - `window_secs`: analysis window length in seconds (default: 30).
|
||||
#[must_use]
|
||||
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
|
||||
pub fn new(n_subcarriers: usize, sample_rate: f64, window_secs: f64) -> Self {
|
||||
let capacity = (sample_rate * window_secs) as usize;
|
||||
Self {
|
||||
filtered_history: Vec::with_capacity(capacity),
|
||||
sample_rate,
|
||||
window_secs,
|
||||
n_subcarriers,
|
||||
freq_low: 0.1,
|
||||
freq_high: 0.5,
|
||||
filter_state: IirState::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with ESP32 defaults (56 subcarriers, 100 Hz, 30 s window).
|
||||
#[must_use]
|
||||
pub fn esp32_default() -> Self {
|
||||
Self::new(56, 100.0, 30.0)
|
||||
}
|
||||
|
||||
/// Extract respiratory rate from a vector of per-subcarrier residuals.
|
||||
///
|
||||
/// - `residuals`: amplitude residuals from the preprocessor.
|
||||
/// - `weights`: per-subcarrier attention weights (higher = more
|
||||
/// body-sensitive). If shorter than `residuals`, missing weights
|
||||
/// default to uniform.
|
||||
///
|
||||
/// Returns a `VitalEstimate` with the breathing rate in BPM, or
|
||||
/// `None` if insufficient history has been accumulated.
|
||||
pub fn extract(&mut self, residuals: &[f64], weights: &[f64]) -> Option<VitalEstimate> {
|
||||
let n = residuals.len().min(self.n_subcarriers);
|
||||
if n == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Weighted fusion of subcarrier residuals
|
||||
let uniform_w = 1.0 / n as f64;
|
||||
let weighted_signal: f64 = residuals
|
||||
.iter()
|
||||
.enumerate()
|
||||
.take(n)
|
||||
.map(|(i, &r)| {
|
||||
let w = weights.get(i).copied().unwrap_or(uniform_w);
|
||||
r * w
|
||||
})
|
||||
.sum();
|
||||
|
||||
// Apply IIR bandpass filter
|
||||
let filtered = self.bandpass_filter(weighted_signal);
|
||||
|
||||
// Append to history, enforce window limit
|
||||
self.filtered_history.push(filtered);
|
||||
let max_len = (self.sample_rate * self.window_secs) as usize;
|
||||
if self.filtered_history.len() > max_len {
|
||||
self.filtered_history.remove(0);
|
||||
}
|
||||
|
||||
// Need at least 10 seconds of data
|
||||
let min_samples = (self.sample_rate * 10.0) as usize;
|
||||
if self.filtered_history.len() < min_samples {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Zero-crossing rate -> frequency
|
||||
let crossings = count_zero_crossings(&self.filtered_history);
|
||||
let duration_s = self.filtered_history.len() as f64 / self.sample_rate;
|
||||
let frequency_hz = crossings as f64 / (2.0 * duration_s);
|
||||
|
||||
// Validate frequency is within the breathing band
|
||||
if frequency_hz < self.freq_low || frequency_hz > self.freq_high {
|
||||
return None;
|
||||
}
|
||||
|
||||
let bpm = frequency_hz * 60.0;
|
||||
let confidence = compute_confidence(&self.filtered_history);
|
||||
|
||||
let status = if confidence >= 0.7 {
|
||||
VitalStatus::Valid
|
||||
} else if confidence >= 0.4 {
|
||||
VitalStatus::Degraded
|
||||
} else {
|
||||
VitalStatus::Unreliable
|
||||
};
|
||||
|
||||
Some(VitalEstimate {
|
||||
value_bpm: bpm,
|
||||
confidence,
|
||||
status,
|
||||
})
|
||||
}
|
||||
|
||||
/// 2nd-order IIR bandpass filter using a resonator topology.
|
||||
///
|
||||
/// y[n] = (1-r)*(x[n] - x[n-2]) + 2*r*cos(w0)*y[n-1] - r^2*y[n-2]
|
||||
fn bandpass_filter(&mut self, input: f64) -> f64 {
|
||||
let state = &mut self.filter_state;
|
||||
|
||||
let omega_low = 2.0 * std::f64::consts::PI * self.freq_low / self.sample_rate;
|
||||
let omega_high = 2.0 * std::f64::consts::PI * self.freq_high / self.sample_rate;
|
||||
let bw = omega_high - omega_low;
|
||||
let center = f64::midpoint(omega_low, omega_high);
|
||||
|
||||
let r = 1.0 - bw / 2.0;
|
||||
let cos_w0 = center.cos();
|
||||
|
||||
let output =
|
||||
(1.0 - r) * (input - state.x2) + 2.0 * r * cos_w0 * state.y1 - r * r * state.y2;
|
||||
|
||||
state.x2 = state.x1;
|
||||
state.x1 = input;
|
||||
state.y2 = state.y1;
|
||||
state.y1 = output;
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Reset all filter state and history.
|
||||
pub fn reset(&mut self) {
|
||||
self.filtered_history.clear();
|
||||
self.filter_state = IirState::default();
|
||||
}
|
||||
|
||||
/// Current number of samples in the history buffer.
|
||||
#[must_use]
|
||||
pub fn history_len(&self) -> usize {
|
||||
self.filtered_history.len()
|
||||
}
|
||||
|
||||
/// Breathing band cutoff frequencies.
|
||||
#[must_use]
|
||||
pub fn band(&self) -> (f64, f64) {
|
||||
(self.freq_low, self.freq_high)
|
||||
}
|
||||
}
|
||||
|
||||
/// Count zero crossings in a signal.
|
||||
fn count_zero_crossings(signal: &[f64]) -> usize {
|
||||
signal.windows(2).filter(|w| w[0] * w[1] < 0.0).count()
|
||||
}
|
||||
|
||||
/// Compute confidence in the breathing estimate based on signal regularity.
|
||||
fn compute_confidence(history: &[f64]) -> f64 {
|
||||
if history.len() < 4 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let n = history.len() as f64;
|
||||
let mean: f64 = history.iter().sum::<f64>() / n;
|
||||
let variance: f64 = history.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / n;
|
||||
|
||||
if variance < 1e-15 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let peak = history
|
||||
.iter()
|
||||
.map(|x| x.abs())
|
||||
.fold(0.0_f64, f64::max);
|
||||
let noise = variance.sqrt();
|
||||
|
||||
let snr = if noise > 1e-15 { peak / noise } else { 0.0 };
|
||||
|
||||
// Map SNR to [0, 1] confidence
|
||||
(snr / 5.0).min(1.0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn no_data_returns_none() {
|
||||
let mut ext = BreathingExtractor::new(4, 10.0, 30.0);
|
||||
assert!(ext.extract(&[], &[]).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn insufficient_history_returns_none() {
|
||||
let mut ext = BreathingExtractor::new(2, 10.0, 30.0);
|
||||
// Just a few frames are not enough
|
||||
for _ in 0..5 {
|
||||
assert!(ext.extract(&[1.0, 2.0], &[0.5, 0.5]).is_none());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_crossings_count() {
|
||||
let signal = vec![1.0, -1.0, 1.0, -1.0, 1.0];
|
||||
assert_eq!(count_zero_crossings(&signal), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_crossings_constant() {
|
||||
let signal = vec![1.0, 1.0, 1.0, 1.0];
|
||||
assert_eq!(count_zero_crossings(&signal), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sinusoidal_breathing_detected() {
|
||||
let sample_rate = 10.0;
|
||||
let mut ext = BreathingExtractor::new(1, sample_rate, 60.0);
|
||||
let breathing_freq = 0.25; // 15 BPM
|
||||
|
||||
// Generate 60 seconds of sinusoidal breathing signal
|
||||
for i in 0..600 {
|
||||
let t = i as f64 / sample_rate;
|
||||
let signal = (2.0 * std::f64::consts::PI * breathing_freq * t).sin();
|
||||
ext.extract(&[signal], &[1.0]);
|
||||
}
|
||||
|
||||
let result = ext.extract(&[0.0], &[1.0]);
|
||||
if let Some(est) = result {
|
||||
// Should be approximately 15 BPM (0.25 Hz * 60)
|
||||
assert!(
|
||||
est.value_bpm > 5.0 && est.value_bpm < 40.0,
|
||||
"estimated BPM should be in breathing range: {}",
|
||||
est.value_bpm,
|
||||
);
|
||||
assert!(est.confidence > 0.0, "confidence should be > 0");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reset_clears_state() {
|
||||
let mut ext = BreathingExtractor::new(2, 10.0, 30.0);
|
||||
ext.extract(&[1.0, 2.0], &[0.5, 0.5]);
|
||||
assert!(ext.history_len() > 0);
|
||||
ext.reset();
|
||||
assert_eq!(ext.history_len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn band_returns_correct_values() {
|
||||
let ext = BreathingExtractor::new(1, 10.0, 30.0);
|
||||
let (low, high) = ext.band();
|
||||
assert!((low - 0.1).abs() < f64::EPSILON);
|
||||
assert!((high - 0.5).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn confidence_zero_for_flat_signal() {
|
||||
let history = vec![0.0; 100];
|
||||
let conf = compute_confidence(&history);
|
||||
assert!((conf - 0.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn confidence_positive_for_oscillating_signal() {
|
||||
let history: Vec<f64> = (0..100)
|
||||
.map(|i| (i as f64 * 0.5).sin())
|
||||
.collect();
|
||||
let conf = compute_confidence(&history);
|
||||
assert!(conf > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn esp32_default_creates_correctly() {
|
||||
let ext = BreathingExtractor::esp32_default();
|
||||
assert_eq!(ext.n_subcarriers, 56);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,396 @@
|
||||
//! Heart rate extraction from CSI phase coherence.
|
||||
//!
|
||||
//! Uses bandpass filtering (0.8-2.0 Hz) and autocorrelation-based
|
||||
//! peak detection to extract cardiac rate from inter-subcarrier
|
||||
//! phase data. Requires multi-subcarrier CSI data (ESP32 mode only).
|
||||
//!
|
||||
//! The cardiac signal (0.1-0.5 mm body surface displacement) is
|
||||
//! ~10x weaker than the respiratory signal (1-5 mm chest displacement),
|
||||
//! so this module relies on phase coherence across subcarriers rather
|
||||
//! than single-channel amplitude analysis.
|
||||
|
||||
use crate::types::{VitalEstimate, VitalStatus};
|
||||
|
||||
/// IIR bandpass filter state (2nd-order resonator).
|
||||
#[derive(Clone, Debug)]
|
||||
struct IirState {
|
||||
x1: f64,
|
||||
x2: f64,
|
||||
y1: f64,
|
||||
y2: f64,
|
||||
}
|
||||
|
||||
impl Default for IirState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
x1: 0.0,
|
||||
x2: 0.0,
|
||||
y1: 0.0,
|
||||
y2: 0.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Heart rate extractor using bandpass filtering and autocorrelation
|
||||
/// peak detection.
|
||||
pub struct HeartRateExtractor {
|
||||
/// Per-sample filtered signal history.
|
||||
filtered_history: Vec<f64>,
|
||||
/// Sample rate in Hz.
|
||||
sample_rate: f64,
|
||||
/// Analysis window in seconds.
|
||||
window_secs: f64,
|
||||
/// Maximum subcarrier slots.
|
||||
n_subcarriers: usize,
|
||||
/// Cardiac band low cutoff (Hz) -- 0.8 Hz = 48 BPM.
|
||||
freq_low: f64,
|
||||
/// Cardiac band high cutoff (Hz) -- 2.0 Hz = 120 BPM.
|
||||
freq_high: f64,
|
||||
/// IIR filter state.
|
||||
filter_state: IirState,
|
||||
/// Minimum subcarriers required for reliable HR estimation.
|
||||
min_subcarriers: usize,
|
||||
}
|
||||
|
||||
impl HeartRateExtractor {
|
||||
/// Create a new heart rate extractor.
|
||||
///
|
||||
/// - `n_subcarriers`: number of subcarrier channels.
|
||||
/// - `sample_rate`: input sample rate in Hz.
|
||||
/// - `window_secs`: analysis window length in seconds (default: 15).
|
||||
#[must_use]
|
||||
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
|
||||
pub fn new(n_subcarriers: usize, sample_rate: f64, window_secs: f64) -> Self {
|
||||
let capacity = (sample_rate * window_secs) as usize;
|
||||
Self {
|
||||
filtered_history: Vec::with_capacity(capacity),
|
||||
sample_rate,
|
||||
window_secs,
|
||||
n_subcarriers,
|
||||
freq_low: 0.8,
|
||||
freq_high: 2.0,
|
||||
filter_state: IirState::default(),
|
||||
min_subcarriers: 4,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with ESP32 defaults (56 subcarriers, 100 Hz, 15 s window).
|
||||
#[must_use]
|
||||
pub fn esp32_default() -> Self {
|
||||
Self::new(56, 100.0, 15.0)
|
||||
}
|
||||
|
||||
/// Extract heart rate from per-subcarrier residuals and phase data.
|
||||
///
|
||||
/// - `residuals`: amplitude residuals from the preprocessor.
|
||||
/// - `phases`: per-subcarrier unwrapped phases (radians).
|
||||
///
|
||||
/// Returns a `VitalEstimate` with heart rate in BPM, or `None`
|
||||
/// if insufficient data or too few subcarriers.
|
||||
pub fn extract(&mut self, residuals: &[f64], phases: &[f64]) -> Option<VitalEstimate> {
|
||||
let n = residuals.len().min(self.n_subcarriers).min(phases.len());
|
||||
if n == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// For cardiac signals, use phase-coherence weighted fusion.
|
||||
// Compute mean phase differential as a proxy for body-surface
|
||||
// displacement sensitivity.
|
||||
let phase_signal = compute_phase_coherence_signal(residuals, phases, n);
|
||||
|
||||
// Apply cardiac-band IIR bandpass filter
|
||||
let filtered = self.bandpass_filter(phase_signal);
|
||||
|
||||
// Append to history, enforce window limit
|
||||
self.filtered_history.push(filtered);
|
||||
let max_len = (self.sample_rate * self.window_secs) as usize;
|
||||
if self.filtered_history.len() > max_len {
|
||||
self.filtered_history.remove(0);
|
||||
}
|
||||
|
||||
// Need at least 5 seconds of data for cardiac detection
|
||||
let min_samples = (self.sample_rate * 5.0) as usize;
|
||||
if self.filtered_history.len() < min_samples {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Use autocorrelation to find the dominant periodicity
|
||||
let (period_samples, acf_peak) =
|
||||
autocorrelation_peak(&self.filtered_history, self.sample_rate, self.freq_low, self.freq_high);
|
||||
|
||||
if period_samples == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let frequency_hz = self.sample_rate / period_samples as f64;
|
||||
let bpm = frequency_hz * 60.0;
|
||||
|
||||
// Validate BPM is in physiological range (40-180 BPM)
|
||||
if !(40.0..=180.0).contains(&bpm) {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Confidence based on autocorrelation peak strength and subcarrier count
|
||||
let subcarrier_factor = if n >= self.min_subcarriers {
|
||||
1.0
|
||||
} else {
|
||||
n as f64 / self.min_subcarriers as f64
|
||||
};
|
||||
let confidence = (acf_peak * subcarrier_factor).clamp(0.0, 1.0);
|
||||
|
||||
let status = if confidence >= 0.6 && n >= self.min_subcarriers {
|
||||
VitalStatus::Valid
|
||||
} else if confidence >= 0.3 {
|
||||
VitalStatus::Degraded
|
||||
} else {
|
||||
VitalStatus::Unreliable
|
||||
};
|
||||
|
||||
Some(VitalEstimate {
|
||||
value_bpm: bpm,
|
||||
confidence,
|
||||
status,
|
||||
})
|
||||
}
|
||||
|
||||
/// 2nd-order IIR bandpass filter (cardiac band: 0.8-2.0 Hz).
|
||||
fn bandpass_filter(&mut self, input: f64) -> f64 {
|
||||
let state = &mut self.filter_state;
|
||||
|
||||
let omega_low = 2.0 * std::f64::consts::PI * self.freq_low / self.sample_rate;
|
||||
let omega_high = 2.0 * std::f64::consts::PI * self.freq_high / self.sample_rate;
|
||||
let bw = omega_high - omega_low;
|
||||
let center = f64::midpoint(omega_low, omega_high);
|
||||
|
||||
let r = 1.0 - bw / 2.0;
|
||||
let cos_w0 = center.cos();
|
||||
|
||||
let output =
|
||||
(1.0 - r) * (input - state.x2) + 2.0 * r * cos_w0 * state.y1 - r * r * state.y2;
|
||||
|
||||
state.x2 = state.x1;
|
||||
state.x1 = input;
|
||||
state.y2 = state.y1;
|
||||
state.y1 = output;
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Reset all filter state and history.
|
||||
pub fn reset(&mut self) {
|
||||
self.filtered_history.clear();
|
||||
self.filter_state = IirState::default();
|
||||
}
|
||||
|
||||
/// Current number of samples in the history buffer.
|
||||
#[must_use]
|
||||
pub fn history_len(&self) -> usize {
|
||||
self.filtered_history.len()
|
||||
}
|
||||
|
||||
/// Cardiac band cutoff frequencies.
|
||||
#[must_use]
|
||||
pub fn band(&self) -> (f64, f64) {
|
||||
(self.freq_low, self.freq_high)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute a phase-coherence-weighted signal from residuals and phases.
|
||||
///
|
||||
/// Combines amplitude residuals with inter-subcarrier phase coherence
|
||||
/// to enhance the cardiac signal. Subcarriers with similar phase
|
||||
/// derivatives are likely sensing the same body surface.
|
||||
fn compute_phase_coherence_signal(residuals: &[f64], phases: &[f64], n: usize) -> f64 {
|
||||
if n <= 1 {
|
||||
return residuals.first().copied().unwrap_or(0.0);
|
||||
}
|
||||
|
||||
// Compute inter-subcarrier phase differences as coherence weights.
|
||||
// Adjacent subcarriers with small phase differences are more coherent.
|
||||
let mut weighted_sum = 0.0;
|
||||
let mut weight_total = 0.0;
|
||||
|
||||
for i in 0..n {
|
||||
let coherence = if i + 1 < n {
|
||||
let phase_diff = (phases[i + 1] - phases[i]).abs();
|
||||
// Higher coherence when phase difference is small
|
||||
(-phase_diff).exp()
|
||||
} else if i > 0 {
|
||||
let phase_diff = (phases[i] - phases[i - 1]).abs();
|
||||
(-phase_diff).exp()
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
weighted_sum += residuals[i] * coherence;
|
||||
weight_total += coherence;
|
||||
}
|
||||
|
||||
if weight_total > 1e-15 {
|
||||
weighted_sum / weight_total
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the dominant periodicity via autocorrelation in the cardiac band.
|
||||
///
|
||||
/// Returns `(period_in_samples, peak_normalized_acf)`. If no peak is
|
||||
/// found, returns `(0, 0.0)`.
|
||||
fn autocorrelation_peak(
|
||||
signal: &[f64],
|
||||
sample_rate: f64,
|
||||
freq_low: f64,
|
||||
freq_high: f64,
|
||||
) -> (usize, f64) {
|
||||
let n = signal.len();
|
||||
if n < 4 {
|
||||
return (0, 0.0);
|
||||
}
|
||||
|
||||
// Lag range corresponding to the cardiac band
|
||||
let min_lag = (sample_rate / freq_high).floor() as usize; // highest freq = shortest period
|
||||
let max_lag = (sample_rate / freq_low).ceil() as usize; // lowest freq = longest period
|
||||
let max_lag = max_lag.min(n / 2);
|
||||
|
||||
if min_lag >= max_lag || min_lag >= n {
|
||||
return (0, 0.0);
|
||||
}
|
||||
|
||||
// Compute mean-subtracted signal
|
||||
let mean: f64 = signal.iter().sum::<f64>() / n as f64;
|
||||
|
||||
// Autocorrelation at lag 0 for normalisation
|
||||
let acf0: f64 = signal.iter().map(|&x| (x - mean) * (x - mean)).sum();
|
||||
if acf0 < 1e-15 {
|
||||
return (0, 0.0);
|
||||
}
|
||||
|
||||
// Search for the peak in the cardiac lag range
|
||||
let mut best_lag = 0;
|
||||
let mut best_acf = f64::MIN;
|
||||
|
||||
for lag in min_lag..=max_lag {
|
||||
let acf: f64 = signal
|
||||
.iter()
|
||||
.take(n - lag)
|
||||
.enumerate()
|
||||
.map(|(i, &x)| (x - mean) * (signal[i + lag] - mean))
|
||||
.sum();
|
||||
|
||||
let normalized = acf / acf0;
|
||||
if normalized > best_acf {
|
||||
best_acf = normalized;
|
||||
best_lag = lag;
|
||||
}
|
||||
}
|
||||
|
||||
if best_acf > 0.0 {
|
||||
(best_lag, best_acf)
|
||||
} else {
|
||||
(0, 0.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn no_data_returns_none() {
|
||||
let mut ext = HeartRateExtractor::new(4, 100.0, 15.0);
|
||||
assert!(ext.extract(&[], &[]).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn insufficient_history_returns_none() {
|
||||
let mut ext = HeartRateExtractor::new(2, 100.0, 15.0);
|
||||
for _ in 0..10 {
|
||||
assert!(ext.extract(&[0.1, 0.2], &[0.0, 0.0]).is_none());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sinusoidal_heartbeat_detected() {
|
||||
let sample_rate = 50.0;
|
||||
let mut ext = HeartRateExtractor::new(4, sample_rate, 20.0);
|
||||
let heart_freq = 1.2; // 72 BPM
|
||||
|
||||
// Generate 20 seconds of simulated cardiac signal across 4 subcarriers
|
||||
for i in 0..1000 {
|
||||
let t = i as f64 / sample_rate;
|
||||
let base = (2.0 * std::f64::consts::PI * heart_freq * t).sin();
|
||||
let residuals = vec![base * 0.1, base * 0.08, base * 0.12, base * 0.09];
|
||||
let phases = vec![0.0, 0.01, 0.02, 0.03]; // highly coherent
|
||||
ext.extract(&residuals, &phases);
|
||||
}
|
||||
|
||||
let final_residuals = vec![0.0; 4];
|
||||
let final_phases = vec![0.0; 4];
|
||||
let result = ext.extract(&final_residuals, &final_phases);
|
||||
|
||||
if let Some(est) = result {
|
||||
assert!(
|
||||
est.value_bpm > 40.0 && est.value_bpm < 180.0,
|
||||
"estimated BPM should be in cardiac range: {}",
|
||||
est.value_bpm,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reset_clears_state() {
|
||||
let mut ext = HeartRateExtractor::new(2, 100.0, 15.0);
|
||||
ext.extract(&[0.1, 0.2], &[0.0, 0.1]);
|
||||
assert!(ext.history_len() > 0);
|
||||
ext.reset();
|
||||
assert_eq!(ext.history_len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn band_returns_correct_values() {
|
||||
let ext = HeartRateExtractor::new(1, 100.0, 15.0);
|
||||
let (low, high) = ext.band();
|
||||
assert!((low - 0.8).abs() < f64::EPSILON);
|
||||
assert!((high - 2.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn autocorrelation_finds_known_period() {
|
||||
let sample_rate = 50.0;
|
||||
let freq = 1.0; // 1 Hz = period of 50 samples
|
||||
let signal: Vec<f64> = (0..500)
|
||||
.map(|i| (2.0 * std::f64::consts::PI * freq * i as f64 / sample_rate).sin())
|
||||
.collect();
|
||||
|
||||
let (period, acf) = autocorrelation_peak(&signal, sample_rate, 0.8, 2.0);
|
||||
assert!(period > 0, "should find a period");
|
||||
assert!(acf > 0.5, "autocorrelation peak should be strong: {acf}");
|
||||
|
||||
let estimated_freq = sample_rate / period as f64;
|
||||
assert!(
|
||||
(estimated_freq - 1.0).abs() < 0.1,
|
||||
"estimated frequency should be ~1 Hz, got {estimated_freq}",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn phase_coherence_single_subcarrier() {
|
||||
let result = compute_phase_coherence_signal(&[5.0], &[0.0], 1);
|
||||
assert!((result - 5.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn phase_coherence_multi_subcarrier() {
|
||||
// Two coherent subcarriers (small phase difference)
|
||||
let result = compute_phase_coherence_signal(&[1.0, 1.0], &[0.0, 0.01], 2);
|
||||
// Both weights should be ~1.0 (exp(-0.01) ~ 0.99), so result ~ 1.0
|
||||
assert!((result - 1.0).abs() < 0.1, "coherent result should be ~1.0: {result}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn esp32_default_creates_correctly() {
|
||||
let ext = HeartRateExtractor::esp32_default();
|
||||
assert_eq!(ext.n_subcarriers, 56);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
//! ESP32 CSI-grade vital sign extraction (ADR-021).
|
||||
//!
|
||||
//! Extracts heart rate and respiratory rate from WiFi Channel
|
||||
//! State Information using multi-subcarrier amplitude and phase
|
||||
//! analysis.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! The pipeline processes CSI frames through four stages:
|
||||
//!
|
||||
//! 1. **Preprocessing** ([`CsiVitalPreprocessor`]): EMA-based static
|
||||
//! component suppression, producing per-subcarrier residuals.
|
||||
//! 2. **Breathing extraction** ([`BreathingExtractor`]): Bandpass
|
||||
//! filtering (0.1-0.5 Hz) with zero-crossing analysis for
|
||||
//! respiratory rate.
|
||||
//! 3. **Heart rate extraction** ([`HeartRateExtractor`]): Bandpass
|
||||
//! filtering (0.8-2.0 Hz) with autocorrelation peak detection
|
||||
//! and inter-subcarrier phase coherence weighting.
|
||||
//! 4. **Anomaly detection** ([`VitalAnomalyDetector`]): Z-score
|
||||
//! analysis with Welford running statistics for clinical alerts
|
||||
//! (apnea, tachycardia, bradycardia).
|
||||
//!
|
||||
//! Results are stored in a [`VitalSignStore`] with configurable
|
||||
//! retention for historical analysis.
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```
|
||||
//! use wifi_densepose_vitals::{
|
||||
//! CsiVitalPreprocessor, BreathingExtractor, HeartRateExtractor,
|
||||
//! VitalAnomalyDetector, VitalSignStore, CsiFrame,
|
||||
//! VitalReading, VitalEstimate, VitalStatus,
|
||||
//! };
|
||||
//!
|
||||
//! let mut preprocessor = CsiVitalPreprocessor::new(56, 0.05);
|
||||
//! let mut breathing = BreathingExtractor::new(56, 100.0, 30.0);
|
||||
//! let mut heartrate = HeartRateExtractor::new(56, 100.0, 15.0);
|
||||
//! let mut anomaly = VitalAnomalyDetector::default_config();
|
||||
//! let mut store = VitalSignStore::new(3600);
|
||||
//!
|
||||
//! // Process a CSI frame
|
||||
//! let frame = CsiFrame {
|
||||
//! amplitudes: vec![1.0; 56],
|
||||
//! phases: vec![0.0; 56],
|
||||
//! n_subcarriers: 56,
|
||||
//! sample_index: 0,
|
||||
//! sample_rate_hz: 100.0,
|
||||
//! };
|
||||
//!
|
||||
//! if let Some(residuals) = preprocessor.process(&frame) {
|
||||
//! let weights = vec![1.0 / 56.0; 56];
|
||||
//! let rr = breathing.extract(&residuals, &weights);
|
||||
//! let hr = heartrate.extract(&residuals, &frame.phases);
|
||||
//!
|
||||
//! let reading = VitalReading {
|
||||
//! respiratory_rate: rr.unwrap_or_else(VitalEstimate::unavailable),
|
||||
//! heart_rate: hr.unwrap_or_else(VitalEstimate::unavailable),
|
||||
//! subcarrier_count: frame.n_subcarriers,
|
||||
//! signal_quality: 0.9,
|
||||
//! timestamp_secs: 0.0,
|
||||
//! };
|
||||
//!
|
||||
//! let alerts = anomaly.check(&reading);
|
||||
//! store.push(reading);
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
pub mod anomaly;
|
||||
pub mod breathing;
|
||||
pub mod heartrate;
|
||||
pub mod preprocessor;
|
||||
pub mod store;
|
||||
pub mod types;
|
||||
|
||||
pub use anomaly::{AnomalyAlert, VitalAnomalyDetector};
|
||||
pub use breathing::BreathingExtractor;
|
||||
pub use heartrate::HeartRateExtractor;
|
||||
pub use preprocessor::CsiVitalPreprocessor;
|
||||
pub use store::{VitalSignStore, VitalStats};
|
||||
pub use types::{CsiFrame, VitalEstimate, VitalReading, VitalStatus};
|
||||
@@ -0,0 +1,206 @@
|
||||
//! CSI vital sign preprocessor.
|
||||
//!
|
||||
//! Suppresses static subcarrier components and extracts the
|
||||
//! body-modulated signal residuals for vital sign analysis.
|
||||
//!
|
||||
//! Uses an EMA-based predictive filter (same pattern as
|
||||
//! [`PredictiveGate`](wifi_densepose_wifiscan::pipeline::PredictiveGate)
|
||||
//! in the wifiscan crate) operating on per-subcarrier amplitudes.
|
||||
//! The residuals represent deviations from the static environment
|
||||
//! baseline, isolating physiological movements (breathing, heartbeat).
|
||||
|
||||
use crate::types::CsiFrame;
|
||||
|
||||
/// EMA-based preprocessor that extracts body-modulated residuals
|
||||
/// from raw CSI subcarrier amplitudes.
|
||||
pub struct CsiVitalPreprocessor {
|
||||
/// EMA predictions per subcarrier.
|
||||
predictions: Vec<f64>,
|
||||
/// Whether each subcarrier slot has been initialised.
|
||||
initialized: Vec<bool>,
|
||||
/// EMA smoothing factor (lower = slower tracking, better static suppression).
|
||||
alpha: f64,
|
||||
/// Number of subcarrier slots.
|
||||
n_subcarriers: usize,
|
||||
}
|
||||
|
||||
impl CsiVitalPreprocessor {
|
||||
/// Create a new preprocessor.
|
||||
///
|
||||
/// - `n_subcarriers`: number of subcarrier slots to track.
|
||||
/// - `alpha`: EMA smoothing factor in `(0, 1)`. Lower values
|
||||
/// provide better static component suppression but slower
|
||||
/// adaptation. Default for vital signs: `0.05`.
|
||||
#[must_use]
|
||||
pub fn new(n_subcarriers: usize, alpha: f64) -> Self {
|
||||
Self {
|
||||
predictions: vec![0.0; n_subcarriers],
|
||||
initialized: vec![false; n_subcarriers],
|
||||
alpha: alpha.clamp(0.001, 0.999),
|
||||
n_subcarriers,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a preprocessor with defaults suitable for ESP32 CSI
|
||||
/// vital sign extraction (56 subcarriers, alpha = 0.05).
|
||||
#[must_use]
|
||||
pub fn esp32_default() -> Self {
|
||||
Self::new(56, 0.05)
|
||||
}
|
||||
|
||||
/// Process a CSI frame and return the residual vector.
|
||||
///
|
||||
/// The residuals represent the difference between observed and
|
||||
/// predicted (EMA) amplitudes. On the first frame for each
|
||||
/// subcarrier, the prediction is seeded and the raw amplitude
|
||||
/// is returned.
|
||||
///
|
||||
/// Returns `None` if the frame has zero subcarriers.
|
||||
pub fn process(&mut self, frame: &CsiFrame) -> Option<Vec<f64>> {
|
||||
let n = frame.amplitudes.len().min(self.n_subcarriers);
|
||||
if n == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut residuals = vec![0.0; n];
|
||||
|
||||
for (i, residual) in residuals.iter_mut().enumerate().take(n) {
|
||||
if self.initialized[i] {
|
||||
// Compute residual: observed - predicted
|
||||
*residual = frame.amplitudes[i] - self.predictions[i];
|
||||
// Update EMA prediction
|
||||
self.predictions[i] =
|
||||
self.alpha * frame.amplitudes[i] + (1.0 - self.alpha) * self.predictions[i];
|
||||
} else {
|
||||
// First observation: seed the prediction
|
||||
self.predictions[i] = frame.amplitudes[i];
|
||||
self.initialized[i] = true;
|
||||
// First-frame residual is zero (no prior to compare against)
|
||||
*residual = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
Some(residuals)
|
||||
}
|
||||
|
||||
/// Reset all predictions and initialisation state.
|
||||
pub fn reset(&mut self) {
|
||||
self.predictions.fill(0.0);
|
||||
self.initialized.fill(false);
|
||||
}
|
||||
|
||||
/// Current EMA smoothing factor.
|
||||
#[must_use]
|
||||
pub fn alpha(&self) -> f64 {
|
||||
self.alpha
|
||||
}
|
||||
|
||||
/// Update the EMA smoothing factor.
|
||||
pub fn set_alpha(&mut self, alpha: f64) {
|
||||
self.alpha = alpha.clamp(0.001, 0.999);
|
||||
}
|
||||
|
||||
/// Number of subcarrier slots.
|
||||
#[must_use]
|
||||
pub fn n_subcarriers(&self) -> usize {
|
||||
self.n_subcarriers
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::CsiFrame;
|
||||
|
||||
fn make_frame(amplitudes: Vec<f64>, n: usize) -> CsiFrame {
|
||||
let phases = vec![0.0; n];
|
||||
CsiFrame {
|
||||
amplitudes,
|
||||
phases,
|
||||
n_subcarriers: n,
|
||||
sample_index: 0,
|
||||
sample_rate_hz: 100.0,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_frame_returns_none() {
|
||||
let mut pp = CsiVitalPreprocessor::new(4, 0.05);
|
||||
let frame = make_frame(vec![], 0);
|
||||
assert!(pp.process(&frame).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_frame_residuals_are_zero() {
|
||||
let mut pp = CsiVitalPreprocessor::new(3, 0.05);
|
||||
let frame = make_frame(vec![1.0, 2.0, 3.0], 3);
|
||||
let residuals = pp.process(&frame).unwrap();
|
||||
assert_eq!(residuals.len(), 3);
|
||||
for &r in &residuals {
|
||||
assert!((r - 0.0).abs() < f64::EPSILON, "first frame residual should be 0");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn static_signal_residuals_converge_to_zero() {
|
||||
let mut pp = CsiVitalPreprocessor::new(2, 0.1);
|
||||
let frame = make_frame(vec![5.0, 10.0], 2);
|
||||
|
||||
// Seed
|
||||
pp.process(&frame);
|
||||
|
||||
// After many identical frames, residuals should be near zero
|
||||
let mut last_residuals = vec![0.0; 2];
|
||||
for _ in 0..100 {
|
||||
last_residuals = pp.process(&frame).unwrap();
|
||||
}
|
||||
|
||||
for &r in &last_residuals {
|
||||
assert!(r.abs() < 0.01, "residuals should converge to ~0 for static signal, got {r}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn step_change_produces_large_residual() {
|
||||
let mut pp = CsiVitalPreprocessor::new(1, 0.05);
|
||||
let frame1 = make_frame(vec![10.0], 1);
|
||||
|
||||
// Converge EMA
|
||||
pp.process(&frame1);
|
||||
for _ in 0..200 {
|
||||
pp.process(&frame1);
|
||||
}
|
||||
|
||||
// Step change
|
||||
let frame2 = make_frame(vec![20.0], 1);
|
||||
let residuals = pp.process(&frame2).unwrap();
|
||||
assert!(residuals[0] > 5.0, "step change should produce large residual, got {}", residuals[0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reset_clears_state() {
|
||||
let mut pp = CsiVitalPreprocessor::new(2, 0.1);
|
||||
let frame = make_frame(vec![1.0, 2.0], 2);
|
||||
pp.process(&frame);
|
||||
pp.reset();
|
||||
// After reset, next frame is treated as first
|
||||
let residuals = pp.process(&frame).unwrap();
|
||||
for &r in &residuals {
|
||||
assert!((r - 0.0).abs() < f64::EPSILON);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn alpha_clamped() {
|
||||
let pp = CsiVitalPreprocessor::new(1, -5.0);
|
||||
assert!(pp.alpha() > 0.0);
|
||||
let pp = CsiVitalPreprocessor::new(1, 100.0);
|
||||
assert!(pp.alpha() < 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn esp32_default_has_correct_subcarriers() {
|
||||
let pp = CsiVitalPreprocessor::esp32_default();
|
||||
assert_eq!(pp.n_subcarriers(), 56);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,290 @@
|
||||
//! Vital sign time series store.
|
||||
//!
|
||||
//! Stores vital sign readings with configurable retention.
|
||||
//! Designed for upgrade to `TieredStore` when `ruvector-temporal-tensor`
|
||||
//! becomes available (ADR-021 phase 2).
|
||||
|
||||
use crate::types::{VitalReading, VitalStatus};
|
||||
|
||||
/// Simple vital sign store with capacity-limited ring buffer semantics.
|
||||
pub struct VitalSignStore {
|
||||
/// Stored readings (oldest first).
|
||||
readings: Vec<VitalReading>,
|
||||
/// Maximum number of readings to retain.
|
||||
max_readings: usize,
|
||||
}
|
||||
|
||||
/// Summary statistics for stored vital sign readings.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VitalStats {
|
||||
/// Number of readings in the store.
|
||||
pub count: usize,
|
||||
/// Mean respiratory rate (BPM).
|
||||
pub rr_mean: f64,
|
||||
/// Mean heart rate (BPM).
|
||||
pub hr_mean: f64,
|
||||
/// Min respiratory rate (BPM).
|
||||
pub rr_min: f64,
|
||||
/// Max respiratory rate (BPM).
|
||||
pub rr_max: f64,
|
||||
/// Min heart rate (BPM).
|
||||
pub hr_min: f64,
|
||||
/// Max heart rate (BPM).
|
||||
pub hr_max: f64,
|
||||
/// Fraction of readings with Valid status.
|
||||
pub valid_fraction: f64,
|
||||
}
|
||||
|
||||
impl VitalSignStore {
|
||||
/// Create a new store with a given maximum capacity.
|
||||
///
|
||||
/// When the capacity is exceeded, the oldest readings are evicted.
|
||||
#[must_use]
|
||||
pub fn new(max_readings: usize) -> Self {
|
||||
Self {
|
||||
readings: Vec::with_capacity(max_readings.min(4096)),
|
||||
max_readings: max_readings.max(1),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with default capacity (3600 readings ~ 1 hour at 1 Hz).
|
||||
#[must_use]
|
||||
pub fn default_capacity() -> Self {
|
||||
Self::new(3600)
|
||||
}
|
||||
|
||||
/// Push a new reading into the store.
|
||||
///
|
||||
/// If the store is at capacity, the oldest reading is evicted.
|
||||
pub fn push(&mut self, reading: VitalReading) {
|
||||
if self.readings.len() >= self.max_readings {
|
||||
self.readings.remove(0);
|
||||
}
|
||||
self.readings.push(reading);
|
||||
}
|
||||
|
||||
/// Get the most recent reading, if any.
|
||||
#[must_use]
|
||||
pub fn latest(&self) -> Option<&VitalReading> {
|
||||
self.readings.last()
|
||||
}
|
||||
|
||||
/// Get the last `n` readings (most recent last).
|
||||
///
|
||||
/// Returns fewer than `n` if the store contains fewer readings.
|
||||
#[must_use]
|
||||
pub fn history(&self, n: usize) -> &[VitalReading] {
|
||||
let start = self.readings.len().saturating_sub(n);
|
||||
&self.readings[start..]
|
||||
}
|
||||
|
||||
/// Compute summary statistics over all stored readings.
|
||||
///
|
||||
/// Returns `None` if the store is empty.
|
||||
#[must_use]
|
||||
pub fn stats(&self) -> Option<VitalStats> {
|
||||
if self.readings.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let n = self.readings.len() as f64;
|
||||
let mut rr_sum = 0.0;
|
||||
let mut hr_sum = 0.0;
|
||||
let mut rr_min = f64::MAX;
|
||||
let mut rr_max = f64::MIN;
|
||||
let mut hr_min = f64::MAX;
|
||||
let mut hr_max = f64::MIN;
|
||||
let mut valid_count = 0_usize;
|
||||
|
||||
for r in &self.readings {
|
||||
let rr = r.respiratory_rate.value_bpm;
|
||||
let hr = r.heart_rate.value_bpm;
|
||||
rr_sum += rr;
|
||||
hr_sum += hr;
|
||||
rr_min = rr_min.min(rr);
|
||||
rr_max = rr_max.max(rr);
|
||||
hr_min = hr_min.min(hr);
|
||||
hr_max = hr_max.max(hr);
|
||||
|
||||
if r.respiratory_rate.status == VitalStatus::Valid
|
||||
&& r.heart_rate.status == VitalStatus::Valid
|
||||
{
|
||||
valid_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Some(VitalStats {
|
||||
count: self.readings.len(),
|
||||
rr_mean: rr_sum / n,
|
||||
hr_mean: hr_sum / n,
|
||||
rr_min,
|
||||
rr_max,
|
||||
hr_min,
|
||||
hr_max,
|
||||
valid_fraction: valid_count as f64 / n,
|
||||
})
|
||||
}
|
||||
|
||||
/// Number of readings currently stored.
|
||||
#[must_use]
|
||||
pub fn len(&self) -> usize {
|
||||
self.readings.len()
|
||||
}
|
||||
|
||||
/// Whether the store is empty.
|
||||
#[must_use]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.readings.is_empty()
|
||||
}
|
||||
|
||||
/// Maximum capacity of the store.
|
||||
#[must_use]
|
||||
pub fn capacity(&self) -> usize {
|
||||
self.max_readings
|
||||
}
|
||||
|
||||
/// Clear all stored readings.
|
||||
pub fn clear(&mut self) {
|
||||
self.readings.clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::{VitalEstimate, VitalReading, VitalStatus};
|
||||
|
||||
fn make_reading(rr: f64, hr: f64) -> VitalReading {
|
||||
VitalReading {
|
||||
respiratory_rate: VitalEstimate {
|
||||
value_bpm: rr,
|
||||
confidence: 0.9,
|
||||
status: VitalStatus::Valid,
|
||||
},
|
||||
heart_rate: VitalEstimate {
|
||||
value_bpm: hr,
|
||||
confidence: 0.85,
|
||||
status: VitalStatus::Valid,
|
||||
},
|
||||
subcarrier_count: 56,
|
||||
signal_quality: 0.9,
|
||||
timestamp_secs: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_store() {
|
||||
let store = VitalSignStore::new(10);
|
||||
assert!(store.is_empty());
|
||||
assert_eq!(store.len(), 0);
|
||||
assert!(store.latest().is_none());
|
||||
assert!(store.stats().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn push_and_retrieve() {
|
||||
let mut store = VitalSignStore::new(10);
|
||||
store.push(make_reading(15.0, 72.0));
|
||||
assert_eq!(store.len(), 1);
|
||||
assert!(!store.is_empty());
|
||||
|
||||
let latest = store.latest().unwrap();
|
||||
assert!((latest.respiratory_rate.value_bpm - 15.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eviction_at_capacity() {
|
||||
let mut store = VitalSignStore::new(3);
|
||||
store.push(make_reading(10.0, 60.0));
|
||||
store.push(make_reading(15.0, 72.0));
|
||||
store.push(make_reading(20.0, 80.0));
|
||||
assert_eq!(store.len(), 3);
|
||||
|
||||
// Push one more; oldest should be evicted
|
||||
store.push(make_reading(25.0, 90.0));
|
||||
assert_eq!(store.len(), 3);
|
||||
|
||||
// Oldest should now be 15.0, not 10.0
|
||||
let oldest = &store.history(10)[0];
|
||||
assert!((oldest.respiratory_rate.value_bpm - 15.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn history_returns_last_n() {
|
||||
let mut store = VitalSignStore::new(10);
|
||||
for i in 0..5 {
|
||||
store.push(make_reading(10.0 + i as f64, 60.0 + i as f64));
|
||||
}
|
||||
|
||||
let last3 = store.history(3);
|
||||
assert_eq!(last3.len(), 3);
|
||||
assert!((last3[0].respiratory_rate.value_bpm - 12.0).abs() < f64::EPSILON);
|
||||
assert!((last3[2].respiratory_rate.value_bpm - 14.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn history_when_fewer_than_n() {
|
||||
let mut store = VitalSignStore::new(10);
|
||||
store.push(make_reading(15.0, 72.0));
|
||||
let all = store.history(100);
|
||||
assert_eq!(all.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stats_computation() {
|
||||
let mut store = VitalSignStore::new(10);
|
||||
store.push(make_reading(10.0, 60.0));
|
||||
store.push(make_reading(20.0, 80.0));
|
||||
store.push(make_reading(15.0, 70.0));
|
||||
|
||||
let stats = store.stats().unwrap();
|
||||
assert_eq!(stats.count, 3);
|
||||
assert!((stats.rr_mean - 15.0).abs() < f64::EPSILON);
|
||||
assert!((stats.hr_mean - 70.0).abs() < f64::EPSILON);
|
||||
assert!((stats.rr_min - 10.0).abs() < f64::EPSILON);
|
||||
assert!((stats.rr_max - 20.0).abs() < f64::EPSILON);
|
||||
assert!((stats.hr_min - 60.0).abs() < f64::EPSILON);
|
||||
assert!((stats.hr_max - 80.0).abs() < f64::EPSILON);
|
||||
assert!((stats.valid_fraction - 1.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stats_valid_fraction() {
|
||||
let mut store = VitalSignStore::new(10);
|
||||
store.push(make_reading(15.0, 72.0)); // Valid
|
||||
store.push(VitalReading {
|
||||
respiratory_rate: VitalEstimate {
|
||||
value_bpm: 15.0,
|
||||
confidence: 0.3,
|
||||
status: VitalStatus::Degraded,
|
||||
},
|
||||
heart_rate: VitalEstimate {
|
||||
value_bpm: 72.0,
|
||||
confidence: 0.8,
|
||||
status: VitalStatus::Valid,
|
||||
},
|
||||
subcarrier_count: 56,
|
||||
signal_quality: 0.5,
|
||||
timestamp_secs: 1.0,
|
||||
});
|
||||
|
||||
let stats = store.stats().unwrap();
|
||||
assert!((stats.valid_fraction - 0.5).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clear_empties_store() {
|
||||
let mut store = VitalSignStore::new(10);
|
||||
store.push(make_reading(15.0, 72.0));
|
||||
store.push(make_reading(16.0, 73.0));
|
||||
assert_eq!(store.len(), 2);
|
||||
store.clear();
|
||||
assert!(store.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_capacity_is_3600() {
|
||||
let store = VitalSignStore::default_capacity();
|
||||
assert_eq!(store.capacity(), 3600);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
//! Vital sign domain types (ADR-021).
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Status of a vital sign measurement.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub enum VitalStatus {
|
||||
/// Valid measurement with clinical-grade confidence.
|
||||
Valid,
|
||||
/// Measurement present but with reduced confidence.
|
||||
Degraded,
|
||||
/// Measurement unreliable (e.g., single RSSI source).
|
||||
Unreliable,
|
||||
/// No measurement possible.
|
||||
Unavailable,
|
||||
}
|
||||
|
||||
/// A single vital sign estimate.
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub struct VitalEstimate {
|
||||
/// Estimated value in BPM (beats/breaths per minute).
|
||||
pub value_bpm: f64,
|
||||
/// Confidence in the estimate [0.0, 1.0].
|
||||
pub confidence: f64,
|
||||
/// Measurement status.
|
||||
pub status: VitalStatus,
|
||||
}
|
||||
|
||||
/// Combined vital sign reading.
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub struct VitalReading {
|
||||
/// Respiratory rate estimate.
|
||||
pub respiratory_rate: VitalEstimate,
|
||||
/// Heart rate estimate.
|
||||
pub heart_rate: VitalEstimate,
|
||||
/// Number of subcarriers used.
|
||||
pub subcarrier_count: usize,
|
||||
/// Signal quality score [0.0, 1.0].
|
||||
pub signal_quality: f64,
|
||||
/// Timestamp (seconds since epoch).
|
||||
pub timestamp_secs: f64,
|
||||
}
|
||||
|
||||
/// Input frame for the vital sign pipeline.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CsiFrame {
|
||||
/// Per-subcarrier amplitudes.
|
||||
pub amplitudes: Vec<f64>,
|
||||
/// Per-subcarrier phases (radians).
|
||||
pub phases: Vec<f64>,
|
||||
/// Number of subcarriers.
|
||||
pub n_subcarriers: usize,
|
||||
/// Sample index (monotonically increasing).
|
||||
pub sample_index: u64,
|
||||
/// Sample rate in Hz.
|
||||
pub sample_rate_hz: f64,
|
||||
}
|
||||
|
||||
impl CsiFrame {
|
||||
/// Create a new CSI frame, validating that amplitude and phase
|
||||
/// vectors match the declared subcarrier count.
|
||||
///
|
||||
/// Returns `None` if the lengths are inconsistent.
|
||||
pub fn new(
|
||||
amplitudes: Vec<f64>,
|
||||
phases: Vec<f64>,
|
||||
n_subcarriers: usize,
|
||||
sample_index: u64,
|
||||
sample_rate_hz: f64,
|
||||
) -> Option<Self> {
|
||||
if amplitudes.len() != n_subcarriers || phases.len() != n_subcarriers {
|
||||
return None;
|
||||
}
|
||||
Some(Self {
|
||||
amplitudes,
|
||||
phases,
|
||||
n_subcarriers,
|
||||
sample_index,
|
||||
sample_rate_hz,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl VitalEstimate {
|
||||
/// Create an unavailable estimate (no measurement possible).
|
||||
pub fn unavailable() -> Self {
|
||||
Self {
|
||||
value_bpm: 0.0,
|
||||
confidence: 0.0,
|
||||
status: VitalStatus::Unavailable,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn vital_status_equality() {
|
||||
assert_eq!(VitalStatus::Valid, VitalStatus::Valid);
|
||||
assert_ne!(VitalStatus::Valid, VitalStatus::Degraded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vital_estimate_unavailable() {
|
||||
let est = VitalEstimate::unavailable();
|
||||
assert_eq!(est.status, VitalStatus::Unavailable);
|
||||
assert!((est.value_bpm - 0.0).abs() < f64::EPSILON);
|
||||
assert!((est.confidence - 0.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csi_frame_new_valid() {
|
||||
let frame = CsiFrame::new(
|
||||
vec![1.0, 2.0, 3.0],
|
||||
vec![0.1, 0.2, 0.3],
|
||||
3,
|
||||
0,
|
||||
100.0,
|
||||
);
|
||||
assert!(frame.is_some());
|
||||
let f = frame.unwrap();
|
||||
assert_eq!(f.n_subcarriers, 3);
|
||||
assert_eq!(f.amplitudes.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csi_frame_new_mismatched_lengths() {
|
||||
let frame = CsiFrame::new(
|
||||
vec![1.0, 2.0],
|
||||
vec![0.1, 0.2, 0.3],
|
||||
3,
|
||||
0,
|
||||
100.0,
|
||||
);
|
||||
assert!(frame.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csi_frame_clone() {
|
||||
let frame = CsiFrame::new(vec![1.0], vec![0.5], 1, 42, 50.0).unwrap();
|
||||
let cloned = frame.clone();
|
||||
assert_eq!(cloned.sample_index, 42);
|
||||
assert_eq!(cloned.n_subcarriers, 1);
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
#[test]
|
||||
fn vital_reading_serde_roundtrip() {
|
||||
let reading = VitalReading {
|
||||
respiratory_rate: VitalEstimate {
|
||||
value_bpm: 15.0,
|
||||
confidence: 0.9,
|
||||
status: VitalStatus::Valid,
|
||||
},
|
||||
heart_rate: VitalEstimate {
|
||||
value_bpm: 72.0,
|
||||
confidence: 0.85,
|
||||
status: VitalStatus::Valid,
|
||||
},
|
||||
subcarrier_count: 56,
|
||||
signal_quality: 0.92,
|
||||
timestamp_secs: 1_700_000_000.0,
|
||||
};
|
||||
let json = serde_json::to_string(&reading).unwrap();
|
||||
let parsed: VitalReading = serde_json::from_str(&json).unwrap();
|
||||
assert!((parsed.heart_rate.value_bpm - 72.0).abs() < f64::EPSILON);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
[package]
|
||||
name = "wifi-densepose-wifiscan"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description = "Multi-BSSID WiFi scanning domain layer for enhanced Windows WiFi DensePose sensing (ADR-022)"
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
# Logging
|
||||
tracing.workspace = true
|
||||
|
||||
# Serialization (optional, for domain types)
|
||||
serde = { workspace = true, optional = true }
|
||||
|
||||
# Async runtime (optional, for Tier 2 async scanning)
|
||||
tokio = { workspace = true, optional = true }
|
||||
|
||||
[features]
|
||||
default = ["serde", "pipeline"]
|
||||
serde = ["dep:serde"]
|
||||
pipeline = []
|
||||
## Tier 2: enables async scan_async() method on WlanApiScanner via tokio
|
||||
wlanapi = ["dep:tokio"]
|
||||
|
||||
[lints.rust]
|
||||
unsafe_code = "forbid"
|
||||
|
||||
[lints.clippy]
|
||||
all = "warn"
|
||||
pedantic = "warn"
|
||||
doc_markdown = "allow"
|
||||
module_name_repetitions = "allow"
|
||||
must_use_candidate = "allow"
|
||||
missing_errors_doc = "allow"
|
||||
missing_panics_doc = "allow"
|
||||
cast_precision_loss = "allow"
|
||||
cast_lossless = "allow"
|
||||
many_single_char_names = "allow"
|
||||
uninlined_format_args = "allow"
|
||||
assigning_clones = "allow"
|
||||
@@ -0,0 +1,12 @@
|
||||
//! Adapter implementations for the [`WlanScanPort`] port.
|
||||
//!
|
||||
//! Each adapter targets a specific platform scanning mechanism:
|
||||
//! - [`NetshBssidScanner`]: Tier 1 -- parses `netsh wlan show networks mode=bssid`.
|
||||
//! - [`WlanApiScanner`]: Tier 2 -- async wrapper with metrics and future native FFI path.
|
||||
|
||||
pub(crate) mod netsh_scanner;
|
||||
pub mod wlanapi_scanner;
|
||||
|
||||
pub use netsh_scanner::NetshBssidScanner;
|
||||
pub use netsh_scanner::parse_netsh_output;
|
||||
pub use wlanapi_scanner::WlanApiScanner;
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,474 @@
|
||||
//! Tier 2: Windows WLAN API adapter for higher scan rates.
|
||||
//!
|
||||
//! This module provides a higher-rate scanning interface that targets 10-20 Hz
|
||||
//! scan rates compared to the Tier 1 [`NetshBssidScanner`]'s ~2 Hz limitation
|
||||
//! (caused by subprocess spawn overhead per scan).
|
||||
//!
|
||||
//! # Current implementation
|
||||
//!
|
||||
//! The adapter currently wraps [`NetshBssidScanner`] and provides:
|
||||
//!
|
||||
//! - **Synchronous scanning** via [`WlanScanPort`] trait implementation
|
||||
//! - **Async scanning** (feature-gated behind `"wlanapi"`) via
|
||||
//! `tokio::task::spawn_blocking`
|
||||
//! - **Scan metrics** (count, timing) for performance monitoring
|
||||
//! - **Rate estimation** based on observed inter-scan intervals
|
||||
//!
|
||||
//! # Future: native `wlanapi.dll` FFI
|
||||
//!
|
||||
//! When native WLAN API bindings are available, this adapter will call:
|
||||
//!
|
||||
//! - `WlanOpenHandle` -- open a session to the WLAN service
|
||||
//! - `WlanEnumInterfaces` -- discover WLAN adapters
|
||||
//! - `WlanScan` -- trigger a fresh scan
|
||||
//! - `WlanGetNetworkBssList` -- retrieve raw BSS entries with RSSI
|
||||
//! - `WlanCloseHandle` -- clean up the session handle
|
||||
//!
|
||||
//! This eliminates the `netsh.exe` process-spawn bottleneck and enables
|
||||
//! true 10-20 Hz scan rates suitable for real-time sensing.
|
||||
//!
|
||||
//! # Platform
|
||||
//!
|
||||
//! Windows only. On other platforms this module is not compiled.
|
||||
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::adapter::netsh_scanner::NetshBssidScanner;
|
||||
use crate::domain::bssid::BssidObservation;
|
||||
use crate::error::WifiScanError;
|
||||
use crate::port::WlanScanPort;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Scan metrics
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Accumulated metrics from scan operations.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScanMetrics {
|
||||
/// Total number of scans performed since creation.
|
||||
pub scan_count: u64,
|
||||
/// Total number of BSSIDs observed across all scans.
|
||||
pub total_bssids_observed: u64,
|
||||
/// Duration of the most recent scan.
|
||||
pub last_scan_duration: Option<Duration>,
|
||||
/// Estimated scan rate in Hz based on the last scan duration.
|
||||
/// Returns `None` if no scans have been performed yet.
|
||||
pub estimated_rate_hz: Option<f64>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// WlanApiScanner
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Tier 2 WLAN API scanner with async support and scan metrics.
|
||||
///
|
||||
/// Currently wraps [`NetshBssidScanner`] with performance instrumentation.
|
||||
/// When native WLAN API bindings become available, the inner implementation
|
||||
/// will switch to `WlanGetNetworkBssList` for approximately 10x higher scan
|
||||
/// rates without changing the public interface.
|
||||
///
|
||||
/// # Example (sync)
|
||||
///
|
||||
/// ```no_run
|
||||
/// use wifi_densepose_wifiscan::adapter::wlanapi_scanner::WlanApiScanner;
|
||||
/// use wifi_densepose_wifiscan::port::WlanScanPort;
|
||||
///
|
||||
/// let scanner = WlanApiScanner::new();
|
||||
/// let observations = scanner.scan().unwrap();
|
||||
/// for obs in &observations {
|
||||
/// println!("{}: {} dBm", obs.bssid, obs.rssi_dbm);
|
||||
/// }
|
||||
/// println!("metrics: {:?}", scanner.metrics());
|
||||
/// ```
|
||||
pub struct WlanApiScanner {
|
||||
/// The underlying Tier 1 scanner.
|
||||
inner: NetshBssidScanner,
|
||||
|
||||
/// Number of scans performed.
|
||||
scan_count: AtomicU64,
|
||||
|
||||
/// Total BSSIDs observed across all scans.
|
||||
total_bssids: AtomicU64,
|
||||
|
||||
/// Timestamp of the most recent scan start (for rate estimation).
|
||||
///
|
||||
/// Uses `std::sync::Mutex` because `Instant` is not atomic but we need
|
||||
/// interior mutability. The lock duration is negligible (one write per
|
||||
/// scan) so contention is not a concern.
|
||||
last_scan_start: std::sync::Mutex<Option<Instant>>,
|
||||
|
||||
/// Duration of the most recent scan.
|
||||
last_scan_duration: std::sync::Mutex<Option<Duration>>,
|
||||
}
|
||||
|
||||
impl WlanApiScanner {
|
||||
/// Create a new Tier 2 scanner.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: NetshBssidScanner::new(),
|
||||
scan_count: AtomicU64::new(0),
|
||||
total_bssids: AtomicU64::new(0),
|
||||
last_scan_start: std::sync::Mutex::new(None),
|
||||
last_scan_duration: std::sync::Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return accumulated scan metrics.
|
||||
pub fn metrics(&self) -> ScanMetrics {
|
||||
let scan_count = self.scan_count.load(Ordering::Relaxed);
|
||||
let total_bssids_observed = self.total_bssids.load(Ordering::Relaxed);
|
||||
let last_scan_duration =
|
||||
*self.last_scan_duration.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let estimated_rate_hz = last_scan_duration.map(|d| {
|
||||
let secs = d.as_secs_f64();
|
||||
if secs > 0.0 {
|
||||
1.0 / secs
|
||||
} else {
|
||||
f64::INFINITY
|
||||
}
|
||||
});
|
||||
|
||||
ScanMetrics {
|
||||
scan_count,
|
||||
total_bssids_observed,
|
||||
last_scan_duration,
|
||||
estimated_rate_hz,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the number of scans performed so far.
|
||||
pub fn scan_count(&self) -> u64 {
|
||||
self.scan_count.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Perform a synchronous scan with timing instrumentation.
|
||||
///
|
||||
/// This is the core scan method that both the [`WlanScanPort`] trait
|
||||
/// implementation and the async wrapper delegate to.
|
||||
fn scan_instrumented(&self) -> Result<Vec<BssidObservation>, WifiScanError> {
|
||||
let start = Instant::now();
|
||||
|
||||
// Record scan start time.
|
||||
if let Ok(mut guard) = self.last_scan_start.lock() {
|
||||
*guard = Some(start);
|
||||
}
|
||||
|
||||
// Delegate to the Tier 1 scanner.
|
||||
let results = self.inner.scan_sync()?;
|
||||
|
||||
// Record metrics.
|
||||
let elapsed = start.elapsed();
|
||||
if let Ok(mut guard) = self.last_scan_duration.lock() {
|
||||
*guard = Some(elapsed);
|
||||
}
|
||||
|
||||
self.scan_count.fetch_add(1, Ordering::Relaxed);
|
||||
self.total_bssids
|
||||
.fetch_add(results.len() as u64, Ordering::Relaxed);
|
||||
|
||||
tracing::debug!(
|
||||
scan_count = self.scan_count.load(Ordering::Relaxed),
|
||||
bssid_count = results.len(),
|
||||
elapsed_ms = elapsed.as_millis(),
|
||||
"Tier 2 scan complete"
|
||||
);
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Perform an async scan by offloading the blocking netsh call to
|
||||
/// a background thread.
|
||||
///
|
||||
/// This is gated behind the `"wlanapi"` feature because it requires
|
||||
/// the `tokio` runtime dependency.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`WifiScanError::ScanFailed`] if the background task panics
|
||||
/// or is cancelled, or propagates any error from the underlying scan.
|
||||
#[cfg(feature = "wlanapi")]
|
||||
pub async fn scan_async(&self) -> Result<Vec<BssidObservation>, WifiScanError> {
|
||||
// We need to create a fresh scanner for the blocking task because
|
||||
// `&self` is not `Send` across the spawn_blocking boundary.
|
||||
// `NetshBssidScanner` is cheap (zero-size struct) so this is fine.
|
||||
let inner = NetshBssidScanner::new();
|
||||
let start = Instant::now();
|
||||
|
||||
let results = tokio::task::spawn_blocking(move || inner.scan_sync())
|
||||
.await
|
||||
.map_err(|e| WifiScanError::ScanFailed {
|
||||
reason: format!("async scan task failed: {e}"),
|
||||
})??;
|
||||
|
||||
// Record metrics.
|
||||
let elapsed = start.elapsed();
|
||||
if let Ok(mut guard) = self.last_scan_duration.lock() {
|
||||
*guard = Some(elapsed);
|
||||
}
|
||||
self.scan_count.fetch_add(1, Ordering::Relaxed);
|
||||
self.total_bssids
|
||||
.fetch_add(results.len() as u64, Ordering::Relaxed);
|
||||
|
||||
tracing::debug!(
|
||||
scan_count = self.scan_count.load(Ordering::Relaxed),
|
||||
bssid_count = results.len(),
|
||||
elapsed_ms = elapsed.as_millis(),
|
||||
"Tier 2 async scan complete"
|
||||
);
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WlanApiScanner {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// WlanScanPort implementation (sync)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
impl WlanScanPort for WlanApiScanner {
|
||||
fn scan(&self) -> Result<Vec<BssidObservation>, WifiScanError> {
|
||||
self.scan_instrumented()
|
||||
}
|
||||
|
||||
fn connected(&self) -> Result<Option<BssidObservation>, WifiScanError> {
|
||||
// Not yet implemented for Tier 2 -- fall back to a full scan and
|
||||
// return the strongest signal (heuristic for "likely connected").
|
||||
let mut results = self.scan_instrumented()?;
|
||||
if results.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
// Sort by signal strength descending; return the strongest.
|
||||
results.sort_by(|a, b| {
|
||||
b.rssi_dbm
|
||||
.partial_cmp(&a.rssi_dbm)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
Ok(Some(results.swap_remove(0)))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Native WLAN API constants and frequency utilities
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Native WLAN API constants and frequency conversion utilities.
|
||||
///
|
||||
/// When implemented, this will contain:
|
||||
///
|
||||
/// ```ignore
|
||||
/// extern "system" {
|
||||
/// fn WlanOpenHandle(
|
||||
/// dwClientVersion: u32,
|
||||
/// pReserved: *const std::ffi::c_void,
|
||||
/// pdwNegotiatedVersion: *mut u32,
|
||||
/// phClientHandle: *mut HANDLE,
|
||||
/// ) -> u32;
|
||||
///
|
||||
/// fn WlanEnumInterfaces(
|
||||
/// hClientHandle: HANDLE,
|
||||
/// pReserved: *const std::ffi::c_void,
|
||||
/// ppInterfaceList: *mut *mut WLAN_INTERFACE_INFO_LIST,
|
||||
/// ) -> u32;
|
||||
///
|
||||
/// fn WlanGetNetworkBssList(
|
||||
/// hClientHandle: HANDLE,
|
||||
/// pInterfaceGuid: *const GUID,
|
||||
/// pDot11Ssid: *const DOT11_SSID,
|
||||
/// dot11BssType: DOT11_BSS_TYPE,
|
||||
/// bSecurityEnabled: BOOL,
|
||||
/// pReserved: *const std::ffi::c_void,
|
||||
/// ppWlanBssList: *mut *mut WLAN_BSS_LIST,
|
||||
/// ) -> u32;
|
||||
///
|
||||
/// fn WlanCloseHandle(
|
||||
/// hClientHandle: HANDLE,
|
||||
/// pReserved: *const std::ffi::c_void,
|
||||
/// ) -> u32;
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// The native API returns `WLAN_BSS_ENTRY` structs that include:
|
||||
/// - `dot11Bssid` (6-byte MAC)
|
||||
/// - `lRssi` (dBm as i32)
|
||||
/// - `ulChCenterFrequency` (kHz, from which channel/band are derived)
|
||||
/// - `dot11BssPhyType` (maps to `RadioType`)
|
||||
///
|
||||
/// This eliminates the netsh subprocess overhead entirely.
|
||||
#[allow(dead_code)]
|
||||
mod wlan_ffi {
|
||||
/// WLAN API client version 2 (Vista+).
|
||||
pub const WLAN_CLIENT_VERSION_2: u32 = 2;
|
||||
|
||||
/// BSS type for infrastructure networks.
|
||||
pub const DOT11_BSS_TYPE_INFRASTRUCTURE: u32 = 1;
|
||||
|
||||
/// Convert a center frequency in kHz to an 802.11 channel number.
|
||||
///
|
||||
/// Covers 2.4 GHz (ch 1-14), 5 GHz (ch 36-177), and 6 GHz bands.
|
||||
#[allow(clippy::cast_possible_truncation)] // Channel numbers always fit in u8
|
||||
pub fn freq_khz_to_channel(frequency_khz: u32) -> u8 {
|
||||
let mhz = frequency_khz / 1000;
|
||||
match mhz {
|
||||
// 2.4 GHz band
|
||||
2412..=2472 => ((mhz - 2407) / 5) as u8,
|
||||
2484 => 14,
|
||||
// 5 GHz band
|
||||
5170..=5825 => ((mhz - 5000) / 5) as u8,
|
||||
// 6 GHz band (Wi-Fi 6E)
|
||||
5955..=7115 => ((mhz - 5950) / 5) as u8,
|
||||
_ => 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a center frequency in kHz to a band type discriminant.
|
||||
///
|
||||
/// Returns 0 for 2.4 GHz, 1 for 5 GHz, 2 for 6 GHz.
|
||||
pub fn freq_khz_to_band(frequency_khz: u32) -> u8 {
|
||||
let mhz = frequency_khz / 1000;
|
||||
match mhz {
|
||||
5000..=5900 => 1, // 5 GHz
|
||||
5925..=7200 => 2, // 6 GHz
|
||||
_ => 0, // 2.4 GHz and unknown
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// Tests
|
||||
// ===========================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// -- construction ---------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn new_creates_scanner_with_zero_metrics() {
|
||||
let scanner = WlanApiScanner::new();
|
||||
assert_eq!(scanner.scan_count(), 0);
|
||||
|
||||
let m = scanner.metrics();
|
||||
assert_eq!(m.scan_count, 0);
|
||||
assert_eq!(m.total_bssids_observed, 0);
|
||||
assert!(m.last_scan_duration.is_none());
|
||||
assert!(m.estimated_rate_hz.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_creates_scanner() {
|
||||
let scanner = WlanApiScanner::default();
|
||||
assert_eq!(scanner.scan_count(), 0);
|
||||
}
|
||||
|
||||
// -- frequency conversion (FFI placeholder) --------------------------------
|
||||
|
||||
#[test]
|
||||
fn freq_khz_to_channel_2_4ghz() {
|
||||
assert_eq!(wlan_ffi::freq_khz_to_channel(2_412_000), 1);
|
||||
assert_eq!(wlan_ffi::freq_khz_to_channel(2_437_000), 6);
|
||||
assert_eq!(wlan_ffi::freq_khz_to_channel(2_462_000), 11);
|
||||
assert_eq!(wlan_ffi::freq_khz_to_channel(2_484_000), 14);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn freq_khz_to_channel_5ghz() {
|
||||
assert_eq!(wlan_ffi::freq_khz_to_channel(5_180_000), 36);
|
||||
assert_eq!(wlan_ffi::freq_khz_to_channel(5_240_000), 48);
|
||||
assert_eq!(wlan_ffi::freq_khz_to_channel(5_745_000), 149);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn freq_khz_to_channel_6ghz() {
|
||||
// 6 GHz channel 1 = 5955 MHz
|
||||
assert_eq!(wlan_ffi::freq_khz_to_channel(5_955_000), 1);
|
||||
// 6 GHz channel 5 = 5975 MHz
|
||||
assert_eq!(wlan_ffi::freq_khz_to_channel(5_975_000), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn freq_khz_to_channel_unknown_returns_zero() {
|
||||
assert_eq!(wlan_ffi::freq_khz_to_channel(900_000), 0);
|
||||
assert_eq!(wlan_ffi::freq_khz_to_channel(0), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn freq_khz_to_band_classification() {
|
||||
assert_eq!(wlan_ffi::freq_khz_to_band(2_437_000), 0); // 2.4 GHz
|
||||
assert_eq!(wlan_ffi::freq_khz_to_band(5_180_000), 1); // 5 GHz
|
||||
assert_eq!(wlan_ffi::freq_khz_to_band(5_975_000), 2); // 6 GHz
|
||||
}
|
||||
|
||||
// -- WlanScanPort trait compliance -----------------------------------------
|
||||
|
||||
#[test]
|
||||
fn implements_wlan_scan_port() {
|
||||
// Compile-time check: WlanApiScanner implements WlanScanPort.
|
||||
fn assert_port<T: WlanScanPort>() {}
|
||||
assert_port::<WlanApiScanner>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn implements_send_and_sync() {
|
||||
fn assert_send_sync<T: Send + Sync>() {}
|
||||
assert_send_sync::<WlanApiScanner>();
|
||||
}
|
||||
|
||||
// -- metrics structure -----------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn scan_metrics_debug_display() {
|
||||
let m = ScanMetrics {
|
||||
scan_count: 42,
|
||||
total_bssids_observed: 126,
|
||||
last_scan_duration: Some(Duration::from_millis(150)),
|
||||
estimated_rate_hz: Some(1.0 / 0.15),
|
||||
};
|
||||
let debug = format!("{m:?}");
|
||||
assert!(debug.contains("42"));
|
||||
assert!(debug.contains("126"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scan_metrics_clone() {
|
||||
let m = ScanMetrics {
|
||||
scan_count: 1,
|
||||
total_bssids_observed: 5,
|
||||
last_scan_duration: None,
|
||||
estimated_rate_hz: None,
|
||||
};
|
||||
let m2 = m.clone();
|
||||
assert_eq!(m2.scan_count, 1);
|
||||
assert_eq!(m2.total_bssids_observed, 5);
|
||||
}
|
||||
|
||||
// -- rate estimation -------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn estimated_rate_from_known_duration() {
|
||||
let scanner = WlanApiScanner::new();
|
||||
|
||||
// Manually set last_scan_duration to simulate a completed scan.
|
||||
{
|
||||
let mut guard = scanner.last_scan_duration.lock().unwrap();
|
||||
*guard = Some(Duration::from_millis(100));
|
||||
}
|
||||
|
||||
let m = scanner.metrics();
|
||||
let rate = m.estimated_rate_hz.unwrap();
|
||||
// 100ms per scan => 10 Hz
|
||||
assert!((rate - 10.0).abs() < 0.01, "expected ~10 Hz, got {rate}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn estimated_rate_none_before_first_scan() {
|
||||
let scanner = WlanApiScanner::new();
|
||||
assert!(scanner.metrics().estimated_rate_hz.is_none());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,282 @@
|
||||
//! Core value objects for BSSID identification and observation.
|
||||
//!
|
||||
//! These types form the shared kernel of the BSSID Acquisition bounded context
|
||||
//! as defined in ADR-022 section 3.1.
|
||||
|
||||
use std::fmt;
|
||||
use std::time::Instant;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::WifiScanError;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BssidId -- Value Object
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A unique BSSID identifier wrapping a 6-byte IEEE 802.11 MAC address.
|
||||
///
|
||||
/// This is the primary identity for access points in the multi-BSSID scanning
|
||||
/// pipeline. Two `BssidId` values are equal when their MAC bytes match.
|
||||
#[derive(Clone, Copy, Hash, Eq, PartialEq, Ord, PartialOrd)]
|
||||
pub struct BssidId(pub [u8; 6]);
|
||||
|
||||
impl BssidId {
|
||||
/// Create a `BssidId` from a byte slice.
|
||||
///
|
||||
/// Returns an error if the slice is not exactly 6 bytes.
|
||||
pub fn from_bytes(bytes: &[u8]) -> Result<Self, WifiScanError> {
|
||||
let arr: [u8; 6] = bytes
|
||||
.try_into()
|
||||
.map_err(|_| WifiScanError::InvalidMac { len: bytes.len() })?;
|
||||
Ok(Self(arr))
|
||||
}
|
||||
|
||||
/// Parse a `BssidId` from a colon-separated hex string such as
|
||||
/// `"aa:bb:cc:dd:ee:ff"`.
|
||||
pub fn parse(s: &str) -> Result<Self, WifiScanError> {
|
||||
let parts: Vec<&str> = s.split(':').collect();
|
||||
if parts.len() != 6 {
|
||||
return Err(WifiScanError::MacParseFailed {
|
||||
input: s.to_owned(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut bytes = [0u8; 6];
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
bytes[i] = u8::from_str_radix(part, 16).map_err(|_| WifiScanError::MacParseFailed {
|
||||
input: s.to_owned(),
|
||||
})?;
|
||||
}
|
||||
Ok(Self(bytes))
|
||||
}
|
||||
|
||||
/// Return the raw 6-byte MAC address.
|
||||
pub fn as_bytes(&self) -> &[u8; 6] {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for BssidId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "BssidId({self})")
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for BssidId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let [a, b, c, d, e, g] = self.0;
|
||||
write!(f, "{a:02x}:{b:02x}:{c:02x}:{d:02x}:{e:02x}:{g:02x}")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BandType -- Value Object
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// The WiFi frequency band on which a BSSID operates.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub enum BandType {
|
||||
/// 2.4 GHz (channels 1-14)
|
||||
Band2_4GHz,
|
||||
/// 5 GHz (channels 36-177)
|
||||
Band5GHz,
|
||||
/// 6 GHz (Wi-Fi 6E / 7)
|
||||
Band6GHz,
|
||||
}
|
||||
|
||||
impl BandType {
|
||||
/// Infer the band from an 802.11 channel number.
|
||||
pub fn from_channel(channel: u8) -> Self {
|
||||
match channel {
|
||||
1..=14 => Self::Band2_4GHz,
|
||||
32..=177 => Self::Band5GHz,
|
||||
_ => Self::Band6GHz,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for BandType {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Band2_4GHz => write!(f, "2.4 GHz"),
|
||||
Self::Band5GHz => write!(f, "5 GHz"),
|
||||
Self::Band6GHz => write!(f, "6 GHz"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RadioType -- Value Object
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// The 802.11 radio standard reported by the access point.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub enum RadioType {
|
||||
/// 802.11n (Wi-Fi 4)
|
||||
N,
|
||||
/// 802.11ac (Wi-Fi 5)
|
||||
Ac,
|
||||
/// 802.11ax (Wi-Fi 6 / 6E)
|
||||
Ax,
|
||||
/// 802.11be (Wi-Fi 7)
|
||||
Be,
|
||||
}
|
||||
|
||||
impl RadioType {
|
||||
/// Parse a radio type from a `netsh` output string such as `"802.11ax"`.
|
||||
///
|
||||
/// Returns `None` for unrecognised strings.
|
||||
pub fn from_netsh_str(s: &str) -> Option<Self> {
|
||||
let lower = s.trim().to_ascii_lowercase();
|
||||
if lower.contains("802.11be") || lower.contains("be") {
|
||||
Some(Self::Be)
|
||||
} else if lower.contains("802.11ax") || lower.contains("ax") || lower.contains("wi-fi 6")
|
||||
{
|
||||
Some(Self::Ax)
|
||||
} else if lower.contains("802.11ac") || lower.contains("ac") || lower.contains("wi-fi 5")
|
||||
{
|
||||
Some(Self::Ac)
|
||||
} else if lower.contains("802.11n") || lower.contains("wi-fi 4") {
|
||||
Some(Self::N)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for RadioType {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::N => write!(f, "802.11n"),
|
||||
Self::Ac => write!(f, "802.11ac"),
|
||||
Self::Ax => write!(f, "802.11ax"),
|
||||
Self::Be => write!(f, "802.11be"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BssidObservation -- Value Object
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A single observation of a BSSID from a WiFi scan.
|
||||
///
|
||||
/// This is the fundamental measurement unit: one access point observed once
|
||||
/// at a specific point in time.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct BssidObservation {
|
||||
/// The MAC address of the observed access point.
|
||||
pub bssid: BssidId,
|
||||
/// Received signal strength in dBm (typically -30 to -90).
|
||||
pub rssi_dbm: f64,
|
||||
/// Signal quality as a percentage (0-100), as reported by the driver.
|
||||
pub signal_pct: f64,
|
||||
/// The 802.11 channel number.
|
||||
pub channel: u8,
|
||||
/// The frequency band.
|
||||
pub band: BandType,
|
||||
/// The 802.11 radio standard.
|
||||
pub radio_type: RadioType,
|
||||
/// The SSID (network name). May be empty for hidden networks.
|
||||
pub ssid: String,
|
||||
/// When this observation was captured.
|
||||
pub timestamp: Instant,
|
||||
}
|
||||
|
||||
impl BssidObservation {
|
||||
/// Convert signal percentage (0-100) to an approximate dBm value.
|
||||
///
|
||||
/// Uses the common linear mapping: `dBm = (pct / 2) - 100`.
|
||||
/// This matches the conversion used by Windows WLAN API.
|
||||
pub fn pct_to_dbm(pct: f64) -> f64 {
|
||||
(pct / 2.0) - 100.0
|
||||
}
|
||||
|
||||
/// Convert dBm to a linear amplitude suitable for pseudo-CSI frames.
|
||||
///
|
||||
/// Formula: `10^((rssi_dbm + 100) / 20)`, mapping -100 dBm to 1.0.
|
||||
pub fn rssi_to_amplitude(rssi_dbm: f64) -> f64 {
|
||||
10.0_f64.powf((rssi_dbm + 100.0) / 20.0)
|
||||
}
|
||||
|
||||
/// Return the amplitude of this observation (linear scale).
|
||||
pub fn amplitude(&self) -> f64 {
|
||||
Self::rssi_to_amplitude(self.rssi_dbm)
|
||||
}
|
||||
|
||||
/// Encode the channel number as a pseudo-phase value in `[0, pi]`.
|
||||
///
|
||||
/// This provides downstream pipeline compatibility with code that expects
|
||||
/// phase data, even though RSSI-based scanning has no true phase.
|
||||
pub fn pseudo_phase(&self) -> f64 {
|
||||
(self.channel as f64 / 48.0) * std::f64::consts::PI
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn bssid_id_roundtrip() {
|
||||
let mac = [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff];
|
||||
let id = BssidId(mac);
|
||||
assert_eq!(id.to_string(), "aa:bb:cc:dd:ee:ff");
|
||||
assert_eq!(BssidId::parse("aa:bb:cc:dd:ee:ff").unwrap(), id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bssid_id_parse_errors() {
|
||||
assert!(BssidId::parse("aa:bb:cc").is_err());
|
||||
assert!(BssidId::parse("zz:bb:cc:dd:ee:ff").is_err());
|
||||
assert!(BssidId::parse("").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bssid_id_from_bytes() {
|
||||
let bytes = vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06];
|
||||
let id = BssidId::from_bytes(&bytes).unwrap();
|
||||
assert_eq!(id.0, [0x01, 0x02, 0x03, 0x04, 0x05, 0x06]);
|
||||
|
||||
assert!(BssidId::from_bytes(&[0x01, 0x02]).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn band_type_from_channel() {
|
||||
assert_eq!(BandType::from_channel(1), BandType::Band2_4GHz);
|
||||
assert_eq!(BandType::from_channel(11), BandType::Band2_4GHz);
|
||||
assert_eq!(BandType::from_channel(36), BandType::Band5GHz);
|
||||
assert_eq!(BandType::from_channel(149), BandType::Band5GHz);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn radio_type_from_netsh() {
|
||||
assert_eq!(RadioType::from_netsh_str("802.11ax"), Some(RadioType::Ax));
|
||||
assert_eq!(RadioType::from_netsh_str("802.11ac"), Some(RadioType::Ac));
|
||||
assert_eq!(RadioType::from_netsh_str("802.11n"), Some(RadioType::N));
|
||||
assert_eq!(RadioType::from_netsh_str("802.11be"), Some(RadioType::Be));
|
||||
assert_eq!(RadioType::from_netsh_str("unknown"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pct_to_dbm_conversion() {
|
||||
// 100% -> -50 dBm
|
||||
assert!((BssidObservation::pct_to_dbm(100.0) - (-50.0)).abs() < f64::EPSILON);
|
||||
// 0% -> -100 dBm
|
||||
assert!((BssidObservation::pct_to_dbm(0.0) - (-100.0)).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rssi_to_amplitude_baseline() {
|
||||
// At -100 dBm, amplitude should be 1.0
|
||||
let amp = BssidObservation::rssi_to_amplitude(-100.0);
|
||||
assert!((amp - 1.0).abs() < 1e-9);
|
||||
// At -80 dBm, amplitude should be 10.0
|
||||
let amp = BssidObservation::rssi_to_amplitude(-80.0);
|
||||
assert!((amp - 10.0).abs() < 1e-9);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
//! Multi-AP frame value object.
|
||||
//!
|
||||
//! A `MultiApFrame` is a snapshot of all BSSID observations at a single point
|
||||
//! in time. It serves as the input to the signal intelligence pipeline
|
||||
//! (Bounded Context 2 in ADR-022), providing the multi-dimensional
|
||||
//! pseudo-CSI data that replaces the single-RSSI approach.
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use std::time::Instant;
|
||||
|
||||
/// A snapshot of all tracked BSSIDs at a single point in time.
|
||||
///
|
||||
/// This value object is produced by [`BssidRegistry::to_multi_ap_frame`] and
|
||||
/// consumed by the signal intelligence pipeline. Each index `i` in the
|
||||
/// vectors corresponds to the `i`-th entry in the registry's subcarrier map.
|
||||
///
|
||||
/// [`BssidRegistry::to_multi_ap_frame`]: crate::domain::registry::BssidRegistry::to_multi_ap_frame
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MultiApFrame {
|
||||
/// Number of BSSIDs (pseudo-subcarriers) in this frame.
|
||||
pub bssid_count: usize,
|
||||
|
||||
/// RSSI values in dBm, one per BSSID.
|
||||
///
|
||||
/// Index matches the subcarrier map ordering.
|
||||
pub rssi_dbm: Vec<f64>,
|
||||
|
||||
/// Linear amplitudes derived from RSSI via `10^((rssi + 100) / 20)`.
|
||||
///
|
||||
/// This maps -100 dBm to amplitude 1.0, providing a scale that is
|
||||
/// compatible with the downstream attention and correlation stages.
|
||||
pub amplitudes: Vec<f64>,
|
||||
|
||||
/// Pseudo-phase values derived from channel numbers.
|
||||
///
|
||||
/// Encoded as `(channel / 48) * pi`, giving a value in `[0, pi]`.
|
||||
/// This is a heuristic that provides spatial diversity information
|
||||
/// to pipeline stages that expect phase data.
|
||||
pub phases: Vec<f64>,
|
||||
|
||||
/// Per-BSSID RSSI variance (Welford), one per BSSID.
|
||||
///
|
||||
/// High variance indicates a BSSID whose signal is modulated by body
|
||||
/// movement; low variance indicates a static background AP.
|
||||
pub per_bssid_variance: Vec<f64>,
|
||||
|
||||
/// Per-BSSID RSSI history (ring buffer), one per BSSID.
|
||||
///
|
||||
/// Used by the spatial correlator and breathing extractor to compute
|
||||
/// cross-correlation and spectral features.
|
||||
pub histories: Vec<VecDeque<f64>>,
|
||||
|
||||
/// Estimated effective sample rate in Hz.
|
||||
///
|
||||
/// Tier 1 (netsh): approximately 2 Hz.
|
||||
/// Tier 2 (wlanapi): approximately 10-20 Hz.
|
||||
pub sample_rate_hz: f64,
|
||||
|
||||
/// When this frame was constructed.
|
||||
pub timestamp: Instant,
|
||||
}
|
||||
|
||||
impl MultiApFrame {
|
||||
/// Whether this frame has enough BSSIDs for multi-AP sensing.
|
||||
///
|
||||
/// The `min_bssids` parameter comes from `WindowsWifiConfig::min_bssids`.
|
||||
pub fn is_sufficient(&self, min_bssids: usize) -> bool {
|
||||
self.bssid_count >= min_bssids
|
||||
}
|
||||
|
||||
/// The maximum amplitude across all BSSIDs. Returns 0.0 for empty frames.
|
||||
pub fn max_amplitude(&self) -> f64 {
|
||||
self.amplitudes
|
||||
.iter()
|
||||
.copied()
|
||||
.fold(0.0_f64, f64::max)
|
||||
}
|
||||
|
||||
/// The mean RSSI across all BSSIDs in dBm. Returns `f64::NEG_INFINITY`
|
||||
/// for empty frames.
|
||||
pub fn mean_rssi(&self) -> f64 {
|
||||
if self.rssi_dbm.is_empty() {
|
||||
return f64::NEG_INFINITY;
|
||||
}
|
||||
let sum: f64 = self.rssi_dbm.iter().sum();
|
||||
sum / self.rssi_dbm.len() as f64
|
||||
}
|
||||
|
||||
/// The total variance across all BSSIDs (sum of per-BSSID variances).
|
||||
///
|
||||
/// Higher values indicate more environmental change, which correlates
|
||||
/// with human presence and movement.
|
||||
pub fn total_variance(&self) -> f64 {
|
||||
self.per_bssid_variance.iter().sum()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_frame(bssid_count: usize, rssi_values: &[f64]) -> MultiApFrame {
|
||||
let amplitudes: Vec<f64> = rssi_values
|
||||
.iter()
|
||||
.map(|&r| 10.0_f64.powf((r + 100.0) / 20.0))
|
||||
.collect();
|
||||
MultiApFrame {
|
||||
bssid_count,
|
||||
rssi_dbm: rssi_values.to_vec(),
|
||||
amplitudes,
|
||||
phases: vec![0.0; bssid_count],
|
||||
per_bssid_variance: vec![0.1; bssid_count],
|
||||
histories: vec![VecDeque::new(); bssid_count],
|
||||
sample_rate_hz: 2.0,
|
||||
timestamp: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_sufficient_checks_threshold() {
|
||||
let frame = make_frame(5, &[-60.0, -65.0, -70.0, -75.0, -80.0]);
|
||||
assert!(frame.is_sufficient(3));
|
||||
assert!(frame.is_sufficient(5));
|
||||
assert!(!frame.is_sufficient(6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mean_rssi_calculation() {
|
||||
let frame = make_frame(3, &[-60.0, -70.0, -80.0]);
|
||||
assert!((frame.mean_rssi() - (-70.0)).abs() < 1e-9);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_frame_handles_gracefully() {
|
||||
let frame = make_frame(0, &[]);
|
||||
assert_eq!(frame.max_amplitude(), 0.0);
|
||||
assert!(frame.mean_rssi().is_infinite());
|
||||
assert_eq!(frame.total_variance(), 0.0);
|
||||
assert!(!frame.is_sufficient(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn total_variance_sums_per_bssid() {
|
||||
let mut frame = make_frame(3, &[-60.0, -70.0, -80.0]);
|
||||
frame.per_bssid_variance = vec![0.1, 0.2, 0.3];
|
||||
assert!((frame.total_variance() - 0.6).abs() < 1e-9);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
//! Domain types for the BSSID Acquisition bounded context (ADR-022).
|
||||
|
||||
pub mod bssid;
|
||||
pub mod frame;
|
||||
pub mod registry;
|
||||
pub mod result;
|
||||
|
||||
pub use bssid::{BandType, BssidId, BssidObservation, RadioType};
|
||||
pub use frame::MultiApFrame;
|
||||
pub use registry::{BssidEntry, BssidMeta, BssidRegistry, RunningStats};
|
||||
pub use result::EnhancedSensingResult;
|
||||
@@ -0,0 +1,511 @@
|
||||
//! BSSID Registry aggregate root.
|
||||
//!
|
||||
//! The `BssidRegistry` is the aggregate root of the BSSID Acquisition bounded
|
||||
//! context. It tracks all visible access points across scans, maintains
|
||||
//! identity stability as BSSIDs appear and disappear, and provides a
|
||||
//! consistent subcarrier mapping for pseudo-CSI frame construction.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::VecDeque;
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::domain::bssid::{BandType, BssidId, BssidObservation, RadioType};
|
||||
use crate::domain::frame::MultiApFrame;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RunningStats -- Welford online statistics
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Welford online algorithm for computing running mean and variance.
|
||||
///
|
||||
/// This allows us to compute per-BSSID statistics incrementally without
|
||||
/// storing the entire history, which is essential for detecting which BSSIDs
|
||||
/// show body-correlated variance versus static background.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RunningStats {
|
||||
/// Number of samples seen.
|
||||
count: u64,
|
||||
/// Running mean.
|
||||
mean: f64,
|
||||
/// Running M2 accumulator (sum of squared differences from the mean).
|
||||
m2: f64,
|
||||
}
|
||||
|
||||
impl RunningStats {
|
||||
/// Create a new empty `RunningStats`.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
count: 0,
|
||||
mean: 0.0,
|
||||
m2: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Push a new sample into the running statistics.
|
||||
pub fn push(&mut self, value: f64) {
|
||||
self.count += 1;
|
||||
let delta = value - self.mean;
|
||||
self.mean += delta / self.count as f64;
|
||||
let delta2 = value - self.mean;
|
||||
self.m2 += delta * delta2;
|
||||
}
|
||||
|
||||
/// The number of samples observed.
|
||||
pub fn count(&self) -> u64 {
|
||||
self.count
|
||||
}
|
||||
|
||||
/// The running mean. Returns 0.0 if no samples have been pushed.
|
||||
pub fn mean(&self) -> f64 {
|
||||
self.mean
|
||||
}
|
||||
|
||||
/// The population variance. Returns 0.0 if fewer than 2 samples.
|
||||
pub fn variance(&self) -> f64 {
|
||||
if self.count < 2 {
|
||||
0.0
|
||||
} else {
|
||||
self.m2 / self.count as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// The sample variance (Bessel-corrected). Returns 0.0 if fewer than 2 samples.
|
||||
pub fn sample_variance(&self) -> f64 {
|
||||
if self.count < 2 {
|
||||
0.0
|
||||
} else {
|
||||
self.m2 / (self.count - 1) as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// The population standard deviation.
|
||||
pub fn std_dev(&self) -> f64 {
|
||||
self.variance().sqrt()
|
||||
}
|
||||
|
||||
/// Reset all statistics to zero.
|
||||
pub fn reset(&mut self) {
|
||||
self.count = 0;
|
||||
self.mean = 0.0;
|
||||
self.m2 = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RunningStats {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BssidMeta -- metadata about a tracked BSSID
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Static metadata about a tracked BSSID, captured on first observation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BssidMeta {
|
||||
/// The SSID (network name). May be empty for hidden networks.
|
||||
pub ssid: String,
|
||||
/// The 802.11 channel number.
|
||||
pub channel: u8,
|
||||
/// The frequency band.
|
||||
pub band: BandType,
|
||||
/// The radio standard.
|
||||
pub radio_type: RadioType,
|
||||
/// When this BSSID was first observed.
|
||||
pub first_seen: Instant,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BssidEntry -- Entity
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A tracked BSSID with observation history and running statistics.
|
||||
///
|
||||
/// Each entry corresponds to one physical access point. The ring buffer
|
||||
/// stores recent RSSI values (in dBm) for temporal analysis, while the
|
||||
/// `RunningStats` provides efficient online mean/variance without needing
|
||||
/// the full history.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BssidEntry {
|
||||
/// The unique identifier for this BSSID.
|
||||
pub id: BssidId,
|
||||
/// Static metadata (SSID, channel, band, radio type).
|
||||
pub meta: BssidMeta,
|
||||
/// Ring buffer of recent RSSI observations (dBm).
|
||||
pub history: VecDeque<f64>,
|
||||
/// Welford online statistics over the full observation lifetime.
|
||||
pub stats: RunningStats,
|
||||
/// When this BSSID was last observed.
|
||||
pub last_seen: Instant,
|
||||
/// Index in the subcarrier map, or `None` if not yet assigned.
|
||||
pub subcarrier_idx: Option<usize>,
|
||||
}
|
||||
|
||||
impl BssidEntry {
|
||||
/// Maximum number of RSSI samples kept in the ring buffer history.
|
||||
pub const DEFAULT_HISTORY_CAPACITY: usize = 128;
|
||||
|
||||
/// Create a new entry from a first observation.
|
||||
fn new(obs: &BssidObservation) -> Self {
|
||||
let mut stats = RunningStats::new();
|
||||
stats.push(obs.rssi_dbm);
|
||||
|
||||
let mut history = VecDeque::with_capacity(Self::DEFAULT_HISTORY_CAPACITY);
|
||||
history.push_back(obs.rssi_dbm);
|
||||
|
||||
Self {
|
||||
id: obs.bssid,
|
||||
meta: BssidMeta {
|
||||
ssid: obs.ssid.clone(),
|
||||
channel: obs.channel,
|
||||
band: obs.band,
|
||||
radio_type: obs.radio_type,
|
||||
first_seen: obs.timestamp,
|
||||
},
|
||||
history,
|
||||
stats,
|
||||
last_seen: obs.timestamp,
|
||||
subcarrier_idx: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a new observation for this BSSID.
|
||||
fn record(&mut self, obs: &BssidObservation) {
|
||||
self.stats.push(obs.rssi_dbm);
|
||||
|
||||
if self.history.len() >= Self::DEFAULT_HISTORY_CAPACITY {
|
||||
self.history.pop_front();
|
||||
}
|
||||
self.history.push_back(obs.rssi_dbm);
|
||||
|
||||
self.last_seen = obs.timestamp;
|
||||
|
||||
// Update mutable metadata in case the AP changed channel/band
|
||||
self.meta.channel = obs.channel;
|
||||
self.meta.band = obs.band;
|
||||
self.meta.radio_type = obs.radio_type;
|
||||
if !obs.ssid.is_empty() {
|
||||
self.meta.ssid = obs.ssid.clone();
|
||||
}
|
||||
}
|
||||
|
||||
/// The RSSI variance over the observation lifetime (Welford).
|
||||
pub fn variance(&self) -> f64 {
|
||||
self.stats.variance()
|
||||
}
|
||||
|
||||
/// The most recent RSSI observation in dBm.
|
||||
pub fn latest_rssi(&self) -> Option<f64> {
|
||||
self.history.back().copied()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BssidRegistry -- Aggregate Root
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Aggregate root that tracks all visible BSSIDs across scans.
|
||||
///
|
||||
/// The registry maintains:
|
||||
/// - A map of known BSSIDs with per-BSSID history and statistics.
|
||||
/// - An ordered subcarrier map that assigns each BSSID a stable index,
|
||||
/// sorted by first-seen time so that the mapping is deterministic.
|
||||
/// - Expiry logic to remove BSSIDs that have not been observed recently.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BssidRegistry {
|
||||
/// Known BSSIDs with sliding window of observations.
|
||||
entries: HashMap<BssidId, BssidEntry>,
|
||||
/// Ordered list of BSSID IDs for consistent subcarrier mapping.
|
||||
/// Sorted by first-seen time for stability.
|
||||
subcarrier_map: Vec<BssidId>,
|
||||
/// Maximum number of tracked BSSIDs (maps to max pseudo-subcarriers).
|
||||
max_bssids: usize,
|
||||
/// How long a BSSID can go unseen before being expired (in seconds).
|
||||
expiry_secs: u64,
|
||||
}
|
||||
|
||||
impl BssidRegistry {
|
||||
/// Default maximum number of tracked BSSIDs.
|
||||
pub const DEFAULT_MAX_BSSIDS: usize = 32;
|
||||
|
||||
/// Default expiry time in seconds.
|
||||
pub const DEFAULT_EXPIRY_SECS: u64 = 30;
|
||||
|
||||
/// Create a new registry with the given capacity and expiry settings.
|
||||
pub fn new(max_bssids: usize, expiry_secs: u64) -> Self {
|
||||
Self {
|
||||
entries: HashMap::with_capacity(max_bssids),
|
||||
subcarrier_map: Vec::with_capacity(max_bssids),
|
||||
max_bssids,
|
||||
expiry_secs,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the registry with a batch of observations from a single scan.
|
||||
///
|
||||
/// New BSSIDs are registered and assigned subcarrier indices. Existing
|
||||
/// BSSIDs have their history and statistics updated. BSSIDs that have
|
||||
/// not been seen within the expiry window are removed.
|
||||
pub fn update(&mut self, observations: &[BssidObservation]) {
|
||||
let now = if let Some(obs) = observations.first() {
|
||||
obs.timestamp
|
||||
} else {
|
||||
return;
|
||||
};
|
||||
|
||||
// Update or insert each observed BSSID
|
||||
for obs in observations {
|
||||
if let Some(entry) = self.entries.get_mut(&obs.bssid) {
|
||||
entry.record(obs);
|
||||
} else if self.subcarrier_map.len() < self.max_bssids {
|
||||
// New BSSID: register it
|
||||
let mut entry = BssidEntry::new(obs);
|
||||
let idx = self.subcarrier_map.len();
|
||||
entry.subcarrier_idx = Some(idx);
|
||||
self.subcarrier_map.push(obs.bssid);
|
||||
self.entries.insert(obs.bssid, entry);
|
||||
}
|
||||
// If we are at capacity, silently ignore new BSSIDs.
|
||||
// A smarter policy (evict lowest-variance) can be added later.
|
||||
}
|
||||
|
||||
// Expire stale BSSIDs
|
||||
self.expire(now);
|
||||
}
|
||||
|
||||
/// Remove BSSIDs that have not been observed within the expiry window.
|
||||
fn expire(&mut self, now: Instant) {
|
||||
let expiry = std::time::Duration::from_secs(self.expiry_secs);
|
||||
let stale: Vec<BssidId> = self
|
||||
.entries
|
||||
.iter()
|
||||
.filter(|(_, entry)| now.duration_since(entry.last_seen) > expiry)
|
||||
.map(|(id, _)| *id)
|
||||
.collect();
|
||||
|
||||
for id in &stale {
|
||||
self.entries.remove(id);
|
||||
}
|
||||
|
||||
if !stale.is_empty() {
|
||||
// Rebuild the subcarrier map without the stale entries,
|
||||
// preserving relative ordering.
|
||||
self.subcarrier_map.retain(|id| !stale.contains(id));
|
||||
// Re-index remaining entries
|
||||
for (idx, id) in self.subcarrier_map.iter().enumerate() {
|
||||
if let Some(entry) = self.entries.get_mut(id) {
|
||||
entry.subcarrier_idx = Some(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Look up the subcarrier index assigned to a BSSID.
|
||||
pub fn subcarrier_index(&self, bssid: &BssidId) -> Option<usize> {
|
||||
self.entries
|
||||
.get(bssid)
|
||||
.and_then(|entry| entry.subcarrier_idx)
|
||||
}
|
||||
|
||||
/// Return the ordered subcarrier map (list of BSSID IDs).
|
||||
pub fn subcarrier_map(&self) -> &[BssidId] {
|
||||
&self.subcarrier_map
|
||||
}
|
||||
|
||||
/// The number of currently tracked BSSIDs.
|
||||
pub fn len(&self) -> usize {
|
||||
self.entries.len()
|
||||
}
|
||||
|
||||
/// Whether the registry is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.entries.is_empty()
|
||||
}
|
||||
|
||||
/// The maximum number of BSSIDs this registry can track.
|
||||
pub fn capacity(&self) -> usize {
|
||||
self.max_bssids
|
||||
}
|
||||
|
||||
/// Get an entry by BSSID ID.
|
||||
pub fn get(&self, bssid: &BssidId) -> Option<&BssidEntry> {
|
||||
self.entries.get(bssid)
|
||||
}
|
||||
|
||||
/// Iterate over all tracked entries.
|
||||
pub fn entries(&self) -> impl Iterator<Item = &BssidEntry> {
|
||||
self.entries.values()
|
||||
}
|
||||
|
||||
/// Build a `MultiApFrame` from the current registry state.
|
||||
///
|
||||
/// The frame contains one slot per subcarrier (BSSID), with amplitudes
|
||||
/// derived from the most recent RSSI observation and pseudo-phase from
|
||||
/// the channel number.
|
||||
pub fn to_multi_ap_frame(&self) -> MultiApFrame {
|
||||
let n = self.subcarrier_map.len();
|
||||
let mut rssi_dbm = vec![0.0_f64; n];
|
||||
let mut amplitudes = vec![0.0_f64; n];
|
||||
let mut phases = vec![0.0_f64; n];
|
||||
let mut per_bssid_variance = vec![0.0_f64; n];
|
||||
let mut histories: Vec<VecDeque<f64>> = Vec::with_capacity(n);
|
||||
|
||||
for (idx, bssid_id) in self.subcarrier_map.iter().enumerate() {
|
||||
if let Some(entry) = self.entries.get(bssid_id) {
|
||||
let latest = entry.latest_rssi().unwrap_or(-100.0);
|
||||
rssi_dbm[idx] = latest;
|
||||
amplitudes[idx] = BssidObservation::rssi_to_amplitude(latest);
|
||||
phases[idx] = (entry.meta.channel as f64 / 48.0) * std::f64::consts::PI;
|
||||
per_bssid_variance[idx] = entry.variance();
|
||||
histories.push(entry.history.clone());
|
||||
} else {
|
||||
histories.push(VecDeque::new());
|
||||
}
|
||||
}
|
||||
|
||||
// Estimate sample rate from observation count and time span
|
||||
let sample_rate_hz = self.estimate_sample_rate();
|
||||
|
||||
MultiApFrame {
|
||||
bssid_count: n,
|
||||
rssi_dbm,
|
||||
amplitudes,
|
||||
phases,
|
||||
per_bssid_variance,
|
||||
histories,
|
||||
sample_rate_hz,
|
||||
timestamp: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Rough estimate of the effective sample rate based on observation history.
|
||||
fn estimate_sample_rate(&self) -> f64 {
|
||||
// Default to 2 Hz (Tier 1 netsh rate) when we cannot compute
|
||||
if self.entries.is_empty() {
|
||||
return 2.0;
|
||||
}
|
||||
|
||||
// Use the first entry with enough history
|
||||
for entry in self.entries.values() {
|
||||
if entry.stats.count() >= 4 {
|
||||
let elapsed = entry
|
||||
.last_seen
|
||||
.duration_since(entry.meta.first_seen)
|
||||
.as_secs_f64();
|
||||
if elapsed > 0.0 {
|
||||
return entry.stats.count() as f64 / elapsed;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
2.0 // Fallback: assume Tier 1 rate
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for BssidRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new(Self::DEFAULT_MAX_BSSIDS, Self::DEFAULT_EXPIRY_SECS)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::domain::bssid::{BandType, RadioType};
|
||||
|
||||
fn make_obs(mac: [u8; 6], rssi: f64, channel: u8) -> BssidObservation {
|
||||
BssidObservation {
|
||||
bssid: BssidId(mac),
|
||||
rssi_dbm: rssi,
|
||||
signal_pct: (rssi + 100.0) * 2.0,
|
||||
channel,
|
||||
band: BandType::from_channel(channel),
|
||||
radio_type: RadioType::Ax,
|
||||
ssid: "TestNetwork".to_string(),
|
||||
timestamp: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registry_tracks_new_bssids() {
|
||||
let mut reg = BssidRegistry::default();
|
||||
let obs = vec![
|
||||
make_obs([0x01; 6], -60.0, 6),
|
||||
make_obs([0x02; 6], -70.0, 36),
|
||||
];
|
||||
reg.update(&obs);
|
||||
|
||||
assert_eq!(reg.len(), 2);
|
||||
assert_eq!(reg.subcarrier_index(&BssidId([0x01; 6])), Some(0));
|
||||
assert_eq!(reg.subcarrier_index(&BssidId([0x02; 6])), Some(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registry_updates_existing_bssid() {
|
||||
let mut reg = BssidRegistry::default();
|
||||
let mac = [0xaa; 6];
|
||||
|
||||
let obs1 = vec![make_obs(mac, -60.0, 6)];
|
||||
reg.update(&obs1);
|
||||
|
||||
let obs2 = vec![make_obs(mac, -65.0, 6)];
|
||||
reg.update(&obs2);
|
||||
|
||||
let entry = reg.get(&BssidId(mac)).unwrap();
|
||||
assert_eq!(entry.stats.count(), 2);
|
||||
assert_eq!(entry.history.len(), 2);
|
||||
assert!((entry.stats.mean() - (-62.5)).abs() < 1e-9);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registry_respects_capacity() {
|
||||
let mut reg = BssidRegistry::new(2, 30);
|
||||
let obs = vec![
|
||||
make_obs([0x01; 6], -60.0, 1),
|
||||
make_obs([0x02; 6], -70.0, 6),
|
||||
make_obs([0x03; 6], -80.0, 11), // Should be ignored
|
||||
];
|
||||
reg.update(&obs);
|
||||
|
||||
assert_eq!(reg.len(), 2);
|
||||
assert!(reg.get(&BssidId([0x03; 6])).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_multi_ap_frame_builds_correct_frame() {
|
||||
let mut reg = BssidRegistry::default();
|
||||
let obs = vec![
|
||||
make_obs([0x01; 6], -60.0, 6),
|
||||
make_obs([0x02; 6], -70.0, 36),
|
||||
];
|
||||
reg.update(&obs);
|
||||
|
||||
let frame = reg.to_multi_ap_frame();
|
||||
assert_eq!(frame.bssid_count, 2);
|
||||
assert_eq!(frame.rssi_dbm.len(), 2);
|
||||
assert_eq!(frame.amplitudes.len(), 2);
|
||||
assert_eq!(frame.phases.len(), 2);
|
||||
assert!(frame.amplitudes[0] > frame.amplitudes[1]); // -60 dBm > -70 dBm
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn welford_stats_accuracy() {
|
||||
let mut stats = RunningStats::new();
|
||||
let values = [2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
|
||||
for v in &values {
|
||||
stats.push(*v);
|
||||
}
|
||||
|
||||
assert_eq!(stats.count(), 8);
|
||||
assert!((stats.mean() - 5.0).abs() < 1e-9);
|
||||
// Population variance of this dataset is 4.0
|
||||
assert!((stats.variance() - 4.0).abs() < 1e-9);
|
||||
// Sample variance is 4.571428...
|
||||
assert!((stats.sample_variance() - (32.0 / 7.0)).abs() < 1e-9);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,216 @@
|
||||
//! Enhanced sensing result value object.
|
||||
//!
|
||||
//! The `EnhancedSensingResult` is the output of the signal intelligence
|
||||
//! pipeline, carrying motion, breathing, posture, and quality metrics
|
||||
//! derived from multi-BSSID pseudo-CSI data.
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// MotionLevel
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Coarse classification of detected motion intensity.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub enum MotionLevel {
|
||||
/// No significant change in BSSID variance; room likely empty.
|
||||
None,
|
||||
/// Very small fluctuations consistent with a stationary person
|
||||
/// (e.g., breathing, minor fidgeting).
|
||||
Minimal,
|
||||
/// Moderate changes suggesting slow movement (e.g., walking, gesturing).
|
||||
Moderate,
|
||||
/// Large variance swings indicating vigorous or rapid movement.
|
||||
High,
|
||||
}
|
||||
|
||||
impl MotionLevel {
|
||||
/// Map a normalised motion score `[0.0, 1.0]` to a `MotionLevel`.
|
||||
///
|
||||
/// The thresholds are tuned for multi-BSSID RSSI variance and can be
|
||||
/// overridden via `WindowsWifiConfig` in the pipeline layer.
|
||||
pub fn from_score(score: f64) -> Self {
|
||||
if score < 0.05 {
|
||||
Self::None
|
||||
} else if score < 0.20 {
|
||||
Self::Minimal
|
||||
} else if score < 0.60 {
|
||||
Self::Moderate
|
||||
} else {
|
||||
Self::High
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// MotionEstimate
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Quantitative motion estimate from the multi-BSSID pipeline.
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub struct MotionEstimate {
|
||||
/// Normalised motion score in `[0.0, 1.0]`.
|
||||
pub score: f64,
|
||||
/// Coarse classification derived from the score.
|
||||
pub level: MotionLevel,
|
||||
/// The number of BSSIDs contributing to this estimate.
|
||||
pub contributing_bssids: usize,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BreathingEstimate
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Coarse respiratory rate estimate extracted from body-sensitive BSSIDs.
|
||||
///
|
||||
/// Only valid when motion level is `Minimal` (person stationary) and at
|
||||
/// least 3 body-correlated BSSIDs are available. The accuracy is limited
|
||||
/// by the low sample rate of Tier 1 scanning (~2 Hz).
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub struct BreathingEstimate {
|
||||
/// Estimated breaths per minute (typical: 12-20 for adults at rest).
|
||||
pub rate_bpm: f64,
|
||||
/// Confidence in the estimate, `[0.0, 1.0]`.
|
||||
pub confidence: f64,
|
||||
/// Number of BSSIDs used for the spectral analysis.
|
||||
pub bssid_count: usize,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// PostureClass
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Coarse posture classification from BSSID fingerprint matching.
|
||||
///
|
||||
/// Based on Hopfield template matching of the multi-BSSID amplitude
|
||||
/// signature against stored reference patterns.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub enum PostureClass {
|
||||
/// Room appears empty.
|
||||
Empty,
|
||||
/// Person standing.
|
||||
Standing,
|
||||
/// Person sitting.
|
||||
Sitting,
|
||||
/// Person lying down.
|
||||
LyingDown,
|
||||
/// Person walking / in motion.
|
||||
Walking,
|
||||
/// Unknown posture (insufficient confidence).
|
||||
Unknown,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SignalQuality
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Signal quality metrics for the current multi-BSSID frame.
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub struct SignalQuality {
|
||||
/// Overall quality score `[0.0, 1.0]`, where 1.0 is excellent.
|
||||
pub score: f64,
|
||||
/// Number of BSSIDs in the current frame.
|
||||
pub bssid_count: usize,
|
||||
/// Spectral gap from the BSSID correlation graph.
|
||||
/// A large gap indicates good signal separation.
|
||||
pub spectral_gap: f64,
|
||||
/// Mean RSSI across all tracked BSSIDs (dBm).
|
||||
pub mean_rssi_dbm: f64,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Verdict
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Quality gate verdict from the ruQu three-filter pipeline.
|
||||
///
|
||||
/// The pipeline evaluates structural integrity, statistical shift
|
||||
/// significance, and evidence accumulation before permitting a reading.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub enum Verdict {
|
||||
/// Reading passed all quality gates and is reliable.
|
||||
Permit,
|
||||
/// Reading shows some anomalies but is usable with reduced confidence.
|
||||
Warn,
|
||||
/// Reading failed quality checks and should be discarded.
|
||||
Deny,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// EnhancedSensingResult
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// The output of the multi-BSSID signal intelligence pipeline.
|
||||
///
|
||||
/// This value object carries all sensing information derived from a single
|
||||
/// scan cycle. It is converted to a `SensingUpdate` by the Sensing Output
|
||||
/// bounded context for delivery to the UI.
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub struct EnhancedSensingResult {
|
||||
/// Motion detection result.
|
||||
pub motion: MotionEstimate,
|
||||
/// Coarse respiratory rate, if detectable.
|
||||
pub breathing: Option<BreathingEstimate>,
|
||||
/// Posture classification, if available.
|
||||
pub posture: Option<PostureClass>,
|
||||
/// Signal quality metrics for the current frame.
|
||||
pub signal_quality: SignalQuality,
|
||||
/// Number of BSSIDs used in this sensing cycle.
|
||||
pub bssid_count: usize,
|
||||
/// Quality gate verdict.
|
||||
pub verdict: Verdict,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn motion_level_thresholds() {
|
||||
assert_eq!(MotionLevel::from_score(0.0), MotionLevel::None);
|
||||
assert_eq!(MotionLevel::from_score(0.04), MotionLevel::None);
|
||||
assert_eq!(MotionLevel::from_score(0.05), MotionLevel::Minimal);
|
||||
assert_eq!(MotionLevel::from_score(0.19), MotionLevel::Minimal);
|
||||
assert_eq!(MotionLevel::from_score(0.20), MotionLevel::Moderate);
|
||||
assert_eq!(MotionLevel::from_score(0.59), MotionLevel::Moderate);
|
||||
assert_eq!(MotionLevel::from_score(0.60), MotionLevel::High);
|
||||
assert_eq!(MotionLevel::from_score(1.0), MotionLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enhanced_result_construction() {
|
||||
let result = EnhancedSensingResult {
|
||||
motion: MotionEstimate {
|
||||
score: 0.3,
|
||||
level: MotionLevel::Moderate,
|
||||
contributing_bssids: 10,
|
||||
},
|
||||
breathing: Some(BreathingEstimate {
|
||||
rate_bpm: 16.0,
|
||||
confidence: 0.7,
|
||||
bssid_count: 5,
|
||||
}),
|
||||
posture: Some(PostureClass::Standing),
|
||||
signal_quality: SignalQuality {
|
||||
score: 0.85,
|
||||
bssid_count: 15,
|
||||
spectral_gap: 0.42,
|
||||
mean_rssi_dbm: -65.0,
|
||||
},
|
||||
bssid_count: 15,
|
||||
verdict: Verdict::Permit,
|
||||
};
|
||||
|
||||
assert_eq!(result.motion.level, MotionLevel::Moderate);
|
||||
assert_eq!(result.verdict, Verdict::Permit);
|
||||
assert_eq!(result.bssid_count, 15);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
//! Error types for the wifi-densepose-wifiscan crate.
|
||||
|
||||
use std::fmt;
|
||||
|
||||
/// Errors that can occur during WiFi scanning and BSSID processing.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum WifiScanError {
|
||||
/// The BSSID MAC address bytes are invalid (must be exactly 6 bytes).
|
||||
InvalidMac {
|
||||
/// The number of bytes that were provided.
|
||||
len: usize,
|
||||
},
|
||||
|
||||
/// Failed to parse a MAC address string (expected `aa:bb:cc:dd:ee:ff`).
|
||||
MacParseFailed {
|
||||
/// The input string that could not be parsed.
|
||||
input: String,
|
||||
},
|
||||
|
||||
/// The scan backend returned an error.
|
||||
ScanFailed {
|
||||
/// Human-readable description of what went wrong.
|
||||
reason: String,
|
||||
},
|
||||
|
||||
/// Too few BSSIDs are visible for multi-AP mode.
|
||||
InsufficientBssids {
|
||||
/// Number of BSSIDs observed.
|
||||
observed: usize,
|
||||
/// Minimum required for multi-AP mode.
|
||||
required: usize,
|
||||
},
|
||||
|
||||
/// A BSSID was not found in the registry.
|
||||
BssidNotFound {
|
||||
/// The MAC address that was not found.
|
||||
bssid: [u8; 6],
|
||||
},
|
||||
|
||||
/// The subcarrier map is full and cannot accept more BSSIDs.
|
||||
SubcarrierMapFull {
|
||||
/// Maximum capacity of the subcarrier map.
|
||||
max: usize,
|
||||
},
|
||||
|
||||
/// An RSSI value is out of the expected range.
|
||||
RssiOutOfRange {
|
||||
/// The invalid RSSI value in dBm.
|
||||
value: f64,
|
||||
},
|
||||
|
||||
/// The requested operation is not supported by this adapter.
|
||||
Unsupported(String),
|
||||
|
||||
/// Failed to execute the scan subprocess.
|
||||
ProcessError(String),
|
||||
|
||||
/// Failed to parse scan output.
|
||||
ParseError(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for WifiScanError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::InvalidMac { len } => {
|
||||
write!(f, "invalid MAC address: expected 6 bytes, got {len}")
|
||||
}
|
||||
Self::MacParseFailed { input } => {
|
||||
write!(
|
||||
f,
|
||||
"failed to parse MAC address from '{input}': expected aa:bb:cc:dd:ee:ff"
|
||||
)
|
||||
}
|
||||
Self::ScanFailed { reason } => {
|
||||
write!(f, "WiFi scan failed: {reason}")
|
||||
}
|
||||
Self::InsufficientBssids { observed, required } => {
|
||||
write!(
|
||||
f,
|
||||
"insufficient BSSIDs for multi-AP mode: {observed} observed, {required} required"
|
||||
)
|
||||
}
|
||||
Self::BssidNotFound { bssid } => {
|
||||
write!(
|
||||
f,
|
||||
"BSSID not found in registry: {:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}",
|
||||
bssid[0], bssid[1], bssid[2], bssid[3], bssid[4], bssid[5]
|
||||
)
|
||||
}
|
||||
Self::SubcarrierMapFull { max } => {
|
||||
write!(
|
||||
f,
|
||||
"subcarrier map is full at {max} entries; cannot add more BSSIDs"
|
||||
)
|
||||
}
|
||||
Self::RssiOutOfRange { value } => {
|
||||
write!(f, "RSSI value {value} dBm is out of expected range [-120, 0]")
|
||||
}
|
||||
Self::Unsupported(msg) => {
|
||||
write!(f, "unsupported operation: {msg}")
|
||||
}
|
||||
Self::ProcessError(msg) => {
|
||||
write!(f, "scan process error: {msg}")
|
||||
}
|
||||
Self::ParseError(msg) => {
|
||||
write!(f, "scan output parse error: {msg}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for WifiScanError {}
|
||||
@@ -0,0 +1,30 @@
|
||||
//! # wifi-densepose-wifiscan
|
||||
//!
|
||||
//! Domain layer for multi-BSSID WiFi scanning and enhanced sensing (ADR-022).
|
||||
//!
|
||||
//! This crate implements the **BSSID Acquisition** bounded context, providing:
|
||||
//!
|
||||
//! - **Domain types**: [`BssidId`], [`BssidObservation`], [`BandType`], [`RadioType`]
|
||||
//! - **Port**: [`WlanScanPort`] -- trait abstracting the platform scan backend
|
||||
//! - **Adapter**: [`NetshBssidScanner`] -- Tier 1 adapter that parses
|
||||
//! `netsh wlan show networks mode=bssid` output
|
||||
|
||||
pub mod adapter;
|
||||
pub mod domain;
|
||||
pub mod error;
|
||||
pub mod pipeline;
|
||||
pub mod port;
|
||||
|
||||
// Re-export key types at the crate root for convenience.
|
||||
pub use adapter::NetshBssidScanner;
|
||||
pub use adapter::parse_netsh_output;
|
||||
pub use adapter::WlanApiScanner;
|
||||
pub use domain::bssid::{BandType, BssidId, BssidObservation, RadioType};
|
||||
pub use domain::frame::MultiApFrame;
|
||||
pub use domain::registry::{BssidEntry, BssidMeta, BssidRegistry, RunningStats};
|
||||
pub use domain::result::EnhancedSensingResult;
|
||||
pub use error::WifiScanError;
|
||||
pub use port::WlanScanPort;
|
||||
|
||||
#[cfg(feature = "pipeline")]
|
||||
pub use pipeline::WindowsWifiPipeline;
|
||||
@@ -0,0 +1,129 @@
|
||||
//! Stage 2: Attention-based BSSID weighting.
|
||||
//!
|
||||
//! Uses scaled dot-product attention to learn which BSSIDs respond
|
||||
//! most to body movement. High-variance BSSIDs on body-affected
|
||||
//! paths get higher attention weights.
|
||||
//!
|
||||
//! When the `pipeline` feature is enabled, this uses
|
||||
//! `ruvector_attention::ScaledDotProductAttention` for the core
|
||||
//! attention computation. Otherwise, it falls back to a pure-Rust
|
||||
//! softmax implementation.
|
||||
|
||||
/// Weights BSSIDs by body-sensitivity using attention mechanism.
|
||||
pub struct AttentionWeighter {
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl AttentionWeighter {
|
||||
/// Create a new attention weighter.
|
||||
///
|
||||
/// - `dim`: dimensionality of the attention space (typically 1 for scalar RSSI).
|
||||
#[must_use]
|
||||
pub fn new(dim: usize) -> Self {
|
||||
Self { dim }
|
||||
}
|
||||
|
||||
/// Compute attention-weighted output from BSSID residuals.
|
||||
///
|
||||
/// - `query`: the aggregated variance profile (1 x dim).
|
||||
/// - `keys`: per-BSSID residual vectors (`n_bssids` x dim).
|
||||
/// - `values`: per-BSSID amplitude vectors (`n_bssids` x dim).
|
||||
///
|
||||
/// Returns the weighted amplitude vector and per-BSSID weights.
|
||||
#[must_use]
|
||||
pub fn weight(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[Vec<f32>],
|
||||
values: &[Vec<f32>],
|
||||
) -> (Vec<f32>, Vec<f32>) {
|
||||
if keys.is_empty() || values.is_empty() {
|
||||
return (vec![0.0; self.dim], vec![]);
|
||||
}
|
||||
|
||||
// Compute per-BSSID attention scores (softmax of q·k / sqrt(d))
|
||||
let scores = self.compute_scores(query, keys);
|
||||
|
||||
// Weighted sum of values
|
||||
let mut weighted = vec![0.0f32; self.dim];
|
||||
for (i, score) in scores.iter().enumerate() {
|
||||
if let Some(val) = values.get(i) {
|
||||
for (d, v) in weighted.iter_mut().zip(val.iter()) {
|
||||
*d += score * v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(weighted, scores)
|
||||
}
|
||||
|
||||
/// Compute raw attention scores (softmax of q*k / sqrt(d)).
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
fn compute_scores(&self, query: &[f32], keys: &[Vec<f32>]) -> Vec<f32> {
|
||||
let scale = (self.dim as f32).sqrt();
|
||||
let mut scores: Vec<f32> = keys
|
||||
.iter()
|
||||
.map(|key| {
|
||||
let dot: f32 = query.iter().zip(key.iter()).map(|(q, k)| q * k).sum();
|
||||
dot / scale
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Softmax
|
||||
let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
|
||||
let sum_exp: f32 = scores.iter().map(|&s| (s - max_score).exp()).sum();
|
||||
for s in &mut scores {
|
||||
*s = (*s - max_score).exp() / sum_exp;
|
||||
}
|
||||
scores
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn empty_input_returns_zero() {
|
||||
let weighter = AttentionWeighter::new(1);
|
||||
let (output, scores) = weighter.weight(&[0.0], &[], &[]);
|
||||
assert_eq!(output, vec![0.0]);
|
||||
assert!(scores.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn single_bssid_gets_full_weight() {
|
||||
let weighter = AttentionWeighter::new(1);
|
||||
let query = vec![1.0];
|
||||
let keys = vec![vec![1.0]];
|
||||
let values = vec![vec![5.0]];
|
||||
let (output, scores) = weighter.weight(&query, &keys, &values);
|
||||
assert!((scores[0] - 1.0).abs() < 1e-5, "single BSSID should have weight 1.0");
|
||||
assert!((output[0] - 5.0).abs() < 1e-3, "output should equal the single value");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn higher_residual_gets_more_weight() {
|
||||
let weighter = AttentionWeighter::new(1);
|
||||
let query = vec![1.0];
|
||||
// BSSID 0 has low residual, BSSID 1 has high residual
|
||||
let keys = vec![vec![0.1], vec![10.0]];
|
||||
let values = vec![vec![1.0], vec![1.0]];
|
||||
let (_output, scores) = weighter.weight(&query, &keys, &values);
|
||||
assert!(
|
||||
scores[1] > scores[0],
|
||||
"high-residual BSSID should get higher weight: {scores:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scores_sum_to_one() {
|
||||
let weighter = AttentionWeighter::new(1);
|
||||
let query = vec![1.0];
|
||||
let keys = vec![vec![0.5], vec![1.0], vec![2.0]];
|
||||
let values = vec![vec![1.0], vec![2.0], vec![3.0]];
|
||||
let (_output, scores) = weighter.weight(&query, &keys, &values);
|
||||
let sum: f32 = scores.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-5, "scores should sum to 1.0, got {sum}");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,277 @@
|
||||
//! Stage 5: Coarse breathing rate extraction.
|
||||
//!
|
||||
//! Extracts respiratory rate from body-sensitive BSSID oscillations.
|
||||
//! Uses a simple bandpass filter (0.1-0.5 Hz) and zero-crossing
|
||||
//! analysis rather than `OscillatoryRouter` (which is designed for
|
||||
//! gamma-band frequencies, not sub-Hz breathing).
|
||||
|
||||
/// Coarse breathing extractor from multi-BSSID signal variance.
|
||||
pub struct CoarseBreathingExtractor {
|
||||
/// Combined filtered signal history.
|
||||
filtered_history: Vec<f32>,
|
||||
/// Window size for analysis.
|
||||
window: usize,
|
||||
/// Maximum tracked BSSIDs.
|
||||
n_bssids: usize,
|
||||
/// Breathing band low cutoff (Hz).
|
||||
freq_low: f32,
|
||||
/// Breathing band high cutoff (Hz).
|
||||
freq_high: f32,
|
||||
/// Sample rate (Hz) -- typically 2 Hz for Tier 1.
|
||||
sample_rate: f32,
|
||||
/// IIR filter state (simple 2nd-order bandpass).
|
||||
filter_state: IirState,
|
||||
}
|
||||
|
||||
/// Simple IIR bandpass filter state (2nd order).
|
||||
#[derive(Clone, Debug)]
|
||||
struct IirState {
|
||||
x1: f32,
|
||||
x2: f32,
|
||||
y1: f32,
|
||||
y2: f32,
|
||||
}
|
||||
|
||||
impl Default for IirState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
x1: 0.0,
|
||||
x2: 0.0,
|
||||
y1: 0.0,
|
||||
y2: 0.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CoarseBreathingExtractor {
|
||||
/// Create a breathing extractor.
|
||||
///
|
||||
/// - `n_bssids`: maximum BSSID slots.
|
||||
/// - `sample_rate`: input sample rate in Hz.
|
||||
/// - `freq_low`: breathing band low cutoff (default 0.1 Hz).
|
||||
/// - `freq_high`: breathing band high cutoff (default 0.5 Hz).
|
||||
#[must_use]
|
||||
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
|
||||
pub fn new(n_bssids: usize, sample_rate: f32, freq_low: f32, freq_high: f32) -> Self {
|
||||
let window = (sample_rate * 30.0) as usize; // 30 seconds of data
|
||||
Self {
|
||||
filtered_history: Vec::with_capacity(window),
|
||||
window,
|
||||
n_bssids,
|
||||
freq_low,
|
||||
freq_high,
|
||||
sample_rate,
|
||||
filter_state: IirState::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with defaults suitable for Tier 1 (2 Hz sample rate).
|
||||
#[must_use]
|
||||
pub fn tier1_default(n_bssids: usize) -> Self {
|
||||
Self::new(n_bssids, 2.0, 0.1, 0.5)
|
||||
}
|
||||
|
||||
/// Process a frame of residuals with attention weights.
|
||||
/// Returns estimated breathing rate (BPM) if detectable.
|
||||
///
|
||||
/// - `residuals`: per-BSSID residuals from `PredictiveGate`.
|
||||
/// - `weights`: per-BSSID attention weights.
|
||||
pub fn extract(&mut self, residuals: &[f32], weights: &[f32]) -> Option<BreathingEstimate> {
|
||||
let n = residuals.len().min(self.n_bssids);
|
||||
if n == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Compute weighted sum of residuals for breathing analysis
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
let weighted_signal: f32 = residuals
|
||||
.iter()
|
||||
.enumerate()
|
||||
.take(n)
|
||||
.map(|(i, &r)| {
|
||||
let w = weights.get(i).copied().unwrap_or(1.0 / n as f32);
|
||||
r * w
|
||||
})
|
||||
.sum();
|
||||
|
||||
// Apply bandpass filter
|
||||
let filtered = self.bandpass_filter(weighted_signal);
|
||||
|
||||
// Store in history
|
||||
self.filtered_history.push(filtered);
|
||||
if self.filtered_history.len() > self.window {
|
||||
self.filtered_history.remove(0);
|
||||
}
|
||||
|
||||
// Need at least 10 seconds of data to estimate breathing
|
||||
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
|
||||
let min_samples = (self.sample_rate * 10.0) as usize;
|
||||
if self.filtered_history.len() < min_samples {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Zero-crossing rate -> frequency
|
||||
let crossings = count_zero_crossings(&self.filtered_history);
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
let duration_s = self.filtered_history.len() as f32 / self.sample_rate;
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
let frequency_hz = crossings as f32 / (2.0 * duration_s);
|
||||
|
||||
// Validate frequency is in breathing range
|
||||
if frequency_hz < self.freq_low || frequency_hz > self.freq_high {
|
||||
return None;
|
||||
}
|
||||
|
||||
let bpm = frequency_hz * 60.0;
|
||||
|
||||
// Compute confidence based on signal regularity
|
||||
let confidence = compute_confidence(&self.filtered_history);
|
||||
|
||||
Some(BreathingEstimate {
|
||||
bpm,
|
||||
frequency_hz,
|
||||
confidence,
|
||||
})
|
||||
}
|
||||
|
||||
/// Simple 2nd-order IIR bandpass filter.
|
||||
fn bandpass_filter(&mut self, input: f32) -> f32 {
|
||||
let state = &mut self.filter_state;
|
||||
|
||||
// Butterworth bandpass coefficients for [freq_low, freq_high] at given sample rate.
|
||||
// Using bilinear transform approximation.
|
||||
let omega_low = 2.0 * std::f32::consts::PI * self.freq_low / self.sample_rate;
|
||||
let omega_high = 2.0 * std::f32::consts::PI * self.freq_high / self.sample_rate;
|
||||
let bw = omega_high - omega_low;
|
||||
let center = f32::midpoint(omega_low, omega_high);
|
||||
|
||||
let r = 1.0 - bw / 2.0;
|
||||
let cos_w0 = center.cos();
|
||||
|
||||
// y[n] = (1-r)*(x[n] - x[n-2]) + 2*r*cos(w0)*y[n-1] - r^2*y[n-2]
|
||||
let output =
|
||||
(1.0 - r) * (input - state.x2) + 2.0 * r * cos_w0 * state.y1 - r * r * state.y2;
|
||||
|
||||
state.x2 = state.x1;
|
||||
state.x1 = input;
|
||||
state.y2 = state.y1;
|
||||
state.y1 = output;
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Reset all filter states and histories.
|
||||
pub fn reset(&mut self) {
|
||||
self.filtered_history.clear();
|
||||
self.filter_state = IirState::default();
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of breathing extraction.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BreathingEstimate {
|
||||
/// Estimated breathing rate in breaths per minute.
|
||||
pub bpm: f32,
|
||||
/// Estimated breathing frequency in Hz.
|
||||
pub frequency_hz: f32,
|
||||
/// Confidence in the estimate [0, 1].
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
/// Compute confidence in the breathing estimate based on signal regularity.
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
fn compute_confidence(history: &[f32]) -> f32 {
|
||||
if history.len() < 4 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Use variance-based SNR as a confidence metric
|
||||
let mean: f32 = history.iter().sum::<f32>() / history.len() as f32;
|
||||
let variance: f32 = history
|
||||
.iter()
|
||||
.map(|x| (x - mean) * (x - mean))
|
||||
.sum::<f32>()
|
||||
/ history.len() as f32;
|
||||
|
||||
if variance < 1e-10 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Simple SNR-based confidence
|
||||
let peak = history.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
|
||||
let noise = variance.sqrt();
|
||||
|
||||
let snr = if noise > 1e-10 { peak / noise } else { 0.0 };
|
||||
|
||||
// Map SNR to [0, 1] confidence
|
||||
(snr / 5.0).min(1.0)
|
||||
}
|
||||
|
||||
/// Count zero crossings in a signal.
|
||||
fn count_zero_crossings(signal: &[f32]) -> usize {
|
||||
signal.windows(2).filter(|w| w[0] * w[1] < 0.0).count()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn no_data_returns_none() {
|
||||
let mut ext = CoarseBreathingExtractor::tier1_default(4);
|
||||
assert!(ext.extract(&[], &[]).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn insufficient_history_returns_none() {
|
||||
let mut ext = CoarseBreathingExtractor::tier1_default(4);
|
||||
// Just a few frames are not enough
|
||||
for _ in 0..5 {
|
||||
assert!(ext.extract(&[1.0, 2.0], &[0.5, 0.5]).is_none());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sinusoidal_breathing_detected() {
|
||||
let mut ext = CoarseBreathingExtractor::new(1, 10.0, 0.1, 0.5);
|
||||
let breathing_freq = 0.25; // 15 BPM
|
||||
|
||||
// Generate 60 seconds of sinusoidal breathing signal at 10 Hz
|
||||
for i in 0..600 {
|
||||
let t = i as f32 / 10.0;
|
||||
let signal = (2.0 * std::f32::consts::PI * breathing_freq * t).sin();
|
||||
ext.extract(&[signal], &[1.0]);
|
||||
}
|
||||
|
||||
let result = ext.extract(&[0.0], &[1.0]);
|
||||
if let Some(est) = result {
|
||||
// Should be approximately 15 BPM (0.25 Hz * 60)
|
||||
assert!(
|
||||
est.bpm > 5.0 && est.bpm < 40.0,
|
||||
"estimated BPM should be in breathing range: {}",
|
||||
est.bpm
|
||||
);
|
||||
}
|
||||
// It is acceptable if None -- the bandpass filter may need tuning
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_crossings_count() {
|
||||
let signal = vec![1.0, -1.0, 1.0, -1.0, 1.0];
|
||||
assert_eq!(count_zero_crossings(&signal), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_crossings_constant() {
|
||||
let signal = vec![1.0, 1.0, 1.0, 1.0];
|
||||
assert_eq!(count_zero_crossings(&signal), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reset_clears_state() {
|
||||
let mut ext = CoarseBreathingExtractor::tier1_default(2);
|
||||
ext.extract(&[1.0, 2.0], &[0.5, 0.5]);
|
||||
ext.reset();
|
||||
assert!(ext.filtered_history.is_empty());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,267 @@
|
||||
//! Stage 3: BSSID spatial correlation via GNN message passing.
|
||||
//!
|
||||
//! Builds a cross-correlation graph where nodes are BSSIDs and edges
|
||||
//! represent temporal cross-correlation between their RSSI histories.
|
||||
//! A single message-passing step identifies co-varying BSSID clusters
|
||||
//! that are likely affected by the same person.
|
||||
|
||||
/// BSSID correlator that computes pairwise Pearson correlation
|
||||
/// and identifies co-varying clusters.
|
||||
///
|
||||
/// Note: The full `RuvectorLayer` GNN requires matching dimension
|
||||
/// weights trained on CSI data. For Phase 2 we use a lightweight
|
||||
/// correlation-based approach that can be upgraded to GNN later.
|
||||
pub struct BssidCorrelator {
|
||||
/// Per-BSSID history buffers for correlation computation.
|
||||
histories: Vec<Vec<f32>>,
|
||||
/// Maximum history length.
|
||||
window: usize,
|
||||
/// Number of tracked BSSIDs.
|
||||
n_bssids: usize,
|
||||
/// Correlation threshold for "co-varying" classification.
|
||||
correlation_threshold: f32,
|
||||
}
|
||||
|
||||
impl BssidCorrelator {
|
||||
/// Create a new correlator.
|
||||
///
|
||||
/// - `n_bssids`: number of BSSID slots.
|
||||
/// - `window`: correlation window size (number of frames).
|
||||
/// - `correlation_threshold`: minimum |r| to consider BSSIDs co-varying.
|
||||
#[must_use]
|
||||
pub fn new(n_bssids: usize, window: usize, correlation_threshold: f32) -> Self {
|
||||
Self {
|
||||
histories: vec![Vec::with_capacity(window); n_bssids],
|
||||
window,
|
||||
n_bssids,
|
||||
correlation_threshold,
|
||||
}
|
||||
}
|
||||
|
||||
/// Push a new frame of amplitudes and compute correlation features.
|
||||
///
|
||||
/// Returns a `CorrelationResult` with the correlation matrix and
|
||||
/// cluster assignments.
|
||||
pub fn update(&mut self, amplitudes: &[f32]) -> CorrelationResult {
|
||||
let n = amplitudes.len().min(self.n_bssids);
|
||||
|
||||
// Update histories
|
||||
for (i, &) in amplitudes.iter().enumerate().take(n) {
|
||||
let hist = &mut self.histories[i];
|
||||
hist.push(amp);
|
||||
if hist.len() > self.window {
|
||||
hist.remove(0);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute pairwise Pearson correlation
|
||||
let mut corr_matrix = vec![vec![0.0f32; n]; n];
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for i in 0..n {
|
||||
corr_matrix[i][i] = 1.0;
|
||||
for j in (i + 1)..n {
|
||||
let r = pearson_r(&self.histories[i], &self.histories[j]);
|
||||
corr_matrix[i][j] = r;
|
||||
corr_matrix[j][i] = r;
|
||||
}
|
||||
}
|
||||
|
||||
// Find strongly correlated clusters (simple union-find)
|
||||
let clusters = self.find_clusters(&corr_matrix, n);
|
||||
|
||||
// Compute per-BSSID "spatial diversity" score:
|
||||
// how many other BSSIDs is each one correlated with
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
let diversity: Vec<f32> = (0..n)
|
||||
.map(|i| {
|
||||
let count = (0..n)
|
||||
.filter(|&j| j != i && corr_matrix[i][j].abs() > self.correlation_threshold)
|
||||
.count();
|
||||
count as f32 / (n.max(1) - 1) as f32
|
||||
})
|
||||
.collect();
|
||||
|
||||
CorrelationResult {
|
||||
matrix: corr_matrix,
|
||||
clusters,
|
||||
diversity,
|
||||
n_active: n,
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple cluster assignment via thresholded correlation.
|
||||
fn find_clusters(&self, corr: &[Vec<f32>], n: usize) -> Vec<usize> {
|
||||
let mut cluster_id = vec![0usize; n];
|
||||
let mut next_cluster = 0usize;
|
||||
let mut assigned = vec![false; n];
|
||||
|
||||
for i in 0..n {
|
||||
if assigned[i] {
|
||||
continue;
|
||||
}
|
||||
cluster_id[i] = next_cluster;
|
||||
assigned[i] = true;
|
||||
|
||||
// BFS: assign same cluster to correlated BSSIDs
|
||||
let mut queue = vec![i];
|
||||
while let Some(current) = queue.pop() {
|
||||
for j in 0..n {
|
||||
if !assigned[j] && corr[current][j].abs() > self.correlation_threshold {
|
||||
cluster_id[j] = next_cluster;
|
||||
assigned[j] = true;
|
||||
queue.push(j);
|
||||
}
|
||||
}
|
||||
}
|
||||
next_cluster += 1;
|
||||
}
|
||||
cluster_id
|
||||
}
|
||||
|
||||
/// Reset all correlation histories.
|
||||
pub fn reset(&mut self) {
|
||||
for h in &mut self.histories {
|
||||
h.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of correlation analysis.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CorrelationResult {
|
||||
/// n x n Pearson correlation matrix.
|
||||
pub matrix: Vec<Vec<f32>>,
|
||||
/// Cluster assignment per BSSID.
|
||||
pub clusters: Vec<usize>,
|
||||
/// Per-BSSID spatial diversity score [0, 1].
|
||||
pub diversity: Vec<f32>,
|
||||
/// Number of active BSSIDs in this frame.
|
||||
pub n_active: usize,
|
||||
}
|
||||
|
||||
impl CorrelationResult {
|
||||
/// Number of distinct clusters.
|
||||
#[must_use]
|
||||
pub fn n_clusters(&self) -> usize {
|
||||
self.clusters.iter().copied().max().map_or(0, |m| m + 1)
|
||||
}
|
||||
|
||||
/// Mean absolute correlation (proxy for signal coherence).
|
||||
#[must_use]
|
||||
pub fn mean_correlation(&self) -> f32 {
|
||||
if self.n_active < 2 {
|
||||
return 0.0;
|
||||
}
|
||||
let mut sum = 0.0f32;
|
||||
let mut count = 0;
|
||||
for i in 0..self.n_active {
|
||||
for j in (i + 1)..self.n_active {
|
||||
sum += self.matrix[i][j].abs();
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
let mean = if count == 0 { 0.0 } else { sum / count as f32 };
|
||||
mean
|
||||
}
|
||||
}
|
||||
|
||||
/// Pearson correlation coefficient between two equal-length slices.
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
fn pearson_r(x: &[f32], y: &[f32]) -> f32 {
|
||||
let n = x.len().min(y.len());
|
||||
if n < 2 {
|
||||
return 0.0;
|
||||
}
|
||||
let n_f = n as f32;
|
||||
|
||||
let mean_x: f32 = x.iter().take(n).sum::<f32>() / n_f;
|
||||
let mean_y: f32 = y.iter().take(n).sum::<f32>() / n_f;
|
||||
|
||||
let mut cov = 0.0f32;
|
||||
let mut var_x = 0.0f32;
|
||||
let mut var_y = 0.0f32;
|
||||
|
||||
for i in 0..n {
|
||||
let dx = x[i] - mean_x;
|
||||
let dy = y[i] - mean_y;
|
||||
cov += dx * dy;
|
||||
var_x += dx * dx;
|
||||
var_y += dy * dy;
|
||||
}
|
||||
|
||||
let denom = (var_x * var_y).sqrt();
|
||||
if denom < 1e-12 {
|
||||
0.0
|
||||
} else {
|
||||
cov / denom
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn pearson_perfect_correlation() {
|
||||
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
|
||||
let r = pearson_r(&x, &y);
|
||||
assert!((r - 1.0).abs() < 1e-5, "perfect positive correlation: {r}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pearson_negative_correlation() {
|
||||
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let y = vec![10.0, 8.0, 6.0, 4.0, 2.0];
|
||||
let r = pearson_r(&x, &y);
|
||||
assert!((r - (-1.0)).abs() < 1e-5, "perfect negative correlation: {r}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pearson_no_correlation() {
|
||||
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let y = vec![5.0, 1.0, 4.0, 2.0, 3.0]; // shuffled
|
||||
let r = pearson_r(&x, &y);
|
||||
assert!(r.abs() < 0.5, "low correlation expected: {r}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn correlator_basic_update() {
|
||||
let mut corr = BssidCorrelator::new(3, 10, 0.7);
|
||||
// Push several identical frames
|
||||
for _ in 0..5 {
|
||||
corr.update(&[1.0, 2.0, 3.0]);
|
||||
}
|
||||
let result = corr.update(&[1.0, 2.0, 3.0]);
|
||||
assert_eq!(result.n_active, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn correlator_detects_covarying_bssids() {
|
||||
let mut corr = BssidCorrelator::new(3, 20, 0.8);
|
||||
// BSSID 0 and 1 co-vary, BSSID 2 is independent
|
||||
for i in 0..20 {
|
||||
let v = i as f32;
|
||||
corr.update(&[v, v * 2.0, 5.0]); // 0 and 1 correlate, 2 is constant
|
||||
}
|
||||
let result = corr.update(&[20.0, 40.0, 5.0]);
|
||||
// BSSIDs 0 and 1 should be in the same cluster
|
||||
assert_eq!(
|
||||
result.clusters[0], result.clusters[1],
|
||||
"co-varying BSSIDs should cluster: {:?}",
|
||||
result.clusters
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mean_correlation_zero_for_one_bssid() {
|
||||
let result = CorrelationResult {
|
||||
matrix: vec![vec![1.0]],
|
||||
clusters: vec![0],
|
||||
diversity: vec![0.0],
|
||||
n_active: 1,
|
||||
};
|
||||
assert!((result.mean_correlation() - 0.0).abs() < 1e-5);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,288 @@
|
||||
//! Stage 7: BSSID fingerprint matching via cosine similarity.
|
||||
//!
|
||||
//! Stores reference BSSID amplitude patterns for known postures
|
||||
//! (standing, sitting, walking, empty) and classifies new observations
|
||||
//! by retrieving the nearest stored template.
|
||||
//!
|
||||
//! This is a pure-Rust implementation using cosine similarity. When
|
||||
//! `ruvector-nervous-system` becomes available, the inner store can
|
||||
//! be replaced with `ModernHopfield` for richer associative memory.
|
||||
|
||||
use crate::domain::result::PostureClass;
|
||||
|
||||
/// A stored posture fingerprint template.
|
||||
#[derive(Debug, Clone)]
|
||||
struct PostureTemplate {
|
||||
/// Reference amplitude pattern (normalised).
|
||||
pattern: Vec<f32>,
|
||||
/// The posture label for this template.
|
||||
label: PostureClass,
|
||||
}
|
||||
|
||||
/// BSSID fingerprint matcher using cosine similarity.
|
||||
pub struct FingerprintMatcher {
|
||||
/// Stored reference templates.
|
||||
templates: Vec<PostureTemplate>,
|
||||
/// Minimum cosine similarity for a match.
|
||||
confidence_threshold: f32,
|
||||
/// Expected dimension (number of BSSID slots).
|
||||
n_bssids: usize,
|
||||
}
|
||||
|
||||
impl FingerprintMatcher {
|
||||
/// Create a new fingerprint matcher.
|
||||
///
|
||||
/// - `n_bssids`: number of BSSID slots (pattern dimension).
|
||||
/// - `confidence_threshold`: minimum cosine similarity for a match.
|
||||
#[must_use]
|
||||
pub fn new(n_bssids: usize, confidence_threshold: f32) -> Self {
|
||||
Self {
|
||||
templates: Vec::new(),
|
||||
confidence_threshold,
|
||||
n_bssids,
|
||||
}
|
||||
}
|
||||
|
||||
/// Store a reference pattern with its posture label.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the pattern dimension does not match `n_bssids`.
|
||||
pub fn store_pattern(
|
||||
&mut self,
|
||||
pattern: Vec<f32>,
|
||||
label: PostureClass,
|
||||
) -> Result<(), String> {
|
||||
if pattern.len() != self.n_bssids {
|
||||
return Err(format!(
|
||||
"pattern dimension {} != expected {}",
|
||||
pattern.len(),
|
||||
self.n_bssids
|
||||
));
|
||||
}
|
||||
self.templates.push(PostureTemplate { pattern, label });
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Classify an observation by matching against stored fingerprints.
|
||||
///
|
||||
/// Returns the best-matching posture and similarity score, or `None`
|
||||
/// if no patterns are stored or similarity is below threshold.
|
||||
#[must_use]
|
||||
pub fn classify(&self, observation: &[f32]) -> Option<(PostureClass, f32)> {
|
||||
if self.templates.is_empty() || observation.len() != self.n_bssids {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut best_label = None;
|
||||
let mut best_sim = f32::NEG_INFINITY;
|
||||
|
||||
for tmpl in &self.templates {
|
||||
let sim = cosine_similarity(&tmpl.pattern, observation);
|
||||
if sim > best_sim {
|
||||
best_sim = sim;
|
||||
best_label = Some(tmpl.label);
|
||||
}
|
||||
}
|
||||
|
||||
match best_label {
|
||||
Some(label) if best_sim >= self.confidence_threshold => Some((label, best_sim)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Match posture and return a structured result.
|
||||
#[must_use]
|
||||
pub fn match_posture(&self, observation: &[f32]) -> MatchResult {
|
||||
match self.classify(observation) {
|
||||
Some((posture, confidence)) => MatchResult {
|
||||
posture: Some(posture),
|
||||
confidence,
|
||||
matched: true,
|
||||
},
|
||||
None => MatchResult {
|
||||
posture: None,
|
||||
confidence: 0.0,
|
||||
matched: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate default templates from a baseline signal.
|
||||
///
|
||||
/// Creates heuristic patterns for standing, sitting, and empty by
|
||||
/// scaling the baseline amplitude pattern.
|
||||
pub fn generate_defaults(&mut self, baseline: &[f32]) {
|
||||
if baseline.len() != self.n_bssids {
|
||||
return;
|
||||
}
|
||||
|
||||
// Empty: very low amplitude (background noise only)
|
||||
let empty: Vec<f32> = baseline.iter().map(|&a| a * 0.1).collect();
|
||||
let _ = self.store_pattern(empty, PostureClass::Empty);
|
||||
|
||||
// Standing: moderate perturbation of some BSSIDs
|
||||
let standing: Vec<f32> = baseline
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &a)| if i % 3 == 0 { a * 1.3 } else { a })
|
||||
.collect();
|
||||
let _ = self.store_pattern(standing, PostureClass::Standing);
|
||||
|
||||
// Sitting: different perturbation pattern
|
||||
let sitting: Vec<f32> = baseline
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &a)| if i % 2 == 0 { a * 1.2 } else { a * 0.9 })
|
||||
.collect();
|
||||
let _ = self.store_pattern(sitting, PostureClass::Sitting);
|
||||
}
|
||||
|
||||
/// Number of stored patterns.
|
||||
#[must_use]
|
||||
pub fn num_patterns(&self) -> usize {
|
||||
self.templates.len()
|
||||
}
|
||||
|
||||
/// Clear all stored patterns.
|
||||
pub fn clear(&mut self) {
|
||||
self.templates.clear();
|
||||
}
|
||||
|
||||
/// Set the minimum similarity threshold for classification.
|
||||
pub fn set_confidence_threshold(&mut self, threshold: f32) {
|
||||
self.confidence_threshold = threshold;
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of fingerprint matching.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatchResult {
|
||||
/// Matched posture class (None if no match).
|
||||
pub posture: Option<PostureClass>,
|
||||
/// Cosine similarity of the best match.
|
||||
pub confidence: f32,
|
||||
/// Whether a match was found above threshold.
|
||||
pub matched: bool,
|
||||
}
|
||||
|
||||
/// Cosine similarity between two vectors.
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
let n = a.len().min(b.len());
|
||||
if n == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mut dot = 0.0f32;
|
||||
let mut norm_a = 0.0f32;
|
||||
let mut norm_b = 0.0f32;
|
||||
|
||||
for i in 0..n {
|
||||
dot += a[i] * b[i];
|
||||
norm_a += a[i] * a[i];
|
||||
norm_b += b[i] * b[i];
|
||||
}
|
||||
|
||||
let denom = (norm_a * norm_b).sqrt();
|
||||
if denom < 1e-12 {
|
||||
0.0
|
||||
} else {
|
||||
dot / denom
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn empty_matcher_returns_none() {
|
||||
let matcher = FingerprintMatcher::new(4, 0.5);
|
||||
assert!(matcher.classify(&[1.0, 2.0, 3.0, 4.0]).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_dimension_returns_none() {
|
||||
let mut matcher = FingerprintMatcher::new(4, 0.5);
|
||||
matcher
|
||||
.store_pattern(vec![1.0; 4], PostureClass::Standing)
|
||||
.unwrap();
|
||||
// Wrong dimension
|
||||
assert!(matcher.classify(&[1.0, 2.0]).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn store_and_recall() {
|
||||
let mut matcher = FingerprintMatcher::new(4, 0.5);
|
||||
|
||||
// Store distinct patterns
|
||||
matcher
|
||||
.store_pattern(vec![1.0, 0.0, 0.0, 0.0], PostureClass::Standing)
|
||||
.unwrap();
|
||||
matcher
|
||||
.store_pattern(vec![0.0, 1.0, 0.0, 0.0], PostureClass::Sitting)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(matcher.num_patterns(), 2);
|
||||
|
||||
// Query close to "Standing" pattern
|
||||
let result = matcher.classify(&[0.9, 0.1, 0.0, 0.0]);
|
||||
if let Some((posture, sim)) = result {
|
||||
assert_eq!(posture, PostureClass::Standing);
|
||||
assert!(sim > 0.5, "similarity should be above threshold: {sim}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_dim_store_rejected() {
|
||||
let mut matcher = FingerprintMatcher::new(4, 0.5);
|
||||
let result = matcher.store_pattern(vec![1.0, 2.0], PostureClass::Empty);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clear_removes_all() {
|
||||
let mut matcher = FingerprintMatcher::new(2, 0.5);
|
||||
matcher
|
||||
.store_pattern(vec![1.0, 0.0], PostureClass::Standing)
|
||||
.unwrap();
|
||||
assert_eq!(matcher.num_patterns(), 1);
|
||||
matcher.clear();
|
||||
assert_eq!(matcher.num_patterns(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosine_similarity_identical() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![1.0, 2.0, 3.0];
|
||||
let sim = cosine_similarity(&a, &b);
|
||||
assert!((sim - 1.0).abs() < 1e-5, "identical vectors: {sim}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosine_similarity_orthogonal() {
|
||||
let a = vec![1.0, 0.0];
|
||||
let b = vec![0.0, 1.0];
|
||||
let sim = cosine_similarity(&a, &b);
|
||||
assert!(sim.abs() < 1e-5, "orthogonal vectors: {sim}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn match_posture_result() {
|
||||
let mut matcher = FingerprintMatcher::new(3, 0.5);
|
||||
matcher
|
||||
.store_pattern(vec![1.0, 0.0, 0.0], PostureClass::Standing)
|
||||
.unwrap();
|
||||
|
||||
let result = matcher.match_posture(&[0.95, 0.05, 0.0]);
|
||||
assert!(result.matched);
|
||||
assert_eq!(result.posture, Some(PostureClass::Standing));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_defaults_creates_templates() {
|
||||
let mut matcher = FingerprintMatcher::new(4, 0.3);
|
||||
matcher.generate_defaults(&[1.0, 2.0, 3.0, 4.0]);
|
||||
assert_eq!(matcher.num_patterns(), 3); // Empty, Standing, Sitting
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
//! Signal Intelligence pipeline (Phase 2, ADR-022).
|
||||
//!
|
||||
//! Composes `RuVector` primitives into a multi-stage sensing pipeline
|
||||
//! that transforms multi-BSSID RSSI frames into presence, motion,
|
||||
//! and coarse vital sign estimates.
|
||||
//!
|
||||
//! ## Stages
|
||||
//!
|
||||
//! 1. [`predictive_gate`] -- residual gating via `PredictiveLayer`
|
||||
//! 2. [`attention_weighter`] -- BSSID attention weighting
|
||||
//! 3. [`correlator`] -- cross-BSSID Pearson correlation & clustering
|
||||
//! 4. [`motion_estimator`] -- multi-AP motion estimation
|
||||
//! 5. [`breathing_extractor`] -- coarse breathing rate extraction
|
||||
//! 6. [`quality_gate`] -- ruQu three-filter quality gate
|
||||
//! 7. [`fingerprint_matcher`] -- `ModernHopfield` posture fingerprinting
|
||||
//! 8. [`orchestrator`] -- full pipeline orchestrator
|
||||
|
||||
#[cfg(feature = "pipeline")]
|
||||
pub mod predictive_gate;
|
||||
#[cfg(feature = "pipeline")]
|
||||
pub mod attention_weighter;
|
||||
#[cfg(feature = "pipeline")]
|
||||
pub mod correlator;
|
||||
#[cfg(feature = "pipeline")]
|
||||
pub mod motion_estimator;
|
||||
#[cfg(feature = "pipeline")]
|
||||
pub mod breathing_extractor;
|
||||
#[cfg(feature = "pipeline")]
|
||||
pub mod quality_gate;
|
||||
#[cfg(feature = "pipeline")]
|
||||
pub mod fingerprint_matcher;
|
||||
#[cfg(feature = "pipeline")]
|
||||
pub mod orchestrator;
|
||||
|
||||
#[cfg(feature = "pipeline")]
|
||||
pub use orchestrator::WindowsWifiPipeline;
|
||||
@@ -0,0 +1,210 @@
|
||||
//! Stage 4: Multi-AP motion estimation.
|
||||
//!
|
||||
//! Combines per-BSSID residuals, attention weights, and correlation
|
||||
//! features to estimate overall motion intensity and classify
|
||||
//! motion level (None / Minimal / Moderate / High).
|
||||
|
||||
use crate::domain::result::MotionLevel;
|
||||
|
||||
/// Multi-AP motion estimator using weighted variance of BSSID residuals.
|
||||
pub struct MultiApMotionEstimator {
|
||||
/// EMA smoothing factor for motion score.
|
||||
alpha: f32,
|
||||
/// Running EMA of motion score.
|
||||
ema_motion: f32,
|
||||
/// Motion threshold for None->Minimal transition.
|
||||
threshold_minimal: f32,
|
||||
/// Motion threshold for Minimal->Moderate transition.
|
||||
threshold_moderate: f32,
|
||||
/// Motion threshold for Moderate->High transition.
|
||||
threshold_high: f32,
|
||||
}
|
||||
|
||||
impl MultiApMotionEstimator {
|
||||
/// Create a motion estimator with default thresholds.
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
alpha: 0.3,
|
||||
ema_motion: 0.0,
|
||||
threshold_minimal: 0.02,
|
||||
threshold_moderate: 0.10,
|
||||
threshold_high: 0.30,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with custom thresholds.
|
||||
#[must_use]
|
||||
pub fn with_thresholds(minimal: f32, moderate: f32, high: f32) -> Self {
|
||||
Self {
|
||||
alpha: 0.3,
|
||||
ema_motion: 0.0,
|
||||
threshold_minimal: minimal,
|
||||
threshold_moderate: moderate,
|
||||
threshold_high: high,
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate motion from weighted residuals.
|
||||
///
|
||||
/// - `residuals`: per-BSSID residual from `PredictiveGate`.
|
||||
/// - `weights`: per-BSSID attention weights from `AttentionWeighter`.
|
||||
/// - `diversity`: per-BSSID correlation diversity from `BssidCorrelator`.
|
||||
///
|
||||
/// Returns `MotionEstimate` with score and level.
|
||||
pub fn estimate(
|
||||
&mut self,
|
||||
residuals: &[f32],
|
||||
weights: &[f32],
|
||||
diversity: &[f32],
|
||||
) -> MotionEstimate {
|
||||
let n = residuals.len();
|
||||
if n == 0 {
|
||||
return MotionEstimate {
|
||||
score: 0.0,
|
||||
level: MotionLevel::None,
|
||||
weighted_variance: 0.0,
|
||||
n_contributing: 0,
|
||||
};
|
||||
}
|
||||
|
||||
// Weighted variance of residuals (body-sensitive BSSIDs contribute more)
|
||||
let mut weighted_sum = 0.0f32;
|
||||
let mut weight_total = 0.0f32;
|
||||
let mut n_contributing = 0usize;
|
||||
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
for (i, residual) in residuals.iter().enumerate() {
|
||||
let w = weights.get(i).copied().unwrap_or(1.0 / n as f32);
|
||||
let d = diversity.get(i).copied().unwrap_or(0.5);
|
||||
// Combine attention weight with diversity (correlated BSSIDs
|
||||
// that respond together are better indicators)
|
||||
let combined_w = w * (0.5 + 0.5 * d);
|
||||
weighted_sum += combined_w * residual.abs();
|
||||
weight_total += combined_w;
|
||||
|
||||
if residual.abs() > 0.001 {
|
||||
n_contributing += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let weighted_variance = if weight_total > 1e-9 {
|
||||
weighted_sum / weight_total
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// EMA smoothing
|
||||
self.ema_motion = self.alpha * weighted_variance + (1.0 - self.alpha) * self.ema_motion;
|
||||
|
||||
let level = if self.ema_motion < self.threshold_minimal {
|
||||
MotionLevel::None
|
||||
} else if self.ema_motion < self.threshold_moderate {
|
||||
MotionLevel::Minimal
|
||||
} else if self.ema_motion < self.threshold_high {
|
||||
MotionLevel::Moderate
|
||||
} else {
|
||||
MotionLevel::High
|
||||
};
|
||||
|
||||
MotionEstimate {
|
||||
score: self.ema_motion,
|
||||
level,
|
||||
weighted_variance,
|
||||
n_contributing,
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset the EMA state.
|
||||
pub fn reset(&mut self) {
|
||||
self.ema_motion = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MultiApMotionEstimator {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of motion estimation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MotionEstimate {
|
||||
/// Smoothed motion score (EMA of weighted variance).
|
||||
pub score: f32,
|
||||
/// Classified motion level.
|
||||
pub level: MotionLevel,
|
||||
/// Raw weighted variance before smoothing.
|
||||
pub weighted_variance: f32,
|
||||
/// Number of BSSIDs with non-zero residuals.
|
||||
pub n_contributing: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn no_residuals_yields_no_motion() {
|
||||
let mut est = MultiApMotionEstimator::new();
|
||||
let result = est.estimate(&[], &[], &[]);
|
||||
assert_eq!(result.level, MotionLevel::None);
|
||||
assert!((result.score - 0.0).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_residuals_yield_no_motion() {
|
||||
let mut est = MultiApMotionEstimator::new();
|
||||
let residuals = vec![0.0, 0.0, 0.0];
|
||||
let weights = vec![0.33, 0.33, 0.34];
|
||||
let diversity = vec![0.5, 0.5, 0.5];
|
||||
let result = est.estimate(&residuals, &weights, &diversity);
|
||||
assert_eq!(result.level, MotionLevel::None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn large_residuals_yield_high_motion() {
|
||||
let mut est = MultiApMotionEstimator::new();
|
||||
let residuals = vec![5.0, 5.0, 5.0];
|
||||
let weights = vec![0.33, 0.33, 0.34];
|
||||
let diversity = vec![1.0, 1.0, 1.0];
|
||||
// Push several frames to overcome EMA smoothing
|
||||
for _ in 0..20 {
|
||||
est.estimate(&residuals, &weights, &diversity);
|
||||
}
|
||||
let result = est.estimate(&residuals, &weights, &diversity);
|
||||
assert_eq!(result.level, MotionLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ema_smooths_transients() {
|
||||
let mut est = MultiApMotionEstimator::new();
|
||||
let big = vec![10.0, 10.0, 10.0];
|
||||
let zero = vec![0.0, 0.0, 0.0];
|
||||
let w = vec![0.33, 0.33, 0.34];
|
||||
let d = vec![0.5, 0.5, 0.5];
|
||||
|
||||
// One big spike followed by zeros
|
||||
est.estimate(&big, &w, &d);
|
||||
let r1 = est.estimate(&zero, &w, &d);
|
||||
let r2 = est.estimate(&zero, &w, &d);
|
||||
// Score should decay
|
||||
assert!(r2.score < r1.score, "EMA should decay: {} < {}", r2.score, r1.score);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn n_contributing_counts_nonzero() {
|
||||
let mut est = MultiApMotionEstimator::new();
|
||||
let residuals = vec![0.0, 1.0, 0.0, 2.0];
|
||||
let weights = vec![0.25; 4];
|
||||
let diversity = vec![0.5; 4];
|
||||
let result = est.estimate(&residuals, &weights, &diversity);
|
||||
assert_eq!(result.n_contributing, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_creates_estimator() {
|
||||
let est = MultiApMotionEstimator::default();
|
||||
assert!((est.threshold_minimal - 0.02).abs() < f32::EPSILON);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,432 @@
|
||||
//! Stage 8: Pipeline orchestrator (Domain Service).
|
||||
//!
|
||||
//! `WindowsWifiPipeline` connects all pipeline stages (1-7) into a
|
||||
//! single processing step that transforms a `MultiApFrame` into an
|
||||
//! `EnhancedSensingResult`.
|
||||
//!
|
||||
//! This is the Domain Service described in ADR-022 section 3.2.
|
||||
|
||||
use crate::domain::frame::MultiApFrame;
|
||||
use crate::domain::result::{
|
||||
BreathingEstimate as DomainBreathingEstimate, EnhancedSensingResult,
|
||||
MotionEstimate as DomainMotionEstimate, MotionLevel, PostureClass, SignalQuality,
|
||||
Verdict as DomainVerdict,
|
||||
};
|
||||
|
||||
use super::attention_weighter::AttentionWeighter;
|
||||
use super::breathing_extractor::CoarseBreathingExtractor;
|
||||
use super::correlator::BssidCorrelator;
|
||||
use super::fingerprint_matcher::FingerprintMatcher;
|
||||
use super::motion_estimator::MultiApMotionEstimator;
|
||||
use super::predictive_gate::PredictiveGate;
|
||||
use super::quality_gate::{QualityGate, Verdict};
|
||||
|
||||
/// Configuration for the Windows `WiFi` sensing pipeline.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PipelineConfig {
|
||||
/// Maximum number of BSSID slots.
|
||||
pub max_bssids: usize,
|
||||
/// Residual gating threshold (stage 1).
|
||||
pub gate_threshold: f32,
|
||||
/// Correlation window size in frames (stage 3).
|
||||
pub correlation_window: usize,
|
||||
/// Correlation threshold for co-varying classification (stage 3).
|
||||
pub correlation_threshold: f32,
|
||||
/// Minimum BSSIDs for a valid frame.
|
||||
pub min_bssids: usize,
|
||||
/// Enable breathing extraction (stage 5).
|
||||
pub enable_breathing: bool,
|
||||
/// Enable fingerprint matching (stage 7).
|
||||
pub enable_fingerprint: bool,
|
||||
/// Sample rate in Hz.
|
||||
pub sample_rate: f32,
|
||||
}
|
||||
|
||||
impl Default for PipelineConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_bssids: 32,
|
||||
gate_threshold: 0.05,
|
||||
correlation_window: 30,
|
||||
correlation_threshold: 0.7,
|
||||
min_bssids: 3,
|
||||
enable_breathing: true,
|
||||
enable_fingerprint: true,
|
||||
sample_rate: 2.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The complete Windows `WiFi` sensing pipeline (Domain Service).
|
||||
///
|
||||
/// Connects stages 1-7 into a single `process()` call that transforms
|
||||
/// a `MultiApFrame` into an `EnhancedSensingResult`.
|
||||
///
|
||||
/// Stages:
|
||||
/// 1. Predictive gating (EMA residual filter)
|
||||
/// 2. Attention weighting (softmax dot-product)
|
||||
/// 3. Spatial correlation (Pearson + clustering)
|
||||
/// 4. Motion estimation (weighted variance + EMA)
|
||||
/// 5. Breathing extraction (bandpass + zero-crossing)
|
||||
/// 6. Quality gate (three-filter: structural / shift / evidence)
|
||||
/// 7. Fingerprint matching (cosine similarity templates)
|
||||
pub struct WindowsWifiPipeline {
|
||||
gate: PredictiveGate,
|
||||
attention: AttentionWeighter,
|
||||
correlator: BssidCorrelator,
|
||||
motion: MultiApMotionEstimator,
|
||||
breathing: CoarseBreathingExtractor,
|
||||
quality: QualityGate,
|
||||
fingerprint: FingerprintMatcher,
|
||||
config: PipelineConfig,
|
||||
/// Whether fingerprint defaults have been initialised.
|
||||
fingerprints_initialised: bool,
|
||||
/// Frame counter.
|
||||
frame_count: u64,
|
||||
}
|
||||
|
||||
impl WindowsWifiPipeline {
|
||||
/// Create a new pipeline with default configuration.
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(PipelineConfig::default())
|
||||
}
|
||||
|
||||
/// Create with default configuration (alias for `new`).
|
||||
#[must_use]
|
||||
pub fn with_defaults() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
|
||||
/// Create a new pipeline with custom configuration.
|
||||
#[must_use]
|
||||
pub fn with_config(config: PipelineConfig) -> Self {
|
||||
Self {
|
||||
gate: PredictiveGate::new(config.max_bssids, config.gate_threshold),
|
||||
attention: AttentionWeighter::new(1),
|
||||
correlator: BssidCorrelator::new(
|
||||
config.max_bssids,
|
||||
config.correlation_window,
|
||||
config.correlation_threshold,
|
||||
),
|
||||
motion: MultiApMotionEstimator::new(),
|
||||
breathing: CoarseBreathingExtractor::new(
|
||||
config.max_bssids,
|
||||
config.sample_rate,
|
||||
0.1,
|
||||
0.5,
|
||||
),
|
||||
quality: QualityGate::new(),
|
||||
fingerprint: FingerprintMatcher::new(config.max_bssids, 0.5),
|
||||
fingerprints_initialised: false,
|
||||
frame_count: 0,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a single multi-BSSID frame through all pipeline stages.
|
||||
///
|
||||
/// Returns an `EnhancedSensingResult` with motion, breathing,
|
||||
/// posture, and quality information.
|
||||
pub fn process(&mut self, frame: &MultiApFrame) -> EnhancedSensingResult {
|
||||
self.frame_count += 1;
|
||||
|
||||
let n = frame.bssid_count;
|
||||
|
||||
// Convert f64 amplitudes to f32 for pipeline stages.
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let amps_f32: Vec<f32> = frame.amplitudes.iter().map(|&a| a as f32).collect();
|
||||
|
||||
// Initialise fingerprint defaults on first frame with enough BSSIDs.
|
||||
if !self.fingerprints_initialised
|
||||
&& self.config.enable_fingerprint
|
||||
&& amps_f32.len() == self.config.max_bssids
|
||||
{
|
||||
self.fingerprint.generate_defaults(&s_f32);
|
||||
self.fingerprints_initialised = true;
|
||||
}
|
||||
|
||||
// Check minimum BSSID count.
|
||||
if n < self.config.min_bssids {
|
||||
return Self::make_empty_result(frame, n);
|
||||
}
|
||||
|
||||
// -- Stage 1: Predictive gating --
|
||||
let Some(residuals) = self.gate.gate(&s_f32) else {
|
||||
// Static environment, no body present.
|
||||
return Self::make_empty_result(frame, n);
|
||||
};
|
||||
|
||||
// -- Stage 2: Attention weighting --
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
let mean_residual =
|
||||
residuals.iter().map(|r| r.abs()).sum::<f32>() / residuals.len().max(1) as f32;
|
||||
let query = vec![mean_residual];
|
||||
let keys: Vec<Vec<f32>> = residuals.iter().map(|&r| vec![r]).collect();
|
||||
let values: Vec<Vec<f32>> = amps_f32.iter().map(|&a| vec![a]).collect();
|
||||
let (_weighted, weights) = self.attention.weight(&query, &keys, &values);
|
||||
|
||||
// -- Stage 3: Spatial correlation --
|
||||
let corr = self.correlator.update(&s_f32);
|
||||
|
||||
// -- Stage 4: Motion estimation --
|
||||
let motion = self.motion.estimate(&residuals, &weights, &corr.diversity);
|
||||
|
||||
// -- Stage 5: Breathing extraction (only when stationary) --
|
||||
let breathing = if self.config.enable_breathing && motion.level == MotionLevel::Minimal {
|
||||
self.breathing.extract(&residuals, &weights)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// -- Stage 6: Quality gate --
|
||||
let quality_result = self.quality.evaluate(
|
||||
n,
|
||||
frame.mean_rssi(),
|
||||
f64::from(corr.mean_correlation()),
|
||||
motion.score,
|
||||
);
|
||||
|
||||
// -- Stage 7: Fingerprint matching --
|
||||
let posture = if self.config.enable_fingerprint {
|
||||
self.fingerprint.classify(&s_f32).map(|(p, _sim)| p)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Count body-sensitive BSSIDs (attention weight above 1.5x average).
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
let avg_weight = 1.0 / n.max(1) as f32;
|
||||
let sensitive_count = weights.iter().filter(|&&w| w > avg_weight * 1.5).count();
|
||||
|
||||
// Map internal quality gate verdict to domain Verdict.
|
||||
let domain_verdict = match &quality_result.verdict {
|
||||
Verdict::Permit => DomainVerdict::Permit,
|
||||
Verdict::Defer => DomainVerdict::Warn,
|
||||
Verdict::Deny(_) => DomainVerdict::Deny,
|
||||
};
|
||||
|
||||
// Build the domain BreathingEstimate if we have one.
|
||||
let domain_breathing = breathing.map(|b| DomainBreathingEstimate {
|
||||
rate_bpm: f64::from(b.bpm),
|
||||
confidence: f64::from(b.confidence),
|
||||
bssid_count: sensitive_count,
|
||||
});
|
||||
|
||||
EnhancedSensingResult {
|
||||
motion: DomainMotionEstimate {
|
||||
score: f64::from(motion.score),
|
||||
level: motion.level,
|
||||
contributing_bssids: motion.n_contributing,
|
||||
},
|
||||
breathing: domain_breathing,
|
||||
posture,
|
||||
signal_quality: SignalQuality {
|
||||
score: quality_result.quality,
|
||||
bssid_count: n,
|
||||
spectral_gap: f64::from(corr.mean_correlation()),
|
||||
mean_rssi_dbm: frame.mean_rssi(),
|
||||
},
|
||||
bssid_count: n,
|
||||
verdict: domain_verdict,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build an empty/gated result for frames that don't pass initial checks.
|
||||
fn make_empty_result(frame: &MultiApFrame, n: usize) -> EnhancedSensingResult {
|
||||
EnhancedSensingResult {
|
||||
motion: DomainMotionEstimate {
|
||||
score: 0.0,
|
||||
level: MotionLevel::None,
|
||||
contributing_bssids: 0,
|
||||
},
|
||||
breathing: None,
|
||||
posture: None,
|
||||
signal_quality: SignalQuality {
|
||||
score: 0.0,
|
||||
bssid_count: n,
|
||||
spectral_gap: 0.0,
|
||||
mean_rssi_dbm: frame.mean_rssi(),
|
||||
},
|
||||
bssid_count: n,
|
||||
verdict: DomainVerdict::Deny,
|
||||
}
|
||||
}
|
||||
|
||||
/// Store a reference fingerprint pattern.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the pattern dimension does not match `max_bssids`.
|
||||
pub fn store_fingerprint(
|
||||
&mut self,
|
||||
pattern: Vec<f32>,
|
||||
label: PostureClass,
|
||||
) -> Result<(), String> {
|
||||
self.fingerprint.store_pattern(pattern, label)
|
||||
}
|
||||
|
||||
/// Reset all pipeline state.
|
||||
pub fn reset(&mut self) {
|
||||
self.gate = PredictiveGate::new(self.config.max_bssids, self.config.gate_threshold);
|
||||
self.correlator = BssidCorrelator::new(
|
||||
self.config.max_bssids,
|
||||
self.config.correlation_window,
|
||||
self.config.correlation_threshold,
|
||||
);
|
||||
self.motion.reset();
|
||||
self.breathing.reset();
|
||||
self.quality.reset();
|
||||
self.fingerprint.clear();
|
||||
self.fingerprints_initialised = false;
|
||||
self.frame_count = 0;
|
||||
}
|
||||
|
||||
/// Number of frames processed.
|
||||
#[must_use]
|
||||
pub fn frame_count(&self) -> u64 {
|
||||
self.frame_count
|
||||
}
|
||||
|
||||
/// Current pipeline configuration.
|
||||
#[must_use]
|
||||
pub fn config(&self) -> &PipelineConfig {
|
||||
&self.config
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WindowsWifiPipeline {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::VecDeque;
|
||||
use std::time::Instant;
|
||||
|
||||
fn make_frame(bssid_count: usize, rssi_values: &[f64]) -> MultiApFrame {
|
||||
let amplitudes: Vec<f64> = rssi_values
|
||||
.iter()
|
||||
.map(|&r| 10.0_f64.powf((r + 100.0) / 20.0))
|
||||
.collect();
|
||||
MultiApFrame {
|
||||
bssid_count,
|
||||
rssi_dbm: rssi_values.to_vec(),
|
||||
amplitudes,
|
||||
phases: vec![0.0; bssid_count],
|
||||
per_bssid_variance: vec![0.1; bssid_count],
|
||||
histories: vec![VecDeque::new(); bssid_count],
|
||||
sample_rate_hz: 2.0,
|
||||
timestamp: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pipeline_creates_ok() {
|
||||
let pipeline = WindowsWifiPipeline::with_defaults();
|
||||
assert_eq!(pipeline.frame_count(), 0);
|
||||
assert_eq!(pipeline.config().max_bssids, 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn too_few_bssids_returns_deny() {
|
||||
let mut pipeline = WindowsWifiPipeline::new();
|
||||
let frame = make_frame(2, &[-60.0, -70.0]);
|
||||
let result = pipeline.process(&frame);
|
||||
assert_eq!(result.verdict, DomainVerdict::Deny);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_frame_increments_count() {
|
||||
let mut pipeline = WindowsWifiPipeline::with_config(PipelineConfig {
|
||||
min_bssids: 1,
|
||||
max_bssids: 4,
|
||||
..Default::default()
|
||||
});
|
||||
let frame = make_frame(4, &[-60.0, -65.0, -70.0, -75.0]);
|
||||
let _result = pipeline.process(&frame);
|
||||
assert_eq!(pipeline.frame_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn static_signal_returns_deny_after_learning() {
|
||||
let mut pipeline = WindowsWifiPipeline::with_config(PipelineConfig {
|
||||
min_bssids: 1,
|
||||
max_bssids: 4,
|
||||
..Default::default()
|
||||
});
|
||||
let frame = make_frame(4, &[-60.0, -65.0, -70.0, -75.0]);
|
||||
|
||||
// Train on static signal.
|
||||
pipeline.process(&frame);
|
||||
pipeline.process(&frame);
|
||||
pipeline.process(&frame);
|
||||
|
||||
// After learning, static signal should be gated (Deny verdict).
|
||||
let result = pipeline.process(&frame);
|
||||
assert_eq!(
|
||||
result.verdict,
|
||||
DomainVerdict::Deny,
|
||||
"static signal should be gated"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn changing_signal_increments_count() {
|
||||
let mut pipeline = WindowsWifiPipeline::with_config(PipelineConfig {
|
||||
min_bssids: 1,
|
||||
max_bssids: 4,
|
||||
..Default::default()
|
||||
});
|
||||
let baseline = make_frame(4, &[-60.0, -65.0, -70.0, -75.0]);
|
||||
|
||||
// Learn baseline.
|
||||
for _ in 0..5 {
|
||||
pipeline.process(&baseline);
|
||||
}
|
||||
|
||||
// Significant change should be noticed.
|
||||
let changed = make_frame(4, &[-60.0, -65.0, -70.0, -30.0]);
|
||||
pipeline.process(&changed);
|
||||
assert!(pipeline.frame_count() > 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reset_clears_state() {
|
||||
let mut pipeline = WindowsWifiPipeline::new();
|
||||
let frame = make_frame(4, &[-60.0, -65.0, -70.0, -75.0]);
|
||||
pipeline.process(&frame);
|
||||
assert_eq!(pipeline.frame_count(), 1);
|
||||
pipeline.reset();
|
||||
assert_eq!(pipeline.frame_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_creates_pipeline() {
|
||||
let _pipeline = WindowsWifiPipeline::default();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pipeline_throughput_benchmark() {
|
||||
let mut pipeline = WindowsWifiPipeline::with_config(PipelineConfig {
|
||||
min_bssids: 1,
|
||||
max_bssids: 4,
|
||||
..Default::default()
|
||||
});
|
||||
let frame = make_frame(4, &[-60.0, -65.0, -70.0, -75.0]);
|
||||
|
||||
let start = Instant::now();
|
||||
let n_frames = 10_000;
|
||||
for _ in 0..n_frames {
|
||||
pipeline.process(&frame);
|
||||
}
|
||||
let elapsed = start.elapsed();
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
let fps = n_frames as f64 / elapsed.as_secs_f64();
|
||||
println!("Pipeline throughput: {fps:.0} frames/sec ({elapsed:?} for {n_frames} frames)");
|
||||
assert!(fps > 100.0, "Pipeline should process >100 frames/sec, got {fps:.0}");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
//! Stage 1: Predictive gating via EMA-based residual filter.
|
||||
//!
|
||||
//! Suppresses static BSSIDs by computing residuals between predicted
|
||||
//! (EMA) and actual RSSI values. Only transmits frames where significant
|
||||
//! change is detected (body interaction).
|
||||
//!
|
||||
//! This is a lightweight pure-Rust implementation. When `ruvector-nervous-system`
|
||||
//! becomes available, the inner EMA predictor can be replaced with
|
||||
//! `PredictiveLayer` for more sophisticated prediction.
|
||||
|
||||
/// Wrapper around an EMA predictor for multi-BSSID residual gating.
|
||||
pub struct PredictiveGate {
|
||||
/// Per-BSSID EMA predictions.
|
||||
predictions: Vec<f32>,
|
||||
/// Whether a prediction has been initialised for each slot.
|
||||
initialised: Vec<bool>,
|
||||
/// EMA smoothing factor (higher = faster tracking).
|
||||
alpha: f32,
|
||||
/// Residual threshold for change detection.
|
||||
threshold: f32,
|
||||
/// Residuals from the last frame (for downstream use).
|
||||
last_residuals: Vec<f32>,
|
||||
/// Number of BSSID slots.
|
||||
n_bssids: usize,
|
||||
}
|
||||
|
||||
impl PredictiveGate {
|
||||
/// Create a new predictive gate.
|
||||
///
|
||||
/// - `n_bssids`: maximum number of tracked BSSIDs (subcarrier slots).
|
||||
/// - `threshold`: residual threshold for change detection (ADR-022 default: 0.05).
|
||||
#[must_use]
|
||||
pub fn new(n_bssids: usize, threshold: f32) -> Self {
|
||||
Self {
|
||||
predictions: vec![0.0; n_bssids],
|
||||
initialised: vec![false; n_bssids],
|
||||
alpha: 0.3,
|
||||
threshold,
|
||||
last_residuals: vec![0.0; n_bssids],
|
||||
n_bssids,
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a frame. Returns `Some(residuals)` if body-correlated change
|
||||
/// is detected, `None` if the environment is static.
|
||||
pub fn gate(&mut self, amplitudes: &[f32]) -> Option<Vec<f32>> {
|
||||
let n = amplitudes.len().min(self.n_bssids);
|
||||
let mut residuals = vec![0.0f32; n];
|
||||
let mut max_residual = 0.0f32;
|
||||
|
||||
for i in 0..n {
|
||||
if self.initialised[i] {
|
||||
residuals[i] = amplitudes[i] - self.predictions[i];
|
||||
max_residual = max_residual.max(residuals[i].abs());
|
||||
// Update EMA
|
||||
self.predictions[i] =
|
||||
self.alpha * amplitudes[i] + (1.0 - self.alpha) * self.predictions[i];
|
||||
} else {
|
||||
// First observation: seed the prediction
|
||||
self.predictions[i] = amplitudes[i];
|
||||
self.initialised[i] = true;
|
||||
residuals[i] = amplitudes[i]; // first frame always transmits
|
||||
max_residual = f32::MAX;
|
||||
}
|
||||
}
|
||||
|
||||
self.last_residuals.clone_from(&residuals);
|
||||
|
||||
if max_residual > self.threshold {
|
||||
Some(residuals)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the residuals from the last `gate()` call.
|
||||
#[must_use]
|
||||
pub fn last_residuals(&self) -> &[f32] {
|
||||
&self.last_residuals
|
||||
}
|
||||
|
||||
/// Update the threshold dynamically (e.g., from SONA adaptation).
|
||||
pub fn set_threshold(&mut self, threshold: f32) {
|
||||
self.threshold = threshold;
|
||||
}
|
||||
|
||||
/// Current threshold.
|
||||
#[must_use]
|
||||
pub fn threshold(&self) -> f32 {
|
||||
self.threshold
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn static_signal_is_gated() {
|
||||
let mut gate = PredictiveGate::new(4, 0.05);
|
||||
let signal = vec![1.0, 2.0, 3.0, 4.0];
|
||||
// First frame always transmits (no prediction yet)
|
||||
assert!(gate.gate(&signal).is_some());
|
||||
// After many repeated frames, EMA converges and residuals shrink
|
||||
for _ in 0..20 {
|
||||
gate.gate(&signal);
|
||||
}
|
||||
assert!(gate.gate(&signal).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn changing_signal_transmits() {
|
||||
let mut gate = PredictiveGate::new(4, 0.05);
|
||||
let signal1 = vec![1.0, 2.0, 3.0, 4.0];
|
||||
gate.gate(&signal1);
|
||||
// Let EMA converge
|
||||
for _ in 0..20 {
|
||||
gate.gate(&signal1);
|
||||
}
|
||||
|
||||
// Large change should be transmitted
|
||||
let signal2 = vec![1.0, 2.0, 3.0, 10.0];
|
||||
assert!(gate.gate(&signal2).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn residuals_are_stored() {
|
||||
let mut gate = PredictiveGate::new(3, 0.05);
|
||||
let signal = vec![1.0, 2.0, 3.0];
|
||||
gate.gate(&signal);
|
||||
assert_eq!(gate.last_residuals().len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn threshold_can_be_updated() {
|
||||
let mut gate = PredictiveGate::new(2, 0.05);
|
||||
assert!((gate.threshold() - 0.05).abs() < f32::EPSILON);
|
||||
gate.set_threshold(0.1);
|
||||
assert!((gate.threshold() - 0.1).abs() < f32::EPSILON);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,261 @@
|
||||
//! Stage 6: Signal quality gate.
|
||||
//!
|
||||
//! Evaluates signal quality using three factors inspired by the ruQu
|
||||
//! three-filter architecture (structural integrity, distribution drift,
|
||||
//! evidence accumulation):
|
||||
//!
|
||||
//! - **Structural**: number of active BSSIDs (graph connectivity proxy).
|
||||
//! - **Shift**: RSSI drift from running baseline.
|
||||
//! - **Evidence**: accumulated weighted variance evidence.
|
||||
//!
|
||||
//! This is a pure-Rust implementation. When the `ruqu` crate becomes
|
||||
//! available, the inner filter can be replaced with `FilterPipeline`.
|
||||
|
||||
/// Configuration for the quality gate.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QualityGateConfig {
|
||||
/// Minimum active BSSIDs for a "Permit" verdict.
|
||||
pub min_bssids: usize,
|
||||
/// Evidence threshold for "Permit" (accumulated variance).
|
||||
pub evidence_threshold: f64,
|
||||
/// RSSI drift threshold (dBm) for triggering a "Warn".
|
||||
pub drift_threshold: f64,
|
||||
/// Maximum evidence decay per frame.
|
||||
pub evidence_decay: f64,
|
||||
}
|
||||
|
||||
impl Default for QualityGateConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
min_bssids: 3,
|
||||
evidence_threshold: 0.5,
|
||||
drift_threshold: 10.0,
|
||||
evidence_decay: 0.95,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Quality gate combining structural, shift, and evidence filters.
|
||||
pub struct QualityGate {
|
||||
config: QualityGateConfig,
|
||||
/// Accumulated evidence score.
|
||||
evidence: f64,
|
||||
/// Running mean RSSI baseline for drift detection.
|
||||
prev_mean_rssi: Option<f64>,
|
||||
/// EMA smoothing factor for drift baseline.
|
||||
alpha: f64,
|
||||
}
|
||||
|
||||
impl QualityGate {
|
||||
/// Create a quality gate with default configuration.
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(QualityGateConfig::default())
|
||||
}
|
||||
|
||||
/// Create a quality gate with custom configuration.
|
||||
#[must_use]
|
||||
pub fn with_config(config: QualityGateConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
evidence: 0.0,
|
||||
prev_mean_rssi: None,
|
||||
alpha: 0.3,
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate signal quality.
|
||||
///
|
||||
/// - `bssid_count`: number of active BSSIDs.
|
||||
/// - `mean_rssi_dbm`: mean RSSI across all BSSIDs.
|
||||
/// - `mean_correlation`: mean cross-BSSID correlation (spectral gap proxy).
|
||||
/// - `motion_score`: smoothed motion score from the estimator.
|
||||
///
|
||||
/// Returns a `QualityResult` with verdict and quality score.
|
||||
pub fn evaluate(
|
||||
&mut self,
|
||||
bssid_count: usize,
|
||||
mean_rssi_dbm: f64,
|
||||
mean_correlation: f64,
|
||||
motion_score: f32,
|
||||
) -> QualityResult {
|
||||
// --- Filter 1: Structural (BSSID count) ---
|
||||
let structural_ok = bssid_count >= self.config.min_bssids;
|
||||
|
||||
// --- Filter 2: Shift (RSSI drift detection) ---
|
||||
let drift = if let Some(prev) = self.prev_mean_rssi {
|
||||
(mean_rssi_dbm - prev).abs()
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
// Update baseline with EMA
|
||||
self.prev_mean_rssi = Some(match self.prev_mean_rssi {
|
||||
Some(prev) => self.alpha * mean_rssi_dbm + (1.0 - self.alpha) * prev,
|
||||
None => mean_rssi_dbm,
|
||||
});
|
||||
let drift_detected = drift > self.config.drift_threshold;
|
||||
|
||||
// --- Filter 3: Evidence accumulation ---
|
||||
// Motion and correlation both contribute positive evidence.
|
||||
let evidence_input = f64::from(motion_score) * 0.7 + mean_correlation * 0.3;
|
||||
self.evidence = self.evidence * self.config.evidence_decay + evidence_input;
|
||||
|
||||
// --- Quality score ---
|
||||
let quality = compute_quality_score(
|
||||
bssid_count,
|
||||
f64::from(motion_score),
|
||||
mean_correlation,
|
||||
drift_detected,
|
||||
);
|
||||
|
||||
// --- Verdict decision ---
|
||||
let verdict = if !structural_ok {
|
||||
Verdict::Deny("insufficient BSSIDs".to_string())
|
||||
} else if self.evidence < self.config.evidence_threshold * 0.5 || drift_detected {
|
||||
Verdict::Defer
|
||||
} else {
|
||||
Verdict::Permit
|
||||
};
|
||||
|
||||
QualityResult {
|
||||
verdict,
|
||||
quality,
|
||||
drift_detected,
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset the gate state.
|
||||
pub fn reset(&mut self) {
|
||||
self.evidence = 0.0;
|
||||
self.prev_mean_rssi = None;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for QualityGate {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Quality verdict from the gate.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QualityResult {
|
||||
/// Filter decision.
|
||||
pub verdict: Verdict,
|
||||
/// Signal quality score [0, 1].
|
||||
pub quality: f64,
|
||||
/// Whether environmental drift was detected.
|
||||
pub drift_detected: bool,
|
||||
}
|
||||
|
||||
/// Simplified quality gate verdict.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum Verdict {
|
||||
/// Reading passed all quality gates and is reliable.
|
||||
Permit,
|
||||
/// Reading failed quality checks with a reason.
|
||||
Deny(String),
|
||||
/// Evidence still accumulating.
|
||||
Defer,
|
||||
}
|
||||
|
||||
impl Verdict {
|
||||
/// Returns true if this verdict permits the reading.
|
||||
#[must_use]
|
||||
pub fn is_permit(&self) -> bool {
|
||||
matches!(self, Self::Permit)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute a quality score from pipeline metrics.
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
fn compute_quality_score(
|
||||
n_active: usize,
|
||||
weighted_variance: f64,
|
||||
mean_correlation: f64,
|
||||
drift: bool,
|
||||
) -> f64 {
|
||||
// 1. Number of active BSSIDs (more = better, diminishing returns)
|
||||
let bssid_factor = (n_active as f64 / 10.0).min(1.0);
|
||||
|
||||
// 2. Evidence strength (higher weighted variance = more signal)
|
||||
let evidence_factor = (weighted_variance * 10.0).min(1.0);
|
||||
|
||||
// 3. Correlation coherence (moderate correlation is best)
|
||||
let corr_factor = 1.0 - (mean_correlation - 0.5).abs() * 2.0;
|
||||
|
||||
// 4. Drift penalty
|
||||
let drift_penalty = if drift { 0.7 } else { 1.0 };
|
||||
|
||||
let raw =
|
||||
(bssid_factor * 0.3 + evidence_factor * 0.4 + corr_factor.max(0.0) * 0.3) * drift_penalty;
|
||||
raw.clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn new_gate_creates_ok() {
|
||||
let gate = QualityGate::new();
|
||||
assert!((gate.evidence - 0.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn evaluate_with_good_signal() {
|
||||
let mut gate = QualityGate::new();
|
||||
// Pump several frames to build evidence.
|
||||
for _ in 0..20 {
|
||||
gate.evaluate(10, -60.0, 0.5, 0.3);
|
||||
}
|
||||
let result = gate.evaluate(10, -60.0, 0.5, 0.3);
|
||||
assert!(result.quality > 0.0, "quality should be positive");
|
||||
assert!(result.verdict.is_permit(), "should permit good signal");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn too_few_bssids_denied() {
|
||||
let mut gate = QualityGate::new();
|
||||
let result = gate.evaluate(1, -60.0, 0.5, 0.3);
|
||||
assert!(
|
||||
matches!(result.verdict, Verdict::Deny(_)),
|
||||
"too few BSSIDs should be denied"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quality_increases_with_more_bssids() {
|
||||
let q_few = compute_quality_score(3, 0.1, 0.5, false);
|
||||
let q_many = compute_quality_score(10, 0.1, 0.5, false);
|
||||
assert!(q_many > q_few, "more BSSIDs should give higher quality");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn drift_reduces_quality() {
|
||||
let q_stable = compute_quality_score(5, 0.1, 0.5, false);
|
||||
let q_drift = compute_quality_score(5, 0.1, 0.5, true);
|
||||
assert!(q_drift < q_stable, "drift should reduce quality");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verdict_is_permit_check() {
|
||||
assert!(Verdict::Permit.is_permit());
|
||||
assert!(!Verdict::Deny("test".to_string()).is_permit());
|
||||
assert!(!Verdict::Defer.is_permit());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_creates_gate() {
|
||||
let _gate = QualityGate::default();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reset_clears_state() {
|
||||
let mut gate = QualityGate::new();
|
||||
gate.evaluate(10, -60.0, 0.5, 0.3);
|
||||
gate.reset();
|
||||
assert!(gate.prev_mean_rssi.is_none());
|
||||
assert!((gate.evidence - 0.0).abs() < f64::EPSILON);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
//! Port definitions for the BSSID Acquisition bounded context.
|
||||
//!
|
||||
//! Hexagonal-architecture ports that abstract the WiFi scanning backend,
|
||||
//! enabling Tier 1 (netsh), Tier 2 (wlanapi FFI), and test-double adapters
|
||||
//! to be swapped transparently.
|
||||
|
||||
mod scan_port;
|
||||
|
||||
pub use scan_port::WlanScanPort;
|
||||
@@ -0,0 +1,17 @@
|
||||
//! The primary port (driving side) for WiFi BSSID scanning.
|
||||
|
||||
use crate::domain::bssid::BssidObservation;
|
||||
use crate::error::WifiScanError;
|
||||
|
||||
/// Port that abstracts the platform WiFi scanning backend.
|
||||
///
|
||||
/// Implementations include:
|
||||
/// - [`crate::adapter::NetshBssidScanner`] -- Tier 1, subprocess-based.
|
||||
/// - Future: `WlanApiBssidScanner` -- Tier 2, native FFI (feature-gated).
|
||||
pub trait WlanScanPort: Send + Sync {
|
||||
/// Perform a scan and return all currently visible BSSIDs.
|
||||
fn scan(&self) -> Result<Vec<BssidObservation>, WifiScanError>;
|
||||
|
||||
/// Return the BSSID to which the adapter is currently connected, if any.
|
||||
fn connected(&self) -> Result<Option<BssidObservation>, WifiScanError>;
|
||||
}
|
||||
Reference in New Issue
Block a user