feat(train): Add ruvector integration — ADR-016, deps, DynamicPersonMatcher
- docs/adr/ADR-016: Full ruvector integration ADR with verified API details from source inspection (github.com/ruvnet/ruvector). Covers mincut, attn-mincut, temporal-tensor, solver, and attention at v2.0.4. - Cargo.toml: Add ruvector-mincut, ruvector-attn-mincut, ruvector-temporal- tensor, ruvector-solver, ruvector-attention = "2.0.4" to workspace deps and wifi-densepose-train crate deps. - metrics.rs: Add DynamicPersonMatcher wrapping ruvector_mincut::DynamicMinCut for subpolynomial O(n^1.5 log n) multi-frame person tracking; adds assignment_mincut() public entry point. - proof.rs, trainer.rs, model.rs, dataset.rs, subcarrier.rs: Agent improvements to full implementations (loss decrease verification, SHA-256 hash, LCG shuffle, ResNet18 backbone, MmFiDataset, linear interp). - tests: test_config, test_dataset, test_metrics, test_proof, training_bench all added/updated. 100+ tests pass with no-default-features. https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
This commit is contained in:
@@ -16,7 +16,6 @@
|
||||
//! exclusively on the [`CsiDataset`] passed at call site. The
|
||||
//! [`SyntheticCsiDataset`] is only used for the deterministic proof protocol.
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use std::io::Write as IoWrite;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::Instant;
|
||||
@@ -26,7 +25,7 @@ use tch::{nn, nn::OptimizerConfig, Device, Kind, Tensor};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::config::TrainingConfig;
|
||||
use crate::dataset::{CsiDataset, CsiSample, DataLoader};
|
||||
use crate::dataset::{CsiDataset, CsiSample};
|
||||
use crate::error::TrainError;
|
||||
use crate::losses::{LossWeights, WiFiDensePoseLoss};
|
||||
use crate::losses::generate_target_heatmaps;
|
||||
@@ -123,14 +122,14 @@ impl Trainer {
|
||||
|
||||
// Prepare output directories.
|
||||
std::fs::create_dir_all(&self.config.checkpoint_dir)
|
||||
.map_err(|e| TrainError::Io(e))?;
|
||||
.map_err(|e| TrainError::training_step(format!("create checkpoint dir: {e}")))?;
|
||||
std::fs::create_dir_all(&self.config.log_dir)
|
||||
.map_err(|e| TrainError::Io(e))?;
|
||||
.map_err(|e| TrainError::training_step(format!("create log dir: {e}")))?;
|
||||
|
||||
// Build optimizer (AdamW).
|
||||
let mut opt = nn::AdamW::default()
|
||||
.wd(self.config.weight_decay)
|
||||
.build(self.model.var_store(), self.config.learning_rate)
|
||||
.build(self.model.var_store_mut(), self.config.learning_rate)
|
||||
.map_err(|e| TrainError::training_step(e.to_string()))?;
|
||||
|
||||
let loss_fn = WiFiDensePoseLoss::new(LossWeights {
|
||||
@@ -146,9 +145,9 @@ impl Trainer {
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.open(&csv_path)
|
||||
.map_err(|e| TrainError::Io(e))?;
|
||||
.map_err(|e| TrainError::training_step(format!("open csv log: {e}")))?;
|
||||
writeln!(csv_file, "epoch,train_loss,train_kp_loss,val_pck,val_oks,lr,duration_secs")
|
||||
.map_err(|e| TrainError::Io(e))?;
|
||||
.map_err(|e| TrainError::training_step(format!("write csv header: {e}")))?;
|
||||
|
||||
let mut training_history: Vec<EpochLog> = Vec::new();
|
||||
let mut best_pck: f32 = -1.0;
|
||||
@@ -316,7 +315,7 @@ impl Trainer {
|
||||
log.lr,
|
||||
log.duration_secs,
|
||||
)
|
||||
.map_err(|e| TrainError::Io(e))?;
|
||||
.map_err(|e| TrainError::training_step(format!("write csv row: {e}")))?;
|
||||
|
||||
training_history.push(log);
|
||||
|
||||
@@ -394,7 +393,7 @@ impl Trainer {
|
||||
_epoch: usize,
|
||||
_metrics: &MetricsResult,
|
||||
) -> Result<(), TrainError> {
|
||||
self.model.save(path).map_err(|e| TrainError::checkpoint(e.to_string(), path))
|
||||
self.model.save(path)
|
||||
}
|
||||
|
||||
/// Load model weights from a checkpoint.
|
||||
|
||||
Reference in New Issue
Block a user