Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
54
vendor/ruvector/crates/ruvector-attn-mincut/src/config.rs
vendored
Normal file
54
vendor/ruvector/crates/ruvector-attn-mincut/src/config.rs
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configuration for the min-cut gating attention operator.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MinCutConfig {
|
||||
pub lambda: f32,
|
||||
pub tau: usize,
|
||||
pub eps: f32,
|
||||
pub seed: u64,
|
||||
pub witness_enabled: bool,
|
||||
}
|
||||
|
||||
impl Default for MinCutConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
lambda: 0.5,
|
||||
tau: 2,
|
||||
eps: 0.01,
|
||||
seed: 42,
|
||||
witness_enabled: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let c = MinCutConfig::default();
|
||||
assert!((c.lambda - 0.5).abs() < f32::EPSILON);
|
||||
assert_eq!(c.tau, 2);
|
||||
assert!((c.eps - 0.01).abs() < f32::EPSILON);
|
||||
assert_eq!(c.seed, 42);
|
||||
assert!(c.witness_enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serde_roundtrip() {
|
||||
let c = MinCutConfig {
|
||||
lambda: 0.3,
|
||||
tau: 5,
|
||||
eps: 0.001,
|
||||
seed: 99,
|
||||
witness_enabled: false,
|
||||
};
|
||||
let json = serde_json::to_string(&c).unwrap();
|
||||
let r: MinCutConfig = serde_json::from_str(&json).unwrap();
|
||||
assert!((r.lambda - 0.3).abs() < f32::EPSILON);
|
||||
assert_eq!(r.tau, 5);
|
||||
assert!(!r.witness_enabled);
|
||||
}
|
||||
}
|
||||
149
vendor/ruvector/crates/ruvector-attn-mincut/src/gating.rs
vendored
Normal file
149
vendor/ruvector/crates/ruvector-attn-mincut/src/gating.rs
vendored
Normal file
@@ -0,0 +1,149 @@
|
||||
use crate::mincut::{dynamic_min_cut, GatingResult};
|
||||
|
||||
/// Combined output from min-cut gated attention.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AttentionOutput {
|
||||
pub output: Vec<f32>,
|
||||
pub gating: GatingResult,
|
||||
}
|
||||
|
||||
/// Compute raw logits: Q * K^T / sqrt(d). Returns flattened `seq_len x seq_len`.
|
||||
fn compute_logits(q: &[f32], k: &[f32], d: usize, seq_len: usize) -> Vec<f32> {
|
||||
let scale = 1.0 / (d as f32).sqrt();
|
||||
let mut logits = vec![0.0f32; seq_len * seq_len];
|
||||
for i in 0..seq_len {
|
||||
for j in 0..seq_len {
|
||||
let mut dot = 0.0f32;
|
||||
for h in 0..d {
|
||||
dot += q[i * d + h] * k[j * d + h];
|
||||
}
|
||||
logits[i * seq_len + j] = dot * scale;
|
||||
}
|
||||
}
|
||||
logits
|
||||
}
|
||||
|
||||
/// Row-wise softmax in place on a flattened `rows x cols` matrix.
|
||||
fn row_softmax(mat: &mut [f32], rows: usize, cols: usize) {
|
||||
for i in 0..rows {
|
||||
let row = &mut mat[i * cols..(i + 1) * cols];
|
||||
let mx = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
|
||||
let mut sum = 0.0f32;
|
||||
for v in row.iter_mut() {
|
||||
*v = (*v - mx).exp();
|
||||
sum += *v;
|
||||
}
|
||||
if sum > 0.0 {
|
||||
for v in row.iter_mut() {
|
||||
*v /= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiply weights (seq_len x seq_len) by V (seq_len x d).
|
||||
fn matmul_wv(w: &[f32], v: &[f32], seq_len: usize, d: usize) -> Vec<f32> {
|
||||
let mut out = vec![0.0f32; seq_len * d];
|
||||
for i in 0..seq_len {
|
||||
for j in 0..seq_len {
|
||||
let wij = w[i * seq_len + j];
|
||||
if wij != 0.0 {
|
||||
for h in 0..d {
|
||||
out[i * d + h] += wij * v[j * d + h];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Baseline standard softmax attention. Returns flattened `seq_len x d`.
|
||||
pub fn attn_softmax(q: &[f32], k: &[f32], v: &[f32], d: usize, seq_len: usize) -> Vec<f32> {
|
||||
assert!(q.len() == seq_len * d && k.len() == seq_len * d && v.len() == seq_len * d);
|
||||
let mut logits = compute_logits(q, k, d, seq_len);
|
||||
row_softmax(&mut logits, seq_len, seq_len);
|
||||
matmul_wv(&logits, v, seq_len, d)
|
||||
}
|
||||
|
||||
/// Min-cut gated attention.
|
||||
/// 1. Compute logits 2. Min-cut gating 3. Mask with -INF 4. Row-softmax 5. Multiply V
|
||||
pub fn attn_mincut(
|
||||
q: &[f32],
|
||||
k: &[f32],
|
||||
v: &[f32],
|
||||
d: usize,
|
||||
seq_len: usize,
|
||||
lambda: f32,
|
||||
tau: usize,
|
||||
eps: f32,
|
||||
) -> AttentionOutput {
|
||||
assert!(q.len() == seq_len * d && k.len() == seq_len * d && v.len() == seq_len * d);
|
||||
let mut logits = compute_logits(q, k, d, seq_len);
|
||||
let gating = dynamic_min_cut(&logits, seq_len, lambda, tau, eps);
|
||||
|
||||
// Gate entries with -INF so softmax zeroes them
|
||||
for i in 0..logits.len() {
|
||||
if !gating.keep_mask[i] {
|
||||
logits[i] = f32::NEG_INFINITY;
|
||||
}
|
||||
}
|
||||
row_softmax(&mut logits, seq_len, seq_len);
|
||||
// Replace NaN (fully-gated rows) with 0
|
||||
for v in logits.iter_mut() {
|
||||
if v.is_nan() {
|
||||
*v = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
AttentionOutput {
|
||||
output: matmul_wv(&logits, v, seq_len, d),
|
||||
gating,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_qkv(seq: usize, d: usize) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
|
||||
let mut q = vec![0.0f32; seq * d];
|
||||
let mut k = vec![0.0f32; seq * d];
|
||||
let v: Vec<f32> = (0..seq * d).map(|i| i as f32).collect();
|
||||
for i in 0..seq.min(d) {
|
||||
q[i * d + i] = 1.0;
|
||||
k[i * d + i] = 1.0;
|
||||
}
|
||||
(q, k, v)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_shape_and_finite() {
|
||||
let (q, k, v) = make_qkv(4, 3);
|
||||
let out = attn_softmax(&q, &k, &v, 3, 4);
|
||||
assert_eq!(out.len(), 12);
|
||||
assert!(out.iter().all(|x| x.is_finite()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mincut_shape_and_finite() {
|
||||
let (q, k, v) = make_qkv(4, 3);
|
||||
let r = attn_mincut(&q, &k, &v, 3, 4, 0.5, 2, 0.01);
|
||||
assert_eq!(r.output.len(), 12);
|
||||
assert!(r.output.iter().all(|x| x.is_finite()));
|
||||
assert_eq!(r.gating.edges_total, 16);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_logit_scale() {
|
||||
let logits = compute_logits(&[1.0; 4], &[1.0; 4], 4, 1);
|
||||
assert!((logits[0] - 2.0).abs() < 1e-5); // dot=4, scale=1/2
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_row_softmax_sums_to_one() {
|
||||
let mut m = vec![1.0, 2.0, 3.0, 4.0];
|
||||
row_softmax(&mut m, 2, 2);
|
||||
assert!(((m[0] + m[1]) - 1.0).abs() < 1e-5);
|
||||
assert!(((m[2] + m[3]) - 1.0).abs() < 1e-5);
|
||||
}
|
||||
}
|
||||
73
vendor/ruvector/crates/ruvector-attn-mincut/src/graph.rs
vendored
Normal file
73
vendor/ruvector/crates/ruvector-attn-mincut/src/graph.rs
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A directed edge in the attention graph.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Edge {
|
||||
pub src: usize,
|
||||
pub dst: usize,
|
||||
pub weight: f32,
|
||||
}
|
||||
|
||||
/// Weighted directed graph built from attention logits.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AttentionGraph {
|
||||
pub nodes: usize,
|
||||
pub edges: Vec<Edge>,
|
||||
}
|
||||
|
||||
/// Build a weighted directed graph from flattened `seq_len x seq_len` logits.
|
||||
/// Only positive logits become edges; non-positive entries are omitted.
|
||||
pub fn graph_from_logits(logits: &[f32], seq_len: usize) -> AttentionGraph {
|
||||
assert_eq!(
|
||||
logits.len(),
|
||||
seq_len * seq_len,
|
||||
"logits length must equal seq_len^2"
|
||||
);
|
||||
let mut edges = Vec::new();
|
||||
for i in 0..seq_len {
|
||||
for j in 0..seq_len {
|
||||
let w = logits[i * seq_len + j];
|
||||
if w > 0.0 {
|
||||
edges.push(Edge {
|
||||
src: i,
|
||||
dst: j,
|
||||
weight: w,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
AttentionGraph {
|
||||
nodes: seq_len,
|
||||
edges,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_all_positive() {
|
||||
let g = graph_from_logits(&[1.0, 2.0, 3.0, 4.0], 2);
|
||||
assert_eq!(g.nodes, 2);
|
||||
assert_eq!(g.edges.len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_filters_non_positive() {
|
||||
let g = graph_from_logits(&[1.0, -0.5, 0.0, 2.0], 2);
|
||||
assert_eq!(g.edges.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "logits length must equal seq_len^2")]
|
||||
fn test_mismatched_length() {
|
||||
graph_from_logits(&[1.0, 2.0], 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_graph() {
|
||||
let g = graph_from_logits(&[-1.0; 9], 3);
|
||||
assert!(g.edges.is_empty());
|
||||
}
|
||||
}
|
||||
99
vendor/ruvector/crates/ruvector-attn-mincut/src/hysteresis.rs
vendored
Normal file
99
vendor/ruvector/crates/ruvector-attn-mincut/src/hysteresis.rs
vendored
Normal file
@@ -0,0 +1,99 @@
|
||||
/// Temporal hysteresis tracker for stable gating decisions.
|
||||
/// An edge only flips after the new decision is consistent for `tau` consecutive steps.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HysteresisTracker {
|
||||
prev_mask: Option<Vec<bool>>,
|
||||
counts: Vec<usize>,
|
||||
tau: usize,
|
||||
step: usize,
|
||||
}
|
||||
|
||||
impl HysteresisTracker {
|
||||
pub fn new(tau: usize) -> Self {
|
||||
Self {
|
||||
prev_mask: None,
|
||||
counts: Vec::new(),
|
||||
tau,
|
||||
step: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply hysteresis to a raw gating mask, returning the stabilised mask.
|
||||
pub fn apply(&mut self, raw: &[bool]) -> Vec<bool> {
|
||||
self.step += 1;
|
||||
let stable = match &self.prev_mask {
|
||||
None => {
|
||||
self.counts = vec![0; raw.len()];
|
||||
self.prev_mask = Some(raw.to_vec());
|
||||
return raw.to_vec();
|
||||
}
|
||||
Some(p) => p.clone(),
|
||||
};
|
||||
if self.counts.len() != raw.len() {
|
||||
self.counts = vec![0; raw.len()];
|
||||
self.prev_mask = Some(raw.to_vec());
|
||||
return raw.to_vec();
|
||||
}
|
||||
let mut result = stable.clone();
|
||||
for i in 0..raw.len() {
|
||||
if raw[i] != stable[i] {
|
||||
self.counts[i] += 1;
|
||||
if self.counts[i] >= self.tau {
|
||||
result[i] = raw[i];
|
||||
self.counts[i] = 0;
|
||||
}
|
||||
} else {
|
||||
self.counts[i] = 0;
|
||||
}
|
||||
}
|
||||
self.prev_mask = Some(result.clone());
|
||||
result
|
||||
}
|
||||
|
||||
pub fn step(&self) -> usize {
|
||||
self.step
|
||||
}
|
||||
pub fn current_mask(&self) -> Option<&[bool]> {
|
||||
self.prev_mask.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_first_step_passthrough() {
|
||||
let mut t = HysteresisTracker::new(3);
|
||||
assert_eq!(t.apply(&[true, false, true]), vec![true, false, true]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_flip_before_tau() {
|
||||
let mut t = HysteresisTracker::new(3);
|
||||
let init = vec![true, true, false];
|
||||
t.apply(&init);
|
||||
let changed = vec![false, true, true];
|
||||
assert_eq!(t.apply(&changed), init);
|
||||
assert_eq!(t.apply(&changed), init);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flip_at_tau() {
|
||||
let mut t = HysteresisTracker::new(2);
|
||||
t.apply(&[true, false]);
|
||||
let c = vec![false, true];
|
||||
t.apply(&c);
|
||||
assert_eq!(t.apply(&c), c);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_counter_reset_on_agreement() {
|
||||
let mut t = HysteresisTracker::new(3);
|
||||
t.apply(&[true]);
|
||||
t.apply(&[false]); // count=1
|
||||
t.apply(&[true]); // reset
|
||||
t.apply(&[false]); // count=1
|
||||
assert_eq!(t.apply(&[false]), vec![true]); // count=2 < 3
|
||||
}
|
||||
}
|
||||
32
vendor/ruvector/crates/ruvector-attn-mincut/src/lib.rs
vendored
Normal file
32
vendor/ruvector/crates/ruvector-attn-mincut/src/lib.rs
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
//! # ruvector-attn-mincut
|
||||
//!
|
||||
//! Dynamic min-cut gating as an alternative to softmax attention.
|
||||
//!
|
||||
//! Instead of applying softmax uniformly over all Q*K^T logits, this crate
|
||||
//! builds a weighted directed graph from the logits and computes a minimum
|
||||
//! cut (via Dinic's max-flow algorithm) to gate irrelevant edges. Surviving
|
||||
//! edges are then normalised with row-softmax and multiplied by V.
|
||||
//!
|
||||
//! ## Key features
|
||||
//!
|
||||
//! - **Graph construction** from attention logits (`graph` module).
|
||||
//! - **Dinic's max-flow / min-cut** solver (`mincut` module).
|
||||
//! - **Gating operators**: standard softmax and min-cut gated (`gating` module).
|
||||
//! - **Temporal hysteresis** to stabilise gating over time (`hysteresis` module).
|
||||
//! - **Witness logging** with SHA-256 hashing for determinism verification (`witness` module).
|
||||
//! - **Configuration** with sane defaults (`config` module).
|
||||
|
||||
pub mod config;
|
||||
pub mod gating;
|
||||
pub mod graph;
|
||||
pub mod hysteresis;
|
||||
pub mod mincut;
|
||||
pub mod witness;
|
||||
|
||||
// Re-export primary types for ergonomic usage.
|
||||
pub use config::MinCutConfig;
|
||||
pub use gating::{attn_mincut, attn_softmax, AttentionOutput};
|
||||
pub use graph::{graph_from_logits, AttentionGraph, Edge};
|
||||
pub use hysteresis::HysteresisTracker;
|
||||
pub use mincut::{dynamic_min_cut, CutResult, DinicSolver, GatingResult};
|
||||
pub use witness::{hash_tensor, witness_log, WitnessEntry};
|
||||
257
vendor/ruvector/crates/ruvector-attn-mincut/src/mincut.rs
vendored
Normal file
257
vendor/ruvector/crates/ruvector-attn-mincut/src/mincut.rs
vendored
Normal file
@@ -0,0 +1,257 @@
|
||||
use crate::graph::AttentionGraph;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::VecDeque;
|
||||
|
||||
/// Result of a single s-t min-cut.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CutResult {
|
||||
pub cut_edges: Vec<(usize, usize)>,
|
||||
pub cut_cost: f32,
|
||||
pub keep_mask: Vec<bool>,
|
||||
}
|
||||
|
||||
/// Aggregated gating decision from `dynamic_min_cut`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GatingResult {
|
||||
pub keep_mask: Vec<bool>,
|
||||
pub cut_cost: f32,
|
||||
pub edges_kept: usize,
|
||||
pub edges_total: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct FlowEdge {
|
||||
to: usize,
|
||||
rev: usize,
|
||||
cap: f32,
|
||||
}
|
||||
|
||||
/// Dinic's max-flow solver for s-t min-cut on an attention graph.
|
||||
pub struct DinicSolver {
|
||||
adj: Vec<Vec<FlowEdge>>,
|
||||
level: Vec<i32>,
|
||||
iter: Vec<usize>,
|
||||
}
|
||||
|
||||
impl DinicSolver {
|
||||
fn new(n: usize) -> Self {
|
||||
Self {
|
||||
adj: vec![Vec::new(); n],
|
||||
level: vec![0; n],
|
||||
iter: vec![0; n],
|
||||
}
|
||||
}
|
||||
|
||||
fn add_edge(&mut self, from: usize, to: usize, cap: f32) {
|
||||
let (rf, rt) = (self.adj[to].len(), self.adj[from].len());
|
||||
self.adj[from].push(FlowEdge { to, rev: rf, cap });
|
||||
self.adj[to].push(FlowEdge {
|
||||
to: from,
|
||||
rev: rt,
|
||||
cap: 0.0,
|
||||
});
|
||||
}
|
||||
|
||||
fn bfs(&mut self, s: usize) {
|
||||
self.level.fill(-1);
|
||||
self.level[s] = 0;
|
||||
let mut q = VecDeque::new();
|
||||
q.push_back(s);
|
||||
while let Some(v) = q.pop_front() {
|
||||
for e in &self.adj[v] {
|
||||
if e.cap > 0.0 && self.level[e.to] < 0 {
|
||||
self.level[e.to] = self.level[v] + 1;
|
||||
q.push_back(e.to);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn dfs(&mut self, v: usize, t: usize, f: f32) -> f32 {
|
||||
if v == t {
|
||||
return f;
|
||||
}
|
||||
while self.iter[v] < self.adj[v].len() {
|
||||
let i = self.iter[v];
|
||||
let (to, cap) = (self.adj[v][i].to, self.adj[v][i].cap);
|
||||
if cap > 0.0 && self.level[v] < self.level[to] {
|
||||
let d = self.dfs(to, t, f.min(cap));
|
||||
if d > 0.0 {
|
||||
self.adj[v][i].cap -= d;
|
||||
let rev = self.adj[v][i].rev;
|
||||
self.adj[to][rev].cap += d;
|
||||
return d;
|
||||
}
|
||||
}
|
||||
self.iter[v] += 1;
|
||||
}
|
||||
0.0
|
||||
}
|
||||
|
||||
/// Compute s-t min-cut on the given attention graph.
|
||||
pub fn min_cut(&mut self, graph: &AttentionGraph, s: usize, t: usize) -> CutResult {
|
||||
assert!(s < graph.nodes && t < graph.nodes && s != t);
|
||||
*self = Self::new(graph.nodes);
|
||||
for edge in &graph.edges {
|
||||
self.add_edge(edge.src, edge.dst, edge.weight);
|
||||
}
|
||||
|
||||
let inf = f32::MAX / 2.0;
|
||||
loop {
|
||||
self.bfs(s);
|
||||
if self.level[t] < 0 {
|
||||
break;
|
||||
}
|
||||
self.iter.fill(0);
|
||||
while self.dfs(s, t, inf) > 0.0 {}
|
||||
}
|
||||
|
||||
// Final BFS to find S-side of the cut
|
||||
self.bfs(s);
|
||||
let mut cut_edges = Vec::new();
|
||||
let mut cut_cost = 0.0f32;
|
||||
let mut keep_mask = vec![true; graph.edges.len()];
|
||||
for (idx, e) in graph.edges.iter().enumerate() {
|
||||
if self.level[e.src] >= 0 && self.level[e.dst] < 0 {
|
||||
cut_edges.push((e.src, e.dst));
|
||||
cut_cost += e.weight;
|
||||
keep_mask[idx] = false;
|
||||
}
|
||||
}
|
||||
CutResult {
|
||||
cut_edges,
|
||||
cut_cost,
|
||||
keep_mask,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute dynamic min-cut gating over a flattened `seq_len x seq_len` logit matrix.
|
||||
pub fn dynamic_min_cut(
|
||||
logits: &[f32],
|
||||
seq_len: usize,
|
||||
lambda: f32,
|
||||
_tau: usize,
|
||||
eps: f32,
|
||||
) -> GatingResult {
|
||||
assert_eq!(logits.len(), seq_len * seq_len);
|
||||
let n = seq_len * seq_len;
|
||||
let clamped: Vec<f32> = logits
|
||||
.iter()
|
||||
.map(|&v| if v > eps { v } else { 0.0 })
|
||||
.collect();
|
||||
let graph = crate::graph::graph_from_logits(&clamped, seq_len);
|
||||
|
||||
if graph.edges.is_empty() || seq_len < 2 {
|
||||
return GatingResult {
|
||||
keep_mask: vec![false; n],
|
||||
cut_cost: 0.0,
|
||||
edges_kept: 0,
|
||||
edges_total: n,
|
||||
};
|
||||
}
|
||||
|
||||
let mean_w: f32 = graph.edges.iter().map(|e| e.weight).sum::<f32>() / graph.edges.len() as f32;
|
||||
let threshold = lambda * mean_w;
|
||||
let mut flat_keep = vec![true; n];
|
||||
let mut total_cut_cost = 0.0f32;
|
||||
|
||||
let mut solver = DinicSolver::new(seq_len);
|
||||
let result = solver.min_cut(&graph, 0, seq_len - 1);
|
||||
if result.cut_cost <= threshold {
|
||||
total_cut_cost += result.cut_cost;
|
||||
for &(s, d) in &result.cut_edges {
|
||||
flat_keep[s * seq_len + d] = false;
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
if clamped[i] <= 0.0 {
|
||||
flat_keep[i] = false;
|
||||
}
|
||||
}
|
||||
let edges_kept = flat_keep.iter().filter(|&&k| k).count();
|
||||
GatingResult {
|
||||
keep_mask: flat_keep,
|
||||
cut_cost: total_cut_cost,
|
||||
edges_kept,
|
||||
edges_total: n,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::graph::Edge;
|
||||
|
||||
#[test]
|
||||
fn test_dinic_simple_cut() {
|
||||
let graph = AttentionGraph {
|
||||
nodes: 4,
|
||||
edges: vec![
|
||||
Edge {
|
||||
src: 0,
|
||||
dst: 1,
|
||||
weight: 5.0,
|
||||
},
|
||||
Edge {
|
||||
src: 0,
|
||||
dst: 2,
|
||||
weight: 4.0,
|
||||
},
|
||||
Edge {
|
||||
src: 1,
|
||||
dst: 3,
|
||||
weight: 3.0,
|
||||
},
|
||||
Edge {
|
||||
src: 2,
|
||||
dst: 3,
|
||||
weight: 6.0,
|
||||
},
|
||||
Edge {
|
||||
src: 1,
|
||||
dst: 2,
|
||||
weight: 2.0,
|
||||
},
|
||||
],
|
||||
};
|
||||
let mut solver = DinicSolver::new(4);
|
||||
let r = solver.min_cut(&graph, 0, 3);
|
||||
assert!((r.cut_cost - 9.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dinic_two_node() {
|
||||
let graph = AttentionGraph {
|
||||
nodes: 2,
|
||||
edges: vec![Edge {
|
||||
src: 0,
|
||||
dst: 1,
|
||||
weight: 3.5,
|
||||
}],
|
||||
};
|
||||
let mut solver = DinicSolver::new(2);
|
||||
let r = solver.min_cut(&graph, 0, 1);
|
||||
assert!((r.cut_cost - 3.5).abs() < 0.01);
|
||||
assert!(!r.keep_mask[0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dynamic_basic() {
|
||||
let logits = vec![1.0, 0.5, 0.0, 0.0, 1.0, 0.5, 0.0, 0.0, 1.0];
|
||||
let r = dynamic_min_cut(&logits, 3, 0.5, 2, 0.01);
|
||||
assert_eq!(r.edges_total, 9);
|
||||
assert!(r.edges_kept > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dynamic_all_negative() {
|
||||
assert_eq!(dynamic_min_cut(&[-1.0; 4], 2, 0.5, 2, 0.01).edges_kept, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dynamic_single_token() {
|
||||
assert_eq!(dynamic_min_cut(&[1.0], 1, 0.5, 2, 0.01).edges_total, 1);
|
||||
}
|
||||
}
|
||||
64
vendor/ruvector/crates/ruvector-attn-mincut/src/witness.rs
vendored
Normal file
64
vendor/ruvector/crates/ruvector-attn-mincut/src/witness.rs
vendored
Normal file
@@ -0,0 +1,64 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
/// A single witness entry for determinism verification.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WitnessEntry {
|
||||
pub q_hash: String,
|
||||
pub k_hash: String,
|
||||
pub keep_mask: Vec<bool>,
|
||||
pub cut_cost: f32,
|
||||
pub lambda: f32,
|
||||
pub tau: usize,
|
||||
pub eps: f32,
|
||||
pub timestamp: u64,
|
||||
}
|
||||
|
||||
/// Serialize a witness entry to a single JSONL line.
|
||||
pub fn witness_log(entry: &WitnessEntry) -> String {
|
||||
serde_json::to_string(entry).unwrap_or_else(|_| "{}".to_string())
|
||||
}
|
||||
|
||||
/// SHA-256 hash of a float tensor (little-endian bytes), returned as hex.
|
||||
pub fn hash_tensor(data: &[f32]) -> String {
|
||||
let mut h = Sha256::new();
|
||||
for &v in data {
|
||||
h.update(v.to_le_bytes());
|
||||
}
|
||||
h.finalize().iter().map(|b| format!("{:02x}", b)).collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_hash_deterministic() {
|
||||
let d = vec![1.0f32, 2.0, 3.0];
|
||||
assert_eq!(hash_tensor(&d), hash_tensor(&d));
|
||||
assert_eq!(hash_tensor(&d).len(), 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hash_differs() {
|
||||
assert_ne!(hash_tensor(&[1.0, 2.0]), hash_tensor(&[1.0, 3.0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_witness_roundtrip() {
|
||||
let e = WitnessEntry {
|
||||
q_hash: "a".into(),
|
||||
k_hash: "b".into(),
|
||||
keep_mask: vec![true, false],
|
||||
cut_cost: 1.5,
|
||||
lambda: 0.5,
|
||||
tau: 2,
|
||||
eps: 0.01,
|
||||
timestamp: 1000,
|
||||
};
|
||||
let json = witness_log(&e);
|
||||
let r: WitnessEntry = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(r.q_hash, "a");
|
||||
assert!((r.cut_cost - 1.5).abs() < f32::EPSILON);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user