Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,482 @@
//! Cost Curve Compression Tracker and Acceleration Scoreboard
//!
//! Measures whether cost curves compress faster in each new domain.
//! If they do, you are increasing general problem-solving capability.
//!
//! ## Acceptance Test
//!
//! Domain 2 must converge faster than Domain 1.
//! Measure cycles to reach:
//! - 95% accuracy
//! - Target cost per solve
//! - Target robustness
//! - Zero policy violations
use crate::domain::DomainId;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// A single data point on the cost curve.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostCurvePoint {
/// Cycle number (training iteration).
pub cycle: u64,
/// Current accuracy [0.0, 1.0].
pub accuracy: f32,
/// Cost per solve at this point.
pub cost_per_solve: f32,
/// Robustness score [0.0, 1.0].
pub robustness: f32,
/// Number of policy violations in this cycle.
pub policy_violations: u32,
/// Wall-clock timestamp (seconds since epoch).
pub timestamp: f64,
}
/// Convergence thresholds for the acceptance test.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConvergenceThresholds {
/// Target accuracy (default: 0.95).
pub target_accuracy: f32,
/// Target cost per solve.
pub target_cost: f32,
/// Target robustness (default: 0.90).
pub target_robustness: f32,
/// Maximum allowed policy violations (default: 0).
pub max_violations: u32,
}
impl Default for ConvergenceThresholds {
fn default() -> Self {
Self {
target_accuracy: 0.95,
target_cost: 0.01,
target_robustness: 0.90,
max_violations: 0,
}
}
}
/// Cost curve for a single domain, tracking convergence over cycles.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostCurve {
/// Domain this curve belongs to.
pub domain_id: DomainId,
/// Whether this was trained with transfer priors.
pub used_transfer: bool,
/// Source domain for transfer (if any).
pub transfer_source: Option<DomainId>,
/// Ordered data points.
pub points: Vec<CostCurvePoint>,
/// Convergence thresholds.
pub thresholds: ConvergenceThresholds,
}
impl CostCurve {
/// Create a new cost curve for a domain.
pub fn new(domain_id: DomainId, thresholds: ConvergenceThresholds) -> Self {
Self {
domain_id,
used_transfer: false,
transfer_source: None,
points: Vec::new(),
thresholds,
}
}
/// Create a cost curve with transfer metadata.
pub fn with_transfer(
domain_id: DomainId,
source: DomainId,
thresholds: ConvergenceThresholds,
) -> Self {
Self {
domain_id,
used_transfer: true,
transfer_source: Some(source),
points: Vec::new(),
thresholds,
}
}
/// Record a new data point.
pub fn record(&mut self, point: CostCurvePoint) {
self.points.push(point);
}
/// Check if all convergence criteria are met at the latest point.
pub fn has_converged(&self) -> bool {
self.points.last().map_or(false, |p| {
p.accuracy >= self.thresholds.target_accuracy
&& p.cost_per_solve <= self.thresholds.target_cost
&& p.robustness >= self.thresholds.target_robustness
&& p.policy_violations <= self.thresholds.max_violations
})
}
/// Cycles to reach target accuracy (None if not yet reached).
pub fn cycles_to_accuracy(&self) -> Option<u64> {
self.points
.iter()
.find(|p| p.accuracy >= self.thresholds.target_accuracy)
.map(|p| p.cycle)
}
/// Cycles to reach target cost (None if not yet reached).
pub fn cycles_to_cost(&self) -> Option<u64> {
self.points
.iter()
.find(|p| p.cost_per_solve <= self.thresholds.target_cost)
.map(|p| p.cycle)
}
/// Cycles to reach target robustness.
pub fn cycles_to_robustness(&self) -> Option<u64> {
self.points
.iter()
.find(|p| p.robustness >= self.thresholds.target_robustness)
.map(|p| p.cycle)
}
/// Cycles to full convergence (all criteria met).
pub fn cycles_to_convergence(&self) -> Option<u64> {
self.points
.iter()
.find(|p| {
p.accuracy >= self.thresholds.target_accuracy
&& p.cost_per_solve <= self.thresholds.target_cost
&& p.robustness >= self.thresholds.target_robustness
&& p.policy_violations <= self.thresholds.max_violations
})
.map(|p| p.cycle)
}
/// Area under the accuracy curve (higher = faster learning).
pub fn auc_accuracy(&self) -> f32 {
if self.points.len() < 2 {
return 0.0;
}
self.points
.windows(2)
.map(|w| {
let dx = (w[1].cycle - w[0].cycle) as f32;
let avg_y = (w[0].accuracy + w[1].accuracy) / 2.0;
dx * avg_y
})
.sum()
}
/// Compression ratio: how fast the cost curve drops.
/// Computed as initial_cost / final_cost (higher = more compression).
pub fn compression_ratio(&self) -> f32 {
if self.points.len() < 2 {
return 1.0;
}
let initial = self.points.first().unwrap().cost_per_solve;
let final_cost = self.points.last().unwrap().cost_per_solve;
if final_cost > 1e-10 {
initial / final_cost
} else {
initial / 1e-10
}
}
}
/// Acceleration scoreboard comparing domain learning curves.
/// Shows acceleration, not just improvement.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccelerationScoreboard {
/// Per-domain cost curves.
pub curves: HashMap<DomainId, CostCurve>,
/// Pairwise acceleration factors.
pub accelerations: Vec<AccelerationEntry>,
}
/// An entry showing how transfer from source to target affected convergence.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccelerationEntry {
/// Source domain.
pub source: DomainId,
/// Target domain.
pub target: DomainId,
/// Cycles to convergence without transfer (baseline).
pub baseline_cycles: Option<u64>,
/// Cycles to convergence with transfer.
pub transfer_cycles: Option<u64>,
/// Acceleration factor: baseline / transfer (>1 = transfer helped).
pub acceleration: f32,
/// AUC comparison (higher = better learning curve).
pub auc_baseline: f32,
pub auc_transfer: f32,
/// Compression ratio comparison.
pub compression_baseline: f32,
pub compression_transfer: f32,
/// Whether generalization test passed.
pub generalization_passed: bool,
}
impl AccelerationScoreboard {
pub fn new() -> Self {
Self {
curves: HashMap::new(),
accelerations: Vec::new(),
}
}
/// Add a cost curve for a domain.
pub fn add_curve(&mut self, curve: CostCurve) {
self.curves.insert(curve.domain_id.clone(), curve);
}
/// Compute acceleration between a baseline (no transfer) and transfer curve.
pub fn compute_acceleration(
&mut self,
baseline_domain: &DomainId,
transfer_domain: &DomainId,
) -> Option<AccelerationEntry> {
let baseline = self.curves.get(baseline_domain)?;
let transfer = self.curves.get(transfer_domain)?;
let baseline_cycles = baseline.cycles_to_convergence();
let transfer_cycles = transfer.cycles_to_convergence();
let acceleration = match (baseline_cycles, transfer_cycles) {
(Some(b), Some(t)) if t > 0 => b as f32 / t as f32,
_ => 1.0, // No measurable acceleration
};
let entry = AccelerationEntry {
source: transfer
.transfer_source
.clone()
.unwrap_or_else(|| DomainId("none".into())),
target: transfer_domain.clone(),
baseline_cycles,
transfer_cycles,
acceleration,
auc_baseline: baseline.auc_accuracy(),
auc_transfer: transfer.auc_accuracy(),
compression_baseline: baseline.compression_ratio(),
compression_transfer: transfer.compression_ratio(),
generalization_passed: acceleration > 1.0,
};
self.accelerations.push(entry.clone());
Some(entry)
}
/// Check whether each successive domain converges faster (the IQ growth test).
pub fn progressive_acceleration(&self) -> bool {
if self.accelerations.len() < 2 {
return true; // Not enough data to judge
}
self.accelerations
.windows(2)
.all(|w| w[1].acceleration >= w[0].acceleration)
}
/// Summary report of all domains.
pub fn summary(&self) -> ScoreboardSummary {
let domain_summaries: Vec<DomainSummary> = self
.curves
.iter()
.map(|(id, curve)| DomainSummary {
domain_id: id.clone(),
total_cycles: curve.points.last().map(|p| p.cycle).unwrap_or(0),
final_accuracy: curve.points.last().map(|p| p.accuracy).unwrap_or(0.0),
final_cost: curve
.points
.last()
.map(|p| p.cost_per_solve)
.unwrap_or(f32::MAX),
converged: curve.has_converged(),
cycles_to_convergence: curve.cycles_to_convergence(),
compression_ratio: curve.compression_ratio(),
used_transfer: curve.used_transfer,
})
.collect();
let overall_acceleration = if self.accelerations.is_empty() {
1.0
} else {
self.accelerations
.iter()
.map(|a| a.acceleration)
.sum::<f32>()
/ self.accelerations.len() as f32
};
ScoreboardSummary {
domains: domain_summaries,
accelerations: self.accelerations.clone(),
overall_acceleration,
progressive_improvement: self.progressive_acceleration(),
}
}
}
impl Default for AccelerationScoreboard {
fn default() -> Self {
Self::new()
}
}
/// Summary of a single domain's learning.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DomainSummary {
pub domain_id: DomainId,
pub total_cycles: u64,
pub final_accuracy: f32,
pub final_cost: f32,
pub converged: bool,
pub cycles_to_convergence: Option<u64>,
pub compression_ratio: f32,
pub used_transfer: bool,
}
/// Full scoreboard summary.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoreboardSummary {
pub domains: Vec<DomainSummary>,
pub accelerations: Vec<AccelerationEntry>,
pub overall_acceleration: f32,
/// True if each new domain converges faster than the previous.
pub progressive_improvement: bool,
}
#[cfg(test)]
mod tests {
use super::*;
fn make_curve(domain: &str, transfer: bool, accuracy_steps: &[(u64, f32, f32)]) -> CostCurve {
let mut curve = if transfer {
CostCurve::with_transfer(
DomainId(domain.into()),
DomainId("source".into()),
ConvergenceThresholds::default(),
)
} else {
CostCurve::new(DomainId(domain.into()), ConvergenceThresholds::default())
};
for &(cycle, accuracy, cost) in accuracy_steps {
curve.record(CostCurvePoint {
cycle,
accuracy,
cost_per_solve: cost,
robustness: accuracy * 0.95,
policy_violations: 0,
timestamp: cycle as f64,
});
}
curve
}
#[test]
fn test_cost_curve_convergence() {
let curve = make_curve(
"test",
false,
&[
(0, 0.3, 0.1),
(10, 0.6, 0.05),
(20, 0.8, 0.02),
(30, 0.95, 0.008),
],
);
assert!(curve.has_converged());
assert_eq!(curve.cycles_to_accuracy(), Some(30));
assert_eq!(curve.cycles_to_cost(), Some(30));
}
#[test]
fn test_cost_curve_not_converged() {
let curve = make_curve("test", false, &[(0, 0.3, 0.1), (10, 0.6, 0.05)]);
assert!(!curve.has_converged());
assert_eq!(curve.cycles_to_accuracy(), None);
}
#[test]
fn test_compression_ratio() {
let curve = make_curve(
"test",
false,
&[(0, 0.3, 1.0), (10, 0.6, 0.5), (20, 0.9, 0.1)],
);
let ratio = curve.compression_ratio();
assert!((ratio - 10.0).abs() < 1e-4); // 1.0 / 0.1 = 10x
}
#[test]
fn test_acceleration_scoreboard() {
let mut board = AccelerationScoreboard::new();
// Domain 1: baseline (slow convergence)
let baseline = make_curve(
"d1_baseline",
false,
&[
(0, 0.2, 0.1),
(20, 0.5, 0.05),
(50, 0.8, 0.02),
(100, 0.95, 0.008),
],
);
// Domain 2: with transfer (fast convergence)
let transfer = make_curve(
"d2_transfer",
true,
&[
(0, 0.4, 0.08),
(10, 0.7, 0.03),
(20, 0.9, 0.01),
(40, 0.96, 0.007),
],
);
board.add_curve(baseline);
board.add_curve(transfer);
let entry = board
.compute_acceleration(
&DomainId("d1_baseline".into()),
&DomainId("d2_transfer".into()),
)
.unwrap();
assert!(entry.acceleration > 1.0, "Transfer should accelerate");
assert_eq!(entry.baseline_cycles, Some(100));
assert_eq!(entry.transfer_cycles, Some(40));
assert!((entry.acceleration - 2.5).abs() < 1e-4);
assert!(entry.generalization_passed);
}
#[test]
fn test_scoreboard_summary() {
let mut board = AccelerationScoreboard::new();
let curve = make_curve("d1", false, &[(0, 0.5, 0.1), (50, 0.96, 0.005)]);
board.add_curve(curve);
let summary = board.summary();
assert_eq!(summary.domains.len(), 1);
assert!(summary.domains[0].converged);
}
#[test]
fn test_auc_accuracy() {
let curve = make_curve(
"test",
false,
&[(0, 0.0, 1.0), (10, 0.5, 0.5), (20, 1.0, 0.1)],
);
let auc = curve.auc_accuracy();
// Trapezoid: (10*(0+0.5)/2) + (10*(0.5+1.0)/2) = 2.5 + 7.5 = 10.0
assert!((auc - 10.0).abs() < 1e-4);
}
}

View File

@@ -0,0 +1,212 @@
//! Core domain trait and types for cross-domain transfer learning.
//!
//! A domain defines a problem space with:
//! - A task generator (produces training instances)
//! - An evaluator (scores solutions on [0.0, 1.0])
//! - Embedding extraction (maps solutions into a shared representation space)
//!
//! True IQ growth appears when a kernel trained on Domain 1 improves Domain 2
//! faster than Domain 2 alone. That is generalization.
use serde::{Deserialize, Serialize};
use std::fmt;
/// Unique identifier for a domain.
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct DomainId(pub String);
impl fmt::Display for DomainId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
/// A single task instance within a domain.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Task {
/// Unique task identifier.
pub id: String,
/// Domain this task belongs to.
pub domain_id: DomainId,
/// Difficulty level [0.0, 1.0].
pub difficulty: f32,
/// Structured task specification (domain-specific JSON).
pub spec: serde_json::Value,
/// Optional constraints the solution must satisfy.
pub constraints: Vec<String>,
}
/// A candidate solution to a domain task.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Solution {
/// The task this solves.
pub task_id: String,
/// Raw solution content (e.g., Rust source, plan steps, tool calls).
pub content: String,
/// Structured solution data (domain-specific).
pub data: serde_json::Value,
}
/// Evaluation result for a solution.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Evaluation {
/// Overall score [0.0, 1.0] where 1.0 is perfect.
pub score: f32,
/// Correctness: does it produce the right answer?
pub correctness: f32,
/// Efficiency: resource usage relative to optimal.
pub efficiency: f32,
/// Elegance: structural quality, idiomatic patterns.
pub elegance: f32,
/// Per-constraint pass/fail results.
pub constraint_results: Vec<bool>,
/// Diagnostic notes from the evaluator.
pub notes: Vec<String>,
}
impl Evaluation {
/// Create a zero-score evaluation (failure).
pub fn zero(notes: Vec<String>) -> Self {
Self {
score: 0.0,
correctness: 0.0,
efficiency: 0.0,
elegance: 0.0,
constraint_results: Vec::new(),
notes,
}
}
/// Compute composite score from weighted sub-scores.
pub fn composite(correctness: f32, efficiency: f32, elegance: f32) -> Self {
let score = 0.6 * correctness + 0.25 * efficiency + 0.15 * elegance;
Self {
score: score.clamp(0.0, 1.0),
correctness,
efficiency,
elegance,
constraint_results: Vec::new(),
notes: Vec::new(),
}
}
}
/// Embedding vector for cross-domain representation.
/// Solutions from different domains are projected into a shared space
/// so that transfer learning can identify structural similarities.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DomainEmbedding {
/// The embedding vector.
pub vector: Vec<f32>,
/// Which domain produced this embedding.
pub domain_id: DomainId,
/// Dimensionality.
pub dim: usize,
}
impl DomainEmbedding {
/// Create a new embedding.
pub fn new(vector: Vec<f32>, domain_id: DomainId) -> Self {
let dim = vector.len();
Self {
vector,
domain_id,
dim,
}
}
/// Cosine similarity with another embedding.
pub fn cosine_similarity(&self, other: &DomainEmbedding) -> f32 {
assert_eq!(self.dim, other.dim, "Embedding dimensions must match");
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for i in 0..self.dim {
dot += self.vector[i] * other.vector[i];
norm_a += self.vector[i] * self.vector[i];
norm_b += other.vector[i] * other.vector[i];
}
let denom = (norm_a.sqrt() * norm_b.sqrt()).max(1e-10);
dot / denom
}
}
/// Core trait that every domain must implement.
///
/// Domains are problem spaces: Rust program synthesis, structured planning,
/// tool orchestration, etc. Each domain knows how to generate tasks,
/// evaluate solutions, and embed solutions into a shared representation space.
pub trait Domain: Send + Sync {
/// Unique identifier for this domain.
fn id(&self) -> &DomainId;
/// Human-readable name.
fn name(&self) -> &str;
/// Generate a batch of tasks at the given difficulty level.
///
/// # Arguments
/// * `count` - Number of tasks to generate
/// * `difficulty` - Target difficulty [0.0, 1.0]
fn generate_tasks(&self, count: usize, difficulty: f32) -> Vec<Task>;
/// Evaluate a solution against its task.
fn evaluate(&self, task: &Task, solution: &Solution) -> Evaluation;
/// Project a solution into the shared embedding space.
/// This enables cross-domain transfer by finding structural similarities
/// between solutions across different problem domains.
fn embed(&self, solution: &Solution) -> DomainEmbedding;
/// Embedding dimensionality for this domain.
fn embedding_dim(&self) -> usize;
/// Generate a reference (optimal or near-optimal) solution for a task.
/// Used for computing efficiency ratios and as training signal.
fn reference_solution(&self, task: &Task) -> Option<Solution>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_domain_id_display() {
let id = DomainId("rust_synthesis".to_string());
assert_eq!(format!("{}", id), "rust_synthesis");
}
#[test]
fn test_evaluation_zero() {
let eval = Evaluation::zero(vec!["compile error".to_string()]);
assert_eq!(eval.score, 0.0);
assert_eq!(eval.notes.len(), 1);
}
#[test]
fn test_evaluation_composite() {
let eval = Evaluation::composite(1.0, 0.8, 0.6);
// 0.6*1.0 + 0.25*0.8 + 0.15*0.6 = 0.6 + 0.2 + 0.09 = 0.89
assert!((eval.score - 0.89).abs() < 1e-4);
}
#[test]
fn test_embedding_cosine_similarity() {
let id = DomainId("test".to_string());
let a = DomainEmbedding::new(vec![1.0, 0.0, 0.0], id.clone());
let b = DomainEmbedding::new(vec![1.0, 0.0, 0.0], id.clone());
assert!((a.cosine_similarity(&b) - 1.0).abs() < 1e-6);
let c = DomainEmbedding::new(vec![0.0, 1.0, 0.0], id);
assert!(a.cosine_similarity(&c).abs() < 1e-6);
}
#[test]
fn test_evaluation_clamp() {
let eval = Evaluation::composite(1.0, 1.0, 1.0);
assert!(eval.score <= 1.0);
}
}

