Squashed 'vendor/ruvector/' content from commit b64c2172

git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
commit d803bfe2b1
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,420 @@
//! Federated aggregation: FedAvg, FedProx, Byzantine-tolerant weighted averaging.
use crate::error::FederationError;
use crate::types::AggregateWeights;
/// Aggregation strategy.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum AggregationStrategy {
/// Federated Averaging (McMahan et al., 2017).
FedAvg,
/// Federated Proximal (Li et al., 2020).
FedProx { mu: u32 },
/// Simple weighted average.
WeightedAverage,
}
impl Default for AggregationStrategy {
fn default() -> Self {
Self::FedAvg
}
}
/// A single contribution to a federated averaging round.
#[derive(Clone, Debug)]
pub struct Contribution {
/// Contributor pseudonym.
pub contributor: String,
/// Weight vector (LoRA deltas).
pub weights: Vec<f64>,
/// Quality/reputation weight for this contributor.
pub quality_weight: f64,
/// Number of training trajectories behind this contribution.
pub trajectory_count: u64,
}
/// Federated aggregation server.
pub struct FederatedAggregator {
/// Aggregation strategy.
strategy: AggregationStrategy,
/// Domain identifier.
domain_id: String,
/// Current round number.
round: u64,
/// Minimum contributions required for a round.
min_contributions: usize,
/// Standard deviation threshold for Byzantine outlier detection.
byzantine_std_threshold: f64,
/// Collected contributions for the current round.
contributions: Vec<Contribution>,
}
impl FederatedAggregator {
/// Create a new aggregator.
pub fn new(domain_id: String, strategy: AggregationStrategy) -> Self {
Self {
strategy,
domain_id,
round: 0,
min_contributions: 2,
byzantine_std_threshold: 2.0,
contributions: Vec::new(),
}
}
/// Set minimum contributions required.
pub fn with_min_contributions(mut self, min: usize) -> Self {
self.min_contributions = min;
self
}
/// Set Byzantine outlier threshold (in standard deviations).
pub fn with_byzantine_threshold(mut self, threshold: f64) -> Self {
self.byzantine_std_threshold = threshold;
self
}
/// Add a contribution for the current round.
pub fn add_contribution(&mut self, contribution: Contribution) {
self.contributions.push(contribution);
}
/// Number of contributions collected so far.
pub fn contribution_count(&self) -> usize {
self.contributions.len()
}
/// Current round number.
pub fn round(&self) -> u64 {
self.round
}
/// Check if we have enough contributions to aggregate.
pub fn ready(&self) -> bool {
self.contributions.len() >= self.min_contributions
}
/// Detect and remove Byzantine outliers.
///
/// Returns the number of outliers removed.
fn remove_byzantine_outliers(&mut self) -> u32 {
if self.contributions.len() < 3 {
return 0; // Need at least 3 for meaningful outlier detection
}
let dim = self.contributions[0].weights.len();
if dim == 0 || !self.contributions.iter().all(|c| c.weights.len() == dim) {
return 0;
}
// Compute mean and std of L2 norms
let norms: Vec<f64> = self.contributions.iter()
.map(|c| c.weights.iter().map(|w| w * w).sum::<f64>().sqrt())
.collect();
let mean_norm = norms.iter().sum::<f64>() / norms.len() as f64;
let variance = norms.iter().map(|n| (n - mean_norm).powi(2)).sum::<f64>() / norms.len() as f64;
let std_dev = variance.sqrt();
if std_dev < 1e-10 {
return 0;
}
let original_count = self.contributions.len();
let threshold = self.byzantine_std_threshold;
self.contributions.retain(|c| {
let norm = c.weights.iter().map(|w| w * w).sum::<f64>().sqrt();
((norm - mean_norm) / std_dev).abs() <= threshold
});
(original_count - self.contributions.len()) as u32
}
/// Aggregate contributions and produce an `AggregateWeights` segment.
pub fn aggregate(&mut self) -> Result<AggregateWeights, FederationError> {
if self.contributions.len() < self.min_contributions {
return Err(FederationError::InsufficientContributions {
min: self.min_contributions,
got: self.contributions.len(),
});
}
// Byzantine outlier removal
let outliers_removed = self.remove_byzantine_outliers();
if self.contributions.is_empty() {
return Err(FederationError::InsufficientContributions {
min: self.min_contributions,
got: 0,
});
}
let dim = self.contributions[0].weights.len();
let result = match self.strategy {
AggregationStrategy::FedAvg => self.fedavg(dim),
AggregationStrategy::FedProx { mu } => self.fedprox(dim, mu as f64 / 100.0),
AggregationStrategy::WeightedAverage => self.weighted_avg(dim),
};
self.round += 1;
let participation_count = self.contributions.len() as u32;
// Compute loss stats
let losses: Vec<f64> = self.contributions.iter()
.map(|c| {
// Use inverse quality as a proxy for loss
1.0 - c.quality_weight.clamp(0.0, 1.0)
})
.collect();
let mean_loss = losses.iter().sum::<f64>() / losses.len() as f64;
let loss_variance = losses.iter().map(|l| (l - mean_loss).powi(2)).sum::<f64>() / losses.len() as f64;
self.contributions.clear();
Ok(AggregateWeights {
round: self.round,
participation_count,
lora_deltas: result.0,
confidences: result.1,
mean_loss,
loss_variance,
domain_id: self.domain_id.clone(),
byzantine_filtered: outliers_removed > 0,
outliers_removed,
})
}
/// FedAvg: weighted average by trajectory count.
fn fedavg(&self, dim: usize) -> (Vec<f64>, Vec<f64>) {
let total_trajectories: f64 = self.contributions.iter()
.map(|c| c.trajectory_count as f64)
.sum();
let mut avg = vec![0.0f64; dim];
let mut confidences = vec![0.0f64; dim];
if total_trajectories <= 0.0 {
return (avg, confidences);
}
for c in &self.contributions {
let w = c.trajectory_count as f64 / total_trajectories;
for (i, val) in c.weights.iter().enumerate() {
if i < dim {
avg[i] += w * val;
}
}
}
// Confidence = inverse of variance across contributions per dimension
for i in 0..dim {
let mean = avg[i];
let var: f64 = self.contributions.iter()
.map(|c| {
let v = if i < c.weights.len() { c.weights[i] } else { 0.0 };
(v - mean).powi(2)
})
.sum::<f64>() / self.contributions.len() as f64;
confidences[i] = 1.0 / (1.0 + var);
}
(avg, confidences)
}
/// FedProx: weighted average with proximal term.
fn fedprox(&self, dim: usize, mu: f64) -> (Vec<f64>, Vec<f64>) {
let (mut avg, confidences) = self.fedavg(dim);
// Apply proximal regularization: pull toward zero (global model)
for val in &mut avg {
*val *= 1.0 / (1.0 + mu);
}
(avg, confidences)
}
/// Weighted average by quality_weight.
fn weighted_avg(&self, dim: usize) -> (Vec<f64>, Vec<f64>) {
let total_weight: f64 = self.contributions.iter().map(|c| c.quality_weight).sum();
let mut avg = vec![0.0f64; dim];
let mut confidences = vec![0.0f64; dim];
if total_weight <= 0.0 {
return (avg, confidences);
}
for c in &self.contributions {
let w = c.quality_weight / total_weight;
for (i, val) in c.weights.iter().enumerate() {
if i < dim {
avg[i] += w * val;
}
}
}
for i in 0..dim {
let mean = avg[i];
let var: f64 = self.contributions.iter()
.map(|c| {
let v = if i < c.weights.len() { c.weights[i] } else { 0.0 };
(v - mean).powi(2)
})
.sum::<f64>() / self.contributions.len() as f64;
confidences[i] = 1.0 / (1.0 + var);
}
(avg, confidences)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_contribution(name: &str, weights: Vec<f64>, quality: f64, trajectories: u64) -> Contribution {
Contribution {
contributor: name.to_string(),
weights,
quality_weight: quality,
trajectory_count: trajectories,
}
}
#[test]
fn fedavg_two_equal_contributions() {
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
.with_min_contributions(2);
agg.add_contribution(make_contribution("a", vec![1.0, 2.0, 3.0], 1.0, 100));
agg.add_contribution(make_contribution("b", vec![3.0, 4.0, 5.0], 1.0, 100));
let result = agg.aggregate().unwrap();
assert_eq!(result.round, 1);
assert_eq!(result.participation_count, 2);
assert!((result.lora_deltas[0] - 2.0).abs() < 1e-10);
assert!((result.lora_deltas[1] - 3.0).abs() < 1e-10);
assert!((result.lora_deltas[2] - 4.0).abs() < 1e-10);
}
#[test]
fn fedavg_weighted_by_trajectories() {
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
.with_min_contributions(2);
// A has 3x more trajectories, so A's values should dominate
agg.add_contribution(make_contribution("a", vec![10.0], 1.0, 300));
agg.add_contribution(make_contribution("b", vec![0.0], 1.0, 100));
let result = agg.aggregate().unwrap();
// (300*10 + 100*0) / 400 = 7.5
assert!((result.lora_deltas[0] - 7.5).abs() < 1e-10);
}
#[test]
fn fedprox_shrinks_toward_zero() {
let mut agg_avg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
.with_min_contributions(2);
agg_avg.add_contribution(make_contribution("a", vec![10.0], 1.0, 100));
agg_avg.add_contribution(make_contribution("b", vec![10.0], 1.0, 100));
let avg_result = agg_avg.aggregate().unwrap();
let mut agg_prox = FederatedAggregator::new("test".into(), AggregationStrategy::FedProx { mu: 50 })
.with_min_contributions(2);
agg_prox.add_contribution(make_contribution("a", vec![10.0], 1.0, 100));
agg_prox.add_contribution(make_contribution("b", vec![10.0], 1.0, 100));
let prox_result = agg_prox.aggregate().unwrap();
// FedProx should produce smaller values due to proximal regularization
assert!(prox_result.lora_deltas[0] < avg_result.lora_deltas[0]);
}
#[test]
fn byzantine_outlier_removal() {
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
.with_min_contributions(2)
.with_byzantine_threshold(2.0);
// Need enough good contributions so the outlier's z-score exceeds 2.0.
// With k good + 1 evil, the evil z-score grows with sqrt(k).
agg.add_contribution(make_contribution("good1", vec![1.0, 1.0], 1.0, 100));
agg.add_contribution(make_contribution("good2", vec![1.1, 0.9], 1.0, 100));
agg.add_contribution(make_contribution("good3", vec![0.9, 1.1], 1.0, 100));
agg.add_contribution(make_contribution("good4", vec![1.0, 1.0], 1.0, 100));
agg.add_contribution(make_contribution("good5", vec![1.0, 1.0], 1.0, 100));
agg.add_contribution(make_contribution("good6", vec![1.0, 1.0], 1.0, 100));
agg.add_contribution(make_contribution("evil", vec![100.0, 100.0], 1.0, 100)); // outlier
let result = agg.aggregate().unwrap();
assert!(result.byzantine_filtered);
assert!(result.outliers_removed >= 1);
// Result should be close to 1.0, not pulled toward 100
assert!(result.lora_deltas[0] < 5.0);
}
#[test]
fn insufficient_contributions_error() {
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
.with_min_contributions(3);
agg.add_contribution(make_contribution("a", vec![1.0], 1.0, 100));
let result = agg.aggregate();
assert!(result.is_err());
}
#[test]
fn weighted_average_strategy() {
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::WeightedAverage)
.with_min_contributions(2);
agg.add_contribution(make_contribution("a", vec![10.0], 0.9, 10));
agg.add_contribution(make_contribution("b", vec![0.0], 0.1, 10));
let result = agg.aggregate().unwrap();
// (0.9*10 + 0.1*0) / 1.0 = 9.0
assert!((result.lora_deltas[0] - 9.0).abs() < 1e-10);
}
#[test]
fn round_increments() {
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
.with_min_contributions(2);
agg.add_contribution(make_contribution("a", vec![1.0], 1.0, 100));
agg.add_contribution(make_contribution("b", vec![2.0], 1.0, 100));
let r1 = agg.aggregate().unwrap();
assert_eq!(r1.round, 1);
agg.add_contribution(make_contribution("a", vec![3.0], 1.0, 100));
agg.add_contribution(make_contribution("b", vec![4.0], 1.0, 100));
let r2 = agg.aggregate().unwrap();
assert_eq!(r2.round, 2);
}
#[test]
fn confidences_high_when_agreement() {
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
.with_min_contributions(2);
agg.add_contribution(make_contribution("a", vec![1.0], 1.0, 100));
agg.add_contribution(make_contribution("b", vec![1.0], 1.0, 100));
let result = agg.aggregate().unwrap();
// When all agree, variance = 0, confidence = 1/(1+0) = 1.0
assert!((result.confidences[0] - 1.0).abs() < 1e-10);
}
#[test]
fn confidences_lower_when_disagreement() {
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
.with_min_contributions(2);
agg.add_contribution(make_contribution("a", vec![0.0], 1.0, 100));
agg.add_contribution(make_contribution("b", vec![10.0], 1.0, 100));
let result = agg.aggregate().unwrap();
// When disagreement, confidence < 1.0
assert!(result.confidences[0] < 1.0);
}
}

