Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
405
vendor/ruvector/crates/sona/src/engine.rs
vendored
Normal file
405
vendor/ruvector/crates/sona/src/engine.rs
vendored
Normal file
@@ -0,0 +1,405 @@
|
||||
//! SONA Engine - Main interface for self-optimizing neural architecture
|
||||
|
||||
use crate::loops::coordinator::{CoordinatorStats, LoopCoordinator};
|
||||
use crate::trajectory::TrajectoryBuilder;
|
||||
use crate::types::{QueryTrajectory, SonaConfig};
|
||||
|
||||
/// Main SONA engine integrating all components
|
||||
pub struct SonaEngine {
|
||||
/// Loop coordinator
|
||||
coordinator: LoopCoordinator,
|
||||
/// Configuration
|
||||
config: SonaConfig,
|
||||
/// Whether engine is enabled
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for SonaEngine {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("SonaEngine")
|
||||
.field("config", &self.config)
|
||||
.field("enabled", &self.enabled)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl SonaEngine {
|
||||
/// Create new SONA engine with default config
|
||||
pub fn new(hidden_dim: usize) -> Self {
|
||||
Self::with_config(SonaConfig {
|
||||
hidden_dim,
|
||||
embedding_dim: hidden_dim,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with custom config
|
||||
pub fn with_config(config: SonaConfig) -> Self {
|
||||
Self {
|
||||
coordinator: LoopCoordinator::with_config(config.clone()),
|
||||
config,
|
||||
enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start trajectory recording for a query
|
||||
pub fn begin_trajectory(&self, query_embedding: Vec<f32>) -> TrajectoryBuilder {
|
||||
let id = self.coordinator.next_trajectory_id();
|
||||
TrajectoryBuilder::new(id, query_embedding)
|
||||
}
|
||||
|
||||
/// Complete trajectory and submit for learning
|
||||
pub fn end_trajectory(&self, builder: TrajectoryBuilder, quality: f32) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
}
|
||||
|
||||
let trajectory = builder.build(quality);
|
||||
self.coordinator.on_inference(trajectory);
|
||||
}
|
||||
|
||||
/// Submit pre-built trajectory
|
||||
pub fn submit_trajectory(&self, trajectory: QueryTrajectory) {
|
||||
if self.enabled {
|
||||
self.coordinator.on_inference(trajectory);
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply micro-LoRA to hidden states
|
||||
pub fn apply_micro_lora(&self, input: &[f32], output: &mut [f32]) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(lora) = self.coordinator.micro_lora().try_read() {
|
||||
lora.forward(input, output);
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply base-LoRA to layer output
|
||||
pub fn apply_base_lora(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(lora) = self.coordinator.base_lora().try_read() {
|
||||
lora.forward_layer(layer_idx, input, output);
|
||||
}
|
||||
}
|
||||
|
||||
/// Run background learning cycle if due
|
||||
pub fn tick(&self) -> Option<String> {
|
||||
if !self.enabled {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(result) = self.coordinator.maybe_run_background() {
|
||||
Some(format!(
|
||||
"Background cycle: {} trajectories -> {} patterns in {:?}",
|
||||
result.trajectories_processed, result.patterns_extracted, result.elapsed
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Force background learning cycle
|
||||
pub fn force_learn(&self) -> String {
|
||||
let result = self.coordinator.force_background();
|
||||
format!(
|
||||
"Forced learning: {} trajectories -> {} patterns, status: {}",
|
||||
result.trajectories_processed, result.patterns_extracted, result.status
|
||||
)
|
||||
}
|
||||
|
||||
/// Flush instant loop updates
|
||||
pub fn flush(&self) {
|
||||
self.coordinator.flush_instant();
|
||||
}
|
||||
|
||||
/// Find similar patterns to query
|
||||
pub fn find_patterns(&self, query_embedding: &[f32], k: usize) -> Vec<crate::LearnedPattern> {
|
||||
self.coordinator
|
||||
.reasoning_bank()
|
||||
.read()
|
||||
.find_similar(query_embedding, k)
|
||||
.into_iter()
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get engine statistics
|
||||
pub fn stats(&self) -> CoordinatorStats {
|
||||
self.coordinator.stats()
|
||||
}
|
||||
|
||||
/// Enable/disable engine
|
||||
pub fn set_enabled(&mut self, enabled: bool) {
|
||||
self.enabled = enabled;
|
||||
}
|
||||
|
||||
/// Check if enabled
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.enabled
|
||||
}
|
||||
|
||||
/// Get config
|
||||
pub fn config(&self) -> &SonaConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Get all learned patterns from the reasoning bank
|
||||
#[cfg(feature = "serde-support")]
|
||||
pub fn get_all_patterns(&self) -> Vec<crate::LearnedPattern> {
|
||||
self.coordinator.reasoning_bank().read().get_all_patterns()
|
||||
}
|
||||
|
||||
/// Export LoRA state for serialization
|
||||
#[cfg(feature = "serde-support")]
|
||||
pub fn export_lora_state(&self) -> crate::export::safetensors::LoRAState {
|
||||
use crate::export::safetensors::{LoRALayerState, LoRAState};
|
||||
|
||||
let mut state = LoRAState::default();
|
||||
|
||||
// Export MicroLoRA (single layer)
|
||||
if let Some(lora) = self.coordinator.micro_lora().try_read() {
|
||||
let (down, up) = lora.get_weights();
|
||||
state.micro_lora_layers.push(LoRALayerState {
|
||||
lora_a: down.clone(),
|
||||
lora_b: up.clone(),
|
||||
rank: self.config.micro_lora_rank,
|
||||
input_dim: self.config.hidden_dim,
|
||||
output_dim: self.config.hidden_dim,
|
||||
});
|
||||
}
|
||||
|
||||
// Export BaseLoRA (multi-layer)
|
||||
if let Some(lora) = self.coordinator.base_lora().try_read() {
|
||||
for idx in 0..lora.num_layers() {
|
||||
if let Some((down, up)) = lora.get_layer_weights(idx) {
|
||||
state.base_lora_layers.push(LoRALayerState {
|
||||
lora_a: down.clone(),
|
||||
lora_b: up.clone(),
|
||||
rank: lora.rank,
|
||||
input_dim: lora.hidden_dim,
|
||||
output_dim: lora.hidden_dim,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
state
|
||||
}
|
||||
|
||||
/// Get quality trajectories for preference learning export
|
||||
#[cfg(feature = "serde-support")]
|
||||
pub fn get_quality_trajectories(&self) -> Vec<crate::export::dataset::QualityTrajectory> {
|
||||
use crate::export::dataset::QualityTrajectory;
|
||||
|
||||
// Get buffered trajectories from the instant loop via coordinator
|
||||
let trajectories = self.coordinator.reasoning_bank().read().get_all_patterns();
|
||||
|
||||
trajectories
|
||||
.iter()
|
||||
.map(|p| {
|
||||
QualityTrajectory {
|
||||
query_embedding: p.centroid.clone(),
|
||||
response_embedding: p.centroid.clone(), // Use centroid as proxy
|
||||
route: p.pattern_type.to_string(),
|
||||
quality: p.avg_quality,
|
||||
context_ids: vec![],
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get routing decisions for distillation export
|
||||
#[cfg(feature = "serde-support")]
|
||||
pub fn get_routing_decisions(&self) -> Vec<crate::export::dataset::RoutingDecision> {
|
||||
use crate::export::dataset::RoutingDecision;
|
||||
|
||||
let patterns = self.coordinator.reasoning_bank().read().get_all_patterns();
|
||||
|
||||
patterns
|
||||
.iter()
|
||||
.map(|p| {
|
||||
RoutingDecision {
|
||||
query_embedding: p.centroid.clone(),
|
||||
routing_logits: vec![p.avg_quality], // Simplified
|
||||
selected_route: p.pattern_type.to_string(),
|
||||
confidence: p.avg_quality,
|
||||
quality: p.avg_quality,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for SonaEngine
|
||||
pub struct SonaEngineBuilder {
|
||||
config: SonaConfig,
|
||||
}
|
||||
|
||||
impl SonaEngineBuilder {
|
||||
/// Create new builder
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
config: SonaConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set hidden dimension
|
||||
pub fn hidden_dim(mut self, dim: usize) -> Self {
|
||||
self.config.hidden_dim = dim;
|
||||
self.config.embedding_dim = dim;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set micro-LoRA rank
|
||||
pub fn micro_lora_rank(mut self, rank: usize) -> Self {
|
||||
self.config.micro_lora_rank = rank.clamp(1, 2);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set base-LoRA rank
|
||||
pub fn base_lora_rank(mut self, rank: usize) -> Self {
|
||||
self.config.base_lora_rank = rank;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set micro-LoRA learning rate
|
||||
pub fn micro_lr(mut self, lr: f32) -> Self {
|
||||
self.config.micro_lora_lr = lr;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set base-LoRA learning rate
|
||||
pub fn base_lr(mut self, lr: f32) -> Self {
|
||||
self.config.base_lora_lr = lr;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set EWC lambda
|
||||
pub fn ewc_lambda(mut self, lambda: f32) -> Self {
|
||||
self.config.ewc_lambda = lambda;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set pattern clusters
|
||||
pub fn pattern_clusters(mut self, k: usize) -> Self {
|
||||
self.config.pattern_clusters = k;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set trajectory buffer capacity
|
||||
pub fn buffer_capacity(mut self, capacity: usize) -> Self {
|
||||
self.config.trajectory_capacity = capacity;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set quality threshold
|
||||
pub fn quality_threshold(mut self, threshold: f32) -> Self {
|
||||
self.config.quality_threshold = threshold;
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the engine
|
||||
pub fn build(self) -> SonaEngine {
|
||||
SonaEngine::with_config(self.config)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SonaEngineBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::TrajectoryStep;
|
||||
|
||||
#[test]
|
||||
fn test_engine_creation() {
|
||||
let engine = SonaEngine::new(256);
|
||||
assert!(engine.is_enabled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder() {
|
||||
let engine = SonaEngineBuilder::new()
|
||||
.hidden_dim(512)
|
||||
.micro_lora_rank(2)
|
||||
.base_lora_rank(16)
|
||||
.micro_lr(0.002)
|
||||
.ewc_lambda(500.0)
|
||||
.build();
|
||||
|
||||
assert_eq!(engine.config().hidden_dim, 512);
|
||||
assert_eq!(engine.config().micro_lora_rank, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trajectory_workflow() {
|
||||
let engine = SonaEngine::new(64);
|
||||
|
||||
// Begin trajectory
|
||||
let mut builder = engine.begin_trajectory(vec![0.1; 64]);
|
||||
builder.add_step(vec![0.5; 64], vec![], 0.8);
|
||||
builder.add_step(vec![0.6; 64], vec![], 0.9);
|
||||
|
||||
// End trajectory
|
||||
engine.end_trajectory(builder, 0.85);
|
||||
|
||||
let stats = engine.stats();
|
||||
assert_eq!(stats.trajectories_buffered, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_micro_lora_application() {
|
||||
let engine = SonaEngine::new(64);
|
||||
|
||||
// Train a bit first
|
||||
for i in 0..10 {
|
||||
let mut builder = engine.begin_trajectory(vec![0.1; 64]);
|
||||
builder.add_step(vec![0.5; 64], vec![], 0.8);
|
||||
engine.end_trajectory(builder, 0.8);
|
||||
}
|
||||
engine.flush();
|
||||
|
||||
// Apply LoRA
|
||||
let input = vec![1.0; 64];
|
||||
let mut output = vec![0.0; 64];
|
||||
engine.apply_micro_lora(&input, &mut output);
|
||||
|
||||
// Output may or may not be modified depending on accumulated gradients
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_force_learn() {
|
||||
let engine = SonaEngine::new(256);
|
||||
|
||||
for i in 0..150 {
|
||||
let mut builder = engine.begin_trajectory(vec![0.1; 256]);
|
||||
builder.add_step(vec![0.5; 256], vec![], 0.8);
|
||||
engine.end_trajectory(builder, 0.8);
|
||||
}
|
||||
|
||||
let result = engine.force_learn();
|
||||
assert!(result.contains("150 trajectories"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_disabled_engine() {
|
||||
let mut engine = SonaEngine::new(64);
|
||||
engine.set_enabled(false);
|
||||
|
||||
let builder = engine.begin_trajectory(vec![0.1; 64]);
|
||||
engine.end_trajectory(builder, 0.8);
|
||||
|
||||
// Should not record when disabled
|
||||
let stats = engine.stats();
|
||||
assert_eq!(stats.trajectories_buffered, 0);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user