View File

@@ -0,0 +1,39 @@
//! Error types for domain expansion.
use thiserror::Error;
/// Errors that can occur during domain expansion operations.
#[derive(Error, Debug)]
pub enum DomainError {
/// Problem generation failed.
#[error("problem generation failed: {0}")]
Generation(String),
/// Solution evaluation failed.
#[error("evaluation failed: {0}")]
Evaluation(String),
/// Dimension mismatch between domains.
#[error("dimension mismatch: expected {expected}, got {got}")]
DimensionMismatch { expected: usize, got: usize },
/// Domain not found in the expansion engine.
#[error("domain not found: {0}")]
DomainNotFound(String),
/// Transfer failed between domains.
#[error("transfer failed from {source} to {target}: {reason}")]
TransferFailed {
source: String,
target: String,
reason: String,
},
/// Kernel has not been trained on any domain yet.
#[error("kernel not initialized: {0}")]
KernelNotInitialized(String),
/// Invalid configuration.
#[error("invalid config: {0}")]
InvalidConfig(String),
}

View File

@@ -0,0 +1,591 @@
//! # Domain Expansion Engine
//!
//! Cross-domain transfer learning for general problem-solving capability.
//!
//! ## Core Insight
//!
//! True IQ growth appears when a kernel trained on Domain 1 improves Domain 2
//! faster than Domain 2 alone. That is generalization.
//!
//! ## Two-Layer Architecture
//!
//! **Policy learning layer**: Meta Thompson Sampling with Beta priors across
//! context buckets. Chooses strategies via uncertainty-aware selection.
//! Transfer happens through compact priors — not raw trajectories.
//!
//! **Operator layer**: Deterministic domain kernels (Rust synthesis, planning,
//! tool orchestration) that generate tasks, evaluate solutions, and produce
//! embeddings into a shared representation space.
//!
//! ## Domains
//!
//! - **Rust Program Synthesis**: Generate Rust functions from specifications
//! - **Structured Planning**: Multi-step plans with dependencies and resources
//! - **Tool Orchestration**: Coordinate multiple tools/agents for complex goals
//!
//! ## Transfer Protocol
//!
//! 1. Train on Domain 1, extract `TransferPrior` (posterior summaries)
//! 2. Initialize Domain 2 with dampened priors from Domain 1
//! 3. Measure acceleration: cycles to convergence with/without transfer
//! 4. A delta is promotable only if it improves target without regressing source
//!
//! ## Population-Based Policy Search
//!
//! Run a population of `PolicyKernel` variants in parallel.
//! Each variant tunes knobs (skip mode, prepass, speculation thresholds).
//! Keep top performers on holdouts, mutate, repeat.
//!
//! ## Acceptance Test
//!
//! Domain 2 must converge faster than Domain 1 to target accuracy, cost,
//! robustness, and zero policy violations.
#![warn(missing_docs)]
pub mod cost_curve;
pub mod domain;
pub mod meta_learning;
pub mod planning;
pub mod policy_kernel;
pub mod rust_synthesis;
pub mod tool_orchestration;
pub mod transfer;
/// RVF format integration: segment serialization, witness chains, AGI packaging.
///
/// Requires the `rvf` feature to be enabled.
#[cfg(feature = "rvf")]
pub mod rvf_bridge;
// Re-export core types.
pub use cost_curve::{
AccelerationEntry, AccelerationScoreboard, ConvergenceThresholds, CostCurve, CostCurvePoint,
ScoreboardSummary,
};
pub use domain::{Domain, DomainEmbedding, DomainId, Evaluation, Solution, Task};
pub use meta_learning::{
CuriosityBonus, DecayingBeta, MetaLearningEngine, MetaLearningHealth, ParetoFront, ParetoPoint,
PlateauAction, PlateauDetector, RegretSummary, RegretTracker,
};
pub use planning::PlanningDomain;
pub use policy_kernel::{PolicyKernel, PolicyKnobs, PopulationSearch, PopulationStats};
pub use rust_synthesis::RustSynthesisDomain;
pub use tool_orchestration::ToolOrchestrationDomain;
pub use transfer::{
ArmId, BetaParams, ContextBucket, DualPathResult, MetaThompsonEngine, TransferPrior,
TransferVerification,
};
use std::collections::HashMap;
/// The domain expansion orchestrator.
///
/// Manages multiple domains, transfer learning between them,
/// population-based policy search, and the acceleration scoreboard.
///
/// The `meta` field provides five composable learning improvements:
/// regret tracking, decaying priors, plateau detection, Pareto front
/// optimization, and curiosity-driven exploration.
pub struct DomainExpansionEngine {
/// Registered domains.
domains: HashMap<DomainId, Box<dyn Domain>>,
/// Meta Thompson Sampling engine for cross-domain transfer.
pub thompson: MetaThompsonEngine,
/// Population-based policy search.
pub population: PopulationSearch,
/// Acceleration scoreboard tracking convergence across domains.
pub scoreboard: AccelerationScoreboard,
/// Meta-learning engine: regret, plateau, Pareto, curiosity, decay.
pub meta: MetaLearningEngine,
/// Holdout tasks per domain for verification.
holdouts: HashMap<DomainId, Vec<Task>>,
/// Counterexample set: failed solutions that inform future decisions.
counterexamples: HashMap<DomainId, Vec<(Task, Solution, Evaluation)>>,
}
impl DomainExpansionEngine {
/// Create a new domain expansion engine with default configuration.
///
/// Initializes the three core domains and the transfer engine.
pub fn new() -> Self {
let arms = vec![
"greedy".into(),
"exploratory".into(),
"conservative".into(),
"speculative".into(),
];
let mut engine = Self {
domains: HashMap::new(),
thompson: MetaThompsonEngine::new(arms),
population: PopulationSearch::new(8),
scoreboard: AccelerationScoreboard::new(),
meta: MetaLearningEngine::new(),
holdouts: HashMap::new(),
counterexamples: HashMap::new(),
};
// Register the three core domains.
engine.register_domain(Box::new(RustSynthesisDomain::new()));
engine.register_domain(Box::new(PlanningDomain::new()));
engine.register_domain(Box::new(ToolOrchestrationDomain::new()));
engine
}
/// Register a new domain.
pub fn register_domain(&mut self, domain: Box<dyn Domain>) {
let id = domain.id().clone();
self.thompson.init_domain_uniform(id.clone());
self.domains.insert(id, domain);
}
/// Generate holdout tasks for verification.
pub fn generate_holdouts(&mut self, tasks_per_domain: usize, difficulty: f32) {
for (id, domain) in &self.domains {
let tasks = domain.generate_tasks(tasks_per_domain, difficulty);
self.holdouts.insert(id.clone(), tasks);
}
}
/// Generate training tasks for a specific domain.
pub fn generate_tasks(&self, domain_id: &DomainId, count: usize, difficulty: f32) -> Vec<Task> {
self.domains
.get(domain_id)
.map(|d| d.generate_tasks(count, difficulty))
.unwrap_or_default()
}
/// Evaluate a solution and record the outcome.
pub fn evaluate_and_record(
&mut self,
domain_id: &DomainId,
task: &Task,
solution: &Solution,
bucket: ContextBucket,
arm: ArmId,
) -> Evaluation {
let eval = self
.domains
.get(domain_id)
.map(|d| d.evaluate(task, solution))
.unwrap_or_else(|| Evaluation::zero(vec!["Domain not found".into()]));
// Record outcome in Thompson engine.
self.thompson.record_outcome(
domain_id,
bucket.clone(),
arm.clone(),
eval.score,
1.0, // unit cost for now
);
// Record in meta-learning engine (regret + curiosity + decaying beta).
self.meta.record_decision(&bucket, &arm, eval.score);
// Store counterexamples for poor solutions.
if eval.score < 0.3 {
self.counterexamples
.entry(domain_id.clone())
.or_default()
.push((task.clone(), solution.clone(), eval.clone()));
}
eval
}
/// Embed a solution into the shared representation space.
pub fn embed(&self, domain_id: &DomainId, solution: &Solution) -> Option<DomainEmbedding> {
self.domains.get(domain_id).map(|d| d.embed(solution))
}
/// Initiate transfer from source domain to target domain.
/// Extracts priors from source and seeds target.
pub fn initiate_transfer(&mut self, source: &DomainId, target: &DomainId) {
if let Some(prior) = self.thompson.extract_prior(source) {
self.thompson
.init_domain_with_transfer(target.clone(), &prior);
}
}
/// Verify a transfer delta: did it improve target without regressing source?
pub fn verify_transfer(
&self,
source: &DomainId,
target: &DomainId,
source_before: f32,
source_after: f32,
target_before: f32,
target_after: f32,
baseline_cycles: u64,
transfer_cycles: u64,
) -> TransferVerification {
TransferVerification::verify(
source.clone(),
target.clone(),
source_before,
source_after,
target_before,
target_after,
baseline_cycles,
transfer_cycles,
)
}
/// Evaluate all policy kernels on holdout tasks.
pub fn evaluate_population(&mut self) {
let holdout_snapshot: HashMap<DomainId, Vec<Task>> = self.holdouts.clone();
let domain_ids: Vec<DomainId> = self.domains.keys().cloned().collect();
for i in 0..self.population.population().len() {
for domain_id in &domain_ids {
if let Some(holdout_tasks) = holdout_snapshot.get(domain_id) {
let mut total_score = 0.0f32;
let mut count = 0;
for task in holdout_tasks {
if let Some(domain) = self.domains.get(domain_id) {
if let Some(ref_sol) = domain.reference_solution(task) {
let eval = domain.evaluate(task, &ref_sol);
total_score += eval.score;
count += 1;
}
}
}
let avg_score = if count > 0 {
total_score / count as f32
} else {
0.0
};
if let Some(kernel) = self.population.kernel_mut(i) {
kernel.record_score(domain_id.clone(), avg_score, 1.0);
}
}
}
}
}
/// Evolve the policy kernel population and update Pareto front.
pub fn evolve_population(&mut self) {
// Record current population into Pareto front before evolving.
let gen = self.population.generation();
for kernel in self.population.population() {
let accuracy = kernel.fitness();
let cost = if kernel.cycles > 0 {
kernel.total_cost / kernel.cycles as f32
} else {
0.0
};
// Robustness approximated by consistency across domains.
let robustness = if kernel.holdout_scores.len() > 1 {
let mean = accuracy;
let var: f32 = kernel
.holdout_scores
.values()
.map(|s| (s - mean).powi(2))
.sum::<f32>()
/ kernel.holdout_scores.len() as f32;
(1.0 - var.sqrt()).max(0.0)
} else {
accuracy
};
self.meta
.record_kernel(&kernel.id, accuracy, cost, robustness, gen);
}
self.population.evolve();
}
/// Get the best policy kernel found so far.
pub fn best_kernel(&self) -> Option<&PolicyKernel> {
self.population.best()
}
/// Get population statistics.
pub fn population_stats(&self) -> PopulationStats {
self.population.stats()
}
/// Get the scoreboard summary.
pub fn scoreboard_summary(&self) -> ScoreboardSummary {
self.scoreboard.summary()
}
/// Get registered domain IDs.
pub fn domain_ids(&self) -> Vec<DomainId> {
self.domains.keys().cloned().collect()
}
/// Get counterexamples for a domain.
pub fn counterexamples(&self, domain_id: &DomainId) -> &[(Task, Solution, Evaluation)] {
self.counterexamples
.get(domain_id)
.map(|v| v.as_slice())
.unwrap_or(&[])
}
/// Select best arm for a context using Thompson Sampling.
pub fn select_arm(&self, domain_id: &DomainId, bucket: &ContextBucket) -> Option<ArmId> {
let mut rng = rand::thread_rng();
self.thompson.select_arm(domain_id, bucket, &mut rng)
}
/// Check if dual-path speculation should be triggered.
pub fn should_speculate(&self, domain_id: &DomainId, bucket: &ContextBucket) -> bool {
self.thompson.is_uncertain(domain_id, bucket, 0.15)
}
/// Select arm with curiosity-boosted Thompson Sampling.
///
/// Combines the standard Thompson sample with a UCB-style exploration
/// bonus that favors under-visited bucket/arm combinations.
pub fn select_arm_curious(
&self,
domain_id: &DomainId,
bucket: &ContextBucket,
) -> Option<ArmId> {
let mut rng = rand::thread_rng();
// Get all arms and compute boosted scores
let prior = self.thompson.extract_prior(domain_id)?;
let arms: Vec<ArmId> = prior
.bucket_priors
.get(bucket)
.map(|m| m.keys().cloned().collect())
.unwrap_or_default();
if arms.is_empty() {
return self.thompson.select_arm(domain_id, bucket, &mut rng);
}
let mut best_arm = None;
let mut best_score = f32::NEG_INFINITY;
for arm in &arms {
let params = prior.get_prior(bucket, arm);
let sample = params.sample(&mut rng);
let boosted = self.meta.boosted_score(bucket, arm, sample);
if boosted > best_score {
best_score = boosted;
best_arm = Some(arm.clone());
}
}
best_arm.or_else(|| self.thompson.select_arm(domain_id, bucket, &mut rng))
}
/// Get meta-learning health diagnostics.
pub fn meta_health(&self) -> MetaLearningHealth {
self.meta.health_check()
}
/// Check cost curve for plateau and get recommended action.
pub fn check_plateau(&mut self, domain_id: &DomainId) -> PlateauAction {
if let Some(curve) = self.scoreboard.curves.get(domain_id) {
self.meta.check_plateau(&curve.points)
} else {
PlateauAction::Continue
}
}
/// Get regret summary across all learning contexts.
pub fn regret_summary(&self) -> RegretSummary {
self.meta.regret.summary()
}
/// Get the Pareto front of non-dominated policy kernels.
pub fn pareto_front(&self) -> &ParetoFront {
&self.meta.pareto
}
}
impl Default for DomainExpansionEngine {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_engine_creation() {
let engine = DomainExpansionEngine::new();
let ids = engine.domain_ids();
assert_eq!(ids.len(), 3);
}
#[test]
fn test_generate_tasks_all_domains() {
let engine = DomainExpansionEngine::new();
for domain_id in engine.domain_ids() {
let tasks = engine.generate_tasks(&domain_id, 5, 0.5);
assert_eq!(tasks.len(), 5);
}
}
#[test]
fn test_arm_selection() {
let engine = DomainExpansionEngine::new();
let bucket = ContextBucket {
difficulty_tier: "medium".into(),
category: "general".into(),
};
for domain_id in engine.domain_ids() {
let arm = engine.select_arm(&domain_id, &bucket);
assert!(arm.is_some());
}
}
#[test]
fn test_evaluate_and_record() {
let mut engine = DomainExpansionEngine::new();
let domain_id = DomainId("rust_synthesis".into());
let tasks = engine.generate_tasks(&domain_id, 1, 0.3);
let task = &tasks[0];
let solution = Solution {
task_id: task.id.clone(),
content:
"fn double(values: &[i64]) -> Vec<i64> { values.iter().map(|&x| x * 2).collect() }"
.into(),
data: serde_json::Value::Null,
};
let bucket = ContextBucket {
difficulty_tier: "easy".into(),
category: "transform".into(),
};
let arm = ArmId("greedy".into());
let eval = engine.evaluate_and_record(&domain_id, task, &solution, bucket, arm);
assert!(eval.score >= 0.0 && eval.score <= 1.0);
}
#[test]
fn test_cross_domain_embedding() {
let engine = DomainExpansionEngine::new();
let rust_sol = Solution {
task_id: "rust".into(),
content: "fn foo() { for i in 0..10 { if i > 5 { } } }".into(),
data: serde_json::Value::Null,
};
let plan_sol = Solution {
task_id: "plan".into(),
content: "allocate cpu then schedule parallel jobs".into(),
data: serde_json::json!({"steps": []}),
};
let rust_emb = engine
.embed(&DomainId("rust_synthesis".into()), &rust_sol)
.unwrap();
let plan_emb = engine
.embed(&DomainId("structured_planning".into()), &plan_sol)
.unwrap();
// Embeddings should be same dimension.
assert_eq!(rust_emb.dim, plan_emb.dim);
// Cross-domain similarity should be defined.
let sim = rust_emb.cosine_similarity(&plan_emb);
assert!(sim >= -1.0 && sim <= 1.0);
}
#[test]
fn test_transfer_flow() {
let mut engine = DomainExpansionEngine::new();
let source = DomainId("rust_synthesis".into());
let target = DomainId("structured_planning".into());
// Record some outcomes in source domain.
let bucket = ContextBucket {
difficulty_tier: "medium".into(),
category: "algorithm".into(),
};
for _ in 0..30 {
engine.thompson.record_outcome(
&source,
bucket.clone(),
ArmId("greedy".into()),
0.85,
1.0,
);
}
// Initiate transfer.
engine.initiate_transfer(&source, &target);
// Verify the transfer.
let verification = engine.verify_transfer(
&source, &target, 0.85, // source before
0.845, // source after (within tolerance)
0.3, // target before
0.7, // target after
100, // baseline cycles
45, // transfer cycles
);
assert!(verification.promotable);
assert!(verification.acceleration_factor > 1.0);
}
#[test]
fn test_population_evolution() {
let mut engine = DomainExpansionEngine::new();
engine.generate_holdouts(3, 0.3);
engine.evaluate_population();
let stats_before = engine.population_stats();
assert_eq!(stats_before.generation, 0);
engine.evolve_population();
let stats_after = engine.population_stats();
assert_eq!(stats_after.generation, 1);
}
#[test]
fn test_speculation_trigger() {
let engine = DomainExpansionEngine::new();
let bucket = ContextBucket {
difficulty_tier: "hard".into(),
category: "unknown".into(),
};
// With uniform priors, should be uncertain.
assert!(engine.should_speculate(&DomainId("rust_synthesis".into()), &bucket,));
}
#[test]
fn test_counterexample_tracking() {
let mut engine = DomainExpansionEngine::new();
let domain_id = DomainId("rust_synthesis".into());
let tasks = engine.generate_tasks(&domain_id, 1, 0.9);
let task = &tasks[0];
// Submit a terrible solution.
let solution = Solution {
task_id: task.id.clone(),
content: "".into(), // empty = bad
data: serde_json::Value::Null,
};
let bucket = ContextBucket {
difficulty_tier: "hard".into(),
category: "algorithm".into(),
};
let arm = ArmId("speculative".into());
let eval = engine.evaluate_and_record(&domain_id, task, &solution, bucket, arm);
assert!(eval.score < 0.3);
// Should be recorded as counterexample.
assert!(!engine.counterexamples(&domain_id).is_empty());
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,647 @@
//! Structured Planning Tasks Domain
//!
//! Generates tasks that require multi-step reasoning and plan construction.
//! Task types include:
//!
//! - **ResourceAllocation**: Assign limited resources to maximize objective
//! - **DependencyScheduling**: Order tasks respecting dependencies and deadlines
//! - **StateSpaceSearch**: Navigate from initial to goal state
//! - **ConstraintSatisfaction**: Find assignments satisfying all constraints
//! - **HierarchicalDecomposition**: Break complex goals into sub-goals
//!
//! Solutions are plans: ordered sequences of actions with preconditions and effects.
//! Cross-domain transfer from Rust synthesis helps because both require:
//! structured decomposition, constraint satisfaction, and efficient search.
use crate::domain::{Domain, DomainEmbedding, DomainId, Evaluation, Solution, Task};
use rand::Rng;
use serde::{Deserialize, Serialize};
const EMBEDDING_DIM: usize = 64;
/// Categories of planning tasks.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PlanningCategory {
/// Assign limited resources to competing demands.
ResourceAllocation,
/// Schedule tasks with precedence constraints and deadlines.
DependencyScheduling,
/// Find a path from initial state to goal state.
StateSpaceSearch,
/// Assign values to variables satisfying all constraints.
ConstraintSatisfaction,
/// Decompose a high-level goal into achievable sub-tasks.
HierarchicalDecomposition,
}
/// A resource in the planning world.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Resource {
pub name: String,
pub capacity: u32,
}
/// An action in a plan.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlanAction {
pub name: String,
pub preconditions: Vec<String>,
pub effects: Vec<String>,
pub cost: f32,
pub duration: u32,
}
/// A dependency edge: task A must complete before task B.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Dependency {
pub from: String,
pub to: String,
}
/// Specification for a planning task.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlanningTaskSpec {
pub category: PlanningCategory,
pub description: String,
/// Available actions in the planning domain.
pub available_actions: Vec<PlanAction>,
/// Resources with capacity limits.
pub resources: Vec<Resource>,
/// Dependency constraints.
pub dependencies: Vec<Dependency>,
/// Initial state predicates.
pub initial_state: Vec<String>,
/// Goal state predicates.
pub goal_state: Vec<String>,
/// Maximum allowed plan cost.
pub max_cost: Option<f32>,
/// Maximum allowed plan steps.
pub max_steps: Option<usize>,
}
/// A parsed plan from a solution.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Plan {
pub steps: Vec<PlanStep>,
}
/// A single step in a plan.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlanStep {
pub action: String,
pub args: Vec<String>,
pub start_time: Option<u32>,
}
/// Structured planning domain.
pub struct PlanningDomain {
id: DomainId,
}
impl PlanningDomain {
pub fn new() -> Self {
Self {
id: DomainId("structured_planning".to_string()),
}
}
fn gen_resource_allocation(&self, difficulty: f32) -> PlanningTaskSpec {
let num_tasks = if difficulty < 0.3 {
3
} else if difficulty < 0.7 {
6
} else {
10
};
let actions: Vec<PlanAction> = (0..num_tasks)
.map(|i| PlanAction {
name: format!("task_{}", i),
preconditions: vec![format!("resource_available_{}", i % 3)],
effects: vec![format!("task_{}_complete", i)],
cost: (i as f32 + 1.0) * 10.0,
duration: (i as u32 % 5) + 1,
})
.collect();
let resources = vec![
Resource {
name: "cpu".into(),
capacity: if difficulty < 0.5 { 10 } else { 5 },
},
Resource {
name: "memory".into(),
capacity: if difficulty < 0.5 { 8 } else { 3 },
},
Resource {
name: "io".into(),
capacity: if difficulty < 0.5 { 6 } else { 2 },
},
];
let goal_state: Vec<String> = (0..num_tasks)
.map(|i| format!("task_{}_complete", i))
.collect();
PlanningTaskSpec {
category: PlanningCategory::ResourceAllocation,
description: format!(
"Allocate {} resources to complete {} tasks within capacity.",
resources.len(),
num_tasks
),
available_actions: actions,
resources,
dependencies: Vec::new(),
initial_state: vec![
"resource_available_0".into(),
"resource_available_1".into(),
"resource_available_2".into(),
],
goal_state,
max_cost: Some(num_tasks as f32 * 50.0),
max_steps: Some(num_tasks * 2),
}
}
fn gen_dependency_scheduling(&self, difficulty: f32) -> PlanningTaskSpec {
let num_tasks = if difficulty < 0.3 {
4
} else if difficulty < 0.7 {
7
} else {
12
};
let actions: Vec<PlanAction> = (0..num_tasks)
.map(|i| PlanAction {
name: format!("job_{}", i),
preconditions: if i > 0 {
vec![format!("job_{}_done", i - 1)]
} else {
Vec::new()
},
effects: vec![format!("job_{}_done", i)],
cost: 1.0,
duration: (i as u32 % 3) + 1,
})
.collect();
// Create dependency chain with some parallelism
let mut dependencies = Vec::new();
for i in 1..num_tasks {
// Linear chain
dependencies.push(Dependency {
from: format!("job_{}", i - 1),
to: format!("job_{}", i),
});
// Add cross-dependencies at higher difficulty
if difficulty > 0.5 && i >= 3 && i % 2 == 0 {
dependencies.push(Dependency {
from: format!("job_{}", i - 3),
to: format!("job_{}", i),
});
}
}
PlanningTaskSpec {
category: PlanningCategory::DependencyScheduling,
description: format!(
"Schedule {} jobs respecting {} dependencies, minimizing makespan.",
num_tasks,
dependencies.len()
),
available_actions: actions,
resources: vec![Resource {
name: "worker".into(),
capacity: if difficulty < 0.5 { 3 } else { 2 },
}],
dependencies,
initial_state: Vec::new(),
goal_state: (0..num_tasks).map(|i| format!("job_{}_done", i)).collect(),
max_cost: None,
max_steps: Some(num_tasks + 5),
}
}
fn gen_state_space_search(&self, difficulty: f32) -> PlanningTaskSpec {
let grid_size = if difficulty < 0.3 {
3
} else if difficulty < 0.7 {
5
} else {
8
};
let actions = vec![
PlanAction {
name: "move_up".into(),
preconditions: vec!["not_top_edge".into()],
effects: vec!["moved_up".into()],
cost: 1.0,
duration: 1,
},
PlanAction {
name: "move_down".into(),
preconditions: vec!["not_bottom_edge".into()],
effects: vec!["moved_down".into()],
cost: 1.0,
duration: 1,
},
PlanAction {
name: "move_left".into(),
preconditions: vec!["not_left_edge".into()],
effects: vec!["moved_left".into()],
cost: 1.0,
duration: 1,
},
PlanAction {
name: "move_right".into(),
preconditions: vec!["not_right_edge".into()],
effects: vec!["moved_right".into()],
cost: 1.0,
duration: 1,
},
];
PlanningTaskSpec {
category: PlanningCategory::StateSpaceSearch,
description: format!(
"Navigate a {}x{} grid from (0,0) to ({},{}) avoiding obstacles.",
grid_size,
grid_size,
grid_size - 1,
grid_size - 1
),
available_actions: actions,
resources: Vec::new(),
dependencies: Vec::new(),
initial_state: vec!["at(0,0)".into()],
goal_state: vec![format!("at({},{})", grid_size - 1, grid_size - 1)],
max_cost: Some((grid_size as f32) * 4.0),
max_steps: Some(grid_size * grid_size),
}
}
/// Extract structural features from a planning solution.
fn extract_features(&self, solution: &Solution) -> Vec<f32> {
let content = &solution.content;
let mut features = vec![0.0f32; EMBEDDING_DIM];
// Parse the plan
let plan: Plan = serde_json::from_str(&solution.data.to_string())
.or_else(|_| serde_json::from_str(content))
.unwrap_or(Plan { steps: Vec::new() });
// Feature 0-7: Plan structure
features[0] = plan.steps.len() as f32 / 20.0;
features[1] = {
let unique_actions: std::collections::HashSet<&str> =
plan.steps.iter().map(|s| s.action.as_str()).collect();
unique_actions.len() as f32 / plan.steps.len().max(1) as f32
};
// Sequential vs parallel indicator
features[2] = plan
.steps
.windows(2)
.filter(|w| w[0].start_time == w[1].start_time)
.count() as f32
/ plan.steps.len().max(1) as f32;
// Average args per step
features[3] = plan.steps.iter().map(|s| s.args.len() as f32).sum::<f32>()
/ plan.steps.len().max(1) as f32
/ 5.0;
// Feature 8-15: Action type distribution
let action_counts: std::collections::HashMap<&str, usize> =
plan.steps
.iter()
.fold(std::collections::HashMap::new(), |mut acc, s| {
*acc.entry(s.action.as_str()).or_insert(0) += 1;
acc
});
let max_count = action_counts.values().max().copied().unwrap_or(0);
features[8] = action_counts.len() as f32 / 10.0;
features[9] = max_count as f32 / plan.steps.len().max(1) as f32;
// Feature 16-23: Text-based features from content
features[16] = content.matches("allocate").count() as f32 / 5.0;
features[17] = content.matches("schedule").count() as f32 / 5.0;
features[18] = content.matches("move").count() as f32 / 10.0;
features[19] = content.matches("assign").count() as f32 / 5.0;
features[20] = content.matches("wait").count() as f32 / 5.0;
features[21] = content.matches("parallel").count() as f32 / 3.0;
features[22] = content.matches("constraint").count() as f32 / 5.0;
features[23] = content.matches("deadline").count() as f32 / 3.0;
// Feature 32-39: Structural complexity indicators
features[32] = content.matches("->").count() as f32 / 10.0;
features[33] = content.matches("if ").count() as f32 / 5.0;
features[34] = content.matches("then ").count() as f32 / 5.0;
features[35] = content.matches("before").count() as f32 / 5.0;
features[36] = content.matches("after").count() as f32 / 5.0;
features[37] = content.matches("while").count() as f32 / 3.0;
features[38] = content.matches("until").count() as f32 / 3.0;
features[39] = content.matches("complete").count() as f32 / 5.0;
// Feature 48-55: Resource usage indicators
features[48] = content.matches("cpu").count() as f32 / 3.0;
features[49] = content.matches("memory").count() as f32 / 3.0;
features[50] = content.matches("worker").count() as f32 / 3.0;
features[51] = content.matches("capacity").count() as f32 / 3.0;
features[52] = content.matches("cost").count() as f32 / 5.0;
features[53] = content.matches("time").count() as f32 / 5.0;
features[54] = content.matches("resource").count() as f32 / 5.0;
features[55] = content.matches("limit").count() as f32 / 3.0;
// Normalize
let norm: f32 = features.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for f in &mut features {
*f /= norm;
}
}
features
}
/// Evaluate a planning solution.
fn score_plan(&self, spec: &PlanningTaskSpec, solution: &Solution) -> Evaluation {
let content = &solution.content;
let mut correctness = 0.0f32;
let mut efficiency = 0.5f32;
let mut elegance = 0.5f32;
let mut notes = Vec::new();
// Parse plan from solution
let plan: Option<Plan> = serde_json::from_str(&solution.data.to_string())
.ok()
.or_else(|| serde_json::from_str(content).ok());
let plan = match plan {
Some(p) => p,
None => {
// Fall back to text analysis
let has_steps = content.contains("step") || content.contains("action");
if has_steps {
correctness = 0.2;
}
return Evaluation {
score: correctness * 0.6,
correctness,
efficiency: 0.0,
elegance: 0.0,
constraint_results: Vec::new(),
notes: vec!["Could not parse structured plan".into()],
};
}
};
// Check plan is non-empty
if plan.steps.is_empty() {
return Evaluation::zero(vec!["Empty plan".into()]);
}
// Check goal coverage: how many goal predicates are addressed
let goal_coverage = spec
.goal_state
.iter()
.filter(|goal| {
plan.steps.iter().any(|step| {
let action_name = &step.action;
// Check if any action's effects mention this goal
spec.available_actions
.iter()
.any(|a| a.name == *action_name && a.effects.iter().any(|e| e == *goal))
})
})
.count() as f32
/ spec.goal_state.len().max(1) as f32;
correctness = goal_coverage;
// Check dependency ordering
let mut dep_violations = 0;
for dep in &spec.dependencies {
let from_pos = plan.steps.iter().position(|s| s.action == dep.from);
let to_pos = plan.steps.iter().position(|s| s.action == dep.to);
if let (Some(f), Some(t)) = (from_pos, to_pos) {
if f >= t {
dep_violations += 1;
notes.push(format!(
"Dependency violation: {} must come before {}",
dep.from, dep.to
));
}
}
}
if !spec.dependencies.is_empty() {
let dep_score = 1.0 - (dep_violations as f32 / spec.dependencies.len() as f32);
correctness = correctness * 0.5 + dep_score * 0.5;
}
// Efficiency: compare to max allowed steps/cost
if let Some(max_steps) = spec.max_steps {
let step_ratio = plan.steps.len() as f32 / max_steps as f32;
efficiency = if step_ratio <= 1.0 {
1.0 - (step_ratio * 0.5) // Fewer steps = better
} else {
0.5 / step_ratio // Penalty for exceeding max
};
}
if let Some(max_cost) = spec.max_cost {
let total_cost: f32 = plan
.steps
.iter()
.filter_map(|step| {
spec.available_actions
.iter()
.find(|a| a.name == step.action)
.map(|a| a.cost)
})
.sum();
if total_cost > max_cost {
efficiency *= 0.5;
notes.push(format!(
"Plan cost {:.1} exceeds budget {:.1}",
total_cost, max_cost
));
}
}
// Elegance: minimal redundancy, good parallelism
let unique_actions: std::collections::HashSet<&str> =
plan.steps.iter().map(|s| s.action.as_str()).collect();
let redundancy = 1.0 - (unique_actions.len() as f32 / plan.steps.len().max(1) as f32);
elegance = 1.0 - redundancy * 0.5;
// Bonus for parallel scheduling
if plan
.steps
.windows(2)
.any(|w| w[0].start_time == w[1].start_time)
{
elegance += 0.1;
}
elegance = elegance.clamp(0.0, 1.0);
let score = 0.6 * correctness + 0.25 * efficiency + 0.15 * elegance;
Evaluation {
score: score.clamp(0.0, 1.0),
correctness,
efficiency,
elegance,
constraint_results: Vec::new(),
notes,
}
}
}
impl Default for PlanningDomain {
fn default() -> Self {
Self::new()
}
}
impl Domain for PlanningDomain {
fn id(&self) -> &DomainId {
&self.id
}
fn name(&self) -> &str {
"Structured Planning"
}
fn generate_tasks(&self, count: usize, difficulty: f32) -> Vec<Task> {
let mut rng = rand::thread_rng();
let difficulty = difficulty.clamp(0.0, 1.0);
(0..count)
.map(|i| {
let category_roll: f32 = rng.gen();
let spec = if category_roll < 0.35 {
self.gen_resource_allocation(difficulty)
} else if category_roll < 0.7 {
self.gen_dependency_scheduling(difficulty)
} else {
self.gen_state_space_search(difficulty)
};
Task {
id: format!("planning_{}_d{:.0}", i, difficulty * 100.0),
domain_id: self.id.clone(),
difficulty,
spec: serde_json::to_value(&spec).unwrap_or_default(),
constraints: Vec::new(),
}
})
.collect()
}
fn evaluate(&self, task: &Task, solution: &Solution) -> Evaluation {
let spec: PlanningTaskSpec = match serde_json::from_value(task.spec.clone()) {
Ok(s) => s,
Err(e) => return Evaluation::zero(vec![format!("Invalid task spec: {}", e)]),
};
self.score_plan(&spec, solution)
}
fn embed(&self, solution: &Solution) -> DomainEmbedding {
let features = self.extract_features(solution);
DomainEmbedding::new(features, self.id.clone())
}
fn embedding_dim(&self) -> usize {
EMBEDDING_DIM
}
fn reference_solution(&self, task: &Task) -> Option<Solution> {
let spec: PlanningTaskSpec = serde_json::from_value(task.spec.clone()).ok()?;
// Generate a naive sequential plan that executes all actions in order
let steps: Vec<PlanStep> = spec
.available_actions
.iter()
.enumerate()
.map(|(i, a)| PlanStep {
action: a.name.clone(),
args: Vec::new(),
start_time: Some(i as u32),
})
.collect();
let plan = Plan { steps };
let content = serde_json::to_string_pretty(&plan).ok()?;
Some(Solution {
task_id: task.id.clone(),
content,
data: serde_json::to_value(&plan).ok()?,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_planning_tasks() {
let domain = PlanningDomain::new();
let tasks = domain.generate_tasks(5, 0.5);
assert_eq!(tasks.len(), 5);
for task in &tasks {
assert_eq!(task.domain_id, domain.id);
}
}
#[test]
fn test_reference_solution_exists() {
let domain = PlanningDomain::new();
let tasks = domain.generate_tasks(3, 0.3);
for task in &tasks {
let ref_sol = domain.reference_solution(task);
assert!(ref_sol.is_some(), "Should produce reference solution");
}
}
#[test]
fn test_evaluate_reference() {
let domain = PlanningDomain::new();
let tasks = domain.generate_tasks(3, 0.3);
for task in &tasks {
if let Some(solution) = domain.reference_solution(task) {
let eval = domain.evaluate(task, &solution);
assert!(eval.score >= 0.0 && eval.score <= 1.0);
}
}
}
#[test]
fn test_embed_planning() {
let domain = PlanningDomain::new();
let solution = Solution {
task_id: "test".into(),
content: "allocate cpu to task_0, schedule job_1 after job_0".into(),
data: serde_json::json!({ "steps": [] }),
};
let embedding = domain.embed(&solution);
assert_eq!(embedding.dim, EMBEDDING_DIM);
}
#[test]
fn test_difficulty_scaling() {
let domain = PlanningDomain::new();
let easy = domain.generate_tasks(1, 0.1);
let hard = domain.generate_tasks(1, 0.9);
let easy_spec: PlanningTaskSpec = serde_json::from_value(easy[0].spec.clone()).unwrap();
let hard_spec: PlanningTaskSpec = serde_json::from_value(hard[0].spec.clone()).unwrap();
assert!(
hard_spec.available_actions.len() >= easy_spec.available_actions.len(),
"Harder tasks should have more actions"
);
}
}

View File

@@ -0,0 +1,468 @@
//! PolicyKernel: Population-Based Policy Search
//!
//! Run a small population of policy variants in parallel.
//! Each variant changes a small set of knobs:
//! - skip mode policy
//! - prepass mode
//! - speculation trigger thresholds
//! - budget allocation
//!
//! Selection: keep top performers on holdouts, mutate knobs, repeat.
//! Only merge deltas that pass replay-verify.
use crate::domain::DomainId;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Configuration knobs that a PolicyKernel can tune.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PolicyKnobs {
/// Whether to skip low-value operations.
pub skip_mode: bool,
/// Run a cheaper prepass before full execution.
pub prepass_enabled: bool,
/// Threshold for triggering speculative dual-path [0.0, 1.0].
pub speculation_threshold: f32,
/// Budget fraction allocated to exploration vs exploitation [0.0, 1.0].
pub exploration_budget: f32,
/// Maximum retries on failure.
pub max_retries: u32,
/// Batch size for parallel evaluation.
pub batch_size: usize,
/// Cost decay factor for EMA.
pub cost_decay: f32,
/// Minimum confidence to skip uncertainty check.
pub confidence_floor: f32,
}
impl PolicyKnobs {
/// Sensible defaults.
pub fn default_knobs() -> Self {
Self {
skip_mode: false,
prepass_enabled: true,
speculation_threshold: 0.15,
exploration_budget: 0.2,
max_retries: 2,
batch_size: 8,
cost_decay: 0.9,
confidence_floor: 0.7,
}
}
/// Mutate knobs with small random perturbations.
pub fn mutate(&self, rng: &mut impl Rng, mutation_rate: f32) -> Self {
let mut knobs = self.clone();
if rng.gen::<f32>() < mutation_rate {
knobs.skip_mode = !knobs.skip_mode;
}
if rng.gen::<f32>() < mutation_rate {
knobs.prepass_enabled = !knobs.prepass_enabled;
}
if rng.gen::<f32>() < mutation_rate {
let delta: f32 = rng.gen_range(-0.1..0.1);
knobs.speculation_threshold = (knobs.speculation_threshold + delta).clamp(0.01, 0.5);
}
if rng.gen::<f32>() < mutation_rate {
let delta: f32 = rng.gen_range(-0.1..0.1);
knobs.exploration_budget = (knobs.exploration_budget + delta).clamp(0.01, 0.5);
}
if rng.gen::<f32>() < mutation_rate {
knobs.max_retries = rng.gen_range(0..5);
}
if rng.gen::<f32>() < mutation_rate {
knobs.batch_size = rng.gen_range(1..32);
}
if rng.gen::<f32>() < mutation_rate {
let delta: f32 = rng.gen_range(-0.05..0.05);
knobs.cost_decay = (knobs.cost_decay + delta).clamp(0.5, 0.99);
}
if rng.gen::<f32>() < mutation_rate {
let delta: f32 = rng.gen_range(-0.1..0.1);
knobs.confidence_floor = (knobs.confidence_floor + delta).clamp(0.3, 0.95);
}
knobs
}
/// Crossover two parent knobs to produce a child.
pub fn crossover(&self, other: &PolicyKnobs, rng: &mut impl Rng) -> Self {
Self {
skip_mode: if rng.gen() {
self.skip_mode
} else {
other.skip_mode
},
prepass_enabled: if rng.gen() {
self.prepass_enabled
} else {
other.prepass_enabled
},
speculation_threshold: if rng.gen() {
self.speculation_threshold
} else {
other.speculation_threshold
},
exploration_budget: if rng.gen() {
self.exploration_budget
} else {
other.exploration_budget
},
max_retries: if rng.gen() {
self.max_retries
} else {
other.max_retries
},
batch_size: if rng.gen() {
self.batch_size
} else {
other.batch_size
},
cost_decay: if rng.gen() {
self.cost_decay
} else {
other.cost_decay
},
confidence_floor: if rng.gen() {
self.confidence_floor
} else {
other.confidence_floor
},
}
}
}
/// A PolicyKernel is a versioned policy configuration with performance history.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PolicyKernel {
/// Unique identifier.
pub id: String,
/// Configuration knobs.
pub knobs: PolicyKnobs,
/// Performance on holdout tasks (domain_id -> score).
pub holdout_scores: HashMap<DomainId, f32>,
/// Total cost incurred.
pub total_cost: f32,
/// Number of evaluation cycles.
pub cycles: u64,
/// Generation (0 = initial, increments on mutation).
pub generation: u32,
/// Parent kernel ID (for lineage tracking).
pub parent_id: Option<String>,
/// Whether this kernel has been verified via replay.
pub replay_verified: bool,
}
impl PolicyKernel {
/// Create a new kernel with default knobs.
pub fn new(id: String) -> Self {
Self {
id,
knobs: PolicyKnobs::default_knobs(),
holdout_scores: HashMap::new(),
total_cost: 0.0,
cycles: 0,
generation: 0,
parent_id: None,
replay_verified: false,
}
}
/// Create a mutated child kernel.
pub fn mutate(&self, child_id: String, rng: &mut impl Rng) -> Self {
Self {
id: child_id,
knobs: self.knobs.mutate(rng, 0.3),
holdout_scores: HashMap::new(),
total_cost: 0.0,
cycles: 0,
generation: self.generation + 1,
parent_id: Some(self.id.clone()),
replay_verified: false,
}
}
/// Record a holdout score for a domain.
pub fn record_score(&mut self, domain_id: DomainId, score: f32, cost: f32) {
self.holdout_scores.insert(domain_id, score);
self.total_cost += cost;
self.cycles += 1;
}
/// Fitness: average holdout score across all evaluated domains.
pub fn fitness(&self) -> f32 {
if self.holdout_scores.is_empty() {
return 0.0;
}
let total: f32 = self.holdout_scores.values().sum();
total / self.holdout_scores.len() as f32
}
/// Cost-adjusted fitness: penalizes expensive kernels.
pub fn cost_adjusted_fitness(&self) -> f32 {
let raw = self.fitness();
let cost_penalty = (self.total_cost / self.cycles.max(1) as f32).min(1.0);
raw * (1.0 - cost_penalty * 0.3) // 30% weight on cost
}
}
/// Population-based policy search engine.
#[derive(Clone)]
pub struct PopulationSearch {
/// Current population of kernels.
population: Vec<PolicyKernel>,
/// Population size.
pop_size: usize,
/// Best kernel seen so far.
best_kernel: Option<PolicyKernel>,
/// Generation counter.
generation: u32,
}
impl PopulationSearch {
/// Create a new population search with initial random population.
pub fn new(pop_size: usize) -> Self {
let mut rng = rand::thread_rng();
let population: Vec<PolicyKernel> = (0..pop_size)
.map(|i| {
let mut kernel = PolicyKernel::new(format!("kernel_g0_{}", i));
// Random initial knobs
kernel.knobs = PolicyKnobs::default_knobs().mutate(&mut rng, 0.8);
kernel
})
.collect();
Self {
population,
pop_size,
best_kernel: None,
generation: 0,
}
}
/// Get current population for evaluation.
pub fn population(&self) -> &[PolicyKernel] {
&self.population
}
/// Get mutable reference to a kernel by index.
pub fn kernel_mut(&mut self, index: usize) -> Option<&mut PolicyKernel> {
self.population.get_mut(index)
}
/// Evolve to next generation: select top performers, mutate, fill population.
pub fn evolve(&mut self) {
let mut rng = rand::thread_rng();
self.generation += 1;
// Sort by cost-adjusted fitness (descending)
self.population.sort_by(|a, b| {
b.cost_adjusted_fitness()
.partial_cmp(&a.cost_adjusted_fitness())
.unwrap_or(std::cmp::Ordering::Equal)
});
// Track best
if let Some(best) = self.population.first() {
if self
.best_kernel
.as_ref()
.map_or(true, |b| best.fitness() > b.fitness())
{
self.best_kernel = Some(best.clone());
}
}
// Elite selection: keep top 25%
let elite_count = (self.pop_size / 4).max(1);
let elites: Vec<PolicyKernel> = self.population[..elite_count].to_vec();
// Build next generation
let mut next_gen = Vec::with_capacity(self.pop_size);
// Keep elites
for elite in &elites {
let mut kept = elite.clone();
kept.id = format!("kernel_g{}_{}", self.generation, next_gen.len());
kept.holdout_scores.clear();
kept.total_cost = 0.0;
kept.cycles = 0;
next_gen.push(kept);
}
// Fill rest with mutations and crossovers
while next_gen.len() < self.pop_size {
let parent_idx = rng.gen_range(0..elites.len());
let child_id = format!("kernel_g{}_{}", self.generation, next_gen.len());
let child = if rng.gen::<f32>() < 0.3 && elites.len() > 1 {
// Crossover
let other_idx =
(parent_idx + 1 + rng.gen_range(0..elites.len() - 1)) % elites.len();
let mut child = PolicyKernel::new(child_id);
child.knobs = elites[parent_idx]
.knobs
.crossover(&elites[other_idx].knobs, &mut rng);
child.generation = self.generation;
child.parent_id = Some(elites[parent_idx].id.clone());
child
} else {
// Mutation
elites[parent_idx].mutate(child_id, &mut rng)
};
next_gen.push(child);
}
self.population = next_gen;
}
/// Get the best kernel found so far.
pub fn best(&self) -> Option<&PolicyKernel> {
self.best_kernel.as_ref()
}
/// Current generation number.
pub fn generation(&self) -> u32 {
self.generation
}
/// Get fitness statistics for the current population.
pub fn stats(&self) -> PopulationStats {
let fitnesses: Vec<f32> = self.population.iter().map(|k| k.fitness()).collect();
let mean = fitnesses.iter().sum::<f32>() / fitnesses.len().max(1) as f32;
let max = fitnesses.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let min = fitnesses.iter().cloned().fold(f32::INFINITY, f32::min);
let variance = fitnesses.iter().map(|f| (f - mean).powi(2)).sum::<f32>()
/ fitnesses.len().max(1) as f32;
PopulationStats {
generation: self.generation,
pop_size: self.population.len(),
mean_fitness: mean,
max_fitness: max,
min_fitness: min,
fitness_variance: variance,
best_ever_fitness: self
.best_kernel
.as_ref()
.map(|k| k.fitness())
.unwrap_or(0.0),
}
}
}
/// Statistics about the current population.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PopulationStats {
pub generation: u32,
pub pop_size: usize,
pub mean_fitness: f32,
pub max_fitness: f32,
pub min_fitness: f32,
pub fitness_variance: f32,
pub best_ever_fitness: f32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_policy_knobs_default() {
let knobs = PolicyKnobs::default_knobs();
assert!(!knobs.skip_mode);
assert!(knobs.prepass_enabled);
assert!(knobs.speculation_threshold > 0.0);
}
#[test]
fn test_policy_knobs_mutate() {
let knobs = PolicyKnobs::default_knobs();
let mut rng = rand::thread_rng();
let mutated = knobs.mutate(&mut rng, 1.0); // high mutation rate
// At least something should differ (probabilistically)
// Can't guarantee due to randomness, but bounds should hold
assert!(mutated.speculation_threshold >= 0.01 && mutated.speculation_threshold <= 0.5);
assert!(mutated.exploration_budget >= 0.01 && mutated.exploration_budget <= 0.5);
}
#[test]
fn test_policy_kernel_fitness() {
let mut kernel = PolicyKernel::new("test".into());
assert_eq!(kernel.fitness(), 0.0);
kernel.record_score(DomainId("d1".into()), 0.8, 1.0);
kernel.record_score(DomainId("d2".into()), 0.6, 1.0);
assert!((kernel.fitness() - 0.7).abs() < 1e-6);
}
#[test]
fn test_population_search_evolve() {
let mut search = PopulationSearch::new(8);
assert_eq!(search.population().len(), 8);
// Simulate evaluation
for i in 0..8 {
if let Some(kernel) = search.kernel_mut(i) {
let score = 0.3 + (i as f32) * 0.08;
kernel.record_score(DomainId("test".into()), score, 1.0);
}
}
search.evolve();
assert_eq!(search.population().len(), 8);
assert_eq!(search.generation(), 1);
assert!(search.best().is_some());
}
#[test]
fn test_population_stats() {
let mut search = PopulationSearch::new(4);
for i in 0..4 {
if let Some(kernel) = search.kernel_mut(i) {
kernel.record_score(DomainId("test".into()), (i as f32) * 0.25, 1.0);
}
}
let stats = search.stats();
assert_eq!(stats.pop_size, 4);
assert!(stats.max_fitness >= stats.min_fitness);
assert!(stats.mean_fitness >= stats.min_fitness);
assert!(stats.mean_fitness <= stats.max_fitness);
}
#[test]
fn test_crossover() {
let a = PolicyKnobs {
skip_mode: true,
prepass_enabled: false,
speculation_threshold: 0.1,
exploration_budget: 0.1,
max_retries: 1,
batch_size: 4,
cost_decay: 0.8,
confidence_floor: 0.5,
};
let b = PolicyKnobs {
skip_mode: false,
prepass_enabled: true,
speculation_threshold: 0.4,
exploration_budget: 0.4,
max_retries: 4,
batch_size: 16,
cost_decay: 0.95,
confidence_floor: 0.9,
};
let mut rng = rand::thread_rng();
let child = a.crossover(&b, &mut rng);
// Child values should come from one parent or the other
assert!(child.max_retries == 1 || child.max_retries == 4);
assert!(child.batch_size == 4 || child.batch_size == 16);
}
}

View File

@@ -0,0 +1,603 @@
//! Rust Program Synthesis Domain
//!
//! Generates tasks that require synthesizing Rust programs from specifications.
//! Task types include:
//!
//! - **Transform**: Apply a function to data (map, filter, fold)
//! - **DataStructure**: Implement a data structure with specific operations
//! - **Algorithm**: Implement a named algorithm (sorting, searching, graph)
//! - **TypeLevel**: Express constraints via Rust's type system
//! - **Concurrency**: Safe concurrent data access patterns
//!
//! Solutions are evaluated on correctness (do test cases pass?),
//! efficiency (complexity class), and elegance (idiomatic Rust patterns).
use crate::domain::{Domain, DomainEmbedding, DomainId, Evaluation, Solution, Task};
use rand::Rng;
use serde::{Deserialize, Serialize};
/// Embedding dimension for Rust synthesis domain.
const EMBEDDING_DIM: usize = 64;
/// Categories of Rust synthesis tasks.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RustTaskCategory {
/// Transform data: map, filter, fold, scan.
Transform,
/// Implement a data structure with trait impls.
DataStructure,
/// Implement a named algorithm.
Algorithm,
/// Type-level programming: generics, trait bounds, associated types.
TypeLevel,
/// Concurrent programming: Arc, Mutex, channels, atomics.
Concurrency,
}
/// Specification for a Rust synthesis task.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RustTaskSpec {
/// Task category.
pub category: RustTaskCategory,
/// Function signature that must be implemented.
pub signature: String,
/// Natural language description of the required behavior.
pub description: String,
/// Test cases as (input_json, expected_output_json) pairs.
pub test_cases: Vec<(String, String)>,
/// Required traits the solution must implement.
pub required_traits: Vec<String>,
/// Banned patterns (e.g., "unsafe", "unwrap").
pub banned_patterns: Vec<String>,
/// Expected complexity class (e.g., "O(n log n)").
pub expected_complexity: Option<String>,
}
/// Rust program synthesis domain.
pub struct RustSynthesisDomain {
id: DomainId,
}
impl RustSynthesisDomain {
/// Create a new Rust synthesis domain.
pub fn new() -> Self {
Self {
id: DomainId("rust_synthesis".to_string()),
}
}
/// Generate a transform task at the given difficulty.
fn gen_transform(&self, difficulty: f32, rng: &mut impl Rng) -> RustTaskSpec {
let (signature, description, tests, complexity) = if difficulty < 0.3 {
// Easy: simple map
let ops = ["double", "negate", "abs", "square"];
let op = ops[rng.gen_range(0..ops.len())];
(
format!("fn {}(values: &[i64]) -> Vec<i64>", op),
format!("Apply {} to each element in the slice.", op),
match op {
"double" => vec![
("[1, 2, 3]".into(), "[2, 4, 6]".into()),
("[-1, 0, 5]".into(), "[-2, 0, 10]".into()),
],
"negate" => vec![
("[1, -2, 3]".into(), "[-1, 2, -3]".into()),
("[0]".into(), "[0]".into()),
],
"abs" => vec![
("[-1, 2, -3]".into(), "[1, 2, 3]".into()),
("[0, -0]".into(), "[0, 0]".into()),
],
_ => vec![
("[2, 3, 4]".into(), "[4, 9, 16]".into()),
("[0, -1]".into(), "[0, 1]".into()),
],
},
"O(n)",
)
} else if difficulty < 0.7 {
// Medium: filter + fold combos
(
"fn sum_positives(values: &[i64]) -> i64".into(),
"Sum all positive values in the slice.".into(),
vec![
("[1, -2, 3, -4, 5]".into(), "9".into()),
("[-1, -2, -3]".into(), "0".into()),
("[]".into(), "0".into()),
],
"O(n)",
)
} else {
// Hard: sliding window / scan
(
"fn max_subarray_sum(values: &[i64]) -> i64".into(),
"Find the maximum sum contiguous subarray (Kadane's algorithm).".into(),
vec![
("[-2, 1, -3, 4, -1, 2, 1, -5, 4]".into(), "6".into()),
("[-1, -2, -3]".into(), "-1".into()),
("[5]".into(), "5".into()),
],
"O(n)",
)
};
RustTaskSpec {
category: RustTaskCategory::Transform,
signature,
description,
test_cases: tests,
required_traits: Vec::new(),
banned_patterns: vec!["unsafe".into()],
expected_complexity: Some(complexity.into()),
}
}
/// Generate a data structure task.
fn gen_data_structure(&self, difficulty: f32, _rng: &mut impl Rng) -> RustTaskSpec {
if difficulty < 0.4 {
RustTaskSpec {
category: RustTaskCategory::DataStructure,
signature: "struct Stack<T>".into(),
description: "Implement a generic stack with push, pop, peek, is_empty, len."
.into(),
test_cases: vec![
("push(1); push(2); pop()".into(), "Some(2)".into()),
("is_empty()".into(), "true".into()),
("push(1); len()".into(), "1".into()),
],
required_traits: vec!["Default".into()],
banned_patterns: vec!["unsafe".into()],
expected_complexity: Some("O(1) per operation".into()),
}
} else if difficulty < 0.7 {
RustTaskSpec {
category: RustTaskCategory::DataStructure,
signature: "struct MinHeap<T: Ord>".into(),
description: "Implement a binary min-heap with insert, extract_min, peek_min."
.into(),
test_cases: vec![
(
"insert(3); insert(1); insert(2); extract_min()".into(),
"Some(1)".into(),
),
("peek_min() on empty".into(), "None".into()),
],
required_traits: vec!["Default".into()],
banned_patterns: vec!["unsafe".into(), "BinaryHeap".into()],
expected_complexity: Some("O(log n) insert/extract".into()),
}
} else {
RustTaskSpec {
category: RustTaskCategory::DataStructure,
signature: "struct LRUCache<K: Hash + Eq, V>".into(),
description: "Implement an LRU cache with get, put, and capacity eviction.".into(),
test_cases: vec![
(
"cap=2; put(1,'a'); put(2,'b'); get(1); put(3,'c'); get(2)".into(),
"None".into(),
),
(
"cap=1; put(1,'a'); put(2,'b'); get(1)".into(),
"None".into(),
),
],
required_traits: Vec::new(),
banned_patterns: vec!["unsafe".into()],
expected_complexity: Some("O(1) get/put".into()),
}
}
}
/// Generate an algorithm task.
fn gen_algorithm(&self, difficulty: f32, _rng: &mut impl Rng) -> RustTaskSpec {
if difficulty < 0.4 {
RustTaskSpec {
category: RustTaskCategory::Algorithm,
signature: "fn binary_search(sorted: &[i64], target: i64) -> Option<usize>".into(),
description: "Implement binary search on a sorted slice.".into(),
test_cases: vec![
("[1,3,5,7,9], 5".into(), "Some(2)".into()),
("[1,3,5,7,9], 4".into(), "None".into()),
("[], 1".into(), "None".into()),
],
required_traits: Vec::new(),
banned_patterns: vec!["unsafe".into()],
expected_complexity: Some("O(log n)".into()),
}
} else if difficulty < 0.7 {
RustTaskSpec {
category: RustTaskCategory::Algorithm,
signature: "fn merge_sort(values: &mut [i64])".into(),
description: "Implement stable merge sort in-place.".into(),
test_cases: vec![
("[3,1,4,1,5,9,2,6]".into(), "[1,1,2,3,4,5,6,9]".into()),
("[1]".into(), "[1]".into()),
("[]".into(), "[]".into()),
],
required_traits: Vec::new(),
banned_patterns: vec!["unsafe".into(), ".sort".into()],
expected_complexity: Some("O(n log n)".into()),
}
} else {
RustTaskSpec {
category: RustTaskCategory::Algorithm,
signature: "fn shortest_path(adj: &[Vec<(usize, u64)>], src: usize, dst: usize) -> Option<u64>".into(),
description: "Implement Dijkstra's shortest path on a weighted directed graph.".into(),
test_cases: vec![
("3 nodes, 0->1:2, 1->2:3, 0->2:10; src=0, dst=2".into(), "Some(5)".into()),
("2 nodes, no edges; src=0, dst=1".into(), "None".into()),
],
required_traits: Vec::new(),
banned_patterns: vec!["unsafe".into()],
expected_complexity: Some("O((V + E) log V)".into()),
}
}
}
/// Extract structural features from a Rust solution for embedding.
fn extract_features(&self, solution: &Solution) -> Vec<f32> {
let code = &solution.content;
let mut features = vec![0.0f32; EMBEDDING_DIM];
// Feature 0-7: Control flow complexity
features[0] = code.matches("if ").count() as f32 / 10.0;
features[1] = code.matches("for ").count() as f32 / 5.0;
features[2] = code.matches("while ").count() as f32 / 5.0;
features[3] = code.matches("match ").count() as f32 / 5.0;
features[4] = code.matches("loop ").count() as f32 / 3.0;
features[5] = code.matches("return ").count() as f32 / 5.0;
features[6] = code.matches("break").count() as f32 / 3.0;
features[7] = code.matches("continue").count() as f32 / 3.0;
// Feature 8-15: Type system usage
features[8] = code.matches("impl ").count() as f32 / 5.0;
features[9] = code.matches("trait ").count() as f32 / 3.0;
features[10] = code.matches("struct ").count() as f32 / 3.0;
features[11] = code.matches("enum ").count() as f32 / 3.0;
features[12] = code.matches("where ").count() as f32 / 3.0;
features[13] = code.matches("dyn ").count() as f32 / 3.0;
features[14] = code.matches("Box<").count() as f32 / 3.0;
features[15] = code.matches("Rc<").count() as f32 / 3.0;
// Feature 16-23: Functional patterns
features[16] = code.matches(".map(").count() as f32 / 5.0;
features[17] = code.matches(".filter(").count() as f32 / 5.0;
features[18] = code.matches(".fold(").count() as f32 / 3.0;
features[19] = code.matches(".collect()").count() as f32 / 3.0;
features[20] = code.matches(".iter()").count() as f32 / 5.0;
features[21] = code.matches("|").count() as f32 / 10.0; // closures
features[22] = code.matches("Some(").count() as f32 / 5.0;
features[23] = code.matches("None").count() as f32 / 5.0;
// Feature 24-31: Memory/ownership patterns
features[24] = code.matches("&mut ").count() as f32 / 5.0;
features[25] = code.matches("&self").count() as f32 / 5.0;
features[26] = code.matches("mut ").count() as f32 / 10.0;
features[27] = code.matches(".clone()").count() as f32 / 5.0;
features[28] = code.matches("Vec<").count() as f32 / 5.0;
features[29] = code.matches("HashMap").count() as f32 / 3.0;
features[30] = code.matches("String").count() as f32 / 5.0;
features[31] = code.matches("Result<").count() as f32 / 3.0;
// Feature 32-39: Concurrency patterns
features[32] = code.matches("Arc<").count() as f32 / 3.0;
features[33] = code.matches("Mutex<").count() as f32 / 3.0;
features[34] = code.matches("RwLock").count() as f32 / 3.0;
features[35] = code.matches("async ").count() as f32 / 3.0;
features[36] = code.matches("await").count() as f32 / 5.0;
features[37] = code.matches("spawn").count() as f32 / 3.0;
features[38] = code.matches("channel").count() as f32 / 3.0;
features[39] = code.matches("Atomic").count() as f32 / 3.0;
// Feature 40-47: Code structure metrics
let lines: Vec<&str> = code.lines().collect();
features[40] = (lines.len() as f32) / 100.0;
features[41] = lines.iter().filter(|l| l.trim().is_empty()).count() as f32
/ (lines.len().max(1) as f32);
features[42] = code.matches("fn ").count() as f32 / 10.0;
features[43] = code.matches("pub ").count() as f32 / 10.0;
features[44] = code.matches("mod ").count() as f32 / 5.0;
features[45] = code.matches("use ").count() as f32 / 10.0;
features[46] = code.matches("#[").count() as f32 / 5.0; // attributes
features[47] = code.matches("///").count() as f32 / 10.0; // doc comments
// Feature 48-55: Error handling patterns
features[48] = code.matches("unwrap()").count() as f32 / 5.0;
features[49] = code.matches("expect(").count() as f32 / 5.0;
features[50] = code.matches("?;").count() as f32 / 5.0; // error propagation
features[51] = code.matches("Err(").count() as f32 / 5.0;
features[52] = code.matches("Ok(").count() as f32 / 5.0;
features[53] = code.matches("panic!").count() as f32 / 3.0;
features[54] = code.matches("assert").count() as f32 / 5.0;
features[55] = code.matches("debug_assert").count() as f32 / 3.0;
// Feature 56-63: Algorithm indicators
features[56] = code.matches("sort").count() as f32 / 3.0;
features[57] = code.matches("binary_search").count() as f32 / 2.0;
features[58] = code.matches("push").count() as f32 / 5.0;
features[59] = code.matches("pop").count() as f32 / 5.0;
features[60] = code.matches("swap").count() as f32 / 5.0;
features[61] = code.matches("len()").count() as f32 / 5.0;
features[62] = code.matches("is_empty").count() as f32 / 3.0;
features[63] = code.matches("contains").count() as f32 / 3.0;
// Normalize to unit length
let norm: f32 = features.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for f in &mut features {
*f /= norm;
}
}
features
}
/// Score a Rust solution based on pattern matching heuristics.
fn score_solution(&self, spec: &RustTaskSpec, solution: &Solution) -> Evaluation {
let code = &solution.content;
let mut correctness = 0.0f32;
let mut efficiency = 0.5f32;
let mut elegance = 0.5f32;
let mut notes = Vec::new();
// Check for banned patterns
let mut banned_found = false;
for pattern in &spec.banned_patterns {
if code.contains(pattern.as_str()) {
notes.push(format!("Banned pattern found: {}", pattern));
banned_found = true;
}
}
if banned_found {
elegance *= 0.5;
}
// Check that the solution contains the expected signature
let sig_name = spec
.signature
.split('(')
.next()
.unwrap_or("")
.split_whitespace()
.last()
.unwrap_or("");
if code.contains(sig_name) {
correctness += 0.3;
} else {
notes.push(format!("Missing expected identifier: {}", sig_name));
}
// Check for fn definition
if code.contains("fn ") {
correctness += 0.2;
}
// Check for test case coverage hints
let test_coverage = spec
.test_cases
.iter()
.filter(|(input, _)| {
// Heuristic: solution likely handles the input pattern
let key_tokens: Vec<&str> = input.split(|c: char| !c.is_alphanumeric()).collect();
key_tokens.iter().any(|t| !t.is_empty() && code.contains(t))
})
.count() as f32
/ spec.test_cases.len().max(1) as f32;
correctness += test_coverage * 0.5;
correctness = correctness.clamp(0.0, 1.0);
// Efficiency: penalize obviously quadratic patterns
let nested_loops = code.matches("for ").count() > 1 && code.matches("for ").count() > 2;
if nested_loops {
if let Some(ref expected) = spec.expected_complexity {
if expected.contains("O(n)") || expected.contains("O(log") {
efficiency *= 0.5;
notes.push("Possible O(n^2) when O(n) or O(log n) expected".into());
}
}
}
// Elegance: favor idiomatic Rust
let iterator_usage = code.matches(".iter()").count()
+ code.matches(".map(").count()
+ code.matches(".filter(").count()
+ code.matches(".fold(").count();
if iterator_usage > 0 {
elegance += 0.2;
}
// Penalize excessive unwrap
let unwrap_count = code.matches("unwrap()").count();
if unwrap_count > 3 {
elegance -= 0.2;
notes.push("Excessive unwrap() usage".into());
}
// Proper error handling bonus
if code.contains("Result<") || code.contains("?;") {
elegance += 0.1;
}
elegance = elegance.clamp(0.0, 1.0);
// Constraint results
let constraint_results = spec
.banned_patterns
.iter()
.map(|p| !code.contains(p.as_str()))
.collect();
let score = 0.6 * correctness + 0.25 * efficiency + 0.15 * elegance;
Evaluation {
score: score.clamp(0.0, 1.0),
correctness,
efficiency,
elegance,
constraint_results,
notes,
}
}
}
impl Default for RustSynthesisDomain {
fn default() -> Self {
Self::new()
}
}
impl Domain for RustSynthesisDomain {
fn id(&self) -> &DomainId {
&self.id
}
fn name(&self) -> &str {
"Rust Program Synthesis"
}
fn generate_tasks(&self, count: usize, difficulty: f32) -> Vec<Task> {
let mut rng = rand::thread_rng();
let difficulty = difficulty.clamp(0.0, 1.0);
(0..count)
.map(|i| {
let category_roll: f32 = rng.gen();
let spec = if category_roll < 0.4 {
self.gen_transform(difficulty, &mut rng)
} else if category_roll < 0.7 {
self.gen_data_structure(difficulty, &mut rng)
} else {
self.gen_algorithm(difficulty, &mut rng)
};
Task {
id: format!("rust_synth_{}_d{:.0}", i, difficulty * 100.0),
domain_id: self.id.clone(),
difficulty,
spec: serde_json::to_value(&spec).unwrap_or_default(),
constraints: spec.banned_patterns.clone(),
}
})
.collect()
}
fn evaluate(&self, task: &Task, solution: &Solution) -> Evaluation {
let spec: RustTaskSpec = match serde_json::from_value(task.spec.clone()) {
Ok(s) => s,
Err(e) => return Evaluation::zero(vec![format!("Invalid task spec: {}", e)]),
};
self.score_solution(&spec, solution)
}
fn embed(&self, solution: &Solution) -> DomainEmbedding {
let features = self.extract_features(solution);
DomainEmbedding::new(features, self.id.clone())
}
fn embedding_dim(&self) -> usize {
EMBEDDING_DIM
}
fn reference_solution(&self, task: &Task) -> Option<Solution> {
let spec: RustTaskSpec = serde_json::from_value(task.spec.clone()).ok()?;
let content = match spec.category {
RustTaskCategory::Transform => {
if spec.signature.contains("sum_positives") {
"fn sum_positives(values: &[i64]) -> i64 {\n values.iter().filter(|&&x| x > 0).sum()\n}".to_string()
} else if spec.signature.contains("max_subarray_sum") {
"fn max_subarray_sum(values: &[i64]) -> i64 {\n let mut max_so_far = values[0];\n let mut max_ending = values[0];\n for &v in &values[1..] {\n max_ending = v.max(max_ending + v);\n max_so_far = max_so_far.max(max_ending);\n }\n max_so_far\n}".to_string()
} else {
format!(
"{} {{\n values.iter().map(|&x| x /* TODO */).collect()\n}}",
spec.signature
)
}
}
_ => return None,
};
Some(Solution {
task_id: task.id.clone(),
content,
data: serde_json::Value::Null,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_tasks() {
let domain = RustSynthesisDomain::new();
let tasks = domain.generate_tasks(5, 0.5);
assert_eq!(tasks.len(), 5);
for task in &tasks {
assert_eq!(task.domain_id, domain.id);
assert!((task.difficulty - 0.5).abs() < 1e-6);
}
}
#[test]
fn test_evaluate_good_solution() {
let domain = RustSynthesisDomain::new();
let tasks = domain.generate_tasks(1, 0.0);
let task = &tasks[0];
let solution = Solution {
task_id: task.id.clone(),
content: "fn double(values: &[i64]) -> Vec<i64> {\n values.iter().map(|&x| x * 2).collect()\n}".to_string(),
data: serde_json::Value::Null,
};
let eval = domain.evaluate(task, &solution);
assert!(eval.score > 0.0);
}
#[test]
fn test_embed_produces_correct_dim() {
let domain = RustSynthesisDomain::new();
let solution = Solution {
task_id: "test".into(),
content: "fn foo() { let x = 1; }".into(),
data: serde_json::Value::Null,
};
let embedding = domain.embed(&solution);
assert_eq!(embedding.dim, EMBEDDING_DIM);
assert_eq!(embedding.vector.len(), EMBEDDING_DIM);
}
#[test]
fn test_embedding_normalized() {
let domain = RustSynthesisDomain::new();
let solution = Solution {
task_id: "test".into(),
content: "fn foo() { for i in 0..10 { if i > 5 { println!(\"{}\", i); } } }".into(),
data: serde_json::Value::Null,
};
let embedding = domain.embed(&solution);
let norm: f32 = embedding.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-4);
}
#[test]
fn test_difficulty_range() {
let domain = RustSynthesisDomain::new();
// Easy tasks
let easy = domain.generate_tasks(3, 0.1);
for t in &easy {
let spec: RustTaskSpec = serde_json::from_value(t.spec.clone()).unwrap();
assert!(!spec.signature.is_empty());
}
// Hard tasks
let hard = domain.generate_tasks(3, 0.9);
for t in &hard {
let spec: RustTaskSpec = serde_json::from_value(t.spec.clone()).unwrap();
assert!(!spec.signature.is_empty());
}
}
}

View File

@@ -0,0 +1,715 @@
//! RVF Integration Bridge
//!
//! Connects the domain expansion engine to the RuVector Format (RVF):
//! - Serializes `TransferPrior`, `PolicyKernel`, `CostCurve` into RVF segments
//! - Creates SHAKE-256 witness chains for transfer verification
//! - Packages domain expansion artifacts into AGI container TLV entries
//! - Bridges priors to/from the rvf-solver-wasm `PolicyKernel`
//!
//! Requires the `rvf` feature to be enabled.
use rvf_types::{SegmentFlags, SegmentType};
use rvf_wire::reader::{read_segment, validate_segment};
use rvf_wire::writer::write_segment;
use crate::cost_curve::{AccelerationScoreboard, CostCurve};
use crate::domain::DomainId;
use crate::policy_kernel::PolicyKernel;
use crate::transfer::{ArmId, BetaParams, ContextBucket, MetaThompsonEngine, TransferPrior};
// ─── Wire-format wrappers ───────────────────────────────────────────────────
//
// JSON requires string keys for objects. TransferPrior uses HashMap<ContextBucket, _>
// which can't be directly serialized. These wrappers convert to/from Vec<(K,V)> form.
/// Wire-format representation of a TransferPrior (JSON-safe).
#[derive(serde::Serialize, serde::Deserialize)]
struct WireTransferPrior {
source_domain: DomainId,
bucket_priors: Vec<(ContextBucket, Vec<(ArmId, BetaParams)>)>,
cost_ema_priors: Vec<(ContextBucket, f32)>,
training_cycles: u64,
witness_hash: String,
}
impl From<&TransferPrior> for WireTransferPrior {
fn from(p: &TransferPrior) -> Self {
Self {
source_domain: p.source_domain.clone(),
bucket_priors: p
.bucket_priors
.iter()
.map(|(b, arms)| {
let arm_vec: Vec<(ArmId, BetaParams)> =
arms.iter().map(|(a, bp)| (a.clone(), bp.clone())).collect();
(b.clone(), arm_vec)
})
.collect(),
cost_ema_priors: p
.cost_ema_priors
.iter()
.map(|(b, c)| (b.clone(), *c))
.collect(),
training_cycles: p.training_cycles,
witness_hash: p.witness_hash.clone(),
}
}
}
impl From<WireTransferPrior> for TransferPrior {
fn from(w: WireTransferPrior) -> Self {
let mut bucket_priors = std::collections::HashMap::new();
for (bucket, arms) in w.bucket_priors {
let arm_map: std::collections::HashMap<ArmId, BetaParams> = arms.into_iter().collect();
bucket_priors.insert(bucket, arm_map);
}
let cost_ema_priors: std::collections::HashMap<ContextBucket, f32> =
w.cost_ema_priors.into_iter().collect();
Self {
source_domain: w.source_domain,
bucket_priors,
cost_ema_priors,
training_cycles: w.training_cycles,
witness_hash: w.witness_hash,
}
}
}
/// Wire-format representation of a PolicyKernel (JSON-safe).
#[derive(serde::Serialize, serde::Deserialize)]
struct WirePolicyKernel {
id: String,
knobs: crate::policy_kernel::PolicyKnobs,
holdout_scores: Vec<(DomainId, f32)>,
total_cost: f32,
cycles: u64,
generation: u32,
parent_id: Option<String>,
replay_verified: bool,
}
impl From<&PolicyKernel> for WirePolicyKernel {
fn from(k: &PolicyKernel) -> Self {
Self {
id: k.id.clone(),
knobs: k.knobs.clone(),
holdout_scores: k
.holdout_scores
.iter()
.map(|(d, s)| (d.clone(), *s))
.collect(),
total_cost: k.total_cost,
cycles: k.cycles,
generation: k.generation,
parent_id: k.parent_id.clone(),
replay_verified: k.replay_verified,
}
}
}
impl From<WirePolicyKernel> for PolicyKernel {
fn from(w: WirePolicyKernel) -> Self {
Self {
id: w.id,
knobs: w.knobs,
holdout_scores: w.holdout_scores.into_iter().collect(),
total_cost: w.total_cost,
cycles: w.cycles,
generation: w.generation,
parent_id: w.parent_id,
replay_verified: w.replay_verified,
}
}
}
// ─── Segment serialization ──────────────────────────────────────────────────
/// Serialize a `TransferPrior` into an RVF TRANSFER_PRIOR segment.
///
/// Wire format: JSON payload (using Vec-of-tuples for map keys) inside a
/// 64-byte-aligned RVF segment. Type: `SegmentType::TransferPrior` (0x30).
pub fn transfer_prior_to_segment(prior: &TransferPrior, segment_id: u64) -> Vec<u8> {
let wire: WireTransferPrior = prior.into();
let payload = serde_json::to_vec(&wire).expect("WireTransferPrior serialization cannot fail");
write_segment(
SegmentType::TransferPrior as u8,
&payload,
SegmentFlags::empty(),
segment_id,
)
}
/// Deserialize a `TransferPrior` from an RVF segment's raw bytes.
///
/// Validates the segment header, checks the content hash, and deserializes
/// the JSON payload.
pub fn transfer_prior_from_segment(data: &[u8]) -> Result<TransferPrior, RvfBridgeError> {
let (header, payload) = read_segment(data).map_err(RvfBridgeError::Rvf)?;
if header.seg_type != SegmentType::TransferPrior as u8 {
return Err(RvfBridgeError::WrongSegmentType {
expected: SegmentType::TransferPrior as u8,
got: header.seg_type,
});
}
validate_segment(&header, payload).map_err(RvfBridgeError::Rvf)?;
let wire: WireTransferPrior = serde_json::from_slice(payload).map_err(RvfBridgeError::Json)?;
Ok(wire.into())
}
/// Serialize a `PolicyKernel` into an RVF POLICY_KERNEL segment.
pub fn policy_kernel_to_segment(kernel: &PolicyKernel, segment_id: u64) -> Vec<u8> {
let wire: WirePolicyKernel = kernel.into();
let payload = serde_json::to_vec(&wire).expect("WirePolicyKernel serialization cannot fail");
write_segment(
SegmentType::PolicyKernel as u8,
&payload,
SegmentFlags::empty(),
segment_id,
)
}
/// Deserialize a `PolicyKernel` from an RVF segment.
pub fn policy_kernel_from_segment(data: &[u8]) -> Result<PolicyKernel, RvfBridgeError> {
let (header, payload) = read_segment(data).map_err(RvfBridgeError::Rvf)?;
if header.seg_type != SegmentType::PolicyKernel as u8 {
return Err(RvfBridgeError::WrongSegmentType {
expected: SegmentType::PolicyKernel as u8,
got: header.seg_type,
});
}
validate_segment(&header, payload).map_err(RvfBridgeError::Rvf)?;
let wire: WirePolicyKernel = serde_json::from_slice(payload).map_err(RvfBridgeError::Json)?;
Ok(wire.into())
}
/// Serialize a `CostCurve` into an RVF COST_CURVE segment.
pub fn cost_curve_to_segment(curve: &CostCurve, segment_id: u64) -> Vec<u8> {
let payload = serde_json::to_vec(curve).expect("CostCurve serialization cannot fail");
write_segment(
SegmentType::CostCurve as u8,
&payload,
SegmentFlags::empty(),
segment_id,
)
}
/// Deserialize a `CostCurve` from an RVF segment.
pub fn cost_curve_from_segment(data: &[u8]) -> Result<CostCurve, RvfBridgeError> {
let (header, payload) = read_segment(data).map_err(RvfBridgeError::Rvf)?;
if header.seg_type != SegmentType::CostCurve as u8 {
return Err(RvfBridgeError::WrongSegmentType {
expected: SegmentType::CostCurve as u8,
got: header.seg_type,
});
}
validate_segment(&header, payload).map_err(RvfBridgeError::Rvf)?;
serde_json::from_slice(payload).map_err(RvfBridgeError::Json)
}
// ─── Witness chain ──────────────────────────────────────────────────────────
/// Witness type constants for domain expansion operations.
pub const WITNESS_TRANSFER: u8 = 0x10;
/// Witness type for policy kernel promotion.
pub const WITNESS_POLICY_PROMOTION: u8 = 0x11;
/// Witness type for cost curve convergence checkpoint.
pub const WITNESS_CONVERGENCE: u8 = 0x12;
/// Create a SHAKE-256 witness hash for a transfer prior.
///
/// The witness hash covers: source domain, training cycles, and the serialized
/// bucket priors. This replaces the old string-based `witness_hash` field.
pub fn compute_transfer_witness_hash(prior: &TransferPrior) -> [u8; 32] {
let wire: WireTransferPrior = prior.into();
let payload = serde_json::to_vec(&wire).expect("WireTransferPrior serialization cannot fail");
rvf_crypto::shake256_256(&payload)
}
/// Build witness entries for a transfer verification event.
///
/// Returns entries suitable for `rvf_crypto::create_witness_chain()`.
pub fn build_transfer_witness_entries(
prior: &TransferPrior,
source: &DomainId,
target: &DomainId,
acceleration_factor: f32,
timestamp_ns: u64,
) -> Vec<rvf_crypto::WitnessEntry> {
let mut entries = Vec::with_capacity(2);
// Entry 1: Transfer prior hash
let prior_hash = compute_transfer_witness_hash(prior);
entries.push(rvf_crypto::WitnessEntry {
prev_hash: [0u8; 32],
action_hash: prior_hash,
timestamp_ns,
witness_type: WITNESS_TRANSFER,
});
// Entry 2: Acceleration verification (hash of source→target + factor)
let accel_payload = format!(
"{}->{}:accel={:.6}",
source.0, target.0, acceleration_factor
);
let accel_hash = rvf_crypto::shake256_256(accel_payload.as_bytes());
entries.push(rvf_crypto::WitnessEntry {
prev_hash: [0u8; 32], // chaining handled by create_witness_chain
action_hash: accel_hash,
timestamp_ns: timestamp_ns + 1,
witness_type: WITNESS_CONVERGENCE,
});
entries
}
// ─── AGI Container TLV packaging ────────────────────────────────────────────
/// A TLV (Tag-Length-Value) entry for AGI container manifest packaging.
#[derive(Debug, Clone)]
pub struct AgiTlvEntry {
/// TLV tag (see `AGI_TAG_*` constants in rvf-types).
pub tag: u16,
/// Serialized value payload.
pub value: Vec<u8>,
}
/// Package domain expansion artifacts into AGI container TLV entries.
///
/// Returns a vector of TLV entries ready for inclusion in an AGI container
/// manifest segment. Each entry uses the corresponding `AGI_TAG_*` constant.
pub fn package_for_agi_container(
priors: &[TransferPrior],
kernels: &[PolicyKernel],
scoreboard: &AccelerationScoreboard,
) -> Vec<AgiTlvEntry> {
let mut entries = Vec::new();
// Transfer priors (use wire format for JSON-safe serialization)
for prior in priors {
let wire: WireTransferPrior = prior.into();
let value = serde_json::to_vec(&wire).expect("WireTransferPrior serialization cannot fail");
entries.push(AgiTlvEntry {
tag: rvf_types::AGI_TAG_TRANSFER_PRIOR,
value,
});
}
// Policy kernels (use wire format for JSON-safe serialization)
for kernel in kernels {
let wire: WirePolicyKernel = kernel.into();
let value = serde_json::to_vec(&wire).expect("WirePolicyKernel serialization cannot fail");
entries.push(AgiTlvEntry {
tag: rvf_types::AGI_TAG_POLICY_KERNEL,
value,
});
}
// Cost curves from the scoreboard
for curve in scoreboard.curves.values() {
let value = serde_json::to_vec(curve).expect("CostCurve serialization cannot fail");
entries.push(AgiTlvEntry {
tag: rvf_types::AGI_TAG_COST_CURVE,
value,
});
}
entries
}
/// Encode TLV entries into a binary payload for inclusion in a META segment.
///
/// Wire format per entry: `[tag: u16 LE][length: u32 LE][value: length bytes]`
pub fn encode_tlv_entries(entries: &[AgiTlvEntry]) -> Vec<u8> {
let total_size: usize = entries.iter().map(|e| 6 + e.value.len()).sum();
let mut buf = Vec::with_capacity(total_size);
for entry in entries {
buf.extend_from_slice(&entry.tag.to_le_bytes());
buf.extend_from_slice(&(entry.value.len() as u32).to_le_bytes());
buf.extend_from_slice(&entry.value);
}
buf
}
/// Decode TLV entries from a binary payload.
pub fn decode_tlv_entries(data: &[u8]) -> Result<Vec<AgiTlvEntry>, RvfBridgeError> {
let mut entries = Vec::new();
let mut offset = 0;
while offset + 6 <= data.len() {
let tag = u16::from_le_bytes([data[offset], data[offset + 1]]);
let length = u32::from_le_bytes([
data[offset + 2],
data[offset + 3],
data[offset + 4],
data[offset + 5],
]) as usize;
offset += 6;
if offset + length > data.len() {
return Err(RvfBridgeError::TruncatedTlv);
}
entries.push(AgiTlvEntry {
tag,
value: data[offset..offset + length].to_vec(),
});
offset += length;
}
Ok(entries)
}
// ─── Solver bridge ──────────────────────────────────────────────────────────
/// Compact prior exchange format bridging domain expansion's `MetaThompsonEngine`
/// to the rvf-solver-wasm `PolicyKernel`.
///
/// The solver-wasm uses per-bucket `SkipModeStats` with `(alpha_safety, beta_safety)`
/// and `cost_ema`. The domain expansion uses per-bucket `BetaParams` with
/// `(alpha, beta)` and `cost_ema_priors`. This type converts between them.
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SolverPriorExchange {
/// Context bucket key (e.g. "medium:some:clean").
pub bucket_key: String,
/// Per-arm alpha/beta pairs mapping arm name to (alpha, beta).
pub arm_params: Vec<(String, f32, f32)>,
/// Cost EMA for this bucket.
pub cost_ema: f32,
/// Training cycle count for confidence estimation.
pub training_cycles: u64,
}
/// Extract solver-compatible prior exchange data from the Thompson engine.
///
/// Flattens the domain expansion's hierarchical buckets into the solver's
/// flat "range:distractor:noise" keys for the specified domain.
pub fn extract_solver_priors(
engine: &MetaThompsonEngine,
domain_id: &DomainId,
) -> Vec<SolverPriorExchange> {
let prior = match engine.extract_prior(domain_id) {
Some(p) => p,
None => return Vec::new(),
};
prior
.bucket_priors
.iter()
.map(|(bucket, arms)| {
let bucket_key = format!("{}:{}", bucket.difficulty_tier, bucket.category);
let arm_params: Vec<(String, f32, f32)> = arms
.iter()
.map(|(arm, params)| (arm.0.clone(), params.alpha, params.beta))
.collect();
let cost_ema = prior.cost_ema_priors.get(bucket).copied().unwrap_or(1.0);
SolverPriorExchange {
bucket_key,
arm_params,
cost_ema,
training_cycles: prior.training_cycles,
}
})
.collect()
}
/// Import solver prior exchange data back into the Thompson engine.
///
/// Seeds the specified domain with the exchanged priors, enabling
/// cross-system transfer.
pub fn import_solver_priors(
engine: &mut MetaThompsonEngine,
domain_id: &DomainId,
exchanges: &[SolverPriorExchange],
) {
// Build a synthetic TransferPrior from the exchange data.
let mut prior = TransferPrior::uniform(domain_id.clone());
for exchange in exchanges {
let parts: Vec<&str> = exchange.bucket_key.splitn(2, ':').collect();
let bucket = ContextBucket {
difficulty_tier: parts.first().unwrap_or(&"medium").to_string(),
category: parts.get(1).unwrap_or(&"general").to_string(),
};
let mut arm_map = std::collections::HashMap::new();
for (arm_name, alpha, beta) in &exchange.arm_params {
arm_map.insert(
crate::transfer::ArmId(arm_name.clone()),
BetaParams {
alpha: *alpha,
beta: *beta,
},
);
}
prior.bucket_priors.insert(bucket.clone(), arm_map);
prior.cost_ema_priors.insert(bucket, exchange.cost_ema);
prior.training_cycles = exchange.training_cycles;
}
engine.init_domain_with_transfer(domain_id.clone(), &prior);
}
// ─── Multi-segment file assembly ────────────────────────────────────────────
/// Assemble a complete RVF byte stream containing all domain expansion segments.
///
/// Outputs concatenated segments: transfer priors, then policy kernels, then
/// cost curves. Each gets a unique segment ID starting from `base_segment_id`.
///
/// The returned bytes can be appended to an existing RVF file or written as
/// a standalone domain expansion archive.
pub fn assemble_domain_expansion_segments(
priors: &[TransferPrior],
kernels: &[PolicyKernel],
curves: &[CostCurve],
base_segment_id: u64,
) -> Vec<u8> {
let mut buf = Vec::new();
let mut seg_id = base_segment_id;
for prior in priors {
buf.extend_from_slice(&transfer_prior_to_segment(prior, seg_id));
seg_id += 1;
}
for kernel in kernels {
buf.extend_from_slice(&policy_kernel_to_segment(kernel, seg_id));
seg_id += 1;
}
for curve in curves {
buf.extend_from_slice(&cost_curve_to_segment(curve, seg_id));
seg_id += 1;
}
buf
}
// ─── Errors ─────────────────────────────────────────────────────────────────
/// Errors specific to the RVF bridge operations.
#[derive(Debug)]
pub enum RvfBridgeError {
/// Underlying RVF format error.
Rvf(rvf_types::RvfError),
/// JSON serialization/deserialization error.
Json(serde_json::Error),
/// Segment type mismatch.
WrongSegmentType {
/// Expected segment type discriminant.
expected: u8,
/// Actual segment type discriminant.
got: u8,
},
/// TLV payload truncated.
TruncatedTlv,
}
impl std::fmt::Display for RvfBridgeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Rvf(e) => write!(f, "RVF error: {e}"),
Self::Json(e) => write!(f, "JSON error: {e}"),
Self::WrongSegmentType { expected, got } => {
write!(
f,
"wrong segment type: expected 0x{expected:02X}, got 0x{got:02X}"
)
}
Self::TruncatedTlv => write!(f, "TLV payload truncated"),
}
}
}
impl std::error::Error for RvfBridgeError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Json(e) => Some(e),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cost_curve::{ConvergenceThresholds, CostCurvePoint};
#[test]
fn transfer_prior_round_trip() {
let mut prior = TransferPrior::uniform(DomainId("test".into()));
let bucket = ContextBucket {
difficulty_tier: "medium".into(),
category: "algo".into(),
};
prior.update_posterior(bucket, crate::transfer::ArmId("greedy".into()), 0.85);
let segment = transfer_prior_to_segment(&prior, 1);
let decoded = transfer_prior_from_segment(&segment).unwrap();
assert_eq!(decoded.source_domain, prior.source_domain);
assert_eq!(decoded.training_cycles, prior.training_cycles);
}
#[test]
fn policy_kernel_round_trip() {
let kernel = PolicyKernel::new("test_kernel".into());
let segment = policy_kernel_to_segment(&kernel, 2);
let decoded = policy_kernel_from_segment(&segment).unwrap();
assert_eq!(decoded.id, "test_kernel");
assert_eq!(decoded.generation, 0);
}
#[test]
fn cost_curve_round_trip() {
let mut curve = CostCurve::new(DomainId("test".into()), ConvergenceThresholds::default());
curve.record(CostCurvePoint {
cycle: 0,
accuracy: 0.3,
cost_per_solve: 0.1,
robustness: 0.3,
policy_violations: 0,
timestamp: 0.0,
});
let segment = cost_curve_to_segment(&curve, 3);
let decoded = cost_curve_from_segment(&segment).unwrap();
assert_eq!(decoded.domain_id, DomainId("test".into()));
assert_eq!(decoded.points.len(), 1);
}
#[test]
fn wrong_segment_type_detected() {
let kernel = PolicyKernel::new("k".into());
let segment = policy_kernel_to_segment(&kernel, 1);
let result = transfer_prior_from_segment(&segment);
assert!(matches!(
result,
Err(RvfBridgeError::WrongSegmentType { .. })
));
}
#[test]
fn witness_hash_is_deterministic() {
let prior = TransferPrior::uniform(DomainId("test".into()));
let h1 = compute_transfer_witness_hash(&prior);
let h2 = compute_transfer_witness_hash(&prior);
assert_eq!(h1, h2);
assert_ne!(h1, [0u8; 32]);
}
#[test]
fn witness_entries_chain() {
let prior = TransferPrior::uniform(DomainId("d1".into()));
let entries = build_transfer_witness_entries(
&prior,
&DomainId("d1".into()),
&DomainId("d2".into()),
2.5,
1_000_000_000,
);
assert_eq!(entries.len(), 2);
assert_eq!(entries[0].witness_type, WITNESS_TRANSFER);
assert_eq!(entries[1].witness_type, WITNESS_CONVERGENCE);
// Verify the chain is valid after linking
let chain_bytes = rvf_crypto::create_witness_chain(&entries);
let verified = rvf_crypto::verify_witness_chain(&chain_bytes).unwrap();
assert_eq!(verified.len(), 2);
}
#[test]
fn tlv_round_trip() {
let entries = vec![
AgiTlvEntry {
tag: rvf_types::AGI_TAG_TRANSFER_PRIOR,
value: b"hello".to_vec(),
},
AgiTlvEntry {
tag: rvf_types::AGI_TAG_POLICY_KERNEL,
value: b"world".to_vec(),
},
];
let encoded = encode_tlv_entries(&entries);
let decoded = decode_tlv_entries(&encoded).unwrap();
assert_eq!(decoded.len(), 2);
assert_eq!(decoded[0].tag, rvf_types::AGI_TAG_TRANSFER_PRIOR);
assert_eq!(decoded[0].value, b"hello");
assert_eq!(decoded[1].tag, rvf_types::AGI_TAG_POLICY_KERNEL);
assert_eq!(decoded[1].value, b"world");
}
#[test]
fn agi_container_packaging() {
let prior = TransferPrior::uniform(DomainId("test".into()));
let kernel = PolicyKernel::new("k0".into());
let scoreboard = crate::cost_curve::AccelerationScoreboard::new();
let entries = package_for_agi_container(&[prior], &[kernel], &scoreboard);
assert_eq!(entries.len(), 2); // 1 prior + 1 kernel, 0 curves
let encoded = encode_tlv_entries(&entries);
let decoded = decode_tlv_entries(&encoded).unwrap();
assert_eq!(decoded.len(), 2);
}
#[test]
fn solver_prior_exchange_round_trip() {
let arms = vec!["greedy".into(), "exploratory".into()];
let mut engine = MetaThompsonEngine::new(arms);
let domain = DomainId("test".into());
engine.init_domain_uniform(domain.clone());
let bucket = ContextBucket {
difficulty_tier: "medium".into(),
category: "algorithm".into(),
};
for _ in 0..20 {
engine.record_outcome(
&domain,
bucket.clone(),
crate::transfer::ArmId("greedy".into()),
0.9,
1.0,
);
}
let exchanges = extract_solver_priors(&engine, &domain);
assert!(!exchanges.is_empty());
// Import into a fresh engine
let new_arms = vec!["greedy".into(), "exploratory".into()];
let mut new_engine = MetaThompsonEngine::new(new_arms);
let target = DomainId("target".into());
new_engine.init_domain_uniform(target.clone());
import_solver_priors(&mut new_engine, &target, &exchanges);
// Should have transferred priors
let extracted = new_engine.extract_prior(&target);
assert!(extracted.is_some());
}
#[test]
fn multi_segment_assembly() {
let prior = TransferPrior::uniform(DomainId("d1".into()));
let kernel = PolicyKernel::new("k0".into());
let mut curve = CostCurve::new(DomainId("d1".into()), ConvergenceThresholds::default());
curve.record(CostCurvePoint {
cycle: 0,
accuracy: 0.5,
cost_per_solve: 0.05,
robustness: 0.5,
policy_violations: 0,
timestamp: 0.0,
});
let assembled = assemble_domain_expansion_segments(&[prior], &[kernel], &[curve], 100);
// Should contain 3 segments, each 64-byte aligned
assert!(assembled.len() >= 3 * 64);
assert_eq!(assembled.len() % 64, 0);
// Verify first segment header magic
let magic = u32::from_le_bytes([assembled[0], assembled[1], assembled[2], assembled[3]]);
assert_eq!(magic, rvf_types::SEGMENT_MAGIC);
}
}

