Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
482
vendor/ruvector/crates/ruvector-domain-expansion/src/cost_curve.rs
vendored
Normal file
482
vendor/ruvector/crates/ruvector-domain-expansion/src/cost_curve.rs
vendored
Normal file
@@ -0,0 +1,482 @@
|
||||
//! Cost Curve Compression Tracker and Acceleration Scoreboard
|
||||
//!
|
||||
//! Measures whether cost curves compress faster in each new domain.
|
||||
//! If they do, you are increasing general problem-solving capability.
|
||||
//!
|
||||
//! ## Acceptance Test
|
||||
//!
|
||||
//! Domain 2 must converge faster than Domain 1.
|
||||
//! Measure cycles to reach:
|
||||
//! - 95% accuracy
|
||||
//! - Target cost per solve
|
||||
//! - Target robustness
|
||||
//! - Zero policy violations
|
||||
|
||||
use crate::domain::DomainId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// A single data point on the cost curve.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CostCurvePoint {
|
||||
/// Cycle number (training iteration).
|
||||
pub cycle: u64,
|
||||
/// Current accuracy [0.0, 1.0].
|
||||
pub accuracy: f32,
|
||||
/// Cost per solve at this point.
|
||||
pub cost_per_solve: f32,
|
||||
/// Robustness score [0.0, 1.0].
|
||||
pub robustness: f32,
|
||||
/// Number of policy violations in this cycle.
|
||||
pub policy_violations: u32,
|
||||
/// Wall-clock timestamp (seconds since epoch).
|
||||
pub timestamp: f64,
|
||||
}
|
||||
|
||||
/// Convergence thresholds for the acceptance test.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ConvergenceThresholds {
|
||||
/// Target accuracy (default: 0.95).
|
||||
pub target_accuracy: f32,
|
||||
/// Target cost per solve.
|
||||
pub target_cost: f32,
|
||||
/// Target robustness (default: 0.90).
|
||||
pub target_robustness: f32,
|
||||
/// Maximum allowed policy violations (default: 0).
|
||||
pub max_violations: u32,
|
||||
}
|
||||
|
||||
impl Default for ConvergenceThresholds {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
target_accuracy: 0.95,
|
||||
target_cost: 0.01,
|
||||
target_robustness: 0.90,
|
||||
max_violations: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cost curve for a single domain, tracking convergence over cycles.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CostCurve {
|
||||
/// Domain this curve belongs to.
|
||||
pub domain_id: DomainId,
|
||||
/// Whether this was trained with transfer priors.
|
||||
pub used_transfer: bool,
|
||||
/// Source domain for transfer (if any).
|
||||
pub transfer_source: Option<DomainId>,
|
||||
/// Ordered data points.
|
||||
pub points: Vec<CostCurvePoint>,
|
||||
/// Convergence thresholds.
|
||||
pub thresholds: ConvergenceThresholds,
|
||||
}
|
||||
|
||||
impl CostCurve {
|
||||
/// Create a new cost curve for a domain.
|
||||
pub fn new(domain_id: DomainId, thresholds: ConvergenceThresholds) -> Self {
|
||||
Self {
|
||||
domain_id,
|
||||
used_transfer: false,
|
||||
transfer_source: None,
|
||||
points: Vec::new(),
|
||||
thresholds,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a cost curve with transfer metadata.
|
||||
pub fn with_transfer(
|
||||
domain_id: DomainId,
|
||||
source: DomainId,
|
||||
thresholds: ConvergenceThresholds,
|
||||
) -> Self {
|
||||
Self {
|
||||
domain_id,
|
||||
used_transfer: true,
|
||||
transfer_source: Some(source),
|
||||
points: Vec::new(),
|
||||
thresholds,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a new data point.
|
||||
pub fn record(&mut self, point: CostCurvePoint) {
|
||||
self.points.push(point);
|
||||
}
|
||||
|
||||
/// Check if all convergence criteria are met at the latest point.
|
||||
pub fn has_converged(&self) -> bool {
|
||||
self.points.last().map_or(false, |p| {
|
||||
p.accuracy >= self.thresholds.target_accuracy
|
||||
&& p.cost_per_solve <= self.thresholds.target_cost
|
||||
&& p.robustness >= self.thresholds.target_robustness
|
||||
&& p.policy_violations <= self.thresholds.max_violations
|
||||
})
|
||||
}
|
||||
|
||||
/// Cycles to reach target accuracy (None if not yet reached).
|
||||
pub fn cycles_to_accuracy(&self) -> Option<u64> {
|
||||
self.points
|
||||
.iter()
|
||||
.find(|p| p.accuracy >= self.thresholds.target_accuracy)
|
||||
.map(|p| p.cycle)
|
||||
}
|
||||
|
||||
/// Cycles to reach target cost (None if not yet reached).
|
||||
pub fn cycles_to_cost(&self) -> Option<u64> {
|
||||
self.points
|
||||
.iter()
|
||||
.find(|p| p.cost_per_solve <= self.thresholds.target_cost)
|
||||
.map(|p| p.cycle)
|
||||
}
|
||||
|
||||
/// Cycles to reach target robustness.
|
||||
pub fn cycles_to_robustness(&self) -> Option<u64> {
|
||||
self.points
|
||||
.iter()
|
||||
.find(|p| p.robustness >= self.thresholds.target_robustness)
|
||||
.map(|p| p.cycle)
|
||||
}
|
||||
|
||||
/// Cycles to full convergence (all criteria met).
|
||||
pub fn cycles_to_convergence(&self) -> Option<u64> {
|
||||
self.points
|
||||
.iter()
|
||||
.find(|p| {
|
||||
p.accuracy >= self.thresholds.target_accuracy
|
||||
&& p.cost_per_solve <= self.thresholds.target_cost
|
||||
&& p.robustness >= self.thresholds.target_robustness
|
||||
&& p.policy_violations <= self.thresholds.max_violations
|
||||
})
|
||||
.map(|p| p.cycle)
|
||||
}
|
||||
|
||||
/// Area under the accuracy curve (higher = faster learning).
|
||||
pub fn auc_accuracy(&self) -> f32 {
|
||||
if self.points.len() < 2 {
|
||||
return 0.0;
|
||||
}
|
||||
self.points
|
||||
.windows(2)
|
||||
.map(|w| {
|
||||
let dx = (w[1].cycle - w[0].cycle) as f32;
|
||||
let avg_y = (w[0].accuracy + w[1].accuracy) / 2.0;
|
||||
dx * avg_y
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Compression ratio: how fast the cost curve drops.
|
||||
/// Computed as initial_cost / final_cost (higher = more compression).
|
||||
pub fn compression_ratio(&self) -> f32 {
|
||||
if self.points.len() < 2 {
|
||||
return 1.0;
|
||||
}
|
||||
let initial = self.points.first().unwrap().cost_per_solve;
|
||||
let final_cost = self.points.last().unwrap().cost_per_solve;
|
||||
if final_cost > 1e-10 {
|
||||
initial / final_cost
|
||||
} else {
|
||||
initial / 1e-10
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Acceleration scoreboard comparing domain learning curves.
|
||||
/// Shows acceleration, not just improvement.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AccelerationScoreboard {
|
||||
/// Per-domain cost curves.
|
||||
pub curves: HashMap<DomainId, CostCurve>,
|
||||
/// Pairwise acceleration factors.
|
||||
pub accelerations: Vec<AccelerationEntry>,
|
||||
}
|
||||
|
||||
/// An entry showing how transfer from source to target affected convergence.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AccelerationEntry {
|
||||
/// Source domain.
|
||||
pub source: DomainId,
|
||||
/// Target domain.
|
||||
pub target: DomainId,
|
||||
/// Cycles to convergence without transfer (baseline).
|
||||
pub baseline_cycles: Option<u64>,
|
||||
/// Cycles to convergence with transfer.
|
||||
pub transfer_cycles: Option<u64>,
|
||||
/// Acceleration factor: baseline / transfer (>1 = transfer helped).
|
||||
pub acceleration: f32,
|
||||
/// AUC comparison (higher = better learning curve).
|
||||
pub auc_baseline: f32,
|
||||
pub auc_transfer: f32,
|
||||
/// Compression ratio comparison.
|
||||
pub compression_baseline: f32,
|
||||
pub compression_transfer: f32,
|
||||
/// Whether generalization test passed.
|
||||
pub generalization_passed: bool,
|
||||
}
|
||||
|
||||
impl AccelerationScoreboard {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
curves: HashMap::new(),
|
||||
accelerations: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a cost curve for a domain.
|
||||
pub fn add_curve(&mut self, curve: CostCurve) {
|
||||
self.curves.insert(curve.domain_id.clone(), curve);
|
||||
}
|
||||
|
||||
/// Compute acceleration between a baseline (no transfer) and transfer curve.
|
||||
pub fn compute_acceleration(
|
||||
&mut self,
|
||||
baseline_domain: &DomainId,
|
||||
transfer_domain: &DomainId,
|
||||
) -> Option<AccelerationEntry> {
|
||||
let baseline = self.curves.get(baseline_domain)?;
|
||||
let transfer = self.curves.get(transfer_domain)?;
|
||||
|
||||
let baseline_cycles = baseline.cycles_to_convergence();
|
||||
let transfer_cycles = transfer.cycles_to_convergence();
|
||||
|
||||
let acceleration = match (baseline_cycles, transfer_cycles) {
|
||||
(Some(b), Some(t)) if t > 0 => b as f32 / t as f32,
|
||||
_ => 1.0, // No measurable acceleration
|
||||
};
|
||||
|
||||
let entry = AccelerationEntry {
|
||||
source: transfer
|
||||
.transfer_source
|
||||
.clone()
|
||||
.unwrap_or_else(|| DomainId("none".into())),
|
||||
target: transfer_domain.clone(),
|
||||
baseline_cycles,
|
||||
transfer_cycles,
|
||||
acceleration,
|
||||
auc_baseline: baseline.auc_accuracy(),
|
||||
auc_transfer: transfer.auc_accuracy(),
|
||||
compression_baseline: baseline.compression_ratio(),
|
||||
compression_transfer: transfer.compression_ratio(),
|
||||
generalization_passed: acceleration > 1.0,
|
||||
};
|
||||
|
||||
self.accelerations.push(entry.clone());
|
||||
Some(entry)
|
||||
}
|
||||
|
||||
/// Check whether each successive domain converges faster (the IQ growth test).
|
||||
pub fn progressive_acceleration(&self) -> bool {
|
||||
if self.accelerations.len() < 2 {
|
||||
return true; // Not enough data to judge
|
||||
}
|
||||
|
||||
self.accelerations
|
||||
.windows(2)
|
||||
.all(|w| w[1].acceleration >= w[0].acceleration)
|
||||
}
|
||||
|
||||
/// Summary report of all domains.
|
||||
pub fn summary(&self) -> ScoreboardSummary {
|
||||
let domain_summaries: Vec<DomainSummary> = self
|
||||
.curves
|
||||
.iter()
|
||||
.map(|(id, curve)| DomainSummary {
|
||||
domain_id: id.clone(),
|
||||
total_cycles: curve.points.last().map(|p| p.cycle).unwrap_or(0),
|
||||
final_accuracy: curve.points.last().map(|p| p.accuracy).unwrap_or(0.0),
|
||||
final_cost: curve
|
||||
.points
|
||||
.last()
|
||||
.map(|p| p.cost_per_solve)
|
||||
.unwrap_or(f32::MAX),
|
||||
converged: curve.has_converged(),
|
||||
cycles_to_convergence: curve.cycles_to_convergence(),
|
||||
compression_ratio: curve.compression_ratio(),
|
||||
used_transfer: curve.used_transfer,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let overall_acceleration = if self.accelerations.is_empty() {
|
||||
1.0
|
||||
} else {
|
||||
self.accelerations
|
||||
.iter()
|
||||
.map(|a| a.acceleration)
|
||||
.sum::<f32>()
|
||||
/ self.accelerations.len() as f32
|
||||
};
|
||||
|
||||
ScoreboardSummary {
|
||||
domains: domain_summaries,
|
||||
accelerations: self.accelerations.clone(),
|
||||
overall_acceleration,
|
||||
progressive_improvement: self.progressive_acceleration(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AccelerationScoreboard {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Summary of a single domain's learning.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DomainSummary {
|
||||
pub domain_id: DomainId,
|
||||
pub total_cycles: u64,
|
||||
pub final_accuracy: f32,
|
||||
pub final_cost: f32,
|
||||
pub converged: bool,
|
||||
pub cycles_to_convergence: Option<u64>,
|
||||
pub compression_ratio: f32,
|
||||
pub used_transfer: bool,
|
||||
}
|
||||
|
||||
/// Full scoreboard summary.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ScoreboardSummary {
|
||||
pub domains: Vec<DomainSummary>,
|
||||
pub accelerations: Vec<AccelerationEntry>,
|
||||
pub overall_acceleration: f32,
|
||||
/// True if each new domain converges faster than the previous.
|
||||
pub progressive_improvement: bool,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_curve(domain: &str, transfer: bool, accuracy_steps: &[(u64, f32, f32)]) -> CostCurve {
|
||||
let mut curve = if transfer {
|
||||
CostCurve::with_transfer(
|
||||
DomainId(domain.into()),
|
||||
DomainId("source".into()),
|
||||
ConvergenceThresholds::default(),
|
||||
)
|
||||
} else {
|
||||
CostCurve::new(DomainId(domain.into()), ConvergenceThresholds::default())
|
||||
};
|
||||
|
||||
for &(cycle, accuracy, cost) in accuracy_steps {
|
||||
curve.record(CostCurvePoint {
|
||||
cycle,
|
||||
accuracy,
|
||||
cost_per_solve: cost,
|
||||
robustness: accuracy * 0.95,
|
||||
policy_violations: 0,
|
||||
timestamp: cycle as f64,
|
||||
});
|
||||
}
|
||||
curve
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cost_curve_convergence() {
|
||||
let curve = make_curve(
|
||||
"test",
|
||||
false,
|
||||
&[
|
||||
(0, 0.3, 0.1),
|
||||
(10, 0.6, 0.05),
|
||||
(20, 0.8, 0.02),
|
||||
(30, 0.95, 0.008),
|
||||
],
|
||||
);
|
||||
|
||||
assert!(curve.has_converged());
|
||||
assert_eq!(curve.cycles_to_accuracy(), Some(30));
|
||||
assert_eq!(curve.cycles_to_cost(), Some(30));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cost_curve_not_converged() {
|
||||
let curve = make_curve("test", false, &[(0, 0.3, 0.1), (10, 0.6, 0.05)]);
|
||||
|
||||
assert!(!curve.has_converged());
|
||||
assert_eq!(curve.cycles_to_accuracy(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compression_ratio() {
|
||||
let curve = make_curve(
|
||||
"test",
|
||||
false,
|
||||
&[(0, 0.3, 1.0), (10, 0.6, 0.5), (20, 0.9, 0.1)],
|
||||
);
|
||||
|
||||
let ratio = curve.compression_ratio();
|
||||
assert!((ratio - 10.0).abs() < 1e-4); // 1.0 / 0.1 = 10x
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_acceleration_scoreboard() {
|
||||
let mut board = AccelerationScoreboard::new();
|
||||
|
||||
// Domain 1: baseline (slow convergence)
|
||||
let baseline = make_curve(
|
||||
"d1_baseline",
|
||||
false,
|
||||
&[
|
||||
(0, 0.2, 0.1),
|
||||
(20, 0.5, 0.05),
|
||||
(50, 0.8, 0.02),
|
||||
(100, 0.95, 0.008),
|
||||
],
|
||||
);
|
||||
|
||||
// Domain 2: with transfer (fast convergence)
|
||||
let transfer = make_curve(
|
||||
"d2_transfer",
|
||||
true,
|
||||
&[
|
||||
(0, 0.4, 0.08),
|
||||
(10, 0.7, 0.03),
|
||||
(20, 0.9, 0.01),
|
||||
(40, 0.96, 0.007),
|
||||
],
|
||||
);
|
||||
|
||||
board.add_curve(baseline);
|
||||
board.add_curve(transfer);
|
||||
|
||||
let entry = board
|
||||
.compute_acceleration(
|
||||
&DomainId("d1_baseline".into()),
|
||||
&DomainId("d2_transfer".into()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(entry.acceleration > 1.0, "Transfer should accelerate");
|
||||
assert_eq!(entry.baseline_cycles, Some(100));
|
||||
assert_eq!(entry.transfer_cycles, Some(40));
|
||||
assert!((entry.acceleration - 2.5).abs() < 1e-4);
|
||||
assert!(entry.generalization_passed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scoreboard_summary() {
|
||||
let mut board = AccelerationScoreboard::new();
|
||||
let curve = make_curve("d1", false, &[(0, 0.5, 0.1), (50, 0.96, 0.005)]);
|
||||
board.add_curve(curve);
|
||||
|
||||
let summary = board.summary();
|
||||
assert_eq!(summary.domains.len(), 1);
|
||||
assert!(summary.domains[0].converged);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auc_accuracy() {
|
||||
let curve = make_curve(
|
||||
"test",
|
||||
false,
|
||||
&[(0, 0.0, 1.0), (10, 0.5, 0.5), (20, 1.0, 0.1)],
|
||||
);
|
||||
|
||||
let auc = curve.auc_accuracy();
|
||||
// Trapezoid: (10*(0+0.5)/2) + (10*(0.5+1.0)/2) = 2.5 + 7.5 = 10.0
|
||||
assert!((auc - 10.0).abs() < 1e-4);
|
||||
}
|
||||
}
|
||||
212
vendor/ruvector/crates/ruvector-domain-expansion/src/domain.rs
vendored
Normal file
212
vendor/ruvector/crates/ruvector-domain-expansion/src/domain.rs
vendored
Normal file
@@ -0,0 +1,212 @@
|
||||
//! Core domain trait and types for cross-domain transfer learning.
|
||||
//!
|
||||
//! A domain defines a problem space with:
|
||||
//! - A task generator (produces training instances)
|
||||
//! - An evaluator (scores solutions on [0.0, 1.0])
|
||||
//! - Embedding extraction (maps solutions into a shared representation space)
|
||||
//!
|
||||
//! True IQ growth appears when a kernel trained on Domain 1 improves Domain 2
|
||||
//! faster than Domain 2 alone. That is generalization.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
/// Unique identifier for a domain.
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct DomainId(pub String);
|
||||
|
||||
impl fmt::Display for DomainId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// A single task instance within a domain.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Task {
|
||||
/// Unique task identifier.
|
||||
pub id: String,
|
||||
/// Domain this task belongs to.
|
||||
pub domain_id: DomainId,
|
||||
/// Difficulty level [0.0, 1.0].
|
||||
pub difficulty: f32,
|
||||
/// Structured task specification (domain-specific JSON).
|
||||
pub spec: serde_json::Value,
|
||||
/// Optional constraints the solution must satisfy.
|
||||
pub constraints: Vec<String>,
|
||||
}
|
||||
|
||||
/// A candidate solution to a domain task.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Solution {
|
||||
/// The task this solves.
|
||||
pub task_id: String,
|
||||
/// Raw solution content (e.g., Rust source, plan steps, tool calls).
|
||||
pub content: String,
|
||||
/// Structured solution data (domain-specific).
|
||||
pub data: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Evaluation result for a solution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Evaluation {
|
||||
/// Overall score [0.0, 1.0] where 1.0 is perfect.
|
||||
pub score: f32,
|
||||
/// Correctness: does it produce the right answer?
|
||||
pub correctness: f32,
|
||||
/// Efficiency: resource usage relative to optimal.
|
||||
pub efficiency: f32,
|
||||
/// Elegance: structural quality, idiomatic patterns.
|
||||
pub elegance: f32,
|
||||
/// Per-constraint pass/fail results.
|
||||
pub constraint_results: Vec<bool>,
|
||||
/// Diagnostic notes from the evaluator.
|
||||
pub notes: Vec<String>,
|
||||
}
|
||||
|
||||
impl Evaluation {
|
||||
/// Create a zero-score evaluation (failure).
|
||||
pub fn zero(notes: Vec<String>) -> Self {
|
||||
Self {
|
||||
score: 0.0,
|
||||
correctness: 0.0,
|
||||
efficiency: 0.0,
|
||||
elegance: 0.0,
|
||||
constraint_results: Vec::new(),
|
||||
notes,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute composite score from weighted sub-scores.
|
||||
pub fn composite(correctness: f32, efficiency: f32, elegance: f32) -> Self {
|
||||
let score = 0.6 * correctness + 0.25 * efficiency + 0.15 * elegance;
|
||||
Self {
|
||||
score: score.clamp(0.0, 1.0),
|
||||
correctness,
|
||||
efficiency,
|
||||
elegance,
|
||||
constraint_results: Vec::new(),
|
||||
notes: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Embedding vector for cross-domain representation.
|
||||
/// Solutions from different domains are projected into a shared space
|
||||
/// so that transfer learning can identify structural similarities.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DomainEmbedding {
|
||||
/// The embedding vector.
|
||||
pub vector: Vec<f32>,
|
||||
/// Which domain produced this embedding.
|
||||
pub domain_id: DomainId,
|
||||
/// Dimensionality.
|
||||
pub dim: usize,
|
||||
}
|
||||
|
||||
impl DomainEmbedding {
|
||||
/// Create a new embedding.
|
||||
pub fn new(vector: Vec<f32>, domain_id: DomainId) -> Self {
|
||||
let dim = vector.len();
|
||||
Self {
|
||||
vector,
|
||||
domain_id,
|
||||
dim,
|
||||
}
|
||||
}
|
||||
|
||||
/// Cosine similarity with another embedding.
|
||||
pub fn cosine_similarity(&self, other: &DomainEmbedding) -> f32 {
|
||||
assert_eq!(self.dim, other.dim, "Embedding dimensions must match");
|
||||
|
||||
let mut dot = 0.0f32;
|
||||
let mut norm_a = 0.0f32;
|
||||
let mut norm_b = 0.0f32;
|
||||
|
||||
for i in 0..self.dim {
|
||||
dot += self.vector[i] * other.vector[i];
|
||||
norm_a += self.vector[i] * self.vector[i];
|
||||
norm_b += other.vector[i] * other.vector[i];
|
||||
}
|
||||
|
||||
let denom = (norm_a.sqrt() * norm_b.sqrt()).max(1e-10);
|
||||
dot / denom
|
||||
}
|
||||
}
|
||||
|
||||
/// Core trait that every domain must implement.
|
||||
///
|
||||
/// Domains are problem spaces: Rust program synthesis, structured planning,
|
||||
/// tool orchestration, etc. Each domain knows how to generate tasks,
|
||||
/// evaluate solutions, and embed solutions into a shared representation space.
|
||||
pub trait Domain: Send + Sync {
|
||||
/// Unique identifier for this domain.
|
||||
fn id(&self) -> &DomainId;
|
||||
|
||||
/// Human-readable name.
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Generate a batch of tasks at the given difficulty level.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `count` - Number of tasks to generate
|
||||
/// * `difficulty` - Target difficulty [0.0, 1.0]
|
||||
fn generate_tasks(&self, count: usize, difficulty: f32) -> Vec<Task>;
|
||||
|
||||
/// Evaluate a solution against its task.
|
||||
fn evaluate(&self, task: &Task, solution: &Solution) -> Evaluation;
|
||||
|
||||
/// Project a solution into the shared embedding space.
|
||||
/// This enables cross-domain transfer by finding structural similarities
|
||||
/// between solutions across different problem domains.
|
||||
fn embed(&self, solution: &Solution) -> DomainEmbedding;
|
||||
|
||||
/// Embedding dimensionality for this domain.
|
||||
fn embedding_dim(&self) -> usize;
|
||||
|
||||
/// Generate a reference (optimal or near-optimal) solution for a task.
|
||||
/// Used for computing efficiency ratios and as training signal.
|
||||
fn reference_solution(&self, task: &Task) -> Option<Solution>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_domain_id_display() {
|
||||
let id = DomainId("rust_synthesis".to_string());
|
||||
assert_eq!(format!("{}", id), "rust_synthesis");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluation_zero() {
|
||||
let eval = Evaluation::zero(vec!["compile error".to_string()]);
|
||||
assert_eq!(eval.score, 0.0);
|
||||
assert_eq!(eval.notes.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluation_composite() {
|
||||
let eval = Evaluation::composite(1.0, 0.8, 0.6);
|
||||
// 0.6*1.0 + 0.25*0.8 + 0.15*0.6 = 0.6 + 0.2 + 0.09 = 0.89
|
||||
assert!((eval.score - 0.89).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_cosine_similarity() {
|
||||
let id = DomainId("test".to_string());
|
||||
let a = DomainEmbedding::new(vec![1.0, 0.0, 0.0], id.clone());
|
||||
let b = DomainEmbedding::new(vec![1.0, 0.0, 0.0], id.clone());
|
||||
assert!((a.cosine_similarity(&b) - 1.0).abs() < 1e-6);
|
||||
|
||||
let c = DomainEmbedding::new(vec![0.0, 1.0, 0.0], id);
|
||||
assert!(a.cosine_similarity(&c).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluation_clamp() {
|
||||
let eval = Evaluation::composite(1.0, 1.0, 1.0);
|
||||
assert!(eval.score <= 1.0);
|
||||
}
|
||||
}
|
||||
39
vendor/ruvector/crates/ruvector-domain-expansion/src/error.rs
vendored
Normal file
39
vendor/ruvector/crates/ruvector-domain-expansion/src/error.rs
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
//! Error types for domain expansion.
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// Errors that can occur during domain expansion operations.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum DomainError {
|
||||
/// Problem generation failed.
|
||||
#[error("problem generation failed: {0}")]
|
||||
Generation(String),
|
||||
|
||||
/// Solution evaluation failed.
|
||||
#[error("evaluation failed: {0}")]
|
||||
Evaluation(String),
|
||||
|
||||
/// Dimension mismatch between domains.
|
||||
#[error("dimension mismatch: expected {expected}, got {got}")]
|
||||
DimensionMismatch { expected: usize, got: usize },
|
||||
|
||||
/// Domain not found in the expansion engine.
|
||||
#[error("domain not found: {0}")]
|
||||
DomainNotFound(String),
|
||||
|
||||
/// Transfer failed between domains.
|
||||
#[error("transfer failed from {source} to {target}: {reason}")]
|
||||
TransferFailed {
|
||||
source: String,
|
||||
target: String,
|
||||
reason: String,
|
||||
},
|
||||
|
||||
/// Kernel has not been trained on any domain yet.
|
||||
#[error("kernel not initialized: {0}")]
|
||||
KernelNotInitialized(String),
|
||||
|
||||
/// Invalid configuration.
|
||||
#[error("invalid config: {0}")]
|
||||
InvalidConfig(String),
|
||||
}
|
||||
591
vendor/ruvector/crates/ruvector-domain-expansion/src/lib.rs
vendored
Normal file
591
vendor/ruvector/crates/ruvector-domain-expansion/src/lib.rs
vendored
Normal file
@@ -0,0 +1,591 @@
|
||||
//! # Domain Expansion Engine
|
||||
//!
|
||||
//! Cross-domain transfer learning for general problem-solving capability.
|
||||
//!
|
||||
//! ## Core Insight
|
||||
//!
|
||||
//! True IQ growth appears when a kernel trained on Domain 1 improves Domain 2
|
||||
//! faster than Domain 2 alone. That is generalization.
|
||||
//!
|
||||
//! ## Two-Layer Architecture
|
||||
//!
|
||||
//! **Policy learning layer**: Meta Thompson Sampling with Beta priors across
|
||||
//! context buckets. Chooses strategies via uncertainty-aware selection.
|
||||
//! Transfer happens through compact priors — not raw trajectories.
|
||||
//!
|
||||
//! **Operator layer**: Deterministic domain kernels (Rust synthesis, planning,
|
||||
//! tool orchestration) that generate tasks, evaluate solutions, and produce
|
||||
//! embeddings into a shared representation space.
|
||||
//!
|
||||
//! ## Domains
|
||||
//!
|
||||
//! - **Rust Program Synthesis**: Generate Rust functions from specifications
|
||||
//! - **Structured Planning**: Multi-step plans with dependencies and resources
|
||||
//! - **Tool Orchestration**: Coordinate multiple tools/agents for complex goals
|
||||
//!
|
||||
//! ## Transfer Protocol
|
||||
//!
|
||||
//! 1. Train on Domain 1, extract `TransferPrior` (posterior summaries)
|
||||
//! 2. Initialize Domain 2 with dampened priors from Domain 1
|
||||
//! 3. Measure acceleration: cycles to convergence with/without transfer
|
||||
//! 4. A delta is promotable only if it improves target without regressing source
|
||||
//!
|
||||
//! ## Population-Based Policy Search
|
||||
//!
|
||||
//! Run a population of `PolicyKernel` variants in parallel.
|
||||
//! Each variant tunes knobs (skip mode, prepass, speculation thresholds).
|
||||
//! Keep top performers on holdouts, mutate, repeat.
|
||||
//!
|
||||
//! ## Acceptance Test
|
||||
//!
|
||||
//! Domain 2 must converge faster than Domain 1 to target accuracy, cost,
|
||||
//! robustness, and zero policy violations.
|
||||
|
||||
#![warn(missing_docs)]
|
||||
|
||||
pub mod cost_curve;
|
||||
pub mod domain;
|
||||
pub mod meta_learning;
|
||||
pub mod planning;
|
||||
pub mod policy_kernel;
|
||||
pub mod rust_synthesis;
|
||||
pub mod tool_orchestration;
|
||||
pub mod transfer;
|
||||
|
||||
/// RVF format integration: segment serialization, witness chains, AGI packaging.
|
||||
///
|
||||
/// Requires the `rvf` feature to be enabled.
|
||||
#[cfg(feature = "rvf")]
|
||||
pub mod rvf_bridge;
|
||||
|
||||
// Re-export core types.
|
||||
pub use cost_curve::{
|
||||
AccelerationEntry, AccelerationScoreboard, ConvergenceThresholds, CostCurve, CostCurvePoint,
|
||||
ScoreboardSummary,
|
||||
};
|
||||
pub use domain::{Domain, DomainEmbedding, DomainId, Evaluation, Solution, Task};
|
||||
pub use meta_learning::{
|
||||
CuriosityBonus, DecayingBeta, MetaLearningEngine, MetaLearningHealth, ParetoFront, ParetoPoint,
|
||||
PlateauAction, PlateauDetector, RegretSummary, RegretTracker,
|
||||
};
|
||||
pub use planning::PlanningDomain;
|
||||
pub use policy_kernel::{PolicyKernel, PolicyKnobs, PopulationSearch, PopulationStats};
|
||||
pub use rust_synthesis::RustSynthesisDomain;
|
||||
pub use tool_orchestration::ToolOrchestrationDomain;
|
||||
pub use transfer::{
|
||||
ArmId, BetaParams, ContextBucket, DualPathResult, MetaThompsonEngine, TransferPrior,
|
||||
TransferVerification,
|
||||
};
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// The domain expansion orchestrator.
|
||||
///
|
||||
/// Manages multiple domains, transfer learning between them,
|
||||
/// population-based policy search, and the acceleration scoreboard.
|
||||
///
|
||||
/// The `meta` field provides five composable learning improvements:
|
||||
/// regret tracking, decaying priors, plateau detection, Pareto front
|
||||
/// optimization, and curiosity-driven exploration.
|
||||
pub struct DomainExpansionEngine {
|
||||
/// Registered domains.
|
||||
domains: HashMap<DomainId, Box<dyn Domain>>,
|
||||
/// Meta Thompson Sampling engine for cross-domain transfer.
|
||||
pub thompson: MetaThompsonEngine,
|
||||
/// Population-based policy search.
|
||||
pub population: PopulationSearch,
|
||||
/// Acceleration scoreboard tracking convergence across domains.
|
||||
pub scoreboard: AccelerationScoreboard,
|
||||
/// Meta-learning engine: regret, plateau, Pareto, curiosity, decay.
|
||||
pub meta: MetaLearningEngine,
|
||||
/// Holdout tasks per domain for verification.
|
||||
holdouts: HashMap<DomainId, Vec<Task>>,
|
||||
/// Counterexample set: failed solutions that inform future decisions.
|
||||
counterexamples: HashMap<DomainId, Vec<(Task, Solution, Evaluation)>>,
|
||||
}
|
||||
|
||||
impl DomainExpansionEngine {
|
||||
/// Create a new domain expansion engine with default configuration.
|
||||
///
|
||||
/// Initializes the three core domains and the transfer engine.
|
||||
pub fn new() -> Self {
|
||||
let arms = vec![
|
||||
"greedy".into(),
|
||||
"exploratory".into(),
|
||||
"conservative".into(),
|
||||
"speculative".into(),
|
||||
];
|
||||
|
||||
let mut engine = Self {
|
||||
domains: HashMap::new(),
|
||||
thompson: MetaThompsonEngine::new(arms),
|
||||
population: PopulationSearch::new(8),
|
||||
scoreboard: AccelerationScoreboard::new(),
|
||||
meta: MetaLearningEngine::new(),
|
||||
holdouts: HashMap::new(),
|
||||
counterexamples: HashMap::new(),
|
||||
};
|
||||
|
||||
// Register the three core domains.
|
||||
engine.register_domain(Box::new(RustSynthesisDomain::new()));
|
||||
engine.register_domain(Box::new(PlanningDomain::new()));
|
||||
engine.register_domain(Box::new(ToolOrchestrationDomain::new()));
|
||||
|
||||
engine
|
||||
}
|
||||
|
||||
/// Register a new domain.
|
||||
pub fn register_domain(&mut self, domain: Box<dyn Domain>) {
|
||||
let id = domain.id().clone();
|
||||
self.thompson.init_domain_uniform(id.clone());
|
||||
self.domains.insert(id, domain);
|
||||
}
|
||||
|
||||
/// Generate holdout tasks for verification.
|
||||
pub fn generate_holdouts(&mut self, tasks_per_domain: usize, difficulty: f32) {
|
||||
for (id, domain) in &self.domains {
|
||||
let tasks = domain.generate_tasks(tasks_per_domain, difficulty);
|
||||
self.holdouts.insert(id.clone(), tasks);
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate training tasks for a specific domain.
|
||||
pub fn generate_tasks(&self, domain_id: &DomainId, count: usize, difficulty: f32) -> Vec<Task> {
|
||||
self.domains
|
||||
.get(domain_id)
|
||||
.map(|d| d.generate_tasks(count, difficulty))
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Evaluate a solution and record the outcome.
|
||||
pub fn evaluate_and_record(
|
||||
&mut self,
|
||||
domain_id: &DomainId,
|
||||
task: &Task,
|
||||
solution: &Solution,
|
||||
bucket: ContextBucket,
|
||||
arm: ArmId,
|
||||
) -> Evaluation {
|
||||
let eval = self
|
||||
.domains
|
||||
.get(domain_id)
|
||||
.map(|d| d.evaluate(task, solution))
|
||||
.unwrap_or_else(|| Evaluation::zero(vec!["Domain not found".into()]));
|
||||
|
||||
// Record outcome in Thompson engine.
|
||||
self.thompson.record_outcome(
|
||||
domain_id,
|
||||
bucket.clone(),
|
||||
arm.clone(),
|
||||
eval.score,
|
||||
1.0, // unit cost for now
|
||||
);
|
||||
|
||||
// Record in meta-learning engine (regret + curiosity + decaying beta).
|
||||
self.meta.record_decision(&bucket, &arm, eval.score);
|
||||
|
||||
// Store counterexamples for poor solutions.
|
||||
if eval.score < 0.3 {
|
||||
self.counterexamples
|
||||
.entry(domain_id.clone())
|
||||
.or_default()
|
||||
.push((task.clone(), solution.clone(), eval.clone()));
|
||||
}
|
||||
|
||||
eval
|
||||
}
|
||||
|
||||
/// Embed a solution into the shared representation space.
|
||||
pub fn embed(&self, domain_id: &DomainId, solution: &Solution) -> Option<DomainEmbedding> {
|
||||
self.domains.get(domain_id).map(|d| d.embed(solution))
|
||||
}
|
||||
|
||||
/// Initiate transfer from source domain to target domain.
|
||||
/// Extracts priors from source and seeds target.
|
||||
pub fn initiate_transfer(&mut self, source: &DomainId, target: &DomainId) {
|
||||
if let Some(prior) = self.thompson.extract_prior(source) {
|
||||
self.thompson
|
||||
.init_domain_with_transfer(target.clone(), &prior);
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify a transfer delta: did it improve target without regressing source?
|
||||
pub fn verify_transfer(
|
||||
&self,
|
||||
source: &DomainId,
|
||||
target: &DomainId,
|
||||
source_before: f32,
|
||||
source_after: f32,
|
||||
target_before: f32,
|
||||
target_after: f32,
|
||||
baseline_cycles: u64,
|
||||
transfer_cycles: u64,
|
||||
) -> TransferVerification {
|
||||
TransferVerification::verify(
|
||||
source.clone(),
|
||||
target.clone(),
|
||||
source_before,
|
||||
source_after,
|
||||
target_before,
|
||||
target_after,
|
||||
baseline_cycles,
|
||||
transfer_cycles,
|
||||
)
|
||||
}
|
||||
|
||||
/// Evaluate all policy kernels on holdout tasks.
|
||||
pub fn evaluate_population(&mut self) {
|
||||
let holdout_snapshot: HashMap<DomainId, Vec<Task>> = self.holdouts.clone();
|
||||
let domain_ids: Vec<DomainId> = self.domains.keys().cloned().collect();
|
||||
|
||||
for i in 0..self.population.population().len() {
|
||||
for domain_id in &domain_ids {
|
||||
if let Some(holdout_tasks) = holdout_snapshot.get(domain_id) {
|
||||
let mut total_score = 0.0f32;
|
||||
let mut count = 0;
|
||||
|
||||
for task in holdout_tasks {
|
||||
if let Some(domain) = self.domains.get(domain_id) {
|
||||
if let Some(ref_sol) = domain.reference_solution(task) {
|
||||
let eval = domain.evaluate(task, &ref_sol);
|
||||
total_score += eval.score;
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let avg_score = if count > 0 {
|
||||
total_score / count as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
if let Some(kernel) = self.population.kernel_mut(i) {
|
||||
kernel.record_score(domain_id.clone(), avg_score, 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Evolve the policy kernel population and update Pareto front.
|
||||
pub fn evolve_population(&mut self) {
|
||||
// Record current population into Pareto front before evolving.
|
||||
let gen = self.population.generation();
|
||||
for kernel in self.population.population() {
|
||||
let accuracy = kernel.fitness();
|
||||
let cost = if kernel.cycles > 0 {
|
||||
kernel.total_cost / kernel.cycles as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
// Robustness approximated by consistency across domains.
|
||||
let robustness = if kernel.holdout_scores.len() > 1 {
|
||||
let mean = accuracy;
|
||||
let var: f32 = kernel
|
||||
.holdout_scores
|
||||
.values()
|
||||
.map(|s| (s - mean).powi(2))
|
||||
.sum::<f32>()
|
||||
/ kernel.holdout_scores.len() as f32;
|
||||
(1.0 - var.sqrt()).max(0.0)
|
||||
} else {
|
||||
accuracy
|
||||
};
|
||||
self.meta
|
||||
.record_kernel(&kernel.id, accuracy, cost, robustness, gen);
|
||||
}
|
||||
|
||||
self.population.evolve();
|
||||
}
|
||||
|
||||
/// Get the best policy kernel found so far.
|
||||
pub fn best_kernel(&self) -> Option<&PolicyKernel> {
|
||||
self.population.best()
|
||||
}
|
||||
|
||||
/// Get population statistics.
|
||||
pub fn population_stats(&self) -> PopulationStats {
|
||||
self.population.stats()
|
||||
}
|
||||
|
||||
/// Get the scoreboard summary.
|
||||
pub fn scoreboard_summary(&self) -> ScoreboardSummary {
|
||||
self.scoreboard.summary()
|
||||
}
|
||||
|
||||
/// Get registered domain IDs.
|
||||
pub fn domain_ids(&self) -> Vec<DomainId> {
|
||||
self.domains.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Get counterexamples for a domain.
|
||||
pub fn counterexamples(&self, domain_id: &DomainId) -> &[(Task, Solution, Evaluation)] {
|
||||
self.counterexamples
|
||||
.get(domain_id)
|
||||
.map(|v| v.as_slice())
|
||||
.unwrap_or(&[])
|
||||
}
|
||||
|
||||
/// Select best arm for a context using Thompson Sampling.
|
||||
pub fn select_arm(&self, domain_id: &DomainId, bucket: &ContextBucket) -> Option<ArmId> {
|
||||
let mut rng = rand::thread_rng();
|
||||
self.thompson.select_arm(domain_id, bucket, &mut rng)
|
||||
}
|
||||
|
||||
/// Check if dual-path speculation should be triggered.
|
||||
pub fn should_speculate(&self, domain_id: &DomainId, bucket: &ContextBucket) -> bool {
|
||||
self.thompson.is_uncertain(domain_id, bucket, 0.15)
|
||||
}
|
||||
|
||||
/// Select arm with curiosity-boosted Thompson Sampling.
|
||||
///
|
||||
/// Combines the standard Thompson sample with a UCB-style exploration
|
||||
/// bonus that favors under-visited bucket/arm combinations.
|
||||
pub fn select_arm_curious(
|
||||
&self,
|
||||
domain_id: &DomainId,
|
||||
bucket: &ContextBucket,
|
||||
) -> Option<ArmId> {
|
||||
let mut rng = rand::thread_rng();
|
||||
// Get all arms and compute boosted scores
|
||||
let prior = self.thompson.extract_prior(domain_id)?;
|
||||
let arms: Vec<ArmId> = prior
|
||||
.bucket_priors
|
||||
.get(bucket)
|
||||
.map(|m| m.keys().cloned().collect())
|
||||
.unwrap_or_default();
|
||||
|
||||
if arms.is_empty() {
|
||||
return self.thompson.select_arm(domain_id, bucket, &mut rng);
|
||||
}
|
||||
|
||||
let mut best_arm = None;
|
||||
let mut best_score = f32::NEG_INFINITY;
|
||||
|
||||
for arm in &arms {
|
||||
let params = prior.get_prior(bucket, arm);
|
||||
let sample = params.sample(&mut rng);
|
||||
let boosted = self.meta.boosted_score(bucket, arm, sample);
|
||||
|
||||
if boosted > best_score {
|
||||
best_score = boosted;
|
||||
best_arm = Some(arm.clone());
|
||||
}
|
||||
}
|
||||
|
||||
best_arm.or_else(|| self.thompson.select_arm(domain_id, bucket, &mut rng))
|
||||
}
|
||||
|
||||
/// Get meta-learning health diagnostics.
|
||||
pub fn meta_health(&self) -> MetaLearningHealth {
|
||||
self.meta.health_check()
|
||||
}
|
||||
|
||||
/// Check cost curve for plateau and get recommended action.
|
||||
pub fn check_plateau(&mut self, domain_id: &DomainId) -> PlateauAction {
|
||||
if let Some(curve) = self.scoreboard.curves.get(domain_id) {
|
||||
self.meta.check_plateau(&curve.points)
|
||||
} else {
|
||||
PlateauAction::Continue
|
||||
}
|
||||
}
|
||||
|
||||
/// Get regret summary across all learning contexts.
|
||||
pub fn regret_summary(&self) -> RegretSummary {
|
||||
self.meta.regret.summary()
|
||||
}
|
||||
|
||||
/// Get the Pareto front of non-dominated policy kernels.
|
||||
pub fn pareto_front(&self) -> &ParetoFront {
|
||||
&self.meta.pareto
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DomainExpansionEngine {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_engine_creation() {
|
||||
let engine = DomainExpansionEngine::new();
|
||||
let ids = engine.domain_ids();
|
||||
assert_eq!(ids.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_tasks_all_domains() {
|
||||
let engine = DomainExpansionEngine::new();
|
||||
for domain_id in engine.domain_ids() {
|
||||
let tasks = engine.generate_tasks(&domain_id, 5, 0.5);
|
||||
assert_eq!(tasks.len(), 5);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arm_selection() {
|
||||
let engine = DomainExpansionEngine::new();
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "medium".into(),
|
||||
category: "general".into(),
|
||||
};
|
||||
for domain_id in engine.domain_ids() {
|
||||
let arm = engine.select_arm(&domain_id, &bucket);
|
||||
assert!(arm.is_some());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluate_and_record() {
|
||||
let mut engine = DomainExpansionEngine::new();
|
||||
let domain_id = DomainId("rust_synthesis".into());
|
||||
let tasks = engine.generate_tasks(&domain_id, 1, 0.3);
|
||||
let task = &tasks[0];
|
||||
|
||||
let solution = Solution {
|
||||
task_id: task.id.clone(),
|
||||
content:
|
||||
"fn double(values: &[i64]) -> Vec<i64> { values.iter().map(|&x| x * 2).collect() }"
|
||||
.into(),
|
||||
data: serde_json::Value::Null,
|
||||
};
|
||||
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "easy".into(),
|
||||
category: "transform".into(),
|
||||
};
|
||||
let arm = ArmId("greedy".into());
|
||||
|
||||
let eval = engine.evaluate_and_record(&domain_id, task, &solution, bucket, arm);
|
||||
assert!(eval.score >= 0.0 && eval.score <= 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_domain_embedding() {
|
||||
let engine = DomainExpansionEngine::new();
|
||||
|
||||
let rust_sol = Solution {
|
||||
task_id: "rust".into(),
|
||||
content: "fn foo() { for i in 0..10 { if i > 5 { } } }".into(),
|
||||
data: serde_json::Value::Null,
|
||||
};
|
||||
|
||||
let plan_sol = Solution {
|
||||
task_id: "plan".into(),
|
||||
content: "allocate cpu then schedule parallel jobs".into(),
|
||||
data: serde_json::json!({"steps": []}),
|
||||
};
|
||||
|
||||
let rust_emb = engine
|
||||
.embed(&DomainId("rust_synthesis".into()), &rust_sol)
|
||||
.unwrap();
|
||||
let plan_emb = engine
|
||||
.embed(&DomainId("structured_planning".into()), &plan_sol)
|
||||
.unwrap();
|
||||
|
||||
// Embeddings should be same dimension.
|
||||
assert_eq!(rust_emb.dim, plan_emb.dim);
|
||||
|
||||
// Cross-domain similarity should be defined.
|
||||
let sim = rust_emb.cosine_similarity(&plan_emb);
|
||||
assert!(sim >= -1.0 && sim <= 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transfer_flow() {
|
||||
let mut engine = DomainExpansionEngine::new();
|
||||
let source = DomainId("rust_synthesis".into());
|
||||
let target = DomainId("structured_planning".into());
|
||||
|
||||
// Record some outcomes in source domain.
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "medium".into(),
|
||||
category: "algorithm".into(),
|
||||
};
|
||||
|
||||
for _ in 0..30 {
|
||||
engine.thompson.record_outcome(
|
||||
&source,
|
||||
bucket.clone(),
|
||||
ArmId("greedy".into()),
|
||||
0.85,
|
||||
1.0,
|
||||
);
|
||||
}
|
||||
|
||||
// Initiate transfer.
|
||||
engine.initiate_transfer(&source, &target);
|
||||
|
||||
// Verify the transfer.
|
||||
let verification = engine.verify_transfer(
|
||||
&source, &target, 0.85, // source before
|
||||
0.845, // source after (within tolerance)
|
||||
0.3, // target before
|
||||
0.7, // target after
|
||||
100, // baseline cycles
|
||||
45, // transfer cycles
|
||||
);
|
||||
|
||||
assert!(verification.promotable);
|
||||
assert!(verification.acceleration_factor > 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_population_evolution() {
|
||||
let mut engine = DomainExpansionEngine::new();
|
||||
engine.generate_holdouts(3, 0.3);
|
||||
engine.evaluate_population();
|
||||
|
||||
let stats_before = engine.population_stats();
|
||||
assert_eq!(stats_before.generation, 0);
|
||||
|
||||
engine.evolve_population();
|
||||
let stats_after = engine.population_stats();
|
||||
assert_eq!(stats_after.generation, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_speculation_trigger() {
|
||||
let engine = DomainExpansionEngine::new();
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "hard".into(),
|
||||
category: "unknown".into(),
|
||||
};
|
||||
|
||||
// With uniform priors, should be uncertain.
|
||||
assert!(engine.should_speculate(&DomainId("rust_synthesis".into()), &bucket,));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_counterexample_tracking() {
|
||||
let mut engine = DomainExpansionEngine::new();
|
||||
let domain_id = DomainId("rust_synthesis".into());
|
||||
let tasks = engine.generate_tasks(&domain_id, 1, 0.9);
|
||||
let task = &tasks[0];
|
||||
|
||||
// Submit a terrible solution.
|
||||
let solution = Solution {
|
||||
task_id: task.id.clone(),
|
||||
content: "".into(), // empty = bad
|
||||
data: serde_json::Value::Null,
|
||||
};
|
||||
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "hard".into(),
|
||||
category: "algorithm".into(),
|
||||
};
|
||||
let arm = ArmId("speculative".into());
|
||||
|
||||
let eval = engine.evaluate_and_record(&domain_id, task, &solution, bucket, arm);
|
||||
assert!(eval.score < 0.3);
|
||||
|
||||
// Should be recorded as counterexample.
|
||||
assert!(!engine.counterexamples(&domain_id).is_empty());
|
||||
}
|
||||
}
|
||||
1398
vendor/ruvector/crates/ruvector-domain-expansion/src/meta_learning.rs
vendored
Normal file
1398
vendor/ruvector/crates/ruvector-domain-expansion/src/meta_learning.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
647
vendor/ruvector/crates/ruvector-domain-expansion/src/planning.rs
vendored
Normal file
647
vendor/ruvector/crates/ruvector-domain-expansion/src/planning.rs
vendored
Normal file
@@ -0,0 +1,647 @@
|
||||
//! Structured Planning Tasks Domain
|
||||
//!
|
||||
//! Generates tasks that require multi-step reasoning and plan construction.
|
||||
//! Task types include:
|
||||
//!
|
||||
//! - **ResourceAllocation**: Assign limited resources to maximize objective
|
||||
//! - **DependencyScheduling**: Order tasks respecting dependencies and deadlines
|
||||
//! - **StateSpaceSearch**: Navigate from initial to goal state
|
||||
//! - **ConstraintSatisfaction**: Find assignments satisfying all constraints
|
||||
//! - **HierarchicalDecomposition**: Break complex goals into sub-goals
|
||||
//!
|
||||
//! Solutions are plans: ordered sequences of actions with preconditions and effects.
|
||||
//! Cross-domain transfer from Rust synthesis helps because both require:
|
||||
//! structured decomposition, constraint satisfaction, and efficient search.
|
||||
|
||||
use crate::domain::{Domain, DomainEmbedding, DomainId, Evaluation, Solution, Task};
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const EMBEDDING_DIM: usize = 64;
|
||||
|
||||
/// Categories of planning tasks.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum PlanningCategory {
|
||||
/// Assign limited resources to competing demands.
|
||||
ResourceAllocation,
|
||||
/// Schedule tasks with precedence constraints and deadlines.
|
||||
DependencyScheduling,
|
||||
/// Find a path from initial state to goal state.
|
||||
StateSpaceSearch,
|
||||
/// Assign values to variables satisfying all constraints.
|
||||
ConstraintSatisfaction,
|
||||
/// Decompose a high-level goal into achievable sub-tasks.
|
||||
HierarchicalDecomposition,
|
||||
}
|
||||
|
||||
/// A resource in the planning world.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Resource {
|
||||
pub name: String,
|
||||
pub capacity: u32,
|
||||
}
|
||||
|
||||
/// An action in a plan.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PlanAction {
|
||||
pub name: String,
|
||||
pub preconditions: Vec<String>,
|
||||
pub effects: Vec<String>,
|
||||
pub cost: f32,
|
||||
pub duration: u32,
|
||||
}
|
||||
|
||||
/// A dependency edge: task A must complete before task B.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Dependency {
|
||||
pub from: String,
|
||||
pub to: String,
|
||||
}
|
||||
|
||||
/// Specification for a planning task.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PlanningTaskSpec {
|
||||
pub category: PlanningCategory,
|
||||
pub description: String,
|
||||
/// Available actions in the planning domain.
|
||||
pub available_actions: Vec<PlanAction>,
|
||||
/// Resources with capacity limits.
|
||||
pub resources: Vec<Resource>,
|
||||
/// Dependency constraints.
|
||||
pub dependencies: Vec<Dependency>,
|
||||
/// Initial state predicates.
|
||||
pub initial_state: Vec<String>,
|
||||
/// Goal state predicates.
|
||||
pub goal_state: Vec<String>,
|
||||
/// Maximum allowed plan cost.
|
||||
pub max_cost: Option<f32>,
|
||||
/// Maximum allowed plan steps.
|
||||
pub max_steps: Option<usize>,
|
||||
}
|
||||
|
||||
/// A parsed plan from a solution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Plan {
|
||||
pub steps: Vec<PlanStep>,
|
||||
}
|
||||
|
||||
/// A single step in a plan.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PlanStep {
|
||||
pub action: String,
|
||||
pub args: Vec<String>,
|
||||
pub start_time: Option<u32>,
|
||||
}
|
||||
|
||||
/// Structured planning domain.
|
||||
pub struct PlanningDomain {
|
||||
id: DomainId,
|
||||
}
|
||||
|
||||
impl PlanningDomain {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
id: DomainId("structured_planning".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_resource_allocation(&self, difficulty: f32) -> PlanningTaskSpec {
|
||||
let num_tasks = if difficulty < 0.3 {
|
||||
3
|
||||
} else if difficulty < 0.7 {
|
||||
6
|
||||
} else {
|
||||
10
|
||||
};
|
||||
|
||||
let actions: Vec<PlanAction> = (0..num_tasks)
|
||||
.map(|i| PlanAction {
|
||||
name: format!("task_{}", i),
|
||||
preconditions: vec![format!("resource_available_{}", i % 3)],
|
||||
effects: vec![format!("task_{}_complete", i)],
|
||||
cost: (i as f32 + 1.0) * 10.0,
|
||||
duration: (i as u32 % 5) + 1,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let resources = vec![
|
||||
Resource {
|
||||
name: "cpu".into(),
|
||||
capacity: if difficulty < 0.5 { 10 } else { 5 },
|
||||
},
|
||||
Resource {
|
||||
name: "memory".into(),
|
||||
capacity: if difficulty < 0.5 { 8 } else { 3 },
|
||||
},
|
||||
Resource {
|
||||
name: "io".into(),
|
||||
capacity: if difficulty < 0.5 { 6 } else { 2 },
|
||||
},
|
||||
];
|
||||
|
||||
let goal_state: Vec<String> = (0..num_tasks)
|
||||
.map(|i| format!("task_{}_complete", i))
|
||||
.collect();
|
||||
|
||||
PlanningTaskSpec {
|
||||
category: PlanningCategory::ResourceAllocation,
|
||||
description: format!(
|
||||
"Allocate {} resources to complete {} tasks within capacity.",
|
||||
resources.len(),
|
||||
num_tasks
|
||||
),
|
||||
available_actions: actions,
|
||||
resources,
|
||||
dependencies: Vec::new(),
|
||||
initial_state: vec![
|
||||
"resource_available_0".into(),
|
||||
"resource_available_1".into(),
|
||||
"resource_available_2".into(),
|
||||
],
|
||||
goal_state,
|
||||
max_cost: Some(num_tasks as f32 * 50.0),
|
||||
max_steps: Some(num_tasks * 2),
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_dependency_scheduling(&self, difficulty: f32) -> PlanningTaskSpec {
|
||||
let num_tasks = if difficulty < 0.3 {
|
||||
4
|
||||
} else if difficulty < 0.7 {
|
||||
7
|
||||
} else {
|
||||
12
|
||||
};
|
||||
|
||||
let actions: Vec<PlanAction> = (0..num_tasks)
|
||||
.map(|i| PlanAction {
|
||||
name: format!("job_{}", i),
|
||||
preconditions: if i > 0 {
|
||||
vec![format!("job_{}_done", i - 1)]
|
||||
} else {
|
||||
Vec::new()
|
||||
},
|
||||
effects: vec![format!("job_{}_done", i)],
|
||||
cost: 1.0,
|
||||
duration: (i as u32 % 3) + 1,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Create dependency chain with some parallelism
|
||||
let mut dependencies = Vec::new();
|
||||
for i in 1..num_tasks {
|
||||
// Linear chain
|
||||
dependencies.push(Dependency {
|
||||
from: format!("job_{}", i - 1),
|
||||
to: format!("job_{}", i),
|
||||
});
|
||||
// Add cross-dependencies at higher difficulty
|
||||
if difficulty > 0.5 && i >= 3 && i % 2 == 0 {
|
||||
dependencies.push(Dependency {
|
||||
from: format!("job_{}", i - 3),
|
||||
to: format!("job_{}", i),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
PlanningTaskSpec {
|
||||
category: PlanningCategory::DependencyScheduling,
|
||||
description: format!(
|
||||
"Schedule {} jobs respecting {} dependencies, minimizing makespan.",
|
||||
num_tasks,
|
||||
dependencies.len()
|
||||
),
|
||||
available_actions: actions,
|
||||
resources: vec![Resource {
|
||||
name: "worker".into(),
|
||||
capacity: if difficulty < 0.5 { 3 } else { 2 },
|
||||
}],
|
||||
dependencies,
|
||||
initial_state: Vec::new(),
|
||||
goal_state: (0..num_tasks).map(|i| format!("job_{}_done", i)).collect(),
|
||||
max_cost: None,
|
||||
max_steps: Some(num_tasks + 5),
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_state_space_search(&self, difficulty: f32) -> PlanningTaskSpec {
|
||||
let grid_size = if difficulty < 0.3 {
|
||||
3
|
||||
} else if difficulty < 0.7 {
|
||||
5
|
||||
} else {
|
||||
8
|
||||
};
|
||||
|
||||
let actions = vec![
|
||||
PlanAction {
|
||||
name: "move_up".into(),
|
||||
preconditions: vec!["not_top_edge".into()],
|
||||
effects: vec!["moved_up".into()],
|
||||
cost: 1.0,
|
||||
duration: 1,
|
||||
},
|
||||
PlanAction {
|
||||
name: "move_down".into(),
|
||||
preconditions: vec!["not_bottom_edge".into()],
|
||||
effects: vec!["moved_down".into()],
|
||||
cost: 1.0,
|
||||
duration: 1,
|
||||
},
|
||||
PlanAction {
|
||||
name: "move_left".into(),
|
||||
preconditions: vec!["not_left_edge".into()],
|
||||
effects: vec!["moved_left".into()],
|
||||
cost: 1.0,
|
||||
duration: 1,
|
||||
},
|
||||
PlanAction {
|
||||
name: "move_right".into(),
|
||||
preconditions: vec!["not_right_edge".into()],
|
||||
effects: vec!["moved_right".into()],
|
||||
cost: 1.0,
|
||||
duration: 1,
|
||||
},
|
||||
];
|
||||
|
||||
PlanningTaskSpec {
|
||||
category: PlanningCategory::StateSpaceSearch,
|
||||
description: format!(
|
||||
"Navigate a {}x{} grid from (0,0) to ({},{}) avoiding obstacles.",
|
||||
grid_size,
|
||||
grid_size,
|
||||
grid_size - 1,
|
||||
grid_size - 1
|
||||
),
|
||||
available_actions: actions,
|
||||
resources: Vec::new(),
|
||||
dependencies: Vec::new(),
|
||||
initial_state: vec!["at(0,0)".into()],
|
||||
goal_state: vec![format!("at({},{})", grid_size - 1, grid_size - 1)],
|
||||
max_cost: Some((grid_size as f32) * 4.0),
|
||||
max_steps: Some(grid_size * grid_size),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract structural features from a planning solution.
|
||||
fn extract_features(&self, solution: &Solution) -> Vec<f32> {
|
||||
let content = &solution.content;
|
||||
let mut features = vec![0.0f32; EMBEDDING_DIM];
|
||||
|
||||
// Parse the plan
|
||||
let plan: Plan = serde_json::from_str(&solution.data.to_string())
|
||||
.or_else(|_| serde_json::from_str(content))
|
||||
.unwrap_or(Plan { steps: Vec::new() });
|
||||
|
||||
// Feature 0-7: Plan structure
|
||||
features[0] = plan.steps.len() as f32 / 20.0;
|
||||
features[1] = {
|
||||
let unique_actions: std::collections::HashSet<&str> =
|
||||
plan.steps.iter().map(|s| s.action.as_str()).collect();
|
||||
unique_actions.len() as f32 / plan.steps.len().max(1) as f32
|
||||
};
|
||||
// Sequential vs parallel indicator
|
||||
features[2] = plan
|
||||
.steps
|
||||
.windows(2)
|
||||
.filter(|w| w[0].start_time == w[1].start_time)
|
||||
.count() as f32
|
||||
/ plan.steps.len().max(1) as f32;
|
||||
// Average args per step
|
||||
features[3] = plan.steps.iter().map(|s| s.args.len() as f32).sum::<f32>()
|
||||
/ plan.steps.len().max(1) as f32
|
||||
/ 5.0;
|
||||
|
||||
// Feature 8-15: Action type distribution
|
||||
let action_counts: std::collections::HashMap<&str, usize> =
|
||||
plan.steps
|
||||
.iter()
|
||||
.fold(std::collections::HashMap::new(), |mut acc, s| {
|
||||
*acc.entry(s.action.as_str()).or_insert(0) += 1;
|
||||
acc
|
||||
});
|
||||
let max_count = action_counts.values().max().copied().unwrap_or(0);
|
||||
features[8] = action_counts.len() as f32 / 10.0;
|
||||
features[9] = max_count as f32 / plan.steps.len().max(1) as f32;
|
||||
|
||||
// Feature 16-23: Text-based features from content
|
||||
features[16] = content.matches("allocate").count() as f32 / 5.0;
|
||||
features[17] = content.matches("schedule").count() as f32 / 5.0;
|
||||
features[18] = content.matches("move").count() as f32 / 10.0;
|
||||
features[19] = content.matches("assign").count() as f32 / 5.0;
|
||||
features[20] = content.matches("wait").count() as f32 / 5.0;
|
||||
features[21] = content.matches("parallel").count() as f32 / 3.0;
|
||||
features[22] = content.matches("constraint").count() as f32 / 5.0;
|
||||
features[23] = content.matches("deadline").count() as f32 / 3.0;
|
||||
|
||||
// Feature 32-39: Structural complexity indicators
|
||||
features[32] = content.matches("->").count() as f32 / 10.0;
|
||||
features[33] = content.matches("if ").count() as f32 / 5.0;
|
||||
features[34] = content.matches("then ").count() as f32 / 5.0;
|
||||
features[35] = content.matches("before").count() as f32 / 5.0;
|
||||
features[36] = content.matches("after").count() as f32 / 5.0;
|
||||
features[37] = content.matches("while").count() as f32 / 3.0;
|
||||
features[38] = content.matches("until").count() as f32 / 3.0;
|
||||
features[39] = content.matches("complete").count() as f32 / 5.0;
|
||||
|
||||
// Feature 48-55: Resource usage indicators
|
||||
features[48] = content.matches("cpu").count() as f32 / 3.0;
|
||||
features[49] = content.matches("memory").count() as f32 / 3.0;
|
||||
features[50] = content.matches("worker").count() as f32 / 3.0;
|
||||
features[51] = content.matches("capacity").count() as f32 / 3.0;
|
||||
features[52] = content.matches("cost").count() as f32 / 5.0;
|
||||
features[53] = content.matches("time").count() as f32 / 5.0;
|
||||
features[54] = content.matches("resource").count() as f32 / 5.0;
|
||||
features[55] = content.matches("limit").count() as f32 / 3.0;
|
||||
|
||||
// Normalize
|
||||
let norm: f32 = features.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 1e-10 {
|
||||
for f in &mut features {
|
||||
*f /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
features
|
||||
}
|
||||
|
||||
/// Evaluate a planning solution.
|
||||
fn score_plan(&self, spec: &PlanningTaskSpec, solution: &Solution) -> Evaluation {
|
||||
let content = &solution.content;
|
||||
let mut correctness = 0.0f32;
|
||||
let mut efficiency = 0.5f32;
|
||||
let mut elegance = 0.5f32;
|
||||
let mut notes = Vec::new();
|
||||
|
||||
// Parse plan from solution
|
||||
let plan: Option<Plan> = serde_json::from_str(&solution.data.to_string())
|
||||
.ok()
|
||||
.or_else(|| serde_json::from_str(content).ok());
|
||||
|
||||
let plan = match plan {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
// Fall back to text analysis
|
||||
let has_steps = content.contains("step") || content.contains("action");
|
||||
if has_steps {
|
||||
correctness = 0.2;
|
||||
}
|
||||
return Evaluation {
|
||||
score: correctness * 0.6,
|
||||
correctness,
|
||||
efficiency: 0.0,
|
||||
elegance: 0.0,
|
||||
constraint_results: Vec::new(),
|
||||
notes: vec!["Could not parse structured plan".into()],
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
// Check plan is non-empty
|
||||
if plan.steps.is_empty() {
|
||||
return Evaluation::zero(vec!["Empty plan".into()]);
|
||||
}
|
||||
|
||||
// Check goal coverage: how many goal predicates are addressed
|
||||
let goal_coverage = spec
|
||||
.goal_state
|
||||
.iter()
|
||||
.filter(|goal| {
|
||||
plan.steps.iter().any(|step| {
|
||||
let action_name = &step.action;
|
||||
// Check if any action's effects mention this goal
|
||||
spec.available_actions
|
||||
.iter()
|
||||
.any(|a| a.name == *action_name && a.effects.iter().any(|e| e == *goal))
|
||||
})
|
||||
})
|
||||
.count() as f32
|
||||
/ spec.goal_state.len().max(1) as f32;
|
||||
|
||||
correctness = goal_coverage;
|
||||
|
||||
// Check dependency ordering
|
||||
let mut dep_violations = 0;
|
||||
for dep in &spec.dependencies {
|
||||
let from_pos = plan.steps.iter().position(|s| s.action == dep.from);
|
||||
let to_pos = plan.steps.iter().position(|s| s.action == dep.to);
|
||||
if let (Some(f), Some(t)) = (from_pos, to_pos) {
|
||||
if f >= t {
|
||||
dep_violations += 1;
|
||||
notes.push(format!(
|
||||
"Dependency violation: {} must come before {}",
|
||||
dep.from, dep.to
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
if !spec.dependencies.is_empty() {
|
||||
let dep_score = 1.0 - (dep_violations as f32 / spec.dependencies.len() as f32);
|
||||
correctness = correctness * 0.5 + dep_score * 0.5;
|
||||
}
|
||||
|
||||
// Efficiency: compare to max allowed steps/cost
|
||||
if let Some(max_steps) = spec.max_steps {
|
||||
let step_ratio = plan.steps.len() as f32 / max_steps as f32;
|
||||
efficiency = if step_ratio <= 1.0 {
|
||||
1.0 - (step_ratio * 0.5) // Fewer steps = better
|
||||
} else {
|
||||
0.5 / step_ratio // Penalty for exceeding max
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(max_cost) = spec.max_cost {
|
||||
let total_cost: f32 = plan
|
||||
.steps
|
||||
.iter()
|
||||
.filter_map(|step| {
|
||||
spec.available_actions
|
||||
.iter()
|
||||
.find(|a| a.name == step.action)
|
||||
.map(|a| a.cost)
|
||||
})
|
||||
.sum();
|
||||
if total_cost > max_cost {
|
||||
efficiency *= 0.5;
|
||||
notes.push(format!(
|
||||
"Plan cost {:.1} exceeds budget {:.1}",
|
||||
total_cost, max_cost
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Elegance: minimal redundancy, good parallelism
|
||||
let unique_actions: std::collections::HashSet<&str> =
|
||||
plan.steps.iter().map(|s| s.action.as_str()).collect();
|
||||
let redundancy = 1.0 - (unique_actions.len() as f32 / plan.steps.len().max(1) as f32);
|
||||
elegance = 1.0 - redundancy * 0.5;
|
||||
|
||||
// Bonus for parallel scheduling
|
||||
if plan
|
||||
.steps
|
||||
.windows(2)
|
||||
.any(|w| w[0].start_time == w[1].start_time)
|
||||
{
|
||||
elegance += 0.1;
|
||||
}
|
||||
elegance = elegance.clamp(0.0, 1.0);
|
||||
|
||||
let score = 0.6 * correctness + 0.25 * efficiency + 0.15 * elegance;
|
||||
Evaluation {
|
||||
score: score.clamp(0.0, 1.0),
|
||||
correctness,
|
||||
efficiency,
|
||||
elegance,
|
||||
constraint_results: Vec::new(),
|
||||
notes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PlanningDomain {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Domain for PlanningDomain {
|
||||
fn id(&self) -> &DomainId {
|
||||
&self.id
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"Structured Planning"
|
||||
}
|
||||
|
||||
fn generate_tasks(&self, count: usize, difficulty: f32) -> Vec<Task> {
|
||||
let mut rng = rand::thread_rng();
|
||||
let difficulty = difficulty.clamp(0.0, 1.0);
|
||||
|
||||
(0..count)
|
||||
.map(|i| {
|
||||
let category_roll: f32 = rng.gen();
|
||||
let spec = if category_roll < 0.35 {
|
||||
self.gen_resource_allocation(difficulty)
|
||||
} else if category_roll < 0.7 {
|
||||
self.gen_dependency_scheduling(difficulty)
|
||||
} else {
|
||||
self.gen_state_space_search(difficulty)
|
||||
};
|
||||
|
||||
Task {
|
||||
id: format!("planning_{}_d{:.0}", i, difficulty * 100.0),
|
||||
domain_id: self.id.clone(),
|
||||
difficulty,
|
||||
spec: serde_json::to_value(&spec).unwrap_or_default(),
|
||||
constraints: Vec::new(),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn evaluate(&self, task: &Task, solution: &Solution) -> Evaluation {
|
||||
let spec: PlanningTaskSpec = match serde_json::from_value(task.spec.clone()) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Evaluation::zero(vec![format!("Invalid task spec: {}", e)]),
|
||||
};
|
||||
self.score_plan(&spec, solution)
|
||||
}
|
||||
|
||||
fn embed(&self, solution: &Solution) -> DomainEmbedding {
|
||||
let features = self.extract_features(solution);
|
||||
DomainEmbedding::new(features, self.id.clone())
|
||||
}
|
||||
|
||||
fn embedding_dim(&self) -> usize {
|
||||
EMBEDDING_DIM
|
||||
}
|
||||
|
||||
fn reference_solution(&self, task: &Task) -> Option<Solution> {
|
||||
let spec: PlanningTaskSpec = serde_json::from_value(task.spec.clone()).ok()?;
|
||||
|
||||
// Generate a naive sequential plan that executes all actions in order
|
||||
let steps: Vec<PlanStep> = spec
|
||||
.available_actions
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, a)| PlanStep {
|
||||
action: a.name.clone(),
|
||||
args: Vec::new(),
|
||||
start_time: Some(i as u32),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let plan = Plan { steps };
|
||||
let content = serde_json::to_string_pretty(&plan).ok()?;
|
||||
|
||||
Some(Solution {
|
||||
task_id: task.id.clone(),
|
||||
content,
|
||||
data: serde_json::to_value(&plan).ok()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_planning_tasks() {
|
||||
let domain = PlanningDomain::new();
|
||||
let tasks = domain.generate_tasks(5, 0.5);
|
||||
assert_eq!(tasks.len(), 5);
|
||||
for task in &tasks {
|
||||
assert_eq!(task.domain_id, domain.id);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reference_solution_exists() {
|
||||
let domain = PlanningDomain::new();
|
||||
let tasks = domain.generate_tasks(3, 0.3);
|
||||
for task in &tasks {
|
||||
let ref_sol = domain.reference_solution(task);
|
||||
assert!(ref_sol.is_some(), "Should produce reference solution");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluate_reference() {
|
||||
let domain = PlanningDomain::new();
|
||||
let tasks = domain.generate_tasks(3, 0.3);
|
||||
for task in &tasks {
|
||||
if let Some(solution) = domain.reference_solution(task) {
|
||||
let eval = domain.evaluate(task, &solution);
|
||||
assert!(eval.score >= 0.0 && eval.score <= 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embed_planning() {
|
||||
let domain = PlanningDomain::new();
|
||||
let solution = Solution {
|
||||
task_id: "test".into(),
|
||||
content: "allocate cpu to task_0, schedule job_1 after job_0".into(),
|
||||
data: serde_json::json!({ "steps": [] }),
|
||||
};
|
||||
let embedding = domain.embed(&solution);
|
||||
assert_eq!(embedding.dim, EMBEDDING_DIM);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_difficulty_scaling() {
|
||||
let domain = PlanningDomain::new();
|
||||
let easy = domain.generate_tasks(1, 0.1);
|
||||
let hard = domain.generate_tasks(1, 0.9);
|
||||
|
||||
let easy_spec: PlanningTaskSpec = serde_json::from_value(easy[0].spec.clone()).unwrap();
|
||||
let hard_spec: PlanningTaskSpec = serde_json::from_value(hard[0].spec.clone()).unwrap();
|
||||
|
||||
assert!(
|
||||
hard_spec.available_actions.len() >= easy_spec.available_actions.len(),
|
||||
"Harder tasks should have more actions"
|
||||
);
|
||||
}
|
||||
}
|
||||
468
vendor/ruvector/crates/ruvector-domain-expansion/src/policy_kernel.rs
vendored
Normal file
468
vendor/ruvector/crates/ruvector-domain-expansion/src/policy_kernel.rs
vendored
Normal file
@@ -0,0 +1,468 @@
|
||||
//! PolicyKernel: Population-Based Policy Search
|
||||
//!
|
||||
//! Run a small population of policy variants in parallel.
|
||||
//! Each variant changes a small set of knobs:
|
||||
//! - skip mode policy
|
||||
//! - prepass mode
|
||||
//! - speculation trigger thresholds
|
||||
//! - budget allocation
|
||||
//!
|
||||
//! Selection: keep top performers on holdouts, mutate knobs, repeat.
|
||||
//! Only merge deltas that pass replay-verify.
|
||||
|
||||
use crate::domain::DomainId;
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Configuration knobs that a PolicyKernel can tune.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PolicyKnobs {
|
||||
/// Whether to skip low-value operations.
|
||||
pub skip_mode: bool,
|
||||
/// Run a cheaper prepass before full execution.
|
||||
pub prepass_enabled: bool,
|
||||
/// Threshold for triggering speculative dual-path [0.0, 1.0].
|
||||
pub speculation_threshold: f32,
|
||||
/// Budget fraction allocated to exploration vs exploitation [0.0, 1.0].
|
||||
pub exploration_budget: f32,
|
||||
/// Maximum retries on failure.
|
||||
pub max_retries: u32,
|
||||
/// Batch size for parallel evaluation.
|
||||
pub batch_size: usize,
|
||||
/// Cost decay factor for EMA.
|
||||
pub cost_decay: f32,
|
||||
/// Minimum confidence to skip uncertainty check.
|
||||
pub confidence_floor: f32,
|
||||
}
|
||||
|
||||
impl PolicyKnobs {
|
||||
/// Sensible defaults.
|
||||
pub fn default_knobs() -> Self {
|
||||
Self {
|
||||
skip_mode: false,
|
||||
prepass_enabled: true,
|
||||
speculation_threshold: 0.15,
|
||||
exploration_budget: 0.2,
|
||||
max_retries: 2,
|
||||
batch_size: 8,
|
||||
cost_decay: 0.9,
|
||||
confidence_floor: 0.7,
|
||||
}
|
||||
}
|
||||
|
||||
/// Mutate knobs with small random perturbations.
|
||||
pub fn mutate(&self, rng: &mut impl Rng, mutation_rate: f32) -> Self {
|
||||
let mut knobs = self.clone();
|
||||
|
||||
if rng.gen::<f32>() < mutation_rate {
|
||||
knobs.skip_mode = !knobs.skip_mode;
|
||||
}
|
||||
if rng.gen::<f32>() < mutation_rate {
|
||||
knobs.prepass_enabled = !knobs.prepass_enabled;
|
||||
}
|
||||
if rng.gen::<f32>() < mutation_rate {
|
||||
let delta: f32 = rng.gen_range(-0.1..0.1);
|
||||
knobs.speculation_threshold = (knobs.speculation_threshold + delta).clamp(0.01, 0.5);
|
||||
}
|
||||
if rng.gen::<f32>() < mutation_rate {
|
||||
let delta: f32 = rng.gen_range(-0.1..0.1);
|
||||
knobs.exploration_budget = (knobs.exploration_budget + delta).clamp(0.01, 0.5);
|
||||
}
|
||||
if rng.gen::<f32>() < mutation_rate {
|
||||
knobs.max_retries = rng.gen_range(0..5);
|
||||
}
|
||||
if rng.gen::<f32>() < mutation_rate {
|
||||
knobs.batch_size = rng.gen_range(1..32);
|
||||
}
|
||||
if rng.gen::<f32>() < mutation_rate {
|
||||
let delta: f32 = rng.gen_range(-0.05..0.05);
|
||||
knobs.cost_decay = (knobs.cost_decay + delta).clamp(0.5, 0.99);
|
||||
}
|
||||
if rng.gen::<f32>() < mutation_rate {
|
||||
let delta: f32 = rng.gen_range(-0.1..0.1);
|
||||
knobs.confidence_floor = (knobs.confidence_floor + delta).clamp(0.3, 0.95);
|
||||
}
|
||||
|
||||
knobs
|
||||
}
|
||||
|
||||
/// Crossover two parent knobs to produce a child.
|
||||
pub fn crossover(&self, other: &PolicyKnobs, rng: &mut impl Rng) -> Self {
|
||||
Self {
|
||||
skip_mode: if rng.gen() {
|
||||
self.skip_mode
|
||||
} else {
|
||||
other.skip_mode
|
||||
},
|
||||
prepass_enabled: if rng.gen() {
|
||||
self.prepass_enabled
|
||||
} else {
|
||||
other.prepass_enabled
|
||||
},
|
||||
speculation_threshold: if rng.gen() {
|
||||
self.speculation_threshold
|
||||
} else {
|
||||
other.speculation_threshold
|
||||
},
|
||||
exploration_budget: if rng.gen() {
|
||||
self.exploration_budget
|
||||
} else {
|
||||
other.exploration_budget
|
||||
},
|
||||
max_retries: if rng.gen() {
|
||||
self.max_retries
|
||||
} else {
|
||||
other.max_retries
|
||||
},
|
||||
batch_size: if rng.gen() {
|
||||
self.batch_size
|
||||
} else {
|
||||
other.batch_size
|
||||
},
|
||||
cost_decay: if rng.gen() {
|
||||
self.cost_decay
|
||||
} else {
|
||||
other.cost_decay
|
||||
},
|
||||
confidence_floor: if rng.gen() {
|
||||
self.confidence_floor
|
||||
} else {
|
||||
other.confidence_floor
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A PolicyKernel is a versioned policy configuration with performance history.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PolicyKernel {
|
||||
/// Unique identifier.
|
||||
pub id: String,
|
||||
/// Configuration knobs.
|
||||
pub knobs: PolicyKnobs,
|
||||
/// Performance on holdout tasks (domain_id -> score).
|
||||
pub holdout_scores: HashMap<DomainId, f32>,
|
||||
/// Total cost incurred.
|
||||
pub total_cost: f32,
|
||||
/// Number of evaluation cycles.
|
||||
pub cycles: u64,
|
||||
/// Generation (0 = initial, increments on mutation).
|
||||
pub generation: u32,
|
||||
/// Parent kernel ID (for lineage tracking).
|
||||
pub parent_id: Option<String>,
|
||||
/// Whether this kernel has been verified via replay.
|
||||
pub replay_verified: bool,
|
||||
}
|
||||
|
||||
impl PolicyKernel {
|
||||
/// Create a new kernel with default knobs.
|
||||
pub fn new(id: String) -> Self {
|
||||
Self {
|
||||
id,
|
||||
knobs: PolicyKnobs::default_knobs(),
|
||||
holdout_scores: HashMap::new(),
|
||||
total_cost: 0.0,
|
||||
cycles: 0,
|
||||
generation: 0,
|
||||
parent_id: None,
|
||||
replay_verified: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a mutated child kernel.
|
||||
pub fn mutate(&self, child_id: String, rng: &mut impl Rng) -> Self {
|
||||
Self {
|
||||
id: child_id,
|
||||
knobs: self.knobs.mutate(rng, 0.3),
|
||||
holdout_scores: HashMap::new(),
|
||||
total_cost: 0.0,
|
||||
cycles: 0,
|
||||
generation: self.generation + 1,
|
||||
parent_id: Some(self.id.clone()),
|
||||
replay_verified: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a holdout score for a domain.
|
||||
pub fn record_score(&mut self, domain_id: DomainId, score: f32, cost: f32) {
|
||||
self.holdout_scores.insert(domain_id, score);
|
||||
self.total_cost += cost;
|
||||
self.cycles += 1;
|
||||
}
|
||||
|
||||
/// Fitness: average holdout score across all evaluated domains.
|
||||
pub fn fitness(&self) -> f32 {
|
||||
if self.holdout_scores.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let total: f32 = self.holdout_scores.values().sum();
|
||||
total / self.holdout_scores.len() as f32
|
||||
}
|
||||
|
||||
/// Cost-adjusted fitness: penalizes expensive kernels.
|
||||
pub fn cost_adjusted_fitness(&self) -> f32 {
|
||||
let raw = self.fitness();
|
||||
let cost_penalty = (self.total_cost / self.cycles.max(1) as f32).min(1.0);
|
||||
raw * (1.0 - cost_penalty * 0.3) // 30% weight on cost
|
||||
}
|
||||
}
|
||||
|
||||
/// Population-based policy search engine.
|
||||
#[derive(Clone)]
|
||||
pub struct PopulationSearch {
|
||||
/// Current population of kernels.
|
||||
population: Vec<PolicyKernel>,
|
||||
/// Population size.
|
||||
pop_size: usize,
|
||||
/// Best kernel seen so far.
|
||||
best_kernel: Option<PolicyKernel>,
|
||||
/// Generation counter.
|
||||
generation: u32,
|
||||
}
|
||||
|
||||
impl PopulationSearch {
|
||||
/// Create a new population search with initial random population.
|
||||
pub fn new(pop_size: usize) -> Self {
|
||||
let mut rng = rand::thread_rng();
|
||||
let population: Vec<PolicyKernel> = (0..pop_size)
|
||||
.map(|i| {
|
||||
let mut kernel = PolicyKernel::new(format!("kernel_g0_{}", i));
|
||||
// Random initial knobs
|
||||
kernel.knobs = PolicyKnobs::default_knobs().mutate(&mut rng, 0.8);
|
||||
kernel
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
population,
|
||||
pop_size,
|
||||
best_kernel: None,
|
||||
generation: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current population for evaluation.
|
||||
pub fn population(&self) -> &[PolicyKernel] {
|
||||
&self.population
|
||||
}
|
||||
|
||||
/// Get mutable reference to a kernel by index.
|
||||
pub fn kernel_mut(&mut self, index: usize) -> Option<&mut PolicyKernel> {
|
||||
self.population.get_mut(index)
|
||||
}
|
||||
|
||||
/// Evolve to next generation: select top performers, mutate, fill population.
|
||||
pub fn evolve(&mut self) {
|
||||
let mut rng = rand::thread_rng();
|
||||
self.generation += 1;
|
||||
|
||||
// Sort by cost-adjusted fitness (descending)
|
||||
self.population.sort_by(|a, b| {
|
||||
b.cost_adjusted_fitness()
|
||||
.partial_cmp(&a.cost_adjusted_fitness())
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
// Track best
|
||||
if let Some(best) = self.population.first() {
|
||||
if self
|
||||
.best_kernel
|
||||
.as_ref()
|
||||
.map_or(true, |b| best.fitness() > b.fitness())
|
||||
{
|
||||
self.best_kernel = Some(best.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Elite selection: keep top 25%
|
||||
let elite_count = (self.pop_size / 4).max(1);
|
||||
let elites: Vec<PolicyKernel> = self.population[..elite_count].to_vec();
|
||||
|
||||
// Build next generation
|
||||
let mut next_gen = Vec::with_capacity(self.pop_size);
|
||||
|
||||
// Keep elites
|
||||
for elite in &elites {
|
||||
let mut kept = elite.clone();
|
||||
kept.id = format!("kernel_g{}_{}", self.generation, next_gen.len());
|
||||
kept.holdout_scores.clear();
|
||||
kept.total_cost = 0.0;
|
||||
kept.cycles = 0;
|
||||
next_gen.push(kept);
|
||||
}
|
||||
|
||||
// Fill rest with mutations and crossovers
|
||||
while next_gen.len() < self.pop_size {
|
||||
let parent_idx = rng.gen_range(0..elites.len());
|
||||
let child_id = format!("kernel_g{}_{}", self.generation, next_gen.len());
|
||||
|
||||
let child = if rng.gen::<f32>() < 0.3 && elites.len() > 1 {
|
||||
// Crossover
|
||||
let other_idx =
|
||||
(parent_idx + 1 + rng.gen_range(0..elites.len() - 1)) % elites.len();
|
||||
let mut child = PolicyKernel::new(child_id);
|
||||
child.knobs = elites[parent_idx]
|
||||
.knobs
|
||||
.crossover(&elites[other_idx].knobs, &mut rng);
|
||||
child.generation = self.generation;
|
||||
child.parent_id = Some(elites[parent_idx].id.clone());
|
||||
child
|
||||
} else {
|
||||
// Mutation
|
||||
elites[parent_idx].mutate(child_id, &mut rng)
|
||||
};
|
||||
|
||||
next_gen.push(child);
|
||||
}
|
||||
|
||||
self.population = next_gen;
|
||||
}
|
||||
|
||||
/// Get the best kernel found so far.
|
||||
pub fn best(&self) -> Option<&PolicyKernel> {
|
||||
self.best_kernel.as_ref()
|
||||
}
|
||||
|
||||
/// Current generation number.
|
||||
pub fn generation(&self) -> u32 {
|
||||
self.generation
|
||||
}
|
||||
|
||||
/// Get fitness statistics for the current population.
|
||||
pub fn stats(&self) -> PopulationStats {
|
||||
let fitnesses: Vec<f32> = self.population.iter().map(|k| k.fitness()).collect();
|
||||
let mean = fitnesses.iter().sum::<f32>() / fitnesses.len().max(1) as f32;
|
||||
let max = fitnesses.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let min = fitnesses.iter().cloned().fold(f32::INFINITY, f32::min);
|
||||
let variance = fitnesses.iter().map(|f| (f - mean).powi(2)).sum::<f32>()
|
||||
/ fitnesses.len().max(1) as f32;
|
||||
|
||||
PopulationStats {
|
||||
generation: self.generation,
|
||||
pop_size: self.population.len(),
|
||||
mean_fitness: mean,
|
||||
max_fitness: max,
|
||||
min_fitness: min,
|
||||
fitness_variance: variance,
|
||||
best_ever_fitness: self
|
||||
.best_kernel
|
||||
.as_ref()
|
||||
.map(|k| k.fitness())
|
||||
.unwrap_or(0.0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Statistics about the current population.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PopulationStats {
|
||||
pub generation: u32,
|
||||
pub pop_size: usize,
|
||||
pub mean_fitness: f32,
|
||||
pub max_fitness: f32,
|
||||
pub min_fitness: f32,
|
||||
pub fitness_variance: f32,
|
||||
pub best_ever_fitness: f32,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_policy_knobs_default() {
|
||||
let knobs = PolicyKnobs::default_knobs();
|
||||
assert!(!knobs.skip_mode);
|
||||
assert!(knobs.prepass_enabled);
|
||||
assert!(knobs.speculation_threshold > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_policy_knobs_mutate() {
|
||||
let knobs = PolicyKnobs::default_knobs();
|
||||
let mut rng = rand::thread_rng();
|
||||
let mutated = knobs.mutate(&mut rng, 1.0); // high mutation rate
|
||||
// At least something should differ (probabilistically)
|
||||
// Can't guarantee due to randomness, but bounds should hold
|
||||
assert!(mutated.speculation_threshold >= 0.01 && mutated.speculation_threshold <= 0.5);
|
||||
assert!(mutated.exploration_budget >= 0.01 && mutated.exploration_budget <= 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_policy_kernel_fitness() {
|
||||
let mut kernel = PolicyKernel::new("test".into());
|
||||
assert_eq!(kernel.fitness(), 0.0);
|
||||
|
||||
kernel.record_score(DomainId("d1".into()), 0.8, 1.0);
|
||||
kernel.record_score(DomainId("d2".into()), 0.6, 1.0);
|
||||
assert!((kernel.fitness() - 0.7).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_population_search_evolve() {
|
||||
let mut search = PopulationSearch::new(8);
|
||||
assert_eq!(search.population().len(), 8);
|
||||
|
||||
// Simulate evaluation
|
||||
for i in 0..8 {
|
||||
if let Some(kernel) = search.kernel_mut(i) {
|
||||
let score = 0.3 + (i as f32) * 0.08;
|
||||
kernel.record_score(DomainId("test".into()), score, 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
search.evolve();
|
||||
assert_eq!(search.population().len(), 8);
|
||||
assert_eq!(search.generation(), 1);
|
||||
assert!(search.best().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_population_stats() {
|
||||
let mut search = PopulationSearch::new(4);
|
||||
|
||||
for i in 0..4 {
|
||||
if let Some(kernel) = search.kernel_mut(i) {
|
||||
kernel.record_score(DomainId("test".into()), (i as f32) * 0.25, 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
let stats = search.stats();
|
||||
assert_eq!(stats.pop_size, 4);
|
||||
assert!(stats.max_fitness >= stats.min_fitness);
|
||||
assert!(stats.mean_fitness >= stats.min_fitness);
|
||||
assert!(stats.mean_fitness <= stats.max_fitness);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_crossover() {
|
||||
let a = PolicyKnobs {
|
||||
skip_mode: true,
|
||||
prepass_enabled: false,
|
||||
speculation_threshold: 0.1,
|
||||
exploration_budget: 0.1,
|
||||
max_retries: 1,
|
||||
batch_size: 4,
|
||||
cost_decay: 0.8,
|
||||
confidence_floor: 0.5,
|
||||
};
|
||||
let b = PolicyKnobs {
|
||||
skip_mode: false,
|
||||
prepass_enabled: true,
|
||||
speculation_threshold: 0.4,
|
||||
exploration_budget: 0.4,
|
||||
max_retries: 4,
|
||||
batch_size: 16,
|
||||
cost_decay: 0.95,
|
||||
confidence_floor: 0.9,
|
||||
};
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let child = a.crossover(&b, &mut rng);
|
||||
|
||||
// Child values should come from one parent or the other
|
||||
assert!(child.max_retries == 1 || child.max_retries == 4);
|
||||
assert!(child.batch_size == 4 || child.batch_size == 16);
|
||||
}
|
||||
}
|
||||
603
vendor/ruvector/crates/ruvector-domain-expansion/src/rust_synthesis.rs
vendored
Normal file
603
vendor/ruvector/crates/ruvector-domain-expansion/src/rust_synthesis.rs
vendored
Normal file
@@ -0,0 +1,603 @@
|
||||
//! Rust Program Synthesis Domain
|
||||
//!
|
||||
//! Generates tasks that require synthesizing Rust programs from specifications.
|
||||
//! Task types include:
|
||||
//!
|
||||
//! - **Transform**: Apply a function to data (map, filter, fold)
|
||||
//! - **DataStructure**: Implement a data structure with specific operations
|
||||
//! - **Algorithm**: Implement a named algorithm (sorting, searching, graph)
|
||||
//! - **TypeLevel**: Express constraints via Rust's type system
|
||||
//! - **Concurrency**: Safe concurrent data access patterns
|
||||
//!
|
||||
//! Solutions are evaluated on correctness (do test cases pass?),
|
||||
//! efficiency (complexity class), and elegance (idiomatic Rust patterns).
|
||||
|
||||
use crate::domain::{Domain, DomainEmbedding, DomainId, Evaluation, Solution, Task};
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Embedding dimension for Rust synthesis domain.
|
||||
const EMBEDDING_DIM: usize = 64;
|
||||
|
||||
/// Categories of Rust synthesis tasks.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum RustTaskCategory {
|
||||
/// Transform data: map, filter, fold, scan.
|
||||
Transform,
|
||||
/// Implement a data structure with trait impls.
|
||||
DataStructure,
|
||||
/// Implement a named algorithm.
|
||||
Algorithm,
|
||||
/// Type-level programming: generics, trait bounds, associated types.
|
||||
TypeLevel,
|
||||
/// Concurrent programming: Arc, Mutex, channels, atomics.
|
||||
Concurrency,
|
||||
}
|
||||
|
||||
/// Specification for a Rust synthesis task.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RustTaskSpec {
|
||||
/// Task category.
|
||||
pub category: RustTaskCategory,
|
||||
/// Function signature that must be implemented.
|
||||
pub signature: String,
|
||||
/// Natural language description of the required behavior.
|
||||
pub description: String,
|
||||
/// Test cases as (input_json, expected_output_json) pairs.
|
||||
pub test_cases: Vec<(String, String)>,
|
||||
/// Required traits the solution must implement.
|
||||
pub required_traits: Vec<String>,
|
||||
/// Banned patterns (e.g., "unsafe", "unwrap").
|
||||
pub banned_patterns: Vec<String>,
|
||||
/// Expected complexity class (e.g., "O(n log n)").
|
||||
pub expected_complexity: Option<String>,
|
||||
}
|
||||
|
||||
/// Rust program synthesis domain.
|
||||
pub struct RustSynthesisDomain {
|
||||
id: DomainId,
|
||||
}
|
||||
|
||||
impl RustSynthesisDomain {
|
||||
/// Create a new Rust synthesis domain.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
id: DomainId("rust_synthesis".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a transform task at the given difficulty.
|
||||
fn gen_transform(&self, difficulty: f32, rng: &mut impl Rng) -> RustTaskSpec {
|
||||
let (signature, description, tests, complexity) = if difficulty < 0.3 {
|
||||
// Easy: simple map
|
||||
let ops = ["double", "negate", "abs", "square"];
|
||||
let op = ops[rng.gen_range(0..ops.len())];
|
||||
(
|
||||
format!("fn {}(values: &[i64]) -> Vec<i64>", op),
|
||||
format!("Apply {} to each element in the slice.", op),
|
||||
match op {
|
||||
"double" => vec![
|
||||
("[1, 2, 3]".into(), "[2, 4, 6]".into()),
|
||||
("[-1, 0, 5]".into(), "[-2, 0, 10]".into()),
|
||||
],
|
||||
"negate" => vec![
|
||||
("[1, -2, 3]".into(), "[-1, 2, -3]".into()),
|
||||
("[0]".into(), "[0]".into()),
|
||||
],
|
||||
"abs" => vec![
|
||||
("[-1, 2, -3]".into(), "[1, 2, 3]".into()),
|
||||
("[0, -0]".into(), "[0, 0]".into()),
|
||||
],
|
||||
_ => vec![
|
||||
("[2, 3, 4]".into(), "[4, 9, 16]".into()),
|
||||
("[0, -1]".into(), "[0, 1]".into()),
|
||||
],
|
||||
},
|
||||
"O(n)",
|
||||
)
|
||||
} else if difficulty < 0.7 {
|
||||
// Medium: filter + fold combos
|
||||
(
|
||||
"fn sum_positives(values: &[i64]) -> i64".into(),
|
||||
"Sum all positive values in the slice.".into(),
|
||||
vec![
|
||||
("[1, -2, 3, -4, 5]".into(), "9".into()),
|
||||
("[-1, -2, -3]".into(), "0".into()),
|
||||
("[]".into(), "0".into()),
|
||||
],
|
||||
"O(n)",
|
||||
)
|
||||
} else {
|
||||
// Hard: sliding window / scan
|
||||
(
|
||||
"fn max_subarray_sum(values: &[i64]) -> i64".into(),
|
||||
"Find the maximum sum contiguous subarray (Kadane's algorithm).".into(),
|
||||
vec![
|
||||
("[-2, 1, -3, 4, -1, 2, 1, -5, 4]".into(), "6".into()),
|
||||
("[-1, -2, -3]".into(), "-1".into()),
|
||||
("[5]".into(), "5".into()),
|
||||
],
|
||||
"O(n)",
|
||||
)
|
||||
};
|
||||
|
||||
RustTaskSpec {
|
||||
category: RustTaskCategory::Transform,
|
||||
signature,
|
||||
description,
|
||||
test_cases: tests,
|
||||
required_traits: Vec::new(),
|
||||
banned_patterns: vec!["unsafe".into()],
|
||||
expected_complexity: Some(complexity.into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a data structure task.
|
||||
fn gen_data_structure(&self, difficulty: f32, _rng: &mut impl Rng) -> RustTaskSpec {
|
||||
if difficulty < 0.4 {
|
||||
RustTaskSpec {
|
||||
category: RustTaskCategory::DataStructure,
|
||||
signature: "struct Stack<T>".into(),
|
||||
description: "Implement a generic stack with push, pop, peek, is_empty, len."
|
||||
.into(),
|
||||
test_cases: vec![
|
||||
("push(1); push(2); pop()".into(), "Some(2)".into()),
|
||||
("is_empty()".into(), "true".into()),
|
||||
("push(1); len()".into(), "1".into()),
|
||||
],
|
||||
required_traits: vec!["Default".into()],
|
||||
banned_patterns: vec!["unsafe".into()],
|
||||
expected_complexity: Some("O(1) per operation".into()),
|
||||
}
|
||||
} else if difficulty < 0.7 {
|
||||
RustTaskSpec {
|
||||
category: RustTaskCategory::DataStructure,
|
||||
signature: "struct MinHeap<T: Ord>".into(),
|
||||
description: "Implement a binary min-heap with insert, extract_min, peek_min."
|
||||
.into(),
|
||||
test_cases: vec![
|
||||
(
|
||||
"insert(3); insert(1); insert(2); extract_min()".into(),
|
||||
"Some(1)".into(),
|
||||
),
|
||||
("peek_min() on empty".into(), "None".into()),
|
||||
],
|
||||
required_traits: vec!["Default".into()],
|
||||
banned_patterns: vec!["unsafe".into(), "BinaryHeap".into()],
|
||||
expected_complexity: Some("O(log n) insert/extract".into()),
|
||||
}
|
||||
} else {
|
||||
RustTaskSpec {
|
||||
category: RustTaskCategory::DataStructure,
|
||||
signature: "struct LRUCache<K: Hash + Eq, V>".into(),
|
||||
description: "Implement an LRU cache with get, put, and capacity eviction.".into(),
|
||||
test_cases: vec![
|
||||
(
|
||||
"cap=2; put(1,'a'); put(2,'b'); get(1); put(3,'c'); get(2)".into(),
|
||||
"None".into(),
|
||||
),
|
||||
(
|
||||
"cap=1; put(1,'a'); put(2,'b'); get(1)".into(),
|
||||
"None".into(),
|
||||
),
|
||||
],
|
||||
required_traits: Vec::new(),
|
||||
banned_patterns: vec!["unsafe".into()],
|
||||
expected_complexity: Some("O(1) get/put".into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate an algorithm task.
|
||||
fn gen_algorithm(&self, difficulty: f32, _rng: &mut impl Rng) -> RustTaskSpec {
|
||||
if difficulty < 0.4 {
|
||||
RustTaskSpec {
|
||||
category: RustTaskCategory::Algorithm,
|
||||
signature: "fn binary_search(sorted: &[i64], target: i64) -> Option<usize>".into(),
|
||||
description: "Implement binary search on a sorted slice.".into(),
|
||||
test_cases: vec![
|
||||
("[1,3,5,7,9], 5".into(), "Some(2)".into()),
|
||||
("[1,3,5,7,9], 4".into(), "None".into()),
|
||||
("[], 1".into(), "None".into()),
|
||||
],
|
||||
required_traits: Vec::new(),
|
||||
banned_patterns: vec!["unsafe".into()],
|
||||
expected_complexity: Some("O(log n)".into()),
|
||||
}
|
||||
} else if difficulty < 0.7 {
|
||||
RustTaskSpec {
|
||||
category: RustTaskCategory::Algorithm,
|
||||
signature: "fn merge_sort(values: &mut [i64])".into(),
|
||||
description: "Implement stable merge sort in-place.".into(),
|
||||
test_cases: vec![
|
||||
("[3,1,4,1,5,9,2,6]".into(), "[1,1,2,3,4,5,6,9]".into()),
|
||||
("[1]".into(), "[1]".into()),
|
||||
("[]".into(), "[]".into()),
|
||||
],
|
||||
required_traits: Vec::new(),
|
||||
banned_patterns: vec!["unsafe".into(), ".sort".into()],
|
||||
expected_complexity: Some("O(n log n)".into()),
|
||||
}
|
||||
} else {
|
||||
RustTaskSpec {
|
||||
category: RustTaskCategory::Algorithm,
|
||||
signature: "fn shortest_path(adj: &[Vec<(usize, u64)>], src: usize, dst: usize) -> Option<u64>".into(),
|
||||
description: "Implement Dijkstra's shortest path on a weighted directed graph.".into(),
|
||||
test_cases: vec![
|
||||
("3 nodes, 0->1:2, 1->2:3, 0->2:10; src=0, dst=2".into(), "Some(5)".into()),
|
||||
("2 nodes, no edges; src=0, dst=1".into(), "None".into()),
|
||||
],
|
||||
required_traits: Vec::new(),
|
||||
banned_patterns: vec!["unsafe".into()],
|
||||
expected_complexity: Some("O((V + E) log V)".into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract structural features from a Rust solution for embedding.
|
||||
fn extract_features(&self, solution: &Solution) -> Vec<f32> {
|
||||
let code = &solution.content;
|
||||
let mut features = vec![0.0f32; EMBEDDING_DIM];
|
||||
|
||||
// Feature 0-7: Control flow complexity
|
||||
features[0] = code.matches("if ").count() as f32 / 10.0;
|
||||
features[1] = code.matches("for ").count() as f32 / 5.0;
|
||||
features[2] = code.matches("while ").count() as f32 / 5.0;
|
||||
features[3] = code.matches("match ").count() as f32 / 5.0;
|
||||
features[4] = code.matches("loop ").count() as f32 / 3.0;
|
||||
features[5] = code.matches("return ").count() as f32 / 5.0;
|
||||
features[6] = code.matches("break").count() as f32 / 3.0;
|
||||
features[7] = code.matches("continue").count() as f32 / 3.0;
|
||||
|
||||
// Feature 8-15: Type system usage
|
||||
features[8] = code.matches("impl ").count() as f32 / 5.0;
|
||||
features[9] = code.matches("trait ").count() as f32 / 3.0;
|
||||
features[10] = code.matches("struct ").count() as f32 / 3.0;
|
||||
features[11] = code.matches("enum ").count() as f32 / 3.0;
|
||||
features[12] = code.matches("where ").count() as f32 / 3.0;
|
||||
features[13] = code.matches("dyn ").count() as f32 / 3.0;
|
||||
features[14] = code.matches("Box<").count() as f32 / 3.0;
|
||||
features[15] = code.matches("Rc<").count() as f32 / 3.0;
|
||||
|
||||
// Feature 16-23: Functional patterns
|
||||
features[16] = code.matches(".map(").count() as f32 / 5.0;
|
||||
features[17] = code.matches(".filter(").count() as f32 / 5.0;
|
||||
features[18] = code.matches(".fold(").count() as f32 / 3.0;
|
||||
features[19] = code.matches(".collect()").count() as f32 / 3.0;
|
||||
features[20] = code.matches(".iter()").count() as f32 / 5.0;
|
||||
features[21] = code.matches("|").count() as f32 / 10.0; // closures
|
||||
features[22] = code.matches("Some(").count() as f32 / 5.0;
|
||||
features[23] = code.matches("None").count() as f32 / 5.0;
|
||||
|
||||
// Feature 24-31: Memory/ownership patterns
|
||||
features[24] = code.matches("&mut ").count() as f32 / 5.0;
|
||||
features[25] = code.matches("&self").count() as f32 / 5.0;
|
||||
features[26] = code.matches("mut ").count() as f32 / 10.0;
|
||||
features[27] = code.matches(".clone()").count() as f32 / 5.0;
|
||||
features[28] = code.matches("Vec<").count() as f32 / 5.0;
|
||||
features[29] = code.matches("HashMap").count() as f32 / 3.0;
|
||||
features[30] = code.matches("String").count() as f32 / 5.0;
|
||||
features[31] = code.matches("Result<").count() as f32 / 3.0;
|
||||
|
||||
// Feature 32-39: Concurrency patterns
|
||||
features[32] = code.matches("Arc<").count() as f32 / 3.0;
|
||||
features[33] = code.matches("Mutex<").count() as f32 / 3.0;
|
||||
features[34] = code.matches("RwLock").count() as f32 / 3.0;
|
||||
features[35] = code.matches("async ").count() as f32 / 3.0;
|
||||
features[36] = code.matches("await").count() as f32 / 5.0;
|
||||
features[37] = code.matches("spawn").count() as f32 / 3.0;
|
||||
features[38] = code.matches("channel").count() as f32 / 3.0;
|
||||
features[39] = code.matches("Atomic").count() as f32 / 3.0;
|
||||
|
||||
// Feature 40-47: Code structure metrics
|
||||
let lines: Vec<&str> = code.lines().collect();
|
||||
features[40] = (lines.len() as f32) / 100.0;
|
||||
features[41] = lines.iter().filter(|l| l.trim().is_empty()).count() as f32
|
||||
/ (lines.len().max(1) as f32);
|
||||
features[42] = code.matches("fn ").count() as f32 / 10.0;
|
||||
features[43] = code.matches("pub ").count() as f32 / 10.0;
|
||||
features[44] = code.matches("mod ").count() as f32 / 5.0;
|
||||
features[45] = code.matches("use ").count() as f32 / 10.0;
|
||||
features[46] = code.matches("#[").count() as f32 / 5.0; // attributes
|
||||
features[47] = code.matches("///").count() as f32 / 10.0; // doc comments
|
||||
|
||||
// Feature 48-55: Error handling patterns
|
||||
features[48] = code.matches("unwrap()").count() as f32 / 5.0;
|
||||
features[49] = code.matches("expect(").count() as f32 / 5.0;
|
||||
features[50] = code.matches("?;").count() as f32 / 5.0; // error propagation
|
||||
features[51] = code.matches("Err(").count() as f32 / 5.0;
|
||||
features[52] = code.matches("Ok(").count() as f32 / 5.0;
|
||||
features[53] = code.matches("panic!").count() as f32 / 3.0;
|
||||
features[54] = code.matches("assert").count() as f32 / 5.0;
|
||||
features[55] = code.matches("debug_assert").count() as f32 / 3.0;
|
||||
|
||||
// Feature 56-63: Algorithm indicators
|
||||
features[56] = code.matches("sort").count() as f32 / 3.0;
|
||||
features[57] = code.matches("binary_search").count() as f32 / 2.0;
|
||||
features[58] = code.matches("push").count() as f32 / 5.0;
|
||||
features[59] = code.matches("pop").count() as f32 / 5.0;
|
||||
features[60] = code.matches("swap").count() as f32 / 5.0;
|
||||
features[61] = code.matches("len()").count() as f32 / 5.0;
|
||||
features[62] = code.matches("is_empty").count() as f32 / 3.0;
|
||||
features[63] = code.matches("contains").count() as f32 / 3.0;
|
||||
|
||||
// Normalize to unit length
|
||||
let norm: f32 = features.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 1e-10 {
|
||||
for f in &mut features {
|
||||
*f /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
features
|
||||
}
|
||||
|
||||
/// Score a Rust solution based on pattern matching heuristics.
|
||||
fn score_solution(&self, spec: &RustTaskSpec, solution: &Solution) -> Evaluation {
|
||||
let code = &solution.content;
|
||||
let mut correctness = 0.0f32;
|
||||
let mut efficiency = 0.5f32;
|
||||
let mut elegance = 0.5f32;
|
||||
let mut notes = Vec::new();
|
||||
|
||||
// Check for banned patterns
|
||||
let mut banned_found = false;
|
||||
for pattern in &spec.banned_patterns {
|
||||
if code.contains(pattern.as_str()) {
|
||||
notes.push(format!("Banned pattern found: {}", pattern));
|
||||
banned_found = true;
|
||||
}
|
||||
}
|
||||
|
||||
if banned_found {
|
||||
elegance *= 0.5;
|
||||
}
|
||||
|
||||
// Check that the solution contains the expected signature
|
||||
let sig_name = spec
|
||||
.signature
|
||||
.split('(')
|
||||
.next()
|
||||
.unwrap_or("")
|
||||
.split_whitespace()
|
||||
.last()
|
||||
.unwrap_or("");
|
||||
|
||||
if code.contains(sig_name) {
|
||||
correctness += 0.3;
|
||||
} else {
|
||||
notes.push(format!("Missing expected identifier: {}", sig_name));
|
||||
}
|
||||
|
||||
// Check for fn definition
|
||||
if code.contains("fn ") {
|
||||
correctness += 0.2;
|
||||
}
|
||||
|
||||
// Check for test case coverage hints
|
||||
let test_coverage = spec
|
||||
.test_cases
|
||||
.iter()
|
||||
.filter(|(input, _)| {
|
||||
// Heuristic: solution likely handles the input pattern
|
||||
let key_tokens: Vec<&str> = input.split(|c: char| !c.is_alphanumeric()).collect();
|
||||
key_tokens.iter().any(|t| !t.is_empty() && code.contains(t))
|
||||
})
|
||||
.count() as f32
|
||||
/ spec.test_cases.len().max(1) as f32;
|
||||
correctness += test_coverage * 0.5;
|
||||
correctness = correctness.clamp(0.0, 1.0);
|
||||
|
||||
// Efficiency: penalize obviously quadratic patterns
|
||||
let nested_loops = code.matches("for ").count() > 1 && code.matches("for ").count() > 2;
|
||||
if nested_loops {
|
||||
if let Some(ref expected) = spec.expected_complexity {
|
||||
if expected.contains("O(n)") || expected.contains("O(log") {
|
||||
efficiency *= 0.5;
|
||||
notes.push("Possible O(n^2) when O(n) or O(log n) expected".into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Elegance: favor idiomatic Rust
|
||||
let iterator_usage = code.matches(".iter()").count()
|
||||
+ code.matches(".map(").count()
|
||||
+ code.matches(".filter(").count()
|
||||
+ code.matches(".fold(").count();
|
||||
if iterator_usage > 0 {
|
||||
elegance += 0.2;
|
||||
}
|
||||
|
||||
// Penalize excessive unwrap
|
||||
let unwrap_count = code.matches("unwrap()").count();
|
||||
if unwrap_count > 3 {
|
||||
elegance -= 0.2;
|
||||
notes.push("Excessive unwrap() usage".into());
|
||||
}
|
||||
|
||||
// Proper error handling bonus
|
||||
if code.contains("Result<") || code.contains("?;") {
|
||||
elegance += 0.1;
|
||||
}
|
||||
|
||||
elegance = elegance.clamp(0.0, 1.0);
|
||||
|
||||
// Constraint results
|
||||
let constraint_results = spec
|
||||
.banned_patterns
|
||||
.iter()
|
||||
.map(|p| !code.contains(p.as_str()))
|
||||
.collect();
|
||||
|
||||
let score = 0.6 * correctness + 0.25 * efficiency + 0.15 * elegance;
|
||||
|
||||
Evaluation {
|
||||
score: score.clamp(0.0, 1.0),
|
||||
correctness,
|
||||
efficiency,
|
||||
elegance,
|
||||
constraint_results,
|
||||
notes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RustSynthesisDomain {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Domain for RustSynthesisDomain {
|
||||
fn id(&self) -> &DomainId {
|
||||
&self.id
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"Rust Program Synthesis"
|
||||
}
|
||||
|
||||
fn generate_tasks(&self, count: usize, difficulty: f32) -> Vec<Task> {
|
||||
let mut rng = rand::thread_rng();
|
||||
let difficulty = difficulty.clamp(0.0, 1.0);
|
||||
|
||||
(0..count)
|
||||
.map(|i| {
|
||||
let category_roll: f32 = rng.gen();
|
||||
let spec = if category_roll < 0.4 {
|
||||
self.gen_transform(difficulty, &mut rng)
|
||||
} else if category_roll < 0.7 {
|
||||
self.gen_data_structure(difficulty, &mut rng)
|
||||
} else {
|
||||
self.gen_algorithm(difficulty, &mut rng)
|
||||
};
|
||||
|
||||
Task {
|
||||
id: format!("rust_synth_{}_d{:.0}", i, difficulty * 100.0),
|
||||
domain_id: self.id.clone(),
|
||||
difficulty,
|
||||
spec: serde_json::to_value(&spec).unwrap_or_default(),
|
||||
constraints: spec.banned_patterns.clone(),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn evaluate(&self, task: &Task, solution: &Solution) -> Evaluation {
|
||||
let spec: RustTaskSpec = match serde_json::from_value(task.spec.clone()) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Evaluation::zero(vec![format!("Invalid task spec: {}", e)]),
|
||||
};
|
||||
self.score_solution(&spec, solution)
|
||||
}
|
||||
|
||||
fn embed(&self, solution: &Solution) -> DomainEmbedding {
|
||||
let features = self.extract_features(solution);
|
||||
DomainEmbedding::new(features, self.id.clone())
|
||||
}
|
||||
|
||||
fn embedding_dim(&self) -> usize {
|
||||
EMBEDDING_DIM
|
||||
}
|
||||
|
||||
fn reference_solution(&self, task: &Task) -> Option<Solution> {
|
||||
let spec: RustTaskSpec = serde_json::from_value(task.spec.clone()).ok()?;
|
||||
|
||||
let content = match spec.category {
|
||||
RustTaskCategory::Transform => {
|
||||
if spec.signature.contains("sum_positives") {
|
||||
"fn sum_positives(values: &[i64]) -> i64 {\n values.iter().filter(|&&x| x > 0).sum()\n}".to_string()
|
||||
} else if spec.signature.contains("max_subarray_sum") {
|
||||
"fn max_subarray_sum(values: &[i64]) -> i64 {\n let mut max_so_far = values[0];\n let mut max_ending = values[0];\n for &v in &values[1..] {\n max_ending = v.max(max_ending + v);\n max_so_far = max_so_far.max(max_ending);\n }\n max_so_far\n}".to_string()
|
||||
} else {
|
||||
format!(
|
||||
"{} {{\n values.iter().map(|&x| x /* TODO */).collect()\n}}",
|
||||
spec.signature
|
||||
)
|
||||
}
|
||||
}
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
Some(Solution {
|
||||
task_id: task.id.clone(),
|
||||
content,
|
||||
data: serde_json::Value::Null,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_tasks() {
|
||||
let domain = RustSynthesisDomain::new();
|
||||
let tasks = domain.generate_tasks(5, 0.5);
|
||||
assert_eq!(tasks.len(), 5);
|
||||
for task in &tasks {
|
||||
assert_eq!(task.domain_id, domain.id);
|
||||
assert!((task.difficulty - 0.5).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluate_good_solution() {
|
||||
let domain = RustSynthesisDomain::new();
|
||||
let tasks = domain.generate_tasks(1, 0.0);
|
||||
let task = &tasks[0];
|
||||
|
||||
let solution = Solution {
|
||||
task_id: task.id.clone(),
|
||||
content: "fn double(values: &[i64]) -> Vec<i64> {\n values.iter().map(|&x| x * 2).collect()\n}".to_string(),
|
||||
data: serde_json::Value::Null,
|
||||
};
|
||||
|
||||
let eval = domain.evaluate(task, &solution);
|
||||
assert!(eval.score > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embed_produces_correct_dim() {
|
||||
let domain = RustSynthesisDomain::new();
|
||||
let solution = Solution {
|
||||
task_id: "test".into(),
|
||||
content: "fn foo() { let x = 1; }".into(),
|
||||
data: serde_json::Value::Null,
|
||||
};
|
||||
let embedding = domain.embed(&solution);
|
||||
assert_eq!(embedding.dim, EMBEDDING_DIM);
|
||||
assert_eq!(embedding.vector.len(), EMBEDDING_DIM);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_normalized() {
|
||||
let domain = RustSynthesisDomain::new();
|
||||
let solution = Solution {
|
||||
task_id: "test".into(),
|
||||
content: "fn foo() { for i in 0..10 { if i > 5 { println!(\"{}\", i); } } }".into(),
|
||||
data: serde_json::Value::Null,
|
||||
};
|
||||
let embedding = domain.embed(&solution);
|
||||
let norm: f32 = embedding.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
assert!((norm - 1.0).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_difficulty_range() {
|
||||
let domain = RustSynthesisDomain::new();
|
||||
// Easy tasks
|
||||
let easy = domain.generate_tasks(3, 0.1);
|
||||
for t in &easy {
|
||||
let spec: RustTaskSpec = serde_json::from_value(t.spec.clone()).unwrap();
|
||||
assert!(!spec.signature.is_empty());
|
||||
}
|
||||
// Hard tasks
|
||||
let hard = domain.generate_tasks(3, 0.9);
|
||||
for t in &hard {
|
||||
let spec: RustTaskSpec = serde_json::from_value(t.spec.clone()).unwrap();
|
||||
assert!(!spec.signature.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
715
vendor/ruvector/crates/ruvector-domain-expansion/src/rvf_bridge.rs
vendored
Normal file
715
vendor/ruvector/crates/ruvector-domain-expansion/src/rvf_bridge.rs
vendored
Normal file
@@ -0,0 +1,715 @@
|
||||
//! RVF Integration Bridge
|
||||
//!
|
||||
//! Connects the domain expansion engine to the RuVector Format (RVF):
|
||||
//! - Serializes `TransferPrior`, `PolicyKernel`, `CostCurve` into RVF segments
|
||||
//! - Creates SHAKE-256 witness chains for transfer verification
|
||||
//! - Packages domain expansion artifacts into AGI container TLV entries
|
||||
//! - Bridges priors to/from the rvf-solver-wasm `PolicyKernel`
|
||||
//!
|
||||
//! Requires the `rvf` feature to be enabled.
|
||||
|
||||
use rvf_types::{SegmentFlags, SegmentType};
|
||||
use rvf_wire::reader::{read_segment, validate_segment};
|
||||
use rvf_wire::writer::write_segment;
|
||||
|
||||
use crate::cost_curve::{AccelerationScoreboard, CostCurve};
|
||||
use crate::domain::DomainId;
|
||||
use crate::policy_kernel::PolicyKernel;
|
||||
use crate::transfer::{ArmId, BetaParams, ContextBucket, MetaThompsonEngine, TransferPrior};
|
||||
|
||||
// ─── Wire-format wrappers ───────────────────────────────────────────────────
|
||||
//
|
||||
// JSON requires string keys for objects. TransferPrior uses HashMap<ContextBucket, _>
|
||||
// which can't be directly serialized. These wrappers convert to/from Vec<(K,V)> form.
|
||||
|
||||
/// Wire-format representation of a TransferPrior (JSON-safe).
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct WireTransferPrior {
|
||||
source_domain: DomainId,
|
||||
bucket_priors: Vec<(ContextBucket, Vec<(ArmId, BetaParams)>)>,
|
||||
cost_ema_priors: Vec<(ContextBucket, f32)>,
|
||||
training_cycles: u64,
|
||||
witness_hash: String,
|
||||
}
|
||||
|
||||
impl From<&TransferPrior> for WireTransferPrior {
|
||||
fn from(p: &TransferPrior) -> Self {
|
||||
Self {
|
||||
source_domain: p.source_domain.clone(),
|
||||
bucket_priors: p
|
||||
.bucket_priors
|
||||
.iter()
|
||||
.map(|(b, arms)| {
|
||||
let arm_vec: Vec<(ArmId, BetaParams)> =
|
||||
arms.iter().map(|(a, bp)| (a.clone(), bp.clone())).collect();
|
||||
(b.clone(), arm_vec)
|
||||
})
|
||||
.collect(),
|
||||
cost_ema_priors: p
|
||||
.cost_ema_priors
|
||||
.iter()
|
||||
.map(|(b, c)| (b.clone(), *c))
|
||||
.collect(),
|
||||
training_cycles: p.training_cycles,
|
||||
witness_hash: p.witness_hash.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WireTransferPrior> for TransferPrior {
|
||||
fn from(w: WireTransferPrior) -> Self {
|
||||
let mut bucket_priors = std::collections::HashMap::new();
|
||||
for (bucket, arms) in w.bucket_priors {
|
||||
let arm_map: std::collections::HashMap<ArmId, BetaParams> = arms.into_iter().collect();
|
||||
bucket_priors.insert(bucket, arm_map);
|
||||
}
|
||||
let cost_ema_priors: std::collections::HashMap<ContextBucket, f32> =
|
||||
w.cost_ema_priors.into_iter().collect();
|
||||
Self {
|
||||
source_domain: w.source_domain,
|
||||
bucket_priors,
|
||||
cost_ema_priors,
|
||||
training_cycles: w.training_cycles,
|
||||
witness_hash: w.witness_hash,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wire-format representation of a PolicyKernel (JSON-safe).
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct WirePolicyKernel {
|
||||
id: String,
|
||||
knobs: crate::policy_kernel::PolicyKnobs,
|
||||
holdout_scores: Vec<(DomainId, f32)>,
|
||||
total_cost: f32,
|
||||
cycles: u64,
|
||||
generation: u32,
|
||||
parent_id: Option<String>,
|
||||
replay_verified: bool,
|
||||
}
|
||||
|
||||
impl From<&PolicyKernel> for WirePolicyKernel {
|
||||
fn from(k: &PolicyKernel) -> Self {
|
||||
Self {
|
||||
id: k.id.clone(),
|
||||
knobs: k.knobs.clone(),
|
||||
holdout_scores: k
|
||||
.holdout_scores
|
||||
.iter()
|
||||
.map(|(d, s)| (d.clone(), *s))
|
||||
.collect(),
|
||||
total_cost: k.total_cost,
|
||||
cycles: k.cycles,
|
||||
generation: k.generation,
|
||||
parent_id: k.parent_id.clone(),
|
||||
replay_verified: k.replay_verified,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WirePolicyKernel> for PolicyKernel {
|
||||
fn from(w: WirePolicyKernel) -> Self {
|
||||
Self {
|
||||
id: w.id,
|
||||
knobs: w.knobs,
|
||||
holdout_scores: w.holdout_scores.into_iter().collect(),
|
||||
total_cost: w.total_cost,
|
||||
cycles: w.cycles,
|
||||
generation: w.generation,
|
||||
parent_id: w.parent_id,
|
||||
replay_verified: w.replay_verified,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Segment serialization ──────────────────────────────────────────────────
|
||||
|
||||
/// Serialize a `TransferPrior` into an RVF TRANSFER_PRIOR segment.
|
||||
///
|
||||
/// Wire format: JSON payload (using Vec-of-tuples for map keys) inside a
|
||||
/// 64-byte-aligned RVF segment. Type: `SegmentType::TransferPrior` (0x30).
|
||||
pub fn transfer_prior_to_segment(prior: &TransferPrior, segment_id: u64) -> Vec<u8> {
|
||||
let wire: WireTransferPrior = prior.into();
|
||||
let payload = serde_json::to_vec(&wire).expect("WireTransferPrior serialization cannot fail");
|
||||
write_segment(
|
||||
SegmentType::TransferPrior as u8,
|
||||
&payload,
|
||||
SegmentFlags::empty(),
|
||||
segment_id,
|
||||
)
|
||||
}
|
||||
|
||||
/// Deserialize a `TransferPrior` from an RVF segment's raw bytes.
|
||||
///
|
||||
/// Validates the segment header, checks the content hash, and deserializes
|
||||
/// the JSON payload.
|
||||
pub fn transfer_prior_from_segment(data: &[u8]) -> Result<TransferPrior, RvfBridgeError> {
|
||||
let (header, payload) = read_segment(data).map_err(RvfBridgeError::Rvf)?;
|
||||
if header.seg_type != SegmentType::TransferPrior as u8 {
|
||||
return Err(RvfBridgeError::WrongSegmentType {
|
||||
expected: SegmentType::TransferPrior as u8,
|
||||
got: header.seg_type,
|
||||
});
|
||||
}
|
||||
validate_segment(&header, payload).map_err(RvfBridgeError::Rvf)?;
|
||||
let wire: WireTransferPrior = serde_json::from_slice(payload).map_err(RvfBridgeError::Json)?;
|
||||
Ok(wire.into())
|
||||
}
|
||||
|
||||
/// Serialize a `PolicyKernel` into an RVF POLICY_KERNEL segment.
|
||||
pub fn policy_kernel_to_segment(kernel: &PolicyKernel, segment_id: u64) -> Vec<u8> {
|
||||
let wire: WirePolicyKernel = kernel.into();
|
||||
let payload = serde_json::to_vec(&wire).expect("WirePolicyKernel serialization cannot fail");
|
||||
write_segment(
|
||||
SegmentType::PolicyKernel as u8,
|
||||
&payload,
|
||||
SegmentFlags::empty(),
|
||||
segment_id,
|
||||
)
|
||||
}
|
||||
|
||||
/// Deserialize a `PolicyKernel` from an RVF segment.
|
||||
pub fn policy_kernel_from_segment(data: &[u8]) -> Result<PolicyKernel, RvfBridgeError> {
|
||||
let (header, payload) = read_segment(data).map_err(RvfBridgeError::Rvf)?;
|
||||
if header.seg_type != SegmentType::PolicyKernel as u8 {
|
||||
return Err(RvfBridgeError::WrongSegmentType {
|
||||
expected: SegmentType::PolicyKernel as u8,
|
||||
got: header.seg_type,
|
||||
});
|
||||
}
|
||||
validate_segment(&header, payload).map_err(RvfBridgeError::Rvf)?;
|
||||
let wire: WirePolicyKernel = serde_json::from_slice(payload).map_err(RvfBridgeError::Json)?;
|
||||
Ok(wire.into())
|
||||
}
|
||||
|
||||
/// Serialize a `CostCurve` into an RVF COST_CURVE segment.
|
||||
pub fn cost_curve_to_segment(curve: &CostCurve, segment_id: u64) -> Vec<u8> {
|
||||
let payload = serde_json::to_vec(curve).expect("CostCurve serialization cannot fail");
|
||||
write_segment(
|
||||
SegmentType::CostCurve as u8,
|
||||
&payload,
|
||||
SegmentFlags::empty(),
|
||||
segment_id,
|
||||
)
|
||||
}
|
||||
|
||||
/// Deserialize a `CostCurve` from an RVF segment.
|
||||
pub fn cost_curve_from_segment(data: &[u8]) -> Result<CostCurve, RvfBridgeError> {
|
||||
let (header, payload) = read_segment(data).map_err(RvfBridgeError::Rvf)?;
|
||||
if header.seg_type != SegmentType::CostCurve as u8 {
|
||||
return Err(RvfBridgeError::WrongSegmentType {
|
||||
expected: SegmentType::CostCurve as u8,
|
||||
got: header.seg_type,
|
||||
});
|
||||
}
|
||||
validate_segment(&header, payload).map_err(RvfBridgeError::Rvf)?;
|
||||
serde_json::from_slice(payload).map_err(RvfBridgeError::Json)
|
||||
}
|
||||
|
||||
// ─── Witness chain ──────────────────────────────────────────────────────────
|
||||
|
||||
/// Witness type constants for domain expansion operations.
|
||||
pub const WITNESS_TRANSFER: u8 = 0x10;
|
||||
/// Witness type for policy kernel promotion.
|
||||
pub const WITNESS_POLICY_PROMOTION: u8 = 0x11;
|
||||
/// Witness type for cost curve convergence checkpoint.
|
||||
pub const WITNESS_CONVERGENCE: u8 = 0x12;
|
||||
|
||||
/// Create a SHAKE-256 witness hash for a transfer prior.
|
||||
///
|
||||
/// The witness hash covers: source domain, training cycles, and the serialized
|
||||
/// bucket priors. This replaces the old string-based `witness_hash` field.
|
||||
pub fn compute_transfer_witness_hash(prior: &TransferPrior) -> [u8; 32] {
|
||||
let wire: WireTransferPrior = prior.into();
|
||||
let payload = serde_json::to_vec(&wire).expect("WireTransferPrior serialization cannot fail");
|
||||
rvf_crypto::shake256_256(&payload)
|
||||
}
|
||||
|
||||
/// Build witness entries for a transfer verification event.
|
||||
///
|
||||
/// Returns entries suitable for `rvf_crypto::create_witness_chain()`.
|
||||
pub fn build_transfer_witness_entries(
|
||||
prior: &TransferPrior,
|
||||
source: &DomainId,
|
||||
target: &DomainId,
|
||||
acceleration_factor: f32,
|
||||
timestamp_ns: u64,
|
||||
) -> Vec<rvf_crypto::WitnessEntry> {
|
||||
let mut entries = Vec::with_capacity(2);
|
||||
|
||||
// Entry 1: Transfer prior hash
|
||||
let prior_hash = compute_transfer_witness_hash(prior);
|
||||
entries.push(rvf_crypto::WitnessEntry {
|
||||
prev_hash: [0u8; 32],
|
||||
action_hash: prior_hash,
|
||||
timestamp_ns,
|
||||
witness_type: WITNESS_TRANSFER,
|
||||
});
|
||||
|
||||
// Entry 2: Acceleration verification (hash of source→target + factor)
|
||||
let accel_payload = format!(
|
||||
"{}->{}:accel={:.6}",
|
||||
source.0, target.0, acceleration_factor
|
||||
);
|
||||
let accel_hash = rvf_crypto::shake256_256(accel_payload.as_bytes());
|
||||
entries.push(rvf_crypto::WitnessEntry {
|
||||
prev_hash: [0u8; 32], // chaining handled by create_witness_chain
|
||||
action_hash: accel_hash,
|
||||
timestamp_ns: timestamp_ns + 1,
|
||||
witness_type: WITNESS_CONVERGENCE,
|
||||
});
|
||||
|
||||
entries
|
||||
}
|
||||
|
||||
// ─── AGI Container TLV packaging ────────────────────────────────────────────
|
||||
|
||||
/// A TLV (Tag-Length-Value) entry for AGI container manifest packaging.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AgiTlvEntry {
|
||||
/// TLV tag (see `AGI_TAG_*` constants in rvf-types).
|
||||
pub tag: u16,
|
||||
/// Serialized value payload.
|
||||
pub value: Vec<u8>,
|
||||
}
|
||||
|
||||
/// Package domain expansion artifacts into AGI container TLV entries.
|
||||
///
|
||||
/// Returns a vector of TLV entries ready for inclusion in an AGI container
|
||||
/// manifest segment. Each entry uses the corresponding `AGI_TAG_*` constant.
|
||||
pub fn package_for_agi_container(
|
||||
priors: &[TransferPrior],
|
||||
kernels: &[PolicyKernel],
|
||||
scoreboard: &AccelerationScoreboard,
|
||||
) -> Vec<AgiTlvEntry> {
|
||||
let mut entries = Vec::new();
|
||||
|
||||
// Transfer priors (use wire format for JSON-safe serialization)
|
||||
for prior in priors {
|
||||
let wire: WireTransferPrior = prior.into();
|
||||
let value = serde_json::to_vec(&wire).expect("WireTransferPrior serialization cannot fail");
|
||||
entries.push(AgiTlvEntry {
|
||||
tag: rvf_types::AGI_TAG_TRANSFER_PRIOR,
|
||||
value,
|
||||
});
|
||||
}
|
||||
|
||||
// Policy kernels (use wire format for JSON-safe serialization)
|
||||
for kernel in kernels {
|
||||
let wire: WirePolicyKernel = kernel.into();
|
||||
let value = serde_json::to_vec(&wire).expect("WirePolicyKernel serialization cannot fail");
|
||||
entries.push(AgiTlvEntry {
|
||||
tag: rvf_types::AGI_TAG_POLICY_KERNEL,
|
||||
value,
|
||||
});
|
||||
}
|
||||
|
||||
// Cost curves from the scoreboard
|
||||
for curve in scoreboard.curves.values() {
|
||||
let value = serde_json::to_vec(curve).expect("CostCurve serialization cannot fail");
|
||||
entries.push(AgiTlvEntry {
|
||||
tag: rvf_types::AGI_TAG_COST_CURVE,
|
||||
value,
|
||||
});
|
||||
}
|
||||
|
||||
entries
|
||||
}
|
||||
|
||||
/// Encode TLV entries into a binary payload for inclusion in a META segment.
|
||||
///
|
||||
/// Wire format per entry: `[tag: u16 LE][length: u32 LE][value: length bytes]`
|
||||
pub fn encode_tlv_entries(entries: &[AgiTlvEntry]) -> Vec<u8> {
|
||||
let total_size: usize = entries.iter().map(|e| 6 + e.value.len()).sum();
|
||||
let mut buf = Vec::with_capacity(total_size);
|
||||
for entry in entries {
|
||||
buf.extend_from_slice(&entry.tag.to_le_bytes());
|
||||
buf.extend_from_slice(&(entry.value.len() as u32).to_le_bytes());
|
||||
buf.extend_from_slice(&entry.value);
|
||||
}
|
||||
buf
|
||||
}
|
||||
|
||||
/// Decode TLV entries from a binary payload.
|
||||
pub fn decode_tlv_entries(data: &[u8]) -> Result<Vec<AgiTlvEntry>, RvfBridgeError> {
|
||||
let mut entries = Vec::new();
|
||||
let mut offset = 0;
|
||||
while offset + 6 <= data.len() {
|
||||
let tag = u16::from_le_bytes([data[offset], data[offset + 1]]);
|
||||
let length = u32::from_le_bytes([
|
||||
data[offset + 2],
|
||||
data[offset + 3],
|
||||
data[offset + 4],
|
||||
data[offset + 5],
|
||||
]) as usize;
|
||||
offset += 6;
|
||||
if offset + length > data.len() {
|
||||
return Err(RvfBridgeError::TruncatedTlv);
|
||||
}
|
||||
entries.push(AgiTlvEntry {
|
||||
tag,
|
||||
value: data[offset..offset + length].to_vec(),
|
||||
});
|
||||
offset += length;
|
||||
}
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
// ─── Solver bridge ──────────────────────────────────────────────────────────
|
||||
|
||||
/// Compact prior exchange format bridging domain expansion's `MetaThompsonEngine`
|
||||
/// to the rvf-solver-wasm `PolicyKernel`.
|
||||
///
|
||||
/// The solver-wasm uses per-bucket `SkipModeStats` with `(alpha_safety, beta_safety)`
|
||||
/// and `cost_ema`. The domain expansion uses per-bucket `BetaParams` with
|
||||
/// `(alpha, beta)` and `cost_ema_priors`. This type converts between them.
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct SolverPriorExchange {
|
||||
/// Context bucket key (e.g. "medium:some:clean").
|
||||
pub bucket_key: String,
|
||||
/// Per-arm alpha/beta pairs mapping arm name to (alpha, beta).
|
||||
pub arm_params: Vec<(String, f32, f32)>,
|
||||
/// Cost EMA for this bucket.
|
||||
pub cost_ema: f32,
|
||||
/// Training cycle count for confidence estimation.
|
||||
pub training_cycles: u64,
|
||||
}
|
||||
|
||||
/// Extract solver-compatible prior exchange data from the Thompson engine.
|
||||
///
|
||||
/// Flattens the domain expansion's hierarchical buckets into the solver's
|
||||
/// flat "range:distractor:noise" keys for the specified domain.
|
||||
pub fn extract_solver_priors(
|
||||
engine: &MetaThompsonEngine,
|
||||
domain_id: &DomainId,
|
||||
) -> Vec<SolverPriorExchange> {
|
||||
let prior = match engine.extract_prior(domain_id) {
|
||||
Some(p) => p,
|
||||
None => return Vec::new(),
|
||||
};
|
||||
|
||||
prior
|
||||
.bucket_priors
|
||||
.iter()
|
||||
.map(|(bucket, arms)| {
|
||||
let bucket_key = format!("{}:{}", bucket.difficulty_tier, bucket.category);
|
||||
let arm_params: Vec<(String, f32, f32)> = arms
|
||||
.iter()
|
||||
.map(|(arm, params)| (arm.0.clone(), params.alpha, params.beta))
|
||||
.collect();
|
||||
let cost_ema = prior.cost_ema_priors.get(bucket).copied().unwrap_or(1.0);
|
||||
|
||||
SolverPriorExchange {
|
||||
bucket_key,
|
||||
arm_params,
|
||||
cost_ema,
|
||||
training_cycles: prior.training_cycles,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Import solver prior exchange data back into the Thompson engine.
|
||||
///
|
||||
/// Seeds the specified domain with the exchanged priors, enabling
|
||||
/// cross-system transfer.
|
||||
pub fn import_solver_priors(
|
||||
engine: &mut MetaThompsonEngine,
|
||||
domain_id: &DomainId,
|
||||
exchanges: &[SolverPriorExchange],
|
||||
) {
|
||||
// Build a synthetic TransferPrior from the exchange data.
|
||||
let mut prior = TransferPrior::uniform(domain_id.clone());
|
||||
|
||||
for exchange in exchanges {
|
||||
let parts: Vec<&str> = exchange.bucket_key.splitn(2, ':').collect();
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: parts.first().unwrap_or(&"medium").to_string(),
|
||||
category: parts.get(1).unwrap_or(&"general").to_string(),
|
||||
};
|
||||
|
||||
let mut arm_map = std::collections::HashMap::new();
|
||||
for (arm_name, alpha, beta) in &exchange.arm_params {
|
||||
arm_map.insert(
|
||||
crate::transfer::ArmId(arm_name.clone()),
|
||||
BetaParams {
|
||||
alpha: *alpha,
|
||||
beta: *beta,
|
||||
},
|
||||
);
|
||||
}
|
||||
prior.bucket_priors.insert(bucket.clone(), arm_map);
|
||||
prior.cost_ema_priors.insert(bucket, exchange.cost_ema);
|
||||
prior.training_cycles = exchange.training_cycles;
|
||||
}
|
||||
|
||||
engine.init_domain_with_transfer(domain_id.clone(), &prior);
|
||||
}
|
||||
|
||||
// ─── Multi-segment file assembly ────────────────────────────────────────────
|
||||
|
||||
/// Assemble a complete RVF byte stream containing all domain expansion segments.
|
||||
///
|
||||
/// Outputs concatenated segments: transfer priors, then policy kernels, then
|
||||
/// cost curves. Each gets a unique segment ID starting from `base_segment_id`.
|
||||
///
|
||||
/// The returned bytes can be appended to an existing RVF file or written as
|
||||
/// a standalone domain expansion archive.
|
||||
pub fn assemble_domain_expansion_segments(
|
||||
priors: &[TransferPrior],
|
||||
kernels: &[PolicyKernel],
|
||||
curves: &[CostCurve],
|
||||
base_segment_id: u64,
|
||||
) -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
let mut seg_id = base_segment_id;
|
||||
|
||||
for prior in priors {
|
||||
buf.extend_from_slice(&transfer_prior_to_segment(prior, seg_id));
|
||||
seg_id += 1;
|
||||
}
|
||||
for kernel in kernels {
|
||||
buf.extend_from_slice(&policy_kernel_to_segment(kernel, seg_id));
|
||||
seg_id += 1;
|
||||
}
|
||||
for curve in curves {
|
||||
buf.extend_from_slice(&cost_curve_to_segment(curve, seg_id));
|
||||
seg_id += 1;
|
||||
}
|
||||
|
||||
buf
|
||||
}
|
||||
|
||||
// ─── Errors ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Errors specific to the RVF bridge operations.
|
||||
#[derive(Debug)]
|
||||
pub enum RvfBridgeError {
|
||||
/// Underlying RVF format error.
|
||||
Rvf(rvf_types::RvfError),
|
||||
/// JSON serialization/deserialization error.
|
||||
Json(serde_json::Error),
|
||||
/// Segment type mismatch.
|
||||
WrongSegmentType {
|
||||
/// Expected segment type discriminant.
|
||||
expected: u8,
|
||||
/// Actual segment type discriminant.
|
||||
got: u8,
|
||||
},
|
||||
/// TLV payload truncated.
|
||||
TruncatedTlv,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for RvfBridgeError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Rvf(e) => write!(f, "RVF error: {e}"),
|
||||
Self::Json(e) => write!(f, "JSON error: {e}"),
|
||||
Self::WrongSegmentType { expected, got } => {
|
||||
write!(
|
||||
f,
|
||||
"wrong segment type: expected 0x{expected:02X}, got 0x{got:02X}"
|
||||
)
|
||||
}
|
||||
Self::TruncatedTlv => write!(f, "TLV payload truncated"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for RvfBridgeError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Self::Json(e) => Some(e),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::cost_curve::{ConvergenceThresholds, CostCurvePoint};
|
||||
|
||||
#[test]
|
||||
fn transfer_prior_round_trip() {
|
||||
let mut prior = TransferPrior::uniform(DomainId("test".into()));
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "medium".into(),
|
||||
category: "algo".into(),
|
||||
};
|
||||
prior.update_posterior(bucket, crate::transfer::ArmId("greedy".into()), 0.85);
|
||||
|
||||
let segment = transfer_prior_to_segment(&prior, 1);
|
||||
let decoded = transfer_prior_from_segment(&segment).unwrap();
|
||||
|
||||
assert_eq!(decoded.source_domain, prior.source_domain);
|
||||
assert_eq!(decoded.training_cycles, prior.training_cycles);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn policy_kernel_round_trip() {
|
||||
let kernel = PolicyKernel::new("test_kernel".into());
|
||||
let segment = policy_kernel_to_segment(&kernel, 2);
|
||||
let decoded = policy_kernel_from_segment(&segment).unwrap();
|
||||
|
||||
assert_eq!(decoded.id, "test_kernel");
|
||||
assert_eq!(decoded.generation, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cost_curve_round_trip() {
|
||||
let mut curve = CostCurve::new(DomainId("test".into()), ConvergenceThresholds::default());
|
||||
curve.record(CostCurvePoint {
|
||||
cycle: 0,
|
||||
accuracy: 0.3,
|
||||
cost_per_solve: 0.1,
|
||||
robustness: 0.3,
|
||||
policy_violations: 0,
|
||||
timestamp: 0.0,
|
||||
});
|
||||
|
||||
let segment = cost_curve_to_segment(&curve, 3);
|
||||
let decoded = cost_curve_from_segment(&segment).unwrap();
|
||||
|
||||
assert_eq!(decoded.domain_id, DomainId("test".into()));
|
||||
assert_eq!(decoded.points.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_segment_type_detected() {
|
||||
let kernel = PolicyKernel::new("k".into());
|
||||
let segment = policy_kernel_to_segment(&kernel, 1);
|
||||
let result = transfer_prior_from_segment(&segment);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(RvfBridgeError::WrongSegmentType { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn witness_hash_is_deterministic() {
|
||||
let prior = TransferPrior::uniform(DomainId("test".into()));
|
||||
let h1 = compute_transfer_witness_hash(&prior);
|
||||
let h2 = compute_transfer_witness_hash(&prior);
|
||||
assert_eq!(h1, h2);
|
||||
assert_ne!(h1, [0u8; 32]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn witness_entries_chain() {
|
||||
let prior = TransferPrior::uniform(DomainId("d1".into()));
|
||||
let entries = build_transfer_witness_entries(
|
||||
&prior,
|
||||
&DomainId("d1".into()),
|
||||
&DomainId("d2".into()),
|
||||
2.5,
|
||||
1_000_000_000,
|
||||
);
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert_eq!(entries[0].witness_type, WITNESS_TRANSFER);
|
||||
assert_eq!(entries[1].witness_type, WITNESS_CONVERGENCE);
|
||||
|
||||
// Verify the chain is valid after linking
|
||||
let chain_bytes = rvf_crypto::create_witness_chain(&entries);
|
||||
let verified = rvf_crypto::verify_witness_chain(&chain_bytes).unwrap();
|
||||
assert_eq!(verified.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tlv_round_trip() {
|
||||
let entries = vec![
|
||||
AgiTlvEntry {
|
||||
tag: rvf_types::AGI_TAG_TRANSFER_PRIOR,
|
||||
value: b"hello".to_vec(),
|
||||
},
|
||||
AgiTlvEntry {
|
||||
tag: rvf_types::AGI_TAG_POLICY_KERNEL,
|
||||
value: b"world".to_vec(),
|
||||
},
|
||||
];
|
||||
|
||||
let encoded = encode_tlv_entries(&entries);
|
||||
let decoded = decode_tlv_entries(&encoded).unwrap();
|
||||
|
||||
assert_eq!(decoded.len(), 2);
|
||||
assert_eq!(decoded[0].tag, rvf_types::AGI_TAG_TRANSFER_PRIOR);
|
||||
assert_eq!(decoded[0].value, b"hello");
|
||||
assert_eq!(decoded[1].tag, rvf_types::AGI_TAG_POLICY_KERNEL);
|
||||
assert_eq!(decoded[1].value, b"world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agi_container_packaging() {
|
||||
let prior = TransferPrior::uniform(DomainId("test".into()));
|
||||
let kernel = PolicyKernel::new("k0".into());
|
||||
let scoreboard = crate::cost_curve::AccelerationScoreboard::new();
|
||||
|
||||
let entries = package_for_agi_container(&[prior], &[kernel], &scoreboard);
|
||||
assert_eq!(entries.len(), 2); // 1 prior + 1 kernel, 0 curves
|
||||
|
||||
let encoded = encode_tlv_entries(&entries);
|
||||
let decoded = decode_tlv_entries(&encoded).unwrap();
|
||||
assert_eq!(decoded.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn solver_prior_exchange_round_trip() {
|
||||
let arms = vec!["greedy".into(), "exploratory".into()];
|
||||
let mut engine = MetaThompsonEngine::new(arms);
|
||||
let domain = DomainId("test".into());
|
||||
engine.init_domain_uniform(domain.clone());
|
||||
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "medium".into(),
|
||||
category: "algorithm".into(),
|
||||
};
|
||||
for _ in 0..20 {
|
||||
engine.record_outcome(
|
||||
&domain,
|
||||
bucket.clone(),
|
||||
crate::transfer::ArmId("greedy".into()),
|
||||
0.9,
|
||||
1.0,
|
||||
);
|
||||
}
|
||||
|
||||
let exchanges = extract_solver_priors(&engine, &domain);
|
||||
assert!(!exchanges.is_empty());
|
||||
|
||||
// Import into a fresh engine
|
||||
let new_arms = vec!["greedy".into(), "exploratory".into()];
|
||||
let mut new_engine = MetaThompsonEngine::new(new_arms);
|
||||
let target = DomainId("target".into());
|
||||
new_engine.init_domain_uniform(target.clone());
|
||||
import_solver_priors(&mut new_engine, &target, &exchanges);
|
||||
|
||||
// Should have transferred priors
|
||||
let extracted = new_engine.extract_prior(&target);
|
||||
assert!(extracted.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multi_segment_assembly() {
|
||||
let prior = TransferPrior::uniform(DomainId("d1".into()));
|
||||
let kernel = PolicyKernel::new("k0".into());
|
||||
let mut curve = CostCurve::new(DomainId("d1".into()), ConvergenceThresholds::default());
|
||||
curve.record(CostCurvePoint {
|
||||
cycle: 0,
|
||||
accuracy: 0.5,
|
||||
cost_per_solve: 0.05,
|
||||
robustness: 0.5,
|
||||
policy_violations: 0,
|
||||
timestamp: 0.0,
|
||||
});
|
||||
|
||||
let assembled = assemble_domain_expansion_segments(&[prior], &[kernel], &[curve], 100);
|
||||
|
||||
// Should contain 3 segments, each 64-byte aligned
|
||||
assert!(assembled.len() >= 3 * 64);
|
||||
assert_eq!(assembled.len() % 64, 0);
|
||||
|
||||
// Verify first segment header magic
|
||||
let magic = u32::from_le_bytes([assembled[0], assembled[1], assembled[2], assembled[3]]);
|
||||
assert_eq!(magic, rvf_types::SEGMENT_MAGIC);
|
||||
}
|
||||
}
|
||||
727
vendor/ruvector/crates/ruvector-domain-expansion/src/tool_orchestration.rs
vendored
Normal file
727
vendor/ruvector/crates/ruvector-domain-expansion/src/tool_orchestration.rs
vendored
Normal file
@@ -0,0 +1,727 @@
|
||||
//! Tool Orchestration Problems Domain
|
||||
//!
|
||||
//! Generates tasks requiring coordinating multiple tools/agents to achieve goals.
|
||||
//! Task types include:
|
||||
//!
|
||||
//! - **PipelineConstruction**: Build a data processing pipeline from available tools
|
||||
//! - **ErrorRecovery**: Handle failures in multi-step tool chains
|
||||
//! - **ParallelCoordination**: Execute independent tool calls concurrently
|
||||
//! - **ResourceNegotiation**: Manage shared resources across tool invocations
|
||||
//! - **AdaptiveRouting**: Select tools dynamically based on intermediate results
|
||||
//!
|
||||
//! Cross-domain transfer is strongest here: planning decomposes goals,
|
||||
//! Rust synthesis provides execution patterns, and orchestration combines them.
|
||||
|
||||
use crate::domain::{Domain, DomainEmbedding, DomainId, Evaluation, Solution, Task};
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const EMBEDDING_DIM: usize = 64;
|
||||
|
||||
/// Categories of tool orchestration tasks.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum OrchestrationCategory {
|
||||
/// Build a pipeline: chain tools to transform input to desired output.
|
||||
PipelineConstruction,
|
||||
/// Handle failure: detect errors and apply fallback strategies.
|
||||
ErrorRecovery,
|
||||
/// Coordinate parallel: dispatch independent calls and merge results.
|
||||
ParallelCoordination,
|
||||
/// Negotiate resources: manage rate limits, quotas, shared state.
|
||||
ResourceNegotiation,
|
||||
/// Adaptive routing: choose tool based on intermediate result properties.
|
||||
AdaptiveRouting,
|
||||
}
|
||||
|
||||
/// A tool available in the orchestration environment.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolSpec {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
/// Input type signature (e.g., "text", "json", "binary").
|
||||
pub input_type: String,
|
||||
/// Output type signature.
|
||||
pub output_type: String,
|
||||
/// Average latency in milliseconds.
|
||||
pub latency_ms: u32,
|
||||
/// Failure rate [0.0, 1.0].
|
||||
pub failure_rate: f32,
|
||||
/// Cost per invocation.
|
||||
pub cost: f32,
|
||||
/// Rate limit (max calls per minute), 0 = unlimited.
|
||||
pub rate_limit: u32,
|
||||
}
|
||||
|
||||
/// An orchestration task specification.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OrchestrationTaskSpec {
|
||||
pub category: OrchestrationCategory,
|
||||
pub description: String,
|
||||
/// Available tools in the environment.
|
||||
pub available_tools: Vec<ToolSpec>,
|
||||
/// Input to the pipeline.
|
||||
pub input: serde_json::Value,
|
||||
/// Expected output type/shape.
|
||||
pub expected_output_type: String,
|
||||
/// Maximum total latency budget (ms).
|
||||
pub latency_budget_ms: u32,
|
||||
/// Maximum total cost budget.
|
||||
pub cost_budget: f32,
|
||||
/// Required reliability (min success rate).
|
||||
pub min_reliability: f32,
|
||||
/// Error scenarios that must be handled.
|
||||
pub error_scenarios: Vec<String>,
|
||||
}
|
||||
|
||||
/// A tool call in an orchestration solution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
pub tool_name: String,
|
||||
/// Input to this tool call (ref to previous output or literal).
|
||||
pub input_ref: String,
|
||||
/// Whether this can run in parallel with other calls.
|
||||
pub parallel_group: Option<u32>,
|
||||
/// Fallback tool if this one fails.
|
||||
pub fallback: Option<String>,
|
||||
/// Retry count on failure.
|
||||
pub retries: u32,
|
||||
}
|
||||
|
||||
/// A parsed orchestration plan.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OrchestrationPlan {
|
||||
pub calls: Vec<ToolCall>,
|
||||
/// Error handling strategy description.
|
||||
pub error_strategy: String,
|
||||
}
|
||||
|
||||
/// Tool orchestration domain.
|
||||
pub struct ToolOrchestrationDomain {
|
||||
id: DomainId,
|
||||
}
|
||||
|
||||
impl ToolOrchestrationDomain {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
id: DomainId("tool_orchestration".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn base_tools() -> Vec<ToolSpec> {
|
||||
vec![
|
||||
ToolSpec {
|
||||
name: "text_extract".into(),
|
||||
description: "Extract text from documents".into(),
|
||||
input_type: "binary".into(),
|
||||
output_type: "text".into(),
|
||||
latency_ms: 50,
|
||||
failure_rate: 0.02,
|
||||
cost: 0.001,
|
||||
rate_limit: 100,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "text_embed".into(),
|
||||
description: "Generate embeddings from text".into(),
|
||||
input_type: "text".into(),
|
||||
output_type: "vector".into(),
|
||||
latency_ms: 30,
|
||||
failure_rate: 0.01,
|
||||
cost: 0.002,
|
||||
rate_limit: 200,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "vector_search".into(),
|
||||
description: "Search vector index for similar items".into(),
|
||||
input_type: "vector".into(),
|
||||
output_type: "json".into(),
|
||||
latency_ms: 10,
|
||||
failure_rate: 0.005,
|
||||
cost: 0.0005,
|
||||
rate_limit: 500,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "llm_generate".into(),
|
||||
description: "Generate text using language model".into(),
|
||||
input_type: "text".into(),
|
||||
output_type: "text".into(),
|
||||
latency_ms: 2000,
|
||||
failure_rate: 0.05,
|
||||
cost: 0.01,
|
||||
rate_limit: 30,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "json_transform".into(),
|
||||
description: "Apply JQ-like transformations to JSON".into(),
|
||||
input_type: "json".into(),
|
||||
output_type: "json".into(),
|
||||
latency_ms: 5,
|
||||
failure_rate: 0.001,
|
||||
cost: 0.0001,
|
||||
rate_limit: 0,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "code_execute".into(),
|
||||
description: "Execute code in sandboxed environment".into(),
|
||||
input_type: "text".into(),
|
||||
output_type: "json".into(),
|
||||
latency_ms: 500,
|
||||
failure_rate: 0.1,
|
||||
cost: 0.005,
|
||||
rate_limit: 20,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "http_fetch".into(),
|
||||
description: "Fetch data from external HTTP endpoint".into(),
|
||||
input_type: "text".into(),
|
||||
output_type: "json".into(),
|
||||
latency_ms: 300,
|
||||
failure_rate: 0.15,
|
||||
cost: 0.0,
|
||||
rate_limit: 60,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "cache_lookup".into(),
|
||||
description: "Check local cache for previously computed results".into(),
|
||||
input_type: "text".into(),
|
||||
output_type: "json".into(),
|
||||
latency_ms: 1,
|
||||
failure_rate: 0.0,
|
||||
cost: 0.0,
|
||||
rate_limit: 0,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "validator".into(),
|
||||
description: "Validate output against schema".into(),
|
||||
input_type: "json".into(),
|
||||
output_type: "json".into(),
|
||||
latency_ms: 2,
|
||||
failure_rate: 0.0,
|
||||
cost: 0.0,
|
||||
rate_limit: 0,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "aggregator".into(),
|
||||
description: "Merge multiple results into one".into(),
|
||||
input_type: "json".into(),
|
||||
output_type: "json".into(),
|
||||
latency_ms: 5,
|
||||
failure_rate: 0.0,
|
||||
cost: 0.0001,
|
||||
rate_limit: 0,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
fn gen_pipeline(&self, difficulty: f32) -> OrchestrationTaskSpec {
|
||||
let tools = Self::base_tools();
|
||||
let num_tools = if difficulty < 0.3 {
|
||||
3
|
||||
} else if difficulty < 0.7 {
|
||||
6
|
||||
} else {
|
||||
10
|
||||
};
|
||||
|
||||
OrchestrationTaskSpec {
|
||||
category: OrchestrationCategory::PipelineConstruction,
|
||||
description: format!(
|
||||
"Build a RAG pipeline using {} tools: extract, embed, search, generate.",
|
||||
num_tools
|
||||
),
|
||||
available_tools: tools[..num_tools.min(tools.len())].to_vec(),
|
||||
input: serde_json::json!({"type": "binary", "format": "pdf"}),
|
||||
expected_output_type: "text".into(),
|
||||
latency_budget_ms: if difficulty < 0.5 { 5000 } else { 2000 },
|
||||
cost_budget: if difficulty < 0.5 { 0.1 } else { 0.02 },
|
||||
min_reliability: if difficulty < 0.5 { 0.9 } else { 0.99 },
|
||||
error_scenarios: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_error_recovery(&self, difficulty: f32) -> OrchestrationTaskSpec {
|
||||
let tools = Self::base_tools();
|
||||
let error_scenarios = if difficulty < 0.3 {
|
||||
vec!["timeout on llm_generate".into()]
|
||||
} else if difficulty < 0.7 {
|
||||
vec![
|
||||
"timeout on llm_generate".into(),
|
||||
"http_fetch returns 429".into(),
|
||||
"code_execute sandbox OOM".into(),
|
||||
]
|
||||
} else {
|
||||
vec![
|
||||
"timeout on llm_generate".into(),
|
||||
"http_fetch returns 429".into(),
|
||||
"code_execute sandbox OOM".into(),
|
||||
"vector_search index corruption".into(),
|
||||
"cascading failure: embed + search both down".into(),
|
||||
]
|
||||
};
|
||||
|
||||
OrchestrationTaskSpec {
|
||||
category: OrchestrationCategory::ErrorRecovery,
|
||||
description: format!(
|
||||
"Handle {} error scenarios in a multi-tool pipeline with graceful degradation.",
|
||||
error_scenarios.len()
|
||||
),
|
||||
available_tools: tools,
|
||||
input: serde_json::json!({"type": "text", "content": "query"}),
|
||||
expected_output_type: "json".into(),
|
||||
latency_budget_ms: 10000,
|
||||
cost_budget: 0.1,
|
||||
min_reliability: 0.95,
|
||||
error_scenarios,
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_parallel_coordination(&self, difficulty: f32) -> OrchestrationTaskSpec {
|
||||
let tools = Self::base_tools();
|
||||
let parallelism = if difficulty < 0.3 {
|
||||
2
|
||||
} else if difficulty < 0.7 {
|
||||
4
|
||||
} else {
|
||||
8
|
||||
};
|
||||
|
||||
OrchestrationTaskSpec {
|
||||
category: OrchestrationCategory::ParallelCoordination,
|
||||
description: format!(
|
||||
"Execute {} independent tool chains in parallel, merge results within latency budget.",
|
||||
parallelism
|
||||
),
|
||||
available_tools: tools,
|
||||
input: serde_json::json!({"queries": (0..parallelism).map(|i| format!("query_{}", i)).collect::<Vec<_>>()}),
|
||||
expected_output_type: "json".into(),
|
||||
latency_budget_ms: if difficulty < 0.5 { 3000 } else { 1000 },
|
||||
cost_budget: 0.05 * parallelism as f32,
|
||||
min_reliability: 0.95,
|
||||
error_scenarios: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_features(&self, solution: &Solution) -> Vec<f32> {
|
||||
let content = &solution.content;
|
||||
let mut features = vec![0.0f32; EMBEDDING_DIM];
|
||||
|
||||
let plan: OrchestrationPlan = serde_json::from_str(&solution.data.to_string())
|
||||
.or_else(|_| serde_json::from_str(content))
|
||||
.unwrap_or(OrchestrationPlan {
|
||||
calls: Vec::new(),
|
||||
error_strategy: String::new(),
|
||||
});
|
||||
|
||||
// Feature 0-7: Plan structure
|
||||
features[0] = plan.calls.len() as f32 / 20.0;
|
||||
let unique_tools: std::collections::HashSet<&str> =
|
||||
plan.calls.iter().map(|c| c.tool_name.as_str()).collect();
|
||||
features[1] = unique_tools.len() as f32 / 10.0;
|
||||
// Parallelism ratio
|
||||
let parallel_calls = plan
|
||||
.calls
|
||||
.iter()
|
||||
.filter(|c| c.parallel_group.is_some())
|
||||
.count();
|
||||
features[2] = parallel_calls as f32 / plan.calls.len().max(1) as f32;
|
||||
// Fallback coverage
|
||||
let fallback_calls = plan.calls.iter().filter(|c| c.fallback.is_some()).count();
|
||||
features[3] = fallback_calls as f32 / plan.calls.len().max(1) as f32;
|
||||
// Average retries
|
||||
let total_retries: u32 = plan.calls.iter().map(|c| c.retries).sum();
|
||||
features[4] = total_retries as f32 / plan.calls.len().max(1) as f32 / 5.0;
|
||||
|
||||
// Feature 8-15: Tool type usage
|
||||
let tool_names = [
|
||||
"extract",
|
||||
"embed",
|
||||
"search",
|
||||
"generate",
|
||||
"transform",
|
||||
"execute",
|
||||
"fetch",
|
||||
"cache",
|
||||
];
|
||||
for (i, name) in tool_names.iter().enumerate() {
|
||||
features[8 + i] = plan
|
||||
.calls
|
||||
.iter()
|
||||
.filter(|c| c.tool_name.contains(name))
|
||||
.count() as f32
|
||||
/ plan.calls.len().max(1) as f32;
|
||||
}
|
||||
|
||||
// Feature 16-23: Text pattern features
|
||||
features[16] = content.matches("pipeline").count() as f32 / 3.0;
|
||||
features[17] = content.matches("parallel").count() as f32 / 5.0;
|
||||
features[18] = content.matches("fallback").count() as f32 / 5.0;
|
||||
features[19] = content.matches("retry").count() as f32 / 5.0;
|
||||
features[20] = content.matches("cache").count() as f32 / 5.0;
|
||||
features[21] = content.matches("timeout").count() as f32 / 3.0;
|
||||
features[22] = content.matches("merge").count() as f32 / 3.0;
|
||||
features[23] = content.matches("validate").count() as f32 / 3.0;
|
||||
|
||||
// Feature 32-39: Error handling patterns
|
||||
features[32] = content.matches("error").count() as f32 / 5.0;
|
||||
features[33] = content.matches("recover").count() as f32 / 3.0;
|
||||
features[34] = content.matches("degrade").count() as f32 / 3.0;
|
||||
features[35] = content.matches("circuit_break").count() as f32 / 2.0;
|
||||
features[36] = content.matches("rate_limit").count() as f32 / 3.0;
|
||||
features[37] = content.matches("backoff").count() as f32 / 3.0;
|
||||
features[38] = content.matches("health_check").count() as f32 / 2.0;
|
||||
features[39] = content.matches("monitor").count() as f32 / 3.0;
|
||||
|
||||
// Feature 48-55: Coordination patterns
|
||||
features[48] = content.matches("scatter").count() as f32 / 2.0;
|
||||
features[49] = content.matches("gather").count() as f32 / 2.0;
|
||||
features[50] = content.matches("fan_out").count() as f32 / 2.0;
|
||||
features[51] = content.matches("aggregate").count() as f32 / 3.0;
|
||||
features[52] = content.matches("route").count() as f32 / 3.0;
|
||||
features[53] = content.matches("dispatch").count() as f32 / 3.0;
|
||||
features[54] = content.matches("await").count() as f32 / 5.0;
|
||||
features[55] = content.matches("join").count() as f32 / 3.0;
|
||||
|
||||
// Normalize
|
||||
let norm: f32 = features.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 1e-10 {
|
||||
for f in &mut features {
|
||||
*f /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
features
|
||||
}
|
||||
|
||||
fn score_orchestration(&self, spec: &OrchestrationTaskSpec, solution: &Solution) -> Evaluation {
|
||||
let content = &solution.content;
|
||||
let mut correctness = 0.0f32;
|
||||
let mut efficiency = 0.5f32;
|
||||
let mut elegance = 0.5f32;
|
||||
let mut notes = Vec::new();
|
||||
|
||||
let plan: Option<OrchestrationPlan> = serde_json::from_str(&solution.data.to_string())
|
||||
.ok()
|
||||
.or_else(|| serde_json::from_str(content).ok());
|
||||
|
||||
let plan = match plan {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
let has_tools = spec
|
||||
.available_tools
|
||||
.iter()
|
||||
.any(|t| content.contains(&t.name));
|
||||
if has_tools {
|
||||
correctness = 0.2;
|
||||
}
|
||||
return Evaluation {
|
||||
score: correctness * 0.6,
|
||||
correctness,
|
||||
efficiency: 0.0,
|
||||
elegance: 0.0,
|
||||
constraint_results: Vec::new(),
|
||||
notes: vec!["Could not parse orchestration plan".into()],
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
if plan.calls.is_empty() {
|
||||
return Evaluation::zero(vec!["Empty orchestration plan".into()]);
|
||||
}
|
||||
|
||||
// Correctness: type chain validity
|
||||
let mut type_errors = 0;
|
||||
for window in plan.calls.windows(2) {
|
||||
let output_tool = spec
|
||||
.available_tools
|
||||
.iter()
|
||||
.find(|t| t.name == window[0].tool_name);
|
||||
let input_tool = spec
|
||||
.available_tools
|
||||
.iter()
|
||||
.find(|t| t.name == window[1].tool_name);
|
||||
|
||||
if let (Some(out_t), Some(in_t)) = (output_tool, input_tool) {
|
||||
if window[1].parallel_group.is_none() && out_t.output_type != in_t.input_type {
|
||||
type_errors += 1;
|
||||
notes.push(format!(
|
||||
"Type mismatch: {} outputs {} but {} expects {}",
|
||||
out_t.name, out_t.output_type, in_t.name, in_t.input_type
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
let chain_len = (plan.calls.len() - 1).max(1);
|
||||
correctness = 1.0 - (type_errors as f32 / chain_len as f32);
|
||||
|
||||
// Tool coverage: do we use tools that produce the expected output?
|
||||
let produces_output = plan.calls.iter().any(|c| {
|
||||
spec.available_tools
|
||||
.iter()
|
||||
.any(|t| t.name == c.tool_name && t.output_type == spec.expected_output_type)
|
||||
});
|
||||
if !produces_output {
|
||||
correctness *= 0.5;
|
||||
notes.push("No tool produces the expected output type".into());
|
||||
}
|
||||
|
||||
// Error handling coverage
|
||||
if !spec.error_scenarios.is_empty() {
|
||||
let handled = spec
|
||||
.error_scenarios
|
||||
.iter()
|
||||
.filter(|scenario| {
|
||||
plan.calls
|
||||
.iter()
|
||||
.any(|c| c.fallback.is_some() || c.retries > 0)
|
||||
|| plan
|
||||
.error_strategy
|
||||
.contains(&scenario.as_str()[..scenario.len().min(10)])
|
||||
})
|
||||
.count() as f32
|
||||
/ spec.error_scenarios.len() as f32;
|
||||
correctness = correctness * 0.7 + handled * 0.3;
|
||||
}
|
||||
|
||||
// Efficiency: estimated latency and cost
|
||||
let est_latency: u32 = {
|
||||
let mut groups: std::collections::HashMap<u32, u32> = std::collections::HashMap::new();
|
||||
let mut sequential_latency = 0u32;
|
||||
for call in &plan.calls {
|
||||
let tool_latency = spec
|
||||
.available_tools
|
||||
.iter()
|
||||
.find(|t| t.name == call.tool_name)
|
||||
.map(|t| t.latency_ms)
|
||||
.unwrap_or(100);
|
||||
|
||||
if let Some(group) = call.parallel_group {
|
||||
let entry = groups.entry(group).or_insert(0);
|
||||
*entry = (*entry).max(tool_latency);
|
||||
} else {
|
||||
sequential_latency += tool_latency;
|
||||
}
|
||||
}
|
||||
sequential_latency + groups.values().sum::<u32>()
|
||||
};
|
||||
|
||||
if est_latency <= spec.latency_budget_ms {
|
||||
efficiency = 1.0 - (est_latency as f32 / spec.latency_budget_ms as f32 * 0.5);
|
||||
} else {
|
||||
efficiency = spec.latency_budget_ms as f32 / est_latency as f32 * 0.5;
|
||||
notes.push(format!(
|
||||
"Estimated latency {}ms exceeds budget {}ms",
|
||||
est_latency, spec.latency_budget_ms
|
||||
));
|
||||
}
|
||||
|
||||
let est_cost: f32 = plan
|
||||
.calls
|
||||
.iter()
|
||||
.filter_map(|c| {
|
||||
spec.available_tools
|
||||
.iter()
|
||||
.find(|t| t.name == c.tool_name)
|
||||
.map(|t| t.cost * (1.0 + c.retries as f32))
|
||||
})
|
||||
.sum();
|
||||
|
||||
if est_cost > spec.cost_budget {
|
||||
efficiency *= 0.7;
|
||||
notes.push(format!(
|
||||
"Cost {:.4} exceeds budget {:.4}",
|
||||
est_cost, spec.cost_budget
|
||||
));
|
||||
}
|
||||
|
||||
// Elegance: parallelism, caching, minimal redundancy
|
||||
let parallelism_used = plan.calls.iter().any(|c| c.parallel_group.is_some());
|
||||
if parallelism_used {
|
||||
elegance += 0.15;
|
||||
}
|
||||
|
||||
let cache_used = plan.calls.iter().any(|c| c.tool_name.contains("cache"));
|
||||
if cache_used {
|
||||
elegance += 0.1;
|
||||
}
|
||||
|
||||
let validation_used = plan.calls.iter().any(|c| c.tool_name.contains("validat"));
|
||||
if validation_used {
|
||||
elegance += 0.1;
|
||||
}
|
||||
|
||||
// Penalize excessive retries
|
||||
let total_retries: u32 = plan.calls.iter().map(|c| c.retries).sum();
|
||||
if total_retries > plan.calls.len() as u32 * 2 {
|
||||
elegance -= 0.2;
|
||||
notes.push("Excessive retry configuration".into());
|
||||
}
|
||||
|
||||
elegance = elegance.clamp(0.0, 1.0);
|
||||
|
||||
let score = 0.6 * correctness + 0.25 * efficiency + 0.15 * elegance;
|
||||
Evaluation {
|
||||
score: score.clamp(0.0, 1.0),
|
||||
correctness,
|
||||
efficiency,
|
||||
elegance,
|
||||
constraint_results: Vec::new(),
|
||||
notes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ToolOrchestrationDomain {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Domain for ToolOrchestrationDomain {
|
||||
fn id(&self) -> &DomainId {
|
||||
&self.id
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"Tool Orchestration"
|
||||
}
|
||||
|
||||
fn generate_tasks(&self, count: usize, difficulty: f32) -> Vec<Task> {
|
||||
let mut rng = rand::thread_rng();
|
||||
let difficulty = difficulty.clamp(0.0, 1.0);
|
||||
|
||||
(0..count)
|
||||
.map(|i| {
|
||||
let roll: f32 = rng.gen();
|
||||
let spec = if roll < 0.4 {
|
||||
self.gen_pipeline(difficulty)
|
||||
} else if roll < 0.7 {
|
||||
self.gen_error_recovery(difficulty)
|
||||
} else {
|
||||
self.gen_parallel_coordination(difficulty)
|
||||
};
|
||||
|
||||
Task {
|
||||
id: format!("orch_{}_d{:.0}", i, difficulty * 100.0),
|
||||
domain_id: self.id.clone(),
|
||||
difficulty,
|
||||
spec: serde_json::to_value(&spec).unwrap_or_default(),
|
||||
constraints: Vec::new(),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn evaluate(&self, task: &Task, solution: &Solution) -> Evaluation {
|
||||
let spec: OrchestrationTaskSpec = match serde_json::from_value(task.spec.clone()) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Evaluation::zero(vec![format!("Invalid task spec: {}", e)]),
|
||||
};
|
||||
self.score_orchestration(&spec, solution)
|
||||
}
|
||||
|
||||
fn embed(&self, solution: &Solution) -> DomainEmbedding {
|
||||
let features = self.extract_features(solution);
|
||||
DomainEmbedding::new(features, self.id.clone())
|
||||
}
|
||||
|
||||
fn embedding_dim(&self) -> usize {
|
||||
EMBEDDING_DIM
|
||||
}
|
||||
|
||||
fn reference_solution(&self, task: &Task) -> Option<Solution> {
|
||||
let spec: OrchestrationTaskSpec = serde_json::from_value(task.spec.clone()).ok()?;
|
||||
|
||||
// Build a sequential pipeline through available tools
|
||||
let calls: Vec<ToolCall> = spec
|
||||
.available_tools
|
||||
.iter()
|
||||
.map(|t| ToolCall {
|
||||
tool_name: t.name.clone(),
|
||||
input_ref: "previous".into(),
|
||||
parallel_group: None,
|
||||
fallback: None,
|
||||
retries: if t.failure_rate > 0.05 { 2 } else { 0 },
|
||||
})
|
||||
.collect();
|
||||
|
||||
let plan = OrchestrationPlan {
|
||||
calls,
|
||||
error_strategy: "retry with exponential backoff".into(),
|
||||
};
|
||||
|
||||
let content = serde_json::to_string_pretty(&plan).ok()?;
|
||||
Some(Solution {
|
||||
task_id: task.id.clone(),
|
||||
content,
|
||||
data: serde_json::to_value(&plan).ok()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_orchestration_tasks() {
|
||||
let domain = ToolOrchestrationDomain::new();
|
||||
let tasks = domain.generate_tasks(5, 0.5);
|
||||
assert_eq!(tasks.len(), 5);
|
||||
for task in &tasks {
|
||||
assert_eq!(task.domain_id, domain.id);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reference_solution() {
|
||||
let domain = ToolOrchestrationDomain::new();
|
||||
let tasks = domain.generate_tasks(3, 0.3);
|
||||
for task in &tasks {
|
||||
let ref_sol = domain.reference_solution(task);
|
||||
assert!(ref_sol.is_some());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluate_reference() {
|
||||
let domain = ToolOrchestrationDomain::new();
|
||||
let tasks = domain.generate_tasks(3, 0.3);
|
||||
for task in &tasks {
|
||||
if let Some(solution) = domain.reference_solution(task) {
|
||||
let eval = domain.evaluate(task, &solution);
|
||||
assert!(eval.score >= 0.0 && eval.score <= 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embed_orchestration() {
|
||||
let domain = ToolOrchestrationDomain::new();
|
||||
let solution = Solution {
|
||||
task_id: "test".into(),
|
||||
content: "pipeline: extract -> embed -> search with fallback and retry".into(),
|
||||
data: serde_json::json!({
|
||||
"calls": [
|
||||
{"tool_name": "text_extract", "input_ref": "input", "retries": 1}
|
||||
],
|
||||
"error_strategy": "retry"
|
||||
}),
|
||||
};
|
||||
let embedding = domain.embed(&solution);
|
||||
assert_eq!(embedding.dim, EMBEDDING_DIM);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_difficulty_affects_error_scenarios() {
|
||||
let domain = ToolOrchestrationDomain::new();
|
||||
// Generate many tasks at high difficulty to get error recovery tasks
|
||||
let hard = domain.generate_tasks(20, 0.9);
|
||||
let has_error_tasks = hard.iter().any(|t| {
|
||||
let spec: OrchestrationTaskSpec = serde_json::from_value(t.spec.clone()).unwrap();
|
||||
!spec.error_scenarios.is_empty()
|
||||
});
|
||||
assert!(
|
||||
has_error_tasks,
|
||||
"High difficulty should produce error scenarios"
|
||||
);
|
||||
}
|
||||
}
|
||||
583
vendor/ruvector/crates/ruvector-domain-expansion/src/transfer.rs
vendored
Normal file
583
vendor/ruvector/crates/ruvector-domain-expansion/src/transfer.rs
vendored
Normal file
@@ -0,0 +1,583 @@
|
||||
//! Cross-Domain Transfer Engine with Meta Thompson Sampling
|
||||
//!
|
||||
//! Transfer happens through priors, not raw memories.
|
||||
//! Ship compact priors and verified kernels between domains.
|
||||
//!
|
||||
//! ## Two-Layer Learning Architecture
|
||||
//!
|
||||
//! **Policy learning layer**: Chooses strategies, budgets, and tool paths
|
||||
//! using uncertainty-aware selection (Thompson Sampling with Beta priors).
|
||||
//!
|
||||
//! **Operator layer**: Executes deterministic kernels and graders,
|
||||
//! logs witnesses, and commits state through gates.
|
||||
//!
|
||||
//! ## Meta Thompson Sampling
|
||||
//!
|
||||
//! After each cycle, compute posterior summary per bucket and arm.
|
||||
//! Store as TransferPrior. When a new domain starts, initialize its
|
||||
//! buckets with these priors instead of uniform, enabling faster adaptation.
|
||||
//!
|
||||
//! ## Cross-Domain Transfer Protocol
|
||||
//!
|
||||
//! A delta is promotable only if it improves Domain 2 without regressing
|
||||
//! Domain 1, or improves Domain 1 without regressing Domain 2.
|
||||
//! That is generalization.
|
||||
|
||||
use crate::domain::DomainId;
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Beta distribution parameters for Thompson Sampling.
|
||||
/// Represents uncertainty about an arm's reward probability.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BetaParams {
|
||||
/// Success count + prior (alpha).
|
||||
pub alpha: f32,
|
||||
/// Failure count + prior (beta).
|
||||
pub beta: f32,
|
||||
}
|
||||
|
||||
impl BetaParams {
|
||||
/// Uniform (uninformative) prior: Beta(1, 1).
|
||||
pub fn uniform() -> Self {
|
||||
Self {
|
||||
alpha: 1.0,
|
||||
beta: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create from observed successes and failures.
|
||||
pub fn from_observations(successes: f32, failures: f32) -> Self {
|
||||
Self {
|
||||
alpha: successes + 1.0,
|
||||
beta: failures + 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Mean of the Beta distribution: E[X] = alpha / (alpha + beta).
|
||||
pub fn mean(&self) -> f32 {
|
||||
self.alpha / (self.alpha + self.beta)
|
||||
}
|
||||
|
||||
/// Variance: measures uncertainty. Lower = more confident.
|
||||
pub fn variance(&self) -> f32 {
|
||||
let total = self.alpha + self.beta;
|
||||
(self.alpha * self.beta) / (total * total * (total + 1.0))
|
||||
}
|
||||
|
||||
/// Sample from the Beta distribution using the Kumaraswamy approximation.
|
||||
/// Fast, no special functions needed, good enough for Thompson Sampling.
|
||||
pub fn sample(&self, rng: &mut impl Rng) -> f32 {
|
||||
// Use inverse CDF of Beta via simple approximation
|
||||
let u: f32 = rng.gen_range(0.001..0.999);
|
||||
// Kumaraswamy approximation: x = (1 - (1 - u^(1/b))^(1/a))
|
||||
// Better approximation using ratio of gammas via the normal approach
|
||||
let x = Self::beta_inv_approx(u, self.alpha, self.beta);
|
||||
x.clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
/// Approximate inverse CDF of Beta distribution.
|
||||
fn beta_inv_approx(p: f32, a: f32, b: f32) -> f32 {
|
||||
// Use normal approximation for Beta when a,b are not too small
|
||||
if a > 1.0 && b > 1.0 {
|
||||
let mean = a / (a + b);
|
||||
let var = (a * b) / ((a + b) * (a + b) * (a + b + 1.0));
|
||||
let std = var.sqrt();
|
||||
// Inverse normal approximation (Abramowitz & Stegun)
|
||||
let t = if p < 0.5 {
|
||||
(-2.0 * (p).ln()).sqrt()
|
||||
} else {
|
||||
(-2.0 * (1.0 - p).ln()).sqrt()
|
||||
};
|
||||
let x = if p < 0.5 {
|
||||
mean - std * t
|
||||
} else {
|
||||
mean + std * t
|
||||
};
|
||||
x.clamp(0.001, 0.999)
|
||||
} else {
|
||||
// Fallback: simple power approximation
|
||||
p.powf(1.0 / a) * (1.0 - (1.0 - p).powf(1.0 / b)) + p.powf(1.0 / a) * 0.5
|
||||
}
|
||||
}
|
||||
|
||||
/// Update with an observation (Bayesian posterior update).
|
||||
pub fn update(&mut self, reward: f32) {
|
||||
self.alpha += reward;
|
||||
self.beta += 1.0 - reward;
|
||||
}
|
||||
|
||||
/// Merge two Beta distributions (approximate: sum parameters).
|
||||
pub fn merge(&self, other: &BetaParams) -> BetaParams {
|
||||
BetaParams {
|
||||
alpha: self.alpha + other.alpha - 1.0, // subtract uniform prior
|
||||
beta: self.beta + other.beta - 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A context bucket groups similar problem instances for targeted learning.
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct ContextBucket {
|
||||
/// Difficulty tier: "easy", "medium", "hard".
|
||||
pub difficulty_tier: String,
|
||||
/// Problem category within the domain.
|
||||
pub category: String,
|
||||
}
|
||||
|
||||
/// An arm in the multi-armed bandit: a strategy choice.
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct ArmId(pub String);
|
||||
|
||||
/// Transfer prior: compact posterior summary from a source domain.
|
||||
/// This is what gets shipped between domains — not raw trajectories.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TransferPrior {
|
||||
/// Source domain that generated this prior.
|
||||
pub source_domain: DomainId,
|
||||
/// Per-bucket, per-arm Beta parameters (posterior summaries).
|
||||
pub bucket_priors: HashMap<ContextBucket, HashMap<ArmId, BetaParams>>,
|
||||
/// Cost EMA (exponential moving average) priors per bucket.
|
||||
pub cost_ema_priors: HashMap<ContextBucket, f32>,
|
||||
/// Number of cycles this prior was trained on.
|
||||
pub training_cycles: u64,
|
||||
/// Witness hash: proof of how this prior was derived.
|
||||
pub witness_hash: String,
|
||||
}
|
||||
|
||||
impl TransferPrior {
|
||||
/// Create an empty (uniform) prior for a domain.
|
||||
pub fn uniform(source_domain: DomainId) -> Self {
|
||||
Self {
|
||||
source_domain,
|
||||
bucket_priors: HashMap::new(),
|
||||
cost_ema_priors: HashMap::new(),
|
||||
training_cycles: 0,
|
||||
witness_hash: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the prior for a specific bucket and arm, defaulting to uniform.
|
||||
pub fn get_prior(&self, bucket: &ContextBucket, arm: &ArmId) -> BetaParams {
|
||||
self.bucket_priors
|
||||
.get(bucket)
|
||||
.and_then(|arms| arms.get(arm))
|
||||
.cloned()
|
||||
.unwrap_or_else(BetaParams::uniform)
|
||||
}
|
||||
|
||||
/// Update the posterior for a bucket/arm with a new observation.
|
||||
pub fn update_posterior(&mut self, bucket: ContextBucket, arm: ArmId, reward: f32) {
|
||||
let arms = self.bucket_priors.entry(bucket.clone()).or_default();
|
||||
let params = arms.entry(arm).or_insert_with(BetaParams::uniform);
|
||||
params.update(reward);
|
||||
self.training_cycles += 1;
|
||||
}
|
||||
|
||||
/// Update cost EMA for a bucket.
|
||||
pub fn update_cost_ema(&mut self, bucket: ContextBucket, cost: f32, decay: f32) {
|
||||
let entry = self.cost_ema_priors.entry(bucket).or_insert(cost);
|
||||
*entry = decay * (*entry) + (1.0 - decay) * cost;
|
||||
}
|
||||
|
||||
/// Extract a compact summary suitable for shipping to another domain.
|
||||
pub fn extract_summary(&self) -> TransferPrior {
|
||||
// Only ship buckets with sufficient evidence (>10 observations)
|
||||
let filtered: HashMap<ContextBucket, HashMap<ArmId, BetaParams>> = self
|
||||
.bucket_priors
|
||||
.iter()
|
||||
.filter_map(|(bucket, arms)| {
|
||||
let significant_arms: HashMap<ArmId, BetaParams> = arms
|
||||
.iter()
|
||||
.filter(|(_, params)| (params.alpha + params.beta) > 12.0)
|
||||
.map(|(arm, params)| (arm.clone(), params.clone()))
|
||||
.collect();
|
||||
if significant_arms.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some((bucket.clone(), significant_arms))
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
TransferPrior {
|
||||
source_domain: self.source_domain.clone(),
|
||||
bucket_priors: filtered,
|
||||
cost_ema_priors: self.cost_ema_priors.clone(),
|
||||
training_cycles: self.training_cycles,
|
||||
witness_hash: self.witness_hash.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Meta Thompson Sampling engine that manages priors across domains.
|
||||
pub struct MetaThompsonEngine {
|
||||
/// Active priors per domain.
|
||||
domain_priors: HashMap<DomainId, TransferPrior>,
|
||||
/// Available arms (strategies) shared across domains.
|
||||
arms: Vec<ArmId>,
|
||||
/// Difficulty tiers for bucketing.
|
||||
difficulty_tiers: Vec<String>,
|
||||
}
|
||||
|
||||
impl MetaThompsonEngine {
|
||||
/// Create a new engine with the given strategy arms.
|
||||
pub fn new(arms: Vec<String>) -> Self {
|
||||
Self {
|
||||
domain_priors: HashMap::new(),
|
||||
arms: arms.into_iter().map(ArmId).collect(),
|
||||
difficulty_tiers: vec!["easy".into(), "medium".into(), "hard".into()],
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize a domain with uniform priors.
|
||||
pub fn init_domain_uniform(&mut self, domain_id: DomainId) {
|
||||
self.domain_priors
|
||||
.insert(domain_id.clone(), TransferPrior::uniform(domain_id));
|
||||
}
|
||||
|
||||
/// Initialize a domain using transfer priors from a source domain.
|
||||
/// This is the key mechanism: Meta-TS seeds new domains with learned priors.
|
||||
pub fn init_domain_with_transfer(
|
||||
&mut self,
|
||||
target_domain: DomainId,
|
||||
source_prior: &TransferPrior,
|
||||
) {
|
||||
let mut prior = TransferPrior::uniform(target_domain.clone());
|
||||
|
||||
// Copy bucket priors from source, scaling by confidence
|
||||
for (bucket, arms) in &source_prior.bucket_priors {
|
||||
for (arm, params) in arms {
|
||||
// Dampen the prior: don't fully trust cross-domain evidence.
|
||||
// Use sqrt scaling: reduces confidence while preserving mean.
|
||||
let dampened = BetaParams {
|
||||
alpha: 1.0 + (params.alpha - 1.0).sqrt(),
|
||||
beta: 1.0 + (params.beta - 1.0).sqrt(),
|
||||
};
|
||||
prior
|
||||
.bucket_priors
|
||||
.entry(bucket.clone())
|
||||
.or_default()
|
||||
.insert(arm.clone(), dampened);
|
||||
}
|
||||
}
|
||||
|
||||
// Transfer cost EMAs with dampening
|
||||
for (bucket, &cost) in &source_prior.cost_ema_priors {
|
||||
prior.cost_ema_priors.insert(bucket.clone(), cost * 1.5); // pessimistic transfer
|
||||
}
|
||||
|
||||
prior.witness_hash = format!("transfer_from_{}", source_prior.source_domain);
|
||||
self.domain_priors.insert(target_domain, prior);
|
||||
}
|
||||
|
||||
/// Select an arm for a given domain and context using Thompson Sampling.
|
||||
pub fn select_arm(
|
||||
&self,
|
||||
domain_id: &DomainId,
|
||||
bucket: &ContextBucket,
|
||||
rng: &mut impl Rng,
|
||||
) -> Option<ArmId> {
|
||||
let prior = self.domain_priors.get(domain_id)?;
|
||||
|
||||
let mut best_arm = None;
|
||||
let mut best_sample = f32::NEG_INFINITY;
|
||||
|
||||
for arm in &self.arms {
|
||||
let params = prior.get_prior(bucket, arm);
|
||||
let sample = params.sample(rng);
|
||||
if sample > best_sample {
|
||||
best_sample = sample;
|
||||
best_arm = Some(arm.clone());
|
||||
}
|
||||
}
|
||||
|
||||
best_arm
|
||||
}
|
||||
|
||||
/// Record the outcome of using an arm in a domain.
|
||||
pub fn record_outcome(
|
||||
&mut self,
|
||||
domain_id: &DomainId,
|
||||
bucket: ContextBucket,
|
||||
arm: ArmId,
|
||||
reward: f32,
|
||||
cost: f32,
|
||||
) {
|
||||
if let Some(prior) = self.domain_priors.get_mut(domain_id) {
|
||||
prior.update_posterior(bucket.clone(), arm, reward);
|
||||
prior.update_cost_ema(bucket, cost, 0.9);
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract transfer prior from a domain (for shipping to another domain).
|
||||
pub fn extract_prior(&self, domain_id: &DomainId) -> Option<TransferPrior> {
|
||||
self.domain_priors
|
||||
.get(domain_id)
|
||||
.map(|p| p.extract_summary())
|
||||
}
|
||||
|
||||
/// Get all domain IDs currently tracked.
|
||||
pub fn domain_ids(&self) -> Vec<&DomainId> {
|
||||
self.domain_priors.keys().collect()
|
||||
}
|
||||
|
||||
/// Check if posterior variance is high (triggers speculative dual-path).
|
||||
pub fn is_uncertain(
|
||||
&self,
|
||||
domain_id: &DomainId,
|
||||
bucket: &ContextBucket,
|
||||
threshold: f32,
|
||||
) -> bool {
|
||||
let prior = match self.domain_priors.get(domain_id) {
|
||||
Some(p) => p,
|
||||
None => return true, // No data = maximum uncertainty
|
||||
};
|
||||
|
||||
// Check if top two arms are within delta of each other
|
||||
let mut samples: Vec<(f32, &ArmId)> = self
|
||||
.arms
|
||||
.iter()
|
||||
.map(|arm| {
|
||||
let params = prior.get_prior(bucket, arm);
|
||||
(params.mean(), arm)
|
||||
})
|
||||
.collect();
|
||||
samples.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
if samples.len() < 2 {
|
||||
return true;
|
||||
}
|
||||
|
||||
let gap = samples[0].0 - samples[1].0;
|
||||
gap < threshold
|
||||
}
|
||||
}
|
||||
|
||||
/// Speculative dual-path execution for high-uncertainty decisions.
|
||||
/// When the top two arms are within delta, run both and pick the winner.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DualPathResult {
|
||||
/// Primary arm and its outcome.
|
||||
pub primary: (ArmId, f32),
|
||||
/// Secondary arm and its outcome.
|
||||
pub secondary: (ArmId, f32),
|
||||
/// Which arm won.
|
||||
pub winner: ArmId,
|
||||
/// The loser becomes a counterexample for that context.
|
||||
pub counterexample: ArmId,
|
||||
}
|
||||
|
||||
/// Cross-domain transfer verification result.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TransferVerification {
|
||||
/// Source domain.
|
||||
pub source: DomainId,
|
||||
/// Target domain.
|
||||
pub target: DomainId,
|
||||
/// Did transfer improve the target domain?
|
||||
pub improved_target: bool,
|
||||
/// Did transfer regress the source domain?
|
||||
pub regressed_source: bool,
|
||||
/// Is this delta promotable? (improved target AND not regressed source).
|
||||
pub promotable: bool,
|
||||
/// Acceleration factor: ratio of convergence speeds.
|
||||
pub acceleration_factor: f32,
|
||||
/// Source score before/after.
|
||||
pub source_scores: (f32, f32),
|
||||
/// Target score before/after.
|
||||
pub target_scores: (f32, f32),
|
||||
}
|
||||
|
||||
impl TransferVerification {
|
||||
/// Verify a transfer delta against the generalization rule:
|
||||
/// promotable iff it improves Domain 2 without regressing Domain 1.
|
||||
pub fn verify(
|
||||
source: DomainId,
|
||||
target: DomainId,
|
||||
source_before: f32,
|
||||
source_after: f32,
|
||||
target_before: f32,
|
||||
target_after: f32,
|
||||
target_baseline_cycles: u64,
|
||||
target_transfer_cycles: u64,
|
||||
) -> Self {
|
||||
let improved_target = target_after > target_before;
|
||||
let regressed_source = source_after < source_before - 0.01; // small tolerance
|
||||
|
||||
let promotable = improved_target && !regressed_source;
|
||||
|
||||
// Acceleration = baseline_cycles / transfer_cycles (higher = better transfer)
|
||||
let acceleration_factor = if target_transfer_cycles > 0 {
|
||||
target_baseline_cycles as f32 / target_transfer_cycles as f32
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
Self {
|
||||
source,
|
||||
target,
|
||||
improved_target,
|
||||
regressed_source,
|
||||
promotable,
|
||||
acceleration_factor,
|
||||
source_scores: (source_before, source_after),
|
||||
target_scores: (target_before, target_after),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_beta_params_uniform() {
|
||||
let p = BetaParams::uniform();
|
||||
assert_eq!(p.alpha, 1.0);
|
||||
assert_eq!(p.beta, 1.0);
|
||||
assert!((p.mean() - 0.5).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_beta_params_update() {
|
||||
let mut p = BetaParams::uniform();
|
||||
p.update(1.0); // success
|
||||
assert_eq!(p.alpha, 2.0);
|
||||
assert_eq!(p.beta, 1.0);
|
||||
assert!(p.mean() > 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_beta_params_sample_in_range() {
|
||||
let p = BetaParams::from_observations(10.0, 5.0);
|
||||
let mut rng = rand::thread_rng();
|
||||
for _ in 0..100 {
|
||||
let s = p.sample(&mut rng);
|
||||
assert!(s >= 0.0 && s <= 1.0, "Sample {} out of [0,1]", s);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transfer_prior_round_trip() {
|
||||
let domain = DomainId("test".into());
|
||||
let mut prior = TransferPrior::uniform(domain);
|
||||
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "easy".into(),
|
||||
category: "transform".into(),
|
||||
};
|
||||
let arm = ArmId("strategy_a".into());
|
||||
|
||||
for _ in 0..20 {
|
||||
prior.update_posterior(bucket.clone(), arm.clone(), 0.8);
|
||||
}
|
||||
|
||||
let summary = prior.extract_summary();
|
||||
assert!(!summary.bucket_priors.is_empty());
|
||||
|
||||
let retrieved = summary.get_prior(&bucket, &arm);
|
||||
assert!(retrieved.mean() > 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_meta_thompson_engine() {
|
||||
let mut engine = MetaThompsonEngine::new(vec![
|
||||
"strategy_a".into(),
|
||||
"strategy_b".into(),
|
||||
"strategy_c".into(),
|
||||
]);
|
||||
|
||||
let domain1 = DomainId("rust_synthesis".into());
|
||||
engine.init_domain_uniform(domain1.clone());
|
||||
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "medium".into(),
|
||||
category: "algorithm".into(),
|
||||
};
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
// Record some outcomes
|
||||
for _ in 0..50 {
|
||||
let arm = engine.select_arm(&domain1, &bucket, &mut rng).unwrap();
|
||||
let reward = if arm.0 == "strategy_a" { 0.9 } else { 0.3 };
|
||||
engine.record_outcome(&domain1, bucket.clone(), arm, reward, 1.0);
|
||||
}
|
||||
|
||||
// Extract prior and transfer to domain2
|
||||
let prior = engine.extract_prior(&domain1).unwrap();
|
||||
let domain2 = DomainId("planning".into());
|
||||
engine.init_domain_with_transfer(domain2.clone(), &prior);
|
||||
|
||||
// Domain2 should now have informative priors
|
||||
let d2_prior = engine.domain_priors.get(&domain2).unwrap();
|
||||
let a_params = d2_prior.get_prior(&bucket, &ArmId("strategy_a".into()));
|
||||
assert!(
|
||||
a_params.mean() > 0.5,
|
||||
"Transferred prior should favor strategy_a"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transfer_verification() {
|
||||
let v = TransferVerification::verify(
|
||||
DomainId("d1".into()),
|
||||
DomainId("d2".into()),
|
||||
0.8, // source before
|
||||
0.79, // source after (slight decrease, within tolerance)
|
||||
0.3, // target before
|
||||
0.7, // target after (big improvement)
|
||||
100, // baseline cycles
|
||||
40, // transfer cycles
|
||||
);
|
||||
|
||||
assert!(v.improved_target);
|
||||
assert!(!v.regressed_source); // within tolerance
|
||||
assert!(v.promotable);
|
||||
assert!((v.acceleration_factor - 2.5).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transfer_not_promotable_on_regression() {
|
||||
let v = TransferVerification::verify(
|
||||
DomainId("d1".into()),
|
||||
DomainId("d2".into()),
|
||||
0.8, // source before
|
||||
0.5, // source after (regression!)
|
||||
0.3, // target before
|
||||
0.7, // target after
|
||||
100,
|
||||
40,
|
||||
);
|
||||
|
||||
assert!(v.improved_target);
|
||||
assert!(v.regressed_source);
|
||||
assert!(!v.promotable);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_uncertainty_detection() {
|
||||
let mut engine = MetaThompsonEngine::new(vec!["a".into(), "b".into()]);
|
||||
|
||||
let domain = DomainId("test".into());
|
||||
engine.init_domain_uniform(domain.clone());
|
||||
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "easy".into(),
|
||||
category: "test".into(),
|
||||
};
|
||||
|
||||
// With uniform priors, should be uncertain
|
||||
assert!(engine.is_uncertain(&domain, &bucket, 0.1));
|
||||
|
||||
// After many observations favoring one arm, should be certain
|
||||
for _ in 0..100 {
|
||||
engine.record_outcome(&domain, bucket.clone(), ArmId("a".into()), 0.95, 1.0);
|
||||
engine.record_outcome(&domain, bucket.clone(), ArmId("b".into()), 0.1, 1.0);
|
||||
}
|
||||
|
||||
assert!(!engine.is_uncertain(&domain, &bucket, 0.1));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user