Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
510
vendor/ruvector/crates/sona/src/training/factory.rs
vendored
Normal file
510
vendor/ruvector/crates/sona/src/training/factory.rs
vendored
Normal file
@@ -0,0 +1,510 @@
|
||||
//! Agent Factory for SONA
|
||||
//!
|
||||
//! Create and manage multiple specialized agents.
|
||||
|
||||
use super::metrics::TrainingMetrics;
|
||||
use super::templates::{AgentType, TrainingTemplate};
|
||||
use crate::engine::SonaEngine;
|
||||
use crate::time_compat::SystemTime;
|
||||
use crate::types::SonaConfig;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
/// Handle to a managed agent
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AgentHandle {
|
||||
/// Agent identifier
|
||||
pub id: String,
|
||||
/// Agent type
|
||||
pub agent_type: AgentType,
|
||||
/// Creation timestamp
|
||||
pub created_at: u64,
|
||||
}
|
||||
|
||||
/// Managed agent with engine and metadata
|
||||
pub struct ManagedAgent {
|
||||
/// Agent handle
|
||||
pub handle: AgentHandle,
|
||||
/// SONA engine
|
||||
pub engine: SonaEngine,
|
||||
/// Training metrics
|
||||
pub metrics: TrainingMetrics,
|
||||
/// Purpose/description
|
||||
pub purpose: String,
|
||||
/// Training count
|
||||
pub training_count: u64,
|
||||
/// Tags for organization
|
||||
pub tags: Vec<String>,
|
||||
}
|
||||
|
||||
impl ManagedAgent {
|
||||
/// Create a new managed agent
|
||||
pub fn new(
|
||||
id: impl Into<String>,
|
||||
agent_type: AgentType,
|
||||
config: SonaConfig,
|
||||
purpose: impl Into<String>,
|
||||
) -> Self {
|
||||
let now = SystemTime::now().duration_since_epoch().as_secs();
|
||||
|
||||
let id = id.into();
|
||||
Self {
|
||||
handle: AgentHandle {
|
||||
id: id.clone(),
|
||||
agent_type,
|
||||
created_at: now,
|
||||
},
|
||||
engine: SonaEngine::with_config(config),
|
||||
metrics: TrainingMetrics::new(&id),
|
||||
purpose: purpose.into(),
|
||||
training_count: 0,
|
||||
tags: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get agent stats
|
||||
pub fn stats(&self) -> AgentStats {
|
||||
AgentStats {
|
||||
id: self.handle.id.clone(),
|
||||
agent_type: self.handle.agent_type.clone(),
|
||||
training_count: self.training_count,
|
||||
patterns_learned: self.metrics.patterns_learned,
|
||||
avg_quality: self.metrics.avg_quality(),
|
||||
total_examples: self.metrics.total_examples,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Agent statistics
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AgentStats {
|
||||
/// Agent ID
|
||||
pub id: String,
|
||||
/// Agent type
|
||||
pub agent_type: AgentType,
|
||||
/// Number of training sessions
|
||||
pub training_count: u64,
|
||||
/// Patterns learned
|
||||
pub patterns_learned: usize,
|
||||
/// Average quality score
|
||||
pub avg_quality: f32,
|
||||
/// Total examples processed
|
||||
pub total_examples: usize,
|
||||
}
|
||||
|
||||
/// Factory for creating and managing agents
|
||||
pub struct AgentFactory {
|
||||
/// Base configuration for all agents
|
||||
base_config: SonaConfig,
|
||||
/// Managed agents
|
||||
agents: HashMap<String, ManagedAgent>,
|
||||
/// Default hidden dimension
|
||||
default_hidden_dim: usize,
|
||||
}
|
||||
|
||||
impl Default for AgentFactory {
|
||||
fn default() -> Self {
|
||||
Self::new(SonaConfig::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentFactory {
|
||||
/// Create a new agent factory
|
||||
pub fn new(base_config: SonaConfig) -> Self {
|
||||
let default_hidden_dim = base_config.hidden_dim;
|
||||
Self {
|
||||
base_config,
|
||||
agents: HashMap::new(),
|
||||
default_hidden_dim,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create factory with specific hidden dimension
|
||||
pub fn with_hidden_dim(hidden_dim: usize) -> Self {
|
||||
let config = SonaConfig {
|
||||
hidden_dim,
|
||||
embedding_dim: hidden_dim,
|
||||
..SonaConfig::default()
|
||||
};
|
||||
Self::new(config)
|
||||
}
|
||||
|
||||
/// Create an agent from a template
|
||||
pub fn create_from_template(
|
||||
&mut self,
|
||||
name: impl Into<String>,
|
||||
template: &TrainingTemplate,
|
||||
) -> &ManagedAgent {
|
||||
let name = name.into();
|
||||
let agent = ManagedAgent::new(
|
||||
name.clone(),
|
||||
template.agent_type.clone(),
|
||||
template.sona_config.clone(),
|
||||
&template.name,
|
||||
);
|
||||
self.agents.insert(name.clone(), agent);
|
||||
self.agents.get(&name).unwrap()
|
||||
}
|
||||
|
||||
/// Create an agent with custom configuration
|
||||
pub fn create_agent(
|
||||
&mut self,
|
||||
name: impl Into<String>,
|
||||
agent_type: AgentType,
|
||||
purpose: impl Into<String>,
|
||||
) -> &ManagedAgent {
|
||||
let name = name.into();
|
||||
let config = self.config_for_agent_type(&agent_type);
|
||||
let mut agent = ManagedAgent::new(name.clone(), agent_type, config, purpose);
|
||||
agent.tags.push("custom".into());
|
||||
self.agents.insert(name.clone(), agent);
|
||||
self.agents.get(&name).unwrap()
|
||||
}
|
||||
|
||||
/// Create a code agent
|
||||
pub fn create_code_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
|
||||
let template = TrainingTemplate::code_agent().with_hidden_dim(self.default_hidden_dim);
|
||||
self.create_from_template(name, &template)
|
||||
}
|
||||
|
||||
/// Create a chat agent
|
||||
pub fn create_chat_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
|
||||
let template = TrainingTemplate::chat_agent().with_hidden_dim(self.default_hidden_dim);
|
||||
self.create_from_template(name, &template)
|
||||
}
|
||||
|
||||
/// Create a RAG agent
|
||||
pub fn create_rag_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
|
||||
let template = TrainingTemplate::rag_agent().with_hidden_dim(self.default_hidden_dim);
|
||||
self.create_from_template(name, &template)
|
||||
}
|
||||
|
||||
/// Create a task planner agent
|
||||
pub fn create_task_planner(&mut self, name: impl Into<String>) -> &ManagedAgent {
|
||||
let template = TrainingTemplate::task_planner().with_hidden_dim(self.default_hidden_dim);
|
||||
self.create_from_template(name, &template)
|
||||
}
|
||||
|
||||
/// Create a reasoning agent
|
||||
pub fn create_reasoning_agent(&mut self, name: impl Into<String>) -> &ManagedAgent {
|
||||
let template = TrainingTemplate::reasoning_agent().with_hidden_dim(self.default_hidden_dim);
|
||||
self.create_from_template(name, &template)
|
||||
}
|
||||
|
||||
/// Create a codebase helper agent
|
||||
pub fn create_codebase_helper(&mut self, name: impl Into<String>) -> &ManagedAgent {
|
||||
let template = TrainingTemplate::codebase_helper().with_hidden_dim(self.default_hidden_dim);
|
||||
self.create_from_template(name, &template)
|
||||
}
|
||||
|
||||
/// Get an agent by name
|
||||
pub fn get_agent(&self, name: &str) -> Option<&ManagedAgent> {
|
||||
self.agents.get(name)
|
||||
}
|
||||
|
||||
/// Get a mutable agent by name
|
||||
pub fn get_agent_mut(&mut self, name: &str) -> Option<&mut ManagedAgent> {
|
||||
self.agents.get_mut(name)
|
||||
}
|
||||
|
||||
/// Remove an agent
|
||||
pub fn remove_agent(&mut self, name: &str) -> Option<ManagedAgent> {
|
||||
self.agents.remove(name)
|
||||
}
|
||||
|
||||
/// List all agents
|
||||
pub fn list_agents(&self) -> Vec<AgentStats> {
|
||||
self.agents.values().map(|a| a.stats()).collect()
|
||||
}
|
||||
|
||||
/// Get agent count
|
||||
pub fn agent_count(&self) -> usize {
|
||||
self.agents.len()
|
||||
}
|
||||
|
||||
/// Train an agent with examples
|
||||
pub fn train_agent<E>(
|
||||
&mut self,
|
||||
name: &str,
|
||||
examples: impl Iterator<Item = E>,
|
||||
) -> Result<usize, String>
|
||||
where
|
||||
E: TrainingExample,
|
||||
{
|
||||
let agent = self
|
||||
.agents
|
||||
.get_mut(name)
|
||||
.ok_or_else(|| format!("Agent '{}' not found", name))?;
|
||||
|
||||
let mut count = 0;
|
||||
for example in examples {
|
||||
// Use builder-based trajectory API
|
||||
let mut builder = agent.engine.begin_trajectory(example.embedding());
|
||||
|
||||
// Set route if available
|
||||
if let Some(route) = example.route() {
|
||||
builder.set_model_route(&route);
|
||||
}
|
||||
|
||||
// Add context if available
|
||||
for ctx in example.context() {
|
||||
builder.add_context(&ctx);
|
||||
}
|
||||
|
||||
// Add step with activations
|
||||
builder.add_step(example.activations(), example.attention(), example.reward());
|
||||
|
||||
// End trajectory with quality
|
||||
agent.engine.end_trajectory(builder, example.quality());
|
||||
|
||||
count += 1;
|
||||
agent.metrics.total_examples += 1;
|
||||
agent.metrics.add_quality_sample(example.quality());
|
||||
}
|
||||
|
||||
// Force learning after batch
|
||||
agent.engine.force_learn();
|
||||
agent.training_count += 1;
|
||||
agent.metrics.training_sessions += 1;
|
||||
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
/// Get configuration for agent type
|
||||
fn config_for_agent_type(&self, agent_type: &AgentType) -> SonaConfig {
|
||||
let mut config = self.base_config.clone();
|
||||
|
||||
match agent_type {
|
||||
AgentType::CodeAgent | AgentType::CodebaseHelper => {
|
||||
config.base_lora_rank = 16;
|
||||
config.pattern_clusters = 200;
|
||||
config.quality_threshold = 0.2;
|
||||
}
|
||||
AgentType::ChatAgent => {
|
||||
config.base_lora_rank = 8;
|
||||
config.pattern_clusters = 50;
|
||||
config.quality_threshold = 0.4;
|
||||
}
|
||||
AgentType::RagAgent => {
|
||||
config.pattern_clusters = 200;
|
||||
config.trajectory_capacity = 10000;
|
||||
}
|
||||
AgentType::TaskPlanner => {
|
||||
config.base_lora_rank = 16;
|
||||
config.ewc_lambda = 2000.0;
|
||||
}
|
||||
AgentType::ReasoningAgent => {
|
||||
config.base_lora_rank = 16;
|
||||
config.ewc_lambda = 3000.0;
|
||||
config.pattern_clusters = 150;
|
||||
}
|
||||
AgentType::DomainExpert => {
|
||||
config.quality_threshold = 0.1;
|
||||
config.trajectory_capacity = 20000;
|
||||
}
|
||||
AgentType::DataAnalyst => {
|
||||
config.base_lora_rank = 8;
|
||||
config.pattern_clusters = 100;
|
||||
}
|
||||
AgentType::CreativeWriter => {
|
||||
config.base_lora_rank = 8;
|
||||
config.pattern_clusters = 50;
|
||||
config.quality_threshold = 0.5;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
config
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for training examples
|
||||
pub trait TrainingExample {
|
||||
/// Get embedding vector
|
||||
fn embedding(&self) -> Vec<f32>;
|
||||
|
||||
/// Get activations (can be same as embedding)
|
||||
fn activations(&self) -> Vec<f32> {
|
||||
self.embedding()
|
||||
}
|
||||
|
||||
/// Get attention weights
|
||||
fn attention(&self) -> Vec<f32> {
|
||||
vec![1.0 / 64.0; 64]
|
||||
}
|
||||
|
||||
/// Get reward signal
|
||||
fn reward(&self) -> f32 {
|
||||
self.quality()
|
||||
}
|
||||
|
||||
/// Get quality score
|
||||
fn quality(&self) -> f32;
|
||||
|
||||
/// Get optional route
|
||||
fn route(&self) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Get context identifiers
|
||||
fn context(&self) -> Vec<String> {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple training example implementation
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SimpleExample {
|
||||
/// Embedding vector
|
||||
pub embedding: Vec<f32>,
|
||||
/// Quality score
|
||||
pub quality: f32,
|
||||
/// Optional route
|
||||
pub route: Option<String>,
|
||||
/// Context IDs
|
||||
pub context: Vec<String>,
|
||||
}
|
||||
|
||||
impl SimpleExample {
|
||||
/// Create a new simple example
|
||||
pub fn new(embedding: Vec<f32>, quality: f32) -> Self {
|
||||
Self {
|
||||
embedding,
|
||||
quality,
|
||||
route: None,
|
||||
context: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set route
|
||||
pub fn with_route(mut self, route: impl Into<String>) -> Self {
|
||||
self.route = Some(route.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Add context
|
||||
pub fn with_context(mut self, ctx: impl Into<String>) -> Self {
|
||||
self.context.push(ctx.into());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl TrainingExample for SimpleExample {
|
||||
fn embedding(&self) -> Vec<f32> {
|
||||
self.embedding.clone()
|
||||
}
|
||||
|
||||
fn quality(&self) -> f32 {
|
||||
self.quality
|
||||
}
|
||||
|
||||
fn route(&self) -> Option<String> {
|
||||
self.route.clone()
|
||||
}
|
||||
|
||||
fn context(&self) -> Vec<String> {
|
||||
self.context.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Thread-safe agent factory wrapper
|
||||
pub struct SharedAgentFactory {
|
||||
inner: Arc<RwLock<AgentFactory>>,
|
||||
}
|
||||
|
||||
impl SharedAgentFactory {
|
||||
/// Create a new shared factory
|
||||
pub fn new(config: SonaConfig) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(RwLock::new(AgentFactory::new(config))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get read access to factory
|
||||
pub fn read(&self) -> std::sync::RwLockReadGuard<'_, AgentFactory> {
|
||||
self.inner.read().unwrap()
|
||||
}
|
||||
|
||||
/// Get write access to factory
|
||||
pub fn write(&self) -> std::sync::RwLockWriteGuard<'_, AgentFactory> {
|
||||
self.inner.write().unwrap()
|
||||
}
|
||||
|
||||
/// Clone the Arc
|
||||
pub fn clone_arc(&self) -> Self {
|
||||
Self {
|
||||
inner: Arc::clone(&self.inner),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for SharedAgentFactory {
|
||||
fn clone(&self) -> Self {
|
||||
self.clone_arc()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_factory_creation() {
|
||||
let factory = AgentFactory::default();
|
||||
assert_eq!(factory.agent_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_agents() {
|
||||
let mut factory = AgentFactory::with_hidden_dim(256);
|
||||
|
||||
factory.create_code_agent("code-1");
|
||||
factory.create_chat_agent("chat-1");
|
||||
factory.create_rag_agent("rag-1");
|
||||
|
||||
assert_eq!(factory.agent_count(), 3);
|
||||
assert!(factory.get_agent("code-1").is_some());
|
||||
assert!(factory.get_agent("unknown").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_from_template() {
|
||||
let mut factory = AgentFactory::with_hidden_dim(256);
|
||||
let template = TrainingTemplate::reasoning_agent().with_hidden_dim(256);
|
||||
|
||||
factory.create_from_template("reasoner", &template);
|
||||
|
||||
let agent = factory.get_agent("reasoner").unwrap();
|
||||
assert_eq!(agent.handle.agent_type, AgentType::ReasoningAgent);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_train_agent() {
|
||||
let mut factory = AgentFactory::with_hidden_dim(256);
|
||||
factory.create_chat_agent("bot");
|
||||
|
||||
let examples = vec![
|
||||
SimpleExample::new(vec![0.1; 256], 0.8).with_route("greeting"),
|
||||
SimpleExample::new(vec![0.2; 256], 0.9).with_route("question"),
|
||||
SimpleExample::new(vec![0.3; 256], 0.7).with_route("farewell"),
|
||||
];
|
||||
|
||||
let count = factory.train_agent("bot", examples.into_iter()).unwrap();
|
||||
assert_eq!(count, 3);
|
||||
|
||||
let agent = factory.get_agent("bot").unwrap();
|
||||
assert_eq!(agent.training_count, 1);
|
||||
assert_eq!(agent.metrics.total_examples, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_list_agents() {
|
||||
let mut factory = AgentFactory::with_hidden_dim(256);
|
||||
factory.create_code_agent("code");
|
||||
factory.create_chat_agent("chat");
|
||||
|
||||
let agents = factory.list_agents();
|
||||
assert_eq!(agents.len(), 2);
|
||||
}
|
||||
}
|
||||
681
vendor/ruvector/crates/sona/src/training/federated.rs
vendored
Normal file
681
vendor/ruvector/crates/sona/src/training/federated.rs
vendored
Normal file
@@ -0,0 +1,681 @@
|
||||
//! Federated Learning for SONA
|
||||
//!
|
||||
//! Enable distributed learning across ephemeral agents that share
|
||||
//! trajectories with a central coordinator.
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! ┌─────────────┐ ┌─────────────┐ ┌─────────────┐
|
||||
//! │ Agent A │ │ Agent B │ │ Agent C │
|
||||
//! │ (ephemeral) │ │ (ephemeral) │ │ (ephemeral) │
|
||||
//! └──────┬──────┘ └──────┬──────┘ └──────┬──────┘
|
||||
//! │ │ │
|
||||
//! │ export() │ export() │ export()
|
||||
//! ▼ ▼ ▼
|
||||
//! ┌────────────────────────────────────────────────┐
|
||||
//! │ Federated Coordinator │
|
||||
//! │ (persistent, large capacity) │
|
||||
//! └────────────────────────────────────────────────┘
|
||||
//! ```
|
||||
|
||||
use super::metrics::TrainingMetrics;
|
||||
use crate::engine::SonaEngine;
|
||||
use crate::time_compat::SystemTime;
|
||||
use crate::types::{LearnedPattern, SonaConfig};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Exported state from an ephemeral agent
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AgentExport {
|
||||
/// Agent identifier
|
||||
pub agent_id: String,
|
||||
/// Exported trajectories (embedding, quality pairs)
|
||||
pub trajectories: Vec<TrajectoryExport>,
|
||||
/// Agent statistics
|
||||
pub stats: AgentExportStats,
|
||||
/// Session duration in milliseconds
|
||||
pub session_duration_ms: u64,
|
||||
/// Export timestamp
|
||||
pub timestamp: u64,
|
||||
}
|
||||
|
||||
/// Single trajectory export
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TrajectoryExport {
|
||||
/// Query embedding
|
||||
pub embedding: Vec<f32>,
|
||||
/// Quality score
|
||||
pub quality: f32,
|
||||
/// Model route (if any)
|
||||
pub route: Option<String>,
|
||||
/// Context identifiers
|
||||
pub context: Vec<String>,
|
||||
/// Timestamp
|
||||
pub timestamp: u64,
|
||||
}
|
||||
|
||||
/// Agent export statistics
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
pub struct AgentExportStats {
|
||||
/// Total trajectories processed
|
||||
pub total_trajectories: usize,
|
||||
/// Average quality
|
||||
pub avg_quality: f32,
|
||||
/// Patterns learned locally
|
||||
pub patterns_learned: usize,
|
||||
}
|
||||
|
||||
/// Ephemeral agent for federated learning
|
||||
///
|
||||
/// Collects trajectories during its session and exports state before termination.
|
||||
pub struct EphemeralAgent {
|
||||
/// Agent identifier
|
||||
agent_id: String,
|
||||
/// SONA engine
|
||||
engine: SonaEngine,
|
||||
/// Collected trajectories
|
||||
trajectories: Vec<TrajectoryExport>,
|
||||
/// Session start time
|
||||
start_time: u64,
|
||||
/// Quality samples
|
||||
quality_samples: Vec<f32>,
|
||||
}
|
||||
|
||||
impl EphemeralAgent {
|
||||
/// Create a new ephemeral agent
|
||||
pub fn new(agent_id: impl Into<String>, config: SonaConfig) -> Self {
|
||||
let now = SystemTime::now().duration_since_epoch().as_millis() as u64;
|
||||
|
||||
Self {
|
||||
agent_id: agent_id.into(),
|
||||
engine: SonaEngine::with_config(config),
|
||||
trajectories: Vec::new(),
|
||||
start_time: now,
|
||||
quality_samples: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with default config for federated learning
|
||||
pub fn default_federated(agent_id: impl Into<String>, hidden_dim: usize) -> Self {
|
||||
Self::new(
|
||||
agent_id,
|
||||
SonaConfig {
|
||||
hidden_dim,
|
||||
embedding_dim: hidden_dim,
|
||||
micro_lora_rank: 2,
|
||||
base_lora_rank: 8,
|
||||
micro_lora_lr: 0.002,
|
||||
trajectory_capacity: 500, // Small buffer per agent
|
||||
pattern_clusters: 25,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Get agent ID
|
||||
pub fn agent_id(&self) -> &str {
|
||||
&self.agent_id
|
||||
}
|
||||
|
||||
/// Get engine reference
|
||||
pub fn engine(&self) -> &SonaEngine {
|
||||
&self.engine
|
||||
}
|
||||
|
||||
/// Get mutable engine reference
|
||||
pub fn engine_mut(&mut self) -> &mut SonaEngine {
|
||||
&mut self.engine
|
||||
}
|
||||
|
||||
/// Process a task and record trajectory
|
||||
pub fn process_trajectory(
|
||||
&mut self,
|
||||
embedding: Vec<f32>,
|
||||
activations: Vec<f32>,
|
||||
quality: f32,
|
||||
route: Option<String>,
|
||||
context: Vec<String>,
|
||||
) {
|
||||
let now = SystemTime::now().duration_since_epoch().as_millis() as u64;
|
||||
|
||||
// Record in SONA engine
|
||||
let mut builder = self.engine.begin_trajectory(embedding.clone());
|
||||
if let Some(ref r) = route {
|
||||
builder.set_model_route(r);
|
||||
}
|
||||
for ctx in &context {
|
||||
builder.add_context(ctx);
|
||||
}
|
||||
builder.add_step(activations, vec![], quality);
|
||||
self.engine.end_trajectory(builder, quality);
|
||||
|
||||
// Store for export
|
||||
self.trajectories.push(TrajectoryExport {
|
||||
embedding,
|
||||
quality,
|
||||
route,
|
||||
context,
|
||||
timestamp: now,
|
||||
});
|
||||
|
||||
self.quality_samples.push(quality);
|
||||
}
|
||||
|
||||
/// Apply micro-LoRA to hidden states
|
||||
pub fn apply_micro_lora(&self, input: &[f32], output: &mut [f32]) {
|
||||
self.engine.apply_micro_lora(input, output);
|
||||
}
|
||||
|
||||
/// Get number of collected trajectories
|
||||
pub fn trajectory_count(&self) -> usize {
|
||||
self.trajectories.len()
|
||||
}
|
||||
|
||||
/// Get average quality
|
||||
pub fn avg_quality(&self) -> f32 {
|
||||
if self.quality_samples.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
self.quality_samples.iter().sum::<f32>() / self.quality_samples.len() as f32
|
||||
}
|
||||
}
|
||||
|
||||
/// Force local learning
|
||||
pub fn force_learn(&self) -> String {
|
||||
self.engine.force_learn()
|
||||
}
|
||||
|
||||
/// Simple process task method
|
||||
pub fn process_task(&mut self, embedding: Vec<f32>, quality: f32) {
|
||||
self.process_trajectory(embedding.clone(), embedding, quality, None, vec![]);
|
||||
}
|
||||
|
||||
/// Process task with route information
|
||||
pub fn process_task_with_route(&mut self, embedding: Vec<f32>, quality: f32, route: &str) {
|
||||
self.process_trajectory(
|
||||
embedding.clone(),
|
||||
embedding,
|
||||
quality,
|
||||
Some(route.to_string()),
|
||||
vec![],
|
||||
);
|
||||
}
|
||||
|
||||
/// Get average quality (alias for avg_quality)
|
||||
pub fn average_quality(&self) -> f32 {
|
||||
self.avg_quality()
|
||||
}
|
||||
|
||||
/// Get uptime in seconds
|
||||
pub fn uptime_seconds(&self) -> u64 {
|
||||
let now = SystemTime::now().duration_since_epoch().as_millis() as u64;
|
||||
(now - self.start_time) / 1000
|
||||
}
|
||||
|
||||
/// Get agent stats
|
||||
pub fn stats(&self) -> AgentExportStats {
|
||||
let engine_stats = self.engine.stats();
|
||||
AgentExportStats {
|
||||
total_trajectories: self.trajectories.len(),
|
||||
avg_quality: self.avg_quality(),
|
||||
patterns_learned: engine_stats.patterns_stored,
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear trajectories (after export)
|
||||
pub fn clear(&mut self) {
|
||||
self.trajectories.clear();
|
||||
self.quality_samples.clear();
|
||||
}
|
||||
|
||||
/// Get learned patterns from agent
|
||||
pub fn get_patterns(&self) -> Vec<LearnedPattern> {
|
||||
self.engine.find_patterns(&[], 0)
|
||||
}
|
||||
|
||||
/// Export agent state for federation
|
||||
///
|
||||
/// Call this before terminating the agent.
|
||||
pub fn export_state(&self) -> AgentExport {
|
||||
let now = SystemTime::now().duration_since_epoch().as_millis() as u64;
|
||||
|
||||
// Force learning before export
|
||||
self.engine.force_learn();
|
||||
|
||||
let stats = self.engine.stats();
|
||||
|
||||
AgentExport {
|
||||
agent_id: self.agent_id.clone(),
|
||||
trajectories: self.trajectories.clone(),
|
||||
stats: AgentExportStats {
|
||||
total_trajectories: self.trajectories.len(),
|
||||
avg_quality: self.avg_quality(),
|
||||
patterns_learned: stats.patterns_stored,
|
||||
},
|
||||
session_duration_ms: now - self.start_time,
|
||||
timestamp: now,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Agent contribution record
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AgentContribution {
|
||||
/// Number of trajectories contributed
|
||||
pub trajectory_count: usize,
|
||||
/// Average quality of contributions
|
||||
pub avg_quality: f32,
|
||||
/// Contribution timestamp
|
||||
pub timestamp: u64,
|
||||
/// Session duration
|
||||
pub session_duration_ms: u64,
|
||||
}
|
||||
|
||||
/// Federated learning coordinator
|
||||
///
|
||||
/// Aggregates learning from multiple ephemeral agents.
|
||||
pub struct FederatedCoordinator {
|
||||
/// Coordinator identifier
|
||||
coordinator_id: String,
|
||||
/// Master SONA engine for aggregation
|
||||
master_engine: SonaEngine,
|
||||
/// Agent contributions
|
||||
contributions: HashMap<String, AgentContribution>,
|
||||
/// Quality threshold for accepting trajectories
|
||||
quality_threshold: f32,
|
||||
/// Total trajectories aggregated
|
||||
total_trajectories: usize,
|
||||
/// Consolidation interval (number of agents)
|
||||
consolidation_interval: usize,
|
||||
/// Metrics
|
||||
metrics: TrainingMetrics,
|
||||
}
|
||||
|
||||
impl FederatedCoordinator {
|
||||
/// Create a new federated coordinator
|
||||
pub fn new(coordinator_id: impl Into<String>, config: SonaConfig) -> Self {
|
||||
let id = coordinator_id.into();
|
||||
Self {
|
||||
coordinator_id: id.clone(),
|
||||
master_engine: SonaEngine::with_config(config),
|
||||
contributions: HashMap::new(),
|
||||
quality_threshold: 0.4,
|
||||
total_trajectories: 0,
|
||||
consolidation_interval: 50,
|
||||
metrics: TrainingMetrics::new(&id),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with default config for coordination
|
||||
pub fn default_coordinator(coordinator_id: impl Into<String>, hidden_dim: usize) -> Self {
|
||||
Self::new(
|
||||
coordinator_id,
|
||||
SonaConfig {
|
||||
hidden_dim,
|
||||
embedding_dim: hidden_dim,
|
||||
micro_lora_rank: 2,
|
||||
base_lora_rank: 16, // Deeper for aggregation
|
||||
trajectory_capacity: 50000, // Large central buffer
|
||||
pattern_clusters: 200,
|
||||
ewc_lambda: 2000.0, // Strong regularization
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Get coordinator ID
|
||||
pub fn coordinator_id(&self) -> &str {
|
||||
&self.coordinator_id
|
||||
}
|
||||
|
||||
/// Set quality threshold for accepting trajectories
|
||||
pub fn set_quality_threshold(&mut self, threshold: f32) {
|
||||
self.quality_threshold = threshold;
|
||||
}
|
||||
|
||||
/// Set consolidation interval
|
||||
pub fn set_consolidation_interval(&mut self, interval: usize) {
|
||||
self.consolidation_interval = interval;
|
||||
}
|
||||
|
||||
/// Get master engine reference
|
||||
pub fn master_engine(&self) -> &SonaEngine {
|
||||
&self.master_engine
|
||||
}
|
||||
|
||||
/// Aggregate agent export into coordinator
|
||||
pub fn aggregate(&mut self, export: AgentExport) -> AggregationResult {
|
||||
let mut accepted = 0;
|
||||
let mut rejected = 0;
|
||||
|
||||
// Replay trajectories into master engine
|
||||
for traj in &export.trajectories {
|
||||
if traj.quality >= self.quality_threshold {
|
||||
let mut builder = self.master_engine.begin_trajectory(traj.embedding.clone());
|
||||
if let Some(ref route) = traj.route {
|
||||
builder.set_model_route(route);
|
||||
}
|
||||
for ctx in &traj.context {
|
||||
builder.add_context(ctx);
|
||||
}
|
||||
self.master_engine.end_trajectory(builder, traj.quality);
|
||||
|
||||
self.metrics.add_quality_sample(traj.quality);
|
||||
accepted += 1;
|
||||
} else {
|
||||
rejected += 1;
|
||||
}
|
||||
}
|
||||
|
||||
self.total_trajectories += accepted;
|
||||
|
||||
// Record contribution
|
||||
let now = SystemTime::now().duration_since_epoch().as_millis() as u64;
|
||||
|
||||
self.contributions.insert(
|
||||
export.agent_id.clone(),
|
||||
AgentContribution {
|
||||
trajectory_count: export.trajectories.len(),
|
||||
avg_quality: export.stats.avg_quality,
|
||||
timestamp: now,
|
||||
session_duration_ms: export.session_duration_ms,
|
||||
},
|
||||
);
|
||||
|
||||
// Auto-consolidate if needed
|
||||
let consolidated = if self.should_consolidate() {
|
||||
self.master_engine.force_learn();
|
||||
true
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
AggregationResult {
|
||||
agent_id: export.agent_id,
|
||||
trajectories_accepted: accepted,
|
||||
trajectories_rejected: rejected,
|
||||
consolidated,
|
||||
total_agents: self.contributions.len(),
|
||||
total_trajectories: self.total_trajectories,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if consolidation is needed
|
||||
fn should_consolidate(&self) -> bool {
|
||||
self.contributions.len() % self.consolidation_interval == 0
|
||||
}
|
||||
|
||||
/// Force consolidation
|
||||
pub fn force_consolidate(&self) -> String {
|
||||
self.master_engine.force_learn()
|
||||
}
|
||||
|
||||
/// Get initial state for new agents
|
||||
///
|
||||
/// Returns learned patterns that new agents can use for warm start.
|
||||
pub fn get_initial_patterns(&self, k: usize) -> Vec<LearnedPattern> {
|
||||
// Find patterns similar to a general query (empty or average)
|
||||
// Since we don't have a specific query, get all patterns
|
||||
self.master_engine
|
||||
.find_patterns(&[], 0)
|
||||
.into_iter()
|
||||
.take(k)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get all learned patterns
|
||||
pub fn get_all_patterns(&self) -> Vec<LearnedPattern> {
|
||||
self.master_engine.find_patterns(&[], 0)
|
||||
}
|
||||
|
||||
/// Get coordinator statistics
|
||||
pub fn stats(&self) -> CoordinatorStats {
|
||||
let engine_stats = self.master_engine.stats();
|
||||
|
||||
CoordinatorStats {
|
||||
coordinator_id: self.coordinator_id.clone(),
|
||||
total_agents: self.contributions.len(),
|
||||
total_trajectories: self.total_trajectories,
|
||||
patterns_learned: engine_stats.patterns_stored,
|
||||
avg_quality: self.metrics.avg_quality(),
|
||||
quality_threshold: self.quality_threshold,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get contribution history
|
||||
pub fn contributions(&self) -> &HashMap<String, AgentContribution> {
|
||||
&self.contributions
|
||||
}
|
||||
|
||||
/// Get metrics
|
||||
pub fn metrics(&self) -> &TrainingMetrics {
|
||||
&self.metrics
|
||||
}
|
||||
|
||||
/// Get total number of contributing agents
|
||||
pub fn agent_count(&self) -> usize {
|
||||
self.contributions.len()
|
||||
}
|
||||
|
||||
/// Get total trajectories aggregated
|
||||
pub fn total_trajectories(&self) -> usize {
|
||||
self.total_trajectories
|
||||
}
|
||||
|
||||
/// Find similar patterns
|
||||
pub fn find_patterns(&self, query: &[f32], k: usize) -> Vec<LearnedPattern> {
|
||||
self.master_engine.find_patterns(query, k)
|
||||
}
|
||||
|
||||
/// Apply coordinator's LoRA to input
|
||||
pub fn apply_lora(&self, input: &[f32]) -> Vec<f32> {
|
||||
let mut output = vec![0.0; input.len()];
|
||||
self.master_engine.apply_micro_lora(input, &mut output);
|
||||
output
|
||||
}
|
||||
|
||||
/// Consolidate learning (alias for force_consolidate)
|
||||
pub fn consolidate(&self) -> String {
|
||||
self.force_consolidate()
|
||||
}
|
||||
|
||||
/// Clear all contributions
|
||||
pub fn clear(&mut self) {
|
||||
self.contributions.clear();
|
||||
self.total_trajectories = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of aggregating an agent export
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AggregationResult {
|
||||
/// Agent ID that was aggregated
|
||||
pub agent_id: String,
|
||||
/// Number of trajectories accepted
|
||||
pub trajectories_accepted: usize,
|
||||
/// Number of trajectories rejected (below quality threshold)
|
||||
pub trajectories_rejected: usize,
|
||||
/// Whether consolidation was triggered
|
||||
pub consolidated: bool,
|
||||
/// Total number of contributing agents
|
||||
pub total_agents: usize,
|
||||
/// Total trajectories in coordinator
|
||||
pub total_trajectories: usize,
|
||||
}
|
||||
|
||||
/// Coordinator statistics
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct CoordinatorStats {
|
||||
/// Coordinator identifier
|
||||
pub coordinator_id: String,
|
||||
/// Number of contributing agents
|
||||
pub total_agents: usize,
|
||||
/// Total trajectories aggregated
|
||||
pub total_trajectories: usize,
|
||||
/// Patterns learned
|
||||
pub patterns_learned: usize,
|
||||
/// Average quality across all contributions
|
||||
pub avg_quality: f32,
|
||||
/// Quality threshold
|
||||
pub quality_threshold: f32,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CoordinatorStats {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Coordinator(id={}, agents={}, trajectories={}, patterns={}, avg_quality={:.4})",
|
||||
self.coordinator_id,
|
||||
self.total_agents,
|
||||
self.total_trajectories,
|
||||
self.patterns_learned,
|
||||
self.avg_quality
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Federated learning topology
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
pub enum FederatedTopology {
|
||||
/// Agents -> Central Coordinator (simple, single aggregation point)
|
||||
#[default]
|
||||
Star,
|
||||
/// Agents -> Regional -> Global (multi-datacenter)
|
||||
Hierarchical {
|
||||
/// Number of regional coordinators
|
||||
regions: usize,
|
||||
},
|
||||
/// Agents share directly (edge deployment)
|
||||
PeerToPeer,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ephemeral_agent_creation() {
|
||||
let agent = EphemeralAgent::default_federated("agent-1", 256);
|
||||
assert_eq!(agent.agent_id(), "agent-1");
|
||||
assert_eq!(agent.trajectory_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trajectory_collection() {
|
||||
let mut agent = EphemeralAgent::default_federated("agent-1", 256);
|
||||
|
||||
agent.process_trajectory(
|
||||
vec![0.1; 256],
|
||||
vec![0.5; 256],
|
||||
0.8,
|
||||
Some("code".into()),
|
||||
vec!["file:main.rs".into()],
|
||||
);
|
||||
|
||||
assert_eq!(agent.trajectory_count(), 1);
|
||||
assert!((agent.avg_quality() - 0.8).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_export() {
|
||||
let mut agent = EphemeralAgent::default_federated("agent-1", 256);
|
||||
|
||||
for i in 0..5 {
|
||||
agent.process_trajectory(
|
||||
vec![i as f32 * 0.1; 256],
|
||||
vec![0.5; 256],
|
||||
0.7 + i as f32 * 0.05,
|
||||
None,
|
||||
vec![],
|
||||
);
|
||||
}
|
||||
|
||||
let export = agent.export_state();
|
||||
assert_eq!(export.agent_id, "agent-1");
|
||||
assert_eq!(export.trajectories.len(), 5);
|
||||
assert!(export.stats.avg_quality > 0.7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coordinator_creation() {
|
||||
let coord = FederatedCoordinator::default_coordinator("coord-1", 256);
|
||||
assert_eq!(coord.coordinator_id(), "coord-1");
|
||||
|
||||
let stats = coord.stats();
|
||||
assert_eq!(stats.total_agents, 0);
|
||||
assert_eq!(stats.total_trajectories, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregation() {
|
||||
let mut coord = FederatedCoordinator::default_coordinator("coord-1", 256);
|
||||
coord.set_quality_threshold(0.5);
|
||||
|
||||
// Create agent export
|
||||
let export = AgentExport {
|
||||
agent_id: "agent-1".into(),
|
||||
trajectories: vec![
|
||||
TrajectoryExport {
|
||||
embedding: vec![0.1; 256],
|
||||
quality: 0.8,
|
||||
route: Some("code".into()),
|
||||
context: vec![],
|
||||
timestamp: 0,
|
||||
},
|
||||
TrajectoryExport {
|
||||
embedding: vec![0.2; 256],
|
||||
quality: 0.3, // Below threshold
|
||||
route: None,
|
||||
context: vec![],
|
||||
timestamp: 0,
|
||||
},
|
||||
],
|
||||
stats: AgentExportStats {
|
||||
total_trajectories: 2,
|
||||
avg_quality: 0.55,
|
||||
patterns_learned: 0,
|
||||
},
|
||||
session_duration_ms: 1000,
|
||||
timestamp: 0,
|
||||
};
|
||||
|
||||
let result = coord.aggregate(export);
|
||||
assert_eq!(result.trajectories_accepted, 1);
|
||||
assert_eq!(result.trajectories_rejected, 1);
|
||||
assert_eq!(result.total_agents, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multi_agent_aggregation() {
|
||||
let mut coord = FederatedCoordinator::default_coordinator("coord-1", 256);
|
||||
coord.set_consolidation_interval(2); // Consolidate every 2 agents
|
||||
|
||||
for i in 0..3 {
|
||||
let export = AgentExport {
|
||||
agent_id: format!("agent-{}", i),
|
||||
trajectories: vec![TrajectoryExport {
|
||||
embedding: vec![i as f32 * 0.1; 256],
|
||||
quality: 0.8,
|
||||
route: None,
|
||||
context: vec![],
|
||||
timestamp: 0,
|
||||
}],
|
||||
stats: AgentExportStats::default(),
|
||||
session_duration_ms: 1000,
|
||||
timestamp: 0,
|
||||
};
|
||||
|
||||
let result = coord.aggregate(export);
|
||||
// Second agent should trigger consolidation
|
||||
if i == 1 {
|
||||
assert!(result.consolidated);
|
||||
}
|
||||
}
|
||||
|
||||
let stats = coord.stats();
|
||||
assert_eq!(stats.total_agents, 3);
|
||||
assert_eq!(stats.total_trajectories, 3);
|
||||
}
|
||||
}
|
||||
468
vendor/ruvector/crates/sona/src/training/metrics.rs
vendored
Normal file
468
vendor/ruvector/crates/sona/src/training/metrics.rs
vendored
Normal file
@@ -0,0 +1,468 @@
|
||||
//! Training Metrics for SONA
|
||||
//!
|
||||
//! Comprehensive analytics for training sessions.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Training metrics collection
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
pub struct TrainingMetrics {
|
||||
/// Pipeline/agent name
|
||||
pub name: String,
|
||||
/// Total examples processed
|
||||
pub total_examples: usize,
|
||||
/// Total training sessions
|
||||
pub training_sessions: u64,
|
||||
/// Patterns learned
|
||||
pub patterns_learned: usize,
|
||||
/// Quality samples for averaging
|
||||
pub quality_samples: Vec<f32>,
|
||||
/// Validation quality (if validation was run)
|
||||
pub validation_quality: Option<f32>,
|
||||
/// Performance metrics
|
||||
pub performance: PerformanceMetrics,
|
||||
}
|
||||
|
||||
impl TrainingMetrics {
|
||||
/// Create new metrics
|
||||
pub fn new(name: &str) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Add quality sample
|
||||
pub fn add_quality_sample(&mut self, quality: f32) {
|
||||
self.quality_samples.push(quality);
|
||||
// Keep last 10000 samples
|
||||
if self.quality_samples.len() > 10000 {
|
||||
self.quality_samples.remove(0);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get average quality
|
||||
pub fn avg_quality(&self) -> f32 {
|
||||
if self.quality_samples.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
self.quality_samples.iter().sum::<f32>() / self.quality_samples.len() as f32
|
||||
}
|
||||
}
|
||||
|
||||
/// Get quality percentile
|
||||
pub fn quality_percentile(&self, percentile: f32) -> f32 {
|
||||
if self.quality_samples.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mut sorted = self.quality_samples.clone();
|
||||
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
let idx = ((percentile / 100.0) * (sorted.len() - 1) as f32) as usize;
|
||||
sorted[idx.min(sorted.len() - 1)]
|
||||
}
|
||||
|
||||
/// Get quality statistics
|
||||
pub fn quality_stats(&self) -> QualityMetrics {
|
||||
if self.quality_samples.is_empty() {
|
||||
return QualityMetrics::default();
|
||||
}
|
||||
|
||||
let avg = self.avg_quality();
|
||||
let min = self
|
||||
.quality_samples
|
||||
.iter()
|
||||
.cloned()
|
||||
.fold(f32::MAX, f32::min);
|
||||
let max = self
|
||||
.quality_samples
|
||||
.iter()
|
||||
.cloned()
|
||||
.fold(f32::MIN, f32::max);
|
||||
|
||||
let variance = self
|
||||
.quality_samples
|
||||
.iter()
|
||||
.map(|q| (q - avg).powi(2))
|
||||
.sum::<f32>()
|
||||
/ self.quality_samples.len() as f32;
|
||||
let std_dev = variance.sqrt();
|
||||
|
||||
QualityMetrics {
|
||||
avg,
|
||||
min,
|
||||
max,
|
||||
std_dev,
|
||||
p25: self.quality_percentile(25.0),
|
||||
p50: self.quality_percentile(50.0),
|
||||
p75: self.quality_percentile(75.0),
|
||||
p95: self.quality_percentile(95.0),
|
||||
sample_count: self.quality_samples.len(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset metrics
|
||||
pub fn reset(&mut self) {
|
||||
self.total_examples = 0;
|
||||
self.training_sessions = 0;
|
||||
self.patterns_learned = 0;
|
||||
self.quality_samples.clear();
|
||||
self.validation_quality = None;
|
||||
self.performance = PerformanceMetrics::default();
|
||||
}
|
||||
|
||||
/// Merge with another metrics instance
|
||||
pub fn merge(&mut self, other: &TrainingMetrics) {
|
||||
self.total_examples += other.total_examples;
|
||||
self.training_sessions += other.training_sessions;
|
||||
self.patterns_learned = other.patterns_learned; // Take latest
|
||||
self.quality_samples.extend(&other.quality_samples);
|
||||
|
||||
// Keep last 10000
|
||||
if self.quality_samples.len() > 10000 {
|
||||
let excess = self.quality_samples.len() - 10000;
|
||||
self.quality_samples.drain(0..excess);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Quality metrics summary
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
pub struct QualityMetrics {
|
||||
/// Average quality
|
||||
pub avg: f32,
|
||||
/// Minimum quality
|
||||
pub min: f32,
|
||||
/// Maximum quality
|
||||
pub max: f32,
|
||||
/// Standard deviation
|
||||
pub std_dev: f32,
|
||||
/// 25th percentile
|
||||
pub p25: f32,
|
||||
/// 50th percentile (median)
|
||||
pub p50: f32,
|
||||
/// 75th percentile
|
||||
pub p75: f32,
|
||||
/// 95th percentile
|
||||
pub p95: f32,
|
||||
/// Number of samples
|
||||
pub sample_count: usize,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for QualityMetrics {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"avg={:.4}, std={:.4}, min={:.4}, max={:.4}, p50={:.4}, p95={:.4} (n={})",
|
||||
self.avg, self.std_dev, self.min, self.max, self.p50, self.p95, self.sample_count
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Performance metrics
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
pub struct PerformanceMetrics {
|
||||
/// Total training time in seconds
|
||||
pub total_training_secs: f64,
|
||||
/// Average batch processing time in milliseconds
|
||||
pub avg_batch_time_ms: f64,
|
||||
/// Average example processing time in microseconds
|
||||
pub avg_example_time_us: f64,
|
||||
/// Peak memory usage in MB
|
||||
pub peak_memory_mb: usize,
|
||||
/// Examples per second throughput
|
||||
pub examples_per_sec: f64,
|
||||
/// Pattern extraction time in milliseconds
|
||||
pub pattern_extraction_ms: f64,
|
||||
}
|
||||
|
||||
impl PerformanceMetrics {
|
||||
/// Calculate throughput
|
||||
pub fn calculate_throughput(&mut self, examples: usize, duration_secs: f64) {
|
||||
if duration_secs > 0.0 {
|
||||
self.examples_per_sec = examples as f64 / duration_secs;
|
||||
self.avg_example_time_us = (duration_secs * 1_000_000.0) / examples as f64;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Epoch statistics
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct EpochStats {
|
||||
/// Epoch number (0-indexed)
|
||||
pub epoch: usize,
|
||||
/// Examples processed in this epoch
|
||||
pub examples_processed: usize,
|
||||
/// Average quality for this epoch
|
||||
pub avg_quality: f32,
|
||||
/// Duration in seconds
|
||||
pub duration_secs: f64,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for EpochStats {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Epoch {}: {} examples, avg_quality={:.4}, {:.2}s",
|
||||
self.epoch + 1,
|
||||
self.examples_processed,
|
||||
self.avg_quality,
|
||||
self.duration_secs
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Training result summary
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TrainingResult {
|
||||
/// Pipeline name
|
||||
pub pipeline_name: String,
|
||||
/// Number of epochs completed
|
||||
pub epochs_completed: usize,
|
||||
/// Total examples processed
|
||||
pub total_examples: usize,
|
||||
/// Patterns learned
|
||||
pub patterns_learned: usize,
|
||||
/// Final average quality
|
||||
pub final_avg_quality: f32,
|
||||
/// Total duration in seconds
|
||||
pub total_duration_secs: f64,
|
||||
/// Per-epoch statistics
|
||||
pub epoch_stats: Vec<EpochStats>,
|
||||
/// Validation quality (if validation was run)
|
||||
pub validation_quality: Option<f32>,
|
||||
}
|
||||
|
||||
impl TrainingResult {
|
||||
/// Get examples per second
|
||||
pub fn examples_per_sec(&self) -> f64 {
|
||||
if self.total_duration_secs > 0.0 {
|
||||
self.total_examples as f64 / self.total_duration_secs
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Get average epoch duration
|
||||
pub fn avg_epoch_duration(&self) -> f64 {
|
||||
if self.epochs_completed > 0 {
|
||||
self.total_duration_secs / self.epochs_completed as f64
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if training improved quality
|
||||
pub fn quality_improved(&self) -> bool {
|
||||
if self.epoch_stats.len() < 2 {
|
||||
return false;
|
||||
}
|
||||
let first = self.epoch_stats.first().unwrap().avg_quality;
|
||||
let last = self.epoch_stats.last().unwrap().avg_quality;
|
||||
last > first
|
||||
}
|
||||
|
||||
/// Get quality improvement
|
||||
pub fn quality_improvement(&self) -> f32 {
|
||||
if self.epoch_stats.len() < 2 {
|
||||
return 0.0;
|
||||
}
|
||||
let first = self.epoch_stats.first().unwrap().avg_quality;
|
||||
let last = self.epoch_stats.last().unwrap().avg_quality;
|
||||
last - first
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TrainingResult {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"TrainingResult(pipeline={}, epochs={}, examples={}, patterns={}, \
|
||||
final_quality={:.4}, duration={:.2}s, throughput={:.1}/s)",
|
||||
self.pipeline_name,
|
||||
self.epochs_completed,
|
||||
self.total_examples,
|
||||
self.patterns_learned,
|
||||
self.final_avg_quality,
|
||||
self.total_duration_secs,
|
||||
self.examples_per_sec()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Comparison metrics between training runs
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct TrainingComparison {
|
||||
/// Baseline result name
|
||||
pub baseline_name: String,
|
||||
/// Comparison result name
|
||||
pub comparison_name: String,
|
||||
/// Quality difference (comparison - baseline)
|
||||
pub quality_diff: f32,
|
||||
/// Quality improvement percentage
|
||||
pub quality_improvement_pct: f32,
|
||||
/// Throughput difference
|
||||
pub throughput_diff: f64,
|
||||
/// Duration difference in seconds
|
||||
pub duration_diff: f64,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl TrainingComparison {
|
||||
/// Compare two training results
|
||||
pub fn compare(baseline: &TrainingResult, comparison: &TrainingResult) -> Self {
|
||||
let quality_diff = comparison.final_avg_quality - baseline.final_avg_quality;
|
||||
let quality_improvement_pct = if baseline.final_avg_quality > 0.0 {
|
||||
(quality_diff / baseline.final_avg_quality) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
Self {
|
||||
baseline_name: baseline.pipeline_name.clone(),
|
||||
comparison_name: comparison.pipeline_name.clone(),
|
||||
quality_diff,
|
||||
quality_improvement_pct,
|
||||
throughput_diff: comparison.examples_per_sec() - baseline.examples_per_sec(),
|
||||
duration_diff: comparison.total_duration_secs - baseline.total_duration_secs,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TrainingComparison {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let quality_sign = if self.quality_diff >= 0.0 { "+" } else { "" };
|
||||
let throughput_sign = if self.throughput_diff >= 0.0 { "+" } else { "" };
|
||||
|
||||
write!(
|
||||
f,
|
||||
"Comparison {} vs {}: quality {}{:.4} ({}{:.1}%), throughput {}{:.1}/s",
|
||||
self.comparison_name,
|
||||
self.baseline_name,
|
||||
quality_sign,
|
||||
self.quality_diff,
|
||||
quality_sign,
|
||||
self.quality_improvement_pct,
|
||||
throughput_sign,
|
||||
self.throughput_diff
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_metrics_creation() {
|
||||
let metrics = TrainingMetrics::new("test");
|
||||
assert_eq!(metrics.name, "test");
|
||||
assert_eq!(metrics.total_examples, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quality_samples() {
|
||||
let mut metrics = TrainingMetrics::new("test");
|
||||
|
||||
for i in 0..10 {
|
||||
metrics.add_quality_sample(i as f32 / 10.0);
|
||||
}
|
||||
|
||||
assert_eq!(metrics.quality_samples.len(), 10);
|
||||
assert!((metrics.avg_quality() - 0.45).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quality_percentiles() {
|
||||
let mut metrics = TrainingMetrics::new("test");
|
||||
|
||||
for i in 0..100 {
|
||||
metrics.add_quality_sample(i as f32 / 100.0);
|
||||
}
|
||||
|
||||
assert!((metrics.quality_percentile(50.0) - 0.5).abs() < 0.02);
|
||||
assert!((metrics.quality_percentile(95.0) - 0.95).abs() < 0.02);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quality_stats() {
|
||||
let mut metrics = TrainingMetrics::new("test");
|
||||
metrics.add_quality_sample(0.5);
|
||||
metrics.add_quality_sample(0.7);
|
||||
metrics.add_quality_sample(0.9);
|
||||
|
||||
let stats = metrics.quality_stats();
|
||||
assert!((stats.avg - 0.7).abs() < 0.01);
|
||||
assert!((stats.min - 0.5).abs() < 0.01);
|
||||
assert!((stats.max - 0.9).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_training_result() {
|
||||
let result = TrainingResult {
|
||||
pipeline_name: "test".into(),
|
||||
epochs_completed: 3,
|
||||
total_examples: 1000,
|
||||
patterns_learned: 50,
|
||||
final_avg_quality: 0.85,
|
||||
total_duration_secs: 10.0,
|
||||
epoch_stats: vec![
|
||||
EpochStats {
|
||||
epoch: 0,
|
||||
examples_processed: 333,
|
||||
avg_quality: 0.75,
|
||||
duration_secs: 3.0,
|
||||
},
|
||||
EpochStats {
|
||||
epoch: 1,
|
||||
examples_processed: 333,
|
||||
avg_quality: 0.80,
|
||||
duration_secs: 3.5,
|
||||
},
|
||||
EpochStats {
|
||||
epoch: 2,
|
||||
examples_processed: 334,
|
||||
avg_quality: 0.85,
|
||||
duration_secs: 3.5,
|
||||
},
|
||||
],
|
||||
validation_quality: Some(0.82),
|
||||
};
|
||||
|
||||
assert_eq!(result.examples_per_sec(), 100.0);
|
||||
assert!(result.quality_improved());
|
||||
assert!((result.quality_improvement() - 0.10).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_training_comparison() {
|
||||
let baseline = TrainingResult {
|
||||
pipeline_name: "baseline".into(),
|
||||
epochs_completed: 2,
|
||||
total_examples: 500,
|
||||
patterns_learned: 25,
|
||||
final_avg_quality: 0.70,
|
||||
total_duration_secs: 5.0,
|
||||
epoch_stats: vec![],
|
||||
validation_quality: None,
|
||||
};
|
||||
|
||||
let improved = TrainingResult {
|
||||
pipeline_name: "improved".into(),
|
||||
epochs_completed: 2,
|
||||
total_examples: 500,
|
||||
patterns_learned: 30,
|
||||
final_avg_quality: 0.85,
|
||||
total_duration_secs: 4.0,
|
||||
epoch_stats: vec![],
|
||||
validation_quality: None,
|
||||
};
|
||||
|
||||
let comparison = TrainingComparison::compare(&baseline, &improved);
|
||||
assert!((comparison.quality_diff - 0.15).abs() < 0.01);
|
||||
assert!(comparison.quality_improvement_pct > 20.0);
|
||||
assert!(comparison.throughput_diff > 0.0);
|
||||
}
|
||||
}
|
||||
70
vendor/ruvector/crates/sona/src/training/mod.rs
vendored
Normal file
70
vendor/ruvector/crates/sona/src/training/mod.rs
vendored
Normal file
@@ -0,0 +1,70 @@
|
||||
//! SONA Training System
|
||||
//!
|
||||
//! Templated training pipelines for specialized model adaptation.
|
||||
//!
|
||||
//! ## Overview
|
||||
//!
|
||||
//! The training module provides:
|
||||
//! - **Training Templates**: Pre-configured training setups for common use cases
|
||||
//! - **Agent Factory**: Create and manage multiple specialized agents
|
||||
//! - **Training Pipelines**: Structured workflows for different verticals
|
||||
//! - **Federated Learning**: Distributed training across ephemeral agents
|
||||
//! - **Metrics & Results**: Comprehensive training analytics
|
||||
//!
|
||||
//! ## Quick Start
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use ruvector_sona::training::{TrainingTemplate, AgentFactory, TrainingPipeline};
|
||||
//!
|
||||
//! // Use a preset template
|
||||
//! let template = TrainingTemplate::code_agent();
|
||||
//! let pipeline = TrainingPipeline::from_template(template);
|
||||
//!
|
||||
//! // Train on examples
|
||||
//! for example in examples {
|
||||
//! pipeline.add_example(example);
|
||||
//! }
|
||||
//! let results = pipeline.train()?;
|
||||
//! ```
|
||||
//!
|
||||
//! ## Federated Learning
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use ruvector_sona::training::{EphemeralAgent, FederatedCoordinator};
|
||||
//!
|
||||
//! // Create coordinator
|
||||
//! let mut coordinator = FederatedCoordinator::default_coordinator("main", 3072);
|
||||
//!
|
||||
//! // Ephemeral agents process tasks
|
||||
//! let mut agent = EphemeralAgent::default_federated("agent-1", 3072);
|
||||
//! agent.process_trajectory(embedding, activations, quality, route, context);
|
||||
//!
|
||||
//! // Export state before termination
|
||||
//! let export = agent.export_state();
|
||||
//! coordinator.aggregate(export);
|
||||
//! ```
|
||||
|
||||
mod factory;
|
||||
mod federated;
|
||||
mod metrics;
|
||||
mod pipeline;
|
||||
mod templates;
|
||||
|
||||
pub use factory::{
|
||||
AgentFactory, AgentHandle, AgentStats, ManagedAgent, SharedAgentFactory, SimpleExample,
|
||||
TrainingExample as FactoryTrainingExample,
|
||||
};
|
||||
pub use federated::{
|
||||
AgentContribution, AgentExport, AgentExportStats, AggregationResult, CoordinatorStats,
|
||||
EphemeralAgent, FederatedCoordinator, FederatedTopology, TrajectoryExport,
|
||||
};
|
||||
pub use metrics::{
|
||||
EpochStats, PerformanceMetrics, QualityMetrics, TrainingMetrics, TrainingResult,
|
||||
};
|
||||
pub use pipeline::{
|
||||
BatchConfig, PipelineStage, TrainingCallback, TrainingExample, TrainingPipeline,
|
||||
};
|
||||
pub use templates::{
|
||||
AgentType, DataSizeHint, TaskDomain, TemplatePreset, TrainingMethod, TrainingTemplate,
|
||||
VerticalConfig,
|
||||
};
|
||||
709
vendor/ruvector/crates/sona/src/training/pipeline.rs
vendored
Normal file
709
vendor/ruvector/crates/sona/src/training/pipeline.rs
vendored
Normal file
@@ -0,0 +1,709 @@
|
||||
//! Training Pipeline for SONA
|
||||
//!
|
||||
//! Structured training workflows with batching and callbacks.
|
||||
|
||||
use super::metrics::{EpochStats, TrainingMetrics, TrainingResult};
|
||||
use super::templates::{DataSizeHint, TrainingMethod, TrainingTemplate};
|
||||
use crate::engine::SonaEngine;
|
||||
use crate::time_compat::Instant;
|
||||
use crate::types::SonaConfig;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Training example with all data needed for learning
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TrainingExample {
|
||||
/// Input embedding
|
||||
pub embedding: Vec<f32>,
|
||||
/// Hidden activations (optional, defaults to embedding)
|
||||
pub activations: Option<Vec<f32>>,
|
||||
/// Attention weights (optional)
|
||||
pub attention: Option<Vec<f32>>,
|
||||
/// Quality score [0.0, 1.0]
|
||||
pub quality: f32,
|
||||
/// Reward signal (optional, defaults to quality)
|
||||
pub reward: Option<f32>,
|
||||
/// Model route identifier
|
||||
pub route: Option<String>,
|
||||
/// Context identifiers
|
||||
pub context: Vec<String>,
|
||||
/// Example weight for importance sampling
|
||||
pub weight: f32,
|
||||
/// Tags for filtering
|
||||
pub tags: Vec<String>,
|
||||
}
|
||||
|
||||
impl TrainingExample {
|
||||
/// Create a new training example
|
||||
pub fn new(embedding: Vec<f32>, quality: f32) -> Self {
|
||||
Self {
|
||||
embedding,
|
||||
activations: None,
|
||||
attention: None,
|
||||
quality,
|
||||
reward: None,
|
||||
route: None,
|
||||
context: Vec::new(),
|
||||
weight: 1.0,
|
||||
tags: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set activations
|
||||
pub fn with_activations(mut self, activations: Vec<f32>) -> Self {
|
||||
self.activations = Some(activations);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set attention
|
||||
pub fn with_attention(mut self, attention: Vec<f32>) -> Self {
|
||||
self.attention = Some(attention);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set reward
|
||||
pub fn with_reward(mut self, reward: f32) -> Self {
|
||||
self.reward = Some(reward);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set route
|
||||
pub fn with_route(mut self, route: impl Into<String>) -> Self {
|
||||
self.route = Some(route.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Add context
|
||||
pub fn with_context(mut self, ctx: impl Into<String>) -> Self {
|
||||
self.context.push(ctx.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set weight
|
||||
pub fn with_weight(mut self, weight: f32) -> Self {
|
||||
self.weight = weight;
|
||||
self
|
||||
}
|
||||
|
||||
/// Add tag
|
||||
pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
|
||||
self.tags.push(tag.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Get activations or default to embedding
|
||||
pub fn get_activations(&self) -> Vec<f32> {
|
||||
self.activations
|
||||
.clone()
|
||||
.unwrap_or_else(|| self.embedding.clone())
|
||||
}
|
||||
|
||||
/// Get attention or default
|
||||
pub fn get_attention(&self) -> Vec<f32> {
|
||||
self.attention
|
||||
.clone()
|
||||
.unwrap_or_else(|| vec![1.0 / 64.0; 64])
|
||||
}
|
||||
|
||||
/// Get reward or default to quality
|
||||
pub fn get_reward(&self) -> f32 {
|
||||
self.reward.unwrap_or(self.quality)
|
||||
}
|
||||
}
|
||||
|
||||
/// Batch configuration for training
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct BatchConfig {
|
||||
/// Batch size
|
||||
pub batch_size: usize,
|
||||
/// Shuffle examples
|
||||
pub shuffle: bool,
|
||||
/// Drop incomplete last batch
|
||||
pub drop_last: bool,
|
||||
/// Number of epochs
|
||||
pub epochs: usize,
|
||||
/// Early stopping patience (None = disabled)
|
||||
pub early_stopping_patience: Option<usize>,
|
||||
/// Minimum quality improvement for early stopping
|
||||
pub min_quality_improvement: f32,
|
||||
}
|
||||
|
||||
impl Default for BatchConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
batch_size: 32,
|
||||
shuffle: true,
|
||||
drop_last: false,
|
||||
epochs: 1,
|
||||
early_stopping_patience: None,
|
||||
min_quality_improvement: 0.001,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BatchConfig {
|
||||
/// Create config for single pass (no batching)
|
||||
pub fn single_pass() -> Self {
|
||||
Self {
|
||||
batch_size: usize::MAX,
|
||||
shuffle: false,
|
||||
drop_last: false,
|
||||
epochs: 1,
|
||||
early_stopping_patience: None,
|
||||
min_quality_improvement: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create config optimized for size hint
|
||||
pub fn for_data_size(hint: &DataSizeHint) -> Self {
|
||||
match hint {
|
||||
DataSizeHint::Tiny => Self {
|
||||
batch_size: 8,
|
||||
epochs: 10,
|
||||
early_stopping_patience: Some(3),
|
||||
..Default::default()
|
||||
},
|
||||
DataSizeHint::Small => Self {
|
||||
batch_size: 16,
|
||||
epochs: 5,
|
||||
early_stopping_patience: Some(2),
|
||||
..Default::default()
|
||||
},
|
||||
DataSizeHint::Medium => Self {
|
||||
batch_size: 32,
|
||||
epochs: 3,
|
||||
early_stopping_patience: Some(2),
|
||||
..Default::default()
|
||||
},
|
||||
DataSizeHint::Large => Self {
|
||||
batch_size: 64,
|
||||
epochs: 2,
|
||||
..Default::default()
|
||||
},
|
||||
DataSizeHint::Massive => Self {
|
||||
batch_size: 128,
|
||||
epochs: 1,
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Pipeline stage for tracking progress
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum PipelineStage {
|
||||
/// Not started
|
||||
Idle,
|
||||
/// Loading and preprocessing data
|
||||
Preprocessing,
|
||||
/// Training in progress
|
||||
Training,
|
||||
/// Running validation
|
||||
Validation,
|
||||
/// Extracting patterns
|
||||
PatternExtraction,
|
||||
/// Exporting results
|
||||
Export,
|
||||
/// Completed successfully
|
||||
Completed,
|
||||
/// Failed with error
|
||||
Failed,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PipelineStage {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
PipelineStage::Idle => write!(f, "idle"),
|
||||
PipelineStage::Preprocessing => write!(f, "preprocessing"),
|
||||
PipelineStage::Training => write!(f, "training"),
|
||||
PipelineStage::Validation => write!(f, "validation"),
|
||||
PipelineStage::PatternExtraction => write!(f, "pattern_extraction"),
|
||||
PipelineStage::Export => write!(f, "export"),
|
||||
PipelineStage::Completed => write!(f, "completed"),
|
||||
PipelineStage::Failed => write!(f, "failed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Callback trait for training events
|
||||
pub trait TrainingCallback: Send + Sync {
|
||||
/// Called when stage changes
|
||||
fn on_stage_change(&self, _stage: &PipelineStage) {}
|
||||
|
||||
/// Called after each batch
|
||||
fn on_batch_complete(&self, _batch_idx: usize, _total_batches: usize, _avg_quality: f32) {}
|
||||
|
||||
/// Called after each epoch
|
||||
fn on_epoch_complete(&self, _epoch: usize, _stats: &EpochStats) {}
|
||||
|
||||
/// Called when training completes
|
||||
fn on_training_complete(&self, _result: &TrainingResult) {}
|
||||
|
||||
/// Called on error
|
||||
fn on_error(&self, _error: &str) {}
|
||||
}
|
||||
|
||||
/// No-op callback implementation
|
||||
pub struct NoOpCallback;
|
||||
impl TrainingCallback for NoOpCallback {}
|
||||
|
||||
/// Logging callback implementation
|
||||
#[allow(dead_code)]
|
||||
pub struct LoggingCallback {
|
||||
prefix: String,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl LoggingCallback {
|
||||
/// Create with prefix
|
||||
pub fn new(prefix: impl Into<String>) -> Self {
|
||||
Self {
|
||||
prefix: prefix.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TrainingCallback for LoggingCallback {
|
||||
fn on_stage_change(&self, stage: &PipelineStage) {
|
||||
println!("[{}] Stage: {}", self.prefix, stage);
|
||||
}
|
||||
|
||||
fn on_batch_complete(&self, batch_idx: usize, total_batches: usize, avg_quality: f32) {
|
||||
if batch_idx % 10 == 0 || batch_idx == total_batches - 1 {
|
||||
println!(
|
||||
"[{}] Batch {}/{}: avg_quality={:.4}",
|
||||
self.prefix,
|
||||
batch_idx + 1,
|
||||
total_batches,
|
||||
avg_quality
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn on_epoch_complete(&self, epoch: usize, stats: &EpochStats) {
|
||||
println!(
|
||||
"[{}] Epoch {}: examples={}, avg_quality={:.4}, duration={:.2}s",
|
||||
self.prefix,
|
||||
epoch + 1,
|
||||
stats.examples_processed,
|
||||
stats.avg_quality,
|
||||
stats.duration_secs
|
||||
);
|
||||
}
|
||||
|
||||
fn on_training_complete(&self, result: &TrainingResult) {
|
||||
println!(
|
||||
"[{}] Training complete: epochs={}, patterns={}, final_quality={:.4}",
|
||||
self.prefix, result.epochs_completed, result.patterns_learned, result.final_avg_quality
|
||||
);
|
||||
}
|
||||
|
||||
fn on_error(&self, error: &str) {
|
||||
eprintln!("[{}] ERROR: {}", self.prefix, error);
|
||||
}
|
||||
}
|
||||
|
||||
/// Training pipeline for structured training workflows
|
||||
pub struct TrainingPipeline {
|
||||
/// Pipeline name
|
||||
name: String,
|
||||
/// SONA engine
|
||||
engine: SonaEngine,
|
||||
/// Batch configuration
|
||||
batch_config: BatchConfig,
|
||||
/// Training method
|
||||
training_method: TrainingMethod,
|
||||
/// Current stage
|
||||
stage: PipelineStage,
|
||||
/// Training examples buffer
|
||||
examples: Vec<TrainingExample>,
|
||||
/// Validation examples
|
||||
validation_examples: Vec<TrainingExample>,
|
||||
/// Training metrics
|
||||
metrics: TrainingMetrics,
|
||||
/// Callback
|
||||
callback: Box<dyn TrainingCallback>,
|
||||
/// Enable pattern extraction after training
|
||||
extract_patterns: bool,
|
||||
}
|
||||
|
||||
impl TrainingPipeline {
|
||||
/// Create a new training pipeline
|
||||
pub fn new(name: impl Into<String>, config: SonaConfig) -> Self {
|
||||
let name = name.into();
|
||||
Self {
|
||||
name: name.clone(),
|
||||
engine: SonaEngine::with_config(config),
|
||||
batch_config: BatchConfig::default(),
|
||||
training_method: TrainingMethod::default(),
|
||||
stage: PipelineStage::Idle,
|
||||
examples: Vec::new(),
|
||||
validation_examples: Vec::new(),
|
||||
metrics: TrainingMetrics::new(&name),
|
||||
callback: Box::new(NoOpCallback),
|
||||
extract_patterns: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create from template
|
||||
pub fn from_template(template: TrainingTemplate) -> Self {
|
||||
let batch_config = BatchConfig::for_data_size(&template.expected_data_size);
|
||||
let mut pipeline = Self::new(&template.name, template.sona_config);
|
||||
pipeline.batch_config = batch_config;
|
||||
pipeline.training_method = template.training_method;
|
||||
pipeline
|
||||
}
|
||||
|
||||
/// Set batch configuration
|
||||
pub fn with_batch_config(mut self, config: BatchConfig) -> Self {
|
||||
self.batch_config = config;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set training method
|
||||
pub fn with_training_method(mut self, method: TrainingMethod) -> Self {
|
||||
self.training_method = method;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set callback
|
||||
pub fn with_callback<C: TrainingCallback + 'static>(mut self, callback: C) -> Self {
|
||||
self.callback = Box::new(callback);
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable/disable pattern extraction
|
||||
pub fn with_pattern_extraction(mut self, enabled: bool) -> Self {
|
||||
self.extract_patterns = enabled;
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a training example
|
||||
pub fn add_example(&mut self, example: TrainingExample) {
|
||||
self.examples.push(example);
|
||||
}
|
||||
|
||||
/// Add multiple training examples
|
||||
pub fn add_examples(&mut self, examples: impl IntoIterator<Item = TrainingExample>) {
|
||||
self.examples.extend(examples);
|
||||
}
|
||||
|
||||
/// Add validation example
|
||||
pub fn add_validation_example(&mut self, example: TrainingExample) {
|
||||
self.validation_examples.push(example);
|
||||
}
|
||||
|
||||
/// Get current stage
|
||||
pub fn stage(&self) -> &PipelineStage {
|
||||
&self.stage
|
||||
}
|
||||
|
||||
/// Get number of examples
|
||||
pub fn example_count(&self) -> usize {
|
||||
self.examples.len()
|
||||
}
|
||||
|
||||
/// Get metrics
|
||||
pub fn metrics(&self) -> &TrainingMetrics {
|
||||
&self.metrics
|
||||
}
|
||||
|
||||
/// Get engine reference
|
||||
pub fn engine(&self) -> &SonaEngine {
|
||||
&self.engine
|
||||
}
|
||||
|
||||
/// Get mutable engine reference
|
||||
pub fn engine_mut(&mut self) -> &mut SonaEngine {
|
||||
&mut self.engine
|
||||
}
|
||||
|
||||
/// Run the training pipeline
|
||||
pub fn train(&mut self) -> Result<TrainingResult, String> {
|
||||
let start = Instant::now();
|
||||
|
||||
// Preprocessing
|
||||
self.set_stage(PipelineStage::Preprocessing);
|
||||
self.preprocess()?;
|
||||
|
||||
// Training
|
||||
self.set_stage(PipelineStage::Training);
|
||||
let epoch_stats = self.run_training()?;
|
||||
|
||||
// Validation (if examples provided)
|
||||
if !self.validation_examples.is_empty() {
|
||||
self.set_stage(PipelineStage::Validation);
|
||||
self.run_validation()?;
|
||||
}
|
||||
|
||||
// Pattern extraction
|
||||
if self.extract_patterns {
|
||||
self.set_stage(PipelineStage::PatternExtraction);
|
||||
self.engine.force_learn();
|
||||
}
|
||||
|
||||
self.set_stage(PipelineStage::Completed);
|
||||
|
||||
let result = TrainingResult {
|
||||
pipeline_name: self.name.clone(),
|
||||
epochs_completed: epoch_stats.len(),
|
||||
total_examples: self.metrics.total_examples,
|
||||
patterns_learned: self.metrics.patterns_learned,
|
||||
final_avg_quality: self.metrics.avg_quality(),
|
||||
total_duration_secs: start.elapsed().as_secs_f64(),
|
||||
epoch_stats,
|
||||
validation_quality: self.metrics.validation_quality,
|
||||
};
|
||||
|
||||
self.callback.on_training_complete(&result);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Set stage and notify callback
|
||||
fn set_stage(&mut self, stage: PipelineStage) {
|
||||
self.stage = stage.clone();
|
||||
self.callback.on_stage_change(&stage);
|
||||
}
|
||||
|
||||
/// Preprocess examples
|
||||
fn preprocess(&mut self) -> Result<(), String> {
|
||||
if self.examples.is_empty() {
|
||||
return Err("No training examples provided".into());
|
||||
}
|
||||
|
||||
// Shuffle if configured
|
||||
if self.batch_config.shuffle {
|
||||
use rand::seq::SliceRandom;
|
||||
let mut rng = rand::thread_rng();
|
||||
self.examples.shuffle(&mut rng);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Run training epochs
|
||||
fn run_training(&mut self) -> Result<Vec<EpochStats>, String> {
|
||||
let mut all_epoch_stats = Vec::new();
|
||||
let mut best_quality = 0.0f32;
|
||||
let mut patience_counter = 0usize;
|
||||
|
||||
for epoch in 0..self.batch_config.epochs {
|
||||
let epoch_start = Instant::now();
|
||||
let mut epoch_quality_sum = 0.0f32;
|
||||
let mut epoch_examples = 0usize;
|
||||
|
||||
// Create batch indices (to avoid borrow checker issues)
|
||||
let batch_size = self.batch_config.batch_size;
|
||||
let total_examples = self.examples.len();
|
||||
let mut batch_indices: Vec<(usize, usize)> = Vec::new();
|
||||
let mut start = 0;
|
||||
while start < total_examples {
|
||||
let end = (start + batch_size).min(total_examples);
|
||||
if end > start && (!self.batch_config.drop_last || end - start == batch_size) {
|
||||
batch_indices.push((start, end));
|
||||
}
|
||||
start = end;
|
||||
}
|
||||
let total_batches = batch_indices.len();
|
||||
|
||||
for (batch_idx, (start, end)) in batch_indices.into_iter().enumerate() {
|
||||
let batch_quality = self.train_batch_range(start, end)?;
|
||||
let batch_len = end - start;
|
||||
epoch_quality_sum += batch_quality * batch_len as f32;
|
||||
epoch_examples += batch_len;
|
||||
|
||||
self.callback.on_batch_complete(
|
||||
batch_idx,
|
||||
total_batches,
|
||||
epoch_quality_sum / epoch_examples as f32,
|
||||
);
|
||||
}
|
||||
|
||||
let epoch_avg_quality = if epoch_examples > 0 {
|
||||
epoch_quality_sum / epoch_examples as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let epoch_stats = EpochStats {
|
||||
epoch,
|
||||
examples_processed: epoch_examples,
|
||||
avg_quality: epoch_avg_quality,
|
||||
duration_secs: epoch_start.elapsed().as_secs_f64(),
|
||||
};
|
||||
|
||||
self.callback.on_epoch_complete(epoch, &epoch_stats);
|
||||
all_epoch_stats.push(epoch_stats);
|
||||
|
||||
// Early stopping check
|
||||
if let Some(patience) = self.batch_config.early_stopping_patience {
|
||||
let improvement = epoch_avg_quality - best_quality;
|
||||
if improvement > self.batch_config.min_quality_improvement {
|
||||
best_quality = epoch_avg_quality;
|
||||
patience_counter = 0;
|
||||
} else {
|
||||
patience_counter += 1;
|
||||
if patience_counter >= patience {
|
||||
break; // Early stop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reshuffle for next epoch
|
||||
if self.batch_config.shuffle && epoch + 1 < self.batch_config.epochs {
|
||||
use rand::seq::SliceRandom;
|
||||
let mut rng = rand::thread_rng();
|
||||
self.examples.shuffle(&mut rng);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(all_epoch_stats)
|
||||
}
|
||||
|
||||
/// Train on examples in a range
|
||||
fn train_batch_range(&mut self, start: usize, end: usize) -> Result<f32, String> {
|
||||
let mut quality_sum = 0.0f32;
|
||||
let batch_len = end - start;
|
||||
|
||||
for idx in start..end {
|
||||
let example = &self.examples[idx];
|
||||
|
||||
// Begin trajectory using builder API
|
||||
let mut builder = self.engine.begin_trajectory(example.embedding.clone());
|
||||
|
||||
// Set route
|
||||
if let Some(ref route) = example.route {
|
||||
builder.set_model_route(route);
|
||||
}
|
||||
|
||||
// Add context
|
||||
for ctx in &example.context {
|
||||
builder.add_context(ctx);
|
||||
}
|
||||
|
||||
// Add step
|
||||
builder.add_step(
|
||||
example.get_activations(),
|
||||
example.get_attention(),
|
||||
example.get_reward() * example.weight,
|
||||
);
|
||||
|
||||
// End trajectory
|
||||
self.engine.end_trajectory(builder, example.quality);
|
||||
|
||||
quality_sum += example.quality;
|
||||
self.metrics.total_examples += 1;
|
||||
self.metrics.add_quality_sample(example.quality);
|
||||
}
|
||||
|
||||
// Run tick to process accumulated trajectories
|
||||
self.engine.tick();
|
||||
|
||||
Ok(quality_sum / batch_len as f32)
|
||||
}
|
||||
|
||||
/// Run validation
|
||||
fn run_validation(&mut self) -> Result<(), String> {
|
||||
let mut quality_sum = 0.0f32;
|
||||
|
||||
for example in &self.validation_examples {
|
||||
// Apply learned transformations
|
||||
let mut output = vec![0.0f32; example.embedding.len()];
|
||||
self.engine
|
||||
.apply_micro_lora(&example.embedding, &mut output);
|
||||
|
||||
// In a real scenario, you'd evaluate the model output
|
||||
// For now, we track the expected quality
|
||||
quality_sum += example.quality;
|
||||
}
|
||||
|
||||
self.metrics.validation_quality = Some(quality_sum / self.validation_examples.len() as f32);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Clear examples (keep engine state)
|
||||
pub fn clear_examples(&mut self) {
|
||||
self.examples.clear();
|
||||
self.validation_examples.clear();
|
||||
}
|
||||
|
||||
/// Reset pipeline (clear examples and metrics)
|
||||
pub fn reset(&mut self) {
|
||||
self.clear_examples();
|
||||
self.metrics = TrainingMetrics::new(&self.name);
|
||||
self.stage = PipelineStage::Idle;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_training_example() {
|
||||
let example = TrainingExample::new(vec![0.1; 256], 0.8)
|
||||
.with_route("test")
|
||||
.with_context("ctx1")
|
||||
.with_weight(1.5)
|
||||
.with_tag("test");
|
||||
|
||||
assert_eq!(example.quality, 0.8);
|
||||
assert_eq!(example.route, Some("test".into()));
|
||||
assert_eq!(example.weight, 1.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_config() {
|
||||
let config = BatchConfig::for_data_size(&DataSizeHint::Small);
|
||||
assert_eq!(config.batch_size, 16);
|
||||
assert_eq!(config.epochs, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pipeline_creation() {
|
||||
let pipeline = TrainingPipeline::new("test", SonaConfig::default());
|
||||
assert_eq!(pipeline.stage(), &PipelineStage::Idle);
|
||||
assert_eq!(pipeline.example_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pipeline_from_template() {
|
||||
let template = TrainingTemplate::code_agent().with_hidden_dim(256);
|
||||
let pipeline = TrainingPipeline::from_template(template);
|
||||
assert_eq!(pipeline.name, "code-agent");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pipeline_training() {
|
||||
let mut pipeline =
|
||||
TrainingPipeline::new("test", SonaConfig::default()).with_batch_config(BatchConfig {
|
||||
batch_size: 2,
|
||||
epochs: 2,
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
// Add examples
|
||||
for i in 0..5 {
|
||||
pipeline.add_example(TrainingExample::new(
|
||||
vec![i as f32 * 0.1; 256],
|
||||
0.7 + i as f32 * 0.05,
|
||||
));
|
||||
}
|
||||
|
||||
let result = pipeline.train().unwrap();
|
||||
assert_eq!(result.epochs_completed, 2);
|
||||
assert!(result.total_examples > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pipeline_with_validation() {
|
||||
let mut pipeline = TrainingPipeline::new("test", SonaConfig::default())
|
||||
.with_batch_config(BatchConfig::single_pass());
|
||||
|
||||
pipeline.add_example(TrainingExample::new(vec![0.1; 256], 0.8));
|
||||
pipeline.add_validation_example(TrainingExample::new(vec![0.2; 256], 0.9));
|
||||
|
||||
let result = pipeline.train().unwrap();
|
||||
assert!(result.validation_quality.is_some());
|
||||
}
|
||||
}
|
||||
656
vendor/ruvector/crates/sona/src/training/templates.rs
vendored
Normal file
656
vendor/ruvector/crates/sona/src/training/templates.rs
vendored
Normal file
@@ -0,0 +1,656 @@
|
||||
//! Training Templates for SONA
|
||||
//!
|
||||
//! Pre-configured training setups optimized for different use cases.
|
||||
|
||||
use crate::types::SonaConfig;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Agent specialization types
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum AgentType {
|
||||
/// Code generation and assistance
|
||||
CodeAgent,
|
||||
/// General chat and conversation
|
||||
ChatAgent,
|
||||
/// Document retrieval and Q&A
|
||||
RagAgent,
|
||||
/// Task decomposition and planning
|
||||
TaskPlanner,
|
||||
/// Domain-specific expert
|
||||
DomainExpert,
|
||||
/// Codebase-aware assistant
|
||||
CodebaseHelper,
|
||||
/// Data analysis and insights
|
||||
DataAnalyst,
|
||||
/// Creative writing and content
|
||||
CreativeWriter,
|
||||
/// Reasoning and logic
|
||||
ReasoningAgent,
|
||||
/// Multi-modal understanding
|
||||
MultiModal,
|
||||
/// Custom agent type
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for AgentType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
AgentType::CodeAgent => write!(f, "code-agent"),
|
||||
AgentType::ChatAgent => write!(f, "chat-agent"),
|
||||
AgentType::RagAgent => write!(f, "rag-agent"),
|
||||
AgentType::TaskPlanner => write!(f, "task-planner"),
|
||||
AgentType::DomainExpert => write!(f, "domain-expert"),
|
||||
AgentType::CodebaseHelper => write!(f, "codebase-helper"),
|
||||
AgentType::DataAnalyst => write!(f, "data-analyst"),
|
||||
AgentType::CreativeWriter => write!(f, "creative-writer"),
|
||||
AgentType::ReasoningAgent => write!(f, "reasoning-agent"),
|
||||
AgentType::MultiModal => write!(f, "multi-modal"),
|
||||
AgentType::Custom(name) => write!(f, "custom-{}", name),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Task domain for training focus
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum TaskDomain {
|
||||
/// Software development
|
||||
SoftwareDevelopment,
|
||||
/// Customer support
|
||||
CustomerSupport,
|
||||
/// Healthcare
|
||||
Healthcare,
|
||||
/// Finance
|
||||
Finance,
|
||||
/// Legal
|
||||
Legal,
|
||||
/// Education
|
||||
Education,
|
||||
/// Research
|
||||
Research,
|
||||
/// Marketing
|
||||
Marketing,
|
||||
/// General purpose
|
||||
General,
|
||||
/// Custom domain
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
/// Training method configuration
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum TrainingMethod {
|
||||
/// Standard supervised learning
|
||||
Supervised {
|
||||
/// Batch size for training
|
||||
batch_size: usize,
|
||||
/// Number of epochs
|
||||
epochs: usize,
|
||||
},
|
||||
/// Reinforcement learning from feedback
|
||||
RLHF {
|
||||
/// Reward model weight
|
||||
reward_weight: f32,
|
||||
/// KL divergence penalty
|
||||
kl_penalty: f32,
|
||||
},
|
||||
/// Direct preference optimization
|
||||
DPO {
|
||||
/// Beta parameter for DPO
|
||||
beta: f32,
|
||||
/// Reference model weight
|
||||
ref_weight: f32,
|
||||
},
|
||||
/// Continuous online learning
|
||||
Online {
|
||||
/// Learning rate decay
|
||||
lr_decay: f32,
|
||||
/// Window size for recent examples
|
||||
window_size: usize,
|
||||
},
|
||||
/// Few-shot adaptation
|
||||
FewShot {
|
||||
/// Number of examples per class
|
||||
k_shot: usize,
|
||||
/// Meta-learning rate
|
||||
meta_lr: f32,
|
||||
},
|
||||
}
|
||||
|
||||
impl Default for TrainingMethod {
|
||||
fn default() -> Self {
|
||||
TrainingMethod::Online {
|
||||
lr_decay: 0.999,
|
||||
window_size: 1000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Vertical-specific configuration
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct VerticalConfig {
|
||||
/// Domain focus
|
||||
pub domain: TaskDomain,
|
||||
/// Specialized vocabulary size
|
||||
pub vocab_boost: usize,
|
||||
/// Domain-specific quality metrics
|
||||
pub quality_metrics: Vec<String>,
|
||||
/// Compliance requirements
|
||||
pub compliance_level: ComplianceLevel,
|
||||
}
|
||||
|
||||
/// Compliance level for regulated industries
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
pub enum ComplianceLevel {
|
||||
#[default]
|
||||
None,
|
||||
/// Basic audit logging
|
||||
Basic,
|
||||
/// HIPAA compliance
|
||||
Hipaa,
|
||||
/// SOC2 compliance
|
||||
Soc2,
|
||||
/// GDPR compliance
|
||||
Gdpr,
|
||||
/// Custom compliance
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
/// Template preset for quick configuration
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum TemplatePreset {
|
||||
/// Minimal configuration for testing
|
||||
Minimal,
|
||||
/// Balanced for general use
|
||||
Balanced,
|
||||
/// High performance for production
|
||||
Production,
|
||||
/// Maximum quality regardless of speed
|
||||
MaxQuality,
|
||||
/// Edge deployment (<5MB)
|
||||
Edge,
|
||||
/// Research and experimentation
|
||||
Research,
|
||||
}
|
||||
|
||||
/// Training template with full configuration
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TrainingTemplate {
|
||||
/// Template name
|
||||
pub name: String,
|
||||
/// Agent type
|
||||
pub agent_type: AgentType,
|
||||
/// SONA configuration
|
||||
pub sona_config: SonaConfig,
|
||||
/// Training method
|
||||
pub training_method: TrainingMethod,
|
||||
/// Vertical configuration
|
||||
pub vertical: Option<VerticalConfig>,
|
||||
/// Expected training data size
|
||||
pub expected_data_size: DataSizeHint,
|
||||
/// Memory budget in MB
|
||||
pub memory_budget_mb: usize,
|
||||
/// Target latency in microseconds
|
||||
pub target_latency_us: u64,
|
||||
/// Enable continuous learning
|
||||
pub continuous_learning: bool,
|
||||
/// Auto-export trained adapters
|
||||
pub auto_export: bool,
|
||||
/// Tags for organization
|
||||
pub tags: Vec<String>,
|
||||
}
|
||||
|
||||
/// Hint about training data size
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
pub enum DataSizeHint {
|
||||
/// <100 examples (few-shot)
|
||||
Tiny,
|
||||
/// 100-1000 examples
|
||||
Small,
|
||||
/// 1000-10000 examples
|
||||
#[default]
|
||||
Medium,
|
||||
/// 10000-100000 examples
|
||||
Large,
|
||||
/// >100000 examples
|
||||
Massive,
|
||||
}
|
||||
|
||||
impl TrainingTemplate {
|
||||
/// Create a new training template
|
||||
pub fn new(name: impl Into<String>, agent_type: AgentType) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
agent_type,
|
||||
sona_config: SonaConfig::default(),
|
||||
training_method: TrainingMethod::default(),
|
||||
vertical: None,
|
||||
expected_data_size: DataSizeHint::default(),
|
||||
memory_budget_mb: 100,
|
||||
target_latency_us: 1000,
|
||||
continuous_learning: true,
|
||||
auto_export: false,
|
||||
tags: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create from preset
|
||||
pub fn from_preset(preset: TemplatePreset, agent_type: AgentType) -> Self {
|
||||
let mut template = Self::new(format!("{:?}-{}", preset, agent_type), agent_type.clone());
|
||||
|
||||
match preset {
|
||||
TemplatePreset::Minimal => {
|
||||
template.sona_config = SonaConfig::edge_deployment();
|
||||
template.memory_budget_mb = 10;
|
||||
template.expected_data_size = DataSizeHint::Tiny;
|
||||
}
|
||||
TemplatePreset::Balanced => {
|
||||
template.sona_config = SonaConfig::default();
|
||||
template.memory_budget_mb = 100;
|
||||
}
|
||||
TemplatePreset::Production => {
|
||||
template.sona_config = SonaConfig::max_throughput();
|
||||
template.memory_budget_mb = 200;
|
||||
template.auto_export = true;
|
||||
}
|
||||
TemplatePreset::MaxQuality => {
|
||||
template.sona_config = SonaConfig::max_quality();
|
||||
template.memory_budget_mb = 500;
|
||||
template.expected_data_size = DataSizeHint::Large;
|
||||
}
|
||||
TemplatePreset::Edge => {
|
||||
template.sona_config = SonaConfig::edge_deployment();
|
||||
template.memory_budget_mb = 5;
|
||||
template.target_latency_us = 500;
|
||||
}
|
||||
TemplatePreset::Research => {
|
||||
template.sona_config = SonaConfig::max_quality();
|
||||
template.sona_config.trajectory_capacity = 50000;
|
||||
template.memory_budget_mb = 1000;
|
||||
template.expected_data_size = DataSizeHint::Massive;
|
||||
}
|
||||
}
|
||||
|
||||
// Apply agent-specific optimizations
|
||||
template.apply_agent_optimizations();
|
||||
template
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// Pre-built Templates for Common Use Cases
|
||||
//------------------------------------------------------------------
|
||||
|
||||
/// Code agent template - optimized for code generation
|
||||
///
|
||||
/// **Best for**: Code completion, bug fixes, refactoring
|
||||
/// **Config**: baseLoraRank=16, clusters=200, capacity=10000
|
||||
/// **Training data**: Code completions, fixes, reviews
|
||||
pub fn code_agent() -> Self {
|
||||
let mut template = Self::new("code-agent", AgentType::CodeAgent);
|
||||
template.sona_config.base_lora_rank = 16; // Deeper for code patterns
|
||||
template.sona_config.pattern_clusters = 200; // Many code patterns
|
||||
template.sona_config.trajectory_capacity = 10000;
|
||||
template.sona_config.quality_threshold = 0.2; // Learn from most examples
|
||||
template.training_method = TrainingMethod::Online {
|
||||
lr_decay: 0.9995,
|
||||
window_size: 5000,
|
||||
};
|
||||
template.tags = vec!["code".into(), "development".into(), "completion".into()];
|
||||
template
|
||||
}
|
||||
|
||||
/// Chat agent template - optimized for conversational AI
|
||||
///
|
||||
/// **Best for**: Customer support, general chat, assistants
|
||||
/// **Config**: baseLoraRank=8, clusters=50, fast response
|
||||
/// **Training data**: Conversation histories, feedback
|
||||
pub fn chat_agent() -> Self {
|
||||
let mut template = Self::new("chat-agent", AgentType::ChatAgent);
|
||||
template.sona_config.base_lora_rank = 8;
|
||||
template.sona_config.pattern_clusters = 50;
|
||||
template.sona_config.quality_threshold = 0.4;
|
||||
template.target_latency_us = 500; // Fast responses
|
||||
template.training_method = TrainingMethod::RLHF {
|
||||
reward_weight: 0.5,
|
||||
kl_penalty: 0.1,
|
||||
};
|
||||
template.tags = vec!["chat".into(), "conversation".into(), "support".into()];
|
||||
template
|
||||
}
|
||||
|
||||
/// RAG agent template - optimized for retrieval-augmented generation
|
||||
///
|
||||
/// **Best for**: Document Q&A, knowledge bases, search
|
||||
/// **Config**: clusters=200, capacity=10000, high pattern storage
|
||||
/// **Training data**: Document chunks, Q&A pairs
|
||||
pub fn rag_agent() -> Self {
|
||||
let mut template = Self::new("rag-agent", AgentType::RagAgent);
|
||||
template.sona_config.pattern_clusters = 200; // Many document patterns
|
||||
template.sona_config.trajectory_capacity = 10000;
|
||||
template.sona_config.embedding_dim = 512; // Larger embeddings for retrieval
|
||||
template.sona_config.hidden_dim = 512;
|
||||
template.training_method = TrainingMethod::Supervised {
|
||||
batch_size: 32,
|
||||
epochs: 10,
|
||||
};
|
||||
template.tags = vec!["rag".into(), "retrieval".into(), "documents".into()];
|
||||
template
|
||||
}
|
||||
|
||||
/// Task planner template - optimized for task decomposition
|
||||
///
|
||||
/// **Best for**: Project planning, task breakdown, scheduling
|
||||
/// **Config**: baseLoraRank=16, ewcLambda=2000, multi-task
|
||||
/// **Training data**: Task decompositions, planning examples
|
||||
pub fn task_planner() -> Self {
|
||||
let mut template = Self::new("task-planner", AgentType::TaskPlanner);
|
||||
template.sona_config.base_lora_rank = 16;
|
||||
template.sona_config.ewc_lambda = 2000.0; // Important for multi-task
|
||||
template.sona_config.pattern_clusters = 100;
|
||||
template.training_method = TrainingMethod::DPO {
|
||||
beta: 0.1,
|
||||
ref_weight: 0.5,
|
||||
};
|
||||
template.tags = vec!["planning".into(), "tasks".into(), "decomposition".into()];
|
||||
template
|
||||
}
|
||||
|
||||
/// Domain expert template - optimized for specialized knowledge
|
||||
///
|
||||
/// **Best for**: Legal, medical, financial expertise
|
||||
/// **Config**: qualityThreshold=0.1, high capacity, compliance
|
||||
/// **Training data**: Domain-specific Q&A, expert responses
|
||||
pub fn domain_expert(domain: TaskDomain) -> Self {
|
||||
let domain_name = format!("{:?}", domain).to_lowercase();
|
||||
let mut template = Self::new(
|
||||
format!("domain-expert-{}", domain_name),
|
||||
AgentType::DomainExpert,
|
||||
);
|
||||
template.sona_config.quality_threshold = 0.1; // Learn from all domain examples
|
||||
template.sona_config.trajectory_capacity = 20000;
|
||||
template.sona_config.base_lora_rank = 16;
|
||||
template.vertical = Some(VerticalConfig {
|
||||
domain: domain.clone(),
|
||||
vocab_boost: 10000,
|
||||
quality_metrics: vec!["accuracy".into(), "relevance".into(), "compliance".into()],
|
||||
compliance_level: match domain {
|
||||
TaskDomain::Healthcare => ComplianceLevel::Hipaa,
|
||||
TaskDomain::Finance => ComplianceLevel::Soc2,
|
||||
TaskDomain::Legal => ComplianceLevel::Basic,
|
||||
_ => ComplianceLevel::None,
|
||||
},
|
||||
});
|
||||
template.tags = vec!["domain".into(), "expert".into(), domain_name];
|
||||
template
|
||||
}
|
||||
|
||||
/// Codebase helper template - learns your specific codebase
|
||||
///
|
||||
/// **Best for**: Repository-specific assistance, code navigation
|
||||
/// **Config**: clusters=200, capacity=10000, high pattern storage
|
||||
/// **Training data**: Your repo's code, documentation
|
||||
pub fn codebase_helper() -> Self {
|
||||
let mut template = Self::new("codebase-helper", AgentType::CodebaseHelper);
|
||||
template.sona_config.pattern_clusters = 200;
|
||||
template.sona_config.trajectory_capacity = 10000;
|
||||
template.sona_config.quality_threshold = 0.2;
|
||||
template.sona_config.base_lora_rank = 16;
|
||||
template.expected_data_size = DataSizeHint::Large;
|
||||
template.training_method = TrainingMethod::Online {
|
||||
lr_decay: 0.999,
|
||||
window_size: 10000,
|
||||
};
|
||||
template.tags = vec!["codebase".into(), "repository".into(), "navigation".into()];
|
||||
template
|
||||
}
|
||||
|
||||
/// Data analyst template - optimized for data insights
|
||||
///
|
||||
/// **Best for**: Data analysis, visualization, statistics
|
||||
/// **Config**: baseLoraRank=8, clusters=100, reasoning focus
|
||||
pub fn data_analyst() -> Self {
|
||||
let mut template = Self::new("data-analyst", AgentType::DataAnalyst);
|
||||
template.sona_config.base_lora_rank = 8;
|
||||
template.sona_config.pattern_clusters = 100;
|
||||
template.vertical = Some(VerticalConfig {
|
||||
domain: TaskDomain::Research,
|
||||
vocab_boost: 5000,
|
||||
quality_metrics: vec!["accuracy".into(), "insight_quality".into()],
|
||||
compliance_level: ComplianceLevel::None,
|
||||
});
|
||||
template.tags = vec!["data".into(), "analysis".into(), "insights".into()];
|
||||
template
|
||||
}
|
||||
|
||||
/// Creative writer template - optimized for content generation
|
||||
///
|
||||
/// **Best for**: Marketing copy, blog posts, creative writing
|
||||
/// **Config**: High diversity, quality focus
|
||||
pub fn creative_writer() -> Self {
|
||||
let mut template = Self::new("creative-writer", AgentType::CreativeWriter);
|
||||
template.sona_config.base_lora_rank = 8;
|
||||
template.sona_config.pattern_clusters = 50; // Fewer clusters for diversity
|
||||
template.sona_config.quality_threshold = 0.5; // Only learn from high quality
|
||||
template.training_method = TrainingMethod::RLHF {
|
||||
reward_weight: 0.7,
|
||||
kl_penalty: 0.05, // Less constraint for creativity
|
||||
};
|
||||
template.vertical = Some(VerticalConfig {
|
||||
domain: TaskDomain::Marketing,
|
||||
vocab_boost: 0,
|
||||
quality_metrics: vec!["creativity".into(), "engagement".into(), "clarity".into()],
|
||||
compliance_level: ComplianceLevel::None,
|
||||
});
|
||||
template.tags = vec!["creative".into(), "writing".into(), "content".into()];
|
||||
template
|
||||
}
|
||||
|
||||
/// Reasoning agent template - optimized for logical reasoning
|
||||
///
|
||||
/// **Best for**: Math, logic, chain-of-thought reasoning
|
||||
/// **Config**: High rank, strong EWC, accuracy focus
|
||||
pub fn reasoning_agent() -> Self {
|
||||
let mut template = Self::new("reasoning-agent", AgentType::ReasoningAgent);
|
||||
template.sona_config.base_lora_rank = 16;
|
||||
template.sona_config.ewc_lambda = 3000.0; // Strong protection
|
||||
template.sona_config.pattern_clusters = 150;
|
||||
template.sona_config.quality_threshold = 0.3;
|
||||
template.training_method = TrainingMethod::DPO {
|
||||
beta: 0.15,
|
||||
ref_weight: 0.4,
|
||||
};
|
||||
template.tags = vec!["reasoning".into(), "logic".into(), "math".into()];
|
||||
template
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// Builder Methods
|
||||
//------------------------------------------------------------------
|
||||
|
||||
/// Set SONA configuration
|
||||
pub fn with_sona_config(mut self, config: SonaConfig) -> Self {
|
||||
self.sona_config = config;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set training method
|
||||
pub fn with_training_method(mut self, method: TrainingMethod) -> Self {
|
||||
self.training_method = method;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set vertical configuration
|
||||
pub fn with_vertical(mut self, vertical: VerticalConfig) -> Self {
|
||||
self.vertical = Some(vertical);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set memory budget
|
||||
pub fn with_memory_budget(mut self, mb: usize) -> Self {
|
||||
self.memory_budget_mb = mb;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set target latency
|
||||
pub fn with_target_latency(mut self, us: u64) -> Self {
|
||||
self.target_latency_us = us;
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable continuous learning
|
||||
pub fn with_continuous_learning(mut self, enabled: bool) -> Self {
|
||||
self.continuous_learning = enabled;
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable auto-export
|
||||
pub fn with_auto_export(mut self, enabled: bool) -> Self {
|
||||
self.auto_export = enabled;
|
||||
self
|
||||
}
|
||||
|
||||
/// Add tags
|
||||
pub fn with_tags(mut self, tags: Vec<String>) -> Self {
|
||||
self.tags = tags;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set hidden dimension
|
||||
pub fn with_hidden_dim(mut self, dim: usize) -> Self {
|
||||
self.sona_config.hidden_dim = dim;
|
||||
self.sona_config.embedding_dim = dim;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set LoRA ranks
|
||||
pub fn with_lora_ranks(mut self, micro: usize, base: usize) -> Self {
|
||||
self.sona_config.micro_lora_rank = micro.min(2); // MicroLoRA max rank is 2
|
||||
self.sona_config.base_lora_rank = base;
|
||||
self
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// Internal Methods
|
||||
//------------------------------------------------------------------
|
||||
|
||||
/// Apply agent-specific optimizations
|
||||
fn apply_agent_optimizations(&mut self) {
|
||||
match &self.agent_type {
|
||||
AgentType::CodeAgent | AgentType::CodebaseHelper => {
|
||||
self.sona_config.pattern_clusters = 200;
|
||||
self.sona_config.base_lora_rank = 16;
|
||||
}
|
||||
AgentType::ChatAgent => {
|
||||
self.sona_config.pattern_clusters = 50;
|
||||
self.target_latency_us = 500;
|
||||
}
|
||||
AgentType::RagAgent => {
|
||||
self.sona_config.pattern_clusters = 200;
|
||||
self.sona_config.trajectory_capacity = 10000;
|
||||
}
|
||||
AgentType::ReasoningAgent => {
|
||||
self.sona_config.ewc_lambda = 3000.0;
|
||||
self.sona_config.base_lora_rank = 16;
|
||||
}
|
||||
AgentType::DomainExpert => {
|
||||
self.sona_config.quality_threshold = 0.1;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate template configuration
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
if self.sona_config.micro_lora_rank > 2 {
|
||||
return Err("MicroLoRA rank must be 1 or 2".into());
|
||||
}
|
||||
if self.sona_config.hidden_dim == 0 {
|
||||
return Err("Hidden dimension must be > 0".into());
|
||||
}
|
||||
if self.memory_budget_mb < 1 {
|
||||
return Err("Memory budget must be >= 1 MB".into());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get estimated memory usage in MB
|
||||
pub fn estimated_memory_mb(&self) -> usize {
|
||||
let config = &self.sona_config;
|
||||
|
||||
// Base engine memory
|
||||
let engine_mb = 5;
|
||||
|
||||
// LoRA weights: hidden_dim * rank * 2 (A and B matrices) * 4 bytes * 2 (micro + base)
|
||||
let lora_bytes =
|
||||
config.hidden_dim * (config.micro_lora_rank + config.base_lora_rank) * 2 * 4 * 2;
|
||||
let lora_mb = lora_bytes / (1024 * 1024);
|
||||
|
||||
// Trajectory buffer: capacity * ~800 bytes per trajectory
|
||||
let traj_mb = (config.trajectory_capacity * 800) / (1024 * 1024);
|
||||
|
||||
// Pattern storage: clusters * embedding_dim * 4 bytes
|
||||
let pattern_mb = (config.pattern_clusters * config.embedding_dim * 4) / (1024 * 1024);
|
||||
|
||||
engine_mb + lora_mb + traj_mb + pattern_mb + 1
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_template_creation() {
|
||||
let template = TrainingTemplate::code_agent();
|
||||
assert_eq!(template.agent_type, AgentType::CodeAgent);
|
||||
assert_eq!(template.sona_config.base_lora_rank, 16);
|
||||
assert_eq!(template.sona_config.pattern_clusters, 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_preset_templates() {
|
||||
let production =
|
||||
TrainingTemplate::from_preset(TemplatePreset::Production, AgentType::ChatAgent);
|
||||
assert!(production.auto_export);
|
||||
|
||||
let edge = TrainingTemplate::from_preset(TemplatePreset::Edge, AgentType::ChatAgent);
|
||||
assert_eq!(edge.memory_budget_mb, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_expert() {
|
||||
let medical = TrainingTemplate::domain_expert(TaskDomain::Healthcare);
|
||||
assert!(medical.vertical.is_some());
|
||||
if let Some(v) = &medical.vertical {
|
||||
assert!(matches!(v.compliance_level, ComplianceLevel::Hipaa));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder_pattern() {
|
||||
let template = TrainingTemplate::new("custom", AgentType::Custom("test".into()))
|
||||
.with_hidden_dim(512)
|
||||
.with_lora_ranks(2, 16)
|
||||
.with_memory_budget(200)
|
||||
.with_continuous_learning(true);
|
||||
|
||||
assert_eq!(template.sona_config.hidden_dim, 512);
|
||||
assert_eq!(template.sona_config.micro_lora_rank, 2);
|
||||
assert_eq!(template.sona_config.base_lora_rank, 16);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validation() {
|
||||
let mut template = TrainingTemplate::code_agent();
|
||||
assert!(template.validate().is_ok());
|
||||
|
||||
template.sona_config.micro_lora_rank = 5;
|
||||
assert!(template.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_estimation() {
|
||||
let template = TrainingTemplate::code_agent();
|
||||
let mem = template.estimated_memory_mb();
|
||||
assert!(mem > 0);
|
||||
assert!(mem < template.memory_budget_mb * 2);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user