598 lines
14 KiB
TypeScript
598 lines
14 KiB
TypeScript
/**
|
|
* Training Pipeline for SONA
|
|
*
|
|
* Comprehensive training infrastructure with metrics tracking,
|
|
* learning rate scheduling, and checkpoint management.
|
|
*
|
|
* @example
|
|
* ```typescript
|
|
* import { TrainingPipeline, TrainingConfig } from '@ruvector/ruvllm';
|
|
*
|
|
* const pipeline = new TrainingPipeline({
|
|
* learningRate: 0.001,
|
|
* batchSize: 32,
|
|
* epochs: 10,
|
|
* });
|
|
*
|
|
* // Add training data
|
|
* pipeline.addBatch(inputs, targets, qualities);
|
|
*
|
|
* // Run training
|
|
* const result = pipeline.train();
|
|
* console.log(`Final loss: ${result.finalLoss}`);
|
|
* ```
|
|
*/
|
|
|
|
import { Embedding, TrainingConfig, TrainingResult } from './types';
|
|
import { LoraAdapter } from './lora';
|
|
import { EwcManager } from './sona';
|
|
|
|
/**
|
|
* Default training config
|
|
*/
|
|
const DEFAULT_TRAINING_CONFIG: Required<TrainingConfig> = {
|
|
learningRate: 0.001,
|
|
batchSize: 32,
|
|
epochs: 10,
|
|
scheduler: 'cosine',
|
|
warmupSteps: 100,
|
|
weightDecay: 0.01,
|
|
gradientClip: 1.0,
|
|
earlyStoppingPatience: 3,
|
|
checkpointInterval: 1,
|
|
ewcLambda: 2000,
|
|
validationSplit: 0.1,
|
|
};
|
|
|
|
/**
|
|
* Training metrics
|
|
*/
|
|
export interface TrainingMetrics {
|
|
/** Current epoch */
|
|
epoch: number;
|
|
/** Current step */
|
|
step: number;
|
|
/** Training loss */
|
|
trainLoss: number;
|
|
/** Validation loss */
|
|
valLoss: number;
|
|
/** Learning rate */
|
|
learningRate: number;
|
|
/** Gradient norm */
|
|
gradNorm: number;
|
|
/** Steps per second */
|
|
stepsPerSecond: number;
|
|
/** ETA in seconds */
|
|
etaSeconds: number;
|
|
}
|
|
|
|
/**
|
|
* Training data batch
|
|
*/
|
|
export interface TrainingBatch {
|
|
/** Input embeddings */
|
|
inputs: Embedding[];
|
|
/** Target outputs */
|
|
targets: Embedding[];
|
|
/** Quality scores */
|
|
qualities: number[];
|
|
}
|
|
|
|
/**
|
|
* Checkpoint data
|
|
*/
|
|
export interface Checkpoint {
|
|
/** Epoch number */
|
|
epoch: number;
|
|
/** Step number */
|
|
step: number;
|
|
/** Training loss at checkpoint */
|
|
loss: number;
|
|
/** Model weights (serialized) */
|
|
weights: string;
|
|
/** Timestamp */
|
|
timestamp: number;
|
|
}
|
|
|
|
/**
|
|
* Learning Rate Scheduler
|
|
*/
|
|
export class LRScheduler {
|
|
private config: Required<TrainingConfig>;
|
|
private initialLR: number;
|
|
private currentStep: number = 0;
|
|
private totalSteps: number;
|
|
|
|
constructor(config: Required<TrainingConfig>, totalSteps: number) {
|
|
this.config = config;
|
|
this.initialLR = config.learningRate;
|
|
this.totalSteps = totalSteps;
|
|
}
|
|
|
|
/**
|
|
* Get learning rate for current step
|
|
*/
|
|
getLR(): number {
|
|
switch (this.config.scheduler) {
|
|
case 'constant':
|
|
return this.initialLR;
|
|
|
|
case 'linear':
|
|
return this.initialLR * (1 - this.currentStep / this.totalSteps);
|
|
|
|
case 'cosine':
|
|
return this.initialLR * 0.5 * (1 + Math.cos(Math.PI * this.currentStep / this.totalSteps));
|
|
|
|
case 'warmup':
|
|
if (this.currentStep < this.config.warmupSteps) {
|
|
return this.initialLR * (this.currentStep / this.config.warmupSteps);
|
|
}
|
|
// Cosine decay after warmup
|
|
const decaySteps = this.totalSteps - this.config.warmupSteps;
|
|
const decayProgress = (this.currentStep - this.config.warmupSteps) / decaySteps;
|
|
return this.initialLR * 0.5 * (1 + Math.cos(Math.PI * decayProgress));
|
|
|
|
default:
|
|
return this.initialLR;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Step the scheduler
|
|
*/
|
|
step(): void {
|
|
this.currentStep++;
|
|
}
|
|
|
|
/**
|
|
* Reset scheduler
|
|
*/
|
|
reset(): void {
|
|
this.currentStep = 0;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Training Metrics Tracker
|
|
*/
|
|
export class MetricsTracker {
|
|
private lossHistory: number[] = [];
|
|
private valLossHistory: number[] = [];
|
|
private gradNormHistory: number[] = [];
|
|
private startTime: number = Date.now();
|
|
private stepTimes: number[] = [];
|
|
|
|
/**
|
|
* Record training loss
|
|
*/
|
|
recordLoss(loss: number): void {
|
|
this.lossHistory.push(loss);
|
|
}
|
|
|
|
/**
|
|
* Record validation loss
|
|
*/
|
|
recordValLoss(loss: number): void {
|
|
this.valLossHistory.push(loss);
|
|
}
|
|
|
|
/**
|
|
* Record gradient norm
|
|
*/
|
|
recordGradNorm(norm: number): void {
|
|
this.gradNormHistory.push(norm);
|
|
}
|
|
|
|
/**
|
|
* Record step time
|
|
*/
|
|
recordStepTime(ms: number): void {
|
|
this.stepTimes.push(ms);
|
|
}
|
|
|
|
/**
|
|
* Get average loss over last N steps
|
|
*/
|
|
avgLoss(n: number = 100): number {
|
|
const recent = this.lossHistory.slice(-n);
|
|
return recent.length > 0 ? recent.reduce((a, b) => a + b, 0) / recent.length : 0;
|
|
}
|
|
|
|
/**
|
|
* Get average validation loss
|
|
*/
|
|
avgValLoss(n: number = 10): number {
|
|
const recent = this.valLossHistory.slice(-n);
|
|
return recent.length > 0 ? recent.reduce((a, b) => a + b, 0) / recent.length : 0;
|
|
}
|
|
|
|
/**
|
|
* Get steps per second
|
|
*/
|
|
stepsPerSecond(): number {
|
|
if (this.stepTimes.length === 0) return 0;
|
|
const avgStepTime = this.stepTimes.slice(-100).reduce((a, b) => a + b, 0) / Math.min(this.stepTimes.length, 100);
|
|
return avgStepTime > 0 ? 1000 / avgStepTime : 0;
|
|
}
|
|
|
|
/**
|
|
* Get ETA in seconds
|
|
*/
|
|
eta(remainingSteps: number): number {
|
|
const sps = this.stepsPerSecond();
|
|
return sps > 0 ? remainingSteps / sps : 0;
|
|
}
|
|
|
|
/**
|
|
* Get best validation loss
|
|
*/
|
|
bestValLoss(): number {
|
|
return this.valLossHistory.length > 0 ? Math.min(...this.valLossHistory) : Infinity;
|
|
}
|
|
|
|
/**
|
|
* Get total duration
|
|
*/
|
|
duration(): number {
|
|
return Date.now() - this.startTime;
|
|
}
|
|
|
|
/**
|
|
* Get all loss history
|
|
*/
|
|
getLossHistory(): number[] {
|
|
return [...this.lossHistory];
|
|
}
|
|
|
|
/**
|
|
* Get all validation loss history
|
|
*/
|
|
getValLossHistory(): number[] {
|
|
return [...this.valLossHistory];
|
|
}
|
|
|
|
/**
|
|
* Reset tracker
|
|
*/
|
|
reset(): void {
|
|
this.lossHistory = [];
|
|
this.valLossHistory = [];
|
|
this.gradNormHistory = [];
|
|
this.stepTimes = [];
|
|
this.startTime = Date.now();
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Training Pipeline
|
|
*
|
|
* Full training infrastructure for SONA models.
|
|
*/
|
|
export class TrainingPipeline {
|
|
private config: Required<TrainingConfig>;
|
|
private adapter: LoraAdapter;
|
|
private ewcManager: EwcManager;
|
|
private metrics: MetricsTracker;
|
|
private scheduler: LRScheduler | null = null;
|
|
private batches: TrainingBatch[] = [];
|
|
private checkpoints: Checkpoint[] = [];
|
|
private currentEpoch: number = 0;
|
|
private currentStep: number = 0;
|
|
private bestValLoss: number = Infinity;
|
|
private patienceCounter: number = 0;
|
|
|
|
constructor(config?: TrainingConfig, adapter?: LoraAdapter) {
|
|
this.config = { ...DEFAULT_TRAINING_CONFIG, ...config };
|
|
this.adapter = adapter || new LoraAdapter({ rank: 8 });
|
|
this.ewcManager = new EwcManager(this.config.ewcLambda);
|
|
this.metrics = new MetricsTracker();
|
|
}
|
|
|
|
/**
|
|
* Add training batch
|
|
*/
|
|
addBatch(inputs: Embedding[], targets: Embedding[], qualities: number[]): void {
|
|
this.batches.push({ inputs, targets, qualities });
|
|
}
|
|
|
|
/**
|
|
* Add training data
|
|
*/
|
|
addData(data: Array<{ input: Embedding; target: Embedding; quality: number }>): void {
|
|
// Group into batches
|
|
for (let i = 0; i < data.length; i += this.config.batchSize) {
|
|
const batch = data.slice(i, i + this.config.batchSize);
|
|
this.addBatch(
|
|
batch.map(d => d.input),
|
|
batch.map(d => d.target),
|
|
batch.map(d => d.quality)
|
|
);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Run training
|
|
*/
|
|
train(): TrainingResult {
|
|
const totalSteps = this.batches.length * this.config.epochs;
|
|
this.scheduler = new LRScheduler(this.config, totalSteps);
|
|
this.metrics.reset();
|
|
this.adapter.startTraining(this.config.learningRate);
|
|
|
|
let earlyStopped = false;
|
|
|
|
for (let epoch = 0; epoch < this.config.epochs; epoch++) {
|
|
this.currentEpoch = epoch;
|
|
|
|
// Shuffle batches
|
|
const shuffledBatches = this.shuffleBatches();
|
|
|
|
// Split into train/val
|
|
const valSize = Math.floor(shuffledBatches.length * this.config.validationSplit);
|
|
const trainBatches = shuffledBatches.slice(valSize);
|
|
const valBatches = shuffledBatches.slice(0, valSize);
|
|
|
|
// Training epoch
|
|
for (const batch of trainBatches) {
|
|
const stepStart = Date.now();
|
|
const loss = this.trainStep(batch);
|
|
this.metrics.recordLoss(loss);
|
|
this.metrics.recordStepTime(Date.now() - stepStart);
|
|
this.scheduler.step();
|
|
this.currentStep++;
|
|
}
|
|
|
|
// Validation
|
|
if (valBatches.length > 0) {
|
|
const valLoss = this.validate(valBatches);
|
|
this.metrics.recordValLoss(valLoss);
|
|
|
|
// Early stopping
|
|
if (valLoss < this.bestValLoss) {
|
|
this.bestValLoss = valLoss;
|
|
this.patienceCounter = 0;
|
|
} else {
|
|
this.patienceCounter++;
|
|
if (this.patienceCounter >= this.config.earlyStoppingPatience) {
|
|
earlyStopped = true;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Checkpoint
|
|
if ((epoch + 1) % this.config.checkpointInterval === 0) {
|
|
this.saveCheckpoint();
|
|
}
|
|
}
|
|
|
|
this.adapter.endTraining();
|
|
|
|
// Register with EWC for continual learning
|
|
const weights = this.adapter.merge().flat();
|
|
this.ewcManager.registerTask(`task-${Date.now()}`, weights);
|
|
|
|
return {
|
|
epochs: this.currentEpoch + 1,
|
|
steps: this.currentStep,
|
|
finalLoss: this.metrics.avgLoss(100),
|
|
bestValLoss: this.bestValLoss,
|
|
durationMs: this.metrics.duration(),
|
|
lossHistory: this.metrics.getLossHistory(),
|
|
valLossHistory: this.metrics.getValLossHistory(),
|
|
earlyStopped,
|
|
};
|
|
}
|
|
|
|
/**
|
|
* Single training step
|
|
*/
|
|
private trainStep(batch: TrainingBatch): number {
|
|
let totalLoss = 0;
|
|
const lr = this.scheduler?.getLR() || this.config.learningRate;
|
|
|
|
for (let i = 0; i < batch.inputs.length; i++) {
|
|
const input = batch.inputs[i];
|
|
const target = batch.targets[i];
|
|
const quality = batch.qualities[i];
|
|
|
|
// Forward pass
|
|
const output = this.adapter.forward(input);
|
|
|
|
// Compute loss (MSE weighted by quality)
|
|
const gradOutput: number[] = [];
|
|
let loss = 0;
|
|
for (let j = 0; j < output.length; j++) {
|
|
const diff = output[j] - (target[j] || 0);
|
|
loss += diff * diff;
|
|
gradOutput.push(2 * diff * quality); // Quality-weighted gradient
|
|
}
|
|
loss = (loss / output.length) * quality;
|
|
|
|
// Add EWC penalty
|
|
const ewcPenalty = this.ewcManager.computePenalty(this.adapter.merge().flat());
|
|
loss += ewcPenalty * 0.001;
|
|
|
|
// Backward pass
|
|
this.adapter.backward(input, gradOutput, lr);
|
|
|
|
totalLoss += loss;
|
|
}
|
|
|
|
return totalLoss / batch.inputs.length;
|
|
}
|
|
|
|
/**
|
|
* Validation pass
|
|
*/
|
|
private validate(batches: TrainingBatch[]): number {
|
|
let totalLoss = 0;
|
|
let count = 0;
|
|
|
|
for (const batch of batches) {
|
|
for (let i = 0; i < batch.inputs.length; i++) {
|
|
const output = this.adapter.forward(batch.inputs[i]);
|
|
const target = batch.targets[i];
|
|
|
|
let loss = 0;
|
|
for (let j = 0; j < output.length; j++) {
|
|
const diff = output[j] - (target[j] || 0);
|
|
loss += diff * diff;
|
|
}
|
|
totalLoss += loss / output.length;
|
|
count++;
|
|
}
|
|
}
|
|
|
|
return count > 0 ? totalLoss / count : 0;
|
|
}
|
|
|
|
/**
|
|
* Save checkpoint
|
|
*/
|
|
private saveCheckpoint(): void {
|
|
this.checkpoints.push({
|
|
epoch: this.currentEpoch,
|
|
step: this.currentStep,
|
|
loss: this.metrics.avgLoss(100),
|
|
weights: this.adapter.toJSON(),
|
|
timestamp: Date.now(),
|
|
});
|
|
}
|
|
|
|
/**
|
|
* Load checkpoint
|
|
*/
|
|
loadCheckpoint(index: number): boolean {
|
|
const checkpoint = this.checkpoints[index];
|
|
if (!checkpoint) return false;
|
|
|
|
this.adapter = LoraAdapter.fromJSON(checkpoint.weights);
|
|
this.currentEpoch = checkpoint.epoch;
|
|
this.currentStep = checkpoint.step;
|
|
return true;
|
|
}
|
|
|
|
/**
|
|
* Get current metrics
|
|
*/
|
|
getMetrics(): TrainingMetrics {
|
|
return {
|
|
epoch: this.currentEpoch,
|
|
step: this.currentStep,
|
|
trainLoss: this.metrics.avgLoss(100),
|
|
valLoss: this.metrics.avgValLoss(10),
|
|
learningRate: this.scheduler?.getLR() || this.config.learningRate,
|
|
gradNorm: 0,
|
|
stepsPerSecond: this.metrics.stepsPerSecond(),
|
|
etaSeconds: this.metrics.eta(
|
|
(this.config.epochs - this.currentEpoch) * this.batches.length
|
|
),
|
|
};
|
|
}
|
|
|
|
/**
|
|
* Get adapter
|
|
*/
|
|
getAdapter(): LoraAdapter {
|
|
return this.adapter;
|
|
}
|
|
|
|
/**
|
|
* Get EWC manager
|
|
*/
|
|
getEwcManager(): EwcManager {
|
|
return this.ewcManager;
|
|
}
|
|
|
|
/**
|
|
* Get checkpoints
|
|
*/
|
|
getCheckpoints(): Checkpoint[] {
|
|
return [...this.checkpoints];
|
|
}
|
|
|
|
/**
|
|
* Reset pipeline
|
|
*/
|
|
reset(): void {
|
|
this.batches = [];
|
|
this.checkpoints = [];
|
|
this.currentEpoch = 0;
|
|
this.currentStep = 0;
|
|
this.bestValLoss = Infinity;
|
|
this.patienceCounter = 0;
|
|
this.metrics.reset();
|
|
this.adapter.reset();
|
|
}
|
|
|
|
private shuffleBatches(): TrainingBatch[] {
|
|
const shuffled = [...this.batches];
|
|
for (let i = shuffled.length - 1; i > 0; i--) {
|
|
const j = Math.floor(Math.random() * (i + 1));
|
|
[shuffled[i], shuffled[j]] = [shuffled[j], shuffled[i]];
|
|
}
|
|
return shuffled;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Training Factory
|
|
*
|
|
* Create pre-configured training pipelines for common scenarios.
|
|
*/
|
|
export class TrainingFactory {
|
|
/**
|
|
* Create pipeline for quick fine-tuning
|
|
*/
|
|
static quickFinetune(): TrainingPipeline {
|
|
return new TrainingPipeline({
|
|
learningRate: 0.01,
|
|
epochs: 3,
|
|
batchSize: 16,
|
|
scheduler: 'constant',
|
|
});
|
|
}
|
|
|
|
/**
|
|
* Create pipeline for deep training
|
|
*/
|
|
static deepTraining(): TrainingPipeline {
|
|
return new TrainingPipeline({
|
|
learningRate: 0.001,
|
|
epochs: 50,
|
|
batchSize: 32,
|
|
scheduler: 'warmup',
|
|
warmupSteps: 500,
|
|
earlyStoppingPatience: 5,
|
|
});
|
|
}
|
|
|
|
/**
|
|
* Create pipeline for continual learning
|
|
*/
|
|
static continualLearning(ewcLambda: number = 5000): TrainingPipeline {
|
|
return new TrainingPipeline({
|
|
learningRate: 0.0005,
|
|
epochs: 10,
|
|
batchSize: 16,
|
|
scheduler: 'cosine',
|
|
ewcLambda,
|
|
earlyStoppingPatience: 10,
|
|
});
|
|
}
|
|
|
|
/**
|
|
* Create pipeline for federated aggregation
|
|
*/
|
|
static federatedAggregation(): TrainingPipeline {
|
|
return new TrainingPipeline({
|
|
learningRate: 0.0001,
|
|
epochs: 5,
|
|
batchSize: 64,
|
|
scheduler: 'linear',
|
|
ewcLambda: 2000,
|
|
});
|
|
}
|
|
}
|