Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
37
vendor/ruvector/crates/ruvector-domain-expansion/Cargo.toml
vendored
Normal file
37
vendor/ruvector/crates/ruvector-domain-expansion/Cargo.toml
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
[package]
|
||||
name = "ruvector-domain-expansion"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
rust-version.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
repository.workspace = true
|
||||
description = "Cross-domain transfer learning engine: Rust synthesis, structured planning, tool orchestration"
|
||||
keywords = ["transfer-learning", "domain-expansion", "generalization", "rust-synthesis", "planning"]
|
||||
categories = ["algorithms", "science"]
|
||||
|
||||
[features]
|
||||
default = []
|
||||
rvf = ["dep:rvf-types", "dep:rvf-wire", "dep:rvf-crypto"]
|
||||
|
||||
[dependencies]
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
|
||||
# RVF integration (optional, behind "rvf" feature)
|
||||
rvf-types = { version = "0.2.0", path = "../rvf/rvf-types", optional = true }
|
||||
rvf-wire = { version = "0.1.0", path = "../rvf/rvf-wire", optional = true }
|
||||
rvf-crypto = { version = "0.2.0", path = "../rvf/rvf-crypto", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
proptest = { workspace = true }
|
||||
criterion = { workspace = true }
|
||||
|
||||
[[bench]]
|
||||
name = "domain_expansion_bench"
|
||||
harness = false
|
||||
|
||||
[lib]
|
||||
crate-type = ["rlib"]
|
||||
221
vendor/ruvector/crates/ruvector-domain-expansion/README.md
vendored
Normal file
221
vendor/ruvector/crates/ruvector-domain-expansion/README.md
vendored
Normal file
@@ -0,0 +1,221 @@
|
||||
# ruvector-domain-expansion
|
||||
|
||||
[](https://crates.io/crates/ruvector-domain-expansion)
|
||||
[](https://docs.rs/ruvector-domain-expansion)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://www.rust-lang.org)
|
||||
|
||||
**Cross-domain transfer learning — train on one problem, get better at a different one automatically.**
|
||||
|
||||
```toml
|
||||
ruvector-domain-expansion = "0.1"
|
||||
```
|
||||
|
||||
Most AI systems learn one task at a time. Train a model on genomics and it can't trade stocks. Teach it quantum circuits and it won't plan workflows. `ruvector-domain-expansion` changes that: knowledge learned in one domain automatically transfers to other domains — and it **proves** the transfer actually helped before committing it. Genomics priors seed molecular design. Trading risk models improve resource allocation. Quantum noise detection accelerates signal processing. This is how real generalization works. Part of the [RuVector](https://github.com/ruvnet/ruvector) ecosystem.
|
||||
|
||||
| | ruvector-domain-expansion | Traditional Fine-Tuning |
|
||||
|---|---|---|
|
||||
| **Learning scope** | Learns across 13+ domains — genomics, trading, quantum, code, planning | One task at a time |
|
||||
| **Transfer** | Automatic: priors from Domain 1 seed Domain 2 | Manual: retrain from scratch per domain |
|
||||
| **Verification** | Transfer only accepted if it helps target without hurting source | No verification — hope it works |
|
||||
| **Strategy selection** | Thompson Sampling picks the best approach per context | Fixed strategy for all inputs |
|
||||
| **Population search** | 8 policy variants evolve in parallel, best survives | Single model, single strategy |
|
||||
| **Curiosity** | Explores under-visited areas automatically | Only learns from data you provide |
|
||||
|
||||
## Quick Start
|
||||
|
||||
```rust
|
||||
use ruvector_domain_expansion::{
|
||||
DomainExpansionEngine, DomainId, ContextBucket, ArmId,
|
||||
};
|
||||
|
||||
let mut engine = DomainExpansionEngine::new();
|
||||
|
||||
// Generate training tasks in any domain
|
||||
let domain = DomainId("rust_synthesis".into());
|
||||
let tasks = engine.generate_tasks(&domain, 10, 0.5); // 10 tasks, medium difficulty
|
||||
|
||||
// Select strategy using Thompson Sampling
|
||||
let bucket = ContextBucket { difficulty_tier: "medium".into(), category: "algorithm".into() };
|
||||
let arm = engine.select_arm(&domain, &bucket).unwrap();
|
||||
|
||||
// Evaluate and learn
|
||||
let eval = engine.evaluate_and_record(&domain, &tasks[0], &solution, bucket, arm);
|
||||
|
||||
// Transfer knowledge to a completely different domain
|
||||
let target = DomainId("structured_planning".into());
|
||||
engine.initiate_transfer(&domain, &target);
|
||||
// Planning now starts at 0.70 accuracy instead of 0.30 — transfer verified and promoted
|
||||
```
|
||||
|
||||
## Key Features
|
||||
|
||||
| Feature | What It Does | Why It Matters |
|
||||
|---------|-------------|----------------|
|
||||
| **Meta Thompson Sampling** | Picks the best strategy per context using uncertainty-aware selection | Explores when unsure, exploits when confident — no manual tuning |
|
||||
| **Cross-Domain Transfer** | Extracts compact priors from one domain, seeds another | New domains learn faster by starting with knowledge from related domains |
|
||||
| **Transfer Verification** | Accepts a transfer only if target improves without source regressing | Guarantees generalization — no silent regressions |
|
||||
| **Population-Based Search** | Evolves 8 policy kernel variants in parallel | Finds optimal strategies faster than single-model training |
|
||||
| **Curiosity-Driven Exploration** | UCB-style bonus for under-visited contexts | Automatically explores blind spots instead of getting stuck |
|
||||
| **Pareto Front Tracking** | Tracks non-dominated kernels across accuracy, cost, and robustness | See the best tradeoffs, not just the single "best" model |
|
||||
| **Plateau Detection** | Detects when learning stalls and recommends actions | Automatically switches strategies instead of wasting compute |
|
||||
| **Counterexample Tracking** | Records failed solutions to inform future decisions | Learns from mistakes, not just successes |
|
||||
| **Cost Curve & Scoreboard** | Tracks convergence speed per domain with acceleration metrics | Proves that transfer actually accelerated learning |
|
||||
| **RVF Integration** | Package trained models as cognitive containers (optional `rvf` feature) | Ship a trained domain expansion engine as a single `.rvf` file |
|
||||
|
||||
## Domain Ecosystem
|
||||
|
||||
Domain expansion draws on the full RuVector capability stack. Each domain contributes unique knowledge that transfers to others through shared embedding spaces.
|
||||
|
||||
### Core Domains (Built-In)
|
||||
|
||||
| Domain | What It Generates | What It Evaluates |
|
||||
|--------|------------------|-------------------|
|
||||
| **Rust Synthesis** | Rust function specs (transforms, filters, searches) | Correctness, efficiency, idiomatic style |
|
||||
| **Structured Planning** | Multi-step plans with dependencies and resources | Feasibility, completeness, dependency ordering |
|
||||
| **Tool Orchestration** | Tool coordination tasks (parallel, error handling) | Correct sequencing, parallelism, failure recovery |
|
||||
|
||||
### Specialized Domains (via RuVector Crates & Examples)
|
||||
|
||||
| Domain | Crate / Example | What It Brings | Transfer Value |
|
||||
|--------|----------------|----------------|----------------|
|
||||
| **Genomics** | [rvDNA](../../examples/dna/) | Variant calling, k-mer HNSW embeddings, 64-dim SNP risk profiles | Sparse structured features seed any domain needing compact representations |
|
||||
| **Algorithmic Trading** | [neural-trader](../../examples/neural-trader/) | Kelly sizing, LSTM-Transformer prediction, DRL portfolio ensembles | Rich reward signals (Sharpe, drawdown) map directly to evaluation scoring |
|
||||
| **Quantum Computing** | [ruQu](../../crates/ruQu/) | Coherence gating, circuit optimization, noise drift detection | Verification methodology — "is it safe to act?" — inspired TransferVerification |
|
||||
| **Neuromorphic AI** | [spiking-neural](../../examples/meta-cognition-spiking-neural-network/) | STDP learning, meta-plasticity, hyperbolic attention | Proves cross-domain acceleration is biologically real and measurable |
|
||||
| **Graph Intelligence** | [graph-transformer](../ruvector-graph-transformer/) | Proof-gated mutation, Nash equilibrium attention, causal Granger layers | Formal proofs before committing changes — same pattern as transfer acceptance |
|
||||
| **Nervous Systems** | [nervous-system](../ruvector-nervous-system/) | One-shot BTSP learning, hyperdimensional computing, circadian duty cycles | Cold-start acceleration — learn from single examples, like transfer priors |
|
||||
| **Scientific OCR** | [scipix](../../examples/scipix/) | LaTeX/MathML extraction, equation vectorization at 50ms/image | Structured mathematical knowledge bootstraps reasoning patterns |
|
||||
| **Knowledge Graphs** | [graph](../../examples/graph/) | Cypher queries, hybrid vector+graph search, community detection | Graph structure reveals which domain clusters should share priors |
|
||||
| **Self-Learning Search** | [ruvector-gnn](../ruvector-gnn/) | GCN/GAT/GraphSAGE on HNSW topology | GraphSAGE handles new domains without retraining — inductive generalization |
|
||||
| **Online Adaptation** | [sona](../sona/) | MicroLoRA (<1ms), EWC++ memory preservation, trajectory tracking | Fast-path arm updates + slow-path prior consolidation without forgetting |
|
||||
|
||||
### How Domains Connect
|
||||
|
||||
```
|
||||
┌──────────────┐
|
||||
│ Domain │
|
||||
│ Expansion │
|
||||
│ Engine │
|
||||
└──────┬───────┘
|
||||
│
|
||||
┌──────────────┼──────────────┐
|
||||
│ │ │
|
||||
┌──────▼──────┐ ┌────▼─────┐ ┌──────▼──────┐
|
||||
│ Genomics │ │ Trading │ │ Quantum │
|
||||
│ 64-dim SNP │ │ Sharpe │ │ Coherence │
|
||||
│ profiles │ │ rewards │ │ gates │
|
||||
└──────┬──────┘ └────┬─────┘ └──────┬──────┘
|
||||
│ │ │
|
||||
└──────┬───────┘──────┬───────┘
|
||||
│ │
|
||||
┌──────▼──────┐ ┌────▼──────────┐
|
||||
│ Shared │ │ Transfer │
|
||||
│ Embedding │ │ Verification │
|
||||
│ Space │ │ Gate │
|
||||
└──────┬──────┘ └────┬──────────┘
|
||||
│ │
|
||||
┌──────▼──────────────▼──────┐
|
||||
│ SONA (MicroLoRA + EWC++) │
|
||||
│ Live adaptation without │
|
||||
│ forgetting old domains │
|
||||
└───────────────────────────┘
|
||||
```
|
||||
|
||||
Every domain produces embeddings in the same vector space. When you transfer from genomics to planning, the engine extracts compact priors (Beta posteriors from Thompson Sampling), seeds them into the target domain, and verifies the transfer helped — using the same coherence metrics that quantum computing uses to decide "is this circuit safe to run?"
|
||||
|
||||
## How Transfer Works
|
||||
|
||||
```
|
||||
Domain 1 (Genomics) Domain 2 (Drug Design)
|
||||
┌─────────────────────┐ ┌─────────────────────┐
|
||||
│ Train on 100 tasks │ │ Start from scratch │
|
||||
│ Extract posteriors │───prior──▶│ Seed with priors │
|
||||
│ Score: 0.85 │ │ Score after 45 runs: │
|
||||
│ │ │ 0.70 (vs 0.30 │
|
||||
│ k-mer embeddings │ │ without transfer) │
|
||||
│ SNP risk profiles │ │ │
|
||||
└─────────────────────┘ └─────────────────────┘
|
||||
│
|
||||
Verification Gate:
|
||||
✓ Target improved (coherence check)
|
||||
✓ Source didn't regress (EWC++ protected)
|
||||
✓ Acceleration > 1.0 (scoreboard)
|
||||
→ Transfer PROMOTED
|
||||
```
|
||||
|
||||
### Cross-Domain Transfer Examples
|
||||
|
||||
| Source Domain | Target Domain | What Transfers | Why It Works |
|
||||
|--------------|---------------|----------------|--------------|
|
||||
| Genomics | Molecular Design | Sequence similarity priors, structural feature embeddings | Both work with sparse biological feature vectors |
|
||||
| Trading | Resource Allocation | Risk/reward tradeoff models, Kelly-style sizing | Same math — allocate limited budget across uncertain options |
|
||||
| Quantum | Signal Processing | Noise detection patterns, drift thresholds | Both need to separate signal from noise in noisy data |
|
||||
| Spiking Neural | Attention Design | STDP timing rules, lateral inhibition patterns | Biological attention and AI attention share structural principles |
|
||||
| Graph Transformer | Code Synthesis | Dependency ordering, proof-gated mutation logic | Code compilation and graph mutation both require valid ordering |
|
||||
| Scientific OCR | Planning | Equation structure, logical step decomposition | Mathematical proofs and multi-step plans share sequential reasoning |
|
||||
|
||||
## Feature Flags
|
||||
|
||||
| Flag | Default | What It Enables |
|
||||
|------|---------|-----------------|
|
||||
| `rvf` | No | RVF cognitive container integration — serialize engines to `.rvf` format |
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
ruvector-domain-expansion = { version = "0.1", features = ["rvf"] }
|
||||
```
|
||||
|
||||
## API Overview
|
||||
|
||||
### Core Types
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `DomainExpansionEngine` | Main orchestrator — manages domains, transfer, population search |
|
||||
| `Domain` (trait) | Implement to add custom domains — generate tasks, evaluate, embed |
|
||||
| `DomainId` | Unique identifier for a domain |
|
||||
| `Task` | A problem instance with difficulty, constraints, and spec |
|
||||
| `Solution` | A candidate answer with content and structured data |
|
||||
| `Evaluation` | Score (0.0–1.0) with correctness, efficiency, and elegance breakdown |
|
||||
|
||||
### Transfer & Strategy
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `MetaThompsonEngine` | Thompson Sampling with Beta priors across context buckets |
|
||||
| `TransferPrior` | Compact posterior summary extracted from a trained domain |
|
||||
| `TransferVerification` | Result of verifying a transfer — promotable only if both domains benefit |
|
||||
| `PolicyKernel` | A strategy configuration with tunable knobs |
|
||||
| `PopulationSearch` | Evolutionary search across policy kernel variants |
|
||||
|
||||
### Meta-Learning
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `MetaLearningEngine` | Regret tracking, plateau detection, Pareto front, curiosity bonuses |
|
||||
| `CostCurve` | Convergence trajectory per domain |
|
||||
| `AccelerationScoreboard` | Measures how much faster transfer makes learning |
|
||||
| `ParetoFront` | Non-dominated set of kernels across accuracy/cost/robustness |
|
||||
|
||||
## Underlying Infrastructure
|
||||
|
||||
The domain expansion engine is built on top of these RuVector primitives:
|
||||
|
||||
| Layer | Crate | Role in Domain Expansion |
|
||||
|-------|-------|--------------------------|
|
||||
| **Retrieval** | [ruvector-gnn](../ruvector-gnn/) | GraphSAGE finds similar contexts across domains without retraining |
|
||||
| **Adaptation** | [sona](../sona/) | MicroLoRA applies arm updates in <1ms; EWC++ prevents forgetting |
|
||||
| **Verification** | [ruvector-coherence](../ruvector-coherence/) | Measures whether transfer preserved semantic quality (95% CI) |
|
||||
| **Attention** | [ruvector-attn-mincut](../ruvector-attn-mincut/) | Min-cut prunes irrelevant domain connections before transfer |
|
||||
| **Computation** | [ruvector-solver](../ruvector-solver/) | Forward Push PPR finds localized relevance across domain knowledge graphs |
|
||||
| **Graph** | [ruvector-graph-transformer](../ruvector-graph-transformer/) | Proof-gated mutations ensure only verified knowledge transfers |
|
||||
| **Packaging** | [rvf](../rvf/) | Ship a trained engine as a single `.rvf` cognitive container |
|
||||
|
||||
## License
|
||||
|
||||
**MIT License** — see [LICENSE](../../LICENSE) for details.
|
||||
|
||||
---
|
||||
|
||||
Part of [RuVector](https://github.com/ruvnet/ruvector) — the self-learning vector database.
|
||||
363
vendor/ruvector/crates/ruvector-domain-expansion/benches/domain_expansion_bench.rs
vendored
Normal file
363
vendor/ruvector/crates/ruvector-domain-expansion/benches/domain_expansion_bench.rs
vendored
Normal file
@@ -0,0 +1,363 @@
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
use ruvector_domain_expansion::{
|
||||
AccelerationScoreboard, ArmId, ContextBucket, ConvergenceThresholds, CostCurve, CostCurvePoint,
|
||||
CuriosityBonus, DecayingBeta, DomainExpansionEngine, DomainId, MetaLearningEngine,
|
||||
MetaThompsonEngine, ParetoFront, ParetoPoint, PlateauDetector, PolicyKnobs, PopulationSearch,
|
||||
RegretTracker, Solution, TransferPrior,
|
||||
};
|
||||
|
||||
fn bench_task_generation(c: &mut Criterion) {
|
||||
let engine = DomainExpansionEngine::new();
|
||||
let domains = engine.domain_ids();
|
||||
|
||||
let mut group = c.benchmark_group("task_generation");
|
||||
|
||||
for domain_id in &domains {
|
||||
group.bench_function(format!("{}", domain_id), |b| {
|
||||
b.iter(|| engine.generate_tasks(black_box(domain_id), black_box(10), black_box(0.5)))
|
||||
});
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_evaluation(c: &mut Criterion) {
|
||||
let engine = DomainExpansionEngine::new();
|
||||
let rust_id = DomainId("rust_synthesis".into());
|
||||
let tasks = engine.generate_tasks(&rust_id, 10, 0.5);
|
||||
|
||||
let solution = Solution {
|
||||
task_id: tasks[0].id.clone(),
|
||||
content:
|
||||
"fn sum_positives(values: &[i64]) -> i64 { values.iter().filter(|&&x| x > 0).sum() }"
|
||||
.into(),
|
||||
data: serde_json::Value::Null,
|
||||
};
|
||||
|
||||
c.bench_function("evaluate_rust_solution", |b| {
|
||||
b.iter(|| {
|
||||
let mut eng = DomainExpansionEngine::new();
|
||||
eng.evaluate_and_record(
|
||||
black_box(&rust_id),
|
||||
black_box(&tasks[0]),
|
||||
black_box(&solution),
|
||||
ContextBucket {
|
||||
difficulty_tier: "medium".into(),
|
||||
category: "transform".into(),
|
||||
},
|
||||
ArmId("greedy".into()),
|
||||
)
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_embedding(c: &mut Criterion) {
|
||||
let engine = DomainExpansionEngine::new();
|
||||
let rust_id = DomainId("rust_synthesis".into());
|
||||
|
||||
let solution = Solution {
|
||||
task_id: "bench".into(),
|
||||
content: "fn foo() { for i in 0..10 { if i > 5 { let x = i.max(3); } } }".into(),
|
||||
data: serde_json::Value::Null,
|
||||
};
|
||||
|
||||
c.bench_function("embed_solution", |b| {
|
||||
b.iter(|| engine.embed(black_box(&rust_id), black_box(&solution)))
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_thompson_sampling(c: &mut Criterion) {
|
||||
let mut engine = MetaThompsonEngine::new(vec![
|
||||
"greedy".into(),
|
||||
"exploratory".into(),
|
||||
"conservative".into(),
|
||||
"speculative".into(),
|
||||
]);
|
||||
|
||||
let domain = DomainId("bench".into());
|
||||
engine.init_domain_uniform(domain.clone());
|
||||
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "medium".into(),
|
||||
category: "algorithm".into(),
|
||||
};
|
||||
|
||||
// Pre-populate with data
|
||||
for i in 0..100 {
|
||||
let arm = ArmId(format!(
|
||||
"{}",
|
||||
["greedy", "exploratory", "conservative", "speculative"][i % 4]
|
||||
));
|
||||
let reward = if i % 4 == 0 { 0.9 } else { 0.4 };
|
||||
engine.record_outcome(&domain, bucket.clone(), arm, reward, 1.0);
|
||||
}
|
||||
|
||||
c.bench_function("thompson_select_arm", |b| {
|
||||
b.iter(|| {
|
||||
let mut rng = rand::thread_rng();
|
||||
engine.select_arm(black_box(&domain), black_box(&bucket), &mut rng)
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_population_evolve(c: &mut Criterion) {
|
||||
let mut search = PopulationSearch::new(16);
|
||||
|
||||
// Pre-populate fitness
|
||||
for i in 0..16 {
|
||||
if let Some(kernel) = search.kernel_mut(i) {
|
||||
kernel.record_score(DomainId("bench".into()), i as f32 / 16.0, 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
c.bench_function("population_evolve_16", |b| {
|
||||
b.iter(|| {
|
||||
let mut s = search.clone();
|
||||
s.evolve();
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_knobs_mutate(c: &mut Criterion) {
|
||||
let knobs = PolicyKnobs::default_knobs();
|
||||
c.bench_function("knobs_mutate", |b| {
|
||||
b.iter(|| {
|
||||
let mut rng = rand::thread_rng();
|
||||
black_box(knobs.mutate(&mut rng, 0.3))
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_cost_curve_auc(c: &mut Criterion) {
|
||||
let mut curve = CostCurve::new(DomainId("bench".into()), ConvergenceThresholds::default());
|
||||
for i in 0..1000 {
|
||||
curve.record(CostCurvePoint {
|
||||
cycle: i,
|
||||
accuracy: (i as f32 / 1000.0).min(1.0),
|
||||
cost_per_solve: 1.0 / (i as f32 + 1.0),
|
||||
robustness: (i as f32 / 1000.0).min(1.0),
|
||||
policy_violations: 0,
|
||||
timestamp: i as f64,
|
||||
});
|
||||
}
|
||||
|
||||
c.bench_function("cost_curve_auc_1000pts", |b| {
|
||||
b.iter(|| black_box(curve.auc_accuracy()))
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_transfer_prior_extract(c: &mut Criterion) {
|
||||
let domain = DomainId("bench".into());
|
||||
let mut prior = TransferPrior::uniform(domain);
|
||||
|
||||
// Populate with 100 buckets x 4 arms
|
||||
for b in 0..100 {
|
||||
for a in 0..4 {
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: format!("tier_{}", b % 3),
|
||||
category: format!("cat_{}", b),
|
||||
};
|
||||
let arm = ArmId(format!("arm_{}", a));
|
||||
for _ in 0..20 {
|
||||
prior.update_posterior(bucket.clone(), arm.clone(), 0.7);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.bench_function("transfer_prior_extract_100buckets", |b| {
|
||||
b.iter(|| black_box(prior.extract_summary()))
|
||||
});
|
||||
}
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════
|
||||
// Meta-Learning Benchmarks
|
||||
// ═══════════════════════════════════════════════════════════════════
|
||||
|
||||
fn bench_regret_tracker(c: &mut Criterion) {
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "medium".into(),
|
||||
category: "algo".into(),
|
||||
};
|
||||
let arms: Vec<ArmId> = (0..4).map(|i| ArmId(format!("arm_{}", i))).collect();
|
||||
|
||||
let mut group = c.benchmark_group("meta_learning");
|
||||
|
||||
group.bench_function("regret_record_1k", |b| {
|
||||
b.iter(|| {
|
||||
let mut tracker = RegretTracker::new(50);
|
||||
for i in 0..1000 {
|
||||
let arm = &arms[i % 4];
|
||||
let reward = if i % 4 == 0 { 0.9 } else { 0.4 };
|
||||
tracker.record(black_box(&bucket), black_box(arm), black_box(reward));
|
||||
}
|
||||
black_box(tracker.average_regret())
|
||||
})
|
||||
});
|
||||
|
||||
group.bench_function("regret_summary", |b| {
|
||||
let mut tracker = RegretTracker::new(50);
|
||||
for i in 0..1000 {
|
||||
let arm = &arms[i % 4];
|
||||
tracker.record(&bucket, arm, if i % 4 == 0 { 0.9 } else { 0.4 });
|
||||
}
|
||||
b.iter(|| black_box(tracker.summary()))
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_decaying_beta(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("decaying_beta");
|
||||
|
||||
group.bench_function("update_1k", |b| {
|
||||
b.iter(|| {
|
||||
let mut db = DecayingBeta::new(0.995);
|
||||
for i in 0..1000 {
|
||||
let reward = if i % 3 == 0 { 0.9 } else { 0.4 };
|
||||
db.update(black_box(reward));
|
||||
}
|
||||
black_box(db.mean())
|
||||
})
|
||||
});
|
||||
|
||||
group.bench_function("update_vs_standard", |b| {
|
||||
b.iter(|| {
|
||||
// Compare DecayingBeta vs standard BetaParams
|
||||
let mut db = DecayingBeta::new(0.995);
|
||||
let mut std_beta = ruvector_domain_expansion::BetaParams::uniform();
|
||||
for i in 0..500 {
|
||||
let reward = if i % 3 == 0 { 0.9 } else { 0.4 };
|
||||
db.update(reward);
|
||||
std_beta.update(reward);
|
||||
}
|
||||
black_box((db.mean(), std_beta.mean()))
|
||||
})
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_plateau_detector(c: &mut Criterion) {
|
||||
let points: Vec<CostCurvePoint> = (0..100)
|
||||
.map(|i| CostCurvePoint {
|
||||
cycle: i,
|
||||
accuracy: 0.80 + (i as f32 * 0.001),
|
||||
cost_per_solve: 0.1 / (i as f32 + 1.0),
|
||||
robustness: 0.8,
|
||||
policy_violations: 0,
|
||||
timestamp: i as f64,
|
||||
})
|
||||
.collect();
|
||||
|
||||
c.bench_function("plateau_check_100pts", |b| {
|
||||
b.iter(|| {
|
||||
let mut detector = PlateauDetector::new(10, 0.005);
|
||||
black_box(detector.check(black_box(&points)))
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_pareto_front(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("pareto_front");
|
||||
|
||||
group.bench_function("insert_100_points", |b| {
|
||||
b.iter(|| {
|
||||
let mut front = ParetoFront::new();
|
||||
for i in 0..100 {
|
||||
let acc = (i as f32) / 100.0;
|
||||
let cost = -((100 - i) as f32) / 100.0;
|
||||
let rob = ((i * 7 + 13) % 100) as f32 / 100.0;
|
||||
front.insert(ParetoPoint {
|
||||
kernel_id: format!("k{}", i),
|
||||
objectives: vec![acc, cost, rob],
|
||||
generation: 0,
|
||||
});
|
||||
}
|
||||
black_box(front.len())
|
||||
})
|
||||
});
|
||||
|
||||
group.bench_function("hypervolume_2d", |b| {
|
||||
let mut front = ParetoFront::new();
|
||||
for i in 0..20 {
|
||||
let x = (i as f32 + 1.0) / 21.0;
|
||||
front.insert(ParetoPoint {
|
||||
kernel_id: format!("k{}", i),
|
||||
objectives: vec![x, 1.0 - x],
|
||||
generation: 0,
|
||||
});
|
||||
}
|
||||
b.iter(|| black_box(front.hypervolume(&[0.0, 0.0])))
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_curiosity_bonus(c: &mut Criterion) {
|
||||
let arms: Vec<ArmId> = (0..4).map(|i| ArmId(format!("arm_{}", i))).collect();
|
||||
let buckets: Vec<ContextBucket> = (0..18)
|
||||
.map(|i| ContextBucket {
|
||||
difficulty_tier: ["easy", "medium", "hard"][i / 6].into(),
|
||||
category: format!("cat_{}", i % 6),
|
||||
})
|
||||
.collect();
|
||||
|
||||
c.bench_function("curiosity_bonus_18buckets", |b| {
|
||||
let mut curiosity = CuriosityBonus::new(1.41);
|
||||
for _ in 0..500 {
|
||||
for bucket in &buckets {
|
||||
for arm in &arms {
|
||||
curiosity.record_visit(bucket, arm);
|
||||
}
|
||||
}
|
||||
}
|
||||
b.iter(|| {
|
||||
let mut total = 0.0f32;
|
||||
for bucket in &buckets {
|
||||
for arm in &arms {
|
||||
total += curiosity.bonus(black_box(bucket), black_box(arm));
|
||||
}
|
||||
}
|
||||
black_box(total)
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_meta_engine_full_cycle(c: &mut Criterion) {
|
||||
c.bench_function("meta_engine_100_decisions", |b| {
|
||||
b.iter(|| {
|
||||
let mut engine = MetaLearningEngine::new();
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "medium".into(),
|
||||
category: "algo".into(),
|
||||
};
|
||||
let arm = ArmId("greedy".into());
|
||||
|
||||
for i in 0..100 {
|
||||
let reward = if i % 3 == 0 { 0.9 } else { 0.5 };
|
||||
engine.record_decision(&bucket, &arm, reward);
|
||||
}
|
||||
|
||||
engine.record_kernel("k1", 0.9, 0.2, 0.8, 1);
|
||||
black_box(engine.health_check())
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
bench_task_generation,
|
||||
bench_evaluation,
|
||||
bench_embedding,
|
||||
bench_thompson_sampling,
|
||||
bench_population_evolve,
|
||||
bench_knobs_mutate,
|
||||
bench_cost_curve_auc,
|
||||
bench_transfer_prior_extract,
|
||||
bench_regret_tracker,
|
||||
bench_decaying_beta,
|
||||
bench_plateau_detector,
|
||||
bench_pareto_front,
|
||||
bench_curiosity_bonus,
|
||||
bench_meta_engine_full_cycle,
|
||||
);
|
||||
criterion_main!(benches);
|
||||
241
vendor/ruvector/crates/ruvector-domain-expansion/docs/README.md
vendored
Normal file
241
vendor/ruvector/crates/ruvector-domain-expansion/docs/README.md
vendored
Normal file
@@ -0,0 +1,241 @@
|
||||
# ruvector-domain-expansion
|
||||
|
||||
Cross-domain transfer learning engine for general problem-solving capability.
|
||||
|
||||
## Core Insight
|
||||
|
||||
> True IQ growth appears when a kernel trained on Domain 1 improves Domain 2 faster than Domain 2 alone. That is generalization.
|
||||
|
||||
If cost curves compress faster in each new domain, you are increasing general problem-solving capability.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Two-Layer Learning
|
||||
|
||||
```
|
||||
Policy Learning Layer (Meta Thompson Sampling)
|
||||
|
|
||||
| TransferPrior: compact Beta posteriors per bucket/arm
|
||||
| NOT raw trajectories. Ship priors, not memories.
|
||||
|
|
||||
v
|
||||
Operator Layer (Domain Kernels)
|
||||
|
|
||||
| Rust Synthesis | Planning | Tool Orchestration
|
||||
| Generate tasks, evaluate solutions, produce embeddings
|
||||
|
|
||||
v
|
||||
Shared Embedding Space (64-dim)
|
||||
Cross-domain similarity via cosine distance
|
||||
```
|
||||
|
||||
### Domains
|
||||
|
||||
| Domain | Description | Task Types |
|
||||
|--------|-------------|------------|
|
||||
| **Rust Program Synthesis** | Synthesize Rust functions from specs | Transform, DataStructure, Algorithm, TypeLevel, Concurrency |
|
||||
| **Structured Planning** | Multi-step plans with constraints | ResourceAllocation, DependencyScheduling, StateSpaceSearch, ConstraintSatisfaction |
|
||||
| **Tool Orchestration** | Coordinate multiple tools/agents | PipelineConstruction, ErrorRecovery, ParallelCoordination, ResourceNegotiation |
|
||||
|
||||
### Transfer Protocol
|
||||
|
||||
1. Train on Domain 1, extract `TransferPrior` (posterior summaries)
|
||||
2. Initialize Domain 2 with dampened priors from Domain 1
|
||||
3. Measure acceleration: cycles to convergence with vs without transfer
|
||||
4. **Generalization rule**: A delta is promotable only if it improves Domain 2 without regressing Domain 1
|
||||
|
||||
### Population-Based Policy Search
|
||||
|
||||
Run a population of `PolicyKernel` variants in parallel. Each variant tunes knobs:
|
||||
- Skip mode policy
|
||||
- Prepass mode
|
||||
- Speculation trigger thresholds
|
||||
- Budget allocation
|
||||
|
||||
Selection: keep top performers on holdouts, mutate knobs, repeat. Only merge deltas that pass replay-verify.
|
||||
|
||||
### Speculative Dual-Path
|
||||
|
||||
When posterior variance is high (top two arms within delta), run both strategies with bounded budgets. Pick the first correct, log the loser as a counterexample.
|
||||
|
||||
## Usage
|
||||
|
||||
### Rust
|
||||
|
||||
```rust
|
||||
use ruvector_domain_expansion::{
|
||||
DomainExpansionEngine, DomainId, ArmId, ContextBucket,
|
||||
};
|
||||
|
||||
// Create engine with 3 core domains
|
||||
let mut engine = DomainExpansionEngine::new();
|
||||
|
||||
// Generate tasks
|
||||
let tasks = engine.generate_tasks(
|
||||
&DomainId("rust_synthesis".into()),
|
||||
10, // count
|
||||
0.5, // difficulty
|
||||
);
|
||||
|
||||
// Select arm via Thompson Sampling
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "medium".into(),
|
||||
category: "algorithm".into(),
|
||||
};
|
||||
let arm = engine.select_arm(
|
||||
&DomainId("rust_synthesis".into()),
|
||||
&bucket,
|
||||
).unwrap();
|
||||
|
||||
// Evaluate and record
|
||||
let eval = engine.evaluate_and_record(
|
||||
&DomainId("rust_synthesis".into()),
|
||||
&tasks[0],
|
||||
&solution,
|
||||
bucket,
|
||||
arm,
|
||||
);
|
||||
|
||||
// Transfer learning
|
||||
engine.initiate_transfer(
|
||||
&DomainId("rust_synthesis".into()),
|
||||
&DomainId("structured_planning".into()),
|
||||
);
|
||||
|
||||
// Verify generalization
|
||||
let v = engine.verify_transfer(
|
||||
&DomainId("rust_synthesis".into()),
|
||||
&DomainId("structured_planning".into()),
|
||||
0.85, 0.84, // source before/after
|
||||
0.3, 0.7, // target before/after
|
||||
100, 40, // baseline/transfer cycles
|
||||
);
|
||||
assert!(v.promotable); // improved target without regressing source
|
||||
assert!(v.acceleration_factor > 1.0); // 2.5x faster convergence
|
||||
```
|
||||
|
||||
### WASM (JavaScript)
|
||||
|
||||
```javascript
|
||||
import { WasmDomainExpansionEngine } from 'ruvector-domain-expansion-wasm';
|
||||
|
||||
const engine = new WasmDomainExpansionEngine();
|
||||
|
||||
// List domains
|
||||
console.log(engine.domainIds());
|
||||
// ["rust_synthesis", "structured_planning", "tool_orchestration"]
|
||||
|
||||
// Generate tasks
|
||||
const tasks = engine.generateTasks("rust_synthesis", 10, 0.5);
|
||||
|
||||
// Select strategy via Thompson Sampling
|
||||
const arm = engine.selectArm("rust_synthesis", "medium", "algorithm");
|
||||
|
||||
// Check if dual-path speculation needed
|
||||
if (engine.shouldSpeculate("rust_synthesis", "medium", "algorithm")) {
|
||||
// Run both strategies, pick winner
|
||||
}
|
||||
|
||||
// Transfer priors between domains
|
||||
engine.initiateTransfer("rust_synthesis", "structured_planning");
|
||||
|
||||
// Evolve policy kernels
|
||||
engine.generateHoldouts(10, 0.5);
|
||||
engine.evaluatePopulation();
|
||||
engine.evolvePopulation();
|
||||
console.log(engine.populationStats());
|
||||
|
||||
// Acceleration scoreboard
|
||||
console.log(engine.scoreboardSummary());
|
||||
```
|
||||
|
||||
## Acceptance Test
|
||||
|
||||
Domain 2 must converge faster than Domain 1. Measure cycles to reach:
|
||||
- 95% accuracy
|
||||
- Target cost per solve
|
||||
- Target robustness
|
||||
- Zero policy violations
|
||||
|
||||
```rust
|
||||
use ruvector_domain_expansion::{AccelerationScoreboard, CostCurve, DomainId};
|
||||
|
||||
let mut board = AccelerationScoreboard::new();
|
||||
|
||||
// Add baseline and transfer curves
|
||||
board.add_curve(baseline_curve);
|
||||
board.add_curve(transfer_curve);
|
||||
|
||||
// Compute acceleration
|
||||
let entry = board.compute_acceleration(
|
||||
&DomainId("baseline".into()),
|
||||
&DomainId("transfer".into()),
|
||||
).unwrap();
|
||||
|
||||
assert!(entry.acceleration > 1.0); // transfer helped
|
||||
assert!(entry.generalization_passed);
|
||||
|
||||
// Check progressive improvement across multiple domains
|
||||
assert!(board.progressive_acceleration());
|
||||
```
|
||||
|
||||
## RVF Packaging
|
||||
|
||||
Transfer artifacts are designed for RVF segment packaging:
|
||||
|
||||
| Segment | Content | Purpose |
|
||||
|---------|---------|---------|
|
||||
| `TransferPrior` | Beta posteriors per bucket/arm | Seeds new domain initialization |
|
||||
| `PolicyKernel` | Knob configuration + fitness history | Best policy for a domain |
|
||||
| `CostCurve` | Convergence data points | Acceleration measurement |
|
||||
| `WitnessChain` | Hash of derivation + holdout results | Audit trail |
|
||||
| `Counterexamples` | Failed solutions per context | Negative signal for future decisions |
|
||||
|
||||
## Benchmarks
|
||||
|
||||
```bash
|
||||
cargo bench -p ruvector-domain-expansion
|
||||
```
|
||||
|
||||
Benchmarks cover:
|
||||
- Task generation (per domain)
|
||||
- Solution evaluation
|
||||
- Embedding extraction
|
||||
- Thompson Sampling arm selection
|
||||
- Population evolution
|
||||
- PolicyKnobs mutation
|
||||
- Cost curve AUC computation
|
||||
- TransferPrior extraction
|
||||
|
||||
## Module Structure
|
||||
|
||||
```
|
||||
src/
|
||||
lib.rs -- Orchestrator: DomainExpansionEngine
|
||||
domain.rs -- Core Domain trait, Task, Solution, Evaluation, Embedding
|
||||
rust_synthesis.rs -- Rust program synthesis domain
|
||||
planning.rs -- Structured planning tasks domain
|
||||
tool_orchestration.rs -- Tool orchestration problems domain
|
||||
transfer.rs -- Meta Thompson Sampling, TransferPrior, verification
|
||||
policy_kernel.rs -- PolicyKernel, PopulationSearch, PolicyKnobs
|
||||
cost_curve.rs -- CostCurve, AccelerationScoreboard
|
||||
```
|
||||
|
||||
## Tests
|
||||
|
||||
49 unit tests covering all modules:
|
||||
|
||||
```bash
|
||||
cargo test -p ruvector-domain-expansion
|
||||
```
|
||||
|
||||
| Module | Tests |
|
||||
|--------|-------|
|
||||
| `domain` | 5 tests: types, embedding cosine similarity, evaluation |
|
||||
| `rust_synthesis` | 5 tests: generation, evaluation, embedding, difficulty |
|
||||
| `planning` | 5 tests: generation, reference, evaluation, embedding, scaling |
|
||||
| `tool_orchestration` | 5 tests: generation, reference, evaluation, embedding, errors |
|
||||
| `transfer` | 6 tests: Beta params, Thompson engine, prior extraction, verification |
|
||||
| `policy_kernel` | 5 tests: knobs, fitness, evolution, stats, crossover |
|
||||
| `cost_curve` | 5 tests: convergence, compression, AUC, acceleration, scoreboard |
|
||||
| `lib` (integration) | 8 tests: engine, tasks, arms, evaluation, embedding, transfer, population |
|
||||
482
vendor/ruvector/crates/ruvector-domain-expansion/src/cost_curve.rs
vendored
Normal file
482
vendor/ruvector/crates/ruvector-domain-expansion/src/cost_curve.rs
vendored
Normal file
@@ -0,0 +1,482 @@
|
||||
//! Cost Curve Compression Tracker and Acceleration Scoreboard
|
||||
//!
|
||||
//! Measures whether cost curves compress faster in each new domain.
|
||||
//! If they do, you are increasing general problem-solving capability.
|
||||
//!
|
||||
//! ## Acceptance Test
|
||||
//!
|
||||
//! Domain 2 must converge faster than Domain 1.
|
||||
//! Measure cycles to reach:
|
||||
//! - 95% accuracy
|
||||
//! - Target cost per solve
|
||||
//! - Target robustness
|
||||
//! - Zero policy violations
|
||||
|
||||
use crate::domain::DomainId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// A single data point on the cost curve.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CostCurvePoint {
|
||||
/// Cycle number (training iteration).
|
||||
pub cycle: u64,
|
||||
/// Current accuracy [0.0, 1.0].
|
||||
pub accuracy: f32,
|
||||
/// Cost per solve at this point.
|
||||
pub cost_per_solve: f32,
|
||||
/// Robustness score [0.0, 1.0].
|
||||
pub robustness: f32,
|
||||
/// Number of policy violations in this cycle.
|
||||
pub policy_violations: u32,
|
||||
/// Wall-clock timestamp (seconds since epoch).
|
||||
pub timestamp: f64,
|
||||
}
|
||||
|
||||
/// Convergence thresholds for the acceptance test.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ConvergenceThresholds {
|
||||
/// Target accuracy (default: 0.95).
|
||||
pub target_accuracy: f32,
|
||||
/// Target cost per solve.
|
||||
pub target_cost: f32,
|
||||
/// Target robustness (default: 0.90).
|
||||
pub target_robustness: f32,
|
||||
/// Maximum allowed policy violations (default: 0).
|
||||
pub max_violations: u32,
|
||||
}
|
||||
|
||||
impl Default for ConvergenceThresholds {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
target_accuracy: 0.95,
|
||||
target_cost: 0.01,
|
||||
target_robustness: 0.90,
|
||||
max_violations: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cost curve for a single domain, tracking convergence over cycles.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CostCurve {
|
||||
/// Domain this curve belongs to.
|
||||
pub domain_id: DomainId,
|
||||
/// Whether this was trained with transfer priors.
|
||||
pub used_transfer: bool,
|
||||
/// Source domain for transfer (if any).
|
||||
pub transfer_source: Option<DomainId>,
|
||||
/// Ordered data points.
|
||||
pub points: Vec<CostCurvePoint>,
|
||||
/// Convergence thresholds.
|
||||
pub thresholds: ConvergenceThresholds,
|
||||
}
|
||||
|
||||
impl CostCurve {
|
||||
/// Create a new cost curve for a domain.
|
||||
pub fn new(domain_id: DomainId, thresholds: ConvergenceThresholds) -> Self {
|
||||
Self {
|
||||
domain_id,
|
||||
used_transfer: false,
|
||||
transfer_source: None,
|
||||
points: Vec::new(),
|
||||
thresholds,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a cost curve with transfer metadata.
|
||||
pub fn with_transfer(
|
||||
domain_id: DomainId,
|
||||
source: DomainId,
|
||||
thresholds: ConvergenceThresholds,
|
||||
) -> Self {
|
||||
Self {
|
||||
domain_id,
|
||||
used_transfer: true,
|
||||
transfer_source: Some(source),
|
||||
points: Vec::new(),
|
||||
thresholds,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a new data point.
|
||||
pub fn record(&mut self, point: CostCurvePoint) {
|
||||
self.points.push(point);
|
||||
}
|
||||
|
||||
/// Check if all convergence criteria are met at the latest point.
|
||||
pub fn has_converged(&self) -> bool {
|
||||
self.points.last().map_or(false, |p| {
|
||||
p.accuracy >= self.thresholds.target_accuracy
|
||||
&& p.cost_per_solve <= self.thresholds.target_cost
|
||||
&& p.robustness >= self.thresholds.target_robustness
|
||||
&& p.policy_violations <= self.thresholds.max_violations
|
||||
})
|
||||
}
|
||||
|
||||
/// Cycles to reach target accuracy (None if not yet reached).
|
||||
pub fn cycles_to_accuracy(&self) -> Option<u64> {
|
||||
self.points
|
||||
.iter()
|
||||
.find(|p| p.accuracy >= self.thresholds.target_accuracy)
|
||||
.map(|p| p.cycle)
|
||||
}
|
||||
|
||||
/// Cycles to reach target cost (None if not yet reached).
|
||||
pub fn cycles_to_cost(&self) -> Option<u64> {
|
||||
self.points
|
||||
.iter()
|
||||
.find(|p| p.cost_per_solve <= self.thresholds.target_cost)
|
||||
.map(|p| p.cycle)
|
||||
}
|
||||
|
||||
/// Cycles to reach target robustness.
|
||||
pub fn cycles_to_robustness(&self) -> Option<u64> {
|
||||
self.points
|
||||
.iter()
|
||||
.find(|p| p.robustness >= self.thresholds.target_robustness)
|
||||
.map(|p| p.cycle)
|
||||
}
|
||||
|
||||
/// Cycles to full convergence (all criteria met).
|
||||
pub fn cycles_to_convergence(&self) -> Option<u64> {
|
||||
self.points
|
||||
.iter()
|
||||
.find(|p| {
|
||||
p.accuracy >= self.thresholds.target_accuracy
|
||||
&& p.cost_per_solve <= self.thresholds.target_cost
|
||||
&& p.robustness >= self.thresholds.target_robustness
|
||||
&& p.policy_violations <= self.thresholds.max_violations
|
||||
})
|
||||
.map(|p| p.cycle)
|
||||
}
|
||||
|
||||
/// Area under the accuracy curve (higher = faster learning).
|
||||
pub fn auc_accuracy(&self) -> f32 {
|
||||
if self.points.len() < 2 {
|
||||
return 0.0;
|
||||
}
|
||||
self.points
|
||||
.windows(2)
|
||||
.map(|w| {
|
||||
let dx = (w[1].cycle - w[0].cycle) as f32;
|
||||
let avg_y = (w[0].accuracy + w[1].accuracy) / 2.0;
|
||||
dx * avg_y
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Compression ratio: how fast the cost curve drops.
|
||||
/// Computed as initial_cost / final_cost (higher = more compression).
|
||||
pub fn compression_ratio(&self) -> f32 {
|
||||
if self.points.len() < 2 {
|
||||
return 1.0;
|
||||
}
|
||||
let initial = self.points.first().unwrap().cost_per_solve;
|
||||
let final_cost = self.points.last().unwrap().cost_per_solve;
|
||||
if final_cost > 1e-10 {
|
||||
initial / final_cost
|
||||
} else {
|
||||
initial / 1e-10
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Acceleration scoreboard comparing domain learning curves.
|
||||
/// Shows acceleration, not just improvement.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AccelerationScoreboard {
|
||||
/// Per-domain cost curves.
|
||||
pub curves: HashMap<DomainId, CostCurve>,
|
||||
/// Pairwise acceleration factors.
|
||||
pub accelerations: Vec<AccelerationEntry>,
|
||||
}
|
||||
|
||||
/// An entry showing how transfer from source to target affected convergence.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AccelerationEntry {
|
||||
/// Source domain.
|
||||
pub source: DomainId,
|
||||
/// Target domain.
|
||||
pub target: DomainId,
|
||||
/// Cycles to convergence without transfer (baseline).
|
||||
pub baseline_cycles: Option<u64>,
|
||||
/// Cycles to convergence with transfer.
|
||||
pub transfer_cycles: Option<u64>,
|
||||
/// Acceleration factor: baseline / transfer (>1 = transfer helped).
|
||||
pub acceleration: f32,
|
||||
/// AUC comparison (higher = better learning curve).
|
||||
pub auc_baseline: f32,
|
||||
pub auc_transfer: f32,
|
||||
/// Compression ratio comparison.
|
||||
pub compression_baseline: f32,
|
||||
pub compression_transfer: f32,
|
||||
/// Whether generalization test passed.
|
||||
pub generalization_passed: bool,
|
||||
}
|
||||
|
||||
impl AccelerationScoreboard {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
curves: HashMap::new(),
|
||||
accelerations: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a cost curve for a domain.
|
||||
pub fn add_curve(&mut self, curve: CostCurve) {
|
||||
self.curves.insert(curve.domain_id.clone(), curve);
|
||||
}
|
||||
|
||||
/// Compute acceleration between a baseline (no transfer) and transfer curve.
|
||||
pub fn compute_acceleration(
|
||||
&mut self,
|
||||
baseline_domain: &DomainId,
|
||||
transfer_domain: &DomainId,
|
||||
) -> Option<AccelerationEntry> {
|
||||
let baseline = self.curves.get(baseline_domain)?;
|
||||
let transfer = self.curves.get(transfer_domain)?;
|
||||
|
||||
let baseline_cycles = baseline.cycles_to_convergence();
|
||||
let transfer_cycles = transfer.cycles_to_convergence();
|
||||
|
||||
let acceleration = match (baseline_cycles, transfer_cycles) {
|
||||
(Some(b), Some(t)) if t > 0 => b as f32 / t as f32,
|
||||
_ => 1.0, // No measurable acceleration
|
||||
};
|
||||
|
||||
let entry = AccelerationEntry {
|
||||
source: transfer
|
||||
.transfer_source
|
||||
.clone()
|
||||
.unwrap_or_else(|| DomainId("none".into())),
|
||||
target: transfer_domain.clone(),
|
||||
baseline_cycles,
|
||||
transfer_cycles,
|
||||
acceleration,
|
||||
auc_baseline: baseline.auc_accuracy(),
|
||||
auc_transfer: transfer.auc_accuracy(),
|
||||
compression_baseline: baseline.compression_ratio(),
|
||||
compression_transfer: transfer.compression_ratio(),
|
||||
generalization_passed: acceleration > 1.0,
|
||||
};
|
||||
|
||||
self.accelerations.push(entry.clone());
|
||||
Some(entry)
|
||||
}
|
||||
|
||||
/// Check whether each successive domain converges faster (the IQ growth test).
|
||||
pub fn progressive_acceleration(&self) -> bool {
|
||||
if self.accelerations.len() < 2 {
|
||||
return true; // Not enough data to judge
|
||||
}
|
||||
|
||||
self.accelerations
|
||||
.windows(2)
|
||||
.all(|w| w[1].acceleration >= w[0].acceleration)
|
||||
}
|
||||
|
||||
/// Summary report of all domains.
|
||||
pub fn summary(&self) -> ScoreboardSummary {
|
||||
let domain_summaries: Vec<DomainSummary> = self
|
||||
.curves
|
||||
.iter()
|
||||
.map(|(id, curve)| DomainSummary {
|
||||
domain_id: id.clone(),
|
||||
total_cycles: curve.points.last().map(|p| p.cycle).unwrap_or(0),
|
||||
final_accuracy: curve.points.last().map(|p| p.accuracy).unwrap_or(0.0),
|
||||
final_cost: curve
|
||||
.points
|
||||
.last()
|
||||
.map(|p| p.cost_per_solve)
|
||||
.unwrap_or(f32::MAX),
|
||||
converged: curve.has_converged(),
|
||||
cycles_to_convergence: curve.cycles_to_convergence(),
|
||||
compression_ratio: curve.compression_ratio(),
|
||||
used_transfer: curve.used_transfer,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let overall_acceleration = if self.accelerations.is_empty() {
|
||||
1.0
|
||||
} else {
|
||||
self.accelerations
|
||||
.iter()
|
||||
.map(|a| a.acceleration)
|
||||
.sum::<f32>()
|
||||
/ self.accelerations.len() as f32
|
||||
};
|
||||
|
||||
ScoreboardSummary {
|
||||
domains: domain_summaries,
|
||||
accelerations: self.accelerations.clone(),
|
||||
overall_acceleration,
|
||||
progressive_improvement: self.progressive_acceleration(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AccelerationScoreboard {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Summary of a single domain's learning.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DomainSummary {
|
||||
pub domain_id: DomainId,
|
||||
pub total_cycles: u64,
|
||||
pub final_accuracy: f32,
|
||||
pub final_cost: f32,
|
||||
pub converged: bool,
|
||||
pub cycles_to_convergence: Option<u64>,
|
||||
pub compression_ratio: f32,
|
||||
pub used_transfer: bool,
|
||||
}
|
||||
|
||||
/// Full scoreboard summary.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ScoreboardSummary {
|
||||
pub domains: Vec<DomainSummary>,
|
||||
pub accelerations: Vec<AccelerationEntry>,
|
||||
pub overall_acceleration: f32,
|
||||
/// True if each new domain converges faster than the previous.
|
||||
pub progressive_improvement: bool,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_curve(domain: &str, transfer: bool, accuracy_steps: &[(u64, f32, f32)]) -> CostCurve {
|
||||
let mut curve = if transfer {
|
||||
CostCurve::with_transfer(
|
||||
DomainId(domain.into()),
|
||||
DomainId("source".into()),
|
||||
ConvergenceThresholds::default(),
|
||||
)
|
||||
} else {
|
||||
CostCurve::new(DomainId(domain.into()), ConvergenceThresholds::default())
|
||||
};
|
||||
|
||||
for &(cycle, accuracy, cost) in accuracy_steps {
|
||||
curve.record(CostCurvePoint {
|
||||
cycle,
|
||||
accuracy,
|
||||
cost_per_solve: cost,
|
||||
robustness: accuracy * 0.95,
|
||||
policy_violations: 0,
|
||||
timestamp: cycle as f64,
|
||||
});
|
||||
}
|
||||
curve
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cost_curve_convergence() {
|
||||
let curve = make_curve(
|
||||
"test",
|
||||
false,
|
||||
&[
|
||||
(0, 0.3, 0.1),
|
||||
(10, 0.6, 0.05),
|
||||
(20, 0.8, 0.02),
|
||||
(30, 0.95, 0.008),
|
||||
],
|
||||
);
|
||||
|
||||
assert!(curve.has_converged());
|
||||
assert_eq!(curve.cycles_to_accuracy(), Some(30));
|
||||
assert_eq!(curve.cycles_to_cost(), Some(30));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cost_curve_not_converged() {
|
||||
let curve = make_curve("test", false, &[(0, 0.3, 0.1), (10, 0.6, 0.05)]);
|
||||
|
||||
assert!(!curve.has_converged());
|
||||
assert_eq!(curve.cycles_to_accuracy(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compression_ratio() {
|
||||
let curve = make_curve(
|
||||
"test",
|
||||
false,
|
||||
&[(0, 0.3, 1.0), (10, 0.6, 0.5), (20, 0.9, 0.1)],
|
||||
);
|
||||
|
||||
let ratio = curve.compression_ratio();
|
||||
assert!((ratio - 10.0).abs() < 1e-4); // 1.0 / 0.1 = 10x
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_acceleration_scoreboard() {
|
||||
let mut board = AccelerationScoreboard::new();
|
||||
|
||||
// Domain 1: baseline (slow convergence)
|
||||
let baseline = make_curve(
|
||||
"d1_baseline",
|
||||
false,
|
||||
&[
|
||||
(0, 0.2, 0.1),
|
||||
(20, 0.5, 0.05),
|
||||
(50, 0.8, 0.02),
|
||||
(100, 0.95, 0.008),
|
||||
],
|
||||
);
|
||||
|
||||
// Domain 2: with transfer (fast convergence)
|
||||
let transfer = make_curve(
|
||||
"d2_transfer",
|
||||
true,
|
||||
&[
|
||||
(0, 0.4, 0.08),
|
||||
(10, 0.7, 0.03),
|
||||
(20, 0.9, 0.01),
|
||||
(40, 0.96, 0.007),
|
||||
],
|
||||
);
|
||||
|
||||
board.add_curve(baseline);
|
||||
board.add_curve(transfer);
|
||||
|
||||
let entry = board
|
||||
.compute_acceleration(
|
||||
&DomainId("d1_baseline".into()),
|
||||
&DomainId("d2_transfer".into()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(entry.acceleration > 1.0, "Transfer should accelerate");
|
||||
assert_eq!(entry.baseline_cycles, Some(100));
|
||||
assert_eq!(entry.transfer_cycles, Some(40));
|
||||
assert!((entry.acceleration - 2.5).abs() < 1e-4);
|
||||
assert!(entry.generalization_passed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scoreboard_summary() {
|
||||
let mut board = AccelerationScoreboard::new();
|
||||
let curve = make_curve("d1", false, &[(0, 0.5, 0.1), (50, 0.96, 0.005)]);
|
||||
board.add_curve(curve);
|
||||
|
||||
let summary = board.summary();
|
||||
assert_eq!(summary.domains.len(), 1);
|
||||
assert!(summary.domains[0].converged);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auc_accuracy() {
|
||||
let curve = make_curve(
|
||||
"test",
|
||||
false,
|
||||
&[(0, 0.0, 1.0), (10, 0.5, 0.5), (20, 1.0, 0.1)],
|
||||
);
|
||||
|
||||
let auc = curve.auc_accuracy();
|
||||
// Trapezoid: (10*(0+0.5)/2) + (10*(0.5+1.0)/2) = 2.5 + 7.5 = 10.0
|
||||
assert!((auc - 10.0).abs() < 1e-4);
|
||||
}
|
||||
}
|
||||
212
vendor/ruvector/crates/ruvector-domain-expansion/src/domain.rs
vendored
Normal file
212
vendor/ruvector/crates/ruvector-domain-expansion/src/domain.rs
vendored
Normal file
@@ -0,0 +1,212 @@
|
||||
//! Core domain trait and types for cross-domain transfer learning.
|
||||
//!
|
||||
//! A domain defines a problem space with:
|
||||
//! - A task generator (produces training instances)
|
||||
//! - An evaluator (scores solutions on [0.0, 1.0])
|
||||
//! - Embedding extraction (maps solutions into a shared representation space)
|
||||
//!
|
||||
//! True IQ growth appears when a kernel trained on Domain 1 improves Domain 2
|
||||
//! faster than Domain 2 alone. That is generalization.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
/// Unique identifier for a domain.
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct DomainId(pub String);
|
||||
|
||||
impl fmt::Display for DomainId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// A single task instance within a domain.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Task {
|
||||
/// Unique task identifier.
|
||||
pub id: String,
|
||||
/// Domain this task belongs to.
|
||||
pub domain_id: DomainId,
|
||||
/// Difficulty level [0.0, 1.0].
|
||||
pub difficulty: f32,
|
||||
/// Structured task specification (domain-specific JSON).
|
||||
pub spec: serde_json::Value,
|
||||
/// Optional constraints the solution must satisfy.
|
||||
pub constraints: Vec<String>,
|
||||
}
|
||||
|
||||
/// A candidate solution to a domain task.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Solution {
|
||||
/// The task this solves.
|
||||
pub task_id: String,
|
||||
/// Raw solution content (e.g., Rust source, plan steps, tool calls).
|
||||
pub content: String,
|
||||
/// Structured solution data (domain-specific).
|
||||
pub data: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Evaluation result for a solution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Evaluation {
|
||||
/// Overall score [0.0, 1.0] where 1.0 is perfect.
|
||||
pub score: f32,
|
||||
/// Correctness: does it produce the right answer?
|
||||
pub correctness: f32,
|
||||
/// Efficiency: resource usage relative to optimal.
|
||||
pub efficiency: f32,
|
||||
/// Elegance: structural quality, idiomatic patterns.
|
||||
pub elegance: f32,
|
||||
/// Per-constraint pass/fail results.
|
||||
pub constraint_results: Vec<bool>,
|
||||
/// Diagnostic notes from the evaluator.
|
||||
pub notes: Vec<String>,
|
||||
}
|
||||
|
||||
impl Evaluation {
|
||||
/// Create a zero-score evaluation (failure).
|
||||
pub fn zero(notes: Vec<String>) -> Self {
|
||||
Self {
|
||||
score: 0.0,
|
||||
correctness: 0.0,
|
||||
efficiency: 0.0,
|
||||
elegance: 0.0,
|
||||
constraint_results: Vec::new(),
|
||||
notes,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute composite score from weighted sub-scores.
|
||||
pub fn composite(correctness: f32, efficiency: f32, elegance: f32) -> Self {
|
||||
let score = 0.6 * correctness + 0.25 * efficiency + 0.15 * elegance;
|
||||
Self {
|
||||
score: score.clamp(0.0, 1.0),
|
||||
correctness,
|
||||
efficiency,
|
||||
elegance,
|
||||
constraint_results: Vec::new(),
|
||||
notes: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Embedding vector for cross-domain representation.
|
||||
/// Solutions from different domains are projected into a shared space
|
||||
/// so that transfer learning can identify structural similarities.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DomainEmbedding {
|
||||
/// The embedding vector.
|
||||
pub vector: Vec<f32>,
|
||||
/// Which domain produced this embedding.
|
||||
pub domain_id: DomainId,
|
||||
/// Dimensionality.
|
||||
pub dim: usize,
|
||||
}
|
||||
|
||||
impl DomainEmbedding {
|
||||
/// Create a new embedding.
|
||||
pub fn new(vector: Vec<f32>, domain_id: DomainId) -> Self {
|
||||
let dim = vector.len();
|
||||
Self {
|
||||
vector,
|
||||
domain_id,
|
||||
dim,
|
||||
}
|
||||
}
|
||||
|
||||
/// Cosine similarity with another embedding.
|
||||
pub fn cosine_similarity(&self, other: &DomainEmbedding) -> f32 {
|
||||
assert_eq!(self.dim, other.dim, "Embedding dimensions must match");
|
||||
|
||||
let mut dot = 0.0f32;
|
||||
let mut norm_a = 0.0f32;
|
||||
let mut norm_b = 0.0f32;
|
||||
|
||||
for i in 0..self.dim {
|
||||
dot += self.vector[i] * other.vector[i];
|
||||
norm_a += self.vector[i] * self.vector[i];
|
||||
norm_b += other.vector[i] * other.vector[i];
|
||||
}
|
||||
|
||||
let denom = (norm_a.sqrt() * norm_b.sqrt()).max(1e-10);
|
||||
dot / denom
|
||||
}
|
||||
}
|
||||
|
||||
/// Core trait that every domain must implement.
|
||||
///
|
||||
/// Domains are problem spaces: Rust program synthesis, structured planning,
|
||||
/// tool orchestration, etc. Each domain knows how to generate tasks,
|
||||
/// evaluate solutions, and embed solutions into a shared representation space.
|
||||
pub trait Domain: Send + Sync {
|
||||
/// Unique identifier for this domain.
|
||||
fn id(&self) -> &DomainId;
|
||||
|
||||
/// Human-readable name.
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Generate a batch of tasks at the given difficulty level.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `count` - Number of tasks to generate
|
||||
/// * `difficulty` - Target difficulty [0.0, 1.0]
|
||||
fn generate_tasks(&self, count: usize, difficulty: f32) -> Vec<Task>;
|
||||
|
||||
/// Evaluate a solution against its task.
|
||||
fn evaluate(&self, task: &Task, solution: &Solution) -> Evaluation;
|
||||
|
||||
/// Project a solution into the shared embedding space.
|
||||
/// This enables cross-domain transfer by finding structural similarities
|
||||
/// between solutions across different problem domains.
|
||||
fn embed(&self, solution: &Solution) -> DomainEmbedding;
|
||||
|
||||
/// Embedding dimensionality for this domain.
|
||||
fn embedding_dim(&self) -> usize;
|
||||
|
||||
/// Generate a reference (optimal or near-optimal) solution for a task.
|
||||
/// Used for computing efficiency ratios and as training signal.
|
||||
fn reference_solution(&self, task: &Task) -> Option<Solution>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_domain_id_display() {
|
||||
let id = DomainId("rust_synthesis".to_string());
|
||||
assert_eq!(format!("{}", id), "rust_synthesis");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluation_zero() {
|
||||
let eval = Evaluation::zero(vec!["compile error".to_string()]);
|
||||
assert_eq!(eval.score, 0.0);
|
||||
assert_eq!(eval.notes.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluation_composite() {
|
||||
let eval = Evaluation::composite(1.0, 0.8, 0.6);
|
||||
// 0.6*1.0 + 0.25*0.8 + 0.15*0.6 = 0.6 + 0.2 + 0.09 = 0.89
|
||||
assert!((eval.score - 0.89).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_cosine_similarity() {
|
||||
let id = DomainId("test".to_string());
|
||||
let a = DomainEmbedding::new(vec![1.0, 0.0, 0.0], id.clone());
|
||||
let b = DomainEmbedding::new(vec![1.0, 0.0, 0.0], id.clone());
|
||||
assert!((a.cosine_similarity(&b) - 1.0).abs() < 1e-6);
|
||||
|
||||
let c = DomainEmbedding::new(vec![0.0, 1.0, 0.0], id);
|
||||
assert!(a.cosine_similarity(&c).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluation_clamp() {
|
||||
let eval = Evaluation::composite(1.0, 1.0, 1.0);
|
||||
assert!(eval.score <= 1.0);
|
||||
}
|
||||
}
|
||||
39
vendor/ruvector/crates/ruvector-domain-expansion/src/error.rs
vendored
Normal file
39
vendor/ruvector/crates/ruvector-domain-expansion/src/error.rs
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
//! Error types for domain expansion.
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// Errors that can occur during domain expansion operations.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum DomainError {
|
||||
/// Problem generation failed.
|
||||
#[error("problem generation failed: {0}")]
|
||||
Generation(String),
|
||||
|
||||
/// Solution evaluation failed.
|
||||
#[error("evaluation failed: {0}")]
|
||||
Evaluation(String),
|
||||
|
||||
/// Dimension mismatch between domains.
|
||||
#[error("dimension mismatch: expected {expected}, got {got}")]
|
||||
DimensionMismatch { expected: usize, got: usize },
|
||||
|
||||
/// Domain not found in the expansion engine.
|
||||
#[error("domain not found: {0}")]
|
||||
DomainNotFound(String),
|
||||
|
||||
/// Transfer failed between domains.
|
||||
#[error("transfer failed from {source} to {target}: {reason}")]
|
||||
TransferFailed {
|
||||
source: String,
|
||||
target: String,
|
||||
reason: String,
|
||||
},
|
||||
|
||||
/// Kernel has not been trained on any domain yet.
|
||||
#[error("kernel not initialized: {0}")]
|
||||
KernelNotInitialized(String),
|
||||
|
||||
/// Invalid configuration.
|
||||
#[error("invalid config: {0}")]
|
||||
InvalidConfig(String),
|
||||
}
|
||||
591
vendor/ruvector/crates/ruvector-domain-expansion/src/lib.rs
vendored
Normal file
591
vendor/ruvector/crates/ruvector-domain-expansion/src/lib.rs
vendored
Normal file
@@ -0,0 +1,591 @@
|
||||
//! # Domain Expansion Engine
|
||||
//!
|
||||
//! Cross-domain transfer learning for general problem-solving capability.
|
||||
//!
|
||||
//! ## Core Insight
|
||||
//!
|
||||
//! True IQ growth appears when a kernel trained on Domain 1 improves Domain 2
|
||||
//! faster than Domain 2 alone. That is generalization.
|
||||
//!
|
||||
//! ## Two-Layer Architecture
|
||||
//!
|
||||
//! **Policy learning layer**: Meta Thompson Sampling with Beta priors across
|
||||
//! context buckets. Chooses strategies via uncertainty-aware selection.
|
||||
//! Transfer happens through compact priors — not raw trajectories.
|
||||
//!
|
||||
//! **Operator layer**: Deterministic domain kernels (Rust synthesis, planning,
|
||||
//! tool orchestration) that generate tasks, evaluate solutions, and produce
|
||||
//! embeddings into a shared representation space.
|
||||
//!
|
||||
//! ## Domains
|
||||
//!
|
||||
//! - **Rust Program Synthesis**: Generate Rust functions from specifications
|
||||
//! - **Structured Planning**: Multi-step plans with dependencies and resources
|
||||
//! - **Tool Orchestration**: Coordinate multiple tools/agents for complex goals
|
||||
//!
|
||||
//! ## Transfer Protocol
|
||||
//!
|
||||
//! 1. Train on Domain 1, extract `TransferPrior` (posterior summaries)
|
||||
//! 2. Initialize Domain 2 with dampened priors from Domain 1
|
||||
//! 3. Measure acceleration: cycles to convergence with/without transfer
|
||||
//! 4. A delta is promotable only if it improves target without regressing source
|
||||
//!
|
||||
//! ## Population-Based Policy Search
|
||||
//!
|
||||
//! Run a population of `PolicyKernel` variants in parallel.
|
||||
//! Each variant tunes knobs (skip mode, prepass, speculation thresholds).
|
||||
//! Keep top performers on holdouts, mutate, repeat.
|
||||
//!
|
||||
//! ## Acceptance Test
|
||||
//!
|
||||
//! Domain 2 must converge faster than Domain 1 to target accuracy, cost,
|
||||
//! robustness, and zero policy violations.
|
||||
|
||||
#![warn(missing_docs)]
|
||||
|
||||
pub mod cost_curve;
|
||||
pub mod domain;
|
||||
pub mod meta_learning;
|
||||
pub mod planning;
|
||||
pub mod policy_kernel;
|
||||
pub mod rust_synthesis;
|
||||
pub mod tool_orchestration;
|
||||
pub mod transfer;
|
||||
|
||||
/// RVF format integration: segment serialization, witness chains, AGI packaging.
|
||||
///
|
||||
/// Requires the `rvf` feature to be enabled.
|
||||
#[cfg(feature = "rvf")]
|
||||
pub mod rvf_bridge;
|
||||
|
||||
// Re-export core types.
|
||||
pub use cost_curve::{
|
||||
AccelerationEntry, AccelerationScoreboard, ConvergenceThresholds, CostCurve, CostCurvePoint,
|
||||
ScoreboardSummary,
|
||||
};
|
||||
pub use domain::{Domain, DomainEmbedding, DomainId, Evaluation, Solution, Task};
|
||||
pub use meta_learning::{
|
||||
CuriosityBonus, DecayingBeta, MetaLearningEngine, MetaLearningHealth, ParetoFront, ParetoPoint,
|
||||
PlateauAction, PlateauDetector, RegretSummary, RegretTracker,
|
||||
};
|
||||
pub use planning::PlanningDomain;
|
||||
pub use policy_kernel::{PolicyKernel, PolicyKnobs, PopulationSearch, PopulationStats};
|
||||
pub use rust_synthesis::RustSynthesisDomain;
|
||||
pub use tool_orchestration::ToolOrchestrationDomain;
|
||||
pub use transfer::{
|
||||
ArmId, BetaParams, ContextBucket, DualPathResult, MetaThompsonEngine, TransferPrior,
|
||||
TransferVerification,
|
||||
};
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// The domain expansion orchestrator.
|
||||
///
|
||||
/// Manages multiple domains, transfer learning between them,
|
||||
/// population-based policy search, and the acceleration scoreboard.
|
||||
///
|
||||
/// The `meta` field provides five composable learning improvements:
|
||||
/// regret tracking, decaying priors, plateau detection, Pareto front
|
||||
/// optimization, and curiosity-driven exploration.
|
||||
pub struct DomainExpansionEngine {
|
||||
/// Registered domains.
|
||||
domains: HashMap<DomainId, Box<dyn Domain>>,
|
||||
/// Meta Thompson Sampling engine for cross-domain transfer.
|
||||
pub thompson: MetaThompsonEngine,
|
||||
/// Population-based policy search.
|
||||
pub population: PopulationSearch,
|
||||
/// Acceleration scoreboard tracking convergence across domains.
|
||||
pub scoreboard: AccelerationScoreboard,
|
||||
/// Meta-learning engine: regret, plateau, Pareto, curiosity, decay.
|
||||
pub meta: MetaLearningEngine,
|
||||
/// Holdout tasks per domain for verification.
|
||||
holdouts: HashMap<DomainId, Vec<Task>>,
|
||||
/// Counterexample set: failed solutions that inform future decisions.
|
||||
counterexamples: HashMap<DomainId, Vec<(Task, Solution, Evaluation)>>,
|
||||
}
|
||||
|
||||
impl DomainExpansionEngine {
|
||||
/// Create a new domain expansion engine with default configuration.
|
||||
///
|
||||
/// Initializes the three core domains and the transfer engine.
|
||||
pub fn new() -> Self {
|
||||
let arms = vec![
|
||||
"greedy".into(),
|
||||
"exploratory".into(),
|
||||
"conservative".into(),
|
||||
"speculative".into(),
|
||||
];
|
||||
|
||||
let mut engine = Self {
|
||||
domains: HashMap::new(),
|
||||
thompson: MetaThompsonEngine::new(arms),
|
||||
population: PopulationSearch::new(8),
|
||||
scoreboard: AccelerationScoreboard::new(),
|
||||
meta: MetaLearningEngine::new(),
|
||||
holdouts: HashMap::new(),
|
||||
counterexamples: HashMap::new(),
|
||||
};
|
||||
|
||||
// Register the three core domains.
|
||||
engine.register_domain(Box::new(RustSynthesisDomain::new()));
|
||||
engine.register_domain(Box::new(PlanningDomain::new()));
|
||||
engine.register_domain(Box::new(ToolOrchestrationDomain::new()));
|
||||
|
||||
engine
|
||||
}
|
||||
|
||||
/// Register a new domain.
|
||||
pub fn register_domain(&mut self, domain: Box<dyn Domain>) {
|
||||
let id = domain.id().clone();
|
||||
self.thompson.init_domain_uniform(id.clone());
|
||||
self.domains.insert(id, domain);
|
||||
}
|
||||
|
||||
/// Generate holdout tasks for verification.
|
||||
pub fn generate_holdouts(&mut self, tasks_per_domain: usize, difficulty: f32) {
|
||||
for (id, domain) in &self.domains {
|
||||
let tasks = domain.generate_tasks(tasks_per_domain, difficulty);
|
||||
self.holdouts.insert(id.clone(), tasks);
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate training tasks for a specific domain.
|
||||
pub fn generate_tasks(&self, domain_id: &DomainId, count: usize, difficulty: f32) -> Vec<Task> {
|
||||
self.domains
|
||||
.get(domain_id)
|
||||
.map(|d| d.generate_tasks(count, difficulty))
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Evaluate a solution and record the outcome.
|
||||
pub fn evaluate_and_record(
|
||||
&mut self,
|
||||
domain_id: &DomainId,
|
||||
task: &Task,
|
||||
solution: &Solution,
|
||||
bucket: ContextBucket,
|
||||
arm: ArmId,
|
||||
) -> Evaluation {
|
||||
let eval = self
|
||||
.domains
|
||||
.get(domain_id)
|
||||
.map(|d| d.evaluate(task, solution))
|
||||
.unwrap_or_else(|| Evaluation::zero(vec!["Domain not found".into()]));
|
||||
|
||||
// Record outcome in Thompson engine.
|
||||
self.thompson.record_outcome(
|
||||
domain_id,
|
||||
bucket.clone(),
|
||||
arm.clone(),
|
||||
eval.score,
|
||||
1.0, // unit cost for now
|
||||
);
|
||||
|
||||
// Record in meta-learning engine (regret + curiosity + decaying beta).
|
||||
self.meta.record_decision(&bucket, &arm, eval.score);
|
||||
|
||||
// Store counterexamples for poor solutions.
|
||||
if eval.score < 0.3 {
|
||||
self.counterexamples
|
||||
.entry(domain_id.clone())
|
||||
.or_default()
|
||||
.push((task.clone(), solution.clone(), eval.clone()));
|
||||
}
|
||||
|
||||
eval
|
||||
}
|
||||
|
||||
/// Embed a solution into the shared representation space.
|
||||
pub fn embed(&self, domain_id: &DomainId, solution: &Solution) -> Option<DomainEmbedding> {
|
||||
self.domains.get(domain_id).map(|d| d.embed(solution))
|
||||
}
|
||||
|
||||
/// Initiate transfer from source domain to target domain.
|
||||
/// Extracts priors from source and seeds target.
|
||||
pub fn initiate_transfer(&mut self, source: &DomainId, target: &DomainId) {
|
||||
if let Some(prior) = self.thompson.extract_prior(source) {
|
||||
self.thompson
|
||||
.init_domain_with_transfer(target.clone(), &prior);
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify a transfer delta: did it improve target without regressing source?
|
||||
pub fn verify_transfer(
|
||||
&self,
|
||||
source: &DomainId,
|
||||
target: &DomainId,
|
||||
source_before: f32,
|
||||
source_after: f32,
|
||||
target_before: f32,
|
||||
target_after: f32,
|
||||
baseline_cycles: u64,
|
||||
transfer_cycles: u64,
|
||||
) -> TransferVerification {
|
||||
TransferVerification::verify(
|
||||
source.clone(),
|
||||
target.clone(),
|
||||
source_before,
|
||||
source_after,
|
||||
target_before,
|
||||
target_after,
|
||||
baseline_cycles,
|
||||
transfer_cycles,
|
||||
)
|
||||
}
|
||||
|
||||
/// Evaluate all policy kernels on holdout tasks.
|
||||
pub fn evaluate_population(&mut self) {
|
||||
let holdout_snapshot: HashMap<DomainId, Vec<Task>> = self.holdouts.clone();
|
||||
let domain_ids: Vec<DomainId> = self.domains.keys().cloned().collect();
|
||||
|
||||
for i in 0..self.population.population().len() {
|
||||
for domain_id in &domain_ids {
|
||||
if let Some(holdout_tasks) = holdout_snapshot.get(domain_id) {
|
||||
let mut total_score = 0.0f32;
|
||||
let mut count = 0;
|
||||
|
||||
for task in holdout_tasks {
|
||||
if let Some(domain) = self.domains.get(domain_id) {
|
||||
if let Some(ref_sol) = domain.reference_solution(task) {
|
||||
let eval = domain.evaluate(task, &ref_sol);
|
||||
total_score += eval.score;
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let avg_score = if count > 0 {
|
||||
total_score / count as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
if let Some(kernel) = self.population.kernel_mut(i) {
|
||||
kernel.record_score(domain_id.clone(), avg_score, 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Evolve the policy kernel population and update Pareto front.
|
||||
pub fn evolve_population(&mut self) {
|
||||
// Record current population into Pareto front before evolving.
|
||||
let gen = self.population.generation();
|
||||
for kernel in self.population.population() {
|
||||
let accuracy = kernel.fitness();
|
||||
let cost = if kernel.cycles > 0 {
|
||||
kernel.total_cost / kernel.cycles as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
// Robustness approximated by consistency across domains.
|
||||
let robustness = if kernel.holdout_scores.len() > 1 {
|
||||
let mean = accuracy;
|
||||
let var: f32 = kernel
|
||||
.holdout_scores
|
||||
.values()
|
||||
.map(|s| (s - mean).powi(2))
|
||||
.sum::<f32>()
|
||||
/ kernel.holdout_scores.len() as f32;
|
||||
(1.0 - var.sqrt()).max(0.0)
|
||||
} else {
|
||||
accuracy
|
||||
};
|
||||
self.meta
|
||||
.record_kernel(&kernel.id, accuracy, cost, robustness, gen);
|
||||
}
|
||||
|
||||
self.population.evolve();
|
||||
}
|
||||
|
||||
/// Get the best policy kernel found so far.
|
||||
pub fn best_kernel(&self) -> Option<&PolicyKernel> {
|
||||
self.population.best()
|
||||
}
|
||||
|
||||
/// Get population statistics.
|
||||
pub fn population_stats(&self) -> PopulationStats {
|
||||
self.population.stats()
|
||||
}
|
||||
|
||||
/// Get the scoreboard summary.
|
||||
pub fn scoreboard_summary(&self) -> ScoreboardSummary {
|
||||
self.scoreboard.summary()
|
||||
}
|
||||
|
||||
/// Get registered domain IDs.
|
||||
pub fn domain_ids(&self) -> Vec<DomainId> {
|
||||
self.domains.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Get counterexamples for a domain.
|
||||
pub fn counterexamples(&self, domain_id: &DomainId) -> &[(Task, Solution, Evaluation)] {
|
||||
self.counterexamples
|
||||
.get(domain_id)
|
||||
.map(|v| v.as_slice())
|
||||
.unwrap_or(&[])
|
||||
}
|
||||
|
||||
/// Select best arm for a context using Thompson Sampling.
|
||||
pub fn select_arm(&self, domain_id: &DomainId, bucket: &ContextBucket) -> Option<ArmId> {
|
||||
let mut rng = rand::thread_rng();
|
||||
self.thompson.select_arm(domain_id, bucket, &mut rng)
|
||||
}
|
||||
|
||||
/// Check if dual-path speculation should be triggered.
|
||||
pub fn should_speculate(&self, domain_id: &DomainId, bucket: &ContextBucket) -> bool {
|
||||
self.thompson.is_uncertain(domain_id, bucket, 0.15)
|
||||
}
|
||||
|
||||
/// Select arm with curiosity-boosted Thompson Sampling.
|
||||
///
|
||||
/// Combines the standard Thompson sample with a UCB-style exploration
|
||||
/// bonus that favors under-visited bucket/arm combinations.
|
||||
pub fn select_arm_curious(
|
||||
&self,
|
||||
domain_id: &DomainId,
|
||||
bucket: &ContextBucket,
|
||||
) -> Option<ArmId> {
|
||||
let mut rng = rand::thread_rng();
|
||||
// Get all arms and compute boosted scores
|
||||
let prior = self.thompson.extract_prior(domain_id)?;
|
||||
let arms: Vec<ArmId> = prior
|
||||
.bucket_priors
|
||||
.get(bucket)
|
||||
.map(|m| m.keys().cloned().collect())
|
||||
.unwrap_or_default();
|
||||
|
||||
if arms.is_empty() {
|
||||
return self.thompson.select_arm(domain_id, bucket, &mut rng);
|
||||
}
|
||||
|
||||
let mut best_arm = None;
|
||||
let mut best_score = f32::NEG_INFINITY;
|
||||
|
||||
for arm in &arms {
|
||||
let params = prior.get_prior(bucket, arm);
|
||||
let sample = params.sample(&mut rng);
|
||||
let boosted = self.meta.boosted_score(bucket, arm, sample);
|
||||
|
||||
if boosted > best_score {
|
||||
best_score = boosted;
|
||||
best_arm = Some(arm.clone());
|
||||
}
|
||||
}
|
||||
|
||||
best_arm.or_else(|| self.thompson.select_arm(domain_id, bucket, &mut rng))
|
||||
}
|
||||
|
||||
/// Get meta-learning health diagnostics.
|
||||
pub fn meta_health(&self) -> MetaLearningHealth {
|
||||
self.meta.health_check()
|
||||
}
|
||||
|
||||
/// Check cost curve for plateau and get recommended action.
|
||||
pub fn check_plateau(&mut self, domain_id: &DomainId) -> PlateauAction {
|
||||
if let Some(curve) = self.scoreboard.curves.get(domain_id) {
|
||||
self.meta.check_plateau(&curve.points)
|
||||
} else {
|
||||
PlateauAction::Continue
|
||||
}
|
||||
}
|
||||
|
||||
/// Get regret summary across all learning contexts.
|
||||
pub fn regret_summary(&self) -> RegretSummary {
|
||||
self.meta.regret.summary()
|
||||
}
|
||||
|
||||
/// Get the Pareto front of non-dominated policy kernels.
|
||||
pub fn pareto_front(&self) -> &ParetoFront {
|
||||
&self.meta.pareto
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DomainExpansionEngine {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_engine_creation() {
|
||||
let engine = DomainExpansionEngine::new();
|
||||
let ids = engine.domain_ids();
|
||||
assert_eq!(ids.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_tasks_all_domains() {
|
||||
let engine = DomainExpansionEngine::new();
|
||||
for domain_id in engine.domain_ids() {
|
||||
let tasks = engine.generate_tasks(&domain_id, 5, 0.5);
|
||||
assert_eq!(tasks.len(), 5);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arm_selection() {
|
||||
let engine = DomainExpansionEngine::new();
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "medium".into(),
|
||||
category: "general".into(),
|
||||
};
|
||||
for domain_id in engine.domain_ids() {
|
||||
let arm = engine.select_arm(&domain_id, &bucket);
|
||||
assert!(arm.is_some());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluate_and_record() {
|
||||
let mut engine = DomainExpansionEngine::new();
|
||||
let domain_id = DomainId("rust_synthesis".into());
|
||||
let tasks = engine.generate_tasks(&domain_id, 1, 0.3);
|
||||
let task = &tasks[0];
|
||||
|
||||
let solution = Solution {
|
||||
task_id: task.id.clone(),
|
||||
content:
|
||||
"fn double(values: &[i64]) -> Vec<i64> { values.iter().map(|&x| x * 2).collect() }"
|
||||
.into(),
|
||||
data: serde_json::Value::Null,
|
||||
};
|
||||
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "easy".into(),
|
||||
category: "transform".into(),
|
||||
};
|
||||
let arm = ArmId("greedy".into());
|
||||
|
||||
let eval = engine.evaluate_and_record(&domain_id, task, &solution, bucket, arm);
|
||||
assert!(eval.score >= 0.0 && eval.score <= 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_domain_embedding() {
|
||||
let engine = DomainExpansionEngine::new();
|
||||
|
||||
let rust_sol = Solution {
|
||||
task_id: "rust".into(),
|
||||
content: "fn foo() { for i in 0..10 { if i > 5 { } } }".into(),
|
||||
data: serde_json::Value::Null,
|
||||
};
|
||||
|
||||
let plan_sol = Solution {
|
||||
task_id: "plan".into(),
|
||||
content: "allocate cpu then schedule parallel jobs".into(),
|
||||
data: serde_json::json!({"steps": []}),
|
||||
};
|
||||
|
||||
let rust_emb = engine
|
||||
.embed(&DomainId("rust_synthesis".into()), &rust_sol)
|
||||
.unwrap();
|
||||
let plan_emb = engine
|
||||
.embed(&DomainId("structured_planning".into()), &plan_sol)
|
||||
.unwrap();
|
||||
|
||||
// Embeddings should be same dimension.
|
||||
assert_eq!(rust_emb.dim, plan_emb.dim);
|
||||
|
||||
// Cross-domain similarity should be defined.
|
||||
let sim = rust_emb.cosine_similarity(&plan_emb);
|
||||
assert!(sim >= -1.0 && sim <= 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transfer_flow() {
|
||||
let mut engine = DomainExpansionEngine::new();
|
||||
let source = DomainId("rust_synthesis".into());
|
||||
let target = DomainId("structured_planning".into());
|
||||
|
||||
// Record some outcomes in source domain.
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "medium".into(),
|
||||
category: "algorithm".into(),
|
||||
};
|
||||
|
||||
for _ in 0..30 {
|
||||
engine.thompson.record_outcome(
|
||||
&source,
|
||||
bucket.clone(),
|
||||
ArmId("greedy".into()),
|
||||
0.85,
|
||||
1.0,
|
||||
);
|
||||
}
|
||||
|
||||
// Initiate transfer.
|
||||
engine.initiate_transfer(&source, &target);
|
||||
|
||||
// Verify the transfer.
|
||||
let verification = engine.verify_transfer(
|
||||
&source, &target, 0.85, // source before
|
||||
0.845, // source after (within tolerance)
|
||||
0.3, // target before
|
||||
0.7, // target after
|
||||
100, // baseline cycles
|
||||
45, // transfer cycles
|
||||
);
|
||||
|
||||
assert!(verification.promotable);
|
||||
assert!(verification.acceleration_factor > 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_population_evolution() {
|
||||
let mut engine = DomainExpansionEngine::new();
|
||||
engine.generate_holdouts(3, 0.3);
|
||||
engine.evaluate_population();
|
||||
|
||||
let stats_before = engine.population_stats();
|
||||
assert_eq!(stats_before.generation, 0);
|
||||
|
||||
engine.evolve_population();
|
||||
let stats_after = engine.population_stats();
|
||||
assert_eq!(stats_after.generation, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_speculation_trigger() {
|
||||
let engine = DomainExpansionEngine::new();
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "hard".into(),
|
||||
category: "unknown".into(),
|
||||
};
|
||||
|
||||
// With uniform priors, should be uncertain.
|
||||
assert!(engine.should_speculate(&DomainId("rust_synthesis".into()), &bucket,));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_counterexample_tracking() {
|
||||
let mut engine = DomainExpansionEngine::new();
|
||||
let domain_id = DomainId("rust_synthesis".into());
|
||||
let tasks = engine.generate_tasks(&domain_id, 1, 0.9);
|
||||
let task = &tasks[0];
|
||||
|
||||
// Submit a terrible solution.
|
||||
let solution = Solution {
|
||||
task_id: task.id.clone(),
|
||||
content: "".into(), // empty = bad
|
||||
data: serde_json::Value::Null,
|
||||
};
|
||||
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "hard".into(),
|
||||
category: "algorithm".into(),
|
||||
};
|
||||
let arm = ArmId("speculative".into());
|
||||
|
||||
let eval = engine.evaluate_and_record(&domain_id, task, &solution, bucket, arm);
|
||||
assert!(eval.score < 0.3);
|
||||
|
||||
// Should be recorded as counterexample.
|
||||
assert!(!engine.counterexamples(&domain_id).is_empty());
|
||||
}
|
||||
}
|
||||
1398
vendor/ruvector/crates/ruvector-domain-expansion/src/meta_learning.rs
vendored
Normal file
1398
vendor/ruvector/crates/ruvector-domain-expansion/src/meta_learning.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
647
vendor/ruvector/crates/ruvector-domain-expansion/src/planning.rs
vendored
Normal file
647
vendor/ruvector/crates/ruvector-domain-expansion/src/planning.rs
vendored
Normal file
@@ -0,0 +1,647 @@
|
||||
//! Structured Planning Tasks Domain
|
||||
//!
|
||||
//! Generates tasks that require multi-step reasoning and plan construction.
|
||||
//! Task types include:
|
||||
//!
|
||||
//! - **ResourceAllocation**: Assign limited resources to maximize objective
|
||||
//! - **DependencyScheduling**: Order tasks respecting dependencies and deadlines
|
||||
//! - **StateSpaceSearch**: Navigate from initial to goal state
|
||||
//! - **ConstraintSatisfaction**: Find assignments satisfying all constraints
|
||||
//! - **HierarchicalDecomposition**: Break complex goals into sub-goals
|
||||
//!
|
||||
//! Solutions are plans: ordered sequences of actions with preconditions and effects.
|
||||
//! Cross-domain transfer from Rust synthesis helps because both require:
|
||||
//! structured decomposition, constraint satisfaction, and efficient search.
|
||||
|
||||
use crate::domain::{Domain, DomainEmbedding, DomainId, Evaluation, Solution, Task};
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const EMBEDDING_DIM: usize = 64;
|
||||
|
||||
/// Categories of planning tasks.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum PlanningCategory {
|
||||
/// Assign limited resources to competing demands.
|
||||
ResourceAllocation,
|
||||
/// Schedule tasks with precedence constraints and deadlines.
|
||||
DependencyScheduling,
|
||||
/// Find a path from initial state to goal state.
|
||||
StateSpaceSearch,
|
||||
/// Assign values to variables satisfying all constraints.
|
||||
ConstraintSatisfaction,
|
||||
/// Decompose a high-level goal into achievable sub-tasks.
|
||||
HierarchicalDecomposition,
|
||||
}
|
||||
|
||||
/// A resource in the planning world.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Resource {
|
||||
pub name: String,
|
||||
pub capacity: u32,
|
||||
}
|
||||
|
||||
/// An action in a plan.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PlanAction {
|
||||
pub name: String,
|
||||
pub preconditions: Vec<String>,
|
||||
pub effects: Vec<String>,
|
||||
pub cost: f32,
|
||||
pub duration: u32,
|
||||
}
|
||||
|
||||
/// A dependency edge: task A must complete before task B.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Dependency {
|
||||
pub from: String,
|
||||
pub to: String,
|
||||
}
|
||||
|
||||
/// Specification for a planning task.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PlanningTaskSpec {
|
||||
pub category: PlanningCategory,
|
||||
pub description: String,
|
||||
/// Available actions in the planning domain.
|
||||
pub available_actions: Vec<PlanAction>,
|
||||
/// Resources with capacity limits.
|
||||
pub resources: Vec<Resource>,
|
||||
/// Dependency constraints.
|
||||
pub dependencies: Vec<Dependency>,
|
||||
/// Initial state predicates.
|
||||
pub initial_state: Vec<String>,
|
||||
/// Goal state predicates.
|
||||
pub goal_state: Vec<String>,
|
||||
/// Maximum allowed plan cost.
|
||||
pub max_cost: Option<f32>,
|
||||
/// Maximum allowed plan steps.
|
||||
pub max_steps: Option<usize>,
|
||||
}
|
||||
|
||||
/// A parsed plan from a solution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Plan {
|
||||
pub steps: Vec<PlanStep>,
|
||||
}
|
||||
|
||||
/// A single step in a plan.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PlanStep {
|
||||
pub action: String,
|
||||
pub args: Vec<String>,
|
||||
pub start_time: Option<u32>,
|
||||
}
|
||||
|
||||
/// Structured planning domain.
|
||||
pub struct PlanningDomain {
|
||||
id: DomainId,
|
||||
}
|
||||
|
||||
impl PlanningDomain {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
id: DomainId("structured_planning".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_resource_allocation(&self, difficulty: f32) -> PlanningTaskSpec {
|
||||
let num_tasks = if difficulty < 0.3 {
|
||||
3
|
||||
} else if difficulty < 0.7 {
|
||||
6
|
||||
} else {
|
||||
10
|
||||
};
|
||||
|
||||
let actions: Vec<PlanAction> = (0..num_tasks)
|
||||
.map(|i| PlanAction {
|
||||
name: format!("task_{}", i),
|
||||
preconditions: vec![format!("resource_available_{}", i % 3)],
|
||||
effects: vec![format!("task_{}_complete", i)],
|
||||
cost: (i as f32 + 1.0) * 10.0,
|
||||
duration: (i as u32 % 5) + 1,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let resources = vec![
|
||||
Resource {
|
||||
name: "cpu".into(),
|
||||
capacity: if difficulty < 0.5 { 10 } else { 5 },
|
||||
},
|
||||
Resource {
|
||||
name: "memory".into(),
|
||||
capacity: if difficulty < 0.5 { 8 } else { 3 },
|
||||
},
|
||||
Resource {
|
||||
name: "io".into(),
|
||||
capacity: if difficulty < 0.5 { 6 } else { 2 },
|
||||
},
|
||||
];
|
||||
|
||||
let goal_state: Vec<String> = (0..num_tasks)
|
||||
.map(|i| format!("task_{}_complete", i))
|
||||
.collect();
|
||||
|
||||
PlanningTaskSpec {
|
||||
category: PlanningCategory::ResourceAllocation,
|
||||
description: format!(
|
||||
"Allocate {} resources to complete {} tasks within capacity.",
|
||||
resources.len(),
|
||||
num_tasks
|
||||
),
|
||||
available_actions: actions,
|
||||
resources,
|
||||
dependencies: Vec::new(),
|
||||
initial_state: vec![
|
||||
"resource_available_0".into(),
|
||||
"resource_available_1".into(),
|
||||
"resource_available_2".into(),
|
||||
],
|
||||
goal_state,
|
||||
max_cost: Some(num_tasks as f32 * 50.0),
|
||||
max_steps: Some(num_tasks * 2),
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_dependency_scheduling(&self, difficulty: f32) -> PlanningTaskSpec {
|
||||
let num_tasks = if difficulty < 0.3 {
|
||||
4
|
||||
} else if difficulty < 0.7 {
|
||||
7
|
||||
} else {
|
||||
12
|
||||
};
|
||||
|
||||
let actions: Vec<PlanAction> = (0..num_tasks)
|
||||
.map(|i| PlanAction {
|
||||
name: format!("job_{}", i),
|
||||
preconditions: if i > 0 {
|
||||
vec![format!("job_{}_done", i - 1)]
|
||||
} else {
|
||||
Vec::new()
|
||||
},
|
||||
effects: vec![format!("job_{}_done", i)],
|
||||
cost: 1.0,
|
||||
duration: (i as u32 % 3) + 1,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Create dependency chain with some parallelism
|
||||
let mut dependencies = Vec::new();
|
||||
for i in 1..num_tasks {
|
||||
// Linear chain
|
||||
dependencies.push(Dependency {
|
||||
from: format!("job_{}", i - 1),
|
||||
to: format!("job_{}", i),
|
||||
});
|
||||
// Add cross-dependencies at higher difficulty
|
||||
if difficulty > 0.5 && i >= 3 && i % 2 == 0 {
|
||||
dependencies.push(Dependency {
|
||||
from: format!("job_{}", i - 3),
|
||||
to: format!("job_{}", i),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
PlanningTaskSpec {
|
||||
category: PlanningCategory::DependencyScheduling,
|
||||
description: format!(
|
||||
"Schedule {} jobs respecting {} dependencies, minimizing makespan.",
|
||||
num_tasks,
|
||||
dependencies.len()
|
||||
),
|
||||
available_actions: actions,
|
||||
resources: vec![Resource {
|
||||
name: "worker".into(),
|
||||
capacity: if difficulty < 0.5 { 3 } else { 2 },
|
||||
}],
|
||||
dependencies,
|
||||
initial_state: Vec::new(),
|
||||
goal_state: (0..num_tasks).map(|i| format!("job_{}_done", i)).collect(),
|
||||
max_cost: None,
|
||||
max_steps: Some(num_tasks + 5),
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_state_space_search(&self, difficulty: f32) -> PlanningTaskSpec {
|
||||
let grid_size = if difficulty < 0.3 {
|
||||
3
|
||||
} else if difficulty < 0.7 {
|
||||
5
|
||||
} else {
|
||||
8
|
||||
};
|
||||
|
||||
let actions = vec![
|
||||
PlanAction {
|
||||
name: "move_up".into(),
|
||||
preconditions: vec!["not_top_edge".into()],
|
||||
effects: vec!["moved_up".into()],
|
||||
cost: 1.0,
|
||||
duration: 1,
|
||||
},
|
||||
PlanAction {
|
||||
name: "move_down".into(),
|
||||
preconditions: vec!["not_bottom_edge".into()],
|
||||
effects: vec!["moved_down".into()],
|
||||
cost: 1.0,
|
||||
duration: 1,
|
||||
},
|
||||
PlanAction {
|
||||
name: "move_left".into(),
|
||||
preconditions: vec!["not_left_edge".into()],
|
||||
effects: vec!["moved_left".into()],
|
||||
cost: 1.0,
|
||||
duration: 1,
|
||||
},
|
||||
PlanAction {
|
||||
name: "move_right".into(),
|
||||
preconditions: vec!["not_right_edge".into()],
|
||||
effects: vec!["moved_right".into()],
|
||||
cost: 1.0,
|
||||
duration: 1,
|
||||
},
|
||||
];
|
||||
|
||||
PlanningTaskSpec {
|
||||
category: PlanningCategory::StateSpaceSearch,
|
||||
description: format!(
|
||||
"Navigate a {}x{} grid from (0,0) to ({},{}) avoiding obstacles.",
|
||||
grid_size,
|
||||
grid_size,
|
||||
grid_size - 1,
|
||||
grid_size - 1
|
||||
),
|
||||
available_actions: actions,
|
||||
resources: Vec::new(),
|
||||
dependencies: Vec::new(),
|
||||
initial_state: vec!["at(0,0)".into()],
|
||||
goal_state: vec![format!("at({},{})", grid_size - 1, grid_size - 1)],
|
||||
max_cost: Some((grid_size as f32) * 4.0),
|
||||
max_steps: Some(grid_size * grid_size),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract structural features from a planning solution.
|
||||
fn extract_features(&self, solution: &Solution) -> Vec<f32> {
|
||||
let content = &solution.content;
|
||||
let mut features = vec![0.0f32; EMBEDDING_DIM];
|
||||
|
||||
// Parse the plan
|
||||
let plan: Plan = serde_json::from_str(&solution.data.to_string())
|
||||
.or_else(|_| serde_json::from_str(content))
|
||||
.unwrap_or(Plan { steps: Vec::new() });
|
||||
|
||||
// Feature 0-7: Plan structure
|
||||
features[0] = plan.steps.len() as f32 / 20.0;
|
||||
features[1] = {
|
||||
let unique_actions: std::collections::HashSet<&str> =
|
||||
plan.steps.iter().map(|s| s.action.as_str()).collect();
|
||||
unique_actions.len() as f32 / plan.steps.len().max(1) as f32
|
||||
};
|
||||
// Sequential vs parallel indicator
|
||||
features[2] = plan
|
||||
.steps
|
||||
.windows(2)
|
||||
.filter(|w| w[0].start_time == w[1].start_time)
|
||||
.count() as f32
|
||||
/ plan.steps.len().max(1) as f32;
|
||||
// Average args per step
|
||||
features[3] = plan.steps.iter().map(|s| s.args.len() as f32).sum::<f32>()
|
||||
/ plan.steps.len().max(1) as f32
|
||||
/ 5.0;
|
||||
|
||||
// Feature 8-15: Action type distribution
|
||||
let action_counts: std::collections::HashMap<&str, usize> =
|
||||
plan.steps
|
||||
.iter()
|
||||
.fold(std::collections::HashMap::new(), |mut acc, s| {
|
||||
*acc.entry(s.action.as_str()).or_insert(0) += 1;
|
||||
acc
|
||||
});
|
||||
let max_count = action_counts.values().max().copied().unwrap_or(0);
|
||||
features[8] = action_counts.len() as f32 / 10.0;
|
||||
features[9] = max_count as f32 / plan.steps.len().max(1) as f32;
|
||||
|
||||
// Feature 16-23: Text-based features from content
|
||||
features[16] = content.matches("allocate").count() as f32 / 5.0;
|
||||
features[17] = content.matches("schedule").count() as f32 / 5.0;
|
||||
features[18] = content.matches("move").count() as f32 / 10.0;
|
||||
features[19] = content.matches("assign").count() as f32 / 5.0;
|
||||
features[20] = content.matches("wait").count() as f32 / 5.0;
|
||||
features[21] = content.matches("parallel").count() as f32 / 3.0;
|
||||
features[22] = content.matches("constraint").count() as f32 / 5.0;
|
||||
features[23] = content.matches("deadline").count() as f32 / 3.0;
|
||||
|
||||
// Feature 32-39: Structural complexity indicators
|
||||
features[32] = content.matches("->").count() as f32 / 10.0;
|
||||
features[33] = content.matches("if ").count() as f32 / 5.0;
|
||||
features[34] = content.matches("then ").count() as f32 / 5.0;
|
||||
features[35] = content.matches("before").count() as f32 / 5.0;
|
||||
features[36] = content.matches("after").count() as f32 / 5.0;
|
||||
features[37] = content.matches("while").count() as f32 / 3.0;
|
||||
features[38] = content.matches("until").count() as f32 / 3.0;
|
||||
features[39] = content.matches("complete").count() as f32 / 5.0;
|
||||
|
||||
// Feature 48-55: Resource usage indicators
|
||||
features[48] = content.matches("cpu").count() as f32 / 3.0;
|
||||
features[49] = content.matches("memory").count() as f32 / 3.0;
|
||||
features[50] = content.matches("worker").count() as f32 / 3.0;
|
||||
features[51] = content.matches("capacity").count() as f32 / 3.0;
|
||||
features[52] = content.matches("cost").count() as f32 / 5.0;
|
||||
features[53] = content.matches("time").count() as f32 / 5.0;
|
||||
features[54] = content.matches("resource").count() as f32 / 5.0;
|
||||
features[55] = content.matches("limit").count() as f32 / 3.0;
|
||||
|
||||
// Normalize
|
||||
let norm: f32 = features.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 1e-10 {
|
||||
for f in &mut features {
|
||||
*f /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
features
|
||||
}
|
||||
|
||||
/// Evaluate a planning solution.
|
||||
fn score_plan(&self, spec: &PlanningTaskSpec, solution: &Solution) -> Evaluation {
|
||||
let content = &solution.content;
|
||||
let mut correctness = 0.0f32;
|
||||
let mut efficiency = 0.5f32;
|
||||
let mut elegance = 0.5f32;
|
||||
let mut notes = Vec::new();
|
||||
|
||||
// Parse plan from solution
|
||||
let plan: Option<Plan> = serde_json::from_str(&solution.data.to_string())
|
||||
.ok()
|
||||
.or_else(|| serde_json::from_str(content).ok());
|
||||
|
||||
let plan = match plan {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
// Fall back to text analysis
|
||||
let has_steps = content.contains("step") || content.contains("action");
|
||||
if has_steps {
|
||||
correctness = 0.2;
|
||||
}
|
||||
return Evaluation {
|
||||
score: correctness * 0.6,
|
||||
correctness,
|
||||
efficiency: 0.0,
|
||||
elegance: 0.0,
|
||||
constraint_results: Vec::new(),
|
||||
notes: vec!["Could not parse structured plan".into()],
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
// Check plan is non-empty
|
||||
if plan.steps.is_empty() {
|
||||
return Evaluation::zero(vec!["Empty plan".into()]);
|
||||
}
|
||||
|
||||
// Check goal coverage: how many goal predicates are addressed
|
||||
let goal_coverage = spec
|
||||
.goal_state
|
||||
.iter()
|
||||
.filter(|goal| {
|
||||
plan.steps.iter().any(|step| {
|
||||
let action_name = &step.action;
|
||||
// Check if any action's effects mention this goal
|
||||
spec.available_actions
|
||||
.iter()
|
||||
.any(|a| a.name == *action_name && a.effects.iter().any(|e| e == *goal))
|
||||
})
|
||||
})
|
||||
.count() as f32
|
||||
/ spec.goal_state.len().max(1) as f32;
|
||||
|
||||
correctness = goal_coverage;
|
||||
|
||||
// Check dependency ordering
|
||||
let mut dep_violations = 0;
|
||||
for dep in &spec.dependencies {
|
||||
let from_pos = plan.steps.iter().position(|s| s.action == dep.from);
|
||||
let to_pos = plan.steps.iter().position(|s| s.action == dep.to);
|
||||
if let (Some(f), Some(t)) = (from_pos, to_pos) {
|
||||
if f >= t {
|
||||
dep_violations += 1;
|
||||
notes.push(format!(
|
||||
"Dependency violation: {} must come before {}",
|
||||
dep.from, dep.to
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
if !spec.dependencies.is_empty() {
|
||||
let dep_score = 1.0 - (dep_violations as f32 / spec.dependencies.len() as f32);
|
||||
correctness = correctness * 0.5 + dep_score * 0.5;
|
||||
}
|
||||
|
||||
// Efficiency: compare to max allowed steps/cost
|
||||
if let Some(max_steps) = spec.max_steps {
|
||||
let step_ratio = plan.steps.len() as f32 / max_steps as f32;
|
||||
efficiency = if step_ratio <= 1.0 {
|
||||
1.0 - (step_ratio * 0.5) // Fewer steps = better
|
||||
} else {
|
||||
0.5 / step_ratio // Penalty for exceeding max
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(max_cost) = spec.max_cost {
|
||||
let total_cost: f32 = plan
|
||||
.steps
|
||||
.iter()
|
||||
.filter_map(|step| {
|
||||
spec.available_actions
|
||||
.iter()
|
||||
.find(|a| a.name == step.action)
|
||||
.map(|a| a.cost)
|
||||
})
|
||||
.sum();
|
||||
if total_cost > max_cost {
|
||||
efficiency *= 0.5;
|
||||
notes.push(format!(
|
||||
"Plan cost {:.1} exceeds budget {:.1}",
|
||||
total_cost, max_cost
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Elegance: minimal redundancy, good parallelism
|
||||
let unique_actions: std::collections::HashSet<&str> =
|
||||
plan.steps.iter().map(|s| s.action.as_str()).collect();
|
||||
let redundancy = 1.0 - (unique_actions.len() as f32 / plan.steps.len().max(1) as f32);
|
||||
elegance = 1.0 - redundancy * 0.5;
|
||||
|
||||
// Bonus for parallel scheduling
|
||||
if plan
|
||||
.steps
|
||||
.windows(2)
|
||||
.any(|w| w[0].start_time == w[1].start_time)
|
||||
{
|
||||
elegance += 0.1;
|
||||
}
|
||||
elegance = elegance.clamp(0.0, 1.0);
|
||||
|
||||
let score = 0.6 * correctness + 0.25 * efficiency + 0.15 * elegance;
|
||||
Evaluation {
|
||||
score: score.clamp(0.0, 1.0),
|
||||
correctness,
|
||||
efficiency,
|
||||
elegance,
|
||||
constraint_results: Vec::new(),
|
||||
notes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PlanningDomain {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Domain for PlanningDomain {
|
||||
fn id(&self) -> &DomainId {
|
||||
&self.id
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"Structured Planning"
|
||||
}
|
||||
|
||||
fn generate_tasks(&self, count: usize, difficulty: f32) -> Vec<Task> {
|
||||
let mut rng = rand::thread_rng();
|
||||
let difficulty = difficulty.clamp(0.0, 1.0);
|
||||
|
||||
(0..count)
|
||||
.map(|i| {
|
||||
let category_roll: f32 = rng.gen();
|
||||
let spec = if category_roll < 0.35 {
|
||||
self.gen_resource_allocation(difficulty)
|
||||
} else if category_roll < 0.7 {
|
||||
self.gen_dependency_scheduling(difficulty)
|
||||
} else {
|
||||
self.gen_state_space_search(difficulty)
|
||||
};
|
||||
|
||||
Task {
|
||||
id: format!("planning_{}_d{:.0}", i, difficulty * 100.0),
|
||||
domain_id: self.id.clone(),
|
||||
difficulty,
|
||||
spec: serde_json::to_value(&spec).unwrap_or_default(),
|
||||
constraints: Vec::new(),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn evaluate(&self, task: &Task, solution: &Solution) -> Evaluation {
|
||||
let spec: PlanningTaskSpec = match serde_json::from_value(task.spec.clone()) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Evaluation::zero(vec![format!("Invalid task spec: {}", e)]),
|
||||
};
|
||||
self.score_plan(&spec, solution)
|
||||
}
|
||||
|
||||
fn embed(&self, solution: &Solution) -> DomainEmbedding {
|
||||
let features = self.extract_features(solution);
|
||||
DomainEmbedding::new(features, self.id.clone())
|
||||
}
|
||||
|
||||
fn embedding_dim(&self) -> usize {
|
||||
EMBEDDING_DIM
|
||||
}
|
||||
|
||||
fn reference_solution(&self, task: &Task) -> Option<Solution> {
|
||||
let spec: PlanningTaskSpec = serde_json::from_value(task.spec.clone()).ok()?;
|
||||
|
||||
// Generate a naive sequential plan that executes all actions in order
|
||||
let steps: Vec<PlanStep> = spec
|
||||
.available_actions
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, a)| PlanStep {
|
||||
action: a.name.clone(),
|
||||
args: Vec::new(),
|
||||
start_time: Some(i as u32),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let plan = Plan { steps };
|
||||
let content = serde_json::to_string_pretty(&plan).ok()?;
|
||||
|
||||
Some(Solution {
|
||||
task_id: task.id.clone(),
|
||||
content,
|
||||
data: serde_json::to_value(&plan).ok()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_planning_tasks() {
|
||||
let domain = PlanningDomain::new();
|
||||
let tasks = domain.generate_tasks(5, 0.5);
|
||||
assert_eq!(tasks.len(), 5);
|
||||
for task in &tasks {
|
||||
assert_eq!(task.domain_id, domain.id);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reference_solution_exists() {
|
||||
let domain = PlanningDomain::new();
|
||||
let tasks = domain.generate_tasks(3, 0.3);
|
||||
for task in &tasks {
|
||||
let ref_sol = domain.reference_solution(task);
|
||||
assert!(ref_sol.is_some(), "Should produce reference solution");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluate_reference() {
|
||||
let domain = PlanningDomain::new();
|
||||
let tasks = domain.generate_tasks(3, 0.3);
|
||||
for task in &tasks {
|
||||
if let Some(solution) = domain.reference_solution(task) {
|
||||
let eval = domain.evaluate(task, &solution);
|
||||
assert!(eval.score >= 0.0 && eval.score <= 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embed_planning() {
|
||||
let domain = PlanningDomain::new();
|
||||
let solution = Solution {
|
||||
task_id: "test".into(),
|
||||
content: "allocate cpu to task_0, schedule job_1 after job_0".into(),
|
||||
data: serde_json::json!({ "steps": [] }),
|
||||
};
|
||||
let embedding = domain.embed(&solution);
|
||||
assert_eq!(embedding.dim, EMBEDDING_DIM);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_difficulty_scaling() {
|
||||
let domain = PlanningDomain::new();
|
||||
let easy = domain.generate_tasks(1, 0.1);
|
||||
let hard = domain.generate_tasks(1, 0.9);
|
||||
|
||||
let easy_spec: PlanningTaskSpec = serde_json::from_value(easy[0].spec.clone()).unwrap();
|
||||
let hard_spec: PlanningTaskSpec = serde_json::from_value(hard[0].spec.clone()).unwrap();
|
||||
|
||||
assert!(
|
||||
hard_spec.available_actions.len() >= easy_spec.available_actions.len(),
|
||||
"Harder tasks should have more actions"
|
||||
);
|
||||
}
|
||||
}
|
||||
468
vendor/ruvector/crates/ruvector-domain-expansion/src/policy_kernel.rs
vendored
Normal file
468
vendor/ruvector/crates/ruvector-domain-expansion/src/policy_kernel.rs
vendored
Normal file
@@ -0,0 +1,468 @@
|
||||
//! PolicyKernel: Population-Based Policy Search
|
||||
//!
|
||||
//! Run a small population of policy variants in parallel.
|
||||
//! Each variant changes a small set of knobs:
|
||||
//! - skip mode policy
|
||||
//! - prepass mode
|
||||
//! - speculation trigger thresholds
|
||||
//! - budget allocation
|
||||
//!
|
||||
//! Selection: keep top performers on holdouts, mutate knobs, repeat.
|
||||
//! Only merge deltas that pass replay-verify.
|
||||
|
||||
use crate::domain::DomainId;
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Configuration knobs that a PolicyKernel can tune.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PolicyKnobs {
|
||||
/// Whether to skip low-value operations.
|
||||
pub skip_mode: bool,
|
||||
/// Run a cheaper prepass before full execution.
|
||||
pub prepass_enabled: bool,
|
||||
/// Threshold for triggering speculative dual-path [0.0, 1.0].
|
||||
pub speculation_threshold: f32,
|
||||
/// Budget fraction allocated to exploration vs exploitation [0.0, 1.0].
|
||||
pub exploration_budget: f32,
|
||||
/// Maximum retries on failure.
|
||||
pub max_retries: u32,
|
||||
/// Batch size for parallel evaluation.
|
||||
pub batch_size: usize,
|
||||
/// Cost decay factor for EMA.
|
||||
pub cost_decay: f32,
|
||||
/// Minimum confidence to skip uncertainty check.
|
||||
pub confidence_floor: f32,
|
||||
}
|
||||
|
||||
impl PolicyKnobs {
|
||||
/// Sensible defaults.
|
||||
pub fn default_knobs() -> Self {
|
||||
Self {
|
||||
skip_mode: false,
|
||||
prepass_enabled: true,
|
||||
speculation_threshold: 0.15,
|
||||
exploration_budget: 0.2,
|
||||
max_retries: 2,
|
||||
batch_size: 8,
|
||||
cost_decay: 0.9,
|
||||
confidence_floor: 0.7,
|
||||
}
|
||||
}
|
||||
|
||||
/// Mutate knobs with small random perturbations.
|
||||
pub fn mutate(&self, rng: &mut impl Rng, mutation_rate: f32) -> Self {
|
||||
let mut knobs = self.clone();
|
||||
|
||||
if rng.gen::<f32>() < mutation_rate {
|
||||
knobs.skip_mode = !knobs.skip_mode;
|
||||
}
|
||||
if rng.gen::<f32>() < mutation_rate {
|
||||
knobs.prepass_enabled = !knobs.prepass_enabled;
|
||||
}
|
||||
if rng.gen::<f32>() < mutation_rate {
|
||||
let delta: f32 = rng.gen_range(-0.1..0.1);
|
||||
knobs.speculation_threshold = (knobs.speculation_threshold + delta).clamp(0.01, 0.5);
|
||||
}
|
||||
if rng.gen::<f32>() < mutation_rate {
|
||||
let delta: f32 = rng.gen_range(-0.1..0.1);
|
||||
knobs.exploration_budget = (knobs.exploration_budget + delta).clamp(0.01, 0.5);
|
||||
}
|
||||
if rng.gen::<f32>() < mutation_rate {
|
||||
knobs.max_retries = rng.gen_range(0..5);
|
||||
}
|
||||
if rng.gen::<f32>() < mutation_rate {
|
||||
knobs.batch_size = rng.gen_range(1..32);
|
||||
}
|
||||
if rng.gen::<f32>() < mutation_rate {
|
||||
let delta: f32 = rng.gen_range(-0.05..0.05);
|
||||
knobs.cost_decay = (knobs.cost_decay + delta).clamp(0.5, 0.99);
|
||||
}
|
||||
if rng.gen::<f32>() < mutation_rate {
|
||||
let delta: f32 = rng.gen_range(-0.1..0.1);
|
||||
knobs.confidence_floor = (knobs.confidence_floor + delta).clamp(0.3, 0.95);
|
||||
}
|
||||
|
||||
knobs
|
||||
}
|
||||
|
||||
/// Crossover two parent knobs to produce a child.
|
||||
pub fn crossover(&self, other: &PolicyKnobs, rng: &mut impl Rng) -> Self {
|
||||
Self {
|
||||
skip_mode: if rng.gen() {
|
||||
self.skip_mode
|
||||
} else {
|
||||
other.skip_mode
|
||||
},
|
||||
prepass_enabled: if rng.gen() {
|
||||
self.prepass_enabled
|
||||
} else {
|
||||
other.prepass_enabled
|
||||
},
|
||||
speculation_threshold: if rng.gen() {
|
||||
self.speculation_threshold
|
||||
} else {
|
||||
other.speculation_threshold
|
||||
},
|
||||
exploration_budget: if rng.gen() {
|
||||
self.exploration_budget
|
||||
} else {
|
||||
other.exploration_budget
|
||||
},
|
||||
max_retries: if rng.gen() {
|
||||
self.max_retries
|
||||
} else {
|
||||
other.max_retries
|
||||
},
|
||||
batch_size: if rng.gen() {
|
||||
self.batch_size
|
||||
} else {
|
||||
other.batch_size
|
||||
},
|
||||
cost_decay: if rng.gen() {
|
||||
self.cost_decay
|
||||
} else {
|
||||
other.cost_decay
|
||||
},
|
||||
confidence_floor: if rng.gen() {
|
||||
self.confidence_floor
|
||||
} else {
|
||||
other.confidence_floor
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A PolicyKernel is a versioned policy configuration with performance history.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PolicyKernel {
|
||||
/// Unique identifier.
|
||||
pub id: String,
|
||||
/// Configuration knobs.
|
||||
pub knobs: PolicyKnobs,
|
||||
/// Performance on holdout tasks (domain_id -> score).
|
||||
pub holdout_scores: HashMap<DomainId, f32>,
|
||||
/// Total cost incurred.
|
||||
pub total_cost: f32,
|
||||
/// Number of evaluation cycles.
|
||||
pub cycles: u64,
|
||||
/// Generation (0 = initial, increments on mutation).
|
||||
pub generation: u32,
|
||||
/// Parent kernel ID (for lineage tracking).
|
||||
pub parent_id: Option<String>,
|
||||
/// Whether this kernel has been verified via replay.
|
||||
pub replay_verified: bool,
|
||||
}
|
||||
|
||||
impl PolicyKernel {
|
||||
/// Create a new kernel with default knobs.
|
||||
pub fn new(id: String) -> Self {
|
||||
Self {
|
||||
id,
|
||||
knobs: PolicyKnobs::default_knobs(),
|
||||
holdout_scores: HashMap::new(),
|
||||
total_cost: 0.0,
|
||||
cycles: 0,
|
||||
generation: 0,
|
||||
parent_id: None,
|
||||
replay_verified: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a mutated child kernel.
|
||||
pub fn mutate(&self, child_id: String, rng: &mut impl Rng) -> Self {
|
||||
Self {
|
||||
id: child_id,
|
||||
knobs: self.knobs.mutate(rng, 0.3),
|
||||
holdout_scores: HashMap::new(),
|
||||
total_cost: 0.0,
|
||||
cycles: 0,
|
||||
generation: self.generation + 1,
|
||||
parent_id: Some(self.id.clone()),
|
||||
replay_verified: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a holdout score for a domain.
|
||||
pub fn record_score(&mut self, domain_id: DomainId, score: f32, cost: f32) {
|
||||
self.holdout_scores.insert(domain_id, score);
|
||||
self.total_cost += cost;
|
||||
self.cycles += 1;
|
||||
}
|
||||
|
||||
/// Fitness: average holdout score across all evaluated domains.
|
||||
pub fn fitness(&self) -> f32 {
|
||||
if self.holdout_scores.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let total: f32 = self.holdout_scores.values().sum();
|
||||
total / self.holdout_scores.len() as f32
|
||||
}
|
||||
|
||||
/// Cost-adjusted fitness: penalizes expensive kernels.
|
||||
pub fn cost_adjusted_fitness(&self) -> f32 {
|
||||
let raw = self.fitness();
|
||||
let cost_penalty = (self.total_cost / self.cycles.max(1) as f32).min(1.0);
|
||||
raw * (1.0 - cost_penalty * 0.3) // 30% weight on cost
|
||||
}
|
||||
}
|
||||
|
||||
/// Population-based policy search engine.
|
||||
#[derive(Clone)]
|
||||
pub struct PopulationSearch {
|
||||
/// Current population of kernels.
|
||||
population: Vec<PolicyKernel>,
|
||||
/// Population size.
|
||||
pop_size: usize,
|
||||
/// Best kernel seen so far.
|
||||
best_kernel: Option<PolicyKernel>,
|
||||
/// Generation counter.
|
||||
generation: u32,
|
||||
}
|
||||
|
||||
impl PopulationSearch {
|
||||
/// Create a new population search with initial random population.
|
||||
pub fn new(pop_size: usize) -> Self {
|
||||
let mut rng = rand::thread_rng();
|
||||
let population: Vec<PolicyKernel> = (0..pop_size)
|
||||
.map(|i| {
|
||||
let mut kernel = PolicyKernel::new(format!("kernel_g0_{}", i));
|
||||
// Random initial knobs
|
||||
kernel.knobs = PolicyKnobs::default_knobs().mutate(&mut rng, 0.8);
|
||||
kernel
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
population,
|
||||
pop_size,
|
||||
best_kernel: None,
|
||||
generation: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current population for evaluation.
|
||||
pub fn population(&self) -> &[PolicyKernel] {
|
||||
&self.population
|
||||
}
|
||||
|
||||
/// Get mutable reference to a kernel by index.
|
||||
pub fn kernel_mut(&mut self, index: usize) -> Option<&mut PolicyKernel> {
|
||||
self.population.get_mut(index)
|
||||
}
|
||||
|
||||
/// Evolve to next generation: select top performers, mutate, fill population.
|
||||
pub fn evolve(&mut self) {
|
||||
let mut rng = rand::thread_rng();
|
||||
self.generation += 1;
|
||||
|
||||
// Sort by cost-adjusted fitness (descending)
|
||||
self.population.sort_by(|a, b| {
|
||||
b.cost_adjusted_fitness()
|
||||
.partial_cmp(&a.cost_adjusted_fitness())
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
// Track best
|
||||
if let Some(best) = self.population.first() {
|
||||
if self
|
||||
.best_kernel
|
||||
.as_ref()
|
||||
.map_or(true, |b| best.fitness() > b.fitness())
|
||||
{
|
||||
self.best_kernel = Some(best.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Elite selection: keep top 25%
|
||||
let elite_count = (self.pop_size / 4).max(1);
|
||||
let elites: Vec<PolicyKernel> = self.population[..elite_count].to_vec();
|
||||
|
||||
// Build next generation
|
||||
let mut next_gen = Vec::with_capacity(self.pop_size);
|
||||
|
||||
// Keep elites
|
||||
for elite in &elites {
|
||||
let mut kept = elite.clone();
|
||||
kept.id = format!("kernel_g{}_{}", self.generation, next_gen.len());
|
||||
kept.holdout_scores.clear();
|
||||
kept.total_cost = 0.0;
|
||||
kept.cycles = 0;
|
||||
next_gen.push(kept);
|
||||
}
|
||||
|
||||
// Fill rest with mutations and crossovers
|
||||
while next_gen.len() < self.pop_size {
|
||||
let parent_idx = rng.gen_range(0..elites.len());
|
||||
let child_id = format!("kernel_g{}_{}", self.generation, next_gen.len());
|
||||
|
||||
let child = if rng.gen::<f32>() < 0.3 && elites.len() > 1 {
|
||||
// Crossover
|
||||
let other_idx =
|
||||
(parent_idx + 1 + rng.gen_range(0..elites.len() - 1)) % elites.len();
|
||||
let mut child = PolicyKernel::new(child_id);
|
||||
child.knobs = elites[parent_idx]
|
||||
.knobs
|
||||
.crossover(&elites[other_idx].knobs, &mut rng);
|
||||
child.generation = self.generation;
|
||||
child.parent_id = Some(elites[parent_idx].id.clone());
|
||||
child
|
||||
} else {
|
||||
// Mutation
|
||||
elites[parent_idx].mutate(child_id, &mut rng)
|
||||
};
|
||||
|
||||
next_gen.push(child);
|
||||
}
|
||||
|
||||
self.population = next_gen;
|
||||
}
|
||||
|
||||
/// Get the best kernel found so far.
|
||||
pub fn best(&self) -> Option<&PolicyKernel> {
|
||||
self.best_kernel.as_ref()
|
||||
}
|
||||
|
||||
/// Current generation number.
|
||||
pub fn generation(&self) -> u32 {
|
||||
self.generation
|
||||
}
|
||||
|
||||
/// Get fitness statistics for the current population.
|
||||
pub fn stats(&self) -> PopulationStats {
|
||||
let fitnesses: Vec<f32> = self.population.iter().map(|k| k.fitness()).collect();
|
||||
let mean = fitnesses.iter().sum::<f32>() / fitnesses.len().max(1) as f32;
|
||||
let max = fitnesses.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let min = fitnesses.iter().cloned().fold(f32::INFINITY, f32::min);
|
||||
let variance = fitnesses.iter().map(|f| (f - mean).powi(2)).sum::<f32>()
|
||||
/ fitnesses.len().max(1) as f32;
|
||||
|
||||
PopulationStats {
|
||||
generation: self.generation,
|
||||
pop_size: self.population.len(),
|
||||
mean_fitness: mean,
|
||||
max_fitness: max,
|
||||
min_fitness: min,
|
||||
fitness_variance: variance,
|
||||
best_ever_fitness: self
|
||||
.best_kernel
|
||||
.as_ref()
|
||||
.map(|k| k.fitness())
|
||||
.unwrap_or(0.0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Statistics about the current population.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PopulationStats {
|
||||
pub generation: u32,
|
||||
pub pop_size: usize,
|
||||
pub mean_fitness: f32,
|
||||
pub max_fitness: f32,
|
||||
pub min_fitness: f32,
|
||||
pub fitness_variance: f32,
|
||||
pub best_ever_fitness: f32,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_policy_knobs_default() {
|
||||
let knobs = PolicyKnobs::default_knobs();
|
||||
assert!(!knobs.skip_mode);
|
||||
assert!(knobs.prepass_enabled);
|
||||
assert!(knobs.speculation_threshold > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_policy_knobs_mutate() {
|
||||
let knobs = PolicyKnobs::default_knobs();
|
||||
let mut rng = rand::thread_rng();
|
||||
let mutated = knobs.mutate(&mut rng, 1.0); // high mutation rate
|
||||
// At least something should differ (probabilistically)
|
||||
// Can't guarantee due to randomness, but bounds should hold
|
||||
assert!(mutated.speculation_threshold >= 0.01 && mutated.speculation_threshold <= 0.5);
|
||||
assert!(mutated.exploration_budget >= 0.01 && mutated.exploration_budget <= 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_policy_kernel_fitness() {
|
||||
let mut kernel = PolicyKernel::new("test".into());
|
||||
assert_eq!(kernel.fitness(), 0.0);
|
||||
|
||||
kernel.record_score(DomainId("d1".into()), 0.8, 1.0);
|
||||
kernel.record_score(DomainId("d2".into()), 0.6, 1.0);
|
||||
assert!((kernel.fitness() - 0.7).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_population_search_evolve() {
|
||||
let mut search = PopulationSearch::new(8);
|
||||
assert_eq!(search.population().len(), 8);
|
||||
|
||||
// Simulate evaluation
|
||||
for i in 0..8 {
|
||||
if let Some(kernel) = search.kernel_mut(i) {
|
||||
let score = 0.3 + (i as f32) * 0.08;
|
||||
kernel.record_score(DomainId("test".into()), score, 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
search.evolve();
|
||||
assert_eq!(search.population().len(), 8);
|
||||
assert_eq!(search.generation(), 1);
|
||||
assert!(search.best().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_population_stats() {
|
||||
let mut search = PopulationSearch::new(4);
|
||||
|
||||
for i in 0..4 {
|
||||
if let Some(kernel) = search.kernel_mut(i) {
|
||||
kernel.record_score(DomainId("test".into()), (i as f32) * 0.25, 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
let stats = search.stats();
|
||||
assert_eq!(stats.pop_size, 4);
|
||||
assert!(stats.max_fitness >= stats.min_fitness);
|
||||
assert!(stats.mean_fitness >= stats.min_fitness);
|
||||
assert!(stats.mean_fitness <= stats.max_fitness);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_crossover() {
|
||||
let a = PolicyKnobs {
|
||||
skip_mode: true,
|
||||
prepass_enabled: false,
|
||||
speculation_threshold: 0.1,
|
||||
exploration_budget: 0.1,
|
||||
max_retries: 1,
|
||||
batch_size: 4,
|
||||
cost_decay: 0.8,
|
||||
confidence_floor: 0.5,
|
||||
};
|
||||
let b = PolicyKnobs {
|
||||
skip_mode: false,
|
||||
prepass_enabled: true,
|
||||
speculation_threshold: 0.4,
|
||||
exploration_budget: 0.4,
|
||||
max_retries: 4,
|
||||
batch_size: 16,
|
||||
cost_decay: 0.95,
|
||||
confidence_floor: 0.9,
|
||||
};
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let child = a.crossover(&b, &mut rng);
|
||||
|
||||
// Child values should come from one parent or the other
|
||||
assert!(child.max_retries == 1 || child.max_retries == 4);
|
||||
assert!(child.batch_size == 4 || child.batch_size == 16);
|
||||
}
|
||||
}
|
||||
603
vendor/ruvector/crates/ruvector-domain-expansion/src/rust_synthesis.rs
vendored
Normal file
603
vendor/ruvector/crates/ruvector-domain-expansion/src/rust_synthesis.rs
vendored
Normal file
@@ -0,0 +1,603 @@
|
||||
//! Rust Program Synthesis Domain
|
||||
//!
|
||||
//! Generates tasks that require synthesizing Rust programs from specifications.
|
||||
//! Task types include:
|
||||
//!
|
||||
//! - **Transform**: Apply a function to data (map, filter, fold)
|
||||
//! - **DataStructure**: Implement a data structure with specific operations
|
||||
//! - **Algorithm**: Implement a named algorithm (sorting, searching, graph)
|
||||
//! - **TypeLevel**: Express constraints via Rust's type system
|
||||
//! - **Concurrency**: Safe concurrent data access patterns
|
||||
//!
|
||||
//! Solutions are evaluated on correctness (do test cases pass?),
|
||||
//! efficiency (complexity class), and elegance (idiomatic Rust patterns).
|
||||
|
||||
use crate::domain::{Domain, DomainEmbedding, DomainId, Evaluation, Solution, Task};
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Embedding dimension for Rust synthesis domain.
|
||||
const EMBEDDING_DIM: usize = 64;
|
||||
|
||||
/// Categories of Rust synthesis tasks.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum RustTaskCategory {
|
||||
/// Transform data: map, filter, fold, scan.
|
||||
Transform,
|
||||
/// Implement a data structure with trait impls.
|
||||
DataStructure,
|
||||
/// Implement a named algorithm.
|
||||
Algorithm,
|
||||
/// Type-level programming: generics, trait bounds, associated types.
|
||||
TypeLevel,
|
||||
/// Concurrent programming: Arc, Mutex, channels, atomics.
|
||||
Concurrency,
|
||||
}
|
||||
|
||||
/// Specification for a Rust synthesis task.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RustTaskSpec {
|
||||
/// Task category.
|
||||
pub category: RustTaskCategory,
|
||||
/// Function signature that must be implemented.
|
||||
pub signature: String,
|
||||
/// Natural language description of the required behavior.
|
||||
pub description: String,
|
||||
/// Test cases as (input_json, expected_output_json) pairs.
|
||||
pub test_cases: Vec<(String, String)>,
|
||||
/// Required traits the solution must implement.
|
||||
pub required_traits: Vec<String>,
|
||||
/// Banned patterns (e.g., "unsafe", "unwrap").
|
||||
pub banned_patterns: Vec<String>,
|
||||
/// Expected complexity class (e.g., "O(n log n)").
|
||||
pub expected_complexity: Option<String>,
|
||||
}
|
||||
|
||||
/// Rust program synthesis domain.
|
||||
pub struct RustSynthesisDomain {
|
||||
id: DomainId,
|
||||
}
|
||||
|
||||
impl RustSynthesisDomain {
|
||||
/// Create a new Rust synthesis domain.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
id: DomainId("rust_synthesis".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a transform task at the given difficulty.
|
||||
fn gen_transform(&self, difficulty: f32, rng: &mut impl Rng) -> RustTaskSpec {
|
||||
let (signature, description, tests, complexity) = if difficulty < 0.3 {
|
||||
// Easy: simple map
|
||||
let ops = ["double", "negate", "abs", "square"];
|
||||
let op = ops[rng.gen_range(0..ops.len())];
|
||||
(
|
||||
format!("fn {}(values: &[i64]) -> Vec<i64>", op),
|
||||
format!("Apply {} to each element in the slice.", op),
|
||||
match op {
|
||||
"double" => vec![
|
||||
("[1, 2, 3]".into(), "[2, 4, 6]".into()),
|
||||
("[-1, 0, 5]".into(), "[-2, 0, 10]".into()),
|
||||
],
|
||||
"negate" => vec![
|
||||
("[1, -2, 3]".into(), "[-1, 2, -3]".into()),
|
||||
("[0]".into(), "[0]".into()),
|
||||
],
|
||||
"abs" => vec![
|
||||
("[-1, 2, -3]".into(), "[1, 2, 3]".into()),
|
||||
("[0, -0]".into(), "[0, 0]".into()),
|
||||
],
|
||||
_ => vec![
|
||||
("[2, 3, 4]".into(), "[4, 9, 16]".into()),
|
||||
("[0, -1]".into(), "[0, 1]".into()),
|
||||
],
|
||||
},
|
||||
"O(n)",
|
||||
)
|
||||
} else if difficulty < 0.7 {
|
||||
// Medium: filter + fold combos
|
||||
(
|
||||
"fn sum_positives(values: &[i64]) -> i64".into(),
|
||||
"Sum all positive values in the slice.".into(),
|
||||
vec![
|
||||
("[1, -2, 3, -4, 5]".into(), "9".into()),
|
||||
("[-1, -2, -3]".into(), "0".into()),
|
||||
("[]".into(), "0".into()),
|
||||
],
|
||||
"O(n)",
|
||||
)
|
||||
} else {
|
||||
// Hard: sliding window / scan
|
||||
(
|
||||
"fn max_subarray_sum(values: &[i64]) -> i64".into(),
|
||||
"Find the maximum sum contiguous subarray (Kadane's algorithm).".into(),
|
||||
vec![
|
||||
("[-2, 1, -3, 4, -1, 2, 1, -5, 4]".into(), "6".into()),
|
||||
("[-1, -2, -3]".into(), "-1".into()),
|
||||
("[5]".into(), "5".into()),
|
||||
],
|
||||
"O(n)",
|
||||
)
|
||||
};
|
||||
|
||||
RustTaskSpec {
|
||||
category: RustTaskCategory::Transform,
|
||||
signature,
|
||||
description,
|
||||
test_cases: tests,
|
||||
required_traits: Vec::new(),
|
||||
banned_patterns: vec!["unsafe".into()],
|
||||
expected_complexity: Some(complexity.into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a data structure task.
|
||||
fn gen_data_structure(&self, difficulty: f32, _rng: &mut impl Rng) -> RustTaskSpec {
|
||||
if difficulty < 0.4 {
|
||||
RustTaskSpec {
|
||||
category: RustTaskCategory::DataStructure,
|
||||
signature: "struct Stack<T>".into(),
|
||||
description: "Implement a generic stack with push, pop, peek, is_empty, len."
|
||||
.into(),
|
||||
test_cases: vec![
|
||||
("push(1); push(2); pop()".into(), "Some(2)".into()),
|
||||
("is_empty()".into(), "true".into()),
|
||||
("push(1); len()".into(), "1".into()),
|
||||
],
|
||||
required_traits: vec!["Default".into()],
|
||||
banned_patterns: vec!["unsafe".into()],
|
||||
expected_complexity: Some("O(1) per operation".into()),
|
||||
}
|
||||
} else if difficulty < 0.7 {
|
||||
RustTaskSpec {
|
||||
category: RustTaskCategory::DataStructure,
|
||||
signature: "struct MinHeap<T: Ord>".into(),
|
||||
description: "Implement a binary min-heap with insert, extract_min, peek_min."
|
||||
.into(),
|
||||
test_cases: vec![
|
||||
(
|
||||
"insert(3); insert(1); insert(2); extract_min()".into(),
|
||||
"Some(1)".into(),
|
||||
),
|
||||
("peek_min() on empty".into(), "None".into()),
|
||||
],
|
||||
required_traits: vec!["Default".into()],
|
||||
banned_patterns: vec!["unsafe".into(), "BinaryHeap".into()],
|
||||
expected_complexity: Some("O(log n) insert/extract".into()),
|
||||
}
|
||||
} else {
|
||||
RustTaskSpec {
|
||||
category: RustTaskCategory::DataStructure,
|
||||
signature: "struct LRUCache<K: Hash + Eq, V>".into(),
|
||||
description: "Implement an LRU cache with get, put, and capacity eviction.".into(),
|
||||
test_cases: vec![
|
||||
(
|
||||
"cap=2; put(1,'a'); put(2,'b'); get(1); put(3,'c'); get(2)".into(),
|
||||
"None".into(),
|
||||
),
|
||||
(
|
||||
"cap=1; put(1,'a'); put(2,'b'); get(1)".into(),
|
||||
"None".into(),
|
||||
),
|
||||
],
|
||||
required_traits: Vec::new(),
|
||||
banned_patterns: vec!["unsafe".into()],
|
||||
expected_complexity: Some("O(1) get/put".into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate an algorithm task.
|
||||
fn gen_algorithm(&self, difficulty: f32, _rng: &mut impl Rng) -> RustTaskSpec {
|
||||
if difficulty < 0.4 {
|
||||
RustTaskSpec {
|
||||
category: RustTaskCategory::Algorithm,
|
||||
signature: "fn binary_search(sorted: &[i64], target: i64) -> Option<usize>".into(),
|
||||
description: "Implement binary search on a sorted slice.".into(),
|
||||
test_cases: vec![
|
||||
("[1,3,5,7,9], 5".into(), "Some(2)".into()),
|
||||
("[1,3,5,7,9], 4".into(), "None".into()),
|
||||
("[], 1".into(), "None".into()),
|
||||
],
|
||||
required_traits: Vec::new(),
|
||||
banned_patterns: vec!["unsafe".into()],
|
||||
expected_complexity: Some("O(log n)".into()),
|
||||
}
|
||||
} else if difficulty < 0.7 {
|
||||
RustTaskSpec {
|
||||
category: RustTaskCategory::Algorithm,
|
||||
signature: "fn merge_sort(values: &mut [i64])".into(),
|
||||
description: "Implement stable merge sort in-place.".into(),
|
||||
test_cases: vec![
|
||||
("[3,1,4,1,5,9,2,6]".into(), "[1,1,2,3,4,5,6,9]".into()),
|
||||
("[1]".into(), "[1]".into()),
|
||||
("[]".into(), "[]".into()),
|
||||
],
|
||||
required_traits: Vec::new(),
|
||||
banned_patterns: vec!["unsafe".into(), ".sort".into()],
|
||||
expected_complexity: Some("O(n log n)".into()),
|
||||
}
|
||||
} else {
|
||||
RustTaskSpec {
|
||||
category: RustTaskCategory::Algorithm,
|
||||
signature: "fn shortest_path(adj: &[Vec<(usize, u64)>], src: usize, dst: usize) -> Option<u64>".into(),
|
||||
description: "Implement Dijkstra's shortest path on a weighted directed graph.".into(),
|
||||
test_cases: vec![
|
||||
("3 nodes, 0->1:2, 1->2:3, 0->2:10; src=0, dst=2".into(), "Some(5)".into()),
|
||||
("2 nodes, no edges; src=0, dst=1".into(), "None".into()),
|
||||
],
|
||||
required_traits: Vec::new(),
|
||||
banned_patterns: vec!["unsafe".into()],
|
||||
expected_complexity: Some("O((V + E) log V)".into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract structural features from a Rust solution for embedding.
|
||||
fn extract_features(&self, solution: &Solution) -> Vec<f32> {
|
||||
let code = &solution.content;
|
||||
let mut features = vec![0.0f32; EMBEDDING_DIM];
|
||||
|
||||
// Feature 0-7: Control flow complexity
|
||||
features[0] = code.matches("if ").count() as f32 / 10.0;
|
||||
features[1] = code.matches("for ").count() as f32 / 5.0;
|
||||
features[2] = code.matches("while ").count() as f32 / 5.0;
|
||||
features[3] = code.matches("match ").count() as f32 / 5.0;
|
||||
features[4] = code.matches("loop ").count() as f32 / 3.0;
|
||||
features[5] = code.matches("return ").count() as f32 / 5.0;
|
||||
features[6] = code.matches("break").count() as f32 / 3.0;
|
||||
features[7] = code.matches("continue").count() as f32 / 3.0;
|
||||
|
||||
// Feature 8-15: Type system usage
|
||||
features[8] = code.matches("impl ").count() as f32 / 5.0;
|
||||
features[9] = code.matches("trait ").count() as f32 / 3.0;
|
||||
features[10] = code.matches("struct ").count() as f32 / 3.0;
|
||||
features[11] = code.matches("enum ").count() as f32 / 3.0;
|
||||
features[12] = code.matches("where ").count() as f32 / 3.0;
|
||||
features[13] = code.matches("dyn ").count() as f32 / 3.0;
|
||||
features[14] = code.matches("Box<").count() as f32 / 3.0;
|
||||
features[15] = code.matches("Rc<").count() as f32 / 3.0;
|
||||
|
||||
// Feature 16-23: Functional patterns
|
||||
features[16] = code.matches(".map(").count() as f32 / 5.0;
|
||||
features[17] = code.matches(".filter(").count() as f32 / 5.0;
|
||||
features[18] = code.matches(".fold(").count() as f32 / 3.0;
|
||||
features[19] = code.matches(".collect()").count() as f32 / 3.0;
|
||||
features[20] = code.matches(".iter()").count() as f32 / 5.0;
|
||||
features[21] = code.matches("|").count() as f32 / 10.0; // closures
|
||||
features[22] = code.matches("Some(").count() as f32 / 5.0;
|
||||
features[23] = code.matches("None").count() as f32 / 5.0;
|
||||
|
||||
// Feature 24-31: Memory/ownership patterns
|
||||
features[24] = code.matches("&mut ").count() as f32 / 5.0;
|
||||
features[25] = code.matches("&self").count() as f32 / 5.0;
|
||||
features[26] = code.matches("mut ").count() as f32 / 10.0;
|
||||
features[27] = code.matches(".clone()").count() as f32 / 5.0;
|
||||
features[28] = code.matches("Vec<").count() as f32 / 5.0;
|
||||
features[29] = code.matches("HashMap").count() as f32 / 3.0;
|
||||
features[30] = code.matches("String").count() as f32 / 5.0;
|
||||
features[31] = code.matches("Result<").count() as f32 / 3.0;
|
||||
|
||||
// Feature 32-39: Concurrency patterns
|
||||
features[32] = code.matches("Arc<").count() as f32 / 3.0;
|
||||
features[33] = code.matches("Mutex<").count() as f32 / 3.0;
|
||||
features[34] = code.matches("RwLock").count() as f32 / 3.0;
|
||||
features[35] = code.matches("async ").count() as f32 / 3.0;
|
||||
features[36] = code.matches("await").count() as f32 / 5.0;
|
||||
features[37] = code.matches("spawn").count() as f32 / 3.0;
|
||||
features[38] = code.matches("channel").count() as f32 / 3.0;
|
||||
features[39] = code.matches("Atomic").count() as f32 / 3.0;
|
||||
|
||||
// Feature 40-47: Code structure metrics
|
||||
let lines: Vec<&str> = code.lines().collect();
|
||||
features[40] = (lines.len() as f32) / 100.0;
|
||||
features[41] = lines.iter().filter(|l| l.trim().is_empty()).count() as f32
|
||||
/ (lines.len().max(1) as f32);
|
||||
features[42] = code.matches("fn ").count() as f32 / 10.0;
|
||||
features[43] = code.matches("pub ").count() as f32 / 10.0;
|
||||
features[44] = code.matches("mod ").count() as f32 / 5.0;
|
||||
features[45] = code.matches("use ").count() as f32 / 10.0;
|
||||
features[46] = code.matches("#[").count() as f32 / 5.0; // attributes
|
||||
features[47] = code.matches("///").count() as f32 / 10.0; // doc comments
|
||||
|
||||
// Feature 48-55: Error handling patterns
|
||||
features[48] = code.matches("unwrap()").count() as f32 / 5.0;
|
||||
features[49] = code.matches("expect(").count() as f32 / 5.0;
|
||||
features[50] = code.matches("?;").count() as f32 / 5.0; // error propagation
|
||||
features[51] = code.matches("Err(").count() as f32 / 5.0;
|
||||
features[52] = code.matches("Ok(").count() as f32 / 5.0;
|
||||
features[53] = code.matches("panic!").count() as f32 / 3.0;
|
||||
features[54] = code.matches("assert").count() as f32 / 5.0;
|
||||
features[55] = code.matches("debug_assert").count() as f32 / 3.0;
|
||||
|
||||
// Feature 56-63: Algorithm indicators
|
||||
features[56] = code.matches("sort").count() as f32 / 3.0;
|
||||
features[57] = code.matches("binary_search").count() as f32 / 2.0;
|
||||
features[58] = code.matches("push").count() as f32 / 5.0;
|
||||
features[59] = code.matches("pop").count() as f32 / 5.0;
|
||||
features[60] = code.matches("swap").count() as f32 / 5.0;
|
||||
features[61] = code.matches("len()").count() as f32 / 5.0;
|
||||
features[62] = code.matches("is_empty").count() as f32 / 3.0;
|
||||
features[63] = code.matches("contains").count() as f32 / 3.0;
|
||||
|
||||
// Normalize to unit length
|
||||
let norm: f32 = features.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 1e-10 {
|
||||
for f in &mut features {
|
||||
*f /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
features
|
||||
}
|
||||
|
||||
/// Score a Rust solution based on pattern matching heuristics.
|
||||
fn score_solution(&self, spec: &RustTaskSpec, solution: &Solution) -> Evaluation {
|
||||
let code = &solution.content;
|
||||
let mut correctness = 0.0f32;
|
||||
let mut efficiency = 0.5f32;
|
||||
let mut elegance = 0.5f32;
|
||||
let mut notes = Vec::new();
|
||||
|
||||
// Check for banned patterns
|
||||
let mut banned_found = false;
|
||||
for pattern in &spec.banned_patterns {
|
||||
if code.contains(pattern.as_str()) {
|
||||
notes.push(format!("Banned pattern found: {}", pattern));
|
||||
banned_found = true;
|
||||
}
|
||||
}
|
||||
|
||||
if banned_found {
|
||||
elegance *= 0.5;
|
||||
}
|
||||
|
||||
// Check that the solution contains the expected signature
|
||||
let sig_name = spec
|
||||
.signature
|
||||
.split('(')
|
||||
.next()
|
||||
.unwrap_or("")
|
||||
.split_whitespace()
|
||||
.last()
|
||||
.unwrap_or("");
|
||||
|
||||
if code.contains(sig_name) {
|
||||
correctness += 0.3;
|
||||
} else {
|
||||
notes.push(format!("Missing expected identifier: {}", sig_name));
|
||||
}
|
||||
|
||||
// Check for fn definition
|
||||
if code.contains("fn ") {
|
||||
correctness += 0.2;
|
||||
}
|
||||
|
||||
// Check for test case coverage hints
|
||||
let test_coverage = spec
|
||||
.test_cases
|
||||
.iter()
|
||||
.filter(|(input, _)| {
|
||||
// Heuristic: solution likely handles the input pattern
|
||||
let key_tokens: Vec<&str> = input.split(|c: char| !c.is_alphanumeric()).collect();
|
||||
key_tokens.iter().any(|t| !t.is_empty() && code.contains(t))
|
||||
})
|
||||
.count() as f32
|
||||
/ spec.test_cases.len().max(1) as f32;
|
||||
correctness += test_coverage * 0.5;
|
||||
correctness = correctness.clamp(0.0, 1.0);
|
||||
|
||||
// Efficiency: penalize obviously quadratic patterns
|
||||
let nested_loops = code.matches("for ").count() > 1 && code.matches("for ").count() > 2;
|
||||
if nested_loops {
|
||||
if let Some(ref expected) = spec.expected_complexity {
|
||||
if expected.contains("O(n)") || expected.contains("O(log") {
|
||||
efficiency *= 0.5;
|
||||
notes.push("Possible O(n^2) when O(n) or O(log n) expected".into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Elegance: favor idiomatic Rust
|
||||
let iterator_usage = code.matches(".iter()").count()
|
||||
+ code.matches(".map(").count()
|
||||
+ code.matches(".filter(").count()
|
||||
+ code.matches(".fold(").count();
|
||||
if iterator_usage > 0 {
|
||||
elegance += 0.2;
|
||||
}
|
||||
|
||||
// Penalize excessive unwrap
|
||||
let unwrap_count = code.matches("unwrap()").count();
|
||||
if unwrap_count > 3 {
|
||||
elegance -= 0.2;
|
||||
notes.push("Excessive unwrap() usage".into());
|
||||
}
|
||||
|
||||
// Proper error handling bonus
|
||||
if code.contains("Result<") || code.contains("?;") {
|
||||
elegance += 0.1;
|
||||
}
|
||||
|
||||
elegance = elegance.clamp(0.0, 1.0);
|
||||
|
||||
// Constraint results
|
||||
let constraint_results = spec
|
||||
.banned_patterns
|
||||
.iter()
|
||||
.map(|p| !code.contains(p.as_str()))
|
||||
.collect();
|
||||
|
||||
let score = 0.6 * correctness + 0.25 * efficiency + 0.15 * elegance;
|
||||
|
||||
Evaluation {
|
||||
score: score.clamp(0.0, 1.0),
|
||||
correctness,
|
||||
efficiency,
|
||||
elegance,
|
||||
constraint_results,
|
||||
notes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RustSynthesisDomain {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Domain for RustSynthesisDomain {
|
||||
fn id(&self) -> &DomainId {
|
||||
&self.id
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"Rust Program Synthesis"
|
||||
}
|
||||
|
||||
fn generate_tasks(&self, count: usize, difficulty: f32) -> Vec<Task> {
|
||||
let mut rng = rand::thread_rng();
|
||||
let difficulty = difficulty.clamp(0.0, 1.0);
|
||||
|
||||
(0..count)
|
||||
.map(|i| {
|
||||
let category_roll: f32 = rng.gen();
|
||||
let spec = if category_roll < 0.4 {
|
||||
self.gen_transform(difficulty, &mut rng)
|
||||
} else if category_roll < 0.7 {
|
||||
self.gen_data_structure(difficulty, &mut rng)
|
||||
} else {
|
||||
self.gen_algorithm(difficulty, &mut rng)
|
||||
};
|
||||
|
||||
Task {
|
||||
id: format!("rust_synth_{}_d{:.0}", i, difficulty * 100.0),
|
||||
domain_id: self.id.clone(),
|
||||
difficulty,
|
||||
spec: serde_json::to_value(&spec).unwrap_or_default(),
|
||||
constraints: spec.banned_patterns.clone(),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn evaluate(&self, task: &Task, solution: &Solution) -> Evaluation {
|
||||
let spec: RustTaskSpec = match serde_json::from_value(task.spec.clone()) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Evaluation::zero(vec![format!("Invalid task spec: {}", e)]),
|
||||
};
|
||||
self.score_solution(&spec, solution)
|
||||
}
|
||||
|
||||
fn embed(&self, solution: &Solution) -> DomainEmbedding {
|
||||
let features = self.extract_features(solution);
|
||||
DomainEmbedding::new(features, self.id.clone())
|
||||
}
|
||||
|
||||
fn embedding_dim(&self) -> usize {
|
||||
EMBEDDING_DIM
|
||||
}
|
||||
|
||||
fn reference_solution(&self, task: &Task) -> Option<Solution> {
|
||||
let spec: RustTaskSpec = serde_json::from_value(task.spec.clone()).ok()?;
|
||||
|
||||
let content = match spec.category {
|
||||
RustTaskCategory::Transform => {
|
||||
if spec.signature.contains("sum_positives") {
|
||||
"fn sum_positives(values: &[i64]) -> i64 {\n values.iter().filter(|&&x| x > 0).sum()\n}".to_string()
|
||||
} else if spec.signature.contains("max_subarray_sum") {
|
||||
"fn max_subarray_sum(values: &[i64]) -> i64 {\n let mut max_so_far = values[0];\n let mut max_ending = values[0];\n for &v in &values[1..] {\n max_ending = v.max(max_ending + v);\n max_so_far = max_so_far.max(max_ending);\n }\n max_so_far\n}".to_string()
|
||||
} else {
|
||||
format!(
|
||||
"{} {{\n values.iter().map(|&x| x /* TODO */).collect()\n}}",
|
||||
spec.signature
|
||||
)
|
||||
}
|
||||
}
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
Some(Solution {
|
||||
task_id: task.id.clone(),
|
||||
content,
|
||||
data: serde_json::Value::Null,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_tasks() {
|
||||
let domain = RustSynthesisDomain::new();
|
||||
let tasks = domain.generate_tasks(5, 0.5);
|
||||
assert_eq!(tasks.len(), 5);
|
||||
for task in &tasks {
|
||||
assert_eq!(task.domain_id, domain.id);
|
||||
assert!((task.difficulty - 0.5).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluate_good_solution() {
|
||||
let domain = RustSynthesisDomain::new();
|
||||
let tasks = domain.generate_tasks(1, 0.0);
|
||||
let task = &tasks[0];
|
||||
|
||||
let solution = Solution {
|
||||
task_id: task.id.clone(),
|
||||
content: "fn double(values: &[i64]) -> Vec<i64> {\n values.iter().map(|&x| x * 2).collect()\n}".to_string(),
|
||||
data: serde_json::Value::Null,
|
||||
};
|
||||
|
||||
let eval = domain.evaluate(task, &solution);
|
||||
assert!(eval.score > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embed_produces_correct_dim() {
|
||||
let domain = RustSynthesisDomain::new();
|
||||
let solution = Solution {
|
||||
task_id: "test".into(),
|
||||
content: "fn foo() { let x = 1; }".into(),
|
||||
data: serde_json::Value::Null,
|
||||
};
|
||||
let embedding = domain.embed(&solution);
|
||||
assert_eq!(embedding.dim, EMBEDDING_DIM);
|
||||
assert_eq!(embedding.vector.len(), EMBEDDING_DIM);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_normalized() {
|
||||
let domain = RustSynthesisDomain::new();
|
||||
let solution = Solution {
|
||||
task_id: "test".into(),
|
||||
content: "fn foo() { for i in 0..10 { if i > 5 { println!(\"{}\", i); } } }".into(),
|
||||
data: serde_json::Value::Null,
|
||||
};
|
||||
let embedding = domain.embed(&solution);
|
||||
let norm: f32 = embedding.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
assert!((norm - 1.0).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_difficulty_range() {
|
||||
let domain = RustSynthesisDomain::new();
|
||||
// Easy tasks
|
||||
let easy = domain.generate_tasks(3, 0.1);
|
||||
for t in &easy {
|
||||
let spec: RustTaskSpec = serde_json::from_value(t.spec.clone()).unwrap();
|
||||
assert!(!spec.signature.is_empty());
|
||||
}
|
||||
// Hard tasks
|
||||
let hard = domain.generate_tasks(3, 0.9);
|
||||
for t in &hard {
|
||||
let spec: RustTaskSpec = serde_json::from_value(t.spec.clone()).unwrap();
|
||||
assert!(!spec.signature.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
715
vendor/ruvector/crates/ruvector-domain-expansion/src/rvf_bridge.rs
vendored
Normal file
715
vendor/ruvector/crates/ruvector-domain-expansion/src/rvf_bridge.rs
vendored
Normal file
@@ -0,0 +1,715 @@
|
||||
//! RVF Integration Bridge
|
||||
//!
|
||||
//! Connects the domain expansion engine to the RuVector Format (RVF):
|
||||
//! - Serializes `TransferPrior`, `PolicyKernel`, `CostCurve` into RVF segments
|
||||
//! - Creates SHAKE-256 witness chains for transfer verification
|
||||
//! - Packages domain expansion artifacts into AGI container TLV entries
|
||||
//! - Bridges priors to/from the rvf-solver-wasm `PolicyKernel`
|
||||
//!
|
||||
//! Requires the `rvf` feature to be enabled.
|
||||
|
||||
use rvf_types::{SegmentFlags, SegmentType};
|
||||
use rvf_wire::reader::{read_segment, validate_segment};
|
||||
use rvf_wire::writer::write_segment;
|
||||
|
||||
use crate::cost_curve::{AccelerationScoreboard, CostCurve};
|
||||
use crate::domain::DomainId;
|
||||
use crate::policy_kernel::PolicyKernel;
|
||||
use crate::transfer::{ArmId, BetaParams, ContextBucket, MetaThompsonEngine, TransferPrior};
|
||||
|
||||
// ─── Wire-format wrappers ───────────────────────────────────────────────────
|
||||
//
|
||||
// JSON requires string keys for objects. TransferPrior uses HashMap<ContextBucket, _>
|
||||
// which can't be directly serialized. These wrappers convert to/from Vec<(K,V)> form.
|
||||
|
||||
/// Wire-format representation of a TransferPrior (JSON-safe).
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct WireTransferPrior {
|
||||
source_domain: DomainId,
|
||||
bucket_priors: Vec<(ContextBucket, Vec<(ArmId, BetaParams)>)>,
|
||||
cost_ema_priors: Vec<(ContextBucket, f32)>,
|
||||
training_cycles: u64,
|
||||
witness_hash: String,
|
||||
}
|
||||
|
||||
impl From<&TransferPrior> for WireTransferPrior {
|
||||
fn from(p: &TransferPrior) -> Self {
|
||||
Self {
|
||||
source_domain: p.source_domain.clone(),
|
||||
bucket_priors: p
|
||||
.bucket_priors
|
||||
.iter()
|
||||
.map(|(b, arms)| {
|
||||
let arm_vec: Vec<(ArmId, BetaParams)> =
|
||||
arms.iter().map(|(a, bp)| (a.clone(), bp.clone())).collect();
|
||||
(b.clone(), arm_vec)
|
||||
})
|
||||
.collect(),
|
||||
cost_ema_priors: p
|
||||
.cost_ema_priors
|
||||
.iter()
|
||||
.map(|(b, c)| (b.clone(), *c))
|
||||
.collect(),
|
||||
training_cycles: p.training_cycles,
|
||||
witness_hash: p.witness_hash.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WireTransferPrior> for TransferPrior {
|
||||
fn from(w: WireTransferPrior) -> Self {
|
||||
let mut bucket_priors = std::collections::HashMap::new();
|
||||
for (bucket, arms) in w.bucket_priors {
|
||||
let arm_map: std::collections::HashMap<ArmId, BetaParams> = arms.into_iter().collect();
|
||||
bucket_priors.insert(bucket, arm_map);
|
||||
}
|
||||
let cost_ema_priors: std::collections::HashMap<ContextBucket, f32> =
|
||||
w.cost_ema_priors.into_iter().collect();
|
||||
Self {
|
||||
source_domain: w.source_domain,
|
||||
bucket_priors,
|
||||
cost_ema_priors,
|
||||
training_cycles: w.training_cycles,
|
||||
witness_hash: w.witness_hash,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wire-format representation of a PolicyKernel (JSON-safe).
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct WirePolicyKernel {
|
||||
id: String,
|
||||
knobs: crate::policy_kernel::PolicyKnobs,
|
||||
holdout_scores: Vec<(DomainId, f32)>,
|
||||
total_cost: f32,
|
||||
cycles: u64,
|
||||
generation: u32,
|
||||
parent_id: Option<String>,
|
||||
replay_verified: bool,
|
||||
}
|
||||
|
||||
impl From<&PolicyKernel> for WirePolicyKernel {
|
||||
fn from(k: &PolicyKernel) -> Self {
|
||||
Self {
|
||||
id: k.id.clone(),
|
||||
knobs: k.knobs.clone(),
|
||||
holdout_scores: k
|
||||
.holdout_scores
|
||||
.iter()
|
||||
.map(|(d, s)| (d.clone(), *s))
|
||||
.collect(),
|
||||
total_cost: k.total_cost,
|
||||
cycles: k.cycles,
|
||||
generation: k.generation,
|
||||
parent_id: k.parent_id.clone(),
|
||||
replay_verified: k.replay_verified,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WirePolicyKernel> for PolicyKernel {
|
||||
fn from(w: WirePolicyKernel) -> Self {
|
||||
Self {
|
||||
id: w.id,
|
||||
knobs: w.knobs,
|
||||
holdout_scores: w.holdout_scores.into_iter().collect(),
|
||||
total_cost: w.total_cost,
|
||||
cycles: w.cycles,
|
||||
generation: w.generation,
|
||||
parent_id: w.parent_id,
|
||||
replay_verified: w.replay_verified,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Segment serialization ──────────────────────────────────────────────────
|
||||
|
||||
/// Serialize a `TransferPrior` into an RVF TRANSFER_PRIOR segment.
|
||||
///
|
||||
/// Wire format: JSON payload (using Vec-of-tuples for map keys) inside a
|
||||
/// 64-byte-aligned RVF segment. Type: `SegmentType::TransferPrior` (0x30).
|
||||
pub fn transfer_prior_to_segment(prior: &TransferPrior, segment_id: u64) -> Vec<u8> {
|
||||
let wire: WireTransferPrior = prior.into();
|
||||
let payload = serde_json::to_vec(&wire).expect("WireTransferPrior serialization cannot fail");
|
||||
write_segment(
|
||||
SegmentType::TransferPrior as u8,
|
||||
&payload,
|
||||
SegmentFlags::empty(),
|
||||
segment_id,
|
||||
)
|
||||
}
|
||||
|
||||
/// Deserialize a `TransferPrior` from an RVF segment's raw bytes.
|
||||
///
|
||||
/// Validates the segment header, checks the content hash, and deserializes
|
||||
/// the JSON payload.
|
||||
pub fn transfer_prior_from_segment(data: &[u8]) -> Result<TransferPrior, RvfBridgeError> {
|
||||
let (header, payload) = read_segment(data).map_err(RvfBridgeError::Rvf)?;
|
||||
if header.seg_type != SegmentType::TransferPrior as u8 {
|
||||
return Err(RvfBridgeError::WrongSegmentType {
|
||||
expected: SegmentType::TransferPrior as u8,
|
||||
got: header.seg_type,
|
||||
});
|
||||
}
|
||||
validate_segment(&header, payload).map_err(RvfBridgeError::Rvf)?;
|
||||
let wire: WireTransferPrior = serde_json::from_slice(payload).map_err(RvfBridgeError::Json)?;
|
||||
Ok(wire.into())
|
||||
}
|
||||
|
||||
/// Serialize a `PolicyKernel` into an RVF POLICY_KERNEL segment.
|
||||
pub fn policy_kernel_to_segment(kernel: &PolicyKernel, segment_id: u64) -> Vec<u8> {
|
||||
let wire: WirePolicyKernel = kernel.into();
|
||||
let payload = serde_json::to_vec(&wire).expect("WirePolicyKernel serialization cannot fail");
|
||||
write_segment(
|
||||
SegmentType::PolicyKernel as u8,
|
||||
&payload,
|
||||
SegmentFlags::empty(),
|
||||
segment_id,
|
||||
)
|
||||
}
|
||||
|
||||
/// Deserialize a `PolicyKernel` from an RVF segment.
|
||||
pub fn policy_kernel_from_segment(data: &[u8]) -> Result<PolicyKernel, RvfBridgeError> {
|
||||
let (header, payload) = read_segment(data).map_err(RvfBridgeError::Rvf)?;
|
||||
if header.seg_type != SegmentType::PolicyKernel as u8 {
|
||||
return Err(RvfBridgeError::WrongSegmentType {
|
||||
expected: SegmentType::PolicyKernel as u8,
|
||||
got: header.seg_type,
|
||||
});
|
||||
}
|
||||
validate_segment(&header, payload).map_err(RvfBridgeError::Rvf)?;
|
||||
let wire: WirePolicyKernel = serde_json::from_slice(payload).map_err(RvfBridgeError::Json)?;
|
||||
Ok(wire.into())
|
||||
}
|
||||
|
||||
/// Serialize a `CostCurve` into an RVF COST_CURVE segment.
|
||||
pub fn cost_curve_to_segment(curve: &CostCurve, segment_id: u64) -> Vec<u8> {
|
||||
let payload = serde_json::to_vec(curve).expect("CostCurve serialization cannot fail");
|
||||
write_segment(
|
||||
SegmentType::CostCurve as u8,
|
||||
&payload,
|
||||
SegmentFlags::empty(),
|
||||
segment_id,
|
||||
)
|
||||
}
|
||||
|
||||
/// Deserialize a `CostCurve` from an RVF segment.
|
||||
pub fn cost_curve_from_segment(data: &[u8]) -> Result<CostCurve, RvfBridgeError> {
|
||||
let (header, payload) = read_segment(data).map_err(RvfBridgeError::Rvf)?;
|
||||
if header.seg_type != SegmentType::CostCurve as u8 {
|
||||
return Err(RvfBridgeError::WrongSegmentType {
|
||||
expected: SegmentType::CostCurve as u8,
|
||||
got: header.seg_type,
|
||||
});
|
||||
}
|
||||
validate_segment(&header, payload).map_err(RvfBridgeError::Rvf)?;
|
||||
serde_json::from_slice(payload).map_err(RvfBridgeError::Json)
|
||||
}
|
||||
|
||||
// ─── Witness chain ──────────────────────────────────────────────────────────
|
||||
|
||||
/// Witness type constants for domain expansion operations.
|
||||
pub const WITNESS_TRANSFER: u8 = 0x10;
|
||||
/// Witness type for policy kernel promotion.
|
||||
pub const WITNESS_POLICY_PROMOTION: u8 = 0x11;
|
||||
/// Witness type for cost curve convergence checkpoint.
|
||||
pub const WITNESS_CONVERGENCE: u8 = 0x12;
|
||||
|
||||
/// Create a SHAKE-256 witness hash for a transfer prior.
|
||||
///
|
||||
/// The witness hash covers: source domain, training cycles, and the serialized
|
||||
/// bucket priors. This replaces the old string-based `witness_hash` field.
|
||||
pub fn compute_transfer_witness_hash(prior: &TransferPrior) -> [u8; 32] {
|
||||
let wire: WireTransferPrior = prior.into();
|
||||
let payload = serde_json::to_vec(&wire).expect("WireTransferPrior serialization cannot fail");
|
||||
rvf_crypto::shake256_256(&payload)
|
||||
}
|
||||
|
||||
/// Build witness entries for a transfer verification event.
|
||||
///
|
||||
/// Returns entries suitable for `rvf_crypto::create_witness_chain()`.
|
||||
pub fn build_transfer_witness_entries(
|
||||
prior: &TransferPrior,
|
||||
source: &DomainId,
|
||||
target: &DomainId,
|
||||
acceleration_factor: f32,
|
||||
timestamp_ns: u64,
|
||||
) -> Vec<rvf_crypto::WitnessEntry> {
|
||||
let mut entries = Vec::with_capacity(2);
|
||||
|
||||
// Entry 1: Transfer prior hash
|
||||
let prior_hash = compute_transfer_witness_hash(prior);
|
||||
entries.push(rvf_crypto::WitnessEntry {
|
||||
prev_hash: [0u8; 32],
|
||||
action_hash: prior_hash,
|
||||
timestamp_ns,
|
||||
witness_type: WITNESS_TRANSFER,
|
||||
});
|
||||
|
||||
// Entry 2: Acceleration verification (hash of source→target + factor)
|
||||
let accel_payload = format!(
|
||||
"{}->{}:accel={:.6}",
|
||||
source.0, target.0, acceleration_factor
|
||||
);
|
||||
let accel_hash = rvf_crypto::shake256_256(accel_payload.as_bytes());
|
||||
entries.push(rvf_crypto::WitnessEntry {
|
||||
prev_hash: [0u8; 32], // chaining handled by create_witness_chain
|
||||
action_hash: accel_hash,
|
||||
timestamp_ns: timestamp_ns + 1,
|
||||
witness_type: WITNESS_CONVERGENCE,
|
||||
});
|
||||
|
||||
entries
|
||||
}
|
||||
|
||||
// ─── AGI Container TLV packaging ────────────────────────────────────────────
|
||||
|
||||
/// A TLV (Tag-Length-Value) entry for AGI container manifest packaging.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AgiTlvEntry {
|
||||
/// TLV tag (see `AGI_TAG_*` constants in rvf-types).
|
||||
pub tag: u16,
|
||||
/// Serialized value payload.
|
||||
pub value: Vec<u8>,
|
||||
}
|
||||
|
||||
/// Package domain expansion artifacts into AGI container TLV entries.
|
||||
///
|
||||
/// Returns a vector of TLV entries ready for inclusion in an AGI container
|
||||
/// manifest segment. Each entry uses the corresponding `AGI_TAG_*` constant.
|
||||
pub fn package_for_agi_container(
|
||||
priors: &[TransferPrior],
|
||||
kernels: &[PolicyKernel],
|
||||
scoreboard: &AccelerationScoreboard,
|
||||
) -> Vec<AgiTlvEntry> {
|
||||
let mut entries = Vec::new();
|
||||
|
||||
// Transfer priors (use wire format for JSON-safe serialization)
|
||||
for prior in priors {
|
||||
let wire: WireTransferPrior = prior.into();
|
||||
let value = serde_json::to_vec(&wire).expect("WireTransferPrior serialization cannot fail");
|
||||
entries.push(AgiTlvEntry {
|
||||
tag: rvf_types::AGI_TAG_TRANSFER_PRIOR,
|
||||
value,
|
||||
});
|
||||
}
|
||||
|
||||
// Policy kernels (use wire format for JSON-safe serialization)
|
||||
for kernel in kernels {
|
||||
let wire: WirePolicyKernel = kernel.into();
|
||||
let value = serde_json::to_vec(&wire).expect("WirePolicyKernel serialization cannot fail");
|
||||
entries.push(AgiTlvEntry {
|
||||
tag: rvf_types::AGI_TAG_POLICY_KERNEL,
|
||||
value,
|
||||
});
|
||||
}
|
||||
|
||||
// Cost curves from the scoreboard
|
||||
for curve in scoreboard.curves.values() {
|
||||
let value = serde_json::to_vec(curve).expect("CostCurve serialization cannot fail");
|
||||
entries.push(AgiTlvEntry {
|
||||
tag: rvf_types::AGI_TAG_COST_CURVE,
|
||||
value,
|
||||
});
|
||||
}
|
||||
|
||||
entries
|
||||
}
|
||||
|
||||
/// Encode TLV entries into a binary payload for inclusion in a META segment.
|
||||
///
|
||||
/// Wire format per entry: `[tag: u16 LE][length: u32 LE][value: length bytes]`
|
||||
pub fn encode_tlv_entries(entries: &[AgiTlvEntry]) -> Vec<u8> {
|
||||
let total_size: usize = entries.iter().map(|e| 6 + e.value.len()).sum();
|
||||
let mut buf = Vec::with_capacity(total_size);
|
||||
for entry in entries {
|
||||
buf.extend_from_slice(&entry.tag.to_le_bytes());
|
||||
buf.extend_from_slice(&(entry.value.len() as u32).to_le_bytes());
|
||||
buf.extend_from_slice(&entry.value);
|
||||
}
|
||||
buf
|
||||
}
|
||||
|
||||
/// Decode TLV entries from a binary payload.
|
||||
pub fn decode_tlv_entries(data: &[u8]) -> Result<Vec<AgiTlvEntry>, RvfBridgeError> {
|
||||
let mut entries = Vec::new();
|
||||
let mut offset = 0;
|
||||
while offset + 6 <= data.len() {
|
||||
let tag = u16::from_le_bytes([data[offset], data[offset + 1]]);
|
||||
let length = u32::from_le_bytes([
|
||||
data[offset + 2],
|
||||
data[offset + 3],
|
||||
data[offset + 4],
|
||||
data[offset + 5],
|
||||
]) as usize;
|
||||
offset += 6;
|
||||
if offset + length > data.len() {
|
||||
return Err(RvfBridgeError::TruncatedTlv);
|
||||
}
|
||||
entries.push(AgiTlvEntry {
|
||||
tag,
|
||||
value: data[offset..offset + length].to_vec(),
|
||||
});
|
||||
offset += length;
|
||||
}
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
// ─── Solver bridge ──────────────────────────────────────────────────────────
|
||||
|
||||
/// Compact prior exchange format bridging domain expansion's `MetaThompsonEngine`
|
||||
/// to the rvf-solver-wasm `PolicyKernel`.
|
||||
///
|
||||
/// The solver-wasm uses per-bucket `SkipModeStats` with `(alpha_safety, beta_safety)`
|
||||
/// and `cost_ema`. The domain expansion uses per-bucket `BetaParams` with
|
||||
/// `(alpha, beta)` and `cost_ema_priors`. This type converts between them.
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct SolverPriorExchange {
|
||||
/// Context bucket key (e.g. "medium:some:clean").
|
||||
pub bucket_key: String,
|
||||
/// Per-arm alpha/beta pairs mapping arm name to (alpha, beta).
|
||||
pub arm_params: Vec<(String, f32, f32)>,
|
||||
/// Cost EMA for this bucket.
|
||||
pub cost_ema: f32,
|
||||
/// Training cycle count for confidence estimation.
|
||||
pub training_cycles: u64,
|
||||
}
|
||||
|
||||
/// Extract solver-compatible prior exchange data from the Thompson engine.
|
||||
///
|
||||
/// Flattens the domain expansion's hierarchical buckets into the solver's
|
||||
/// flat "range:distractor:noise" keys for the specified domain.
|
||||
pub fn extract_solver_priors(
|
||||
engine: &MetaThompsonEngine,
|
||||
domain_id: &DomainId,
|
||||
) -> Vec<SolverPriorExchange> {
|
||||
let prior = match engine.extract_prior(domain_id) {
|
||||
Some(p) => p,
|
||||
None => return Vec::new(),
|
||||
};
|
||||
|
||||
prior
|
||||
.bucket_priors
|
||||
.iter()
|
||||
.map(|(bucket, arms)| {
|
||||
let bucket_key = format!("{}:{}", bucket.difficulty_tier, bucket.category);
|
||||
let arm_params: Vec<(String, f32, f32)> = arms
|
||||
.iter()
|
||||
.map(|(arm, params)| (arm.0.clone(), params.alpha, params.beta))
|
||||
.collect();
|
||||
let cost_ema = prior.cost_ema_priors.get(bucket).copied().unwrap_or(1.0);
|
||||
|
||||
SolverPriorExchange {
|
||||
bucket_key,
|
||||
arm_params,
|
||||
cost_ema,
|
||||
training_cycles: prior.training_cycles,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Import solver prior exchange data back into the Thompson engine.
|
||||
///
|
||||
/// Seeds the specified domain with the exchanged priors, enabling
|
||||
/// cross-system transfer.
|
||||
pub fn import_solver_priors(
|
||||
engine: &mut MetaThompsonEngine,
|
||||
domain_id: &DomainId,
|
||||
exchanges: &[SolverPriorExchange],
|
||||
) {
|
||||
// Build a synthetic TransferPrior from the exchange data.
|
||||
let mut prior = TransferPrior::uniform(domain_id.clone());
|
||||
|
||||
for exchange in exchanges {
|
||||
let parts: Vec<&str> = exchange.bucket_key.splitn(2, ':').collect();
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: parts.first().unwrap_or(&"medium").to_string(),
|
||||
category: parts.get(1).unwrap_or(&"general").to_string(),
|
||||
};
|
||||
|
||||
let mut arm_map = std::collections::HashMap::new();
|
||||
for (arm_name, alpha, beta) in &exchange.arm_params {
|
||||
arm_map.insert(
|
||||
crate::transfer::ArmId(arm_name.clone()),
|
||||
BetaParams {
|
||||
alpha: *alpha,
|
||||
beta: *beta,
|
||||
},
|
||||
);
|
||||
}
|
||||
prior.bucket_priors.insert(bucket.clone(), arm_map);
|
||||
prior.cost_ema_priors.insert(bucket, exchange.cost_ema);
|
||||
prior.training_cycles = exchange.training_cycles;
|
||||
}
|
||||
|
||||
engine.init_domain_with_transfer(domain_id.clone(), &prior);
|
||||
}
|
||||
|
||||
// ─── Multi-segment file assembly ────────────────────────────────────────────
|
||||
|
||||
/// Assemble a complete RVF byte stream containing all domain expansion segments.
|
||||
///
|
||||
/// Outputs concatenated segments: transfer priors, then policy kernels, then
|
||||
/// cost curves. Each gets a unique segment ID starting from `base_segment_id`.
|
||||
///
|
||||
/// The returned bytes can be appended to an existing RVF file or written as
|
||||
/// a standalone domain expansion archive.
|
||||
pub fn assemble_domain_expansion_segments(
|
||||
priors: &[TransferPrior],
|
||||
kernels: &[PolicyKernel],
|
||||
curves: &[CostCurve],
|
||||
base_segment_id: u64,
|
||||
) -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
let mut seg_id = base_segment_id;
|
||||
|
||||
for prior in priors {
|
||||
buf.extend_from_slice(&transfer_prior_to_segment(prior, seg_id));
|
||||
seg_id += 1;
|
||||
}
|
||||
for kernel in kernels {
|
||||
buf.extend_from_slice(&policy_kernel_to_segment(kernel, seg_id));
|
||||
seg_id += 1;
|
||||
}
|
||||
for curve in curves {
|
||||
buf.extend_from_slice(&cost_curve_to_segment(curve, seg_id));
|
||||
seg_id += 1;
|
||||
}
|
||||
|
||||
buf
|
||||
}
|
||||
|
||||
// ─── Errors ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Errors specific to the RVF bridge operations.
|
||||
#[derive(Debug)]
|
||||
pub enum RvfBridgeError {
|
||||
/// Underlying RVF format error.
|
||||
Rvf(rvf_types::RvfError),
|
||||
/// JSON serialization/deserialization error.
|
||||
Json(serde_json::Error),
|
||||
/// Segment type mismatch.
|
||||
WrongSegmentType {
|
||||
/// Expected segment type discriminant.
|
||||
expected: u8,
|
||||
/// Actual segment type discriminant.
|
||||
got: u8,
|
||||
},
|
||||
/// TLV payload truncated.
|
||||
TruncatedTlv,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for RvfBridgeError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Rvf(e) => write!(f, "RVF error: {e}"),
|
||||
Self::Json(e) => write!(f, "JSON error: {e}"),
|
||||
Self::WrongSegmentType { expected, got } => {
|
||||
write!(
|
||||
f,
|
||||
"wrong segment type: expected 0x{expected:02X}, got 0x{got:02X}"
|
||||
)
|
||||
}
|
||||
Self::TruncatedTlv => write!(f, "TLV payload truncated"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for RvfBridgeError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Self::Json(e) => Some(e),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::cost_curve::{ConvergenceThresholds, CostCurvePoint};
|
||||
|
||||
#[test]
|
||||
fn transfer_prior_round_trip() {
|
||||
let mut prior = TransferPrior::uniform(DomainId("test".into()));
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "medium".into(),
|
||||
category: "algo".into(),
|
||||
};
|
||||
prior.update_posterior(bucket, crate::transfer::ArmId("greedy".into()), 0.85);
|
||||
|
||||
let segment = transfer_prior_to_segment(&prior, 1);
|
||||
let decoded = transfer_prior_from_segment(&segment).unwrap();
|
||||
|
||||
assert_eq!(decoded.source_domain, prior.source_domain);
|
||||
assert_eq!(decoded.training_cycles, prior.training_cycles);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn policy_kernel_round_trip() {
|
||||
let kernel = PolicyKernel::new("test_kernel".into());
|
||||
let segment = policy_kernel_to_segment(&kernel, 2);
|
||||
let decoded = policy_kernel_from_segment(&segment).unwrap();
|
||||
|
||||
assert_eq!(decoded.id, "test_kernel");
|
||||
assert_eq!(decoded.generation, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cost_curve_round_trip() {
|
||||
let mut curve = CostCurve::new(DomainId("test".into()), ConvergenceThresholds::default());
|
||||
curve.record(CostCurvePoint {
|
||||
cycle: 0,
|
||||
accuracy: 0.3,
|
||||
cost_per_solve: 0.1,
|
||||
robustness: 0.3,
|
||||
policy_violations: 0,
|
||||
timestamp: 0.0,
|
||||
});
|
||||
|
||||
let segment = cost_curve_to_segment(&curve, 3);
|
||||
let decoded = cost_curve_from_segment(&segment).unwrap();
|
||||
|
||||
assert_eq!(decoded.domain_id, DomainId("test".into()));
|
||||
assert_eq!(decoded.points.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_segment_type_detected() {
|
||||
let kernel = PolicyKernel::new("k".into());
|
||||
let segment = policy_kernel_to_segment(&kernel, 1);
|
||||
let result = transfer_prior_from_segment(&segment);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(RvfBridgeError::WrongSegmentType { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn witness_hash_is_deterministic() {
|
||||
let prior = TransferPrior::uniform(DomainId("test".into()));
|
||||
let h1 = compute_transfer_witness_hash(&prior);
|
||||
let h2 = compute_transfer_witness_hash(&prior);
|
||||
assert_eq!(h1, h2);
|
||||
assert_ne!(h1, [0u8; 32]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn witness_entries_chain() {
|
||||
let prior = TransferPrior::uniform(DomainId("d1".into()));
|
||||
let entries = build_transfer_witness_entries(
|
||||
&prior,
|
||||
&DomainId("d1".into()),
|
||||
&DomainId("d2".into()),
|
||||
2.5,
|
||||
1_000_000_000,
|
||||
);
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert_eq!(entries[0].witness_type, WITNESS_TRANSFER);
|
||||
assert_eq!(entries[1].witness_type, WITNESS_CONVERGENCE);
|
||||
|
||||
// Verify the chain is valid after linking
|
||||
let chain_bytes = rvf_crypto::create_witness_chain(&entries);
|
||||
let verified = rvf_crypto::verify_witness_chain(&chain_bytes).unwrap();
|
||||
assert_eq!(verified.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tlv_round_trip() {
|
||||
let entries = vec![
|
||||
AgiTlvEntry {
|
||||
tag: rvf_types::AGI_TAG_TRANSFER_PRIOR,
|
||||
value: b"hello".to_vec(),
|
||||
},
|
||||
AgiTlvEntry {
|
||||
tag: rvf_types::AGI_TAG_POLICY_KERNEL,
|
||||
value: b"world".to_vec(),
|
||||
},
|
||||
];
|
||||
|
||||
let encoded = encode_tlv_entries(&entries);
|
||||
let decoded = decode_tlv_entries(&encoded).unwrap();
|
||||
|
||||
assert_eq!(decoded.len(), 2);
|
||||
assert_eq!(decoded[0].tag, rvf_types::AGI_TAG_TRANSFER_PRIOR);
|
||||
assert_eq!(decoded[0].value, b"hello");
|
||||
assert_eq!(decoded[1].tag, rvf_types::AGI_TAG_POLICY_KERNEL);
|
||||
assert_eq!(decoded[1].value, b"world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agi_container_packaging() {
|
||||
let prior = TransferPrior::uniform(DomainId("test".into()));
|
||||
let kernel = PolicyKernel::new("k0".into());
|
||||
let scoreboard = crate::cost_curve::AccelerationScoreboard::new();
|
||||
|
||||
let entries = package_for_agi_container(&[prior], &[kernel], &scoreboard);
|
||||
assert_eq!(entries.len(), 2); // 1 prior + 1 kernel, 0 curves
|
||||
|
||||
let encoded = encode_tlv_entries(&entries);
|
||||
let decoded = decode_tlv_entries(&encoded).unwrap();
|
||||
assert_eq!(decoded.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn solver_prior_exchange_round_trip() {
|
||||
let arms = vec!["greedy".into(), "exploratory".into()];
|
||||
let mut engine = MetaThompsonEngine::new(arms);
|
||||
let domain = DomainId("test".into());
|
||||
engine.init_domain_uniform(domain.clone());
|
||||
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "medium".into(),
|
||||
category: "algorithm".into(),
|
||||
};
|
||||
for _ in 0..20 {
|
||||
engine.record_outcome(
|
||||
&domain,
|
||||
bucket.clone(),
|
||||
crate::transfer::ArmId("greedy".into()),
|
||||
0.9,
|
||||
1.0,
|
||||
);
|
||||
}
|
||||
|
||||
let exchanges = extract_solver_priors(&engine, &domain);
|
||||
assert!(!exchanges.is_empty());
|
||||
|
||||
// Import into a fresh engine
|
||||
let new_arms = vec!["greedy".into(), "exploratory".into()];
|
||||
let mut new_engine = MetaThompsonEngine::new(new_arms);
|
||||
let target = DomainId("target".into());
|
||||
new_engine.init_domain_uniform(target.clone());
|
||||
import_solver_priors(&mut new_engine, &target, &exchanges);
|
||||
|
||||
// Should have transferred priors
|
||||
let extracted = new_engine.extract_prior(&target);
|
||||
assert!(extracted.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multi_segment_assembly() {
|
||||
let prior = TransferPrior::uniform(DomainId("d1".into()));
|
||||
let kernel = PolicyKernel::new("k0".into());
|
||||
let mut curve = CostCurve::new(DomainId("d1".into()), ConvergenceThresholds::default());
|
||||
curve.record(CostCurvePoint {
|
||||
cycle: 0,
|
||||
accuracy: 0.5,
|
||||
cost_per_solve: 0.05,
|
||||
robustness: 0.5,
|
||||
policy_violations: 0,
|
||||
timestamp: 0.0,
|
||||
});
|
||||
|
||||
let assembled = assemble_domain_expansion_segments(&[prior], &[kernel], &[curve], 100);
|
||||
|
||||
// Should contain 3 segments, each 64-byte aligned
|
||||
assert!(assembled.len() >= 3 * 64);
|
||||
assert_eq!(assembled.len() % 64, 0);
|
||||
|
||||
// Verify first segment header magic
|
||||
let magic = u32::from_le_bytes([assembled[0], assembled[1], assembled[2], assembled[3]]);
|
||||
assert_eq!(magic, rvf_types::SEGMENT_MAGIC);
|
||||
}
|
||||
}
|
||||
727
vendor/ruvector/crates/ruvector-domain-expansion/src/tool_orchestration.rs
vendored
Normal file
727
vendor/ruvector/crates/ruvector-domain-expansion/src/tool_orchestration.rs
vendored
Normal file
@@ -0,0 +1,727 @@
|
||||
//! Tool Orchestration Problems Domain
|
||||
//!
|
||||
//! Generates tasks requiring coordinating multiple tools/agents to achieve goals.
|
||||
//! Task types include:
|
||||
//!
|
||||
//! - **PipelineConstruction**: Build a data processing pipeline from available tools
|
||||
//! - **ErrorRecovery**: Handle failures in multi-step tool chains
|
||||
//! - **ParallelCoordination**: Execute independent tool calls concurrently
|
||||
//! - **ResourceNegotiation**: Manage shared resources across tool invocations
|
||||
//! - **AdaptiveRouting**: Select tools dynamically based on intermediate results
|
||||
//!
|
||||
//! Cross-domain transfer is strongest here: planning decomposes goals,
|
||||
//! Rust synthesis provides execution patterns, and orchestration combines them.
|
||||
|
||||
use crate::domain::{Domain, DomainEmbedding, DomainId, Evaluation, Solution, Task};
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const EMBEDDING_DIM: usize = 64;
|
||||
|
||||
/// Categories of tool orchestration tasks.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum OrchestrationCategory {
|
||||
/// Build a pipeline: chain tools to transform input to desired output.
|
||||
PipelineConstruction,
|
||||
/// Handle failure: detect errors and apply fallback strategies.
|
||||
ErrorRecovery,
|
||||
/// Coordinate parallel: dispatch independent calls and merge results.
|
||||
ParallelCoordination,
|
||||
/// Negotiate resources: manage rate limits, quotas, shared state.
|
||||
ResourceNegotiation,
|
||||
/// Adaptive routing: choose tool based on intermediate result properties.
|
||||
AdaptiveRouting,
|
||||
}
|
||||
|
||||
/// A tool available in the orchestration environment.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolSpec {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
/// Input type signature (e.g., "text", "json", "binary").
|
||||
pub input_type: String,
|
||||
/// Output type signature.
|
||||
pub output_type: String,
|
||||
/// Average latency in milliseconds.
|
||||
pub latency_ms: u32,
|
||||
/// Failure rate [0.0, 1.0].
|
||||
pub failure_rate: f32,
|
||||
/// Cost per invocation.
|
||||
pub cost: f32,
|
||||
/// Rate limit (max calls per minute), 0 = unlimited.
|
||||
pub rate_limit: u32,
|
||||
}
|
||||
|
||||
/// An orchestration task specification.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OrchestrationTaskSpec {
|
||||
pub category: OrchestrationCategory,
|
||||
pub description: String,
|
||||
/// Available tools in the environment.
|
||||
pub available_tools: Vec<ToolSpec>,
|
||||
/// Input to the pipeline.
|
||||
pub input: serde_json::Value,
|
||||
/// Expected output type/shape.
|
||||
pub expected_output_type: String,
|
||||
/// Maximum total latency budget (ms).
|
||||
pub latency_budget_ms: u32,
|
||||
/// Maximum total cost budget.
|
||||
pub cost_budget: f32,
|
||||
/// Required reliability (min success rate).
|
||||
pub min_reliability: f32,
|
||||
/// Error scenarios that must be handled.
|
||||
pub error_scenarios: Vec<String>,
|
||||
}
|
||||
|
||||
/// A tool call in an orchestration solution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
pub tool_name: String,
|
||||
/// Input to this tool call (ref to previous output or literal).
|
||||
pub input_ref: String,
|
||||
/// Whether this can run in parallel with other calls.
|
||||
pub parallel_group: Option<u32>,
|
||||
/// Fallback tool if this one fails.
|
||||
pub fallback: Option<String>,
|
||||
/// Retry count on failure.
|
||||
pub retries: u32,
|
||||
}
|
||||
|
||||
/// A parsed orchestration plan.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OrchestrationPlan {
|
||||
pub calls: Vec<ToolCall>,
|
||||
/// Error handling strategy description.
|
||||
pub error_strategy: String,
|
||||
}
|
||||
|
||||
/// Tool orchestration domain.
|
||||
pub struct ToolOrchestrationDomain {
|
||||
id: DomainId,
|
||||
}
|
||||
|
||||
impl ToolOrchestrationDomain {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
id: DomainId("tool_orchestration".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn base_tools() -> Vec<ToolSpec> {
|
||||
vec![
|
||||
ToolSpec {
|
||||
name: "text_extract".into(),
|
||||
description: "Extract text from documents".into(),
|
||||
input_type: "binary".into(),
|
||||
output_type: "text".into(),
|
||||
latency_ms: 50,
|
||||
failure_rate: 0.02,
|
||||
cost: 0.001,
|
||||
rate_limit: 100,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "text_embed".into(),
|
||||
description: "Generate embeddings from text".into(),
|
||||
input_type: "text".into(),
|
||||
output_type: "vector".into(),
|
||||
latency_ms: 30,
|
||||
failure_rate: 0.01,
|
||||
cost: 0.002,
|
||||
rate_limit: 200,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "vector_search".into(),
|
||||
description: "Search vector index for similar items".into(),
|
||||
input_type: "vector".into(),
|
||||
output_type: "json".into(),
|
||||
latency_ms: 10,
|
||||
failure_rate: 0.005,
|
||||
cost: 0.0005,
|
||||
rate_limit: 500,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "llm_generate".into(),
|
||||
description: "Generate text using language model".into(),
|
||||
input_type: "text".into(),
|
||||
output_type: "text".into(),
|
||||
latency_ms: 2000,
|
||||
failure_rate: 0.05,
|
||||
cost: 0.01,
|
||||
rate_limit: 30,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "json_transform".into(),
|
||||
description: "Apply JQ-like transformations to JSON".into(),
|
||||
input_type: "json".into(),
|
||||
output_type: "json".into(),
|
||||
latency_ms: 5,
|
||||
failure_rate: 0.001,
|
||||
cost: 0.0001,
|
||||
rate_limit: 0,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "code_execute".into(),
|
||||
description: "Execute code in sandboxed environment".into(),
|
||||
input_type: "text".into(),
|
||||
output_type: "json".into(),
|
||||
latency_ms: 500,
|
||||
failure_rate: 0.1,
|
||||
cost: 0.005,
|
||||
rate_limit: 20,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "http_fetch".into(),
|
||||
description: "Fetch data from external HTTP endpoint".into(),
|
||||
input_type: "text".into(),
|
||||
output_type: "json".into(),
|
||||
latency_ms: 300,
|
||||
failure_rate: 0.15,
|
||||
cost: 0.0,
|
||||
rate_limit: 60,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "cache_lookup".into(),
|
||||
description: "Check local cache for previously computed results".into(),
|
||||
input_type: "text".into(),
|
||||
output_type: "json".into(),
|
||||
latency_ms: 1,
|
||||
failure_rate: 0.0,
|
||||
cost: 0.0,
|
||||
rate_limit: 0,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "validator".into(),
|
||||
description: "Validate output against schema".into(),
|
||||
input_type: "json".into(),
|
||||
output_type: "json".into(),
|
||||
latency_ms: 2,
|
||||
failure_rate: 0.0,
|
||||
cost: 0.0,
|
||||
rate_limit: 0,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "aggregator".into(),
|
||||
description: "Merge multiple results into one".into(),
|
||||
input_type: "json".into(),
|
||||
output_type: "json".into(),
|
||||
latency_ms: 5,
|
||||
failure_rate: 0.0,
|
||||
cost: 0.0001,
|
||||
rate_limit: 0,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
fn gen_pipeline(&self, difficulty: f32) -> OrchestrationTaskSpec {
|
||||
let tools = Self::base_tools();
|
||||
let num_tools = if difficulty < 0.3 {
|
||||
3
|
||||
} else if difficulty < 0.7 {
|
||||
6
|
||||
} else {
|
||||
10
|
||||
};
|
||||
|
||||
OrchestrationTaskSpec {
|
||||
category: OrchestrationCategory::PipelineConstruction,
|
||||
description: format!(
|
||||
"Build a RAG pipeline using {} tools: extract, embed, search, generate.",
|
||||
num_tools
|
||||
),
|
||||
available_tools: tools[..num_tools.min(tools.len())].to_vec(),
|
||||
input: serde_json::json!({"type": "binary", "format": "pdf"}),
|
||||
expected_output_type: "text".into(),
|
||||
latency_budget_ms: if difficulty < 0.5 { 5000 } else { 2000 },
|
||||
cost_budget: if difficulty < 0.5 { 0.1 } else { 0.02 },
|
||||
min_reliability: if difficulty < 0.5 { 0.9 } else { 0.99 },
|
||||
error_scenarios: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_error_recovery(&self, difficulty: f32) -> OrchestrationTaskSpec {
|
||||
let tools = Self::base_tools();
|
||||
let error_scenarios = if difficulty < 0.3 {
|
||||
vec!["timeout on llm_generate".into()]
|
||||
} else if difficulty < 0.7 {
|
||||
vec![
|
||||
"timeout on llm_generate".into(),
|
||||
"http_fetch returns 429".into(),
|
||||
"code_execute sandbox OOM".into(),
|
||||
]
|
||||
} else {
|
||||
vec![
|
||||
"timeout on llm_generate".into(),
|
||||
"http_fetch returns 429".into(),
|
||||
"code_execute sandbox OOM".into(),
|
||||
"vector_search index corruption".into(),
|
||||
"cascading failure: embed + search both down".into(),
|
||||
]
|
||||
};
|
||||
|
||||
OrchestrationTaskSpec {
|
||||
category: OrchestrationCategory::ErrorRecovery,
|
||||
description: format!(
|
||||
"Handle {} error scenarios in a multi-tool pipeline with graceful degradation.",
|
||||
error_scenarios.len()
|
||||
),
|
||||
available_tools: tools,
|
||||
input: serde_json::json!({"type": "text", "content": "query"}),
|
||||
expected_output_type: "json".into(),
|
||||
latency_budget_ms: 10000,
|
||||
cost_budget: 0.1,
|
||||
min_reliability: 0.95,
|
||||
error_scenarios,
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_parallel_coordination(&self, difficulty: f32) -> OrchestrationTaskSpec {
|
||||
let tools = Self::base_tools();
|
||||
let parallelism = if difficulty < 0.3 {
|
||||
2
|
||||
} else if difficulty < 0.7 {
|
||||
4
|
||||
} else {
|
||||
8
|
||||
};
|
||||
|
||||
OrchestrationTaskSpec {
|
||||
category: OrchestrationCategory::ParallelCoordination,
|
||||
description: format!(
|
||||
"Execute {} independent tool chains in parallel, merge results within latency budget.",
|
||||
parallelism
|
||||
),
|
||||
available_tools: tools,
|
||||
input: serde_json::json!({"queries": (0..parallelism).map(|i| format!("query_{}", i)).collect::<Vec<_>>()}),
|
||||
expected_output_type: "json".into(),
|
||||
latency_budget_ms: if difficulty < 0.5 { 3000 } else { 1000 },
|
||||
cost_budget: 0.05 * parallelism as f32,
|
||||
min_reliability: 0.95,
|
||||
error_scenarios: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_features(&self, solution: &Solution) -> Vec<f32> {
|
||||
let content = &solution.content;
|
||||
let mut features = vec![0.0f32; EMBEDDING_DIM];
|
||||
|
||||
let plan: OrchestrationPlan = serde_json::from_str(&solution.data.to_string())
|
||||
.or_else(|_| serde_json::from_str(content))
|
||||
.unwrap_or(OrchestrationPlan {
|
||||
calls: Vec::new(),
|
||||
error_strategy: String::new(),
|
||||
});
|
||||
|
||||
// Feature 0-7: Plan structure
|
||||
features[0] = plan.calls.len() as f32 / 20.0;
|
||||
let unique_tools: std::collections::HashSet<&str> =
|
||||
plan.calls.iter().map(|c| c.tool_name.as_str()).collect();
|
||||
features[1] = unique_tools.len() as f32 / 10.0;
|
||||
// Parallelism ratio
|
||||
let parallel_calls = plan
|
||||
.calls
|
||||
.iter()
|
||||
.filter(|c| c.parallel_group.is_some())
|
||||
.count();
|
||||
features[2] = parallel_calls as f32 / plan.calls.len().max(1) as f32;
|
||||
// Fallback coverage
|
||||
let fallback_calls = plan.calls.iter().filter(|c| c.fallback.is_some()).count();
|
||||
features[3] = fallback_calls as f32 / plan.calls.len().max(1) as f32;
|
||||
// Average retries
|
||||
let total_retries: u32 = plan.calls.iter().map(|c| c.retries).sum();
|
||||
features[4] = total_retries as f32 / plan.calls.len().max(1) as f32 / 5.0;
|
||||
|
||||
// Feature 8-15: Tool type usage
|
||||
let tool_names = [
|
||||
"extract",
|
||||
"embed",
|
||||
"search",
|
||||
"generate",
|
||||
"transform",
|
||||
"execute",
|
||||
"fetch",
|
||||
"cache",
|
||||
];
|
||||
for (i, name) in tool_names.iter().enumerate() {
|
||||
features[8 + i] = plan
|
||||
.calls
|
||||
.iter()
|
||||
.filter(|c| c.tool_name.contains(name))
|
||||
.count() as f32
|
||||
/ plan.calls.len().max(1) as f32;
|
||||
}
|
||||
|
||||
// Feature 16-23: Text pattern features
|
||||
features[16] = content.matches("pipeline").count() as f32 / 3.0;
|
||||
features[17] = content.matches("parallel").count() as f32 / 5.0;
|
||||
features[18] = content.matches("fallback").count() as f32 / 5.0;
|
||||
features[19] = content.matches("retry").count() as f32 / 5.0;
|
||||
features[20] = content.matches("cache").count() as f32 / 5.0;
|
||||
features[21] = content.matches("timeout").count() as f32 / 3.0;
|
||||
features[22] = content.matches("merge").count() as f32 / 3.0;
|
||||
features[23] = content.matches("validate").count() as f32 / 3.0;
|
||||
|
||||
// Feature 32-39: Error handling patterns
|
||||
features[32] = content.matches("error").count() as f32 / 5.0;
|
||||
features[33] = content.matches("recover").count() as f32 / 3.0;
|
||||
features[34] = content.matches("degrade").count() as f32 / 3.0;
|
||||
features[35] = content.matches("circuit_break").count() as f32 / 2.0;
|
||||
features[36] = content.matches("rate_limit").count() as f32 / 3.0;
|
||||
features[37] = content.matches("backoff").count() as f32 / 3.0;
|
||||
features[38] = content.matches("health_check").count() as f32 / 2.0;
|
||||
features[39] = content.matches("monitor").count() as f32 / 3.0;
|
||||
|
||||
// Feature 48-55: Coordination patterns
|
||||
features[48] = content.matches("scatter").count() as f32 / 2.0;
|
||||
features[49] = content.matches("gather").count() as f32 / 2.0;
|
||||
features[50] = content.matches("fan_out").count() as f32 / 2.0;
|
||||
features[51] = content.matches("aggregate").count() as f32 / 3.0;
|
||||
features[52] = content.matches("route").count() as f32 / 3.0;
|
||||
features[53] = content.matches("dispatch").count() as f32 / 3.0;
|
||||
features[54] = content.matches("await").count() as f32 / 5.0;
|
||||
features[55] = content.matches("join").count() as f32 / 3.0;
|
||||
|
||||
// Normalize
|
||||
let norm: f32 = features.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 1e-10 {
|
||||
for f in &mut features {
|
||||
*f /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
features
|
||||
}
|
||||
|
||||
fn score_orchestration(&self, spec: &OrchestrationTaskSpec, solution: &Solution) -> Evaluation {
|
||||
let content = &solution.content;
|
||||
let mut correctness = 0.0f32;
|
||||
let mut efficiency = 0.5f32;
|
||||
let mut elegance = 0.5f32;
|
||||
let mut notes = Vec::new();
|
||||
|
||||
let plan: Option<OrchestrationPlan> = serde_json::from_str(&solution.data.to_string())
|
||||
.ok()
|
||||
.or_else(|| serde_json::from_str(content).ok());
|
||||
|
||||
let plan = match plan {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
let has_tools = spec
|
||||
.available_tools
|
||||
.iter()
|
||||
.any(|t| content.contains(&t.name));
|
||||
if has_tools {
|
||||
correctness = 0.2;
|
||||
}
|
||||
return Evaluation {
|
||||
score: correctness * 0.6,
|
||||
correctness,
|
||||
efficiency: 0.0,
|
||||
elegance: 0.0,
|
||||
constraint_results: Vec::new(),
|
||||
notes: vec!["Could not parse orchestration plan".into()],
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
if plan.calls.is_empty() {
|
||||
return Evaluation::zero(vec!["Empty orchestration plan".into()]);
|
||||
}
|
||||
|
||||
// Correctness: type chain validity
|
||||
let mut type_errors = 0;
|
||||
for window in plan.calls.windows(2) {
|
||||
let output_tool = spec
|
||||
.available_tools
|
||||
.iter()
|
||||
.find(|t| t.name == window[0].tool_name);
|
||||
let input_tool = spec
|
||||
.available_tools
|
||||
.iter()
|
||||
.find(|t| t.name == window[1].tool_name);
|
||||
|
||||
if let (Some(out_t), Some(in_t)) = (output_tool, input_tool) {
|
||||
if window[1].parallel_group.is_none() && out_t.output_type != in_t.input_type {
|
||||
type_errors += 1;
|
||||
notes.push(format!(
|
||||
"Type mismatch: {} outputs {} but {} expects {}",
|
||||
out_t.name, out_t.output_type, in_t.name, in_t.input_type
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
let chain_len = (plan.calls.len() - 1).max(1);
|
||||
correctness = 1.0 - (type_errors as f32 / chain_len as f32);
|
||||
|
||||
// Tool coverage: do we use tools that produce the expected output?
|
||||
let produces_output = plan.calls.iter().any(|c| {
|
||||
spec.available_tools
|
||||
.iter()
|
||||
.any(|t| t.name == c.tool_name && t.output_type == spec.expected_output_type)
|
||||
});
|
||||
if !produces_output {
|
||||
correctness *= 0.5;
|
||||
notes.push("No tool produces the expected output type".into());
|
||||
}
|
||||
|
||||
// Error handling coverage
|
||||
if !spec.error_scenarios.is_empty() {
|
||||
let handled = spec
|
||||
.error_scenarios
|
||||
.iter()
|
||||
.filter(|scenario| {
|
||||
plan.calls
|
||||
.iter()
|
||||
.any(|c| c.fallback.is_some() || c.retries > 0)
|
||||
|| plan
|
||||
.error_strategy
|
||||
.contains(&scenario.as_str()[..scenario.len().min(10)])
|
||||
})
|
||||
.count() as f32
|
||||
/ spec.error_scenarios.len() as f32;
|
||||
correctness = correctness * 0.7 + handled * 0.3;
|
||||
}
|
||||
|
||||
// Efficiency: estimated latency and cost
|
||||
let est_latency: u32 = {
|
||||
let mut groups: std::collections::HashMap<u32, u32> = std::collections::HashMap::new();
|
||||
let mut sequential_latency = 0u32;
|
||||
for call in &plan.calls {
|
||||
let tool_latency = spec
|
||||
.available_tools
|
||||
.iter()
|
||||
.find(|t| t.name == call.tool_name)
|
||||
.map(|t| t.latency_ms)
|
||||
.unwrap_or(100);
|
||||
|
||||
if let Some(group) = call.parallel_group {
|
||||
let entry = groups.entry(group).or_insert(0);
|
||||
*entry = (*entry).max(tool_latency);
|
||||
} else {
|
||||
sequential_latency += tool_latency;
|
||||
}
|
||||
}
|
||||
sequential_latency + groups.values().sum::<u32>()
|
||||
};
|
||||
|
||||
if est_latency <= spec.latency_budget_ms {
|
||||
efficiency = 1.0 - (est_latency as f32 / spec.latency_budget_ms as f32 * 0.5);
|
||||
} else {
|
||||
efficiency = spec.latency_budget_ms as f32 / est_latency as f32 * 0.5;
|
||||
notes.push(format!(
|
||||
"Estimated latency {}ms exceeds budget {}ms",
|
||||
est_latency, spec.latency_budget_ms
|
||||
));
|
||||
}
|
||||
|
||||
let est_cost: f32 = plan
|
||||
.calls
|
||||
.iter()
|
||||
.filter_map(|c| {
|
||||
spec.available_tools
|
||||
.iter()
|
||||
.find(|t| t.name == c.tool_name)
|
||||
.map(|t| t.cost * (1.0 + c.retries as f32))
|
||||
})
|
||||
.sum();
|
||||
|
||||
if est_cost > spec.cost_budget {
|
||||
efficiency *= 0.7;
|
||||
notes.push(format!(
|
||||
"Cost {:.4} exceeds budget {:.4}",
|
||||
est_cost, spec.cost_budget
|
||||
));
|
||||
}
|
||||
|
||||
// Elegance: parallelism, caching, minimal redundancy
|
||||
let parallelism_used = plan.calls.iter().any(|c| c.parallel_group.is_some());
|
||||
if parallelism_used {
|
||||
elegance += 0.15;
|
||||
}
|
||||
|
||||
let cache_used = plan.calls.iter().any(|c| c.tool_name.contains("cache"));
|
||||
if cache_used {
|
||||
elegance += 0.1;
|
||||
}
|
||||
|
||||
let validation_used = plan.calls.iter().any(|c| c.tool_name.contains("validat"));
|
||||
if validation_used {
|
||||
elegance += 0.1;
|
||||
}
|
||||
|
||||
// Penalize excessive retries
|
||||
let total_retries: u32 = plan.calls.iter().map(|c| c.retries).sum();
|
||||
if total_retries > plan.calls.len() as u32 * 2 {
|
||||
elegance -= 0.2;
|
||||
notes.push("Excessive retry configuration".into());
|
||||
}
|
||||
|
||||
elegance = elegance.clamp(0.0, 1.0);
|
||||
|
||||
let score = 0.6 * correctness + 0.25 * efficiency + 0.15 * elegance;
|
||||
Evaluation {
|
||||
score: score.clamp(0.0, 1.0),
|
||||
correctness,
|
||||
efficiency,
|
||||
elegance,
|
||||
constraint_results: Vec::new(),
|
||||
notes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ToolOrchestrationDomain {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Domain for ToolOrchestrationDomain {
|
||||
fn id(&self) -> &DomainId {
|
||||
&self.id
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"Tool Orchestration"
|
||||
}
|
||||
|
||||
fn generate_tasks(&self, count: usize, difficulty: f32) -> Vec<Task> {
|
||||
let mut rng = rand::thread_rng();
|
||||
let difficulty = difficulty.clamp(0.0, 1.0);
|
||||
|
||||
(0..count)
|
||||
.map(|i| {
|
||||
let roll: f32 = rng.gen();
|
||||
let spec = if roll < 0.4 {
|
||||
self.gen_pipeline(difficulty)
|
||||
} else if roll < 0.7 {
|
||||
self.gen_error_recovery(difficulty)
|
||||
} else {
|
||||
self.gen_parallel_coordination(difficulty)
|
||||
};
|
||||
|
||||
Task {
|
||||
id: format!("orch_{}_d{:.0}", i, difficulty * 100.0),
|
||||
domain_id: self.id.clone(),
|
||||
difficulty,
|
||||
spec: serde_json::to_value(&spec).unwrap_or_default(),
|
||||
constraints: Vec::new(),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn evaluate(&self, task: &Task, solution: &Solution) -> Evaluation {
|
||||
let spec: OrchestrationTaskSpec = match serde_json::from_value(task.spec.clone()) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Evaluation::zero(vec![format!("Invalid task spec: {}", e)]),
|
||||
};
|
||||
self.score_orchestration(&spec, solution)
|
||||
}
|
||||
|
||||
fn embed(&self, solution: &Solution) -> DomainEmbedding {
|
||||
let features = self.extract_features(solution);
|
||||
DomainEmbedding::new(features, self.id.clone())
|
||||
}
|
||||
|
||||
fn embedding_dim(&self) -> usize {
|
||||
EMBEDDING_DIM
|
||||
}
|
||||
|
||||
fn reference_solution(&self, task: &Task) -> Option<Solution> {
|
||||
let spec: OrchestrationTaskSpec = serde_json::from_value(task.spec.clone()).ok()?;
|
||||
|
||||
// Build a sequential pipeline through available tools
|
||||
let calls: Vec<ToolCall> = spec
|
||||
.available_tools
|
||||
.iter()
|
||||
.map(|t| ToolCall {
|
||||
tool_name: t.name.clone(),
|
||||
input_ref: "previous".into(),
|
||||
parallel_group: None,
|
||||
fallback: None,
|
||||
retries: if t.failure_rate > 0.05 { 2 } else { 0 },
|
||||
})
|
||||
.collect();
|
||||
|
||||
let plan = OrchestrationPlan {
|
||||
calls,
|
||||
error_strategy: "retry with exponential backoff".into(),
|
||||
};
|
||||
|
||||
let content = serde_json::to_string_pretty(&plan).ok()?;
|
||||
Some(Solution {
|
||||
task_id: task.id.clone(),
|
||||
content,
|
||||
data: serde_json::to_value(&plan).ok()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_orchestration_tasks() {
|
||||
let domain = ToolOrchestrationDomain::new();
|
||||
let tasks = domain.generate_tasks(5, 0.5);
|
||||
assert_eq!(tasks.len(), 5);
|
||||
for task in &tasks {
|
||||
assert_eq!(task.domain_id, domain.id);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reference_solution() {
|
||||
let domain = ToolOrchestrationDomain::new();
|
||||
let tasks = domain.generate_tasks(3, 0.3);
|
||||
for task in &tasks {
|
||||
let ref_sol = domain.reference_solution(task);
|
||||
assert!(ref_sol.is_some());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluate_reference() {
|
||||
let domain = ToolOrchestrationDomain::new();
|
||||
let tasks = domain.generate_tasks(3, 0.3);
|
||||
for task in &tasks {
|
||||
if let Some(solution) = domain.reference_solution(task) {
|
||||
let eval = domain.evaluate(task, &solution);
|
||||
assert!(eval.score >= 0.0 && eval.score <= 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embed_orchestration() {
|
||||
let domain = ToolOrchestrationDomain::new();
|
||||
let solution = Solution {
|
||||
task_id: "test".into(),
|
||||
content: "pipeline: extract -> embed -> search with fallback and retry".into(),
|
||||
data: serde_json::json!({
|
||||
"calls": [
|
||||
{"tool_name": "text_extract", "input_ref": "input", "retries": 1}
|
||||
],
|
||||
"error_strategy": "retry"
|
||||
}),
|
||||
};
|
||||
let embedding = domain.embed(&solution);
|
||||
assert_eq!(embedding.dim, EMBEDDING_DIM);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_difficulty_affects_error_scenarios() {
|
||||
let domain = ToolOrchestrationDomain::new();
|
||||
// Generate many tasks at high difficulty to get error recovery tasks
|
||||
let hard = domain.generate_tasks(20, 0.9);
|
||||
let has_error_tasks = hard.iter().any(|t| {
|
||||
let spec: OrchestrationTaskSpec = serde_json::from_value(t.spec.clone()).unwrap();
|
||||
!spec.error_scenarios.is_empty()
|
||||
});
|
||||
assert!(
|
||||
has_error_tasks,
|
||||
"High difficulty should produce error scenarios"
|
||||
);
|
||||
}
|
||||
}
|
||||
583
vendor/ruvector/crates/ruvector-domain-expansion/src/transfer.rs
vendored
Normal file
583
vendor/ruvector/crates/ruvector-domain-expansion/src/transfer.rs
vendored
Normal file
@@ -0,0 +1,583 @@
|
||||
//! Cross-Domain Transfer Engine with Meta Thompson Sampling
|
||||
//!
|
||||
//! Transfer happens through priors, not raw memories.
|
||||
//! Ship compact priors and verified kernels between domains.
|
||||
//!
|
||||
//! ## Two-Layer Learning Architecture
|
||||
//!
|
||||
//! **Policy learning layer**: Chooses strategies, budgets, and tool paths
|
||||
//! using uncertainty-aware selection (Thompson Sampling with Beta priors).
|
||||
//!
|
||||
//! **Operator layer**: Executes deterministic kernels and graders,
|
||||
//! logs witnesses, and commits state through gates.
|
||||
//!
|
||||
//! ## Meta Thompson Sampling
|
||||
//!
|
||||
//! After each cycle, compute posterior summary per bucket and arm.
|
||||
//! Store as TransferPrior. When a new domain starts, initialize its
|
||||
//! buckets with these priors instead of uniform, enabling faster adaptation.
|
||||
//!
|
||||
//! ## Cross-Domain Transfer Protocol
|
||||
//!
|
||||
//! A delta is promotable only if it improves Domain 2 without regressing
|
||||
//! Domain 1, or improves Domain 1 without regressing Domain 2.
|
||||
//! That is generalization.
|
||||
|
||||
use crate::domain::DomainId;
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Beta distribution parameters for Thompson Sampling.
|
||||
/// Represents uncertainty about an arm's reward probability.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BetaParams {
|
||||
/// Success count + prior (alpha).
|
||||
pub alpha: f32,
|
||||
/// Failure count + prior (beta).
|
||||
pub beta: f32,
|
||||
}
|
||||
|
||||
impl BetaParams {
|
||||
/// Uniform (uninformative) prior: Beta(1, 1).
|
||||
pub fn uniform() -> Self {
|
||||
Self {
|
||||
alpha: 1.0,
|
||||
beta: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create from observed successes and failures.
|
||||
pub fn from_observations(successes: f32, failures: f32) -> Self {
|
||||
Self {
|
||||
alpha: successes + 1.0,
|
||||
beta: failures + 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Mean of the Beta distribution: E[X] = alpha / (alpha + beta).
|
||||
pub fn mean(&self) -> f32 {
|
||||
self.alpha / (self.alpha + self.beta)
|
||||
}
|
||||
|
||||
/// Variance: measures uncertainty. Lower = more confident.
|
||||
pub fn variance(&self) -> f32 {
|
||||
let total = self.alpha + self.beta;
|
||||
(self.alpha * self.beta) / (total * total * (total + 1.0))
|
||||
}
|
||||
|
||||
/// Sample from the Beta distribution using the Kumaraswamy approximation.
|
||||
/// Fast, no special functions needed, good enough for Thompson Sampling.
|
||||
pub fn sample(&self, rng: &mut impl Rng) -> f32 {
|
||||
// Use inverse CDF of Beta via simple approximation
|
||||
let u: f32 = rng.gen_range(0.001..0.999);
|
||||
// Kumaraswamy approximation: x = (1 - (1 - u^(1/b))^(1/a))
|
||||
// Better approximation using ratio of gammas via the normal approach
|
||||
let x = Self::beta_inv_approx(u, self.alpha, self.beta);
|
||||
x.clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
/// Approximate inverse CDF of Beta distribution.
|
||||
fn beta_inv_approx(p: f32, a: f32, b: f32) -> f32 {
|
||||
// Use normal approximation for Beta when a,b are not too small
|
||||
if a > 1.0 && b > 1.0 {
|
||||
let mean = a / (a + b);
|
||||
let var = (a * b) / ((a + b) * (a + b) * (a + b + 1.0));
|
||||
let std = var.sqrt();
|
||||
// Inverse normal approximation (Abramowitz & Stegun)
|
||||
let t = if p < 0.5 {
|
||||
(-2.0 * (p).ln()).sqrt()
|
||||
} else {
|
||||
(-2.0 * (1.0 - p).ln()).sqrt()
|
||||
};
|
||||
let x = if p < 0.5 {
|
||||
mean - std * t
|
||||
} else {
|
||||
mean + std * t
|
||||
};
|
||||
x.clamp(0.001, 0.999)
|
||||
} else {
|
||||
// Fallback: simple power approximation
|
||||
p.powf(1.0 / a) * (1.0 - (1.0 - p).powf(1.0 / b)) + p.powf(1.0 / a) * 0.5
|
||||
}
|
||||
}
|
||||
|
||||
/// Update with an observation (Bayesian posterior update).
|
||||
pub fn update(&mut self, reward: f32) {
|
||||
self.alpha += reward;
|
||||
self.beta += 1.0 - reward;
|
||||
}
|
||||
|
||||
/// Merge two Beta distributions (approximate: sum parameters).
|
||||
pub fn merge(&self, other: &BetaParams) -> BetaParams {
|
||||
BetaParams {
|
||||
alpha: self.alpha + other.alpha - 1.0, // subtract uniform prior
|
||||
beta: self.beta + other.beta - 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A context bucket groups similar problem instances for targeted learning.
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct ContextBucket {
|
||||
/// Difficulty tier: "easy", "medium", "hard".
|
||||
pub difficulty_tier: String,
|
||||
/// Problem category within the domain.
|
||||
pub category: String,
|
||||
}
|
||||
|
||||
/// An arm in the multi-armed bandit: a strategy choice.
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct ArmId(pub String);
|
||||
|
||||
/// Transfer prior: compact posterior summary from a source domain.
|
||||
/// This is what gets shipped between domains — not raw trajectories.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TransferPrior {
|
||||
/// Source domain that generated this prior.
|
||||
pub source_domain: DomainId,
|
||||
/// Per-bucket, per-arm Beta parameters (posterior summaries).
|
||||
pub bucket_priors: HashMap<ContextBucket, HashMap<ArmId, BetaParams>>,
|
||||
/// Cost EMA (exponential moving average) priors per bucket.
|
||||
pub cost_ema_priors: HashMap<ContextBucket, f32>,
|
||||
/// Number of cycles this prior was trained on.
|
||||
pub training_cycles: u64,
|
||||
/// Witness hash: proof of how this prior was derived.
|
||||
pub witness_hash: String,
|
||||
}
|
||||
|
||||
impl TransferPrior {
|
||||
/// Create an empty (uniform) prior for a domain.
|
||||
pub fn uniform(source_domain: DomainId) -> Self {
|
||||
Self {
|
||||
source_domain,
|
||||
bucket_priors: HashMap::new(),
|
||||
cost_ema_priors: HashMap::new(),
|
||||
training_cycles: 0,
|
||||
witness_hash: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the prior for a specific bucket and arm, defaulting to uniform.
|
||||
pub fn get_prior(&self, bucket: &ContextBucket, arm: &ArmId) -> BetaParams {
|
||||
self.bucket_priors
|
||||
.get(bucket)
|
||||
.and_then(|arms| arms.get(arm))
|
||||
.cloned()
|
||||
.unwrap_or_else(BetaParams::uniform)
|
||||
}
|
||||
|
||||
/// Update the posterior for a bucket/arm with a new observation.
|
||||
pub fn update_posterior(&mut self, bucket: ContextBucket, arm: ArmId, reward: f32) {
|
||||
let arms = self.bucket_priors.entry(bucket.clone()).or_default();
|
||||
let params = arms.entry(arm).or_insert_with(BetaParams::uniform);
|
||||
params.update(reward);
|
||||
self.training_cycles += 1;
|
||||
}
|
||||
|
||||
/// Update cost EMA for a bucket.
|
||||
pub fn update_cost_ema(&mut self, bucket: ContextBucket, cost: f32, decay: f32) {
|
||||
let entry = self.cost_ema_priors.entry(bucket).or_insert(cost);
|
||||
*entry = decay * (*entry) + (1.0 - decay) * cost;
|
||||
}
|
||||
|
||||
/// Extract a compact summary suitable for shipping to another domain.
|
||||
pub fn extract_summary(&self) -> TransferPrior {
|
||||
// Only ship buckets with sufficient evidence (>10 observations)
|
||||
let filtered: HashMap<ContextBucket, HashMap<ArmId, BetaParams>> = self
|
||||
.bucket_priors
|
||||
.iter()
|
||||
.filter_map(|(bucket, arms)| {
|
||||
let significant_arms: HashMap<ArmId, BetaParams> = arms
|
||||
.iter()
|
||||
.filter(|(_, params)| (params.alpha + params.beta) > 12.0)
|
||||
.map(|(arm, params)| (arm.clone(), params.clone()))
|
||||
.collect();
|
||||
if significant_arms.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some((bucket.clone(), significant_arms))
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
TransferPrior {
|
||||
source_domain: self.source_domain.clone(),
|
||||
bucket_priors: filtered,
|
||||
cost_ema_priors: self.cost_ema_priors.clone(),
|
||||
training_cycles: self.training_cycles,
|
||||
witness_hash: self.witness_hash.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Meta Thompson Sampling engine that manages priors across domains.
|
||||
pub struct MetaThompsonEngine {
|
||||
/// Active priors per domain.
|
||||
domain_priors: HashMap<DomainId, TransferPrior>,
|
||||
/// Available arms (strategies) shared across domains.
|
||||
arms: Vec<ArmId>,
|
||||
/// Difficulty tiers for bucketing.
|
||||
difficulty_tiers: Vec<String>,
|
||||
}
|
||||
|
||||
impl MetaThompsonEngine {
|
||||
/// Create a new engine with the given strategy arms.
|
||||
pub fn new(arms: Vec<String>) -> Self {
|
||||
Self {
|
||||
domain_priors: HashMap::new(),
|
||||
arms: arms.into_iter().map(ArmId).collect(),
|
||||
difficulty_tiers: vec!["easy".into(), "medium".into(), "hard".into()],
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize a domain with uniform priors.
|
||||
pub fn init_domain_uniform(&mut self, domain_id: DomainId) {
|
||||
self.domain_priors
|
||||
.insert(domain_id.clone(), TransferPrior::uniform(domain_id));
|
||||
}
|
||||
|
||||
/// Initialize a domain using transfer priors from a source domain.
|
||||
/// This is the key mechanism: Meta-TS seeds new domains with learned priors.
|
||||
pub fn init_domain_with_transfer(
|
||||
&mut self,
|
||||
target_domain: DomainId,
|
||||
source_prior: &TransferPrior,
|
||||
) {
|
||||
let mut prior = TransferPrior::uniform(target_domain.clone());
|
||||
|
||||
// Copy bucket priors from source, scaling by confidence
|
||||
for (bucket, arms) in &source_prior.bucket_priors {
|
||||
for (arm, params) in arms {
|
||||
// Dampen the prior: don't fully trust cross-domain evidence.
|
||||
// Use sqrt scaling: reduces confidence while preserving mean.
|
||||
let dampened = BetaParams {
|
||||
alpha: 1.0 + (params.alpha - 1.0).sqrt(),
|
||||
beta: 1.0 + (params.beta - 1.0).sqrt(),
|
||||
};
|
||||
prior
|
||||
.bucket_priors
|
||||
.entry(bucket.clone())
|
||||
.or_default()
|
||||
.insert(arm.clone(), dampened);
|
||||
}
|
||||
}
|
||||
|
||||
// Transfer cost EMAs with dampening
|
||||
for (bucket, &cost) in &source_prior.cost_ema_priors {
|
||||
prior.cost_ema_priors.insert(bucket.clone(), cost * 1.5); // pessimistic transfer
|
||||
}
|
||||
|
||||
prior.witness_hash = format!("transfer_from_{}", source_prior.source_domain);
|
||||
self.domain_priors.insert(target_domain, prior);
|
||||
}
|
||||
|
||||
/// Select an arm for a given domain and context using Thompson Sampling.
|
||||
pub fn select_arm(
|
||||
&self,
|
||||
domain_id: &DomainId,
|
||||
bucket: &ContextBucket,
|
||||
rng: &mut impl Rng,
|
||||
) -> Option<ArmId> {
|
||||
let prior = self.domain_priors.get(domain_id)?;
|
||||
|
||||
let mut best_arm = None;
|
||||
let mut best_sample = f32::NEG_INFINITY;
|
||||
|
||||
for arm in &self.arms {
|
||||
let params = prior.get_prior(bucket, arm);
|
||||
let sample = params.sample(rng);
|
||||
if sample > best_sample {
|
||||
best_sample = sample;
|
||||
best_arm = Some(arm.clone());
|
||||
}
|
||||
}
|
||||
|
||||
best_arm
|
||||
}
|
||||
|
||||
/// Record the outcome of using an arm in a domain.
|
||||
pub fn record_outcome(
|
||||
&mut self,
|
||||
domain_id: &DomainId,
|
||||
bucket: ContextBucket,
|
||||
arm: ArmId,
|
||||
reward: f32,
|
||||
cost: f32,
|
||||
) {
|
||||
if let Some(prior) = self.domain_priors.get_mut(domain_id) {
|
||||
prior.update_posterior(bucket.clone(), arm, reward);
|
||||
prior.update_cost_ema(bucket, cost, 0.9);
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract transfer prior from a domain (for shipping to another domain).
|
||||
pub fn extract_prior(&self, domain_id: &DomainId) -> Option<TransferPrior> {
|
||||
self.domain_priors
|
||||
.get(domain_id)
|
||||
.map(|p| p.extract_summary())
|
||||
}
|
||||
|
||||
/// Get all domain IDs currently tracked.
|
||||
pub fn domain_ids(&self) -> Vec<&DomainId> {
|
||||
self.domain_priors.keys().collect()
|
||||
}
|
||||
|
||||
/// Check if posterior variance is high (triggers speculative dual-path).
|
||||
pub fn is_uncertain(
|
||||
&self,
|
||||
domain_id: &DomainId,
|
||||
bucket: &ContextBucket,
|
||||
threshold: f32,
|
||||
) -> bool {
|
||||
let prior = match self.domain_priors.get(domain_id) {
|
||||
Some(p) => p,
|
||||
None => return true, // No data = maximum uncertainty
|
||||
};
|
||||
|
||||
// Check if top two arms are within delta of each other
|
||||
let mut samples: Vec<(f32, &ArmId)> = self
|
||||
.arms
|
||||
.iter()
|
||||
.map(|arm| {
|
||||
let params = prior.get_prior(bucket, arm);
|
||||
(params.mean(), arm)
|
||||
})
|
||||
.collect();
|
||||
samples.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
if samples.len() < 2 {
|
||||
return true;
|
||||
}
|
||||
|
||||
let gap = samples[0].0 - samples[1].0;
|
||||
gap < threshold
|
||||
}
|
||||
}
|
||||
|
||||
/// Speculative dual-path execution for high-uncertainty decisions.
|
||||
/// When the top two arms are within delta, run both and pick the winner.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DualPathResult {
|
||||
/// Primary arm and its outcome.
|
||||
pub primary: (ArmId, f32),
|
||||
/// Secondary arm and its outcome.
|
||||
pub secondary: (ArmId, f32),
|
||||
/// Which arm won.
|
||||
pub winner: ArmId,
|
||||
/// The loser becomes a counterexample for that context.
|
||||
pub counterexample: ArmId,
|
||||
}
|
||||
|
||||
/// Cross-domain transfer verification result.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TransferVerification {
|
||||
/// Source domain.
|
||||
pub source: DomainId,
|
||||
/// Target domain.
|
||||
pub target: DomainId,
|
||||
/// Did transfer improve the target domain?
|
||||
pub improved_target: bool,
|
||||
/// Did transfer regress the source domain?
|
||||
pub regressed_source: bool,
|
||||
/// Is this delta promotable? (improved target AND not regressed source).
|
||||
pub promotable: bool,
|
||||
/// Acceleration factor: ratio of convergence speeds.
|
||||
pub acceleration_factor: f32,
|
||||
/// Source score before/after.
|
||||
pub source_scores: (f32, f32),
|
||||
/// Target score before/after.
|
||||
pub target_scores: (f32, f32),
|
||||
}
|
||||
|
||||
impl TransferVerification {
|
||||
/// Verify a transfer delta against the generalization rule:
|
||||
/// promotable iff it improves Domain 2 without regressing Domain 1.
|
||||
pub fn verify(
|
||||
source: DomainId,
|
||||
target: DomainId,
|
||||
source_before: f32,
|
||||
source_after: f32,
|
||||
target_before: f32,
|
||||
target_after: f32,
|
||||
target_baseline_cycles: u64,
|
||||
target_transfer_cycles: u64,
|
||||
) -> Self {
|
||||
let improved_target = target_after > target_before;
|
||||
let regressed_source = source_after < source_before - 0.01; // small tolerance
|
||||
|
||||
let promotable = improved_target && !regressed_source;
|
||||
|
||||
// Acceleration = baseline_cycles / transfer_cycles (higher = better transfer)
|
||||
let acceleration_factor = if target_transfer_cycles > 0 {
|
||||
target_baseline_cycles as f32 / target_transfer_cycles as f32
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
Self {
|
||||
source,
|
||||
target,
|
||||
improved_target,
|
||||
regressed_source,
|
||||
promotable,
|
||||
acceleration_factor,
|
||||
source_scores: (source_before, source_after),
|
||||
target_scores: (target_before, target_after),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_beta_params_uniform() {
|
||||
let p = BetaParams::uniform();
|
||||
assert_eq!(p.alpha, 1.0);
|
||||
assert_eq!(p.beta, 1.0);
|
||||
assert!((p.mean() - 0.5).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_beta_params_update() {
|
||||
let mut p = BetaParams::uniform();
|
||||
p.update(1.0); // success
|
||||
assert_eq!(p.alpha, 2.0);
|
||||
assert_eq!(p.beta, 1.0);
|
||||
assert!(p.mean() > 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_beta_params_sample_in_range() {
|
||||
let p = BetaParams::from_observations(10.0, 5.0);
|
||||
let mut rng = rand::thread_rng();
|
||||
for _ in 0..100 {
|
||||
let s = p.sample(&mut rng);
|
||||
assert!(s >= 0.0 && s <= 1.0, "Sample {} out of [0,1]", s);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transfer_prior_round_trip() {
|
||||
let domain = DomainId("test".into());
|
||||
let mut prior = TransferPrior::uniform(domain);
|
||||
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "easy".into(),
|
||||
category: "transform".into(),
|
||||
};
|
||||
let arm = ArmId("strategy_a".into());
|
||||
|
||||
for _ in 0..20 {
|
||||
prior.update_posterior(bucket.clone(), arm.clone(), 0.8);
|
||||
}
|
||||
|
||||
let summary = prior.extract_summary();
|
||||
assert!(!summary.bucket_priors.is_empty());
|
||||
|
||||
let retrieved = summary.get_prior(&bucket, &arm);
|
||||
assert!(retrieved.mean() > 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_meta_thompson_engine() {
|
||||
let mut engine = MetaThompsonEngine::new(vec![
|
||||
"strategy_a".into(),
|
||||
"strategy_b".into(),
|
||||
"strategy_c".into(),
|
||||
]);
|
||||
|
||||
let domain1 = DomainId("rust_synthesis".into());
|
||||
engine.init_domain_uniform(domain1.clone());
|
||||
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "medium".into(),
|
||||
category: "algorithm".into(),
|
||||
};
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
// Record some outcomes
|
||||
for _ in 0..50 {
|
||||
let arm = engine.select_arm(&domain1, &bucket, &mut rng).unwrap();
|
||||
let reward = if arm.0 == "strategy_a" { 0.9 } else { 0.3 };
|
||||
engine.record_outcome(&domain1, bucket.clone(), arm, reward, 1.0);
|
||||
}
|
||||
|
||||
// Extract prior and transfer to domain2
|
||||
let prior = engine.extract_prior(&domain1).unwrap();
|
||||
let domain2 = DomainId("planning".into());
|
||||
engine.init_domain_with_transfer(domain2.clone(), &prior);
|
||||
|
||||
// Domain2 should now have informative priors
|
||||
let d2_prior = engine.domain_priors.get(&domain2).unwrap();
|
||||
let a_params = d2_prior.get_prior(&bucket, &ArmId("strategy_a".into()));
|
||||
assert!(
|
||||
a_params.mean() > 0.5,
|
||||
"Transferred prior should favor strategy_a"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transfer_verification() {
|
||||
let v = TransferVerification::verify(
|
||||
DomainId("d1".into()),
|
||||
DomainId("d2".into()),
|
||||
0.8, // source before
|
||||
0.79, // source after (slight decrease, within tolerance)
|
||||
0.3, // target before
|
||||
0.7, // target after (big improvement)
|
||||
100, // baseline cycles
|
||||
40, // transfer cycles
|
||||
);
|
||||
|
||||
assert!(v.improved_target);
|
||||
assert!(!v.regressed_source); // within tolerance
|
||||
assert!(v.promotable);
|
||||
assert!((v.acceleration_factor - 2.5).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transfer_not_promotable_on_regression() {
|
||||
let v = TransferVerification::verify(
|
||||
DomainId("d1".into()),
|
||||
DomainId("d2".into()),
|
||||
0.8, // source before
|
||||
0.5, // source after (regression!)
|
||||
0.3, // target before
|
||||
0.7, // target after
|
||||
100,
|
||||
40,
|
||||
);
|
||||
|
||||
assert!(v.improved_target);
|
||||
assert!(v.regressed_source);
|
||||
assert!(!v.promotable);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_uncertainty_detection() {
|
||||
let mut engine = MetaThompsonEngine::new(vec!["a".into(), "b".into()]);
|
||||
|
||||
let domain = DomainId("test".into());
|
||||
engine.init_domain_uniform(domain.clone());
|
||||
|
||||
let bucket = ContextBucket {
|
||||
difficulty_tier: "easy".into(),
|
||||
category: "test".into(),
|
||||
};
|
||||
|
||||
// With uniform priors, should be uncertain
|
||||
assert!(engine.is_uncertain(&domain, &bucket, 0.1));
|
||||
|
||||
// After many observations favoring one arm, should be certain
|
||||
for _ in 0..100 {
|
||||
engine.record_outcome(&domain, bucket.clone(), ArmId("a".into()), 0.95, 1.0);
|
||||
engine.record_outcome(&domain, bucket.clone(), ArmId("b".into()), 0.1, 1.0);
|
||||
}
|
||||
|
||||
assert!(!engine.is_uncertain(&domain, &bucket, 0.1));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user