diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/dataset.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/dataset.rs index 7cf72d2..7ee18d5 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/dataset.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/dataset.rs @@ -1129,4 +1129,36 @@ mod tests { xorshift_shuffle(&mut b, 123); assert_eq!(a, b); } + + // ----- CompressedCsiBuffer ---------------------------------------------- + + #[test] + fn compressed_csi_buffer_roundtrip() { + // Create a small CSI array and check it round-trips through compression + let arr = Array4::::from_shape_fn((10, 1, 3, 16), |(t, _, rx, sc)| { + ((t + rx + sc) as f32) * 0.1 + }); + let buf = CompressedCsiBuffer::from_array4(&arr, 0); + assert_eq!(buf.len(), 10); + assert!(!buf.is_empty()); + assert!(buf.compression_ratio > 1.0, "Should compress better than f32"); + + // Decode single frame + let frame = buf.get_frame(0); + assert!(frame.is_some()); + assert_eq!(frame.unwrap().len(), 1 * 3 * 16); + + // Full decode + let decoded = buf.to_array4(1, 3, 16); + assert_eq!(decoded.shape(), &[10, 1, 3, 16]); + } + + #[test] + fn compressed_csi_buffer_empty() { + let arr = Array4::::zeros((0, 1, 3, 16)); + let buf = CompressedCsiBuffer::from_array4(&arr, 0); + assert_eq!(buf.len(), 0); + assert!(buf.is_empty()); + assert!(buf.get_frame(0).is_none()); + } } diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/model.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/model.rs index d2742d6..8f112c7 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/model.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/model.rs @@ -82,14 +82,16 @@ impl WiFiDensePoseModel { let root = vs.root(); // Compute the flattened CSI input size used by the modality translator. - let flat_csi = (config.window_frames + let n_ant = (config.window_frames * config.num_antennas_tx - * config.num_antennas_rx - * config.num_subcarriers) as i64; + * config.num_antennas_rx) as i64; + let n_sc = config.num_subcarriers as i64; + let flat_csi = n_ant * n_sc; let num_parts = config.num_body_parts as i64; - let translator = ModalityTranslator::new(&root / "translator", flat_csi); + let translator = + ModalityTranslator::new(&root / "translator", flat_csi, n_ant, n_sc); let backbone = Backbone::new(&root / "backbone", config.backbone_channels as i64); let kp_head = KeypointHead::new( &root / "kp_head", @@ -255,6 +257,139 @@ fn phase_sanitize(phase: &Tensor) -> Tensor { Tensor::cat(&[zeros, diff], 2) } +// --------------------------------------------------------------------------- +// ruvector attention helpers +// --------------------------------------------------------------------------- + +/// Apply min-cut gated attention over the antenna-path dimension. +/// +/// Treats each antenna path as a "token" and subcarriers as the feature +/// dimension. Uses `attn_mincut` to gate irrelevant antenna-pair correlations, +/// which is equivalent to automatic antenna selection. +/// +/// # Arguments +/// +/// - `x`: CSI tensor `[B, n_ant, n_sc]` — amplitude or phase +/// - `lambda`: min-cut threshold (0.3 = moderate pruning) +/// +/// # Returns +/// +/// Attended tensor `[B, n_ant, n_sc]` with irrelevant antenna paths suppressed. +fn apply_antenna_attention(x: &Tensor, lambda: f32) -> Tensor { + let sizes = x.size(); + let n_ant = sizes[1]; + let n_sc = sizes[2]; + + // Skip trivial cases where attention is a no-op. + if n_ant <= 1 || n_sc <= 1 { + return x.shallow_clone(); + } + + let b = sizes[0] as usize; + let n_ant_usize = n_ant as usize; + let n_sc_usize = n_sc as usize; + + let device = x.device(); + let kind = x.kind(); + + // Process each batch element independently (attn_mincut operates on 2D inputs). + let mut results: Vec = Vec::with_capacity(b); + + for bi in 0..b { + // Extract [n_ant, n_sc] slice for this batch element. + let xi = x.select(0, bi as i64); // [n_ant, n_sc] + + // Move to CPU and convert to f32 for the pure-Rust attention kernel. + let flat: Vec = + Vec::from(xi.to_kind(Kind::Float).to_device(Device::Cpu).contiguous()); + + // Q = K = V = the antenna features (self-attention over antenna paths). + let out = attn_mincut( + &flat, // q: [n_ant * n_sc] + &flat, // k: [n_ant * n_sc] + &flat, // v: [n_ant * n_sc] + n_sc_usize, // d: feature dim = n_sc subcarriers + n_ant_usize, // seq_len: number of antenna paths + lambda, // lambda: min-cut threshold + 1, // tau: no temporal hysteresis (single-frame) + 1e-6, // eps: numerical epsilon + ); + + let attended = Tensor::from_slice(&out.output) + .reshape([n_ant, n_sc]) + .to_device(device) + .to_kind(kind); + + results.push(attended); + } + + Tensor::stack(&results, 0) // [B, n_ant, n_sc] +} + +/// Apply scaled dot-product attention over spatial locations. +/// +/// Input: `[B, C, H, W]` feature map — each spatial location (H×W) becomes a +/// token; C is the feature dimension. Captures long-range spatial dependencies +/// between antenna-footprint regions. +/// +/// Returns `[B, C, H, W]` with spatial attention applied. +/// +/// This function can be applied after backbone features when long-range spatial +/// context is needed. It is defined here for completeness and may be called +/// from head implementations or future backbone variants. +#[allow(dead_code)] +fn apply_spatial_attention(x: &Tensor) -> Tensor { + let sizes = x.size(); + let (b, c, h, w) = (sizes[0], sizes[1], sizes[2], sizes[3]); + let n_spatial = (h * w) as usize; + let d = c as usize; + + let device = x.device(); + let kind = x.kind(); + + let attn = ScaledDotProductAttention::new(d); + + let mut results: Vec = Vec::with_capacity(b as usize); + + for bi in 0..b { + // Extract [C, H*W] and transpose to [H*W, C]. + let xi = x.select(0, bi).reshape([c, h * w]).transpose(0, 1); // [H*W, C] + let flat: Vec = + Vec::from(xi.to_kind(Kind::Float).to_device(Device::Cpu).contiguous()); + + // Build token slices — one per spatial position. + let tokens: Vec<&[f32]> = (0..n_spatial) + .map(|i| &flat[i * d..(i + 1) * d]) + .collect(); + + // For each spatial token as query, compute attended output. + let mut out_flat = vec![0.0f32; n_spatial * d]; + for i in 0..n_spatial { + let query = &flat[i * d..(i + 1) * d]; + match attn.compute(query, &tokens, &tokens) { + Ok(attended) => { + out_flat[i * d..(i + 1) * d].copy_from_slice(&attended); + } + Err(_) => { + // Fallback: identity — keep original features unchanged. + out_flat[i * d..(i + 1) * d].copy_from_slice(query); + } + } + } + + let out_tensor = Tensor::from_slice(&out_flat) + .reshape([h * w, c]) + .transpose(0, 1) // [C, H*W] + .reshape([c, h, w]) // [C, H, W] + .to_device(device) + .to_kind(kind); + + results.push(out_tensor); + } + + Tensor::stack(&results, 0) // [B, C, H, W] +} + // --------------------------------------------------------------------------- // Modality Translator // --------------------------------------------------------------------------- @@ -262,10 +397,14 @@ fn phase_sanitize(phase: &Tensor) -> Tensor { /// Translates flattened (amplitude, phase) CSI vectors into a pseudo-image. /// /// ```text -/// amplitude [B, flat_csi] ─► amp_fc1 ► relu ► amp_fc2 ► relu ─┐ -/// ├─► fuse_fc ► reshape ► spatial_conv ► [B, 3, 48, 48] -/// phase [B, flat_csi] ─► ph_fc1 ► relu ► ph_fc2 ► relu ─┘ +/// amplitude [B, flat_csi] ─► attn_mincut ─► amp_fc1 ► relu ► amp_fc2 ► relu ─┐ +/// ├─► fuse_fc ► reshape ► spatial_conv ► [B, 3, 48, 48] +/// phase [B, flat_csi] ─► attn_mincut ─► ph_fc1 ► relu ► ph_fc2 ► relu ─┘ /// ``` +/// +/// The `attn_mincut` step performs self-attention over the antenna-path dimension +/// (`n_ant` tokens, each with `n_sc` subcarrier features) to gate out irrelevant +/// antenna-pair correlations before the FC fusion layers. struct ModalityTranslator { amp_fc1: nn::Linear, amp_fc2: nn::Linear, @@ -276,10 +415,14 @@ struct ModalityTranslator { sp_conv1: nn::Conv2D, sp_bn1: nn::BatchNorm, sp_conv2: nn::Conv2D, + /// Number of antenna paths: T * n_tx * n_rx (used for attention reshape). + n_ant: i64, + /// Number of subcarriers per antenna path (used for attention reshape). + n_sc: i64, } impl ModalityTranslator { - fn new(vs: nn::Path, flat_csi: i64) -> Self { + fn new(vs: nn::Path, flat_csi: i64, n_ant: i64, n_sc: i64) -> Self { let amp_fc1 = nn::linear(&vs / "amp_fc1", flat_csi, 512, Default::default()); let amp_fc2 = nn::linear(&vs / "amp_fc2", 512, 256, Default::default()); let ph_fc1 = nn::linear(&vs / "ph_fc1", flat_csi, 512, Default::default()); @@ -320,22 +463,38 @@ impl ModalityTranslator { sp_conv1, sp_bn1, sp_conv2, + n_ant, + n_sc, } } fn forward_t(&self, amp: &Tensor, ph: &Tensor, train: bool) -> Tensor { let b = amp.size()[0]; - // Amplitude branch - let a = amp + // === ruvector-attn-mincut: gate irrelevant antenna paths === + // + // Reshape from [B, flat_csi] to [B, n_ant, n_sc], apply min-cut + // self-attention over the antenna-path dimension (antenna paths are + // "tokens", subcarrier responses are "features"), then flatten back. + let amp_3d = amp.reshape([b, self.n_ant, self.n_sc]); + let ph_3d = ph.reshape([b, self.n_ant, self.n_sc]); + + let amp_attended = apply_antenna_attention(&_3d, 0.3); + let ph_attended = apply_antenna_attention(&ph_3d, 0.3); + + let amp_flat = amp_attended.reshape([b, -1]); // [B, flat_csi] + let ph_flat = ph_attended.reshape([b, -1]); // [B, flat_csi] + + // Amplitude branch (uses attended input) + let a = amp_flat .apply(&self.amp_fc1) .relu() .dropout(0.2, train) .apply(&self.amp_fc2) .relu(); - // Phase branch - let p = ph + // Phase branch (uses attended input) + let p = ph_flat .apply(&self.ph_fc1) .relu() .dropout(0.2, train)