/** * Training Pipeline for SONA * * Comprehensive training infrastructure with metrics tracking, * learning rate scheduling, and checkpoint management. * * @example * ```typescript * import { TrainingPipeline, TrainingConfig } from '@ruvector/ruvllm'; * * const pipeline = new TrainingPipeline({ * learningRate: 0.001, * batchSize: 32, * epochs: 10, * }); * * // Add training data * pipeline.addBatch(inputs, targets, qualities); * * // Run training * const result = pipeline.train(); * console.log(`Final loss: ${result.finalLoss}`); * ``` */ import { Embedding, TrainingConfig, TrainingResult } from './types'; import { LoraAdapter } from './lora'; import { EwcManager } from './sona'; /** * Default training config */ const DEFAULT_TRAINING_CONFIG: Required = { learningRate: 0.001, batchSize: 32, epochs: 10, scheduler: 'cosine', warmupSteps: 100, weightDecay: 0.01, gradientClip: 1.0, earlyStoppingPatience: 3, checkpointInterval: 1, ewcLambda: 2000, validationSplit: 0.1, }; /** * Training metrics */ export interface TrainingMetrics { /** Current epoch */ epoch: number; /** Current step */ step: number; /** Training loss */ trainLoss: number; /** Validation loss */ valLoss: number; /** Learning rate */ learningRate: number; /** Gradient norm */ gradNorm: number; /** Steps per second */ stepsPerSecond: number; /** ETA in seconds */ etaSeconds: number; } /** * Training data batch */ export interface TrainingBatch { /** Input embeddings */ inputs: Embedding[]; /** Target outputs */ targets: Embedding[]; /** Quality scores */ qualities: number[]; } /** * Checkpoint data */ export interface Checkpoint { /** Epoch number */ epoch: number; /** Step number */ step: number; /** Training loss at checkpoint */ loss: number; /** Model weights (serialized) */ weights: string; /** Timestamp */ timestamp: number; } /** * Learning Rate Scheduler */ export class LRScheduler { private config: Required; private initialLR: number; private currentStep: number = 0; private totalSteps: number; constructor(config: Required, totalSteps: number) { this.config = config; this.initialLR = config.learningRate; this.totalSteps = totalSteps; } /** * Get learning rate for current step */ getLR(): number { switch (this.config.scheduler) { case 'constant': return this.initialLR; case 'linear': return this.initialLR * (1 - this.currentStep / this.totalSteps); case 'cosine': return this.initialLR * 0.5 * (1 + Math.cos(Math.PI * this.currentStep / this.totalSteps)); case 'warmup': if (this.currentStep < this.config.warmupSteps) { return this.initialLR * (this.currentStep / this.config.warmupSteps); } // Cosine decay after warmup const decaySteps = this.totalSteps - this.config.warmupSteps; const decayProgress = (this.currentStep - this.config.warmupSteps) / decaySteps; return this.initialLR * 0.5 * (1 + Math.cos(Math.PI * decayProgress)); default: return this.initialLR; } } /** * Step the scheduler */ step(): void { this.currentStep++; } /** * Reset scheduler */ reset(): void { this.currentStep = 0; } } /** * Training Metrics Tracker */ export class MetricsTracker { private lossHistory: number[] = []; private valLossHistory: number[] = []; private gradNormHistory: number[] = []; private startTime: number = Date.now(); private stepTimes: number[] = []; /** * Record training loss */ recordLoss(loss: number): void { this.lossHistory.push(loss); } /** * Record validation loss */ recordValLoss(loss: number): void { this.valLossHistory.push(loss); } /** * Record gradient norm */ recordGradNorm(norm: number): void { this.gradNormHistory.push(norm); } /** * Record step time */ recordStepTime(ms: number): void { this.stepTimes.push(ms); } /** * Get average loss over last N steps */ avgLoss(n: number = 100): number { const recent = this.lossHistory.slice(-n); return recent.length > 0 ? recent.reduce((a, b) => a + b, 0) / recent.length : 0; } /** * Get average validation loss */ avgValLoss(n: number = 10): number { const recent = this.valLossHistory.slice(-n); return recent.length > 0 ? recent.reduce((a, b) => a + b, 0) / recent.length : 0; } /** * Get steps per second */ stepsPerSecond(): number { if (this.stepTimes.length === 0) return 0; const avgStepTime = this.stepTimes.slice(-100).reduce((a, b) => a + b, 0) / Math.min(this.stepTimes.length, 100); return avgStepTime > 0 ? 1000 / avgStepTime : 0; } /** * Get ETA in seconds */ eta(remainingSteps: number): number { const sps = this.stepsPerSecond(); return sps > 0 ? remainingSteps / sps : 0; } /** * Get best validation loss */ bestValLoss(): number { return this.valLossHistory.length > 0 ? Math.min(...this.valLossHistory) : Infinity; } /** * Get total duration */ duration(): number { return Date.now() - this.startTime; } /** * Get all loss history */ getLossHistory(): number[] { return [...this.lossHistory]; } /** * Get all validation loss history */ getValLossHistory(): number[] { return [...this.valLossHistory]; } /** * Reset tracker */ reset(): void { this.lossHistory = []; this.valLossHistory = []; this.gradNormHistory = []; this.stepTimes = []; this.startTime = Date.now(); } } /** * Training Pipeline * * Full training infrastructure for SONA models. */ export class TrainingPipeline { private config: Required; private adapter: LoraAdapter; private ewcManager: EwcManager; private metrics: MetricsTracker; private scheduler: LRScheduler | null = null; private batches: TrainingBatch[] = []; private checkpoints: Checkpoint[] = []; private currentEpoch: number = 0; private currentStep: number = 0; private bestValLoss: number = Infinity; private patienceCounter: number = 0; constructor(config?: TrainingConfig, adapter?: LoraAdapter) { this.config = { ...DEFAULT_TRAINING_CONFIG, ...config }; this.adapter = adapter || new LoraAdapter({ rank: 8 }); this.ewcManager = new EwcManager(this.config.ewcLambda); this.metrics = new MetricsTracker(); } /** * Add training batch */ addBatch(inputs: Embedding[], targets: Embedding[], qualities: number[]): void { this.batches.push({ inputs, targets, qualities }); } /** * Add training data */ addData(data: Array<{ input: Embedding; target: Embedding; quality: number }>): void { // Group into batches for (let i = 0; i < data.length; i += this.config.batchSize) { const batch = data.slice(i, i + this.config.batchSize); this.addBatch( batch.map(d => d.input), batch.map(d => d.target), batch.map(d => d.quality) ); } } /** * Run training */ train(): TrainingResult { const totalSteps = this.batches.length * this.config.epochs; this.scheduler = new LRScheduler(this.config, totalSteps); this.metrics.reset(); this.adapter.startTraining(this.config.learningRate); let earlyStopped = false; for (let epoch = 0; epoch < this.config.epochs; epoch++) { this.currentEpoch = epoch; // Shuffle batches const shuffledBatches = this.shuffleBatches(); // Split into train/val const valSize = Math.floor(shuffledBatches.length * this.config.validationSplit); const trainBatches = shuffledBatches.slice(valSize); const valBatches = shuffledBatches.slice(0, valSize); // Training epoch for (const batch of trainBatches) { const stepStart = Date.now(); const loss = this.trainStep(batch); this.metrics.recordLoss(loss); this.metrics.recordStepTime(Date.now() - stepStart); this.scheduler.step(); this.currentStep++; } // Validation if (valBatches.length > 0) { const valLoss = this.validate(valBatches); this.metrics.recordValLoss(valLoss); // Early stopping if (valLoss < this.bestValLoss) { this.bestValLoss = valLoss; this.patienceCounter = 0; } else { this.patienceCounter++; if (this.patienceCounter >= this.config.earlyStoppingPatience) { earlyStopped = true; break; } } } // Checkpoint if ((epoch + 1) % this.config.checkpointInterval === 0) { this.saveCheckpoint(); } } this.adapter.endTraining(); // Register with EWC for continual learning const weights = this.adapter.merge().flat(); this.ewcManager.registerTask(`task-${Date.now()}`, weights); return { epochs: this.currentEpoch + 1, steps: this.currentStep, finalLoss: this.metrics.avgLoss(100), bestValLoss: this.bestValLoss, durationMs: this.metrics.duration(), lossHistory: this.metrics.getLossHistory(), valLossHistory: this.metrics.getValLossHistory(), earlyStopped, }; } /** * Single training step */ private trainStep(batch: TrainingBatch): number { let totalLoss = 0; const lr = this.scheduler?.getLR() || this.config.learningRate; for (let i = 0; i < batch.inputs.length; i++) { const input = batch.inputs[i]; const target = batch.targets[i]; const quality = batch.qualities[i]; // Forward pass const output = this.adapter.forward(input); // Compute loss (MSE weighted by quality) const gradOutput: number[] = []; let loss = 0; for (let j = 0; j < output.length; j++) { const diff = output[j] - (target[j] || 0); loss += diff * diff; gradOutput.push(2 * diff * quality); // Quality-weighted gradient } loss = (loss / output.length) * quality; // Add EWC penalty const ewcPenalty = this.ewcManager.computePenalty(this.adapter.merge().flat()); loss += ewcPenalty * 0.001; // Backward pass this.adapter.backward(input, gradOutput, lr); totalLoss += loss; } return totalLoss / batch.inputs.length; } /** * Validation pass */ private validate(batches: TrainingBatch[]): number { let totalLoss = 0; let count = 0; for (const batch of batches) { for (let i = 0; i < batch.inputs.length; i++) { const output = this.adapter.forward(batch.inputs[i]); const target = batch.targets[i]; let loss = 0; for (let j = 0; j < output.length; j++) { const diff = output[j] - (target[j] || 0); loss += diff * diff; } totalLoss += loss / output.length; count++; } } return count > 0 ? totalLoss / count : 0; } /** * Save checkpoint */ private saveCheckpoint(): void { this.checkpoints.push({ epoch: this.currentEpoch, step: this.currentStep, loss: this.metrics.avgLoss(100), weights: this.adapter.toJSON(), timestamp: Date.now(), }); } /** * Load checkpoint */ loadCheckpoint(index: number): boolean { const checkpoint = this.checkpoints[index]; if (!checkpoint) return false; this.adapter = LoraAdapter.fromJSON(checkpoint.weights); this.currentEpoch = checkpoint.epoch; this.currentStep = checkpoint.step; return true; } /** * Get current metrics */ getMetrics(): TrainingMetrics { return { epoch: this.currentEpoch, step: this.currentStep, trainLoss: this.metrics.avgLoss(100), valLoss: this.metrics.avgValLoss(10), learningRate: this.scheduler?.getLR() || this.config.learningRate, gradNorm: 0, stepsPerSecond: this.metrics.stepsPerSecond(), etaSeconds: this.metrics.eta( (this.config.epochs - this.currentEpoch) * this.batches.length ), }; } /** * Get adapter */ getAdapter(): LoraAdapter { return this.adapter; } /** * Get EWC manager */ getEwcManager(): EwcManager { return this.ewcManager; } /** * Get checkpoints */ getCheckpoints(): Checkpoint[] { return [...this.checkpoints]; } /** * Reset pipeline */ reset(): void { this.batches = []; this.checkpoints = []; this.currentEpoch = 0; this.currentStep = 0; this.bestValLoss = Infinity; this.patienceCounter = 0; this.metrics.reset(); this.adapter.reset(); } private shuffleBatches(): TrainingBatch[] { const shuffled = [...this.batches]; for (let i = shuffled.length - 1; i > 0; i--) { const j = Math.floor(Math.random() * (i + 1)); [shuffled[i], shuffled[j]] = [shuffled[j], shuffled[i]]; } return shuffled; } } /** * Training Factory * * Create pre-configured training pipelines for common scenarios. */ export class TrainingFactory { /** * Create pipeline for quick fine-tuning */ static quickFinetune(): TrainingPipeline { return new TrainingPipeline({ learningRate: 0.01, epochs: 3, batchSize: 16, scheduler: 'constant', }); } /** * Create pipeline for deep training */ static deepTraining(): TrainingPipeline { return new TrainingPipeline({ learningRate: 0.001, epochs: 50, batchSize: 32, scheduler: 'warmup', warmupSteps: 500, earlyStoppingPatience: 5, }); } /** * Create pipeline for continual learning */ static continualLearning(ewcLambda: number = 5000): TrainingPipeline { return new TrainingPipeline({ learningRate: 0.0005, epochs: 10, batchSize: 16, scheduler: 'cosine', ewcLambda, earlyStoppingPatience: 10, }); } /** * Create pipeline for federated aggregation */ static federatedAggregation(): TrainingPipeline { return new TrainingPipeline({ learningRate: 0.0001, epochs: 5, batchSize: 64, scheduler: 'linear', ewcLambda: 2000, }); } }