Files
wifi-densepose/vendor/ruvector/crates/rvf/rvf-federation/src/aggregate.rs

421 lines
15 KiB
Rust

//! Federated aggregation: FedAvg, FedProx, Byzantine-tolerant weighted averaging.
use crate::error::FederationError;
use crate::types::AggregateWeights;
/// Aggregation strategy.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum AggregationStrategy {
/// Federated Averaging (McMahan et al., 2017).
FedAvg,
/// Federated Proximal (Li et al., 2020).
FedProx { mu: u32 },
/// Simple weighted average.
WeightedAverage,
}
impl Default for AggregationStrategy {
fn default() -> Self {
Self::FedAvg
}
}
/// A single contribution to a federated averaging round.
#[derive(Clone, Debug)]
pub struct Contribution {
/// Contributor pseudonym.
pub contributor: String,
/// Weight vector (LoRA deltas).
pub weights: Vec<f64>,
/// Quality/reputation weight for this contributor.
pub quality_weight: f64,
/// Number of training trajectories behind this contribution.
pub trajectory_count: u64,
}
/// Federated aggregation server.
pub struct FederatedAggregator {
/// Aggregation strategy.
strategy: AggregationStrategy,
/// Domain identifier.
domain_id: String,
/// Current round number.
round: u64,
/// Minimum contributions required for a round.
min_contributions: usize,
/// Standard deviation threshold for Byzantine outlier detection.
byzantine_std_threshold: f64,
/// Collected contributions for the current round.
contributions: Vec<Contribution>,
}
impl FederatedAggregator {
/// Create a new aggregator.
pub fn new(domain_id: String, strategy: AggregationStrategy) -> Self {
Self {
strategy,
domain_id,
round: 0,
min_contributions: 2,
byzantine_std_threshold: 2.0,
contributions: Vec::new(),
}
}
/// Set minimum contributions required.
pub fn with_min_contributions(mut self, min: usize) -> Self {
self.min_contributions = min;
self
}
/// Set Byzantine outlier threshold (in standard deviations).
pub fn with_byzantine_threshold(mut self, threshold: f64) -> Self {
self.byzantine_std_threshold = threshold;
self
}
/// Add a contribution for the current round.
pub fn add_contribution(&mut self, contribution: Contribution) {
self.contributions.push(contribution);
}
/// Number of contributions collected so far.
pub fn contribution_count(&self) -> usize {
self.contributions.len()
}
/// Current round number.
pub fn round(&self) -> u64 {
self.round
}
/// Check if we have enough contributions to aggregate.
pub fn ready(&self) -> bool {
self.contributions.len() >= self.min_contributions
}
/// Detect and remove Byzantine outliers.
///
/// Returns the number of outliers removed.
fn remove_byzantine_outliers(&mut self) -> u32 {
if self.contributions.len() < 3 {
return 0; // Need at least 3 for meaningful outlier detection
}
let dim = self.contributions[0].weights.len();
if dim == 0 || !self.contributions.iter().all(|c| c.weights.len() == dim) {
return 0;
}
// Compute mean and std of L2 norms
let norms: Vec<f64> = self.contributions.iter()
.map(|c| c.weights.iter().map(|w| w * w).sum::<f64>().sqrt())
.collect();
let mean_norm = norms.iter().sum::<f64>() / norms.len() as f64;
let variance = norms.iter().map(|n| (n - mean_norm).powi(2)).sum::<f64>() / norms.len() as f64;
let std_dev = variance.sqrt();
if std_dev < 1e-10 {
return 0;
}
let original_count = self.contributions.len();
let threshold = self.byzantine_std_threshold;
self.contributions.retain(|c| {
let norm = c.weights.iter().map(|w| w * w).sum::<f64>().sqrt();
((norm - mean_norm) / std_dev).abs() <= threshold
});
(original_count - self.contributions.len()) as u32
}
/// Aggregate contributions and produce an `AggregateWeights` segment.
pub fn aggregate(&mut self) -> Result<AggregateWeights, FederationError> {
if self.contributions.len() < self.min_contributions {
return Err(FederationError::InsufficientContributions {
min: self.min_contributions,
got: self.contributions.len(),
});
}
// Byzantine outlier removal
let outliers_removed = self.remove_byzantine_outliers();
if self.contributions.is_empty() {
return Err(FederationError::InsufficientContributions {
min: self.min_contributions,
got: 0,
});
}
let dim = self.contributions[0].weights.len();
let result = match self.strategy {
AggregationStrategy::FedAvg => self.fedavg(dim),
AggregationStrategy::FedProx { mu } => self.fedprox(dim, mu as f64 / 100.0),
AggregationStrategy::WeightedAverage => self.weighted_avg(dim),
};
self.round += 1;
let participation_count = self.contributions.len() as u32;
// Compute loss stats
let losses: Vec<f64> = self.contributions.iter()
.map(|c| {
// Use inverse quality as a proxy for loss
1.0 - c.quality_weight.clamp(0.0, 1.0)
})
.collect();
let mean_loss = losses.iter().sum::<f64>() / losses.len() as f64;
let loss_variance = losses.iter().map(|l| (l - mean_loss).powi(2)).sum::<f64>() / losses.len() as f64;
self.contributions.clear();
Ok(AggregateWeights {
round: self.round,
participation_count,
lora_deltas: result.0,
confidences: result.1,
mean_loss,
loss_variance,
domain_id: self.domain_id.clone(),
byzantine_filtered: outliers_removed > 0,
outliers_removed,
})
}
/// FedAvg: weighted average by trajectory count.
fn fedavg(&self, dim: usize) -> (Vec<f64>, Vec<f64>) {
let total_trajectories: f64 = self.contributions.iter()
.map(|c| c.trajectory_count as f64)
.sum();
let mut avg = vec![0.0f64; dim];
let mut confidences = vec![0.0f64; dim];
if total_trajectories <= 0.0 {
return (avg, confidences);
}
for c in &self.contributions {
let w = c.trajectory_count as f64 / total_trajectories;
for (i, val) in c.weights.iter().enumerate() {
if i < dim {
avg[i] += w * val;
}
}
}
// Confidence = inverse of variance across contributions per dimension
for i in 0..dim {
let mean = avg[i];
let var: f64 = self.contributions.iter()
.map(|c| {
let v = if i < c.weights.len() { c.weights[i] } else { 0.0 };
(v - mean).powi(2)
})
.sum::<f64>() / self.contributions.len() as f64;
confidences[i] = 1.0 / (1.0 + var);
}
(avg, confidences)
}
/// FedProx: weighted average with proximal term.
fn fedprox(&self, dim: usize, mu: f64) -> (Vec<f64>, Vec<f64>) {
let (mut avg, confidences) = self.fedavg(dim);
// Apply proximal regularization: pull toward zero (global model)
for val in &mut avg {
*val *= 1.0 / (1.0 + mu);
}
(avg, confidences)
}
/// Weighted average by quality_weight.
fn weighted_avg(&self, dim: usize) -> (Vec<f64>, Vec<f64>) {
let total_weight: f64 = self.contributions.iter().map(|c| c.quality_weight).sum();
let mut avg = vec![0.0f64; dim];
let mut confidences = vec![0.0f64; dim];
if total_weight <= 0.0 {
return (avg, confidences);
}
for c in &self.contributions {
let w = c.quality_weight / total_weight;
for (i, val) in c.weights.iter().enumerate() {
if i < dim {
avg[i] += w * val;
}
}
}
for i in 0..dim {
let mean = avg[i];
let var: f64 = self.contributions.iter()
.map(|c| {
let v = if i < c.weights.len() { c.weights[i] } else { 0.0 };
(v - mean).powi(2)
})
.sum::<f64>() / self.contributions.len() as f64;
confidences[i] = 1.0 / (1.0 + var);
}
(avg, confidences)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_contribution(name: &str, weights: Vec<f64>, quality: f64, trajectories: u64) -> Contribution {
Contribution {
contributor: name.to_string(),
weights,
quality_weight: quality,
trajectory_count: trajectories,
}
}
#[test]
fn fedavg_two_equal_contributions() {
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
.with_min_contributions(2);
agg.add_contribution(make_contribution("a", vec![1.0, 2.0, 3.0], 1.0, 100));
agg.add_contribution(make_contribution("b", vec![3.0, 4.0, 5.0], 1.0, 100));
let result = agg.aggregate().unwrap();
assert_eq!(result.round, 1);
assert_eq!(result.participation_count, 2);
assert!((result.lora_deltas[0] - 2.0).abs() < 1e-10);
assert!((result.lora_deltas[1] - 3.0).abs() < 1e-10);
assert!((result.lora_deltas[2] - 4.0).abs() < 1e-10);
}
#[test]
fn fedavg_weighted_by_trajectories() {
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
.with_min_contributions(2);
// A has 3x more trajectories, so A's values should dominate
agg.add_contribution(make_contribution("a", vec![10.0], 1.0, 300));
agg.add_contribution(make_contribution("b", vec![0.0], 1.0, 100));
let result = agg.aggregate().unwrap();
// (300*10 + 100*0) / 400 = 7.5
assert!((result.lora_deltas[0] - 7.5).abs() < 1e-10);
}
#[test]
fn fedprox_shrinks_toward_zero() {
let mut agg_avg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
.with_min_contributions(2);
agg_avg.add_contribution(make_contribution("a", vec![10.0], 1.0, 100));
agg_avg.add_contribution(make_contribution("b", vec![10.0], 1.0, 100));
let avg_result = agg_avg.aggregate().unwrap();
let mut agg_prox = FederatedAggregator::new("test".into(), AggregationStrategy::FedProx { mu: 50 })
.with_min_contributions(2);
agg_prox.add_contribution(make_contribution("a", vec![10.0], 1.0, 100));
agg_prox.add_contribution(make_contribution("b", vec![10.0], 1.0, 100));
let prox_result = agg_prox.aggregate().unwrap();
// FedProx should produce smaller values due to proximal regularization
assert!(prox_result.lora_deltas[0] < avg_result.lora_deltas[0]);
}
#[test]
fn byzantine_outlier_removal() {
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
.with_min_contributions(2)
.with_byzantine_threshold(2.0);
// Need enough good contributions so the outlier's z-score exceeds 2.0.
// With k good + 1 evil, the evil z-score grows with sqrt(k).
agg.add_contribution(make_contribution("good1", vec![1.0, 1.0], 1.0, 100));
agg.add_contribution(make_contribution("good2", vec![1.1, 0.9], 1.0, 100));
agg.add_contribution(make_contribution("good3", vec![0.9, 1.1], 1.0, 100));
agg.add_contribution(make_contribution("good4", vec![1.0, 1.0], 1.0, 100));
agg.add_contribution(make_contribution("good5", vec![1.0, 1.0], 1.0, 100));
agg.add_contribution(make_contribution("good6", vec![1.0, 1.0], 1.0, 100));
agg.add_contribution(make_contribution("evil", vec![100.0, 100.0], 1.0, 100)); // outlier
let result = agg.aggregate().unwrap();
assert!(result.byzantine_filtered);
assert!(result.outliers_removed >= 1);
// Result should be close to 1.0, not pulled toward 100
assert!(result.lora_deltas[0] < 5.0);
}
#[test]
fn insufficient_contributions_error() {
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
.with_min_contributions(3);
agg.add_contribution(make_contribution("a", vec![1.0], 1.0, 100));
let result = agg.aggregate();
assert!(result.is_err());
}
#[test]
fn weighted_average_strategy() {
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::WeightedAverage)
.with_min_contributions(2);
agg.add_contribution(make_contribution("a", vec![10.0], 0.9, 10));
agg.add_contribution(make_contribution("b", vec![0.0], 0.1, 10));
let result = agg.aggregate().unwrap();
// (0.9*10 + 0.1*0) / 1.0 = 9.0
assert!((result.lora_deltas[0] - 9.0).abs() < 1e-10);
}
#[test]
fn round_increments() {
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
.with_min_contributions(2);
agg.add_contribution(make_contribution("a", vec![1.0], 1.0, 100));
agg.add_contribution(make_contribution("b", vec![2.0], 1.0, 100));
let r1 = agg.aggregate().unwrap();
assert_eq!(r1.round, 1);
agg.add_contribution(make_contribution("a", vec![3.0], 1.0, 100));
agg.add_contribution(make_contribution("b", vec![4.0], 1.0, 100));
let r2 = agg.aggregate().unwrap();
assert_eq!(r2.round, 2);
}
#[test]
fn confidences_high_when_agreement() {
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
.with_min_contributions(2);
agg.add_contribution(make_contribution("a", vec![1.0], 1.0, 100));
agg.add_contribution(make_contribution("b", vec![1.0], 1.0, 100));
let result = agg.aggregate().unwrap();
// When all agree, variance = 0, confidence = 1/(1+0) = 1.0
assert!((result.confidences[0] - 1.0).abs() < 1e-10);
}
#[test]
fn confidences_lower_when_disagreement() {
let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
.with_min_contributions(2);
agg.add_contribution(make_contribution("a", vec![0.0], 1.0, 100));
agg.add_contribution(make_contribution("b", vec![10.0], 1.0, 100));
let result = agg.aggregate().unwrap();
// When disagreement, confidence < 1.0
assert!(result.confidences[0] < 1.0);
}
}