Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
142
crates/ruvector-profiler/src/config_hash.rs
Normal file
142
crates/ruvector-profiler/src/config_hash.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct BenchConfig {
|
||||
pub model_commit: String,
|
||||
pub weights_hash: String,
|
||||
pub lambda: f32,
|
||||
pub tau: usize,
|
||||
pub eps: f32,
|
||||
pub compiler_flags: String,
|
||||
}
|
||||
|
||||
/// SHA-256 hex digest of the JSON-serialised config.
|
||||
pub fn config_hash(config: &BenchConfig) -> String {
|
||||
let json = serde_json::to_string(config).expect("BenchConfig serializable");
|
||||
sha256(json.as_bytes())
|
||||
.iter()
|
||||
.map(|b| format!("{b:02x}"))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn sha256(data: &[u8]) -> [u8; 32] {
|
||||
#[rustfmt::skip]
|
||||
const K: [u32; 64] = [
|
||||
0x428a2f98,0x71374491,0xb5c0fbcf,0xe9b5dba5,0x3956c25b,0x59f111f1,0x923f82a4,0xab1c5ed5,
|
||||
0xd807aa98,0x12835b01,0x243185be,0x550c7dc3,0x72be5d74,0x80deb1fe,0x9bdc06a7,0xc19bf174,
|
||||
0xe49b69c1,0xefbe4786,0x0fc19dc6,0x240ca1cc,0x2de92c6f,0x4a7484aa,0x5cb0a9dc,0x76f988da,
|
||||
0x983e5152,0xa831c66d,0xb00327c8,0xbf597fc7,0xc6e00bf3,0xd5a79147,0x06ca6351,0x14292967,
|
||||
0x27b70a85,0x2e1b2138,0x4d2c6dfc,0x53380d13,0x650a7354,0x766a0abb,0x81c2c92e,0x92722c85,
|
||||
0xa2bfe8a1,0xa81a664b,0xc24b8b70,0xc76c51a3,0xd192e819,0xd6990624,0xf40e3585,0x106aa070,
|
||||
0x19a4c116,0x1e376c08,0x2748774c,0x34b0bcb5,0x391c0cb3,0x4ed8aa4a,0x5b9cca4f,0x682e6ff3,
|
||||
0x748f82ee,0x78a5636f,0x84c87814,0x8cc70208,0x90befffa,0xa4506ceb,0xbef9a3f7,0xc67178f2,
|
||||
];
|
||||
let mut h: [u32; 8] = [
|
||||
0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab,
|
||||
0x5be0cd19,
|
||||
];
|
||||
let bit_len = (data.len() as u64) * 8;
|
||||
let mut msg = data.to_vec();
|
||||
msg.push(0x80);
|
||||
while msg.len() % 64 != 56 {
|
||||
msg.push(0);
|
||||
}
|
||||
msg.extend_from_slice(&bit_len.to_be_bytes());
|
||||
|
||||
for chunk in msg.chunks_exact(64) {
|
||||
let mut w = [0u32; 64];
|
||||
for i in 0..16 {
|
||||
w[i] = u32::from_be_bytes([
|
||||
chunk[4 * i],
|
||||
chunk[4 * i + 1],
|
||||
chunk[4 * i + 2],
|
||||
chunk[4 * i + 3],
|
||||
]);
|
||||
}
|
||||
for i in 16..64 {
|
||||
let s0 = w[i - 15].rotate_right(7) ^ w[i - 15].rotate_right(18) ^ (w[i - 15] >> 3);
|
||||
let s1 = w[i - 2].rotate_right(17) ^ w[i - 2].rotate_right(19) ^ (w[i - 2] >> 10);
|
||||
w[i] = w[i - 16]
|
||||
.wrapping_add(s0)
|
||||
.wrapping_add(w[i - 7])
|
||||
.wrapping_add(s1);
|
||||
}
|
||||
let (mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut hh) =
|
||||
(h[0], h[1], h[2], h[3], h[4], h[5], h[6], h[7]);
|
||||
for i in 0..64 {
|
||||
let s1 = e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25);
|
||||
let ch = (e & f) ^ (!e & g);
|
||||
let t1 = hh
|
||||
.wrapping_add(s1)
|
||||
.wrapping_add(ch)
|
||||
.wrapping_add(K[i])
|
||||
.wrapping_add(w[i]);
|
||||
let s0 = a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22);
|
||||
let maj = (a & b) ^ (a & c) ^ (b & c);
|
||||
let t2 = s0.wrapping_add(maj);
|
||||
hh = g;
|
||||
g = f;
|
||||
f = e;
|
||||
e = d.wrapping_add(t1);
|
||||
d = c;
|
||||
c = b;
|
||||
b = a;
|
||||
a = t1.wrapping_add(t2);
|
||||
}
|
||||
for (i, v) in [a, b, c, d, e, f, g, hh].iter().enumerate() {
|
||||
h[i] = h[i].wrapping_add(*v);
|
||||
}
|
||||
}
|
||||
let mut out = [0u8; 32];
|
||||
for (i, v) in h.iter().enumerate() {
|
||||
out[4 * i..4 * i + 4].copy_from_slice(&v.to_be_bytes());
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
fn hex(data: &[u8]) -> String {
|
||||
sha256(data).iter().map(|b| format!("{b:02x}")).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sha_empty() {
|
||||
assert_eq!(
|
||||
hex(b""),
|
||||
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn sha_abc() {
|
||||
assert_eq!(
|
||||
hex(b"abc"),
|
||||
"ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn deterministic() {
|
||||
let c = BenchConfig {
|
||||
model_commit: "a".into(),
|
||||
weights_hash: "b".into(),
|
||||
lambda: 0.1,
|
||||
tau: 64,
|
||||
eps: 1e-6,
|
||||
compiler_flags: "-O3".into(),
|
||||
};
|
||||
let (h1, h2) = (config_hash(&c), config_hash(&c));
|
||||
assert_eq!(h1, h2);
|
||||
assert_eq!(h1.len(), 64);
|
||||
}
|
||||
#[test]
|
||||
fn varies() {
|
||||
let mk = |s: &str| BenchConfig {
|
||||
model_commit: s.into(),
|
||||
weights_hash: "x".into(),
|
||||
lambda: 0.1,
|
||||
tau: 64,
|
||||
eps: 1e-6,
|
||||
compiler_flags: "".into(),
|
||||
};
|
||||
assert_ne!(config_hash(&mk("a")), config_hash(&mk("b")));
|
||||
}
|
||||
}
|
||||
145
crates/ruvector-profiler/src/csv_emitter.rs
Normal file
145
crates/ruvector-profiler/src/csv_emitter.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
use crate::latency::LatencyRecord;
|
||||
use crate::memory::MemorySnapshot;
|
||||
use std::io::Write;
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct ResultRow {
|
||||
pub setting: String,
|
||||
pub coherence_delta: f64,
|
||||
pub kv_cache_reduction: f64,
|
||||
pub peak_mem_reduction: f64,
|
||||
pub energy_reduction: f64,
|
||||
pub p95_latency_us: u64,
|
||||
pub accuracy: f64,
|
||||
}
|
||||
|
||||
pub fn write_results_csv(path: &str, rows: &[ResultRow]) -> std::io::Result<()> {
|
||||
let mut f = std::fs::File::create(path)?;
|
||||
writeln!(f, "setting,coherence_delta,kv_cache_reduction,peak_mem_reduction,energy_reduction,p95_latency_us,accuracy")?;
|
||||
for r in rows {
|
||||
writeln!(
|
||||
f,
|
||||
"{},{},{},{},{},{},{}",
|
||||
esc(&r.setting),
|
||||
r.coherence_delta,
|
||||
r.kv_cache_reduction,
|
||||
r.peak_mem_reduction,
|
||||
r.energy_reduction,
|
||||
r.p95_latency_us,
|
||||
r.accuracy
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write_latency_csv(path: &str, records: &[LatencyRecord]) -> std::io::Result<()> {
|
||||
let mut f = std::fs::File::create(path)?;
|
||||
writeln!(f, "sample_id,wall_time_us,kernel_time_us,seq_len")?;
|
||||
for r in records {
|
||||
writeln!(
|
||||
f,
|
||||
"{},{},{},{}",
|
||||
r.sample_id, r.wall_time_us, r.kernel_time_us, r.seq_len
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write_memory_csv(path: &str, snapshots: &[MemorySnapshot]) -> std::io::Result<()> {
|
||||
let mut f = std::fs::File::create(path)?;
|
||||
writeln!(
|
||||
f,
|
||||
"timestamp_us,peak_rss_bytes,kv_cache_bytes,activation_bytes,temp_buffer_bytes"
|
||||
)?;
|
||||
for s in snapshots {
|
||||
writeln!(
|
||||
f,
|
||||
"{},{},{},{},{}",
|
||||
s.timestamp_us,
|
||||
s.peak_rss_bytes,
|
||||
s.kv_cache_bytes,
|
||||
s.activation_bytes,
|
||||
s.temp_buffer_bytes
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn esc(s: &str) -> String {
|
||||
if s.contains(',') || s.contains('"') || s.contains('\n') {
|
||||
format!("\"{}\"", s.replace('"', "\"\""))
|
||||
} else {
|
||||
s.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn esc_plain() {
|
||||
assert_eq!(esc("hello"), "hello");
|
||||
}
|
||||
#[test]
|
||||
fn esc_comma() {
|
||||
assert_eq!(esc("a,b"), "\"a,b\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_results() {
|
||||
let d = tempfile::tempdir().unwrap();
|
||||
let p = d.path().join("r.csv");
|
||||
write_results_csv(
|
||||
p.to_str().unwrap(),
|
||||
&[ResultRow {
|
||||
setting: "base".into(),
|
||||
coherence_delta: 0.01,
|
||||
kv_cache_reduction: 0.0,
|
||||
peak_mem_reduction: 0.0,
|
||||
energy_reduction: 0.0,
|
||||
p95_latency_us: 1200,
|
||||
accuracy: 0.95,
|
||||
}],
|
||||
)
|
||||
.unwrap();
|
||||
let c = std::fs::read_to_string(&p).unwrap();
|
||||
assert_eq!(c.lines().count(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_latency() {
|
||||
let d = tempfile::tempdir().unwrap();
|
||||
let p = d.path().join("l.csv");
|
||||
write_latency_csv(
|
||||
p.to_str().unwrap(),
|
||||
&[LatencyRecord {
|
||||
sample_id: 0,
|
||||
wall_time_us: 100,
|
||||
kernel_time_us: 80,
|
||||
seq_len: 64,
|
||||
}],
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(std::fs::read_to_string(&p).unwrap().lines().count(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_memory() {
|
||||
let d = tempfile::tempdir().unwrap();
|
||||
let p = d.path().join("m.csv");
|
||||
write_memory_csv(
|
||||
p.to_str().unwrap(),
|
||||
&[MemorySnapshot {
|
||||
peak_rss_bytes: 1024,
|
||||
kv_cache_bytes: 256,
|
||||
activation_bytes: 512,
|
||||
temp_buffer_bytes: 128,
|
||||
timestamp_us: 999,
|
||||
}],
|
||||
)
|
||||
.unwrap();
|
||||
let c = std::fs::read_to_string(&p).unwrap();
|
||||
assert!(c.contains("999,1024,256,512,128"));
|
||||
}
|
||||
}
|
||||
94
crates/ruvector-profiler/src/latency.rs
Normal file
94
crates/ruvector-profiler/src/latency.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct LatencyRecord {
|
||||
pub sample_id: usize,
|
||||
pub wall_time_us: u64,
|
||||
pub kernel_time_us: u64,
|
||||
pub seq_len: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct LatencyStats {
|
||||
pub p50_us: u64,
|
||||
pub p95_us: u64,
|
||||
pub p99_us: u64,
|
||||
pub mean_us: f64,
|
||||
pub std_us: f64,
|
||||
pub n: usize,
|
||||
}
|
||||
|
||||
/// Compute percentile and summary statistics from wall-time latencies.
|
||||
pub fn compute_latency_stats(records: &[LatencyRecord]) -> LatencyStats {
|
||||
let n = records.len();
|
||||
if n == 0 {
|
||||
return LatencyStats {
|
||||
p50_us: 0,
|
||||
p95_us: 0,
|
||||
p99_us: 0,
|
||||
mean_us: 0.0,
|
||||
std_us: 0.0,
|
||||
n: 0,
|
||||
};
|
||||
}
|
||||
let mut times: Vec<u64> = records.iter().map(|r| r.wall_time_us).collect();
|
||||
times.sort_unstable();
|
||||
let mean = times.iter().sum::<u64>() as f64 / n as f64;
|
||||
let var = times
|
||||
.iter()
|
||||
.map(|&t| (t as f64 - mean).powi(2))
|
||||
.sum::<f64>()
|
||||
/ n as f64;
|
||||
LatencyStats {
|
||||
p50_us: pctl(×, 50.0),
|
||||
p95_us: pctl(×, 95.0),
|
||||
p99_us: pctl(×, 99.0),
|
||||
mean_us: mean,
|
||||
std_us: var.sqrt(),
|
||||
n,
|
||||
}
|
||||
}
|
||||
|
||||
fn pctl(sorted: &[u64], p: f64) -> u64 {
|
||||
let idx = ((p / 100.0 * sorted.len() as f64).ceil() as usize)
|
||||
.min(sorted.len())
|
||||
.saturating_sub(1);
|
||||
sorted[idx]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
fn recs(ts: &[u64]) -> Vec<LatencyRecord> {
|
||||
ts.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &t)| LatencyRecord {
|
||||
sample_id: i,
|
||||
wall_time_us: t,
|
||||
kernel_time_us: t,
|
||||
seq_len: 128,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty() {
|
||||
assert_eq!(compute_latency_stats(&[]).n, 0);
|
||||
}
|
||||
#[test]
|
||||
fn single() {
|
||||
let s = compute_latency_stats(&recs(&[42]));
|
||||
assert_eq!((s.p50_us, s.p99_us, s.n), (42, 42, 1));
|
||||
}
|
||||
#[test]
|
||||
fn multi() {
|
||||
let s = compute_latency_stats(&recs(&[10, 20, 30, 40, 50, 60, 70, 80, 90, 100]));
|
||||
assert_eq!(s.p50_us, 50);
|
||||
assert!((s.mean_us - 55.0).abs() < 1e-9);
|
||||
}
|
||||
#[test]
|
||||
fn unsorted() {
|
||||
assert_eq!(
|
||||
compute_latency_stats(&recs(&[100, 10, 50, 90, 20])).p50_us,
|
||||
50
|
||||
);
|
||||
}
|
||||
}
|
||||
13
crates/ruvector-profiler/src/lib.rs
Normal file
13
crates/ruvector-profiler/src/lib.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
//! Memory, power, and latency profiling for attention-mechanism benchmarks.
|
||||
|
||||
pub mod config_hash;
|
||||
pub mod csv_emitter;
|
||||
pub mod latency;
|
||||
pub mod memory;
|
||||
pub mod power;
|
||||
|
||||
pub use config_hash::{config_hash, BenchConfig};
|
||||
pub use csv_emitter::{write_latency_csv, write_memory_csv, write_results_csv, ResultRow};
|
||||
pub use latency::{compute_latency_stats, LatencyRecord, LatencyStats};
|
||||
pub use memory::{capture_memory, MemoryReport, MemorySnapshot, MemoryTracker};
|
||||
pub use power::{EnergyResult, MockPowerSource, PowerSample, PowerSource, PowerTracker};
|
||||
130
crates/ruvector-profiler/src/memory.rs
Normal file
130
crates/ruvector-profiler/src/memory.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct MemorySnapshot {
|
||||
pub peak_rss_bytes: u64,
|
||||
pub kv_cache_bytes: u64,
|
||||
pub activation_bytes: u64,
|
||||
pub temp_buffer_bytes: u64,
|
||||
pub timestamp_us: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct MemoryReport {
|
||||
pub label: String,
|
||||
pub peak_rss: u64,
|
||||
pub mean_rss: u64,
|
||||
pub kv_cache_total: u64,
|
||||
pub activation_total: u64,
|
||||
}
|
||||
|
||||
/// Capture current memory via /proc/self/status (Linux) or zero fallback.
|
||||
pub fn capture_memory() -> MemorySnapshot {
|
||||
let ts = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_micros() as u64;
|
||||
MemorySnapshot {
|
||||
peak_rss_bytes: read_vm_rss(),
|
||||
kv_cache_bytes: 0,
|
||||
activation_bytes: 0,
|
||||
temp_buffer_bytes: 0,
|
||||
timestamp_us: ts,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
fn read_vm_rss() -> u64 {
|
||||
std::fs::read_to_string("/proc/self/status")
|
||||
.ok()
|
||||
.and_then(|s| {
|
||||
s.lines()
|
||||
.find(|l| l.starts_with("VmRSS:"))
|
||||
.and_then(|l| {
|
||||
l.trim_start_matches("VmRSS:")
|
||||
.trim()
|
||||
.trim_end_matches("kB")
|
||||
.trim()
|
||||
.parse::<u64>()
|
||||
.ok()
|
||||
})
|
||||
.map(|kb| kb * 1024)
|
||||
})
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
fn read_vm_rss() -> u64 {
|
||||
0
|
||||
}
|
||||
|
||||
pub struct MemoryTracker {
|
||||
pub snapshots: Vec<MemorySnapshot>,
|
||||
pub label: String,
|
||||
}
|
||||
|
||||
impl MemoryTracker {
|
||||
pub fn new(label: &str) -> Self {
|
||||
Self {
|
||||
snapshots: Vec::new(),
|
||||
label: label.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn snapshot(&mut self) {
|
||||
self.snapshots.push(capture_memory());
|
||||
}
|
||||
|
||||
pub fn peak(&self) -> u64 {
|
||||
self.snapshots
|
||||
.iter()
|
||||
.map(|s| s.peak_rss_bytes)
|
||||
.max()
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn report(&self) -> MemoryReport {
|
||||
let n = self.snapshots.len().max(1) as u64;
|
||||
MemoryReport {
|
||||
label: self.label.clone(),
|
||||
peak_rss: self.peak(),
|
||||
mean_rss: self.snapshots.iter().map(|s| s.peak_rss_bytes).sum::<u64>() / n,
|
||||
kv_cache_total: self.snapshots.iter().map(|s| s.kv_cache_bytes).sum(),
|
||||
activation_total: self.snapshots.iter().map(|s| s.activation_bytes).sum(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn capture_returns_nonzero_timestamp() {
|
||||
assert!(capture_memory().timestamp_us > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tracker_peak_empty() {
|
||||
assert_eq!(MemoryTracker::new("x").peak(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tracker_report_aggregates() {
|
||||
let mut t = MemoryTracker::new("test");
|
||||
let mk = |rss, kv, act| MemorySnapshot {
|
||||
peak_rss_bytes: rss,
|
||||
kv_cache_bytes: kv,
|
||||
activation_bytes: act,
|
||||
temp_buffer_bytes: 0,
|
||||
timestamp_us: 1,
|
||||
};
|
||||
t.snapshots.push(mk(100, 10, 20));
|
||||
t.snapshots.push(mk(200, 30, 40));
|
||||
let r = t.report();
|
||||
assert_eq!(
|
||||
(r.peak_rss, r.mean_rss, r.kv_cache_total, r.activation_total),
|
||||
(200, 150, 40, 60)
|
||||
);
|
||||
}
|
||||
}
|
||||
149
crates/ruvector-profiler/src/power.rs
Normal file
149
crates/ruvector-profiler/src/power.rs
Normal file
@@ -0,0 +1,149 @@
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct PowerSample {
|
||||
pub watts: f64,
|
||||
pub timestamp_us: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct EnergyResult {
|
||||
pub total_joules: f64,
|
||||
pub mean_watts: f64,
|
||||
pub peak_watts: f64,
|
||||
pub duration_s: f64,
|
||||
pub samples: usize,
|
||||
}
|
||||
|
||||
/// Trait for reading instantaneous power (NVML, RAPL, etc.).
|
||||
pub trait PowerSource {
|
||||
fn read_watts(&self) -> f64;
|
||||
}
|
||||
|
||||
/// Fixed-wattage mock for deterministic tests.
|
||||
pub struct MockPowerSource {
|
||||
pub watts: f64,
|
||||
}
|
||||
impl PowerSource for MockPowerSource {
|
||||
fn read_watts(&self) -> f64 {
|
||||
self.watts
|
||||
}
|
||||
}
|
||||
|
||||
/// Trapezoidal integration of power samples (must be sorted by timestamp).
|
||||
pub fn estimate_energy(samples: &[PowerSample]) -> EnergyResult {
|
||||
let n = samples.len();
|
||||
if n < 2 {
|
||||
return EnergyResult {
|
||||
total_joules: 0.0,
|
||||
samples: n,
|
||||
duration_s: 0.0,
|
||||
mean_watts: samples.first().map_or(0.0, |s| s.watts),
|
||||
peak_watts: samples.first().map_or(0.0, |s| s.watts),
|
||||
};
|
||||
}
|
||||
let (mut joules, mut peak, mut sum) = (0.0f64, f64::NEG_INFINITY, 0.0f64);
|
||||
for i in 0..n {
|
||||
let w = samples[i].watts;
|
||||
sum += w;
|
||||
if w > peak {
|
||||
peak = w;
|
||||
}
|
||||
if i > 0 {
|
||||
let dt = samples[i]
|
||||
.timestamp_us
|
||||
.saturating_sub(samples[i - 1].timestamp_us) as f64
|
||||
/ 1e6;
|
||||
joules += (samples[i - 1].watts + w) / 2.0 * dt;
|
||||
}
|
||||
}
|
||||
let dur = samples
|
||||
.last()
|
||||
.unwrap()
|
||||
.timestamp_us
|
||||
.saturating_sub(samples[0].timestamp_us) as f64
|
||||
/ 1e6;
|
||||
EnergyResult {
|
||||
total_joules: joules,
|
||||
mean_watts: sum / n as f64,
|
||||
peak_watts: peak,
|
||||
duration_s: dur,
|
||||
samples: n,
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PowerTracker {
|
||||
pub samples: Vec<PowerSample>,
|
||||
pub label: String,
|
||||
}
|
||||
|
||||
impl PowerTracker {
|
||||
pub fn new(label: &str) -> Self {
|
||||
Self {
|
||||
samples: Vec::new(),
|
||||
label: label.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sample(&mut self, source: &dyn PowerSource) {
|
||||
let ts = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_micros() as u64;
|
||||
self.samples.push(PowerSample {
|
||||
watts: source.read_watts(),
|
||||
timestamp_us: ts,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn energy(&self) -> EnergyResult {
|
||||
estimate_energy(&self.samples)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
fn ps(w: f64, t: u64) -> PowerSample {
|
||||
PowerSample {
|
||||
watts: w,
|
||||
timestamp_us: t,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn energy_empty() {
|
||||
let r = estimate_energy(&[]);
|
||||
assert_eq!(r.samples, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn energy_single() {
|
||||
let r = estimate_energy(&[ps(42.0, 0)]);
|
||||
assert_eq!((r.total_joules, r.mean_watts), (0.0, 42.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn energy_constant_100w_1s() {
|
||||
let r = estimate_energy(&[ps(100.0, 0), ps(100.0, 1_000_000)]);
|
||||
assert!((r.total_joules - 100.0).abs() < 1e-9);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn energy_ramp() {
|
||||
let r = estimate_energy(&[ps(0.0, 0), ps(200.0, 1_000_000)]);
|
||||
assert!((r.total_joules - 100.0).abs() < 1e-9);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mock_source() {
|
||||
assert_eq!(MockPowerSource { watts: 75.0 }.read_watts(), 75.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tracker_collects() {
|
||||
let src = MockPowerSource { watts: 50.0 };
|
||||
let mut t = PowerTracker::new("gpu");
|
||||
t.sample(&src);
|
||||
t.sample(&src);
|
||||
assert_eq!(t.samples.len(), 2);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user