Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
412
vendor/ruvector/crates/ruvector-attention-wasm/js/index.ts
vendored
Normal file
412
vendor/ruvector/crates/ruvector-attention-wasm/js/index.ts
vendored
Normal file
@@ -0,0 +1,412 @@
|
||||
/**
|
||||
* TypeScript wrapper for ruvector-attention-wasm
|
||||
* Provides a clean, type-safe API for attention mechanisms
|
||||
*/
|
||||
|
||||
import init, * as wasm from '../pkg/ruvector_attention_wasm';
|
||||
import type {
|
||||
AttentionConfig,
|
||||
MultiHeadConfig,
|
||||
HyperbolicConfig,
|
||||
LinearAttentionConfig,
|
||||
FlashAttentionConfig,
|
||||
LocalGlobalConfig,
|
||||
MoEConfig,
|
||||
TrainingConfig,
|
||||
SchedulerConfig,
|
||||
ExpertStats,
|
||||
AttentionType,
|
||||
} from './types';
|
||||
|
||||
export * from './types';
|
||||
|
||||
let initialized = false;
|
||||
|
||||
/**
|
||||
* Initialize the WASM module
|
||||
* Must be called before using any attention mechanisms
|
||||
*/
|
||||
export async function initialize(): Promise<void> {
|
||||
if (!initialized) {
|
||||
await init();
|
||||
initialized = true;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the version of the ruvector-attention-wasm package
|
||||
*/
|
||||
export function version(): string {
|
||||
return wasm.version();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get list of available attention mechanisms
|
||||
*/
|
||||
export function availableMechanisms(): AttentionType[] {
|
||||
return wasm.available_mechanisms() as AttentionType[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Multi-head attention mechanism
|
||||
*/
|
||||
export class MultiHeadAttention {
|
||||
private inner: wasm.WasmMultiHeadAttention;
|
||||
|
||||
constructor(config: MultiHeadConfig) {
|
||||
this.inner = new wasm.WasmMultiHeadAttention(config.dim, config.numHeads);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute multi-head attention
|
||||
*/
|
||||
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
|
||||
const result = this.inner.compute(query, keys, values);
|
||||
return new Float32Array(result);
|
||||
}
|
||||
|
||||
get numHeads(): number {
|
||||
return this.inner.num_heads;
|
||||
}
|
||||
|
||||
get dim(): number {
|
||||
return this.inner.dim;
|
||||
}
|
||||
|
||||
free(): void {
|
||||
this.inner.free();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Hyperbolic attention mechanism
|
||||
*/
|
||||
export class HyperbolicAttention {
|
||||
private inner: wasm.WasmHyperbolicAttention;
|
||||
|
||||
constructor(config: HyperbolicConfig) {
|
||||
this.inner = new wasm.WasmHyperbolicAttention(config.dim, config.curvature);
|
||||
}
|
||||
|
||||
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
|
||||
const result = this.inner.compute(query, keys, values);
|
||||
return new Float32Array(result);
|
||||
}
|
||||
|
||||
get curvature(): number {
|
||||
return this.inner.curvature;
|
||||
}
|
||||
|
||||
free(): void {
|
||||
this.inner.free();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Linear attention (Performer-style)
|
||||
*/
|
||||
export class LinearAttention {
|
||||
private inner: wasm.WasmLinearAttention;
|
||||
|
||||
constructor(config: LinearAttentionConfig) {
|
||||
this.inner = new wasm.WasmLinearAttention(config.dim, config.numFeatures);
|
||||
}
|
||||
|
||||
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
|
||||
const result = this.inner.compute(query, keys, values);
|
||||
return new Float32Array(result);
|
||||
}
|
||||
|
||||
free(): void {
|
||||
this.inner.free();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Flash attention mechanism
|
||||
*/
|
||||
export class FlashAttention {
|
||||
private inner: wasm.WasmFlashAttention;
|
||||
|
||||
constructor(config: FlashAttentionConfig) {
|
||||
this.inner = new wasm.WasmFlashAttention(config.dim, config.blockSize);
|
||||
}
|
||||
|
||||
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
|
||||
const result = this.inner.compute(query, keys, values);
|
||||
return new Float32Array(result);
|
||||
}
|
||||
|
||||
free(): void {
|
||||
this.inner.free();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Local-global attention mechanism
|
||||
*/
|
||||
export class LocalGlobalAttention {
|
||||
private inner: wasm.WasmLocalGlobalAttention;
|
||||
|
||||
constructor(config: LocalGlobalConfig) {
|
||||
this.inner = new wasm.WasmLocalGlobalAttention(
|
||||
config.dim,
|
||||
config.localWindow,
|
||||
config.globalTokens
|
||||
);
|
||||
}
|
||||
|
||||
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
|
||||
const result = this.inner.compute(query, keys, values);
|
||||
return new Float32Array(result);
|
||||
}
|
||||
|
||||
free(): void {
|
||||
this.inner.free();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Mixture of Experts attention
|
||||
*/
|
||||
export class MoEAttention {
|
||||
private inner: wasm.WasmMoEAttention;
|
||||
|
||||
constructor(config: MoEConfig) {
|
||||
this.inner = new wasm.WasmMoEAttention(config.dim, config.numExperts, config.topK);
|
||||
}
|
||||
|
||||
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
|
||||
const result = this.inner.compute(query, keys, values);
|
||||
return new Float32Array(result);
|
||||
}
|
||||
|
||||
getExpertStats(): ExpertStats {
|
||||
return this.inner.expert_stats() as ExpertStats;
|
||||
}
|
||||
|
||||
free(): void {
|
||||
this.inner.free();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* InfoNCE contrastive loss
|
||||
*/
|
||||
export class InfoNCELoss {
|
||||
private inner: wasm.WasmInfoNCELoss;
|
||||
|
||||
constructor(temperature: number = 0.07) {
|
||||
this.inner = new wasm.WasmInfoNCELoss(temperature);
|
||||
}
|
||||
|
||||
compute(anchor: Float32Array, positive: Float32Array, negatives: Float32Array[]): number {
|
||||
return this.inner.compute(anchor, positive, negatives);
|
||||
}
|
||||
|
||||
computeMultiPositive(
|
||||
anchor: Float32Array,
|
||||
positives: Float32Array[],
|
||||
negatives: Float32Array[]
|
||||
): number {
|
||||
return this.inner.compute_multi_positive(anchor, positives, negatives);
|
||||
}
|
||||
|
||||
free(): void {
|
||||
this.inner.free();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Adam optimizer
|
||||
*/
|
||||
export class Adam {
|
||||
private inner: wasm.WasmAdam;
|
||||
|
||||
constructor(paramCount: number, config: TrainingConfig) {
|
||||
this.inner = new wasm.WasmAdam(
|
||||
paramCount,
|
||||
config.learningRate,
|
||||
config.beta1,
|
||||
config.beta2,
|
||||
config.epsilon
|
||||
);
|
||||
}
|
||||
|
||||
step(params: Float32Array, gradients: Float32Array): void {
|
||||
this.inner.step(params, gradients);
|
||||
}
|
||||
|
||||
reset(): void {
|
||||
this.inner.reset();
|
||||
}
|
||||
|
||||
get learningRate(): number {
|
||||
return this.inner.learning_rate;
|
||||
}
|
||||
|
||||
set learningRate(lr: number) {
|
||||
this.inner.learning_rate = lr;
|
||||
}
|
||||
|
||||
free(): void {
|
||||
this.inner.free();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* AdamW optimizer (Adam with decoupled weight decay)
|
||||
*/
|
||||
export class AdamW {
|
||||
private inner: wasm.WasmAdamW;
|
||||
|
||||
constructor(paramCount: number, config: TrainingConfig) {
|
||||
if (!config.weightDecay) {
|
||||
throw new Error('AdamW requires weightDecay parameter');
|
||||
}
|
||||
|
||||
this.inner = new wasm.WasmAdamW(
|
||||
paramCount,
|
||||
config.learningRate,
|
||||
config.weightDecay,
|
||||
config.beta1,
|
||||
config.beta2,
|
||||
config.epsilon
|
||||
);
|
||||
}
|
||||
|
||||
step(params: Float32Array, gradients: Float32Array): void {
|
||||
this.inner.step(params, gradients);
|
||||
}
|
||||
|
||||
reset(): void {
|
||||
this.inner.reset();
|
||||
}
|
||||
|
||||
get learningRate(): number {
|
||||
return this.inner.learning_rate;
|
||||
}
|
||||
|
||||
set learningRate(lr: number) {
|
||||
this.inner.learning_rate = lr;
|
||||
}
|
||||
|
||||
get weightDecay(): number {
|
||||
return this.inner.weight_decay;
|
||||
}
|
||||
|
||||
free(): void {
|
||||
this.inner.free();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Learning rate scheduler with warmup and cosine decay
|
||||
*/
|
||||
export class LRScheduler {
|
||||
private inner: wasm.WasmLRScheduler;
|
||||
|
||||
constructor(config: SchedulerConfig) {
|
||||
this.inner = new wasm.WasmLRScheduler(
|
||||
config.initialLR,
|
||||
config.warmupSteps,
|
||||
config.totalSteps
|
||||
);
|
||||
}
|
||||
|
||||
getLR(): number {
|
||||
return this.inner.get_lr();
|
||||
}
|
||||
|
||||
step(): void {
|
||||
this.inner.step();
|
||||
}
|
||||
|
||||
reset(): void {
|
||||
this.inner.reset();
|
||||
}
|
||||
|
||||
free(): void {
|
||||
this.inner.free();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Utility functions
|
||||
*/
|
||||
export const utils = {
|
||||
/**
|
||||
* Compute cosine similarity between two vectors
|
||||
*/
|
||||
cosineSimilarity(a: Float32Array, b: Float32Array): number {
|
||||
return wasm.cosine_similarity(a, b);
|
||||
},
|
||||
|
||||
/**
|
||||
* Compute L2 norm of a vector
|
||||
*/
|
||||
l2Norm(vec: Float32Array): number {
|
||||
return wasm.l2_norm(vec);
|
||||
},
|
||||
|
||||
/**
|
||||
* Normalize a vector to unit length (in-place)
|
||||
*/
|
||||
normalize(vec: Float32Array): void {
|
||||
wasm.normalize(vec);
|
||||
},
|
||||
|
||||
/**
|
||||
* Apply softmax to a vector (in-place)
|
||||
*/
|
||||
softmax(vec: Float32Array): void {
|
||||
wasm.softmax(vec);
|
||||
},
|
||||
|
||||
/**
|
||||
* Compute attention weights from scores (in-place)
|
||||
*/
|
||||
attentionWeights(scores: Float32Array, temperature?: number): void {
|
||||
wasm.attention_weights(scores, temperature);
|
||||
},
|
||||
|
||||
/**
|
||||
* Batch normalize vectors
|
||||
*/
|
||||
batchNormalize(vectors: Float32Array[], epsilon?: number): Float32Array {
|
||||
const result = wasm.batch_normalize(vectors, epsilon);
|
||||
return new Float32Array(result);
|
||||
},
|
||||
|
||||
/**
|
||||
* Generate random orthogonal matrix
|
||||
*/
|
||||
randomOrthogonalMatrix(dim: number): Float32Array {
|
||||
const result = wasm.random_orthogonal_matrix(dim);
|
||||
return new Float32Array(result);
|
||||
},
|
||||
|
||||
/**
|
||||
* Compute pairwise distances between vectors
|
||||
*/
|
||||
pairwiseDistances(vectors: Float32Array[]): Float32Array {
|
||||
const result = wasm.pairwise_distances(vectors);
|
||||
return new Float32Array(result);
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
* Simple scaled dot-product attention (functional API)
|
||||
*/
|
||||
export function scaledDotAttention(
|
||||
query: Float32Array,
|
||||
keys: Float32Array[],
|
||||
values: Float32Array[],
|
||||
scale?: number
|
||||
): Float32Array {
|
||||
const result = wasm.scaled_dot_attention(query, keys, values, scale);
|
||||
return new Float32Array(result);
|
||||
}
|
||||
|
||||
// Re-export WASM module for advanced usage
|
||||
export { wasm };
|
||||
Reference in New Issue
Block a user