Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
420
vendor/ruvector/crates/rvf/rvf-federation/src/aggregate.rs
vendored
Normal file
420
vendor/ruvector/crates/rvf/rvf-federation/src/aggregate.rs
vendored
Normal file
@@ -0,0 +1,420 @@
|
||||
//! Federated aggregation: FedAvg, FedProx, Byzantine-tolerant weighted averaging.
|
||||
|
||||
use crate::error::FederationError;
|
||||
use crate::types::AggregateWeights;
|
||||
|
||||
/// Aggregation strategy.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum AggregationStrategy {
|
||||
/// Federated Averaging (McMahan et al., 2017).
|
||||
FedAvg,
|
||||
/// Federated Proximal (Li et al., 2020).
|
||||
FedProx { mu: u32 },
|
||||
/// Simple weighted average.
|
||||
WeightedAverage,
|
||||
}
|
||||
|
||||
impl Default for AggregationStrategy {
|
||||
fn default() -> Self {
|
||||
Self::FedAvg
|
||||
}
|
||||
}
|
||||
|
||||
/// A single contribution to a federated averaging round.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Contribution {
|
||||
/// Contributor pseudonym.
|
||||
pub contributor: String,
|
||||
/// Weight vector (LoRA deltas).
|
||||
pub weights: Vec<f64>,
|
||||
/// Quality/reputation weight for this contributor.
|
||||
pub quality_weight: f64,
|
||||
/// Number of training trajectories behind this contribution.
|
||||
pub trajectory_count: u64,
|
||||
}
|
||||
|
||||
/// Federated aggregation server.
|
||||
pub struct FederatedAggregator {
|
||||
/// Aggregation strategy.
|
||||
strategy: AggregationStrategy,
|
||||
/// Domain identifier.
|
||||
domain_id: String,
|
||||
/// Current round number.
|
||||
round: u64,
|
||||
/// Minimum contributions required for a round.
|
||||
min_contributions: usize,
|
||||
/// Standard deviation threshold for Byzantine outlier detection.
|
||||
byzantine_std_threshold: f64,
|
||||
/// Collected contributions for the current round.
|
||||
contributions: Vec<Contribution>,
|
||||
}
|
||||
|
||||
impl FederatedAggregator {
|
||||
/// Create a new aggregator.
|
||||
pub fn new(domain_id: String, strategy: AggregationStrategy) -> Self {
|
||||
Self {
|
||||
strategy,
|
||||
domain_id,
|
||||
round: 0,
|
||||
min_contributions: 2,
|
||||
byzantine_std_threshold: 2.0,
|
||||
contributions: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set minimum contributions required.
|
||||
pub fn with_min_contributions(mut self, min: usize) -> Self {
|
||||
self.min_contributions = min;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set Byzantine outlier threshold (in standard deviations).
|
||||
pub fn with_byzantine_threshold(mut self, threshold: f64) -> Self {
|
||||
self.byzantine_std_threshold = threshold;
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a contribution for the current round.
|
||||
pub fn add_contribution(&mut self, contribution: Contribution) {
|
||||
self.contributions.push(contribution);
|
||||
}
|
||||
|
||||
/// Number of contributions collected so far.
|
||||
pub fn contribution_count(&self) -> usize {
|
||||
self.contributions.len()
|
||||
}
|
||||
|
||||
/// Current round number.
|
||||
pub fn round(&self) -> u64 {
|
||||
self.round
|
||||
}
|
||||
|
||||
/// Check if we have enough contributions to aggregate.
|
||||
pub fn ready(&self) -> bool {
|
||||
self.contributions.len() >= self.min_contributions
|
||||
}
|
||||
|
||||
/// Detect and remove Byzantine outliers.
|
||||
///
|
||||
/// Returns the number of outliers removed.
|
||||
fn remove_byzantine_outliers(&mut self) -> u32 {
|
||||
if self.contributions.len() < 3 {
|
||||
return 0; // Need at least 3 for meaningful outlier detection
|
||||
}
|
||||
|
||||
let dim = self.contributions[0].weights.len();
|
||||
if dim == 0 || !self.contributions.iter().all(|c| c.weights.len() == dim) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Compute mean and std of L2 norms
|
||||
let norms: Vec<f64> = self.contributions.iter()
|
||||
.map(|c| c.weights.iter().map(|w| w * w).sum::<f64>().sqrt())
|
||||
.collect();
|
||||
|
||||
let mean_norm = norms.iter().sum::<f64>() / norms.len() as f64;
|
||||
let variance = norms.iter().map(|n| (n - mean_norm).powi(2)).sum::<f64>() / norms.len() as f64;
|
||||
let std_dev = variance.sqrt();
|
||||
|
||||
if std_dev < 1e-10 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let original_count = self.contributions.len();
|
||||
let threshold = self.byzantine_std_threshold;
|
||||
|
||||
self.contributions.retain(|c| {
|
||||
let norm = c.weights.iter().map(|w| w * w).sum::<f64>().sqrt();
|
||||
((norm - mean_norm) / std_dev).abs() <= threshold
|
||||
});
|
||||
|
||||
(original_count - self.contributions.len()) as u32
|
||||
}
|
||||
|
||||
/// Aggregate contributions and produce an `AggregateWeights` segment.
|
||||
pub fn aggregate(&mut self) -> Result<AggregateWeights, FederationError> {
|
||||
if self.contributions.len() < self.min_contributions {
|
||||
return Err(FederationError::InsufficientContributions {
|
||||
min: self.min_contributions,
|
||||
got: self.contributions.len(),
|
||||
});
|
||||
}
|
||||
|
||||
// Byzantine outlier removal
|
||||
let outliers_removed = self.remove_byzantine_outliers();
|
||||
|
||||
if self.contributions.is_empty() {
|
||||
return Err(FederationError::InsufficientContributions {
|
||||
min: self.min_contributions,
|
||||
got: 0,
|
||||
});
|
||||
}
|
||||
|
||||
let dim = self.contributions[0].weights.len();
|
||||
|
||||
let result = match self.strategy {
|
||||
AggregationStrategy::FedAvg => self.fedavg(dim),
|
||||
AggregationStrategy::FedProx { mu } => self.fedprox(dim, mu as f64 / 100.0),
|
||||
AggregationStrategy::WeightedAverage => self.weighted_avg(dim),
|
||||
};
|
||||
|
||||
self.round += 1;
|
||||
let participation_count = self.contributions.len() as u32;
|
||||
|
||||
// Compute loss stats
|
||||
let losses: Vec<f64> = self.contributions.iter()
|
||||
.map(|c| {
|
||||
// Use inverse quality as a proxy for loss
|
||||
1.0 - c.quality_weight.clamp(0.0, 1.0)
|
||||
})
|
||||
.collect();
|
||||
let mean_loss = losses.iter().sum::<f64>() / losses.len() as f64;
|
||||
let loss_variance = losses.iter().map(|l| (l - mean_loss).powi(2)).sum::<f64>() / losses.len() as f64;
|
||||
|
||||
self.contributions.clear();
|
||||
|
||||
Ok(AggregateWeights {
|
||||
round: self.round,
|
||||
participation_count,
|
||||
lora_deltas: result.0,
|
||||
confidences: result.1,
|
||||
mean_loss,
|
||||
loss_variance,
|
||||
domain_id: self.domain_id.clone(),
|
||||
byzantine_filtered: outliers_removed > 0,
|
||||
outliers_removed,
|
||||
})
|
||||
}
|
||||
|
||||
/// FedAvg: weighted average by trajectory count.
|
||||
fn fedavg(&self, dim: usize) -> (Vec<f64>, Vec<f64>) {
|
||||
let total_trajectories: f64 = self.contributions.iter()
|
||||
.map(|c| c.trajectory_count as f64)
|
||||
.sum();
|
||||
|
||||
let mut avg = vec![0.0f64; dim];
|
||||
let mut confidences = vec![0.0f64; dim];
|
||||
|
||||
if total_trajectories <= 0.0 {
|
||||
return (avg, confidences);
|
||||
}
|
||||
|
||||
for c in &self.contributions {
|
||||
let w = c.trajectory_count as f64 / total_trajectories;
|
||||
for (i, val) in c.weights.iter().enumerate() {
|
||||
if i < dim {
|
||||
avg[i] += w * val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Confidence = inverse of variance across contributions per dimension
|
||||
for i in 0..dim {
|
||||
let mean = avg[i];
|
||||
let var: f64 = self.contributions.iter()
|
||||
.map(|c| {
|
||||
let v = if i < c.weights.len() { c.weights[i] } else { 0.0 };
|
||||
(v - mean).powi(2)
|
||||
})
|
||||
.sum::<f64>() / self.contributions.len() as f64;
|
||||
confidences[i] = 1.0 / (1.0 + var);
|
||||
}
|
||||
|
||||
(avg, confidences)
|
||||
}
|
||||
|
||||
/// FedProx: weighted average with proximal term.
|
||||
fn fedprox(&self, dim: usize, mu: f64) -> (Vec<f64>, Vec<f64>) {
|
||||
let (mut avg, confidences) = self.fedavg(dim);
|
||||
// Apply proximal regularization: pull toward zero (global model)
|
||||
for val in &mut avg {
|
||||
*val *= 1.0 / (1.0 + mu);
|
||||
}
|
||||
(avg, confidences)
|
||||
}
|
||||
|
||||
/// Weighted average by quality_weight.
|
||||
fn weighted_avg(&self, dim: usize) -> (Vec<f64>, Vec<f64>) {
|
||||
let total_weight: f64 = self.contributions.iter().map(|c| c.quality_weight).sum();
|
||||
|
||||
let mut avg = vec![0.0f64; dim];
|
||||
let mut confidences = vec![0.0f64; dim];
|
||||
|
||||
if total_weight <= 0.0 {
|
||||
return (avg, confidences);
|
||||
}
|
||||
|
||||
for c in &self.contributions {
|
||||
let w = c.quality_weight / total_weight;
|
||||
for (i, val) in c.weights.iter().enumerate() {
|
||||
if i < dim {
|
||||
avg[i] += w * val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..dim {
|
||||
let mean = avg[i];
|
||||
let var: f64 = self.contributions.iter()
|
||||
.map(|c| {
|
||||
let v = if i < c.weights.len() { c.weights[i] } else { 0.0 };
|
||||
(v - mean).powi(2)
|
||||
})
|
||||
.sum::<f64>() / self.contributions.len() as f64;
|
||||
confidences[i] = 1.0 / (1.0 + var);
|
||||
}
|
||||
|
||||
(avg, confidences)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_contribution(name: &str, weights: Vec<f64>, quality: f64, trajectories: u64) -> Contribution {
|
||||
Contribution {
|
||||
contributor: name.to_string(),
|
||||
weights,
|
||||
quality_weight: quality,
|
||||
trajectory_count: trajectories,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fedavg_two_equal_contributions() {
|
||||
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
|
||||
.with_min_contributions(2);
|
||||
|
||||
agg.add_contribution(make_contribution("a", vec![1.0, 2.0, 3.0], 1.0, 100));
|
||||
agg.add_contribution(make_contribution("b", vec![3.0, 4.0, 5.0], 1.0, 100));
|
||||
|
||||
let result = agg.aggregate().unwrap();
|
||||
assert_eq!(result.round, 1);
|
||||
assert_eq!(result.participation_count, 2);
|
||||
assert!((result.lora_deltas[0] - 2.0).abs() < 1e-10);
|
||||
assert!((result.lora_deltas[1] - 3.0).abs() < 1e-10);
|
||||
assert!((result.lora_deltas[2] - 4.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fedavg_weighted_by_trajectories() {
|
||||
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
|
||||
.with_min_contributions(2);
|
||||
|
||||
// A has 3x more trajectories, so A's values should dominate
|
||||
agg.add_contribution(make_contribution("a", vec![10.0], 1.0, 300));
|
||||
agg.add_contribution(make_contribution("b", vec![0.0], 1.0, 100));
|
||||
|
||||
let result = agg.aggregate().unwrap();
|
||||
// (300*10 + 100*0) / 400 = 7.5
|
||||
assert!((result.lora_deltas[0] - 7.5).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fedprox_shrinks_toward_zero() {
|
||||
let mut agg_avg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
|
||||
.with_min_contributions(2);
|
||||
agg_avg.add_contribution(make_contribution("a", vec![10.0], 1.0, 100));
|
||||
agg_avg.add_contribution(make_contribution("b", vec![10.0], 1.0, 100));
|
||||
let avg_result = agg_avg.aggregate().unwrap();
|
||||
|
||||
let mut agg_prox = FederatedAggregator::new("test".into(), AggregationStrategy::FedProx { mu: 50 })
|
||||
.with_min_contributions(2);
|
||||
agg_prox.add_contribution(make_contribution("a", vec![10.0], 1.0, 100));
|
||||
agg_prox.add_contribution(make_contribution("b", vec![10.0], 1.0, 100));
|
||||
let prox_result = agg_prox.aggregate().unwrap();
|
||||
|
||||
// FedProx should produce smaller values due to proximal regularization
|
||||
assert!(prox_result.lora_deltas[0] < avg_result.lora_deltas[0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn byzantine_outlier_removal() {
|
||||
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
|
||||
.with_min_contributions(2)
|
||||
.with_byzantine_threshold(2.0);
|
||||
|
||||
// Need enough good contributions so the outlier's z-score exceeds 2.0.
|
||||
// With k good + 1 evil, the evil z-score grows with sqrt(k).
|
||||
agg.add_contribution(make_contribution("good1", vec![1.0, 1.0], 1.0, 100));
|
||||
agg.add_contribution(make_contribution("good2", vec![1.1, 0.9], 1.0, 100));
|
||||
agg.add_contribution(make_contribution("good3", vec![0.9, 1.1], 1.0, 100));
|
||||
agg.add_contribution(make_contribution("good4", vec![1.0, 1.0], 1.0, 100));
|
||||
agg.add_contribution(make_contribution("good5", vec![1.0, 1.0], 1.0, 100));
|
||||
agg.add_contribution(make_contribution("good6", vec![1.0, 1.0], 1.0, 100));
|
||||
agg.add_contribution(make_contribution("evil", vec![100.0, 100.0], 1.0, 100)); // outlier
|
||||
|
||||
let result = agg.aggregate().unwrap();
|
||||
assert!(result.byzantine_filtered);
|
||||
assert!(result.outliers_removed >= 1);
|
||||
// Result should be close to 1.0, not pulled toward 100
|
||||
assert!(result.lora_deltas[0] < 5.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn insufficient_contributions_error() {
|
||||
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
|
||||
.with_min_contributions(3);
|
||||
|
||||
agg.add_contribution(make_contribution("a", vec![1.0], 1.0, 100));
|
||||
|
||||
let result = agg.aggregate();
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn weighted_average_strategy() {
|
||||
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::WeightedAverage)
|
||||
.with_min_contributions(2);
|
||||
|
||||
agg.add_contribution(make_contribution("a", vec![10.0], 0.9, 10));
|
||||
agg.add_contribution(make_contribution("b", vec![0.0], 0.1, 10));
|
||||
|
||||
let result = agg.aggregate().unwrap();
|
||||
// (0.9*10 + 0.1*0) / 1.0 = 9.0
|
||||
assert!((result.lora_deltas[0] - 9.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn round_increments() {
|
||||
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
|
||||
.with_min_contributions(2);
|
||||
|
||||
agg.add_contribution(make_contribution("a", vec![1.0], 1.0, 100));
|
||||
agg.add_contribution(make_contribution("b", vec![2.0], 1.0, 100));
|
||||
let r1 = agg.aggregate().unwrap();
|
||||
assert_eq!(r1.round, 1);
|
||||
|
||||
agg.add_contribution(make_contribution("a", vec![3.0], 1.0, 100));
|
||||
agg.add_contribution(make_contribution("b", vec![4.0], 1.0, 100));
|
||||
let r2 = agg.aggregate().unwrap();
|
||||
assert_eq!(r2.round, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn confidences_high_when_agreement() {
|
||||
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
|
||||
.with_min_contributions(2);
|
||||
|
||||
agg.add_contribution(make_contribution("a", vec![1.0], 1.0, 100));
|
||||
agg.add_contribution(make_contribution("b", vec![1.0], 1.0, 100));
|
||||
|
||||
let result = agg.aggregate().unwrap();
|
||||
// When all agree, variance = 0, confidence = 1/(1+0) = 1.0
|
||||
assert!((result.confidences[0] - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn confidences_lower_when_disagreement() {
|
||||
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
|
||||
.with_min_contributions(2);
|
||||
|
||||
agg.add_contribution(make_contribution("a", vec![0.0], 1.0, 100));
|
||||
agg.add_contribution(make_contribution("b", vec![10.0], 1.0, 100));
|
||||
|
||||
let result = agg.aggregate().unwrap();
|
||||
// When disagreement, confidence < 1.0
|
||||
assert!(result.confidences[0] < 1.0);
|
||||
}
|
||||
}
|
||||
416
vendor/ruvector/crates/rvf/rvf-federation/src/diff_privacy.rs
vendored
Normal file
416
vendor/ruvector/crates/rvf/rvf-federation/src/diff_privacy.rs
vendored
Normal file
@@ -0,0 +1,416 @@
|
||||
//! Differential privacy primitives for federated learning.
|
||||
//!
|
||||
//! Provides calibrated noise injection, gradient clipping, and a Renyi
|
||||
//! Differential Privacy (RDP) accountant for tracking cumulative privacy loss.
|
||||
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand_distr::{Distribution, Normal};
|
||||
|
||||
use crate::error::FederationError;
|
||||
use crate::types::{DiffPrivacyProof, NoiseMechanism};
|
||||
|
||||
/// Differential privacy engine for adding calibrated noise.
|
||||
pub struct DiffPrivacyEngine {
|
||||
/// Target epsilon (privacy loss bound).
|
||||
epsilon: f64,
|
||||
/// Target delta (probability of exceeding epsilon).
|
||||
delta: f64,
|
||||
/// L2 sensitivity bound.
|
||||
sensitivity: f64,
|
||||
/// Gradient clipping norm.
|
||||
clipping_norm: f64,
|
||||
/// Noise mechanism.
|
||||
mechanism: NoiseMechanism,
|
||||
/// Random number generator.
|
||||
rng: StdRng,
|
||||
}
|
||||
|
||||
impl DiffPrivacyEngine {
|
||||
/// Create a new DP engine with Gaussian mechanism.
|
||||
///
|
||||
/// Default: epsilon=1.0, delta=1e-5 (strong privacy).
|
||||
pub fn gaussian(
|
||||
epsilon: f64,
|
||||
delta: f64,
|
||||
sensitivity: f64,
|
||||
clipping_norm: f64,
|
||||
) -> Result<Self, FederationError> {
|
||||
if epsilon <= 0.0 {
|
||||
return Err(FederationError::InvalidEpsilon(epsilon));
|
||||
}
|
||||
if delta <= 0.0 || delta >= 1.0 {
|
||||
return Err(FederationError::InvalidDelta(delta));
|
||||
}
|
||||
Ok(Self {
|
||||
epsilon,
|
||||
delta,
|
||||
sensitivity,
|
||||
clipping_norm,
|
||||
mechanism: NoiseMechanism::Gaussian,
|
||||
rng: StdRng::from_rng(rand::thread_rng()).unwrap(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a new DP engine with Laplace mechanism.
|
||||
pub fn laplace(
|
||||
epsilon: f64,
|
||||
sensitivity: f64,
|
||||
clipping_norm: f64,
|
||||
) -> Result<Self, FederationError> {
|
||||
if epsilon <= 0.0 {
|
||||
return Err(FederationError::InvalidEpsilon(epsilon));
|
||||
}
|
||||
Ok(Self {
|
||||
epsilon,
|
||||
delta: 0.0,
|
||||
sensitivity,
|
||||
clipping_norm,
|
||||
mechanism: NoiseMechanism::Laplace,
|
||||
rng: StdRng::from_rng(rand::thread_rng()).unwrap(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with a deterministic seed (for testing).
|
||||
pub fn with_seed(mut self, seed: u64) -> Self {
|
||||
self.rng = StdRng::seed_from_u64(seed);
|
||||
self
|
||||
}
|
||||
|
||||
/// Compute the Gaussian noise standard deviation (sigma).
|
||||
fn gaussian_sigma(&self) -> f64 {
|
||||
self.sensitivity * (2.0_f64 * (1.25_f64 / self.delta).ln()).sqrt() / self.epsilon
|
||||
}
|
||||
|
||||
/// Compute the Laplace noise scale (b).
|
||||
fn laplace_scale(&self) -> f64 {
|
||||
self.sensitivity / self.epsilon
|
||||
}
|
||||
|
||||
/// Clip a gradient vector to the configured L2 norm bound.
|
||||
pub fn clip_gradients(&self, gradients: &mut [f64]) {
|
||||
let norm: f64 = gradients.iter().map(|x| x * x).sum::<f64>().sqrt();
|
||||
if norm > self.clipping_norm {
|
||||
let scale = self.clipping_norm / norm;
|
||||
for g in gradients.iter_mut() {
|
||||
*g *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add calibrated noise to a vector of parameters.
|
||||
///
|
||||
/// Clips gradients first, then adds noise per the configured mechanism.
|
||||
pub fn add_noise(&mut self, params: &mut [f64]) -> DiffPrivacyProof {
|
||||
self.clip_gradients(params);
|
||||
|
||||
match self.mechanism {
|
||||
NoiseMechanism::Gaussian => {
|
||||
let sigma = self.gaussian_sigma();
|
||||
let normal = Normal::new(0.0, sigma).unwrap();
|
||||
for p in params.iter_mut() {
|
||||
*p += normal.sample(&mut self.rng);
|
||||
}
|
||||
DiffPrivacyProof {
|
||||
epsilon: self.epsilon,
|
||||
delta: self.delta,
|
||||
mechanism: NoiseMechanism::Gaussian,
|
||||
sensitivity: self.sensitivity,
|
||||
clipping_norm: self.clipping_norm,
|
||||
noise_scale: sigma,
|
||||
noised_parameter_count: params.len() as u64,
|
||||
}
|
||||
}
|
||||
NoiseMechanism::Laplace => {
|
||||
let b = self.laplace_scale();
|
||||
for p in params.iter_mut() {
|
||||
// Laplace noise via inverse CDF: b * sign(u-0.5) * ln(1 - 2|u-0.5|)
|
||||
let u: f64 = self.rng.gen::<f64>() - 0.5;
|
||||
let noise = -b * u.signum() * (1.0 - 2.0 * u.abs()).ln();
|
||||
*p += noise;
|
||||
}
|
||||
DiffPrivacyProof {
|
||||
epsilon: self.epsilon,
|
||||
delta: 0.0,
|
||||
mechanism: NoiseMechanism::Laplace,
|
||||
sensitivity: self.sensitivity,
|
||||
clipping_norm: self.clipping_norm,
|
||||
noise_scale: b,
|
||||
noised_parameter_count: params.len() as u64,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add noise to a single scalar value.
|
||||
pub fn add_noise_scalar(&mut self, value: &mut f64) -> f64 {
|
||||
let mut v = [*value];
|
||||
self.add_noise(&mut v);
|
||||
*value = v[0];
|
||||
v[0]
|
||||
}
|
||||
|
||||
/// Current epsilon setting.
|
||||
pub fn epsilon(&self) -> f64 {
|
||||
self.epsilon
|
||||
}
|
||||
|
||||
/// Current delta setting.
|
||||
pub fn delta(&self) -> f64 {
|
||||
self.delta
|
||||
}
|
||||
}
|
||||
|
||||
// -- Privacy Accountant (RDP) ------------------------------------------------
|
||||
|
||||
/// Renyi Differential Privacy (RDP) accountant for tracking cumulative privacy loss.
|
||||
///
|
||||
/// Tracks privacy budget across multiple export rounds using RDP composition,
|
||||
/// which provides tighter bounds than naive (epsilon, delta)-DP composition.
|
||||
pub struct PrivacyAccountant {
|
||||
/// Maximum allowed cumulative epsilon.
|
||||
epsilon_limit: f64,
|
||||
/// Target delta for conversion from RDP to (epsilon, delta)-DP.
|
||||
target_delta: f64,
|
||||
/// Accumulated RDP values at various alpha orders.
|
||||
/// Each entry: (alpha_order, accumulated_rdp_epsilon)
|
||||
rdp_alphas: Vec<(f64, f64)>,
|
||||
/// History of exports: (timestamp, epsilon_spent, mechanism).
|
||||
history: Vec<ExportRecord>,
|
||||
}
|
||||
|
||||
/// Record of a single privacy-consuming export.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ExportRecord {
|
||||
/// UNIX timestamp of the export.
|
||||
pub timestamp_s: u64,
|
||||
/// Epsilon consumed by this export.
|
||||
pub epsilon: f64,
|
||||
/// Delta for this export (0 for pure epsilon-DP).
|
||||
pub delta: f64,
|
||||
/// Mechanism used.
|
||||
pub mechanism: NoiseMechanism,
|
||||
/// Number of parameters.
|
||||
pub parameter_count: u64,
|
||||
}
|
||||
|
||||
impl PrivacyAccountant {
|
||||
/// Create a new accountant with the given budget.
|
||||
pub fn new(epsilon_limit: f64, target_delta: f64) -> Self {
|
||||
// Standard RDP alpha orders for accounting
|
||||
let alphas: Vec<f64> = vec![
|
||||
1.5, 1.75, 2.0, 2.5, 3.0, 4.0, 5.0, 6.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 512.0,
|
||||
1024.0,
|
||||
];
|
||||
let rdp_alphas = alphas.into_iter().map(|a| (a, 0.0)).collect();
|
||||
Self {
|
||||
epsilon_limit,
|
||||
target_delta,
|
||||
rdp_alphas,
|
||||
history: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute RDP epsilon for the Gaussian mechanism at a given alpha order.
|
||||
fn gaussian_rdp(alpha: f64, sigma: f64) -> f64 {
|
||||
alpha / (2.0 * sigma * sigma)
|
||||
}
|
||||
|
||||
/// Convert RDP to (epsilon, delta)-DP for a given alpha order.
|
||||
fn rdp_to_dp(alpha: f64, rdp_epsilon: f64, delta: f64) -> f64 {
|
||||
rdp_epsilon - (delta.ln()) / (alpha - 1.0)
|
||||
}
|
||||
|
||||
/// Record a Gaussian mechanism query.
|
||||
pub fn record_gaussian(&mut self, sigma: f64, epsilon: f64, delta: f64, parameter_count: u64) {
|
||||
// Accumulate RDP at each alpha order
|
||||
for (alpha, rdp_eps) in &mut self.rdp_alphas {
|
||||
*rdp_eps += Self::gaussian_rdp(*alpha, sigma);
|
||||
}
|
||||
self.history.push(ExportRecord {
|
||||
timestamp_s: 0,
|
||||
epsilon,
|
||||
delta,
|
||||
mechanism: NoiseMechanism::Gaussian,
|
||||
parameter_count,
|
||||
});
|
||||
}
|
||||
|
||||
/// Record a Laplace mechanism query.
|
||||
pub fn record_laplace(&mut self, epsilon: f64, parameter_count: u64) {
|
||||
// For Laplace, RDP epsilon at order alpha is: alpha * eps / (alpha - 1)
|
||||
// when alpha > 1
|
||||
for (alpha, rdp_eps) in &mut self.rdp_alphas {
|
||||
if *alpha > 1.0 {
|
||||
*rdp_eps += *alpha * epsilon / (*alpha - 1.0);
|
||||
}
|
||||
}
|
||||
self.history.push(ExportRecord {
|
||||
timestamp_s: 0,
|
||||
epsilon,
|
||||
delta: 0.0,
|
||||
mechanism: NoiseMechanism::Laplace,
|
||||
parameter_count,
|
||||
});
|
||||
}
|
||||
|
||||
/// Get the current best (tightest) epsilon estimate.
|
||||
pub fn current_epsilon(&self) -> f64 {
|
||||
self.rdp_alphas
|
||||
.iter()
|
||||
.map(|(alpha, rdp_eps)| Self::rdp_to_dp(*alpha, *rdp_eps, self.target_delta))
|
||||
.fold(f64::INFINITY, f64::min)
|
||||
}
|
||||
|
||||
/// Remaining privacy budget.
|
||||
pub fn remaining_budget(&self) -> f64 {
|
||||
(self.epsilon_limit - self.current_epsilon()).max(0.0)
|
||||
}
|
||||
|
||||
/// Check if we can afford another export with the given epsilon.
|
||||
pub fn can_afford(&self, additional_epsilon: f64) -> bool {
|
||||
self.current_epsilon() + additional_epsilon <= self.epsilon_limit
|
||||
}
|
||||
|
||||
/// Check if budget is exhausted.
|
||||
pub fn is_exhausted(&self) -> bool {
|
||||
self.current_epsilon() >= self.epsilon_limit
|
||||
}
|
||||
|
||||
/// Fraction of budget consumed (0.0 to 1.0+).
|
||||
pub fn budget_fraction_used(&self) -> f64 {
|
||||
self.current_epsilon() / self.epsilon_limit
|
||||
}
|
||||
|
||||
/// Number of exports recorded.
|
||||
pub fn export_count(&self) -> usize {
|
||||
self.history.len()
|
||||
}
|
||||
|
||||
/// Export history.
|
||||
pub fn history(&self) -> &[ExportRecord] {
|
||||
&self.history
|
||||
}
|
||||
|
||||
/// Epsilon limit.
|
||||
pub fn epsilon_limit(&self) -> f64 {
|
||||
self.epsilon_limit
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn gaussian_engine_creates() {
|
||||
let engine = DiffPrivacyEngine::gaussian(1.0, 1e-5, 1.0, 1.0);
|
||||
assert!(engine.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_epsilon_rejected() {
|
||||
let engine = DiffPrivacyEngine::gaussian(0.0, 1e-5, 1.0, 1.0);
|
||||
assert!(engine.is_err());
|
||||
let engine = DiffPrivacyEngine::gaussian(-1.0, 1e-5, 1.0, 1.0);
|
||||
assert!(engine.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_delta_rejected() {
|
||||
let engine = DiffPrivacyEngine::gaussian(1.0, 0.0, 1.0, 1.0);
|
||||
assert!(engine.is_err());
|
||||
let engine = DiffPrivacyEngine::gaussian(1.0, 1.0, 1.0, 1.0);
|
||||
assert!(engine.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gradient_clipping() {
|
||||
let engine = DiffPrivacyEngine::gaussian(1.0, 1e-5, 1.0, 1.0).unwrap();
|
||||
let mut grads = vec![3.0, 4.0]; // norm = 5.0
|
||||
engine.clip_gradients(&mut grads);
|
||||
let norm: f64 = grads.iter().map(|x| x * x).sum::<f64>().sqrt();
|
||||
assert!((norm - 1.0).abs() < 1e-6); // clipped to norm 1.0
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gradient_no_clip_when_small() {
|
||||
let engine = DiffPrivacyEngine::gaussian(1.0, 1e-5, 1.0, 10.0).unwrap();
|
||||
let mut grads = vec![3.0, 4.0]; // norm = 5.0, clip = 10.0
|
||||
engine.clip_gradients(&mut grads);
|
||||
assert!((grads[0] - 3.0).abs() < 1e-10);
|
||||
assert!((grads[1] - 4.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_noise_gaussian_deterministic() {
|
||||
let mut engine = DiffPrivacyEngine::gaussian(1.0, 1e-5, 1.0, 100.0)
|
||||
.unwrap()
|
||||
.with_seed(42);
|
||||
let mut params = vec![1.0, 2.0, 3.0];
|
||||
let original = params.clone();
|
||||
let proof = engine.add_noise(&mut params);
|
||||
assert_eq!(proof.mechanism, NoiseMechanism::Gaussian);
|
||||
assert_eq!(proof.noised_parameter_count, 3);
|
||||
// Params should be different from original (noise added)
|
||||
assert!(params
|
||||
.iter()
|
||||
.zip(original.iter())
|
||||
.any(|(a, b)| (a - b).abs() > 1e-10));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_noise_laplace_deterministic() {
|
||||
let mut engine = DiffPrivacyEngine::laplace(1.0, 1.0, 100.0)
|
||||
.unwrap()
|
||||
.with_seed(42);
|
||||
let mut params = vec![1.0, 2.0, 3.0];
|
||||
let proof = engine.add_noise(&mut params);
|
||||
assert_eq!(proof.mechanism, NoiseMechanism::Laplace);
|
||||
assert_eq!(proof.noised_parameter_count, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn privacy_accountant_initial_state() {
|
||||
let acc = PrivacyAccountant::new(10.0, 1e-5);
|
||||
assert_eq!(acc.export_count(), 0);
|
||||
assert!(!acc.is_exhausted());
|
||||
assert!(acc.can_afford(1.0));
|
||||
assert!(acc.remaining_budget() > 9.9);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn privacy_accountant_tracks_gaussian() {
|
||||
let mut acc = PrivacyAccountant::new(10.0, 1e-5);
|
||||
// sigma=1.0 with epsilon=1.0 per query
|
||||
acc.record_gaussian(1.0, 1.0, 1e-5, 100);
|
||||
assert_eq!(acc.export_count(), 1);
|
||||
let eps = acc.current_epsilon();
|
||||
assert!(eps > 0.0);
|
||||
assert!(eps < 10.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn privacy_accountant_composition() {
|
||||
let mut acc = PrivacyAccountant::new(10.0, 1e-5);
|
||||
let eps_after_1 = {
|
||||
acc.record_gaussian(1.0, 1.0, 1e-5, 100);
|
||||
acc.current_epsilon()
|
||||
};
|
||||
acc.record_gaussian(1.0, 1.0, 1e-5, 100);
|
||||
let eps_after_2 = acc.current_epsilon();
|
||||
// After 2 queries, epsilon should be larger
|
||||
assert!(eps_after_2 > eps_after_1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn privacy_accountant_exhaustion() {
|
||||
let mut acc = PrivacyAccountant::new(1.0, 1e-5);
|
||||
// Use a very small sigma to burn budget fast
|
||||
for _ in 0..100 {
|
||||
acc.record_gaussian(0.1, 10.0, 1e-5, 10);
|
||||
}
|
||||
assert!(acc.is_exhausted());
|
||||
assert!(!acc.can_afford(0.1));
|
||||
}
|
||||
}
|
||||
52
vendor/ruvector/crates/rvf/rvf-federation/src/error.rs
vendored
Normal file
52
vendor/ruvector/crates/rvf/rvf-federation/src/error.rs
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
//! Federation error types.
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// Errors that can occur during federation operations.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum FederationError {
|
||||
#[error("privacy budget exhausted: spent {spent:.4}, limit {limit:.4}")]
|
||||
PrivacyBudgetExhausted { spent: f64, limit: f64 },
|
||||
|
||||
#[error("invalid epsilon value: {0} (must be > 0)")]
|
||||
InvalidEpsilon(f64),
|
||||
|
||||
#[error("invalid delta value: {0} (must be in (0, 1))")]
|
||||
InvalidDelta(f64),
|
||||
|
||||
#[error("segment validation failed: {0}")]
|
||||
SegmentValidation(String),
|
||||
|
||||
#[error("version mismatch: expected {expected}, got {got}")]
|
||||
VersionMismatch { expected: u32, got: u32 },
|
||||
|
||||
#[error("signature verification failed")]
|
||||
SignatureVerification,
|
||||
|
||||
#[error("witness chain broken at index {0}")]
|
||||
WitnessChainBroken(usize),
|
||||
|
||||
#[error("insufficient observations: need {needed}, have {have}")]
|
||||
InsufficientObservations { needed: u64, have: u64 },
|
||||
|
||||
#[error("quality below threshold: {score:.4} < {threshold:.4}")]
|
||||
QualityBelowThreshold { score: f64, threshold: f64 },
|
||||
|
||||
#[error("export rate limited: next export allowed at {next_allowed_epoch_s}")]
|
||||
RateLimited { next_allowed_epoch_s: u64 },
|
||||
|
||||
#[error("PII detected after stripping: {field}")]
|
||||
PiiLeakDetected { field: String },
|
||||
|
||||
#[error("Byzantine outlier detected from contributor {contributor}")]
|
||||
ByzantineOutlier { contributor: String },
|
||||
|
||||
#[error("aggregation requires at least {min} contributions, got {got}")]
|
||||
InsufficientContributions { min: usize, got: usize },
|
||||
|
||||
#[error("serialization error: {0}")]
|
||||
Serialization(String),
|
||||
|
||||
#[error("io error: {0}")]
|
||||
Io(String),
|
||||
}
|
||||
477
vendor/ruvector/crates/rvf/rvf-federation/src/federation.rs
vendored
Normal file
477
vendor/ruvector/crates/rvf/rvf-federation/src/federation.rs
vendored
Normal file
@@ -0,0 +1,477 @@
|
||||
//! Federation protocol: export builder, import merger, version-aware conflict resolution.
|
||||
|
||||
use crate::diff_privacy::DiffPrivacyEngine;
|
||||
use crate::error::FederationError;
|
||||
use crate::pii_strip::PiiStripper;
|
||||
use crate::policy::FederationPolicy;
|
||||
use crate::types::*;
|
||||
|
||||
/// Builder for constructing a federated learning export.
|
||||
///
|
||||
/// Follows the export flow from ADR-057:
|
||||
/// 1. Extract learning (priors, kernels, cost curves, weights)
|
||||
/// 2. PII-strip all payloads
|
||||
/// 3. Add differential privacy noise
|
||||
/// 4. Assemble manifest + attestation segments
|
||||
pub struct ExportBuilder {
|
||||
contributor_pseudonym: String,
|
||||
domain_id: String,
|
||||
priors: Vec<TransferPriorSet>,
|
||||
kernels: Vec<PolicyKernelSnapshot>,
|
||||
cost_curves: Vec<CostCurveSnapshot>,
|
||||
weights: Vec<Vec<f64>>,
|
||||
policy: FederationPolicy,
|
||||
string_fields: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
/// A completed federated export ready for publishing.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct FederatedExport {
|
||||
/// The manifest describing this export.
|
||||
pub manifest: FederatedManifest,
|
||||
/// PII redaction attestation.
|
||||
pub redaction_log: RedactionLog,
|
||||
/// Differential privacy attestation.
|
||||
pub privacy_proof: DiffPrivacyProof,
|
||||
/// Transfer priors (after PII stripping and DP noise).
|
||||
pub priors: Vec<TransferPriorSet>,
|
||||
/// Policy kernel snapshots.
|
||||
pub kernels: Vec<PolicyKernelSnapshot>,
|
||||
/// Cost curve snapshots.
|
||||
pub cost_curves: Vec<CostCurveSnapshot>,
|
||||
/// Noised aggregate weights (if any).
|
||||
pub weights: Vec<Vec<f64>>,
|
||||
}
|
||||
|
||||
impl ExportBuilder {
|
||||
/// Create a new export builder.
|
||||
pub fn new(contributor_pseudonym: String, domain_id: String) -> Self {
|
||||
Self {
|
||||
contributor_pseudonym,
|
||||
domain_id,
|
||||
priors: Vec::new(),
|
||||
kernels: Vec::new(),
|
||||
cost_curves: Vec::new(),
|
||||
weights: Vec::new(),
|
||||
policy: FederationPolicy::default(),
|
||||
string_fields: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the federation policy.
|
||||
pub fn with_policy(mut self, policy: FederationPolicy) -> Self {
|
||||
self.policy = policy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Add transfer priors from a trained domain.
|
||||
pub fn add_priors(mut self, priors: TransferPriorSet) -> Self {
|
||||
self.priors.push(priors);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a policy kernel snapshot.
|
||||
pub fn add_kernel(mut self, kernel: PolicyKernelSnapshot) -> Self {
|
||||
self.kernels.push(kernel);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a cost curve snapshot.
|
||||
pub fn add_cost_curve(mut self, curve: CostCurveSnapshot) -> Self {
|
||||
self.cost_curves.push(curve);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add raw weight vectors (LoRA deltas).
|
||||
pub fn add_weights(mut self, weights: Vec<f64>) -> Self {
|
||||
self.weights.push(weights);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a named string field for PII scanning.
|
||||
pub fn add_string_field(mut self, name: String, value: String) -> Self {
|
||||
self.string_fields.push((name, value));
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the export: PII-strip, add DP noise, assemble manifest.
|
||||
pub fn build(mut self, dp_engine: &mut DiffPrivacyEngine) -> Result<FederatedExport, FederationError> {
|
||||
// 1. Apply quality gate from policy
|
||||
self.priors.retain(|ps| {
|
||||
ps.entries.iter().all(|e| e.observation_count >= self.policy.min_observations)
|
||||
});
|
||||
|
||||
// 2. PII stripping
|
||||
let mut stripper = PiiStripper::new();
|
||||
let field_refs: Vec<(&str, &str)> = self.string_fields
|
||||
.iter()
|
||||
.map(|(n, v)| (n.as_str(), v.as_str()))
|
||||
.collect();
|
||||
let (_redacted_fields, redaction_log) = stripper.strip_fields(&field_refs);
|
||||
|
||||
// Strip PII from domain IDs and bucket IDs in priors
|
||||
for ps in &mut self.priors {
|
||||
ps.source_domain = stripper.strip_value(&ps.source_domain);
|
||||
for entry in &mut ps.entries {
|
||||
entry.bucket_id = stripper.strip_value(&entry.bucket_id);
|
||||
}
|
||||
}
|
||||
|
||||
// Strip PII from cost curve domain IDs
|
||||
for curve in &mut self.cost_curves {
|
||||
curve.domain_id = stripper.strip_value(&curve.domain_id);
|
||||
}
|
||||
|
||||
// 3. Add differential privacy noise to numerical parameters
|
||||
// Noise the Beta posteriors
|
||||
let mut noised_count: u64 = 0;
|
||||
for ps in &mut self.priors {
|
||||
for entry in &mut ps.entries {
|
||||
let mut params = [entry.params.alpha, entry.params.beta];
|
||||
dp_engine.add_noise(&mut params);
|
||||
entry.params.alpha = params[0].max(0.01); // Keep positive
|
||||
entry.params.beta = params[1].max(0.01);
|
||||
noised_count += 2;
|
||||
}
|
||||
}
|
||||
|
||||
// Noise the weight vectors
|
||||
for w in &mut self.weights {
|
||||
dp_engine.add_noise(w);
|
||||
noised_count += w.len() as u64;
|
||||
}
|
||||
|
||||
// Noise kernel knobs
|
||||
for kernel in &mut self.kernels {
|
||||
dp_engine.add_noise(&mut kernel.knobs);
|
||||
noised_count += kernel.knobs.len() as u64;
|
||||
}
|
||||
|
||||
// Noise cost curve values
|
||||
for curve in &mut self.cost_curves {
|
||||
let mut costs: Vec<f64> = curve.points.iter().map(|(_, c)| *c).collect();
|
||||
dp_engine.add_noise(&mut costs);
|
||||
for (i, (_, c)) in curve.points.iter_mut().enumerate() {
|
||||
*c = costs[i];
|
||||
}
|
||||
noised_count += costs.len() as u64;
|
||||
}
|
||||
|
||||
let privacy_proof = DiffPrivacyProof {
|
||||
epsilon: dp_engine.epsilon(),
|
||||
delta: dp_engine.delta(),
|
||||
mechanism: NoiseMechanism::Gaussian,
|
||||
sensitivity: 1.0,
|
||||
clipping_norm: 1.0,
|
||||
noise_scale: 0.0,
|
||||
noised_parameter_count: noised_count,
|
||||
};
|
||||
|
||||
// 4. Build manifest
|
||||
let total_trajectories: u64 = self.priors.iter()
|
||||
.flat_map(|ps| ps.entries.iter())
|
||||
.map(|e| e.observation_count)
|
||||
.sum();
|
||||
|
||||
let avg_quality = if !self.priors.is_empty() {
|
||||
self.priors.iter()
|
||||
.flat_map(|ps| ps.entries.iter())
|
||||
.map(|e| e.params.mean())
|
||||
.sum::<f64>()
|
||||
/ self.priors.iter().map(|ps| ps.entries.len()).sum::<usize>().max(1) as f64
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let manifest = FederatedManifest {
|
||||
format_version: 1,
|
||||
contributor_pseudonym: self.contributor_pseudonym,
|
||||
export_timestamp_s: 0,
|
||||
included_segment_ids: Vec::new(),
|
||||
privacy_budget_spent: dp_engine.epsilon(),
|
||||
domain_id: self.domain_id,
|
||||
rvf_version_tag: String::from("rvf-v1"),
|
||||
trajectory_count: total_trajectories,
|
||||
avg_quality_score: avg_quality,
|
||||
};
|
||||
|
||||
Ok(FederatedExport {
|
||||
manifest,
|
||||
redaction_log,
|
||||
privacy_proof,
|
||||
priors: self.priors,
|
||||
kernels: self.kernels,
|
||||
cost_curves: self.cost_curves,
|
||||
weights: self.weights,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Merger for importing federated learning into local engines.
|
||||
///
|
||||
/// Follows the import flow from ADR-057:
|
||||
/// 1. Validate signature and witness chain
|
||||
/// 2. Check version compatibility
|
||||
/// 3. Merge with dampened confidence
|
||||
pub struct ImportMerger {
|
||||
/// Current RVF version for compatibility checks.
|
||||
current_version: u32,
|
||||
/// Dampening factor for cross-version imports.
|
||||
version_dampen_factor: f64,
|
||||
}
|
||||
|
||||
impl ImportMerger {
|
||||
/// Create a new import merger.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
current_version: 1,
|
||||
version_dampen_factor: 0.5,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the dampening factor for imports from different versions.
|
||||
pub fn with_version_dampen(mut self, factor: f64) -> Self {
|
||||
self.version_dampen_factor = factor.clamp(0.0, 1.0);
|
||||
self
|
||||
}
|
||||
|
||||
/// Validate a federated export.
|
||||
pub fn validate(&self, export: &FederatedExport) -> Result<(), FederationError> {
|
||||
// Check format version
|
||||
if export.manifest.format_version == 0 {
|
||||
return Err(FederationError::SegmentValidation(
|
||||
"format_version must be > 0".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// Check privacy proof has valid parameters
|
||||
if export.privacy_proof.epsilon <= 0.0 {
|
||||
return Err(FederationError::InvalidEpsilon(export.privacy_proof.epsilon));
|
||||
}
|
||||
|
||||
// Check priors have positive parameters
|
||||
for ps in &export.priors {
|
||||
for entry in &ps.entries {
|
||||
if entry.params.alpha <= 0.0 || entry.params.beta <= 0.0 {
|
||||
return Err(FederationError::SegmentValidation(format!(
|
||||
"invalid Beta params in bucket {}: alpha={}, beta={}",
|
||||
entry.bucket_id, entry.params.alpha, entry.params.beta
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Merge imported priors with local priors.
|
||||
///
|
||||
/// Uses version-aware dampening: same version gets full weight,
|
||||
/// older versions get dampened (sqrt-scaling per MetaThompsonEngine).
|
||||
pub fn merge_priors(
|
||||
&self,
|
||||
local: &mut Vec<TransferPriorEntry>,
|
||||
remote: &[TransferPriorEntry],
|
||||
remote_version: u32,
|
||||
) {
|
||||
let dampen = if remote_version == self.current_version {
|
||||
1.0
|
||||
} else {
|
||||
self.version_dampen_factor
|
||||
};
|
||||
|
||||
for remote_entry in remote {
|
||||
let dampened = remote_entry.params.dampen(dampen);
|
||||
|
||||
if let Some(local_entry) = local.iter_mut().find(|l| {
|
||||
l.bucket_id == remote_entry.bucket_id && l.arm_id == remote_entry.arm_id
|
||||
}) {
|
||||
// Merge: sum parameters minus uniform prior
|
||||
local_entry.params = local_entry.params.merge(&dampened);
|
||||
local_entry.observation_count += remote_entry.observation_count;
|
||||
} else {
|
||||
// New entry: insert with dampened params
|
||||
local.push(TransferPriorEntry {
|
||||
bucket_id: remote_entry.bucket_id.clone(),
|
||||
arm_id: remote_entry.arm_id.clone(),
|
||||
params: dampened,
|
||||
observation_count: remote_entry.observation_count,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Merge imported weights with local weights using weighted average.
|
||||
pub fn merge_weights(
|
||||
&self,
|
||||
local: &mut [f64],
|
||||
remote: &[f64],
|
||||
local_weight: f64,
|
||||
remote_weight: f64,
|
||||
) {
|
||||
let total = local_weight + remote_weight;
|
||||
if total <= 0.0 || local.len() != remote.len() {
|
||||
return;
|
||||
}
|
||||
for (l, r) in local.iter_mut().zip(remote.iter()) {
|
||||
*l = (local_weight * *l + remote_weight * *r) / total;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ImportMerger {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::diff_privacy::DiffPrivacyEngine;
|
||||
|
||||
fn make_test_priors() -> TransferPriorSet {
|
||||
TransferPriorSet {
|
||||
source_domain: "test_domain".into(),
|
||||
entries: vec![
|
||||
TransferPriorEntry {
|
||||
bucket_id: "medium_algorithm".into(),
|
||||
arm_id: "arm_0".into(),
|
||||
params: BetaParams::new(10.0, 5.0),
|
||||
observation_count: 50,
|
||||
},
|
||||
TransferPriorEntry {
|
||||
bucket_id: "hard_synthesis".into(),
|
||||
arm_id: "arm_1".into(),
|
||||
params: BetaParams::new(8.0, 12.0),
|
||||
observation_count: 30,
|
||||
},
|
||||
],
|
||||
cost_ema: 0.85,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn export_builder_basic() {
|
||||
let mut dp = DiffPrivacyEngine::gaussian(1.0, 1e-5, 1.0, 10.0)
|
||||
.unwrap()
|
||||
.with_seed(42);
|
||||
let export = ExportBuilder::new("alice_pseudo".into(), "code_review".into())
|
||||
.add_priors(make_test_priors())
|
||||
.build(&mut dp)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(export.manifest.contributor_pseudonym, "alice_pseudo");
|
||||
assert_eq!(export.manifest.domain_id, "code_review");
|
||||
assert_eq!(export.manifest.format_version, 1);
|
||||
assert!(!export.priors.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn export_builder_with_weights() {
|
||||
let mut dp = DiffPrivacyEngine::gaussian(1.0, 1e-5, 1.0, 10.0)
|
||||
.unwrap()
|
||||
.with_seed(42);
|
||||
let weights = vec![0.1, 0.2, 0.3, 0.4];
|
||||
let export = ExportBuilder::new("bob_pseudo".into(), "genomics".into())
|
||||
.add_weights(weights.clone())
|
||||
.build(&mut dp)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(export.weights.len(), 1);
|
||||
// Weights should be different from original (noise added)
|
||||
assert!(export.weights[0].iter().zip(weights.iter()).any(|(a, b)| (a - b).abs() > 1e-10));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn import_merger_validate() {
|
||||
let mut dp = DiffPrivacyEngine::gaussian(1.0, 1e-5, 1.0, 10.0)
|
||||
.unwrap()
|
||||
.with_seed(42);
|
||||
let export = ExportBuilder::new("alice".into(), "domain".into())
|
||||
.add_priors(make_test_priors())
|
||||
.build(&mut dp)
|
||||
.unwrap();
|
||||
|
||||
let merger = ImportMerger::new();
|
||||
assert!(merger.validate(&export).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn import_merger_merge_priors_same_version() {
|
||||
let merger = ImportMerger::new();
|
||||
let mut local = vec![TransferPriorEntry {
|
||||
bucket_id: "medium_algorithm".into(),
|
||||
arm_id: "arm_0".into(),
|
||||
params: BetaParams::new(5.0, 3.0),
|
||||
observation_count: 20,
|
||||
}];
|
||||
|
||||
let remote = vec![TransferPriorEntry {
|
||||
bucket_id: "medium_algorithm".into(),
|
||||
arm_id: "arm_0".into(),
|
||||
params: BetaParams::new(10.0, 5.0),
|
||||
observation_count: 50,
|
||||
}];
|
||||
|
||||
merger.merge_priors(&mut local, &remote, 1);
|
||||
assert_eq!(local.len(), 1);
|
||||
// Merged: alpha = 5 + 10 - 1 = 14, beta = 3 + 5 - 1 = 7
|
||||
assert!((local[0].params.alpha - 14.0).abs() < 1e-10);
|
||||
assert!((local[0].params.beta - 7.0).abs() < 1e-10);
|
||||
assert_eq!(local[0].observation_count, 70);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn import_merger_merge_priors_different_version() {
|
||||
let merger = ImportMerger::new();
|
||||
let mut local = vec![TransferPriorEntry {
|
||||
bucket_id: "b".into(),
|
||||
arm_id: "a".into(),
|
||||
params: BetaParams::new(10.0, 10.0),
|
||||
observation_count: 50,
|
||||
}];
|
||||
|
||||
let remote = vec![TransferPriorEntry {
|
||||
bucket_id: "b".into(),
|
||||
arm_id: "a".into(),
|
||||
params: BetaParams::new(20.0, 5.0),
|
||||
observation_count: 40,
|
||||
}];
|
||||
|
||||
merger.merge_priors(&mut local, &remote, 0); // older version -> dampened
|
||||
assert_eq!(local.len(), 1);
|
||||
// Remote dampened by 0.5: alpha = 1 + (20-1)*0.5 = 10.5, beta = 1 + (5-1)*0.5 = 3.0
|
||||
// Merged: alpha = 10 + 10.5 - 1 = 19.5, beta = 10 + 3.0 - 1 = 12.0
|
||||
assert!((local[0].params.alpha - 19.5).abs() < 1e-10);
|
||||
assert!((local[0].params.beta - 12.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn import_merger_merge_new_bucket() {
|
||||
let merger = ImportMerger::new();
|
||||
let mut local: Vec<TransferPriorEntry> = Vec::new();
|
||||
|
||||
let remote = vec![TransferPriorEntry {
|
||||
bucket_id: "new_bucket".into(),
|
||||
arm_id: "arm_0".into(),
|
||||
params: BetaParams::new(10.0, 5.0),
|
||||
observation_count: 30,
|
||||
}];
|
||||
|
||||
merger.merge_priors(&mut local, &remote, 1);
|
||||
assert_eq!(local.len(), 1);
|
||||
assert_eq!(local[0].bucket_id, "new_bucket");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_weights_weighted_average() {
|
||||
let merger = ImportMerger::new();
|
||||
let mut local = vec![1.0, 2.0, 3.0];
|
||||
let remote = vec![3.0, 4.0, 5.0];
|
||||
merger.merge_weights(&mut local, &remote, 0.5, 0.5);
|
||||
assert!((local[0] - 2.0).abs() < 1e-10);
|
||||
assert!((local[1] - 3.0).abs() < 1e-10);
|
||||
assert!((local[2] - 4.0).abs() < 1e-10);
|
||||
}
|
||||
}
|
||||
24
vendor/ruvector/crates/rvf/rvf-federation/src/lib.rs
vendored
Normal file
24
vendor/ruvector/crates/rvf/rvf-federation/src/lib.rs
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
//! Federated RVF transfer learning.
|
||||
//!
|
||||
//! This crate implements the federation protocol described in ADR-057:
|
||||
//! - **PII stripping**: Three-stage pipeline (detect, redact, attest)
|
||||
//! - **Differential privacy**: Gaussian/Laplace noise, RDP accountant, gradient clipping
|
||||
//! - **Federation protocol**: Export builder, import merger, version-aware conflict resolution
|
||||
//! - **Federated aggregation**: FedAvg, FedProx, Byzantine-tolerant weighted averaging
|
||||
//! - **Segment types**: FederatedManifest, DiffPrivacyProof, RedactionLog, AggregateWeights
|
||||
|
||||
pub mod types;
|
||||
pub mod error;
|
||||
pub mod pii_strip;
|
||||
pub mod diff_privacy;
|
||||
pub mod federation;
|
||||
pub mod aggregate;
|
||||
pub mod policy;
|
||||
|
||||
pub use types::*;
|
||||
pub use error::FederationError;
|
||||
pub use pii_strip::PiiStripper;
|
||||
pub use diff_privacy::{DiffPrivacyEngine, PrivacyAccountant};
|
||||
pub use federation::{ExportBuilder, ImportMerger};
|
||||
pub use aggregate::{FederatedAggregator, AggregationStrategy};
|
||||
pub use policy::FederationPolicy;
|
||||
354
vendor/ruvector/crates/rvf/rvf-federation/src/pii_strip.rs
vendored
Normal file
354
vendor/ruvector/crates/rvf/rvf-federation/src/pii_strip.rs
vendored
Normal file
@@ -0,0 +1,354 @@
|
||||
//! Three-stage PII stripping pipeline.
|
||||
//!
|
||||
//! **Stage 1 — Detection**: Scan string fields for PII patterns.
|
||||
//! **Stage 2 — Redaction**: Replace PII with deterministic pseudonyms.
|
||||
//! **Stage 3 — Attestation**: Generate a `RedactionLog` segment.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use regex::Regex;
|
||||
use sha3::{Shake256, digest::{Update, ExtendableOutput, XofReader}};
|
||||
|
||||
use crate::types::{RedactionLog, RedactionEntry};
|
||||
|
||||
/// PII category with its detection regex and replacement template.
|
||||
struct PiiRule {
|
||||
category: &'static str,
|
||||
rule_id: &'static str,
|
||||
pattern: Regex,
|
||||
prefix: &'static str,
|
||||
}
|
||||
|
||||
/// Three-stage PII stripping pipeline.
|
||||
pub struct PiiStripper {
|
||||
rules: Vec<PiiRule>,
|
||||
/// Custom regex rules added by the user.
|
||||
custom_rules: Vec<PiiRule>,
|
||||
/// Pseudonym counter per category (for deterministic replacement).
|
||||
counters: HashMap<String, u32>,
|
||||
/// Map from original value to pseudonym (preserves structural relationships).
|
||||
pseudonym_map: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl PiiStripper {
|
||||
/// Create a new stripper with default detection rules.
|
||||
pub fn new() -> Self {
|
||||
let rules = vec![
|
||||
PiiRule {
|
||||
category: "path",
|
||||
rule_id: "rule_path_unix",
|
||||
pattern: Regex::new(r#"(?:/(?:home|Users|var|tmp|opt|etc)/[^\s,;:"'\]}>)]+)"#).unwrap(),
|
||||
prefix: "PATH",
|
||||
},
|
||||
PiiRule {
|
||||
category: "path",
|
||||
rule_id: "rule_path_windows",
|
||||
pattern: Regex::new(r#"(?i:[A-Z]:\\(?:Users|Documents|Program Files)[^\s,;:"'\]}>)]+)"#).unwrap(),
|
||||
prefix: "PATH",
|
||||
},
|
||||
PiiRule {
|
||||
category: "ip",
|
||||
rule_id: "rule_ipv4",
|
||||
pattern: Regex::new(r"\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b").unwrap(),
|
||||
prefix: "IP",
|
||||
},
|
||||
PiiRule {
|
||||
category: "ip",
|
||||
rule_id: "rule_ipv6",
|
||||
pattern: Regex::new(r"\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b").unwrap(),
|
||||
prefix: "IP",
|
||||
},
|
||||
PiiRule {
|
||||
category: "email",
|
||||
rule_id: "rule_email",
|
||||
pattern: Regex::new(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b").unwrap(),
|
||||
prefix: "EMAIL",
|
||||
},
|
||||
PiiRule {
|
||||
category: "api_key",
|
||||
rule_id: "rule_api_key_sk",
|
||||
pattern: Regex::new(r"\bsk-[A-Za-z0-9]{20,}\b").unwrap(),
|
||||
prefix: "REDACTED_KEY",
|
||||
},
|
||||
PiiRule {
|
||||
category: "api_key",
|
||||
rule_id: "rule_api_key_aws",
|
||||
pattern: Regex::new(r"\bAKIA[A-Z0-9]{16}\b").unwrap(),
|
||||
prefix: "REDACTED_KEY",
|
||||
},
|
||||
PiiRule {
|
||||
category: "api_key",
|
||||
rule_id: "rule_api_key_github",
|
||||
pattern: Regex::new(r"\bghp_[A-Za-z0-9]{36}\b").unwrap(),
|
||||
prefix: "REDACTED_KEY",
|
||||
},
|
||||
PiiRule {
|
||||
category: "api_key",
|
||||
rule_id: "rule_bearer_token",
|
||||
pattern: Regex::new(r"\bBearer\s+[A-Za-z0-9._~+/=-]{20,}\b").unwrap(),
|
||||
prefix: "REDACTED_KEY",
|
||||
},
|
||||
PiiRule {
|
||||
category: "env_var",
|
||||
rule_id: "rule_env_unix",
|
||||
pattern: Regex::new(r"\$(?:HOME|USER|USERNAME|USERPROFILE|PATH|TMPDIR)\b").unwrap(),
|
||||
prefix: "ENV",
|
||||
},
|
||||
PiiRule {
|
||||
category: "env_var",
|
||||
rule_id: "rule_env_windows",
|
||||
pattern: Regex::new(r"%(?:HOME|USER|USERNAME|USERPROFILE|PATH|TEMP)%").unwrap(),
|
||||
prefix: "ENV",
|
||||
},
|
||||
PiiRule {
|
||||
category: "username",
|
||||
rule_id: "rule_username_at",
|
||||
pattern: Regex::new(r"@[A-Za-z][A-Za-z0-9_-]{2,30}\b").unwrap(),
|
||||
prefix: "USER",
|
||||
},
|
||||
];
|
||||
|
||||
Self {
|
||||
rules,
|
||||
custom_rules: Vec::new(),
|
||||
counters: HashMap::new(),
|
||||
pseudonym_map: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a custom detection rule.
|
||||
pub fn add_rule(&mut self, category: &'static str, rule_id: &'static str, pattern: &str, prefix: &'static str) -> Result<(), regex::Error> {
|
||||
self.custom_rules.push(PiiRule {
|
||||
category,
|
||||
rule_id,
|
||||
pattern: Regex::new(pattern)?,
|
||||
prefix,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Reset the pseudonym map and counters (call between exports).
|
||||
pub fn reset(&mut self) {
|
||||
self.counters.clear();
|
||||
self.pseudonym_map.clear();
|
||||
}
|
||||
|
||||
/// Get or create a deterministic pseudonym for a matched value.
|
||||
fn pseudonym(&mut self, original: &str, prefix: &str) -> String {
|
||||
if let Some(existing) = self.pseudonym_map.get(original) {
|
||||
return existing.clone();
|
||||
}
|
||||
let counter = self.counters.entry(prefix.to_string()).or_insert(0);
|
||||
*counter += 1;
|
||||
let pseudo = format!("<{}_{}>", prefix, counter);
|
||||
self.pseudonym_map.insert(original.to_string(), pseudo.clone());
|
||||
pseudo
|
||||
}
|
||||
|
||||
/// Stage 1+2: Detect and redact PII in a single string.
|
||||
/// Returns (redacted_string, list of (category, rule_id, count) tuples).
|
||||
fn strip_string(&mut self, input: &str) -> (String, Vec<(String, String, u32)>) {
|
||||
let mut result = input.to_string();
|
||||
let mut detections: Vec<(String, String, u32)> = Vec::new();
|
||||
|
||||
let num_builtin = self.rules.len();
|
||||
let num_custom = self.custom_rules.len();
|
||||
|
||||
for i in 0..(num_builtin + num_custom) {
|
||||
let (pattern, prefix, category, rule_id) = if i < num_builtin {
|
||||
let r = &self.rules[i];
|
||||
(&r.pattern as &Regex, r.prefix, r.category, r.rule_id)
|
||||
} else {
|
||||
let r = &self.custom_rules[i - num_builtin];
|
||||
(&r.pattern as &Regex, r.prefix, r.category, r.rule_id)
|
||||
};
|
||||
let matches: Vec<String> = pattern.find_iter(&result).map(|m| m.as_str().to_string()).collect();
|
||||
if matches.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let count = matches.len() as u32;
|
||||
// Build pseudonyms and perform replacements
|
||||
let mut replacements: Vec<(String, String)> = Vec::new();
|
||||
for m in &matches {
|
||||
let pseudo = self.pseudonym(m, prefix);
|
||||
replacements.push((m.clone(), pseudo));
|
||||
}
|
||||
for (original, pseudo) in &replacements {
|
||||
result = result.replace(original.as_str(), pseudo.as_str());
|
||||
}
|
||||
detections.push((category.to_string(), rule_id.to_string(), count));
|
||||
}
|
||||
|
||||
(result, detections)
|
||||
}
|
||||
|
||||
/// Strip PII from a collection of named string fields.
|
||||
///
|
||||
/// Returns the redacted fields and a `RedactionLog` attestation.
|
||||
pub fn strip_fields(&mut self, fields: &[(&str, &str)]) -> (Vec<(String, String)>, RedactionLog) {
|
||||
// Stage 1+2: Detect and redact
|
||||
let mut redacted_fields = Vec::new();
|
||||
let mut all_detections: HashMap<(String, String), u32> = HashMap::new();
|
||||
|
||||
// Compute pre-redaction hash (Stage 3 prep)
|
||||
let mut hasher = Shake256::default();
|
||||
for (name, value) in fields {
|
||||
hasher.update(name.as_bytes());
|
||||
hasher.update(value.as_bytes());
|
||||
}
|
||||
let mut pre_hash = [0u8; 32];
|
||||
hasher.finalize_xof().read(&mut pre_hash);
|
||||
|
||||
for (name, value) in fields {
|
||||
let (redacted, detections) = self.strip_string(value);
|
||||
redacted_fields.push((name.to_string(), redacted));
|
||||
for (cat, rule, count) in detections {
|
||||
*all_detections.entry((cat, rule)).or_insert(0) += count;
|
||||
}
|
||||
}
|
||||
|
||||
// Stage 3: Build attestation
|
||||
let mut log = RedactionLog {
|
||||
entries: Vec::new(),
|
||||
pre_redaction_hash: pre_hash,
|
||||
fields_scanned: fields.len() as u64,
|
||||
total_redactions: 0,
|
||||
timestamp_s: 0, // caller should set this
|
||||
};
|
||||
|
||||
for ((category, rule_id), count) in &all_detections {
|
||||
log.entries.push(RedactionEntry {
|
||||
category: category.clone(),
|
||||
count: *count,
|
||||
rule_id: rule_id.clone(),
|
||||
});
|
||||
log.total_redactions += *count as u64;
|
||||
}
|
||||
|
||||
(redacted_fields, log)
|
||||
}
|
||||
|
||||
/// Strip PII from a single string value.
|
||||
pub fn strip_value(&mut self, input: &str) -> String {
|
||||
let (result, _) = self.strip_string(input);
|
||||
result
|
||||
}
|
||||
|
||||
/// Check if a string contains any detectable PII.
|
||||
pub fn contains_pii(&self, input: &str) -> bool {
|
||||
let all_rules: Vec<&PiiRule> = self.rules.iter().chain(self.custom_rules.iter()).collect();
|
||||
for rule in all_rules {
|
||||
if rule.pattern.is_match(input) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Return the current pseudonym map (for debugging/auditing).
|
||||
pub fn pseudonym_map(&self) -> &HashMap<String, String> {
|
||||
&self.pseudonym_map
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PiiStripper {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn detect_unix_paths() {
|
||||
let stripper = PiiStripper::new();
|
||||
assert!(stripper.contains_pii("/home/user/project/src/main.rs"));
|
||||
assert!(stripper.contains_pii("/Users/alice/.ssh/id_rsa"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detect_ipv4() {
|
||||
let stripper = PiiStripper::new();
|
||||
assert!(stripper.contains_pii("connecting to 192.168.1.100:8080"));
|
||||
assert!(stripper.contains_pii("server at 10.0.0.1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detect_emails() {
|
||||
let stripper = PiiStripper::new();
|
||||
assert!(stripper.contains_pii("contact user@example.com for help"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detect_api_keys() {
|
||||
let stripper = PiiStripper::new();
|
||||
assert!(stripper.contains_pii("key: sk-abcdefghijklmnopqrstuv"));
|
||||
assert!(stripper.contains_pii("aws: AKIAIOSFODNN7EXAMPLE"));
|
||||
assert!(stripper.contains_pii("token: ghp_abcdefghijklmnopqrstuvwxyz0123456789"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detect_env_vars() {
|
||||
let stripper = PiiStripper::new();
|
||||
assert!(stripper.contains_pii("path is $HOME/.config"));
|
||||
assert!(stripper.contains_pii("dir is %USERPROFILE%\\Desktop"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn redact_preserves_structure() {
|
||||
let mut stripper = PiiStripper::new();
|
||||
let input1 = "file at /home/alice/project/a.rs";
|
||||
let input2 = "also at /home/alice/project/b.rs";
|
||||
let r1 = stripper.strip_value(input1);
|
||||
let r2 = stripper.strip_value(input2);
|
||||
// Same path prefix should get same pseudonym
|
||||
assert!(r1.contains("<PATH_"));
|
||||
assert!(r2.contains("<PATH_"));
|
||||
assert!(!r1.contains("/home/alice"));
|
||||
assert!(!r2.contains("/home/alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_fields_produces_redaction_log() {
|
||||
let mut stripper = PiiStripper::new();
|
||||
let fields = vec![
|
||||
("path_field", "/home/user/data.csv"),
|
||||
("ip_field", "connecting to 10.0.0.1"),
|
||||
("clean_field", "no pii here"),
|
||||
];
|
||||
let (redacted, log) = stripper.strip_fields(&fields);
|
||||
assert_eq!(redacted.len(), 3);
|
||||
assert_eq!(log.fields_scanned, 3);
|
||||
assert!(log.total_redactions >= 2);
|
||||
assert!(log.pre_redaction_hash != [0u8; 32]);
|
||||
// clean field should be unchanged
|
||||
assert_eq!(redacted[2].1, "no pii here");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_pii_returns_clean() {
|
||||
let stripper = PiiStripper::new();
|
||||
assert!(!stripper.contains_pii("just a normal string"));
|
||||
assert!(!stripper.contains_pii("alpha = 10.5, beta = 3.2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reset_clears_state() {
|
||||
let mut stripper = PiiStripper::new();
|
||||
stripper.strip_value("/home/user/test");
|
||||
assert!(!stripper.pseudonym_map().is_empty());
|
||||
stripper.reset();
|
||||
assert!(stripper.pseudonym_map().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_rule() {
|
||||
let mut stripper = PiiStripper::new();
|
||||
stripper.add_rule("ssn", "rule_ssn", r"\b\d{3}-\d{2}-\d{4}\b", "SSN").unwrap();
|
||||
assert!(stripper.contains_pii("ssn: 123-45-6789"));
|
||||
let redacted = stripper.strip_value("ssn: 123-45-6789");
|
||||
assert!(redacted.contains("<SSN_"));
|
||||
assert!(!redacted.contains("123-45-6789"));
|
||||
}
|
||||
}
|
||||
193
vendor/ruvector/crates/rvf/rvf-federation/src/policy.rs
vendored
Normal file
193
vendor/ruvector/crates/rvf/rvf-federation/src/policy.rs
vendored
Normal file
@@ -0,0 +1,193 @@
|
||||
//! Federation policy for selective sharing.
|
||||
//!
|
||||
//! Controls what learning is exported, quality gates, rate limits,
|
||||
//! and privacy budget constraints.
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// Controls what a user shares in federated exports.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct FederationPolicy {
|
||||
/// Segment types allowed for export (empty = all allowed).
|
||||
pub allowed_segments: HashSet<u8>,
|
||||
/// Segment types explicitly denied for export.
|
||||
pub denied_segments: HashSet<u8>,
|
||||
/// Domain IDs allowed for export (empty = all allowed).
|
||||
pub allowed_domains: HashSet<String>,
|
||||
/// Domain IDs denied for export.
|
||||
pub denied_domains: HashSet<String>,
|
||||
/// Minimum quality score for exported trajectories (0.0 - 1.0).
|
||||
pub quality_threshold: f64,
|
||||
/// Minimum observations per prior entry for export.
|
||||
pub min_observations: u64,
|
||||
/// Maximum exports per hour.
|
||||
pub max_exports_per_hour: u32,
|
||||
/// Maximum cumulative privacy budget (epsilon).
|
||||
pub privacy_budget_limit: f64,
|
||||
/// Whether to include policy kernel snapshots.
|
||||
pub export_kernels: bool,
|
||||
/// Whether to include cost curve data.
|
||||
pub export_cost_curves: bool,
|
||||
}
|
||||
|
||||
impl Default for FederationPolicy {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
allowed_segments: HashSet::new(),
|
||||
denied_segments: HashSet::new(),
|
||||
allowed_domains: HashSet::new(),
|
||||
denied_domains: HashSet::new(),
|
||||
quality_threshold: 0.5,
|
||||
min_observations: 12,
|
||||
max_exports_per_hour: 100,
|
||||
privacy_budget_limit: 10.0,
|
||||
export_kernels: true,
|
||||
export_cost_curves: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FederationPolicy {
|
||||
/// Create a restrictive policy (deny all by default).
|
||||
pub fn restrictive() -> Self {
|
||||
Self {
|
||||
quality_threshold: 0.8,
|
||||
min_observations: 50,
|
||||
max_exports_per_hour: 10,
|
||||
privacy_budget_limit: 5.0,
|
||||
export_kernels: false,
|
||||
export_cost_curves: false,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a permissive policy (share everything).
|
||||
pub fn permissive() -> Self {
|
||||
Self {
|
||||
quality_threshold: 0.0,
|
||||
min_observations: 1,
|
||||
max_exports_per_hour: 1000,
|
||||
privacy_budget_limit: 100.0,
|
||||
export_kernels: true,
|
||||
export_cost_curves: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a segment type is allowed for export.
|
||||
pub fn is_segment_allowed(&self, seg_type: u8) -> bool {
|
||||
if self.denied_segments.contains(&seg_type) {
|
||||
return false;
|
||||
}
|
||||
if self.allowed_segments.is_empty() {
|
||||
return true;
|
||||
}
|
||||
self.allowed_segments.contains(&seg_type)
|
||||
}
|
||||
|
||||
/// Check if a domain is allowed for export.
|
||||
pub fn is_domain_allowed(&self, domain_id: &str) -> bool {
|
||||
if self.denied_domains.contains(domain_id) {
|
||||
return false;
|
||||
}
|
||||
if self.allowed_domains.is_empty() {
|
||||
return true;
|
||||
}
|
||||
self.allowed_domains.contains(domain_id)
|
||||
}
|
||||
|
||||
/// Allow a specific segment type.
|
||||
pub fn allow_segment(mut self, seg_type: u8) -> Self {
|
||||
self.allowed_segments.insert(seg_type);
|
||||
self
|
||||
}
|
||||
|
||||
/// Deny a specific segment type.
|
||||
pub fn deny_segment(mut self, seg_type: u8) -> Self {
|
||||
self.denied_segments.insert(seg_type);
|
||||
self
|
||||
}
|
||||
|
||||
/// Allow a specific domain.
|
||||
pub fn allow_domain(mut self, domain_id: &str) -> Self {
|
||||
self.allowed_domains.insert(domain_id.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Deny a specific domain.
|
||||
pub fn deny_domain(mut self, domain_id: &str) -> Self {
|
||||
self.denied_domains.insert(domain_id.to_string());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_policy() {
|
||||
let p = FederationPolicy::default();
|
||||
assert_eq!(p.quality_threshold, 0.5);
|
||||
assert_eq!(p.min_observations, 12);
|
||||
assert!(p.is_segment_allowed(0x33));
|
||||
assert!(p.is_domain_allowed("anything"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn restrictive_policy() {
|
||||
let p = FederationPolicy::restrictive();
|
||||
assert_eq!(p.quality_threshold, 0.8);
|
||||
assert_eq!(p.min_observations, 50);
|
||||
assert!(!p.export_kernels);
|
||||
assert!(!p.export_cost_curves);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn permissive_policy() {
|
||||
let p = FederationPolicy::permissive();
|
||||
assert_eq!(p.quality_threshold, 0.0);
|
||||
assert_eq!(p.min_observations, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn segment_allowlist() {
|
||||
let p = FederationPolicy::default().allow_segment(0x33).allow_segment(0x34);
|
||||
assert!(p.is_segment_allowed(0x33));
|
||||
assert!(p.is_segment_allowed(0x34));
|
||||
assert!(!p.is_segment_allowed(0x35)); // not in allowlist
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn segment_denylist() {
|
||||
let p = FederationPolicy::default().deny_segment(0x36);
|
||||
assert!(p.is_segment_allowed(0x33));
|
||||
assert!(!p.is_segment_allowed(0x36)); // denied
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deny_takes_precedence() {
|
||||
let p = FederationPolicy::default()
|
||||
.allow_segment(0x33)
|
||||
.deny_segment(0x33);
|
||||
assert!(!p.is_segment_allowed(0x33)); // deny wins
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn domain_filtering() {
|
||||
let p = FederationPolicy::default()
|
||||
.allow_domain("genomics")
|
||||
.deny_domain("secret_project");
|
||||
assert!(p.is_domain_allowed("genomics"));
|
||||
assert!(!p.is_domain_allowed("secret_project"));
|
||||
assert!(!p.is_domain_allowed("trading")); // not in allowlist
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_allowlist_allows_all() {
|
||||
let p = FederationPolicy::default();
|
||||
assert!(p.is_segment_allowed(0x33));
|
||||
assert!(p.is_segment_allowed(0xFF));
|
||||
assert!(p.is_domain_allowed("any_domain"));
|
||||
}
|
||||
}
|
||||
426
vendor/ruvector/crates/rvf/rvf-federation/src/types.rs
vendored
Normal file
426
vendor/ruvector/crates/rvf/rvf-federation/src/types.rs
vendored
Normal file
@@ -0,0 +1,426 @@
|
||||
//! Federation segment payload types.
|
||||
//!
|
||||
//! Four new RVF segment types (0x33-0x36) defined in ADR-057.
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// ── Segment type constants ──────────────────────────────────────────
|
||||
|
||||
/// Segment type discriminator for FederatedManifest.
|
||||
pub const SEG_FEDERATED_MANIFEST: u8 = 0x33;
|
||||
/// Segment type discriminator for DiffPrivacyProof.
|
||||
pub const SEG_DIFF_PRIVACY_PROOF: u8 = 0x34;
|
||||
/// Segment type discriminator for RedactionLog.
|
||||
pub const SEG_REDACTION_LOG: u8 = 0x35;
|
||||
/// Segment type discriminator for AggregateWeights.
|
||||
pub const SEG_AGGREGATE_WEIGHTS: u8 = 0x36;
|
||||
|
||||
// ── FederatedManifest (0x33) ────────────────────────────────────────
|
||||
|
||||
/// Describes a federated learning export.
|
||||
///
|
||||
/// Attached as the first segment in every federation RVF file.
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub struct FederatedManifest {
|
||||
/// Format version (currently 1).
|
||||
pub format_version: u32,
|
||||
/// Pseudonym of the contributor (never the real identity).
|
||||
pub contributor_pseudonym: String,
|
||||
/// UNIX timestamp (seconds) when the export was created.
|
||||
pub export_timestamp_s: u64,
|
||||
/// Segment IDs included in this export.
|
||||
pub included_segment_ids: Vec<u64>,
|
||||
/// Cumulative differential privacy budget spent (epsilon).
|
||||
pub privacy_budget_spent: f64,
|
||||
/// Domain identifier this export applies to.
|
||||
pub domain_id: String,
|
||||
/// RVF format version compatibility tag.
|
||||
pub rvf_version_tag: String,
|
||||
/// Number of trajectories summarized in the exported learning.
|
||||
pub trajectory_count: u64,
|
||||
/// Average quality score of exported trajectories.
|
||||
pub avg_quality_score: f64,
|
||||
}
|
||||
|
||||
impl FederatedManifest {
|
||||
/// Create a new manifest with required fields.
|
||||
pub fn new(contributor_pseudonym: String, domain_id: String) -> Self {
|
||||
Self {
|
||||
format_version: 1,
|
||||
contributor_pseudonym,
|
||||
export_timestamp_s: 0,
|
||||
included_segment_ids: Vec::new(),
|
||||
privacy_budget_spent: 0.0,
|
||||
domain_id,
|
||||
rvf_version_tag: String::from("rvf-v1"),
|
||||
trajectory_count: 0,
|
||||
avg_quality_score: 0.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── DiffPrivacyProof (0x34) ─────────────────────────────────────────
|
||||
|
||||
/// Noise mechanism used for differential privacy.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub enum NoiseMechanism {
|
||||
/// Gaussian noise for (epsilon, delta)-DP.
|
||||
Gaussian,
|
||||
/// Laplace noise for epsilon-DP.
|
||||
Laplace,
|
||||
}
|
||||
|
||||
/// Differential privacy attestation.
|
||||
///
|
||||
/// Records the privacy parameters and noise applied during export.
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub struct DiffPrivacyProof {
|
||||
/// Privacy loss parameter.
|
||||
pub epsilon: f64,
|
||||
/// Probability of privacy failure.
|
||||
pub delta: f64,
|
||||
/// Noise mechanism applied.
|
||||
pub mechanism: NoiseMechanism,
|
||||
/// L2 sensitivity bound used for noise calibration.
|
||||
pub sensitivity: f64,
|
||||
/// Gradient clipping norm.
|
||||
pub clipping_norm: f64,
|
||||
/// Noise scale (sigma for Gaussian, b for Laplace).
|
||||
pub noise_scale: f64,
|
||||
/// Number of parameters that had noise added.
|
||||
pub noised_parameter_count: u64,
|
||||
}
|
||||
|
||||
impl DiffPrivacyProof {
|
||||
/// Create a new proof for Gaussian mechanism.
|
||||
pub fn gaussian(epsilon: f64, delta: f64, sensitivity: f64, clipping_norm: f64) -> Self {
|
||||
let sigma = sensitivity * (2.0_f64 * (1.25_f64 / delta).ln()).sqrt() / epsilon;
|
||||
Self {
|
||||
epsilon,
|
||||
delta,
|
||||
mechanism: NoiseMechanism::Gaussian,
|
||||
sensitivity,
|
||||
clipping_norm,
|
||||
noise_scale: sigma,
|
||||
noised_parameter_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new proof for Laplace mechanism.
|
||||
pub fn laplace(epsilon: f64, sensitivity: f64, clipping_norm: f64) -> Self {
|
||||
let b = sensitivity / epsilon;
|
||||
Self {
|
||||
epsilon,
|
||||
delta: 0.0,
|
||||
mechanism: NoiseMechanism::Laplace,
|
||||
sensitivity,
|
||||
clipping_norm,
|
||||
noise_scale: b,
|
||||
noised_parameter_count: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── RedactionLog (0x35) ─────────────────────────────────────────────
|
||||
|
||||
/// A single redaction event.
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub struct RedactionEntry {
|
||||
/// Category of PII detected (e.g. "path", "ip", "email", "api_key").
|
||||
pub category: String,
|
||||
/// Number of occurrences redacted.
|
||||
pub count: u32,
|
||||
/// Rule identifier that triggered the redaction.
|
||||
pub rule_id: String,
|
||||
}
|
||||
|
||||
/// PII stripping attestation.
|
||||
///
|
||||
/// Proves that PII scanning was performed without revealing the original content.
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub struct RedactionLog {
|
||||
/// Individual redaction entries by category.
|
||||
pub entries: Vec<RedactionEntry>,
|
||||
/// SHAKE-256 hash of the pre-redaction content (32 bytes).
|
||||
pub pre_redaction_hash: [u8; 32],
|
||||
/// Total number of fields scanned.
|
||||
pub fields_scanned: u64,
|
||||
/// Total number of redactions applied.
|
||||
pub total_redactions: u64,
|
||||
/// UNIX timestamp (seconds) when redaction was performed.
|
||||
pub timestamp_s: u64,
|
||||
}
|
||||
|
||||
impl RedactionLog {
|
||||
/// Create an empty redaction log.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
entries: Vec::new(),
|
||||
pre_redaction_hash: [0u8; 32],
|
||||
fields_scanned: 0,
|
||||
total_redactions: 0,
|
||||
timestamp_s: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a redaction entry.
|
||||
pub fn add_entry(&mut self, category: &str, count: u32, rule_id: &str) {
|
||||
self.total_redactions += count as u64;
|
||||
self.entries.push(RedactionEntry {
|
||||
category: category.to_string(),
|
||||
count,
|
||||
rule_id: rule_id.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RedactionLog {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ── AggregateWeights (0x36) ─────────────────────────────────────────
|
||||
|
||||
/// Federated-averaged weight vector with metadata.
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub struct AggregateWeights {
|
||||
/// Federated averaging round number.
|
||||
pub round: u64,
|
||||
/// Number of participants in this round.
|
||||
pub participation_count: u32,
|
||||
/// Aggregated LoRA delta weights (flattened).
|
||||
pub lora_deltas: Vec<f64>,
|
||||
/// Per-weight confidence scores.
|
||||
pub confidences: Vec<f64>,
|
||||
/// Mean loss across participants.
|
||||
pub mean_loss: f64,
|
||||
/// Loss variance across participants.
|
||||
pub loss_variance: f64,
|
||||
/// Domain identifier.
|
||||
pub domain_id: String,
|
||||
/// Whether Byzantine outlier removal was applied.
|
||||
pub byzantine_filtered: bool,
|
||||
/// Number of contributions removed as outliers.
|
||||
pub outliers_removed: u32,
|
||||
}
|
||||
|
||||
impl AggregateWeights {
|
||||
/// Create empty aggregate weights for a domain.
|
||||
pub fn new(domain_id: String, round: u64) -> Self {
|
||||
Self {
|
||||
round,
|
||||
participation_count: 0,
|
||||
lora_deltas: Vec::new(),
|
||||
confidences: Vec::new(),
|
||||
mean_loss: 0.0,
|
||||
loss_variance: 0.0,
|
||||
domain_id,
|
||||
byzantine_filtered: false,
|
||||
outliers_removed: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── BetaParams (local copy for federation) ──────────────────────────
|
||||
|
||||
/// Beta distribution parameters for Thompson Sampling priors.
|
||||
///
|
||||
/// Mirrors the type in `ruvector-domain-expansion` to avoid cross-crate dependency.
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub struct BetaParams {
|
||||
/// Alpha (success count + 1).
|
||||
pub alpha: f64,
|
||||
/// Beta (failure count + 1).
|
||||
pub beta: f64,
|
||||
}
|
||||
|
||||
impl BetaParams {
|
||||
/// Create new Beta parameters.
|
||||
pub fn new(alpha: f64, beta: f64) -> Self {
|
||||
Self { alpha, beta }
|
||||
}
|
||||
|
||||
/// Uniform (uninformative) prior.
|
||||
pub fn uniform() -> Self {
|
||||
Self { alpha: 1.0, beta: 1.0 }
|
||||
}
|
||||
|
||||
/// Mean of the Beta distribution.
|
||||
pub fn mean(&self) -> f64 {
|
||||
self.alpha / (self.alpha + self.beta)
|
||||
}
|
||||
|
||||
/// Total observations (alpha + beta - 2 for a Beta(1,1) prior).
|
||||
pub fn observations(&self) -> f64 {
|
||||
self.alpha + self.beta - 2.0
|
||||
}
|
||||
|
||||
/// Merge two Beta posteriors by summing parameters and subtracting the uniform prior.
|
||||
pub fn merge(&self, other: &BetaParams) -> BetaParams {
|
||||
BetaParams {
|
||||
alpha: self.alpha + other.alpha - 1.0,
|
||||
beta: self.beta + other.beta - 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Dampen this prior by mixing with a uniform prior using sqrt-scaling.
|
||||
pub fn dampen(&self, factor: f64) -> BetaParams {
|
||||
let f = factor.clamp(0.0, 1.0);
|
||||
BetaParams {
|
||||
alpha: 1.0 + (self.alpha - 1.0) * f,
|
||||
beta: 1.0 + (self.beta - 1.0) * f,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for BetaParams {
|
||||
fn default() -> Self {
|
||||
Self::uniform()
|
||||
}
|
||||
}
|
||||
|
||||
// ── TransferPrior (local copy for federation) ───────────────────────
|
||||
|
||||
/// Compact summary of learned priors for a single context bucket.
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub struct TransferPriorEntry {
|
||||
/// Context bucket identifier.
|
||||
pub bucket_id: String,
|
||||
/// Arm identifier.
|
||||
pub arm_id: String,
|
||||
/// Beta posterior parameters.
|
||||
pub params: BetaParams,
|
||||
/// Number of observations backing this prior.
|
||||
pub observation_count: u64,
|
||||
}
|
||||
|
||||
/// Collection of transfer priors from a trained domain.
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub struct TransferPriorSet {
|
||||
/// Source domain identifier.
|
||||
pub source_domain: String,
|
||||
/// Individual prior entries.
|
||||
pub entries: Vec<TransferPriorEntry>,
|
||||
/// EMA cost at time of extraction.
|
||||
pub cost_ema: f64,
|
||||
}
|
||||
|
||||
// ── PolicyKernelSnapshot ────────────────────────────────────────────
|
||||
|
||||
/// Snapshot of a policy kernel configuration for federation export.
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub struct PolicyKernelSnapshot {
|
||||
/// Kernel identifier.
|
||||
pub kernel_id: String,
|
||||
/// Tunable knob values.
|
||||
pub knobs: Vec<f64>,
|
||||
/// Fitness score.
|
||||
pub fitness: f64,
|
||||
/// Generation number.
|
||||
pub generation: u64,
|
||||
}
|
||||
|
||||
// ── CostCurveSnapshot ───────────────────────────────────────────────
|
||||
|
||||
/// Snapshot of cost curve data for federation export.
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub struct CostCurveSnapshot {
|
||||
/// Domain identifier.
|
||||
pub domain_id: String,
|
||||
/// Ordered (step, cost) points.
|
||||
pub points: Vec<(u64, f64)>,
|
||||
/// Acceleration factor (> 1.0 means transfer helped).
|
||||
pub acceleration: f64,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn segment_type_constants() {
|
||||
assert_eq!(SEG_FEDERATED_MANIFEST, 0x33);
|
||||
assert_eq!(SEG_DIFF_PRIVACY_PROOF, 0x34);
|
||||
assert_eq!(SEG_REDACTION_LOG, 0x35);
|
||||
assert_eq!(SEG_AGGREGATE_WEIGHTS, 0x36);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn federated_manifest_new() {
|
||||
let m = FederatedManifest::new("alice".into(), "genomics".into());
|
||||
assert_eq!(m.format_version, 1);
|
||||
assert_eq!(m.contributor_pseudonym, "alice");
|
||||
assert_eq!(m.domain_id, "genomics");
|
||||
assert_eq!(m.trajectory_count, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn diff_privacy_proof_gaussian() {
|
||||
let p = DiffPrivacyProof::gaussian(1.0, 1e-5, 1.0, 1.0);
|
||||
assert_eq!(p.mechanism, NoiseMechanism::Gaussian);
|
||||
assert!(p.noise_scale > 0.0);
|
||||
assert_eq!(p.epsilon, 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn diff_privacy_proof_laplace() {
|
||||
let p = DiffPrivacyProof::laplace(1.0, 1.0, 1.0);
|
||||
assert_eq!(p.mechanism, NoiseMechanism::Laplace);
|
||||
assert!((p.noise_scale - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn redaction_log_add_entry() {
|
||||
let mut log = RedactionLog::new();
|
||||
log.add_entry("path", 3, "rule_path_unix");
|
||||
log.add_entry("ip", 2, "rule_ipv4");
|
||||
assert_eq!(log.entries.len(), 2);
|
||||
assert_eq!(log.total_redactions, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aggregate_weights_new() {
|
||||
let w = AggregateWeights::new("code_review".into(), 1);
|
||||
assert_eq!(w.round, 1);
|
||||
assert_eq!(w.participation_count, 0);
|
||||
assert!(!w.byzantine_filtered);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn beta_params_merge() {
|
||||
let a = BetaParams::new(10.0, 5.0);
|
||||
let b = BetaParams::new(8.0, 3.0);
|
||||
let merged = a.merge(&b);
|
||||
assert!((merged.alpha - 17.0).abs() < 1e-10);
|
||||
assert!((merged.beta - 7.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn beta_params_dampen() {
|
||||
let p = BetaParams::new(10.0, 5.0);
|
||||
let dampened = p.dampen(0.25);
|
||||
// alpha = 1 + (10-1)*0.25 = 1 + 2.25 = 3.25
|
||||
assert!((dampened.alpha - 3.25).abs() < 1e-10);
|
||||
// beta = 1 + (5-1)*0.25 = 1 + 1.0 = 2.0
|
||||
assert!((dampened.beta - 2.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn beta_params_mean() {
|
||||
let p = BetaParams::new(10.0, 10.0);
|
||||
assert!((p.mean() - 0.5).abs() < 1e-10);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user