View File

@@ -0,0 +1,727 @@
//! Tool Orchestration Problems Domain
//!
//! Generates tasks requiring coordinating multiple tools/agents to achieve goals.
//! Task types include:
//!
//! - **PipelineConstruction**: Build a data processing pipeline from available tools
//! - **ErrorRecovery**: Handle failures in multi-step tool chains
//! - **ParallelCoordination**: Execute independent tool calls concurrently
//! - **ResourceNegotiation**: Manage shared resources across tool invocations
//! - **AdaptiveRouting**: Select tools dynamically based on intermediate results
//!
//! Cross-domain transfer is strongest here: planning decomposes goals,
//! Rust synthesis provides execution patterns, and orchestration combines them.
use crate::domain::{Domain, DomainEmbedding, DomainId, Evaluation, Solution, Task};
use rand::Rng;
use serde::{Deserialize, Serialize};
const EMBEDDING_DIM: usize = 64;
/// Categories of tool orchestration tasks.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum OrchestrationCategory {
/// Build a pipeline: chain tools to transform input to desired output.
PipelineConstruction,
/// Handle failure: detect errors and apply fallback strategies.
ErrorRecovery,
/// Coordinate parallel: dispatch independent calls and merge results.
ParallelCoordination,
/// Negotiate resources: manage rate limits, quotas, shared state.
ResourceNegotiation,
/// Adaptive routing: choose tool based on intermediate result properties.
AdaptiveRouting,
}
/// A tool available in the orchestration environment.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolSpec {
pub name: String,
pub description: String,
/// Input type signature (e.g., "text", "json", "binary").
pub input_type: String,
/// Output type signature.
pub output_type: String,
/// Average latency in milliseconds.
pub latency_ms: u32,
/// Failure rate [0.0, 1.0].
pub failure_rate: f32,
/// Cost per invocation.
pub cost: f32,
/// Rate limit (max calls per minute), 0 = unlimited.
pub rate_limit: u32,
}
/// An orchestration task specification.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OrchestrationTaskSpec {
pub category: OrchestrationCategory,
pub description: String,
/// Available tools in the environment.
pub available_tools: Vec<ToolSpec>,
/// Input to the pipeline.
pub input: serde_json::Value,
/// Expected output type/shape.
pub expected_output_type: String,
/// Maximum total latency budget (ms).
pub latency_budget_ms: u32,
/// Maximum total cost budget.
pub cost_budget: f32,
/// Required reliability (min success rate).
pub min_reliability: f32,
/// Error scenarios that must be handled.
pub error_scenarios: Vec<String>,
}
/// A tool call in an orchestration solution.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub tool_name: String,
/// Input to this tool call (ref to previous output or literal).
pub input_ref: String,
/// Whether this can run in parallel with other calls.
pub parallel_group: Option<u32>,
/// Fallback tool if this one fails.
pub fallback: Option<String>,
/// Retry count on failure.
pub retries: u32,
}
/// A parsed orchestration plan.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OrchestrationPlan {
pub calls: Vec<ToolCall>,
/// Error handling strategy description.
pub error_strategy: String,
}
/// Tool orchestration domain.
pub struct ToolOrchestrationDomain {
id: DomainId,
}
impl ToolOrchestrationDomain {
pub fn new() -> Self {
Self {
id: DomainId("tool_orchestration".to_string()),
}
}
fn base_tools() -> Vec<ToolSpec> {
vec![
ToolSpec {
name: "text_extract".into(),
description: "Extract text from documents".into(),
input_type: "binary".into(),
output_type: "text".into(),
latency_ms: 50,
failure_rate: 0.02,
cost: 0.001,
rate_limit: 100,
},
ToolSpec {
name: "text_embed".into(),
description: "Generate embeddings from text".into(),
input_type: "text".into(),
output_type: "vector".into(),
latency_ms: 30,
failure_rate: 0.01,
cost: 0.002,
rate_limit: 200,
},
ToolSpec {
name: "vector_search".into(),
description: "Search vector index for similar items".into(),
input_type: "vector".into(),
output_type: "json".into(),
latency_ms: 10,
failure_rate: 0.005,
cost: 0.0005,
rate_limit: 500,
},
ToolSpec {
name: "llm_generate".into(),
description: "Generate text using language model".into(),
input_type: "text".into(),
output_type: "text".into(),
latency_ms: 2000,
failure_rate: 0.05,
cost: 0.01,
rate_limit: 30,
},
ToolSpec {
name: "json_transform".into(),
description: "Apply JQ-like transformations to JSON".into(),
input_type: "json".into(),
output_type: "json".into(),
latency_ms: 5,
failure_rate: 0.001,
cost: 0.0001,
rate_limit: 0,
},
ToolSpec {
name: "code_execute".into(),
description: "Execute code in sandboxed environment".into(),
input_type: "text".into(),
output_type: "json".into(),
latency_ms: 500,
failure_rate: 0.1,
cost: 0.005,
rate_limit: 20,
},
ToolSpec {
name: "http_fetch".into(),
description: "Fetch data from external HTTP endpoint".into(),
input_type: "text".into(),
output_type: "json".into(),
latency_ms: 300,
failure_rate: 0.15,
cost: 0.0,
rate_limit: 60,
},
ToolSpec {
name: "cache_lookup".into(),
description: "Check local cache for previously computed results".into(),
input_type: "text".into(),
output_type: "json".into(),
latency_ms: 1,
failure_rate: 0.0,
cost: 0.0,
rate_limit: 0,
},
ToolSpec {
name: "validator".into(),
description: "Validate output against schema".into(),
input_type: "json".into(),
output_type: "json".into(),
latency_ms: 2,
failure_rate: 0.0,
cost: 0.0,
rate_limit: 0,
},
ToolSpec {
name: "aggregator".into(),
description: "Merge multiple results into one".into(),
input_type: "json".into(),
output_type: "json".into(),
latency_ms: 5,
failure_rate: 0.0,
cost: 0.0001,
rate_limit: 0,
},
]
}
fn gen_pipeline(&self, difficulty: f32) -> OrchestrationTaskSpec {
let tools = Self::base_tools();
let num_tools = if difficulty < 0.3 {
3
} else if difficulty < 0.7 {
6
} else {
10
};
OrchestrationTaskSpec {
category: OrchestrationCategory::PipelineConstruction,
description: format!(
"Build a RAG pipeline using {} tools: extract, embed, search, generate.",
num_tools
),
available_tools: tools[..num_tools.min(tools.len())].to_vec(),
input: serde_json::json!({"type": "binary", "format": "pdf"}),
expected_output_type: "text".into(),
latency_budget_ms: if difficulty < 0.5 { 5000 } else { 2000 },
cost_budget: if difficulty < 0.5 { 0.1 } else { 0.02 },
min_reliability: if difficulty < 0.5 { 0.9 } else { 0.99 },
error_scenarios: Vec::new(),
}
}
fn gen_error_recovery(&self, difficulty: f32) -> OrchestrationTaskSpec {
let tools = Self::base_tools();
let error_scenarios = if difficulty < 0.3 {
vec!["timeout on llm_generate".into()]
} else if difficulty < 0.7 {
vec![
"timeout on llm_generate".into(),
"http_fetch returns 429".into(),
"code_execute sandbox OOM".into(),
]
} else {
vec![
"timeout on llm_generate".into(),
"http_fetch returns 429".into(),
"code_execute sandbox OOM".into(),
"vector_search index corruption".into(),
"cascading failure: embed + search both down".into(),
]
};
OrchestrationTaskSpec {
category: OrchestrationCategory::ErrorRecovery,
description: format!(
"Handle {} error scenarios in a multi-tool pipeline with graceful degradation.",
error_scenarios.len()
),
available_tools: tools,
input: serde_json::json!({"type": "text", "content": "query"}),
expected_output_type: "json".into(),
latency_budget_ms: 10000,
cost_budget: 0.1,
min_reliability: 0.95,
error_scenarios,
}
}
fn gen_parallel_coordination(&self, difficulty: f32) -> OrchestrationTaskSpec {
let tools = Self::base_tools();
let parallelism = if difficulty < 0.3 {
2
} else if difficulty < 0.7 {
4
} else {
8
};
OrchestrationTaskSpec {
category: OrchestrationCategory::ParallelCoordination,
description: format!(
"Execute {} independent tool chains in parallel, merge results within latency budget.",
parallelism
),
available_tools: tools,
input: serde_json::json!({"queries": (0..parallelism).map(|i| format!("query_{}", i)).collect::<Vec<_>>()}),
expected_output_type: "json".into(),
latency_budget_ms: if difficulty < 0.5 { 3000 } else { 1000 },
cost_budget: 0.05 * parallelism as f32,
min_reliability: 0.95,
error_scenarios: Vec::new(),
}
}
fn extract_features(&self, solution: &Solution) -> Vec<f32> {
let content = &solution.content;
let mut features = vec![0.0f32; EMBEDDING_DIM];
let plan: OrchestrationPlan = serde_json::from_str(&solution.data.to_string())
.or_else(|_| serde_json::from_str(content))
.unwrap_or(OrchestrationPlan {
calls: Vec::new(),
error_strategy: String::new(),
});
// Feature 0-7: Plan structure
features[0] = plan.calls.len() as f32 / 20.0;
let unique_tools: std::collections::HashSet<&str> =
plan.calls.iter().map(|c| c.tool_name.as_str()).collect();
features[1] = unique_tools.len() as f32 / 10.0;
// Parallelism ratio
let parallel_calls = plan
.calls
.iter()
.filter(|c| c.parallel_group.is_some())
.count();
features[2] = parallel_calls as f32 / plan.calls.len().max(1) as f32;
// Fallback coverage
let fallback_calls = plan.calls.iter().filter(|c| c.fallback.is_some()).count();
features[3] = fallback_calls as f32 / plan.calls.len().max(1) as f32;
// Average retries
let total_retries: u32 = plan.calls.iter().map(|c| c.retries).sum();
features[4] = total_retries as f32 / plan.calls.len().max(1) as f32 / 5.0;
// Feature 8-15: Tool type usage
let tool_names = [
"extract",
"embed",
"search",
"generate",
"transform",
"execute",
"fetch",
"cache",
];
for (i, name) in tool_names.iter().enumerate() {
features[8 + i] = plan
.calls
.iter()
.filter(|c| c.tool_name.contains(name))
.count() as f32
/ plan.calls.len().max(1) as f32;
}
// Feature 16-23: Text pattern features
features[16] = content.matches("pipeline").count() as f32 / 3.0;
features[17] = content.matches("parallel").count() as f32 / 5.0;
features[18] = content.matches("fallback").count() as f32 / 5.0;
features[19] = content.matches("retry").count() as f32 / 5.0;
features[20] = content.matches("cache").count() as f32 / 5.0;
features[21] = content.matches("timeout").count() as f32 / 3.0;
features[22] = content.matches("merge").count() as f32 / 3.0;
features[23] = content.matches("validate").count() as f32 / 3.0;
// Feature 32-39: Error handling patterns
features[32] = content.matches("error").count() as f32 / 5.0;
features[33] = content.matches("recover").count() as f32 / 3.0;
features[34] = content.matches("degrade").count() as f32 / 3.0;
features[35] = content.matches("circuit_break").count() as f32 / 2.0;
features[36] = content.matches("rate_limit").count() as f32 / 3.0;
features[37] = content.matches("backoff").count() as f32 / 3.0;
features[38] = content.matches("health_check").count() as f32 / 2.0;
features[39] = content.matches("monitor").count() as f32 / 3.0;
// Feature 48-55: Coordination patterns
features[48] = content.matches("scatter").count() as f32 / 2.0;
features[49] = content.matches("gather").count() as f32 / 2.0;
features[50] = content.matches("fan_out").count() as f32 / 2.0;
features[51] = content.matches("aggregate").count() as f32 / 3.0;
features[52] = content.matches("route").count() as f32 / 3.0;
features[53] = content.matches("dispatch").count() as f32 / 3.0;
features[54] = content.matches("await").count() as f32 / 5.0;
features[55] = content.matches("join").count() as f32 / 3.0;
// Normalize
let norm: f32 = features.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for f in &mut features {
*f /= norm;
}
}
features
}
fn score_orchestration(&self, spec: &OrchestrationTaskSpec, solution: &Solution) -> Evaluation {
let content = &solution.content;
let mut correctness = 0.0f32;
let mut efficiency = 0.5f32;
let mut elegance = 0.5f32;
let mut notes = Vec::new();
let plan: Option<OrchestrationPlan> = serde_json::from_str(&solution.data.to_string())
.ok()
.or_else(|| serde_json::from_str(content).ok());
let plan = match plan {
Some(p) => p,
None => {
let has_tools = spec
.available_tools
.iter()
.any(|t| content.contains(&t.name));
if has_tools {
correctness = 0.2;
}
return Evaluation {
score: correctness * 0.6,
correctness,
efficiency: 0.0,
elegance: 0.0,
constraint_results: Vec::new(),
notes: vec!["Could not parse orchestration plan".into()],
};
}
};
if plan.calls.is_empty() {
return Evaluation::zero(vec!["Empty orchestration plan".into()]);
}
// Correctness: type chain validity
let mut type_errors = 0;
for window in plan.calls.windows(2) {
let output_tool = spec
.available_tools
.iter()
.find(|t| t.name == window[0].tool_name);
let input_tool = spec
.available_tools
.iter()
.find(|t| t.name == window[1].tool_name);
if let (Some(out_t), Some(in_t)) = (output_tool, input_tool) {
if window[1].parallel_group.is_none() && out_t.output_type != in_t.input_type {
type_errors += 1;
notes.push(format!(
"Type mismatch: {} outputs {} but {} expects {}",
out_t.name, out_t.output_type, in_t.name, in_t.input_type
));
}
}
}
let chain_len = (plan.calls.len() - 1).max(1);
correctness = 1.0 - (type_errors as f32 / chain_len as f32);
// Tool coverage: do we use tools that produce the expected output?
let produces_output = plan.calls.iter().any(|c| {
spec.available_tools
.iter()
.any(|t| t.name == c.tool_name && t.output_type == spec.expected_output_type)
});
if !produces_output {
correctness *= 0.5;
notes.push("No tool produces the expected output type".into());
}
// Error handling coverage
if !spec.error_scenarios.is_empty() {
let handled = spec
.error_scenarios
.iter()
.filter(|scenario| {
plan.calls
.iter()
.any(|c| c.fallback.is_some() || c.retries > 0)
|| plan
.error_strategy
.contains(&scenario.as_str()[..scenario.len().min(10)])
})
.count() as f32
/ spec.error_scenarios.len() as f32;
correctness = correctness * 0.7 + handled * 0.3;
}
// Efficiency: estimated latency and cost
let est_latency: u32 = {
let mut groups: std::collections::HashMap<u32, u32> = std::collections::HashMap::new();
let mut sequential_latency = 0u32;
for call in &plan.calls {
let tool_latency = spec
.available_tools
.iter()
.find(|t| t.name == call.tool_name)
.map(|t| t.latency_ms)
.unwrap_or(100);
if let Some(group) = call.parallel_group {
let entry = groups.entry(group).or_insert(0);
*entry = (*entry).max(tool_latency);
} else {
sequential_latency += tool_latency;
}
}
sequential_latency + groups.values().sum::<u32>()
};
if est_latency <= spec.latency_budget_ms {
efficiency = 1.0 - (est_latency as f32 / spec.latency_budget_ms as f32 * 0.5);
} else {
efficiency = spec.latency_budget_ms as f32 / est_latency as f32 * 0.5;
notes.push(format!(
"Estimated latency {}ms exceeds budget {}ms",
est_latency, spec.latency_budget_ms
));
}
let est_cost: f32 = plan
.calls
.iter()
.filter_map(|c| {
spec.available_tools
.iter()
.find(|t| t.name == c.tool_name)
.map(|t| t.cost * (1.0 + c.retries as f32))
})
.sum();
if est_cost > spec.cost_budget {
efficiency *= 0.7;
notes.push(format!(
"Cost {:.4} exceeds budget {:.4}",
est_cost, spec.cost_budget
));
}
// Elegance: parallelism, caching, minimal redundancy
let parallelism_used = plan.calls.iter().any(|c| c.parallel_group.is_some());
if parallelism_used {
elegance += 0.15;
}
let cache_used = plan.calls.iter().any(|c| c.tool_name.contains("cache"));
if cache_used {
elegance += 0.1;
}
let validation_used = plan.calls.iter().any(|c| c.tool_name.contains("validat"));
if validation_used {
elegance += 0.1;
}
// Penalize excessive retries
let total_retries: u32 = plan.calls.iter().map(|c| c.retries).sum();
if total_retries > plan.calls.len() as u32 * 2 {
elegance -= 0.2;
notes.push("Excessive retry configuration".into());
}
elegance = elegance.clamp(0.0, 1.0);
let score = 0.6 * correctness + 0.25 * efficiency + 0.15 * elegance;
Evaluation {
score: score.clamp(0.0, 1.0),
correctness,
efficiency,
elegance,
constraint_results: Vec::new(),
notes,
}
}
}
impl Default for ToolOrchestrationDomain {
fn default() -> Self {
Self::new()
}
}
impl Domain for ToolOrchestrationDomain {
fn id(&self) -> &DomainId {
&self.id
}
fn name(&self) -> &str {
"Tool Orchestration"
}
fn generate_tasks(&self, count: usize, difficulty: f32) -> Vec<Task> {
let mut rng = rand::thread_rng();
let difficulty = difficulty.clamp(0.0, 1.0);
(0..count)
.map(|i| {
let roll: f32 = rng.gen();
let spec = if roll < 0.4 {
self.gen_pipeline(difficulty)
} else if roll < 0.7 {
self.gen_error_recovery(difficulty)
} else {
self.gen_parallel_coordination(difficulty)
};
Task {
id: format!("orch_{}_d{:.0}", i, difficulty * 100.0),
domain_id: self.id.clone(),
difficulty,
spec: serde_json::to_value(&spec).unwrap_or_default(),
constraints: Vec::new(),
}
})
.collect()
}
fn evaluate(&self, task: &Task, solution: &Solution) -> Evaluation {
let spec: OrchestrationTaskSpec = match serde_json::from_value(task.spec.clone()) {
Ok(s) => s,
Err(e) => return Evaluation::zero(vec![format!("Invalid task spec: {}", e)]),
};
self.score_orchestration(&spec, solution)
}
fn embed(&self, solution: &Solution) -> DomainEmbedding {
let features = self.extract_features(solution);
DomainEmbedding::new(features, self.id.clone())
}
fn embedding_dim(&self) -> usize {
EMBEDDING_DIM
}
fn reference_solution(&self, task: &Task) -> Option<Solution> {
let spec: OrchestrationTaskSpec = serde_json::from_value(task.spec.clone()).ok()?;
// Build a sequential pipeline through available tools
let calls: Vec<ToolCall> = spec
.available_tools
.iter()
.map(|t| ToolCall {
tool_name: t.name.clone(),
input_ref: "previous".into(),
parallel_group: None,
fallback: None,
retries: if t.failure_rate > 0.05 { 2 } else { 0 },
})
.collect();
let plan = OrchestrationPlan {
calls,
error_strategy: "retry with exponential backoff".into(),
};
let content = serde_json::to_string_pretty(&plan).ok()?;
Some(Solution {
task_id: task.id.clone(),
content,
data: serde_json::to_value(&plan).ok()?,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_orchestration_tasks() {
let domain = ToolOrchestrationDomain::new();
let tasks = domain.generate_tasks(5, 0.5);
assert_eq!(tasks.len(), 5);
for task in &tasks {
assert_eq!(task.domain_id, domain.id);
}
}
#[test]
fn test_reference_solution() {
let domain = ToolOrchestrationDomain::new();
let tasks = domain.generate_tasks(3, 0.3);
for task in &tasks {
let ref_sol = domain.reference_solution(task);
assert!(ref_sol.is_some());
}
}
#[test]
fn test_evaluate_reference() {
let domain = ToolOrchestrationDomain::new();
let tasks = domain.generate_tasks(3, 0.3);
for task in &tasks {
if let Some(solution) = domain.reference_solution(task) {
let eval = domain.evaluate(task, &solution);
assert!(eval.score >= 0.0 && eval.score <= 1.0);
}
}
}
#[test]
fn test_embed_orchestration() {
let domain = ToolOrchestrationDomain::new();
let solution = Solution {
task_id: "test".into(),
content: "pipeline: extract -> embed -> search with fallback and retry".into(),
data: serde_json::json!({
"calls": [
{"tool_name": "text_extract", "input_ref": "input", "retries": 1}
],
"error_strategy": "retry"
}),
};
let embedding = domain.embed(&solution);
assert_eq!(embedding.dim, EMBEDDING_DIM);
}
#[test]
fn test_difficulty_affects_error_scenarios() {
let domain = ToolOrchestrationDomain::new();
// Generate many tasks at high difficulty to get error recovery tasks
let hard = domain.generate_tasks(20, 0.9);
let has_error_tasks = hard.iter().any(|t| {
let spec: OrchestrationTaskSpec = serde_json::from_value(t.spec.clone()).unwrap();
!spec.error_scenarios.is_empty()
});
assert!(
has_error_tasks,
"High difficulty should produce error scenarios"
);
}
}

View File

@@ -0,0 +1,583 @@
//! Cross-Domain Transfer Engine with Meta Thompson Sampling
//!
//! Transfer happens through priors, not raw memories.
//! Ship compact priors and verified kernels between domains.
//!
//! ## Two-Layer Learning Architecture
//!
//! **Policy learning layer**: Chooses strategies, budgets, and tool paths
//! using uncertainty-aware selection (Thompson Sampling with Beta priors).
//!
//! **Operator layer**: Executes deterministic kernels and graders,
//! logs witnesses, and commits state through gates.
//!
//! ## Meta Thompson Sampling
//!
//! After each cycle, compute posterior summary per bucket and arm.
//! Store as TransferPrior. When a new domain starts, initialize its
//! buckets with these priors instead of uniform, enabling faster adaptation.
//!
//! ## Cross-Domain Transfer Protocol
//!
//! A delta is promotable only if it improves Domain 2 without regressing
//! Domain 1, or improves Domain 1 without regressing Domain 2.
//! That is generalization.
use crate::domain::DomainId;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Beta distribution parameters for Thompson Sampling.
/// Represents uncertainty about an arm's reward probability.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BetaParams {
/// Success count + prior (alpha).
pub alpha: f32,
/// Failure count + prior (beta).
pub beta: f32,
}
impl BetaParams {
/// Uniform (uninformative) prior: Beta(1, 1).
pub fn uniform() -> Self {
Self {
alpha: 1.0,
beta: 1.0,
}
}
/// Create from observed successes and failures.
pub fn from_observations(successes: f32, failures: f32) -> Self {
Self {
alpha: successes + 1.0,
beta: failures + 1.0,
}
}
/// Mean of the Beta distribution: E[X] = alpha / (alpha + beta).
pub fn mean(&self) -> f32 {
self.alpha / (self.alpha + self.beta)
}
/// Variance: measures uncertainty. Lower = more confident.
pub fn variance(&self) -> f32 {
let total = self.alpha + self.beta;
(self.alpha * self.beta) / (total * total * (total + 1.0))
}
/// Sample from the Beta distribution using the Kumaraswamy approximation.
/// Fast, no special functions needed, good enough for Thompson Sampling.
pub fn sample(&self, rng: &mut impl Rng) -> f32 {
// Use inverse CDF of Beta via simple approximation
let u: f32 = rng.gen_range(0.001..0.999);
// Kumaraswamy approximation: x = (1 - (1 - u^(1/b))^(1/a))
// Better approximation using ratio of gammas via the normal approach
let x = Self::beta_inv_approx(u, self.alpha, self.beta);
x.clamp(0.0, 1.0)
}
/// Approximate inverse CDF of Beta distribution.
fn beta_inv_approx(p: f32, a: f32, b: f32) -> f32 {
// Use normal approximation for Beta when a,b are not too small
if a > 1.0 && b > 1.0 {
let mean = a / (a + b);
let var = (a * b) / ((a + b) * (a + b) * (a + b + 1.0));
let std = var.sqrt();
// Inverse normal approximation (Abramowitz & Stegun)
let t = if p < 0.5 {
(-2.0 * (p).ln()).sqrt()
} else {
(-2.0 * (1.0 - p).ln()).sqrt()
};
let x = if p < 0.5 {
mean - std * t
} else {
mean + std * t
};
x.clamp(0.001, 0.999)
} else {
// Fallback: simple power approximation
p.powf(1.0 / a) * (1.0 - (1.0 - p).powf(1.0 / b)) + p.powf(1.0 / a) * 0.5
}
}
/// Update with an observation (Bayesian posterior update).
pub fn update(&mut self, reward: f32) {
self.alpha += reward;
self.beta += 1.0 - reward;
}
/// Merge two Beta distributions (approximate: sum parameters).
pub fn merge(&self, other: &BetaParams) -> BetaParams {
BetaParams {
alpha: self.alpha + other.alpha - 1.0, // subtract uniform prior
beta: self.beta + other.beta - 1.0,
}
}
}
/// A context bucket groups similar problem instances for targeted learning.
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct ContextBucket {
/// Difficulty tier: "easy", "medium", "hard".
pub difficulty_tier: String,
/// Problem category within the domain.
pub category: String,
}
/// An arm in the multi-armed bandit: a strategy choice.
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct ArmId(pub String);
/// Transfer prior: compact posterior summary from a source domain.
/// This is what gets shipped between domains — not raw trajectories.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransferPrior {
/// Source domain that generated this prior.
pub source_domain: DomainId,
/// Per-bucket, per-arm Beta parameters (posterior summaries).
pub bucket_priors: HashMap<ContextBucket, HashMap<ArmId, BetaParams>>,
/// Cost EMA (exponential moving average) priors per bucket.
pub cost_ema_priors: HashMap<ContextBucket, f32>,
/// Number of cycles this prior was trained on.
pub training_cycles: u64,
/// Witness hash: proof of how this prior was derived.
pub witness_hash: String,
}
impl TransferPrior {
/// Create an empty (uniform) prior for a domain.
pub fn uniform(source_domain: DomainId) -> Self {
Self {
source_domain,
bucket_priors: HashMap::new(),
cost_ema_priors: HashMap::new(),
training_cycles: 0,
witness_hash: String::new(),
}
}
/// Get the prior for a specific bucket and arm, defaulting to uniform.
pub fn get_prior(&self, bucket: &ContextBucket, arm: &ArmId) -> BetaParams {
self.bucket_priors
.get(bucket)
.and_then(|arms| arms.get(arm))
.cloned()
.unwrap_or_else(BetaParams::uniform)
}
/// Update the posterior for a bucket/arm with a new observation.
pub fn update_posterior(&mut self, bucket: ContextBucket, arm: ArmId, reward: f32) {
let arms = self.bucket_priors.entry(bucket.clone()).or_default();
let params = arms.entry(arm).or_insert_with(BetaParams::uniform);
params.update(reward);
self.training_cycles += 1;
}
/// Update cost EMA for a bucket.
pub fn update_cost_ema(&mut self, bucket: ContextBucket, cost: f32, decay: f32) {
let entry = self.cost_ema_priors.entry(bucket).or_insert(cost);
*entry = decay * (*entry) + (1.0 - decay) * cost;
}
/// Extract a compact summary suitable for shipping to another domain.
pub fn extract_summary(&self) -> TransferPrior {
// Only ship buckets with sufficient evidence (>10 observations)
let filtered: HashMap<ContextBucket, HashMap<ArmId, BetaParams>> = self
.bucket_priors
.iter()
.filter_map(|(bucket, arms)| {
let significant_arms: HashMap<ArmId, BetaParams> = arms
.iter()
.filter(|(_, params)| (params.alpha + params.beta) > 12.0)
.map(|(arm, params)| (arm.clone(), params.clone()))
.collect();
if significant_arms.is_empty() {
None
} else {
Some((bucket.clone(), significant_arms))
}
})
.collect();
TransferPrior {
source_domain: self.source_domain.clone(),
bucket_priors: filtered,
cost_ema_priors: self.cost_ema_priors.clone(),
training_cycles: self.training_cycles,
witness_hash: self.witness_hash.clone(),
}
}
}
/// Meta Thompson Sampling engine that manages priors across domains.
pub struct MetaThompsonEngine {
/// Active priors per domain.
domain_priors: HashMap<DomainId, TransferPrior>,
/// Available arms (strategies) shared across domains.
arms: Vec<ArmId>,
/// Difficulty tiers for bucketing.
difficulty_tiers: Vec<String>,
}
impl MetaThompsonEngine {
/// Create a new engine with the given strategy arms.
pub fn new(arms: Vec<String>) -> Self {
Self {
domain_priors: HashMap::new(),
arms: arms.into_iter().map(ArmId).collect(),
difficulty_tiers: vec!["easy".into(), "medium".into(), "hard".into()],
}
}
/// Initialize a domain with uniform priors.
pub fn init_domain_uniform(&mut self, domain_id: DomainId) {
self.domain_priors
.insert(domain_id.clone(), TransferPrior::uniform(domain_id));
}
/// Initialize a domain using transfer priors from a source domain.
/// This is the key mechanism: Meta-TS seeds new domains with learned priors.
pub fn init_domain_with_transfer(
&mut self,
target_domain: DomainId,
source_prior: &TransferPrior,
) {
let mut prior = TransferPrior::uniform(target_domain.clone());
// Copy bucket priors from source, scaling by confidence
for (bucket, arms) in &source_prior.bucket_priors {
for (arm, params) in arms {
// Dampen the prior: don't fully trust cross-domain evidence.
// Use sqrt scaling: reduces confidence while preserving mean.
let dampened = BetaParams {
alpha: 1.0 + (params.alpha - 1.0).sqrt(),
beta: 1.0 + (params.beta - 1.0).sqrt(),
};
prior
.bucket_priors
.entry(bucket.clone())
.or_default()
.insert(arm.clone(), dampened);
}
}
// Transfer cost EMAs with dampening
for (bucket, &cost) in &source_prior.cost_ema_priors {
prior.cost_ema_priors.insert(bucket.clone(), cost * 1.5); // pessimistic transfer
}
prior.witness_hash = format!("transfer_from_{}", source_prior.source_domain);
self.domain_priors.insert(target_domain, prior);
}
/// Select an arm for a given domain and context using Thompson Sampling.
pub fn select_arm(
&self,
domain_id: &DomainId,
bucket: &ContextBucket,
rng: &mut impl Rng,
) -> Option<ArmId> {
let prior = self.domain_priors.get(domain_id)?;
let mut best_arm = None;
let mut best_sample = f32::NEG_INFINITY;
for arm in &self.arms {
let params = prior.get_prior(bucket, arm);
let sample = params.sample(rng);
if sample > best_sample {
best_sample = sample;
best_arm = Some(arm.clone());
}
}
best_arm
}
/// Record the outcome of using an arm in a domain.
pub fn record_outcome(
&mut self,
domain_id: &DomainId,
bucket: ContextBucket,
arm: ArmId,
reward: f32,
cost: f32,
) {
if let Some(prior) = self.domain_priors.get_mut(domain_id) {
prior.update_posterior(bucket.clone(), arm, reward);
prior.update_cost_ema(bucket, cost, 0.9);
}
}
/// Extract transfer prior from a domain (for shipping to another domain).
pub fn extract_prior(&self, domain_id: &DomainId) -> Option<TransferPrior> {
self.domain_priors
.get(domain_id)
.map(|p| p.extract_summary())
}
/// Get all domain IDs currently tracked.
pub fn domain_ids(&self) -> Vec<&DomainId> {
self.domain_priors.keys().collect()
}
/// Check if posterior variance is high (triggers speculative dual-path).
pub fn is_uncertain(
&self,
domain_id: &DomainId,
bucket: &ContextBucket,
threshold: f32,
) -> bool {
let prior = match self.domain_priors.get(domain_id) {
Some(p) => p,
None => return true, // No data = maximum uncertainty
};
// Check if top two arms are within delta of each other
let mut samples: Vec<(f32, &ArmId)> = self
.arms
.iter()
.map(|arm| {
let params = prior.get_prior(bucket, arm);
(params.mean(), arm)
})
.collect();
samples.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
if samples.len() < 2 {
return true;
}
let gap = samples[0].0 - samples[1].0;
gap < threshold
}
}
/// Speculative dual-path execution for high-uncertainty decisions.
/// When the top two arms are within delta, run both and pick the winner.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DualPathResult {
/// Primary arm and its outcome.
pub primary: (ArmId, f32),
/// Secondary arm and its outcome.
pub secondary: (ArmId, f32),
/// Which arm won.
pub winner: ArmId,
/// The loser becomes a counterexample for that context.
pub counterexample: ArmId,
}
/// Cross-domain transfer verification result.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransferVerification {
/// Source domain.
pub source: DomainId,
/// Target domain.
pub target: DomainId,
/// Did transfer improve the target domain?
pub improved_target: bool,
/// Did transfer regress the source domain?
pub regressed_source: bool,
/// Is this delta promotable? (improved target AND not regressed source).
pub promotable: bool,
/// Acceleration factor: ratio of convergence speeds.
pub acceleration_factor: f32,
/// Source score before/after.
pub source_scores: (f32, f32),
/// Target score before/after.
pub target_scores: (f32, f32),
}
impl TransferVerification {
/// Verify a transfer delta against the generalization rule:
/// promotable iff it improves Domain 2 without regressing Domain 1.
pub fn verify(
source: DomainId,
target: DomainId,
source_before: f32,
source_after: f32,
target_before: f32,
target_after: f32,
target_baseline_cycles: u64,
target_transfer_cycles: u64,
) -> Self {
let improved_target = target_after > target_before;
let regressed_source = source_after < source_before - 0.01; // small tolerance
let promotable = improved_target && !regressed_source;
// Acceleration = baseline_cycles / transfer_cycles (higher = better transfer)
let acceleration_factor = if target_transfer_cycles > 0 {
target_baseline_cycles as f32 / target_transfer_cycles as f32
} else {
1.0
};
Self {
source,
target,
improved_target,
regressed_source,
promotable,
acceleration_factor,
source_scores: (source_before, source_after),
target_scores: (target_before, target_after),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_beta_params_uniform() {
let p = BetaParams::uniform();
assert_eq!(p.alpha, 1.0);
assert_eq!(p.beta, 1.0);
assert!((p.mean() - 0.5).abs() < 1e-6);
}
#[test]
fn test_beta_params_update() {
let mut p = BetaParams::uniform();
p.update(1.0); // success
assert_eq!(p.alpha, 2.0);
assert_eq!(p.beta, 1.0);
assert!(p.mean() > 0.5);
}
#[test]
fn test_beta_params_sample_in_range() {
let p = BetaParams::from_observations(10.0, 5.0);
let mut rng = rand::thread_rng();
for _ in 0..100 {
let s = p.sample(&mut rng);
assert!(s >= 0.0 && s <= 1.0, "Sample {} out of [0,1]", s);
}
}
#[test]
fn test_transfer_prior_round_trip() {
let domain = DomainId("test".into());
let mut prior = TransferPrior::uniform(domain);
let bucket = ContextBucket {
difficulty_tier: "easy".into(),
category: "transform".into(),
};
let arm = ArmId("strategy_a".into());
for _ in 0..20 {
prior.update_posterior(bucket.clone(), arm.clone(), 0.8);
}
let summary = prior.extract_summary();
assert!(!summary.bucket_priors.is_empty());
let retrieved = summary.get_prior(&bucket, &arm);
assert!(retrieved.mean() > 0.5);
}
#[test]
fn test_meta_thompson_engine() {
let mut engine = MetaThompsonEngine::new(vec![
"strategy_a".into(),
"strategy_b".into(),
"strategy_c".into(),
]);
let domain1 = DomainId("rust_synthesis".into());
engine.init_domain_uniform(domain1.clone());
let bucket = ContextBucket {
difficulty_tier: "medium".into(),
category: "algorithm".into(),
};
let mut rng = rand::thread_rng();
// Record some outcomes
for _ in 0..50 {
let arm = engine.select_arm(&domain1, &bucket, &mut rng).unwrap();
let reward = if arm.0 == "strategy_a" { 0.9 } else { 0.3 };
engine.record_outcome(&domain1, bucket.clone(), arm, reward, 1.0);
}
// Extract prior and transfer to domain2
let prior = engine.extract_prior(&domain1).unwrap();
let domain2 = DomainId("planning".into());
engine.init_domain_with_transfer(domain2.clone(), &prior);
// Domain2 should now have informative priors
let d2_prior = engine.domain_priors.get(&domain2).unwrap();
let a_params = d2_prior.get_prior(&bucket, &ArmId("strategy_a".into()));
assert!(
a_params.mean() > 0.5,
"Transferred prior should favor strategy_a"
);
}
#[test]
fn test_transfer_verification() {
let v = TransferVerification::verify(
DomainId("d1".into()),
DomainId("d2".into()),
0.8, // source before
0.79, // source after (slight decrease, within tolerance)
0.3, // target before
0.7, // target after (big improvement)
100, // baseline cycles
40, // transfer cycles
);
assert!(v.improved_target);
assert!(!v.regressed_source); // within tolerance
assert!(v.promotable);
assert!((v.acceleration_factor - 2.5).abs() < 1e-4);
}
#[test]
fn test_transfer_not_promotable_on_regression() {
let v = TransferVerification::verify(
DomainId("d1".into()),
DomainId("d2".into()),
0.8, // source before
0.5, // source after (regression!)
0.3, // target before
0.7, // target after
100,
40,
);
assert!(v.improved_target);
assert!(v.regressed_source);
assert!(!v.promotable);
}
#[test]
fn test_uncertainty_detection() {
let mut engine = MetaThompsonEngine::new(vec!["a".into(), "b".into()]);
let domain = DomainId("test".into());
engine.init_domain_uniform(domain.clone());
let bucket = ContextBucket {
difficulty_tier: "easy".into(),
category: "test".into(),
};
// With uniform priors, should be uncertain
assert!(engine.is_uncertain(&domain, &bucket, 0.1));
// After many observations favoring one arm, should be certain
for _ in 0..100 {
engine.record_outcome(&domain, bucket.clone(), ArmId("a".into()), 0.95, 1.0);
engine.record_outcome(&domain, bucket.clone(), ArmId("b".into()), 0.1, 1.0);
}
assert!(!engine.is_uncertain(&domain, &bucket, 0.1));
}
}