View File

@@ -0,0 +1,416 @@
//! Differential privacy primitives for federated learning.
//!
//! Provides calibrated noise injection, gradient clipping, and a Renyi
//! Differential Privacy (RDP) accountant for tracking cumulative privacy loss.
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rand_distr::{Distribution, Normal};
use crate::error::FederationError;
use crate::types::{DiffPrivacyProof, NoiseMechanism};
/// Differential privacy engine for adding calibrated noise.
pub struct DiffPrivacyEngine {
/// Target epsilon (privacy loss bound).
epsilon: f64,
/// Target delta (probability of exceeding epsilon).
delta: f64,
/// L2 sensitivity bound.
sensitivity: f64,
/// Gradient clipping norm.
clipping_norm: f64,
/// Noise mechanism.
mechanism: NoiseMechanism,
/// Random number generator.
rng: StdRng,
}
impl DiffPrivacyEngine {
/// Create a new DP engine with Gaussian mechanism.
///
/// Default: epsilon=1.0, delta=1e-5 (strong privacy).
pub fn gaussian(
epsilon: f64,
delta: f64,
sensitivity: f64,
clipping_norm: f64,
) -> Result<Self, FederationError> {
if epsilon <= 0.0 {
return Err(FederationError::InvalidEpsilon(epsilon));
}
if delta <= 0.0 || delta >= 1.0 {
return Err(FederationError::InvalidDelta(delta));
}
Ok(Self {
epsilon,
delta,
sensitivity,
clipping_norm,
mechanism: NoiseMechanism::Gaussian,
rng: StdRng::from_rng(rand::thread_rng()).unwrap(),
})
}
/// Create a new DP engine with Laplace mechanism.
pub fn laplace(
epsilon: f64,
sensitivity: f64,
clipping_norm: f64,
) -> Result<Self, FederationError> {
if epsilon <= 0.0 {
return Err(FederationError::InvalidEpsilon(epsilon));
}
Ok(Self {
epsilon,
delta: 0.0,
sensitivity,
clipping_norm,
mechanism: NoiseMechanism::Laplace,
rng: StdRng::from_rng(rand::thread_rng()).unwrap(),
})
}
/// Create with a deterministic seed (for testing).
pub fn with_seed(mut self, seed: u64) -> Self {
self.rng = StdRng::seed_from_u64(seed);
self
}
/// Compute the Gaussian noise standard deviation (sigma).
fn gaussian_sigma(&self) -> f64 {
self.sensitivity * (2.0_f64 * (1.25_f64 / self.delta).ln()).sqrt() / self.epsilon
}
/// Compute the Laplace noise scale (b).
fn laplace_scale(&self) -> f64 {
self.sensitivity / self.epsilon
}
/// Clip a gradient vector to the configured L2 norm bound.
pub fn clip_gradients(&self, gradients: &mut [f64]) {
let norm: f64 = gradients.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm > self.clipping_norm {
let scale = self.clipping_norm / norm;
for g in gradients.iter_mut() {
*g *= scale;
}
}
}
/// Add calibrated noise to a vector of parameters.
///
/// Clips gradients first, then adds noise per the configured mechanism.
pub fn add_noise(&mut self, params: &mut [f64]) -> DiffPrivacyProof {
self.clip_gradients(params);
match self.mechanism {
NoiseMechanism::Gaussian => {
let sigma = self.gaussian_sigma();
let normal = Normal::new(0.0, sigma).unwrap();
for p in params.iter_mut() {
*p += normal.sample(&mut self.rng);
}
DiffPrivacyProof {
epsilon: self.epsilon,
delta: self.delta,
mechanism: NoiseMechanism::Gaussian,
sensitivity: self.sensitivity,
clipping_norm: self.clipping_norm,
noise_scale: sigma,
noised_parameter_count: params.len() as u64,
}
}
NoiseMechanism::Laplace => {
let b = self.laplace_scale();
for p in params.iter_mut() {
// Laplace noise via inverse CDF: b * sign(u-0.5) * ln(1 - 2|u-0.5|)
let u: f64 = self.rng.gen::<f64>() - 0.5;
let noise = -b * u.signum() * (1.0 - 2.0 * u.abs()).ln();
*p += noise;
}
DiffPrivacyProof {
epsilon: self.epsilon,
delta: 0.0,
mechanism: NoiseMechanism::Laplace,
sensitivity: self.sensitivity,
clipping_norm: self.clipping_norm,
noise_scale: b,
noised_parameter_count: params.len() as u64,
}
}
}
}
/// Add noise to a single scalar value.
pub fn add_noise_scalar(&mut self, value: &mut f64) -> f64 {
let mut v = [*value];
self.add_noise(&mut v);
*value = v[0];
v[0]
}
/// Current epsilon setting.
pub fn epsilon(&self) -> f64 {
self.epsilon
}
/// Current delta setting.
pub fn delta(&self) -> f64 {
self.delta
}
}
// -- Privacy Accountant (RDP) ------------------------------------------------
/// Renyi Differential Privacy (RDP) accountant for tracking cumulative privacy loss.
///
/// Tracks privacy budget across multiple export rounds using RDP composition,
/// which provides tighter bounds than naive (epsilon, delta)-DP composition.
pub struct PrivacyAccountant {
/// Maximum allowed cumulative epsilon.
epsilon_limit: f64,
/// Target delta for conversion from RDP to (epsilon, delta)-DP.
target_delta: f64,
/// Accumulated RDP values at various alpha orders.
/// Each entry: (alpha_order, accumulated_rdp_epsilon)
rdp_alphas: Vec<(f64, f64)>,
/// History of exports: (timestamp, epsilon_spent, mechanism).
history: Vec<ExportRecord>,
}
/// Record of a single privacy-consuming export.
#[derive(Clone, Debug)]
pub struct ExportRecord {
/// UNIX timestamp of the export.
pub timestamp_s: u64,
/// Epsilon consumed by this export.
pub epsilon: f64,
/// Delta for this export (0 for pure epsilon-DP).
pub delta: f64,
/// Mechanism used.
pub mechanism: NoiseMechanism,
/// Number of parameters.
pub parameter_count: u64,
}
impl PrivacyAccountant {
/// Create a new accountant with the given budget.
pub fn new(epsilon_limit: f64, target_delta: f64) -> Self {
// Standard RDP alpha orders for accounting
let alphas: Vec<f64> = vec![
1.5, 1.75, 2.0, 2.5, 3.0, 4.0, 5.0, 6.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 512.0,
1024.0,
];
let rdp_alphas = alphas.into_iter().map(|a| (a, 0.0)).collect();
Self {
epsilon_limit,
target_delta,
rdp_alphas,
history: Vec::new(),
}
}
/// Compute RDP epsilon for the Gaussian mechanism at a given alpha order.
fn gaussian_rdp(alpha: f64, sigma: f64) -> f64 {
alpha / (2.0 * sigma * sigma)
}
/// Convert RDP to (epsilon, delta)-DP for a given alpha order.
fn rdp_to_dp(alpha: f64, rdp_epsilon: f64, delta: f64) -> f64 {
rdp_epsilon - (delta.ln()) / (alpha - 1.0)
}
/// Record a Gaussian mechanism query.
pub fn record_gaussian(&mut self, sigma: f64, epsilon: f64, delta: f64, parameter_count: u64) {
// Accumulate RDP at each alpha order
for (alpha, rdp_eps) in &mut self.rdp_alphas {
*rdp_eps += Self::gaussian_rdp(*alpha, sigma);
}
self.history.push(ExportRecord {
timestamp_s: 0,
epsilon,
delta,
mechanism: NoiseMechanism::Gaussian,
parameter_count,
});
}
/// Record a Laplace mechanism query.
pub fn record_laplace(&mut self, epsilon: f64, parameter_count: u64) {
// For Laplace, RDP epsilon at order alpha is: alpha * eps / (alpha - 1)
// when alpha > 1
for (alpha, rdp_eps) in &mut self.rdp_alphas {
if *alpha > 1.0 {
*rdp_eps += *alpha * epsilon / (*alpha - 1.0);
}
}
self.history.push(ExportRecord {
timestamp_s: 0,
epsilon,
delta: 0.0,
mechanism: NoiseMechanism::Laplace,
parameter_count,
});
}
/// Get the current best (tightest) epsilon estimate.
pub fn current_epsilon(&self) -> f64 {
self.rdp_alphas
.iter()
.map(|(alpha, rdp_eps)| Self::rdp_to_dp(*alpha, *rdp_eps, self.target_delta))
.fold(f64::INFINITY, f64::min)
}
/// Remaining privacy budget.
pub fn remaining_budget(&self) -> f64 {
(self.epsilon_limit - self.current_epsilon()).max(0.0)
}
/// Check if we can afford another export with the given epsilon.
pub fn can_afford(&self, additional_epsilon: f64) -> bool {
self.current_epsilon() + additional_epsilon <= self.epsilon_limit
}
/// Check if budget is exhausted.
pub fn is_exhausted(&self) -> bool {
self.current_epsilon() >= self.epsilon_limit
}
/// Fraction of budget consumed (0.0 to 1.0+).
pub fn budget_fraction_used(&self) -> f64 {
self.current_epsilon() / self.epsilon_limit
}
/// Number of exports recorded.
pub fn export_count(&self) -> usize {
self.history.len()
}
/// Export history.
pub fn history(&self) -> &[ExportRecord] {
&self.history
}
/// Epsilon limit.
pub fn epsilon_limit(&self) -> f64 {
self.epsilon_limit
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn gaussian_engine_creates() {
let engine = DiffPrivacyEngine::gaussian(1.0, 1e-5, 1.0, 1.0);
assert!(engine.is_ok());
}
#[test]
fn invalid_epsilon_rejected() {
let engine = DiffPrivacyEngine::gaussian(0.0, 1e-5, 1.0, 1.0);
assert!(engine.is_err());
let engine = DiffPrivacyEngine::gaussian(-1.0, 1e-5, 1.0, 1.0);
assert!(engine.is_err());
}
#[test]
fn invalid_delta_rejected() {
let engine = DiffPrivacyEngine::gaussian(1.0, 0.0, 1.0, 1.0);
assert!(engine.is_err());
let engine = DiffPrivacyEngine::gaussian(1.0, 1.0, 1.0, 1.0);
assert!(engine.is_err());
}
#[test]
fn gradient_clipping() {
let engine = DiffPrivacyEngine::gaussian(1.0, 1e-5, 1.0, 1.0).unwrap();
let mut grads = vec![3.0, 4.0]; // norm = 5.0
engine.clip_gradients(&mut grads);
let norm: f64 = grads.iter().map(|x| x * x).sum::<f64>().sqrt();
assert!((norm - 1.0).abs() < 1e-6); // clipped to norm 1.0
}
#[test]
fn gradient_no_clip_when_small() {
let engine = DiffPrivacyEngine::gaussian(1.0, 1e-5, 1.0, 10.0).unwrap();
let mut grads = vec![3.0, 4.0]; // norm = 5.0, clip = 10.0
engine.clip_gradients(&mut grads);
assert!((grads[0] - 3.0).abs() < 1e-10);
assert!((grads[1] - 4.0).abs() < 1e-10);
}
#[test]
fn add_noise_gaussian_deterministic() {
let mut engine = DiffPrivacyEngine::gaussian(1.0, 1e-5, 1.0, 100.0)
.unwrap()
.with_seed(42);
let mut params = vec![1.0, 2.0, 3.0];
let original = params.clone();
let proof = engine.add_noise(&mut params);
assert_eq!(proof.mechanism, NoiseMechanism::Gaussian);
assert_eq!(proof.noised_parameter_count, 3);
// Params should be different from original (noise added)
assert!(params
.iter()
.zip(original.iter())
.any(|(a, b)| (a - b).abs() > 1e-10));
}
#[test]
fn add_noise_laplace_deterministic() {
let mut engine = DiffPrivacyEngine::laplace(1.0, 1.0, 100.0)
.unwrap()
.with_seed(42);
let mut params = vec![1.0, 2.0, 3.0];
let proof = engine.add_noise(&mut params);
assert_eq!(proof.mechanism, NoiseMechanism::Laplace);
assert_eq!(proof.noised_parameter_count, 3);
}
#[test]
fn privacy_accountant_initial_state() {
let acc = PrivacyAccountant::new(10.0, 1e-5);
assert_eq!(acc.export_count(), 0);
assert!(!acc.is_exhausted());
assert!(acc.can_afford(1.0));
assert!(acc.remaining_budget() > 9.9);
}
#[test]
fn privacy_accountant_tracks_gaussian() {
let mut acc = PrivacyAccountant::new(10.0, 1e-5);
// sigma=1.0 with epsilon=1.0 per query
acc.record_gaussian(1.0, 1.0, 1e-5, 100);
assert_eq!(acc.export_count(), 1);
let eps = acc.current_epsilon();
assert!(eps > 0.0);
assert!(eps < 10.0);
}
#[test]
fn privacy_accountant_composition() {
let mut acc = PrivacyAccountant::new(10.0, 1e-5);
let eps_after_1 = {
acc.record_gaussian(1.0, 1.0, 1e-5, 100);
acc.current_epsilon()
};
acc.record_gaussian(1.0, 1.0, 1e-5, 100);
let eps_after_2 = acc.current_epsilon();
// After 2 queries, epsilon should be larger
assert!(eps_after_2 > eps_after_1);
}
#[test]
fn privacy_accountant_exhaustion() {
let mut acc = PrivacyAccountant::new(1.0, 1e-5);
// Use a very small sigma to burn budget fast
for _ in 0..100 {
acc.record_gaussian(0.1, 10.0, 1e-5, 10);
}
assert!(acc.is_exhausted());
assert!(!acc.can_afford(0.1));
}
}

View File

@@ -0,0 +1,52 @@
//! Federation error types.
use thiserror::Error;
/// Errors that can occur during federation operations.
#[derive(Debug, Error)]
pub enum FederationError {
#[error("privacy budget exhausted: spent {spent:.4}, limit {limit:.4}")]
PrivacyBudgetExhausted { spent: f64, limit: f64 },
#[error("invalid epsilon value: {0} (must be > 0)")]
InvalidEpsilon(f64),
#[error("invalid delta value: {0} (must be in (0, 1))")]
InvalidDelta(f64),
#[error("segment validation failed: {0}")]
SegmentValidation(String),
#[error("version mismatch: expected {expected}, got {got}")]
VersionMismatch { expected: u32, got: u32 },
#[error("signature verification failed")]
SignatureVerification,
#[error("witness chain broken at index {0}")]
WitnessChainBroken(usize),
#[error("insufficient observations: need {needed}, have {have}")]
InsufficientObservations { needed: u64, have: u64 },
#[error("quality below threshold: {score:.4} < {threshold:.4}")]
QualityBelowThreshold { score: f64, threshold: f64 },
#[error("export rate limited: next export allowed at {next_allowed_epoch_s}")]
RateLimited { next_allowed_epoch_s: u64 },
#[error("PII detected after stripping: {field}")]
PiiLeakDetected { field: String },
#[error("Byzantine outlier detected from contributor {contributor}")]
ByzantineOutlier { contributor: String },
#[error("aggregation requires at least {min} contributions, got {got}")]
InsufficientContributions { min: usize, got: usize },
#[error("serialization error: {0}")]
Serialization(String),
#[error("io error: {0}")]
Io(String),
}

View File

@@ -0,0 +1,477 @@
//! Federation protocol: export builder, import merger, version-aware conflict resolution.
use crate::diff_privacy::DiffPrivacyEngine;
use crate::error::FederationError;
use crate::pii_strip::PiiStripper;
use crate::policy::FederationPolicy;
use crate::types::*;
/// Builder for constructing a federated learning export.
///
/// Follows the export flow from ADR-057:
/// 1. Extract learning (priors, kernels, cost curves, weights)
/// 2. PII-strip all payloads
/// 3. Add differential privacy noise
/// 4. Assemble manifest + attestation segments
pub struct ExportBuilder {
contributor_pseudonym: String,
domain_id: String,
priors: Vec<TransferPriorSet>,
kernels: Vec<PolicyKernelSnapshot>,
cost_curves: Vec<CostCurveSnapshot>,
weights: Vec<Vec<f64>>,
policy: FederationPolicy,
string_fields: Vec<(String, String)>,
}
/// A completed federated export ready for publishing.
#[derive(Clone, Debug)]
pub struct FederatedExport {
/// The manifest describing this export.
pub manifest: FederatedManifest,
/// PII redaction attestation.
pub redaction_log: RedactionLog,
/// Differential privacy attestation.
pub privacy_proof: DiffPrivacyProof,
/// Transfer priors (after PII stripping and DP noise).
pub priors: Vec<TransferPriorSet>,
/// Policy kernel snapshots.
pub kernels: Vec<PolicyKernelSnapshot>,
/// Cost curve snapshots.
pub cost_curves: Vec<CostCurveSnapshot>,
/// Noised aggregate weights (if any).
pub weights: Vec<Vec<f64>>,
}
impl ExportBuilder {
/// Create a new export builder.
pub fn new(contributor_pseudonym: String, domain_id: String) -> Self {
Self {
contributor_pseudonym,
domain_id,
priors: Vec::new(),
kernels: Vec::new(),
cost_curves: Vec::new(),
weights: Vec::new(),
policy: FederationPolicy::default(),
string_fields: Vec::new(),
}
}
/// Set the federation policy.
pub fn with_policy(mut self, policy: FederationPolicy) -> Self {
self.policy = policy;
self
}
/// Add transfer priors from a trained domain.
pub fn add_priors(mut self, priors: TransferPriorSet) -> Self {
self.priors.push(priors);
self
}
/// Add a policy kernel snapshot.
pub fn add_kernel(mut self, kernel: PolicyKernelSnapshot) -> Self {
self.kernels.push(kernel);
self
}
/// Add a cost curve snapshot.
pub fn add_cost_curve(mut self, curve: CostCurveSnapshot) -> Self {
self.cost_curves.push(curve);
self
}
/// Add raw weight vectors (LoRA deltas).
pub fn add_weights(mut self, weights: Vec<f64>) -> Self {
self.weights.push(weights);
self
}
/// Add a named string field for PII scanning.
pub fn add_string_field(mut self, name: String, value: String) -> Self {
self.string_fields.push((name, value));
self
}
/// Build the export: PII-strip, add DP noise, assemble manifest.
pub fn build(mut self, dp_engine: &mut DiffPrivacyEngine) -> Result<FederatedExport, FederationError> {
// 1. Apply quality gate from policy
self.priors.retain(|ps| {
ps.entries.iter().all(|e| e.observation_count >= self.policy.min_observations)
});
// 2. PII stripping
let mut stripper = PiiStripper::new();
let field_refs: Vec<(&str, &str)> = self.string_fields
.iter()
.map(|(n, v)| (n.as_str(), v.as_str()))
.collect();
let (_redacted_fields, redaction_log) = stripper.strip_fields(&field_refs);
// Strip PII from domain IDs and bucket IDs in priors
for ps in &mut self.priors {
ps.source_domain = stripper.strip_value(&ps.source_domain);
for entry in &mut ps.entries {
entry.bucket_id = stripper.strip_value(&entry.bucket_id);
}
}
// Strip PII from cost curve domain IDs
for curve in &mut self.cost_curves {
curve.domain_id = stripper.strip_value(&curve.domain_id);
}
// 3. Add differential privacy noise to numerical parameters
// Noise the Beta posteriors
let mut noised_count: u64 = 0;
for ps in &mut self.priors {
for entry in &mut ps.entries {
let mut params = [entry.params.alpha, entry.params.beta];
dp_engine.add_noise(&mut params);
entry.params.alpha = params[0].max(0.01); // Keep positive
entry.params.beta = params[1].max(0.01);
noised_count += 2;
}
}
// Noise the weight vectors
for w in &mut self.weights {
dp_engine.add_noise(w);
noised_count += w.len() as u64;
}
// Noise kernel knobs
for kernel in &mut self.kernels {
dp_engine.add_noise(&mut kernel.knobs);
noised_count += kernel.knobs.len() as u64;
}
// Noise cost curve values
for curve in &mut self.cost_curves {
let mut costs: Vec<f64> = curve.points.iter().map(|(_, c)| *c).collect();
dp_engine.add_noise(&mut costs);
for (i, (_, c)) in curve.points.iter_mut().enumerate() {
*c = costs[i];
}
noised_count += costs.len() as u64;
}
let privacy_proof = DiffPrivacyProof {
epsilon: dp_engine.epsilon(),
delta: dp_engine.delta(),
mechanism: NoiseMechanism::Gaussian,
sensitivity: 1.0,
clipping_norm: 1.0,
noise_scale: 0.0,
noised_parameter_count: noised_count,
};
// 4. Build manifest
let total_trajectories: u64 = self.priors.iter()
.flat_map(|ps| ps.entries.iter())
.map(|e| e.observation_count)
.sum();
let avg_quality = if !self.priors.is_empty() {
self.priors.iter()
.flat_map(|ps| ps.entries.iter())
.map(|e| e.params.mean())
.sum::<f64>()
/ self.priors.iter().map(|ps| ps.entries.len()).sum::<usize>().max(1) as f64
} else {
0.0
};
let manifest = FederatedManifest {
format_version: 1,
contributor_pseudonym: self.contributor_pseudonym,
export_timestamp_s: 0,
included_segment_ids: Vec::new(),
privacy_budget_spent: dp_engine.epsilon(),
domain_id: self.domain_id,
rvf_version_tag: String::from("rvf-v1"),
trajectory_count: total_trajectories,
avg_quality_score: avg_quality,
};
Ok(FederatedExport {
manifest,
redaction_log,
privacy_proof,
priors: self.priors,
kernels: self.kernels,
cost_curves: self.cost_curves,
weights: self.weights,
})
}
}
/// Merger for importing federated learning into local engines.
///
/// Follows the import flow from ADR-057:
/// 1. Validate signature and witness chain
/// 2. Check version compatibility
/// 3. Merge with dampened confidence
pub struct ImportMerger {
/// Current RVF version for compatibility checks.
current_version: u32,
/// Dampening factor for cross-version imports.
version_dampen_factor: f64,
}
impl ImportMerger {
/// Create a new import merger.
pub fn new() -> Self {
Self {
current_version: 1,
version_dampen_factor: 0.5,
}
}
/// Set the dampening factor for imports from different versions.
pub fn with_version_dampen(mut self, factor: f64) -> Self {
self.version_dampen_factor = factor.clamp(0.0, 1.0);
self
}
/// Validate a federated export.
pub fn validate(&self, export: &FederatedExport) -> Result<(), FederationError> {
// Check format version
if export.manifest.format_version == 0 {
return Err(FederationError::SegmentValidation(
"format_version must be > 0".into(),
));
}
// Check privacy proof has valid parameters
if export.privacy_proof.epsilon <= 0.0 {
return Err(FederationError::InvalidEpsilon(export.privacy_proof.epsilon));
}
// Check priors have positive parameters
for ps in &export.priors {
for entry in &ps.entries {
if entry.params.alpha <= 0.0 || entry.params.beta <= 0.0 {
return Err(FederationError::SegmentValidation(format!(
"invalid Beta params in bucket {}: alpha={}, beta={}",
entry.bucket_id, entry.params.alpha, entry.params.beta
)));
}
}
}
Ok(())
}
/// Merge imported priors with local priors.
///
/// Uses version-aware dampening: same version gets full weight,
/// older versions get dampened (sqrt-scaling per MetaThompsonEngine).
pub fn merge_priors(
&self,
local: &mut Vec<TransferPriorEntry>,
remote: &[TransferPriorEntry],
remote_version: u32,
) {
let dampen = if remote_version == self.current_version {
1.0
} else {
self.version_dampen_factor
};
for remote_entry in remote {
let dampened = remote_entry.params.dampen(dampen);
if let Some(local_entry) = local.iter_mut().find(|l| {
l.bucket_id == remote_entry.bucket_id && l.arm_id == remote_entry.arm_id
}) {
// Merge: sum parameters minus uniform prior
local_entry.params = local_entry.params.merge(&dampened);
local_entry.observation_count += remote_entry.observation_count;
} else {
// New entry: insert with dampened params
local.push(TransferPriorEntry {
bucket_id: remote_entry.bucket_id.clone(),
arm_id: remote_entry.arm_id.clone(),
params: dampened,
observation_count: remote_entry.observation_count,
});
}
}
}
/// Merge imported weights with local weights using weighted average.
pub fn merge_weights(
&self,
local: &mut [f64],
remote: &[f64],
local_weight: f64,
remote_weight: f64,
) {
let total = local_weight + remote_weight;
if total <= 0.0 || local.len() != remote.len() {
return;
}
for (l, r) in local.iter_mut().zip(remote.iter()) {
*l = (local_weight * *l + remote_weight * *r) / total;
}
}
}
impl Default for ImportMerger {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::diff_privacy::DiffPrivacyEngine;
fn make_test_priors() -> TransferPriorSet {
TransferPriorSet {
source_domain: "test_domain".into(),
entries: vec![
TransferPriorEntry {
bucket_id: "medium_algorithm".into(),
arm_id: "arm_0".into(),
params: BetaParams::new(10.0, 5.0),
observation_count: 50,
},
TransferPriorEntry {
bucket_id: "hard_synthesis".into(),
arm_id: "arm_1".into(),
params: BetaParams::new(8.0, 12.0),
observation_count: 30,
},
],
cost_ema: 0.85,
}
}
#[test]
fn export_builder_basic() {
let mut dp = DiffPrivacyEngine::gaussian(1.0, 1e-5, 1.0, 10.0)
.unwrap()
.with_seed(42);
let export = ExportBuilder::new("alice_pseudo".into(), "code_review".into())
.add_priors(make_test_priors())
.build(&mut dp)
.unwrap();
assert_eq!(export.manifest.contributor_pseudonym, "alice_pseudo");
assert_eq!(export.manifest.domain_id, "code_review");
assert_eq!(export.manifest.format_version, 1);
assert!(!export.priors.is_empty());
}
#[test]
fn export_builder_with_weights() {
let mut dp = DiffPrivacyEngine::gaussian(1.0, 1e-5, 1.0, 10.0)
.unwrap()
.with_seed(42);
let weights = vec![0.1, 0.2, 0.3, 0.4];
let export = ExportBuilder::new("bob_pseudo".into(), "genomics".into())
.add_weights(weights.clone())
.build(&mut dp)
.unwrap();
assert_eq!(export.weights.len(), 1);
// Weights should be different from original (noise added)
assert!(export.weights[0].iter().zip(weights.iter()).any(|(a, b)| (a - b).abs() > 1e-10));
}
#[test]
fn import_merger_validate() {
let mut dp = DiffPrivacyEngine::gaussian(1.0, 1e-5, 1.0, 10.0)
.unwrap()
.with_seed(42);
let export = ExportBuilder::new("alice".into(), "domain".into())
.add_priors(make_test_priors())
.build(&mut dp)
.unwrap();
let merger = ImportMerger::new();
assert!(merger.validate(&export).is_ok());
}
#[test]
fn import_merger_merge_priors_same_version() {
let merger = ImportMerger::new();
let mut local = vec![TransferPriorEntry {
bucket_id: "medium_algorithm".into(),
arm_id: "arm_0".into(),
params: BetaParams::new(5.0, 3.0),
observation_count: 20,
}];
let remote = vec![TransferPriorEntry {
bucket_id: "medium_algorithm".into(),
arm_id: "arm_0".into(),
params: BetaParams::new(10.0, 5.0),
observation_count: 50,
}];
merger.merge_priors(&mut local, &remote, 1);
assert_eq!(local.len(), 1);
// Merged: alpha = 5 + 10 - 1 = 14, beta = 3 + 5 - 1 = 7
assert!((local[0].params.alpha - 14.0).abs() < 1e-10);
assert!((local[0].params.beta - 7.0).abs() < 1e-10);
assert_eq!(local[0].observation_count, 70);
}
#[test]
fn import_merger_merge_priors_different_version() {
let merger = ImportMerger::new();
let mut local = vec![TransferPriorEntry {
bucket_id: "b".into(),
arm_id: "a".into(),
params: BetaParams::new(10.0, 10.0),
observation_count: 50,
}];
let remote = vec![TransferPriorEntry {
bucket_id: "b".into(),
arm_id: "a".into(),
params: BetaParams::new(20.0, 5.0),
observation_count: 40,
}];
merger.merge_priors(&mut local, &remote, 0); // older version -> dampened
assert_eq!(local.len(), 1);
// Remote dampened by 0.5: alpha = 1 + (20-1)*0.5 = 10.5, beta = 1 + (5-1)*0.5 = 3.0
// Merged: alpha = 10 + 10.5 - 1 = 19.5, beta = 10 + 3.0 - 1 = 12.0
assert!((local[0].params.alpha - 19.5).abs() < 1e-10);
assert!((local[0].params.beta - 12.0).abs() < 1e-10);
}
#[test]
fn import_merger_merge_new_bucket() {
let merger = ImportMerger::new();
let mut local: Vec<TransferPriorEntry> = Vec::new();
let remote = vec![TransferPriorEntry {
bucket_id: "new_bucket".into(),
arm_id: "arm_0".into(),
params: BetaParams::new(10.0, 5.0),
observation_count: 30,
}];
merger.merge_priors(&mut local, &remote, 1);
assert_eq!(local.len(), 1);
assert_eq!(local[0].bucket_id, "new_bucket");
}
#[test]
fn merge_weights_weighted_average() {
let merger = ImportMerger::new();
let mut local = vec![1.0, 2.0, 3.0];
let remote = vec![3.0, 4.0, 5.0];
merger.merge_weights(&mut local, &remote, 0.5, 0.5);
assert!((local[0] - 2.0).abs() < 1e-10);
assert!((local[1] - 3.0).abs() < 1e-10);
assert!((local[2] - 4.0).abs() < 1e-10);
}
}

View File

@@ -0,0 +1,24 @@
//! Federated RVF transfer learning.
//!
//! This crate implements the federation protocol described in ADR-057:
//! - **PII stripping**: Three-stage pipeline (detect, redact, attest)
//! - **Differential privacy**: Gaussian/Laplace noise, RDP accountant, gradient clipping
//! - **Federation protocol**: Export builder, import merger, version-aware conflict resolution
//! - **Federated aggregation**: FedAvg, FedProx, Byzantine-tolerant weighted averaging
//! - **Segment types**: FederatedManifest, DiffPrivacyProof, RedactionLog, AggregateWeights
pub mod types;
pub mod error;
pub mod pii_strip;
pub mod diff_privacy;
pub mod federation;
pub mod aggregate;
pub mod policy;
pub use types::*;
pub use error::FederationError;
pub use pii_strip::PiiStripper;
pub use diff_privacy::{DiffPrivacyEngine, PrivacyAccountant};
pub use federation::{ExportBuilder, ImportMerger};
pub use aggregate::{FederatedAggregator, AggregationStrategy};
pub use policy::FederationPolicy;

View File

@@ -0,0 +1,354 @@
//! Three-stage PII stripping pipeline.
//!
//! **Stage 1 — Detection**: Scan string fields for PII patterns.
//! **Stage 2 — Redaction**: Replace PII with deterministic pseudonyms.
//! **Stage 3 — Attestation**: Generate a `RedactionLog` segment.
use std::collections::HashMap;
use regex::Regex;
use sha3::{Shake256, digest::{Update, ExtendableOutput, XofReader}};
use crate::types::{RedactionLog, RedactionEntry};
/// PII category with its detection regex and replacement template.
struct PiiRule {
category: &'static str,
rule_id: &'static str,
pattern: Regex,
prefix: &'static str,
}
/// Three-stage PII stripping pipeline.
pub struct PiiStripper {
rules: Vec<PiiRule>,
/// Custom regex rules added by the user.
custom_rules: Vec<PiiRule>,
/// Pseudonym counter per category (for deterministic replacement).
counters: HashMap<String, u32>,
/// Map from original value to pseudonym (preserves structural relationships).
pseudonym_map: HashMap<String, String>,
}
impl PiiStripper {
/// Create a new stripper with default detection rules.
pub fn new() -> Self {
let rules = vec![
PiiRule {
category: "path",
rule_id: "rule_path_unix",
pattern: Regex::new(r#"(?:/(?:home|Users|var|tmp|opt|etc)/[^\s,;:"'\]}>)]+)"#).unwrap(),
prefix: "PATH",
},
PiiRule {
category: "path",
rule_id: "rule_path_windows",
pattern: Regex::new(r#"(?i:[A-Z]:\\(?:Users|Documents|Program Files)[^\s,;:"'\]}>)]+)"#).unwrap(),
prefix: "PATH",
},
PiiRule {
category: "ip",
rule_id: "rule_ipv4",
pattern: Regex::new(r"\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b").unwrap(),
prefix: "IP",
},
PiiRule {
category: "ip",
rule_id: "rule_ipv6",
pattern: Regex::new(r"\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b").unwrap(),
prefix: "IP",
},
PiiRule {
category: "email",
rule_id: "rule_email",
pattern: Regex::new(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b").unwrap(),
prefix: "EMAIL",
},
PiiRule {
category: "api_key",
rule_id: "rule_api_key_sk",
pattern: Regex::new(r"\bsk-[A-Za-z0-9]{20,}\b").unwrap(),
prefix: "REDACTED_KEY",
},
PiiRule {
category: "api_key",
rule_id: "rule_api_key_aws",
pattern: Regex::new(r"\bAKIA[A-Z0-9]{16}\b").unwrap(),
prefix: "REDACTED_KEY",
},
PiiRule {
category: "api_key",
rule_id: "rule_api_key_github",
pattern: Regex::new(r"\bghp_[A-Za-z0-9]{36}\b").unwrap(),
prefix: "REDACTED_KEY",
},
PiiRule {
category: "api_key",
rule_id: "rule_bearer_token",
pattern: Regex::new(r"\bBearer\s+[A-Za-z0-9._~+/=-]{20,}\b").unwrap(),
prefix: "REDACTED_KEY",
},
PiiRule {
category: "env_var",
rule_id: "rule_env_unix",
pattern: Regex::new(r"\$(?:HOME|USER|USERNAME|USERPROFILE|PATH|TMPDIR)\b").unwrap(),
prefix: "ENV",
},
PiiRule {
category: "env_var",
rule_id: "rule_env_windows",
pattern: Regex::new(r"%(?:HOME|USER|USERNAME|USERPROFILE|PATH|TEMP)%").unwrap(),
prefix: "ENV",
},
PiiRule {
category: "username",
rule_id: "rule_username_at",
pattern: Regex::new(r"@[A-Za-z][A-Za-z0-9_-]{2,30}\b").unwrap(),
prefix: "USER",
},
];
Self {
rules,
custom_rules: Vec::new(),
counters: HashMap::new(),
pseudonym_map: HashMap::new(),
}
}
/// Add a custom detection rule.
pub fn add_rule(&mut self, category: &'static str, rule_id: &'static str, pattern: &str, prefix: &'static str) -> Result<(), regex::Error> {
self.custom_rules.push(PiiRule {
category,
rule_id,
pattern: Regex::new(pattern)?,
prefix,
});
Ok(())
}
/// Reset the pseudonym map and counters (call between exports).
pub fn reset(&mut self) {
self.counters.clear();
self.pseudonym_map.clear();
}
/// Get or create a deterministic pseudonym for a matched value.
fn pseudonym(&mut self, original: &str, prefix: &str) -> String {
if let Some(existing) = self.pseudonym_map.get(original) {
return existing.clone();
}
let counter = self.counters.entry(prefix.to_string()).or_insert(0);
*counter += 1;
let pseudo = format!("<{}_{}>", prefix, counter);
self.pseudonym_map.insert(original.to_string(), pseudo.clone());
pseudo
}
/// Stage 1+2: Detect and redact PII in a single string.
/// Returns (redacted_string, list of (category, rule_id, count) tuples).
fn strip_string(&mut self, input: &str) -> (String, Vec<(String, String, u32)>) {
let mut result = input.to_string();
let mut detections: Vec<(String, String, u32)> = Vec::new();
let num_builtin = self.rules.len();
let num_custom = self.custom_rules.len();
for i in 0..(num_builtin + num_custom) {
let (pattern, prefix, category, rule_id) = if i < num_builtin {
let r = &self.rules[i];
(&r.pattern as &Regex, r.prefix, r.category, r.rule_id)
} else {
let r = &self.custom_rules[i - num_builtin];
(&r.pattern as &Regex, r.prefix, r.category, r.rule_id)
};
let matches: Vec<String> = pattern.find_iter(&result).map(|m| m.as_str().to_string()).collect();
if matches.is_empty() {
continue;
}
let count = matches.len() as u32;
// Build pseudonyms and perform replacements
let mut replacements: Vec<(String, String)> = Vec::new();
for m in &matches {
let pseudo = self.pseudonym(m, prefix);
replacements.push((m.clone(), pseudo));
}
for (original, pseudo) in &replacements {
result = result.replace(original.as_str(), pseudo.as_str());
}
detections.push((category.to_string(), rule_id.to_string(), count));
}
(result, detections)
}
/// Strip PII from a collection of named string fields.
///
/// Returns the redacted fields and a `RedactionLog` attestation.
pub fn strip_fields(&mut self, fields: &[(&str, &str)]) -> (Vec<(String, String)>, RedactionLog) {
// Stage 1+2: Detect and redact
let mut redacted_fields = Vec::new();
let mut all_detections: HashMap<(String, String), u32> = HashMap::new();
// Compute pre-redaction hash (Stage 3 prep)
let mut hasher = Shake256::default();
for (name, value) in fields {
hasher.update(name.as_bytes());
hasher.update(value.as_bytes());
}
let mut pre_hash = [0u8; 32];
hasher.finalize_xof().read(&mut pre_hash);
for (name, value) in fields {
let (redacted, detections) = self.strip_string(value);
redacted_fields.push((name.to_string(), redacted));
for (cat, rule, count) in detections {
*all_detections.entry((cat, rule)).or_insert(0) += count;
}
}
// Stage 3: Build attestation
let mut log = RedactionLog {
entries: Vec::new(),
pre_redaction_hash: pre_hash,
fields_scanned: fields.len() as u64,
total_redactions: 0,
timestamp_s: 0, // caller should set this
};
for ((category, rule_id), count) in &all_detections {
log.entries.push(RedactionEntry {
category: category.clone(),
count: *count,
rule_id: rule_id.clone(),
});
log.total_redactions += *count as u64;
}
(redacted_fields, log)
}
/// Strip PII from a single string value.
pub fn strip_value(&mut self, input: &str) -> String {
let (result, _) = self.strip_string(input);
result
}
/// Check if a string contains any detectable PII.
pub fn contains_pii(&self, input: &str) -> bool {
let all_rules: Vec<&PiiRule> = self.rules.iter().chain(self.custom_rules.iter()).collect();
for rule in all_rules {
if rule.pattern.is_match(input) {
return true;
}
}
false
}
/// Return the current pseudonym map (for debugging/auditing).
pub fn pseudonym_map(&self) -> &HashMap<String, String> {
&self.pseudonym_map
}
}
impl Default for PiiStripper {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn detect_unix_paths() {
let stripper = PiiStripper::new();
assert!(stripper.contains_pii("/home/user/project/src/main.rs"));
assert!(stripper.contains_pii("/Users/alice/.ssh/id_rsa"));
}
#[test]
fn detect_ipv4() {
let stripper = PiiStripper::new();
assert!(stripper.contains_pii("connecting to 192.168.1.100:8080"));
assert!(stripper.contains_pii("server at 10.0.0.1"));
}
#[test]
fn detect_emails() {
let stripper = PiiStripper::new();
assert!(stripper.contains_pii("contact user@example.com for help"));
}
#[test]
fn detect_api_keys() {
let stripper = PiiStripper::new();
assert!(stripper.contains_pii("key: sk-abcdefghijklmnopqrstuv"));
assert!(stripper.contains_pii("aws: AKIAIOSFODNN7EXAMPLE"));
assert!(stripper.contains_pii("token: ghp_abcdefghijklmnopqrstuvwxyz0123456789"));
}
#[test]
fn detect_env_vars() {
let stripper = PiiStripper::new();
assert!(stripper.contains_pii("path is $HOME/.config"));
assert!(stripper.contains_pii("dir is %USERPROFILE%\\Desktop"));
}
#[test]
fn redact_preserves_structure() {
let mut stripper = PiiStripper::new();
let input1 = "file at /home/alice/project/a.rs";
let input2 = "also at /home/alice/project/b.rs";
let r1 = stripper.strip_value(input1);
let r2 = stripper.strip_value(input2);
// Same path prefix should get same pseudonym
assert!(r1.contains("<PATH_"));
assert!(r2.contains("<PATH_"));
assert!(!r1.contains("/home/alice"));
assert!(!r2.contains("/home/alice"));
}
#[test]
fn strip_fields_produces_redaction_log() {
let mut stripper = PiiStripper::new();
let fields = vec![
("path_field", "/home/user/data.csv"),
("ip_field", "connecting to 10.0.0.1"),
("clean_field", "no pii here"),
];
let (redacted, log) = stripper.strip_fields(&fields);
assert_eq!(redacted.len(), 3);
assert_eq!(log.fields_scanned, 3);
assert!(log.total_redactions >= 2);
assert!(log.pre_redaction_hash != [0u8; 32]);
// clean field should be unchanged
assert_eq!(redacted[2].1, "no pii here");
}
#[test]
fn no_pii_returns_clean() {
let stripper = PiiStripper::new();
assert!(!stripper.contains_pii("just a normal string"));
assert!(!stripper.contains_pii("alpha = 10.5, beta = 3.2"));
}
#[test]
fn reset_clears_state() {
let mut stripper = PiiStripper::new();
stripper.strip_value("/home/user/test");
assert!(!stripper.pseudonym_map().is_empty());
stripper.reset();
assert!(stripper.pseudonym_map().is_empty());
}
#[test]
fn custom_rule() {
let mut stripper = PiiStripper::new();
stripper.add_rule("ssn", "rule_ssn", r"\b\d{3}-\d{2}-\d{4}\b", "SSN").unwrap();
assert!(stripper.contains_pii("ssn: 123-45-6789"));
let redacted = stripper.strip_value("ssn: 123-45-6789");
assert!(redacted.contains("<SSN_"));
assert!(!redacted.contains("123-45-6789"));
}
}

View File

@@ -0,0 +1,193 @@
//! Federation policy for selective sharing.
//!
//! Controls what learning is exported, quality gates, rate limits,
//! and privacy budget constraints.
use std::collections::HashSet;
/// Controls what a user shares in federated exports.
#[derive(Clone, Debug)]
pub struct FederationPolicy {
/// Segment types allowed for export (empty = all allowed).
pub allowed_segments: HashSet<u8>,
/// Segment types explicitly denied for export.
pub denied_segments: HashSet<u8>,
/// Domain IDs allowed for export (empty = all allowed).
pub allowed_domains: HashSet<String>,
/// Domain IDs denied for export.
pub denied_domains: HashSet<String>,
/// Minimum quality score for exported trajectories (0.0 - 1.0).
pub quality_threshold: f64,
/// Minimum observations per prior entry for export.
pub min_observations: u64,
/// Maximum exports per hour.
pub max_exports_per_hour: u32,
/// Maximum cumulative privacy budget (epsilon).
pub privacy_budget_limit: f64,
/// Whether to include policy kernel snapshots.
pub export_kernels: bool,
/// Whether to include cost curve data.
pub export_cost_curves: bool,
}
impl Default for FederationPolicy {
fn default() -> Self {
Self {
allowed_segments: HashSet::new(),
denied_segments: HashSet::new(),
allowed_domains: HashSet::new(),
denied_domains: HashSet::new(),
quality_threshold: 0.5,
min_observations: 12,
max_exports_per_hour: 100,
privacy_budget_limit: 10.0,
export_kernels: true,
export_cost_curves: true,
}
}
}
impl FederationPolicy {
/// Create a restrictive policy (deny all by default).
pub fn restrictive() -> Self {
Self {
quality_threshold: 0.8,
min_observations: 50,
max_exports_per_hour: 10,
privacy_budget_limit: 5.0,
export_kernels: false,
export_cost_curves: false,
..Default::default()
}
}
/// Create a permissive policy (share everything).
pub fn permissive() -> Self {
Self {
quality_threshold: 0.0,
min_observations: 1,
max_exports_per_hour: 1000,
privacy_budget_limit: 100.0,
export_kernels: true,
export_cost_curves: true,
..Default::default()
}
}
/// Check if a segment type is allowed for export.
pub fn is_segment_allowed(&self, seg_type: u8) -> bool {
if self.denied_segments.contains(&seg_type) {
return false;
}
if self.allowed_segments.is_empty() {
return true;
}
self.allowed_segments.contains(&seg_type)
}
/// Check if a domain is allowed for export.
pub fn is_domain_allowed(&self, domain_id: &str) -> bool {
if self.denied_domains.contains(domain_id) {
return false;
}
if self.allowed_domains.is_empty() {
return true;
}
self.allowed_domains.contains(domain_id)
}
/// Allow a specific segment type.
pub fn allow_segment(mut self, seg_type: u8) -> Self {
self.allowed_segments.insert(seg_type);
self
}
/// Deny a specific segment type.
pub fn deny_segment(mut self, seg_type: u8) -> Self {
self.denied_segments.insert(seg_type);
self
}
/// Allow a specific domain.
pub fn allow_domain(mut self, domain_id: &str) -> Self {
self.allowed_domains.insert(domain_id.to_string());
self
}
/// Deny a specific domain.
pub fn deny_domain(mut self, domain_id: &str) -> Self {
self.denied_domains.insert(domain_id.to_string());
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_policy() {
let p = FederationPolicy::default();
assert_eq!(p.quality_threshold, 0.5);
assert_eq!(p.min_observations, 12);
assert!(p.is_segment_allowed(0x33));
assert!(p.is_domain_allowed("anything"));
}
#[test]
fn restrictive_policy() {
let p = FederationPolicy::restrictive();
assert_eq!(p.quality_threshold, 0.8);
assert_eq!(p.min_observations, 50);
assert!(!p.export_kernels);
assert!(!p.export_cost_curves);
}
#[test]
fn permissive_policy() {
let p = FederationPolicy::permissive();
assert_eq!(p.quality_threshold, 0.0);
assert_eq!(p.min_observations, 1);
}
#[test]
fn segment_allowlist() {
let p = FederationPolicy::default().allow_segment(0x33).allow_segment(0x34);
assert!(p.is_segment_allowed(0x33));
assert!(p.is_segment_allowed(0x34));
assert!(!p.is_segment_allowed(0x35)); // not in allowlist
}
#[test]
fn segment_denylist() {
let p = FederationPolicy::default().deny_segment(0x36);
assert!(p.is_segment_allowed(0x33));
assert!(!p.is_segment_allowed(0x36)); // denied
}
#[test]
fn deny_takes_precedence() {
let p = FederationPolicy::default()
.allow_segment(0x33)
.deny_segment(0x33);
assert!(!p.is_segment_allowed(0x33)); // deny wins
}
#[test]
fn domain_filtering() {
let p = FederationPolicy::default()
.allow_domain("genomics")
.deny_domain("secret_project");
assert!(p.is_domain_allowed("genomics"));
assert!(!p.is_domain_allowed("secret_project"));
assert!(!p.is_domain_allowed("trading")); // not in allowlist
}
#[test]
fn empty_allowlist_allows_all() {
let p = FederationPolicy::default();
assert!(p.is_segment_allowed(0x33));
assert!(p.is_segment_allowed(0xFF));
assert!(p.is_domain_allowed("any_domain"));
}
}

View File

@@ -0,0 +1,426 @@
//! Federation segment payload types.
//!
//! Four new RVF segment types (0x33-0x36) defined in ADR-057.
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
// ── Segment type constants ──────────────────────────────────────────
/// Segment type discriminator for FederatedManifest.
pub const SEG_FEDERATED_MANIFEST: u8 = 0x33;
/// Segment type discriminator for DiffPrivacyProof.
pub const SEG_DIFF_PRIVACY_PROOF: u8 = 0x34;
/// Segment type discriminator for RedactionLog.
pub const SEG_REDACTION_LOG: u8 = 0x35;
/// Segment type discriminator for AggregateWeights.
pub const SEG_AGGREGATE_WEIGHTS: u8 = 0x36;
// ── FederatedManifest (0x33) ────────────────────────────────────────
/// Describes a federated learning export.
///
/// Attached as the first segment in every federation RVF file.
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct FederatedManifest {
/// Format version (currently 1).
pub format_version: u32,
/// Pseudonym of the contributor (never the real identity).
pub contributor_pseudonym: String,
/// UNIX timestamp (seconds) when the export was created.
pub export_timestamp_s: u64,
/// Segment IDs included in this export.
pub included_segment_ids: Vec<u64>,
/// Cumulative differential privacy budget spent (epsilon).
pub privacy_budget_spent: f64,
/// Domain identifier this export applies to.
pub domain_id: String,
/// RVF format version compatibility tag.
pub rvf_version_tag: String,
/// Number of trajectories summarized in the exported learning.
pub trajectory_count: u64,
/// Average quality score of exported trajectories.
pub avg_quality_score: f64,
}
impl FederatedManifest {
/// Create a new manifest with required fields.
pub fn new(contributor_pseudonym: String, domain_id: String) -> Self {
Self {
format_version: 1,
contributor_pseudonym,
export_timestamp_s: 0,
included_segment_ids: Vec::new(),
privacy_budget_spent: 0.0,
domain_id,
rvf_version_tag: String::from("rvf-v1"),
trajectory_count: 0,
avg_quality_score: 0.0,
}
}
}
// ── DiffPrivacyProof (0x34) ─────────────────────────────────────────
/// Noise mechanism used for differential privacy.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum NoiseMechanism {
/// Gaussian noise for (epsilon, delta)-DP.
Gaussian,
/// Laplace noise for epsilon-DP.
Laplace,
}
/// Differential privacy attestation.
///
/// Records the privacy parameters and noise applied during export.
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct DiffPrivacyProof {
/// Privacy loss parameter.
pub epsilon: f64,
/// Probability of privacy failure.
pub delta: f64,
/// Noise mechanism applied.
pub mechanism: NoiseMechanism,
/// L2 sensitivity bound used for noise calibration.
pub sensitivity: f64,
/// Gradient clipping norm.
pub clipping_norm: f64,
/// Noise scale (sigma for Gaussian, b for Laplace).
pub noise_scale: f64,
/// Number of parameters that had noise added.
pub noised_parameter_count: u64,
}
impl DiffPrivacyProof {
/// Create a new proof for Gaussian mechanism.
pub fn gaussian(epsilon: f64, delta: f64, sensitivity: f64, clipping_norm: f64) -> Self {
let sigma = sensitivity * (2.0_f64 * (1.25_f64 / delta).ln()).sqrt() / epsilon;
Self {
epsilon,
delta,
mechanism: NoiseMechanism::Gaussian,
sensitivity,
clipping_norm,
noise_scale: sigma,
noised_parameter_count: 0,
}
}
/// Create a new proof for Laplace mechanism.
pub fn laplace(epsilon: f64, sensitivity: f64, clipping_norm: f64) -> Self {
let b = sensitivity / epsilon;
Self {
epsilon,
delta: 0.0,
mechanism: NoiseMechanism::Laplace,
sensitivity,
clipping_norm,
noise_scale: b,
noised_parameter_count: 0,
}
}
}
// ── RedactionLog (0x35) ─────────────────────────────────────────────
/// A single redaction event.
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RedactionEntry {
/// Category of PII detected (e.g. "path", "ip", "email", "api_key").
pub category: String,
/// Number of occurrences redacted.
pub count: u32,
/// Rule identifier that triggered the redaction.
pub rule_id: String,
}
/// PII stripping attestation.
///
/// Proves that PII scanning was performed without revealing the original content.
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RedactionLog {
/// Individual redaction entries by category.
pub entries: Vec<RedactionEntry>,
/// SHAKE-256 hash of the pre-redaction content (32 bytes).
pub pre_redaction_hash: [u8; 32],
/// Total number of fields scanned.
pub fields_scanned: u64,
/// Total number of redactions applied.
pub total_redactions: u64,
/// UNIX timestamp (seconds) when redaction was performed.
pub timestamp_s: u64,
}
impl RedactionLog {
/// Create an empty redaction log.
pub fn new() -> Self {
Self {
entries: Vec::new(),
pre_redaction_hash: [0u8; 32],
fields_scanned: 0,
total_redactions: 0,
timestamp_s: 0,
}
}
/// Add a redaction entry.
pub fn add_entry(&mut self, category: &str, count: u32, rule_id: &str) {
self.total_redactions += count as u64;
self.entries.push(RedactionEntry {
category: category.to_string(),
count,
rule_id: rule_id.to_string(),
});
}
}
impl Default for RedactionLog {
fn default() -> Self {
Self::new()
}
}
// ── AggregateWeights (0x36) ─────────────────────────────────────────
/// Federated-averaged weight vector with metadata.
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct AggregateWeights {
/// Federated averaging round number.
pub round: u64,
/// Number of participants in this round.
pub participation_count: u32,
/// Aggregated LoRA delta weights (flattened).
pub lora_deltas: Vec<f64>,
/// Per-weight confidence scores.
pub confidences: Vec<f64>,
/// Mean loss across participants.
pub mean_loss: f64,
/// Loss variance across participants.
pub loss_variance: f64,
/// Domain identifier.
pub domain_id: String,
/// Whether Byzantine outlier removal was applied.
pub byzantine_filtered: bool,
/// Number of contributions removed as outliers.
pub outliers_removed: u32,
}
impl AggregateWeights {
/// Create empty aggregate weights for a domain.
pub fn new(domain_id: String, round: u64) -> Self {
Self {
round,
participation_count: 0,
lora_deltas: Vec::new(),
confidences: Vec::new(),
mean_loss: 0.0,
loss_variance: 0.0,
domain_id,
byzantine_filtered: false,
outliers_removed: 0,
}
}
}
// ── BetaParams (local copy for federation) ──────────────────────────
/// Beta distribution parameters for Thompson Sampling priors.
///
/// Mirrors the type in `ruvector-domain-expansion` to avoid cross-crate dependency.
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct BetaParams {
/// Alpha (success count + 1).
pub alpha: f64,
/// Beta (failure count + 1).
pub beta: f64,
}
impl BetaParams {
/// Create new Beta parameters.
pub fn new(alpha: f64, beta: f64) -> Self {
Self { alpha, beta }
}
/// Uniform (uninformative) prior.
pub fn uniform() -> Self {
Self { alpha: 1.0, beta: 1.0 }
}
/// Mean of the Beta distribution.
pub fn mean(&self) -> f64 {
self.alpha / (self.alpha + self.beta)
}
/// Total observations (alpha + beta - 2 for a Beta(1,1) prior).
pub fn observations(&self) -> f64 {
self.alpha + self.beta - 2.0
}
/// Merge two Beta posteriors by summing parameters and subtracting the uniform prior.
pub fn merge(&self, other: &BetaParams) -> BetaParams {
BetaParams {
alpha: self.alpha + other.alpha - 1.0,
beta: self.beta + other.beta - 1.0,
}
}
/// Dampen this prior by mixing with a uniform prior using sqrt-scaling.
pub fn dampen(&self, factor: f64) -> BetaParams {
let f = factor.clamp(0.0, 1.0);
BetaParams {
alpha: 1.0 + (self.alpha - 1.0) * f,
beta: 1.0 + (self.beta - 1.0) * f,
}
}
}
impl Default for BetaParams {
fn default() -> Self {
Self::uniform()
}
}
// ── TransferPrior (local copy for federation) ───────────────────────
/// Compact summary of learned priors for a single context bucket.
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct TransferPriorEntry {
/// Context bucket identifier.
pub bucket_id: String,
/// Arm identifier.
pub arm_id: String,
/// Beta posterior parameters.
pub params: BetaParams,
/// Number of observations backing this prior.
pub observation_count: u64,
}
/// Collection of transfer priors from a trained domain.
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct TransferPriorSet {
/// Source domain identifier.
pub source_domain: String,
/// Individual prior entries.
pub entries: Vec<TransferPriorEntry>,
/// EMA cost at time of extraction.
pub cost_ema: f64,
}
// ── PolicyKernelSnapshot ────────────────────────────────────────────
/// Snapshot of a policy kernel configuration for federation export.
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct PolicyKernelSnapshot {
/// Kernel identifier.
pub kernel_id: String,
/// Tunable knob values.
pub knobs: Vec<f64>,
/// Fitness score.
pub fitness: f64,
/// Generation number.
pub generation: u64,
}
// ── CostCurveSnapshot ───────────────────────────────────────────────
/// Snapshot of cost curve data for federation export.
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct CostCurveSnapshot {
/// Domain identifier.
pub domain_id: String,
/// Ordered (step, cost) points.
pub points: Vec<(u64, f64)>,
/// Acceleration factor (> 1.0 means transfer helped).
pub acceleration: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn segment_type_constants() {
assert_eq!(SEG_FEDERATED_MANIFEST, 0x33);
assert_eq!(SEG_DIFF_PRIVACY_PROOF, 0x34);
assert_eq!(SEG_REDACTION_LOG, 0x35);
assert_eq!(SEG_AGGREGATE_WEIGHTS, 0x36);
}
#[test]
fn federated_manifest_new() {
let m = FederatedManifest::new("alice".into(), "genomics".into());
assert_eq!(m.format_version, 1);
assert_eq!(m.contributor_pseudonym, "alice");
assert_eq!(m.domain_id, "genomics");
assert_eq!(m.trajectory_count, 0);
}
#[test]
fn diff_privacy_proof_gaussian() {
let p = DiffPrivacyProof::gaussian(1.0, 1e-5, 1.0, 1.0);
assert_eq!(p.mechanism, NoiseMechanism::Gaussian);
assert!(p.noise_scale > 0.0);
assert_eq!(p.epsilon, 1.0);
}
#[test]
fn diff_privacy_proof_laplace() {
let p = DiffPrivacyProof::laplace(1.0, 1.0, 1.0);
assert_eq!(p.mechanism, NoiseMechanism::Laplace);
assert!((p.noise_scale - 1.0).abs() < 1e-10);
}
#[test]
fn redaction_log_add_entry() {
let mut log = RedactionLog::new();
log.add_entry("path", 3, "rule_path_unix");
log.add_entry("ip", 2, "rule_ipv4");
assert_eq!(log.entries.len(), 2);
assert_eq!(log.total_redactions, 5);
}
#[test]
fn aggregate_weights_new() {
let w = AggregateWeights::new("code_review".into(), 1);
assert_eq!(w.round, 1);
assert_eq!(w.participation_count, 0);
assert!(!w.byzantine_filtered);
}
#[test]
fn beta_params_merge() {
let a = BetaParams::new(10.0, 5.0);
let b = BetaParams::new(8.0, 3.0);
let merged = a.merge(&b);
assert!((merged.alpha - 17.0).abs() < 1e-10);
assert!((merged.beta - 7.0).abs() < 1e-10);
}
#[test]
fn beta_params_dampen() {
let p = BetaParams::new(10.0, 5.0);
let dampened = p.dampen(0.25);
// alpha = 1 + (10-1)*0.25 = 1 + 2.25 = 3.25
assert!((dampened.alpha - 3.25).abs() < 1e-10);
// beta = 1 + (5-1)*0.25 = 1 + 1.0 = 2.0
assert!((dampened.beta - 2.0).abs() < 1e-10);
}
#[test]
fn beta_params_mean() {
let p = BetaParams::new(10.0, 10.0);
assert!((p.mean() - 0.5).abs() < 1e-10);
}
}