Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
26
crates/ruvector-attention-wasm/.gitignore
vendored
Normal file
26
crates/ruvector-attention-wasm/.gitignore
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
# WASM build artifacts
|
||||
pkg/
|
||||
pkg-node/
|
||||
pkg-bundler/
|
||||
|
||||
# Dependencies
|
||||
node_modules/
|
||||
|
||||
# Build artifacts
|
||||
target/
|
||||
Cargo.lock
|
||||
|
||||
# Editor files
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# OS files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Test coverage
|
||||
coverage/
|
||||
*.profraw
|
||||
40
crates/ruvector-attention-wasm/Cargo.toml
Normal file
40
crates/ruvector-attention-wasm/Cargo.toml
Normal file
@@ -0,0 +1,40 @@
|
||||
[package]
|
||||
name = "ruvector-attention-wasm"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
description = "High-performance WebAssembly attention mechanisms: Multi-Head, Flash, Hyperbolic, MoE, CGT Sheaf Attention with GPU acceleration for transformers and LLMs"
|
||||
homepage = "https://ruv.io/ruvector"
|
||||
documentation = "https://docs.rs/ruvector-attention-wasm"
|
||||
keywords = ["wasm", "attention", "transformer", "flash-attention", "llm"]
|
||||
categories = ["wasm", "algorithms", "science"]
|
||||
readme = "README.md"
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib", "rlib"]
|
||||
|
||||
[dependencies]
|
||||
ruvector-attention = { version = "2.0", path = "../ruvector-attention", default-features = false, features = ["wasm"] }
|
||||
wasm-bindgen = "0.2"
|
||||
js-sys = "0.3"
|
||||
web-sys = { version = "0.3", features = ["console"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde-wasm-bindgen = "0.6"
|
||||
console_error_panic_hook = { version = "0.1", optional = true }
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
|
||||
[dev-dependencies]
|
||||
wasm-bindgen-test = "0.3"
|
||||
|
||||
[features]
|
||||
default = ["console_error_panic_hook"]
|
||||
|
||||
[profile.release]
|
||||
opt-level = "s"
|
||||
lto = true
|
||||
codegen-units = 1
|
||||
|
||||
[package.metadata.wasm-pack.profile.release]
|
||||
wasm-opt = false
|
||||
21
crates/ruvector-attention-wasm/LICENSE
Normal file
21
crates/ruvector-attention-wasm/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 rUv
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
220
crates/ruvector-attention-wasm/README.md
Normal file
220
crates/ruvector-attention-wasm/README.md
Normal file
@@ -0,0 +1,220 @@
|
||||
# ruvector-attention-wasm
|
||||
|
||||
WebAssembly bindings for the ruvector-attention package, providing high-performance attention mechanisms for browser and Node.js environments.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multiple Attention Mechanisms**:
|
||||
- Scaled Dot-Product Attention
|
||||
- Multi-Head Attention
|
||||
- Hyperbolic Attention (for hierarchical data)
|
||||
- Linear Attention (Performer-style)
|
||||
- Flash Attention (memory-efficient)
|
||||
- Local-Global Attention
|
||||
- Mixture of Experts (MoE) Attention
|
||||
- **CGT Sheaf Attention** (coherence-gated via Prime-Radiant)
|
||||
|
||||
- **Training Utilities**:
|
||||
- InfoNCE contrastive loss
|
||||
- Adam optimizer
|
||||
- AdamW optimizer (with decoupled weight decay)
|
||||
- Learning rate scheduler (warmup + cosine decay)
|
||||
|
||||
- **TypeScript Support**: Full type definitions and modern API
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
npm install ruvector-attention-wasm
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### TypeScript/JavaScript
|
||||
|
||||
```typescript
|
||||
import { initialize, MultiHeadAttention, utils } from 'ruvector-attention-wasm';
|
||||
|
||||
// Initialize WASM module
|
||||
await initialize();
|
||||
|
||||
// Create multi-head attention
|
||||
const attention = new MultiHeadAttention({ dim: 64, numHeads: 8 });
|
||||
|
||||
// Prepare inputs
|
||||
const query = new Float32Array(64);
|
||||
const keys = [new Float32Array(64), new Float32Array(64)];
|
||||
const values = [new Float32Array(64), new Float32Array(64)];
|
||||
|
||||
// Compute attention
|
||||
const output = attention.compute(query, keys, values);
|
||||
|
||||
// Use utilities
|
||||
const similarity = utils.cosineSimilarity(query, keys[0]);
|
||||
```
|
||||
|
||||
### Advanced Examples
|
||||
|
||||
#### Hyperbolic Attention
|
||||
|
||||
```typescript
|
||||
import { HyperbolicAttention } from 'ruvector-attention-wasm';
|
||||
|
||||
const hyperbolic = new HyperbolicAttention({
|
||||
dim: 128,
|
||||
curvature: 1.0
|
||||
});
|
||||
|
||||
const output = hyperbolic.compute(query, keys, values);
|
||||
```
|
||||
|
||||
#### MoE Attention with Expert Stats
|
||||
|
||||
```typescript
|
||||
import { MoEAttention } from 'ruvector-attention-wasm';
|
||||
|
||||
const moe = new MoEAttention({
|
||||
dim: 64,
|
||||
numExperts: 4,
|
||||
topK: 2
|
||||
});
|
||||
|
||||
const output = moe.compute(query, keys, values);
|
||||
|
||||
// Get expert utilization
|
||||
const stats = moe.getExpertStats();
|
||||
console.log('Load balance:', stats.loadBalance);
|
||||
```
|
||||
|
||||
#### Training with InfoNCE Loss
|
||||
|
||||
```typescript
|
||||
import { InfoNCELoss, Adam } from 'ruvector-attention-wasm';
|
||||
|
||||
const loss = new InfoNCELoss(0.07);
|
||||
const optimizer = new Adam(paramCount, {
|
||||
learningRate: 0.001,
|
||||
beta1: 0.9,
|
||||
beta2: 0.999,
|
||||
});
|
||||
|
||||
// Training loop
|
||||
const lossValue = loss.compute(anchor, positive, negatives);
|
||||
optimizer.step(params, gradients);
|
||||
```
|
||||
|
||||
#### Learning Rate Scheduling
|
||||
|
||||
```typescript
|
||||
import { LRScheduler, AdamW } from 'ruvector-attention-wasm';
|
||||
|
||||
const scheduler = new LRScheduler({
|
||||
initialLR: 0.001,
|
||||
warmupSteps: 1000,
|
||||
totalSteps: 10000,
|
||||
});
|
||||
|
||||
const optimizer = new AdamW(paramCount, {
|
||||
learningRate: scheduler.getLR(),
|
||||
weightDecay: 0.01,
|
||||
});
|
||||
|
||||
// Training loop
|
||||
for (let step = 0; step < 10000; step++) {
|
||||
optimizer.learningRate = scheduler.getLR();
|
||||
optimizer.step(params, gradients);
|
||||
scheduler.step();
|
||||
}
|
||||
```
|
||||
|
||||
## Building from Source
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Rust 1.70+
|
||||
- wasm-pack
|
||||
|
||||
### Build Commands
|
||||
|
||||
```bash
|
||||
# Build for web (ES modules)
|
||||
wasm-pack build --target web --out-dir pkg
|
||||
|
||||
# Build for Node.js
|
||||
wasm-pack build --target nodejs --out-dir pkg-node
|
||||
|
||||
# Build for bundlers (webpack, vite, etc.)
|
||||
wasm-pack build --target bundler --out-dir pkg-bundler
|
||||
|
||||
# Run tests
|
||||
wasm-pack test --headless --firefox
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Attention Mechanisms
|
||||
|
||||
- `MultiHeadAttention` - Standard multi-head attention
|
||||
- `HyperbolicAttention` - Attention in hyperbolic space
|
||||
- `LinearAttention` - Linear complexity attention (Performer)
|
||||
- `FlashAttention` - Memory-efficient attention
|
||||
- `LocalGlobalAttention` - Combined local and global attention
|
||||
- `MoEAttention` - Mixture of Experts attention
|
||||
- `CGTSheafAttention` - Coherence-gated via Prime-Radiant energy
|
||||
- `scaledDotAttention()` - Functional API for basic attention
|
||||
|
||||
### CGT Sheaf Attention (Prime-Radiant Integration)
|
||||
|
||||
The CGT (Coherence-Gated Transformer) Sheaf Attention mechanism uses Prime-Radiant's sheaf Laplacian energy to gate attention based on mathematical consistency:
|
||||
|
||||
```typescript
|
||||
import { CGTSheafAttention } from 'ruvector-attention-wasm';
|
||||
|
||||
const cgtAttention = new CGTSheafAttention({
|
||||
dim: 128,
|
||||
numHeads: 8,
|
||||
coherenceThreshold: 0.3, // Block if energy > threshold
|
||||
});
|
||||
|
||||
// Attention is gated by coherence energy
|
||||
const result = cgtAttention.compute(query, keys, values);
|
||||
console.log('Coherence energy:', result.energy);
|
||||
console.log('Is coherent:', result.isCoherent);
|
||||
```
|
||||
|
||||
**Key features:**
|
||||
- Energy-weighted attention: Lower coherence energy → higher attention
|
||||
- Automatic hallucination detection via residual analysis
|
||||
- GPU-accelerated with wgpu WGSL shaders (vec4 optimized)
|
||||
- SIMD fallback (AVX-512/AVX2/NEON)
|
||||
|
||||
### Training
|
||||
|
||||
- `InfoNCELoss` - Contrastive loss function
|
||||
- `Adam` - Adam optimizer
|
||||
- `AdamW` - AdamW optimizer with weight decay
|
||||
- `LRScheduler` - Learning rate scheduler
|
||||
|
||||
### Utilities
|
||||
|
||||
- `utils.cosineSimilarity()` - Cosine similarity between vectors
|
||||
- `utils.l2Norm()` - L2 norm of a vector
|
||||
- `utils.normalize()` - Normalize vector to unit length
|
||||
- `utils.softmax()` - Apply softmax transformation
|
||||
- `utils.attentionWeights()` - Compute attention weights from scores
|
||||
- `utils.batchNormalize()` - Batch normalization
|
||||
- `utils.randomOrthogonalMatrix()` - Generate random orthogonal matrix
|
||||
- `utils.pairwiseDistances()` - Compute pairwise distances
|
||||
|
||||
## Performance
|
||||
|
||||
The WASM bindings provide near-native performance for attention computations:
|
||||
|
||||
- Optimized with `opt-level = "s"` and LTO
|
||||
- SIMD acceleration where available
|
||||
- Efficient memory management
|
||||
- Zero-copy data transfer where possible
|
||||
|
||||
## License
|
||||
|
||||
MIT OR Apache-2.0
|
||||
412
crates/ruvector-attention-wasm/js/index.ts
Normal file
412
crates/ruvector-attention-wasm/js/index.ts
Normal file
@@ -0,0 +1,412 @@
|
||||
/**
|
||||
* TypeScript wrapper for ruvector-attention-wasm
|
||||
* Provides a clean, type-safe API for attention mechanisms
|
||||
*/
|
||||
|
||||
import init, * as wasm from '../pkg/ruvector_attention_wasm';
|
||||
import type {
|
||||
AttentionConfig,
|
||||
MultiHeadConfig,
|
||||
HyperbolicConfig,
|
||||
LinearAttentionConfig,
|
||||
FlashAttentionConfig,
|
||||
LocalGlobalConfig,
|
||||
MoEConfig,
|
||||
TrainingConfig,
|
||||
SchedulerConfig,
|
||||
ExpertStats,
|
||||
AttentionType,
|
||||
} from './types';
|
||||
|
||||
export * from './types';
|
||||
|
||||
let initialized = false;
|
||||
|
||||
/**
|
||||
* Initialize the WASM module
|
||||
* Must be called before using any attention mechanisms
|
||||
*/
|
||||
export async function initialize(): Promise<void> {
|
||||
if (!initialized) {
|
||||
await init();
|
||||
initialized = true;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the version of the ruvector-attention-wasm package
|
||||
*/
|
||||
export function version(): string {
|
||||
return wasm.version();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get list of available attention mechanisms
|
||||
*/
|
||||
export function availableMechanisms(): AttentionType[] {
|
||||
return wasm.available_mechanisms() as AttentionType[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Multi-head attention mechanism
|
||||
*/
|
||||
export class MultiHeadAttention {
|
||||
private inner: wasm.WasmMultiHeadAttention;
|
||||
|
||||
constructor(config: MultiHeadConfig) {
|
||||
this.inner = new wasm.WasmMultiHeadAttention(config.dim, config.numHeads);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute multi-head attention
|
||||
*/
|
||||
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
|
||||
const result = this.inner.compute(query, keys, values);
|
||||
return new Float32Array(result);
|
||||
}
|
||||
|
||||
get numHeads(): number {
|
||||
return this.inner.num_heads;
|
||||
}
|
||||
|
||||
get dim(): number {
|
||||
return this.inner.dim;
|
||||
}
|
||||
|
||||
free(): void {
|
||||
this.inner.free();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Hyperbolic attention mechanism
|
||||
*/
|
||||
export class HyperbolicAttention {
|
||||
private inner: wasm.WasmHyperbolicAttention;
|
||||
|
||||
constructor(config: HyperbolicConfig) {
|
||||
this.inner = new wasm.WasmHyperbolicAttention(config.dim, config.curvature);
|
||||
}
|
||||
|
||||
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
|
||||
const result = this.inner.compute(query, keys, values);
|
||||
return new Float32Array(result);
|
||||
}
|
||||
|
||||
get curvature(): number {
|
||||
return this.inner.curvature;
|
||||
}
|
||||
|
||||
free(): void {
|
||||
this.inner.free();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Linear attention (Performer-style)
|
||||
*/
|
||||
export class LinearAttention {
|
||||
private inner: wasm.WasmLinearAttention;
|
||||
|
||||
constructor(config: LinearAttentionConfig) {
|
||||
this.inner = new wasm.WasmLinearAttention(config.dim, config.numFeatures);
|
||||
}
|
||||
|
||||
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
|
||||
const result = this.inner.compute(query, keys, values);
|
||||
return new Float32Array(result);
|
||||
}
|
||||
|
||||
free(): void {
|
||||
this.inner.free();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Flash attention mechanism
|
||||
*/
|
||||
export class FlashAttention {
|
||||
private inner: wasm.WasmFlashAttention;
|
||||
|
||||
constructor(config: FlashAttentionConfig) {
|
||||
this.inner = new wasm.WasmFlashAttention(config.dim, config.blockSize);
|
||||
}
|
||||
|
||||
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
|
||||
const result = this.inner.compute(query, keys, values);
|
||||
return new Float32Array(result);
|
||||
}
|
||||
|
||||
free(): void {
|
||||
this.inner.free();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Local-global attention mechanism
|
||||
*/
|
||||
export class LocalGlobalAttention {
|
||||
private inner: wasm.WasmLocalGlobalAttention;
|
||||
|
||||
constructor(config: LocalGlobalConfig) {
|
||||
this.inner = new wasm.WasmLocalGlobalAttention(
|
||||
config.dim,
|
||||
config.localWindow,
|
||||
config.globalTokens
|
||||
);
|
||||
}
|
||||
|
||||
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
|
||||
const result = this.inner.compute(query, keys, values);
|
||||
return new Float32Array(result);
|
||||
}
|
||||
|
||||
free(): void {
|
||||
this.inner.free();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Mixture of Experts attention
|
||||
*/
|
||||
export class MoEAttention {
|
||||
private inner: wasm.WasmMoEAttention;
|
||||
|
||||
constructor(config: MoEConfig) {
|
||||
this.inner = new wasm.WasmMoEAttention(config.dim, config.numExperts, config.topK);
|
||||
}
|
||||
|
||||
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
|
||||
const result = this.inner.compute(query, keys, values);
|
||||
return new Float32Array(result);
|
||||
}
|
||||
|
||||
getExpertStats(): ExpertStats {
|
||||
return this.inner.expert_stats() as ExpertStats;
|
||||
}
|
||||
|
||||
free(): void {
|
||||
this.inner.free();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* InfoNCE contrastive loss
|
||||
*/
|
||||
export class InfoNCELoss {
|
||||
private inner: wasm.WasmInfoNCELoss;
|
||||
|
||||
constructor(temperature: number = 0.07) {
|
||||
this.inner = new wasm.WasmInfoNCELoss(temperature);
|
||||
}
|
||||
|
||||
compute(anchor: Float32Array, positive: Float32Array, negatives: Float32Array[]): number {
|
||||
return this.inner.compute(anchor, positive, negatives);
|
||||
}
|
||||
|
||||
computeMultiPositive(
|
||||
anchor: Float32Array,
|
||||
positives: Float32Array[],
|
||||
negatives: Float32Array[]
|
||||
): number {
|
||||
return this.inner.compute_multi_positive(anchor, positives, negatives);
|
||||
}
|
||||
|
||||
free(): void {
|
||||
this.inner.free();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Adam optimizer
|
||||
*/
|
||||
export class Adam {
|
||||
private inner: wasm.WasmAdam;
|
||||
|
||||
constructor(paramCount: number, config: TrainingConfig) {
|
||||
this.inner = new wasm.WasmAdam(
|
||||
paramCount,
|
||||
config.learningRate,
|
||||
config.beta1,
|
||||
config.beta2,
|
||||
config.epsilon
|
||||
);
|
||||
}
|
||||
|
||||
step(params: Float32Array, gradients: Float32Array): void {
|
||||
this.inner.step(params, gradients);
|
||||
}
|
||||
|
||||
reset(): void {
|
||||
this.inner.reset();
|
||||
}
|
||||
|
||||
get learningRate(): number {
|
||||
return this.inner.learning_rate;
|
||||
}
|
||||
|
||||
set learningRate(lr: number) {
|
||||
this.inner.learning_rate = lr;
|
||||
}
|
||||
|
||||
free(): void {
|
||||
this.inner.free();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* AdamW optimizer (Adam with decoupled weight decay)
|
||||
*/
|
||||
export class AdamW {
|
||||
private inner: wasm.WasmAdamW;
|
||||
|
||||
constructor(paramCount: number, config: TrainingConfig) {
|
||||
if (!config.weightDecay) {
|
||||
throw new Error('AdamW requires weightDecay parameter');
|
||||
}
|
||||
|
||||
this.inner = new wasm.WasmAdamW(
|
||||
paramCount,
|
||||
config.learningRate,
|
||||
config.weightDecay,
|
||||
config.beta1,
|
||||
config.beta2,
|
||||
config.epsilon
|
||||
);
|
||||
}
|
||||
|
||||
step(params: Float32Array, gradients: Float32Array): void {
|
||||
this.inner.step(params, gradients);
|
||||
}
|
||||
|
||||
reset(): void {
|
||||
this.inner.reset();
|
||||
}
|
||||
|
||||
get learningRate(): number {
|
||||
return this.inner.learning_rate;
|
||||
}
|
||||
|
||||
set learningRate(lr: number) {
|
||||
this.inner.learning_rate = lr;
|
||||
}
|
||||
|
||||
get weightDecay(): number {
|
||||
return this.inner.weight_decay;
|
||||
}
|
||||
|
||||
free(): void {
|
||||
this.inner.free();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Learning rate scheduler with warmup and cosine decay
|
||||
*/
|
||||
export class LRScheduler {
|
||||
private inner: wasm.WasmLRScheduler;
|
||||
|
||||
constructor(config: SchedulerConfig) {
|
||||
this.inner = new wasm.WasmLRScheduler(
|
||||
config.initialLR,
|
||||
config.warmupSteps,
|
||||
config.totalSteps
|
||||
);
|
||||
}
|
||||
|
||||
getLR(): number {
|
||||
return this.inner.get_lr();
|
||||
}
|
||||
|
||||
step(): void {
|
||||
this.inner.step();
|
||||
}
|
||||
|
||||
reset(): void {
|
||||
this.inner.reset();
|
||||
}
|
||||
|
||||
free(): void {
|
||||
this.inner.free();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Utility functions
|
||||
*/
|
||||
export const utils = {
|
||||
/**
|
||||
* Compute cosine similarity between two vectors
|
||||
*/
|
||||
cosineSimilarity(a: Float32Array, b: Float32Array): number {
|
||||
return wasm.cosine_similarity(a, b);
|
||||
},
|
||||
|
||||
/**
|
||||
* Compute L2 norm of a vector
|
||||
*/
|
||||
l2Norm(vec: Float32Array): number {
|
||||
return wasm.l2_norm(vec);
|
||||
},
|
||||
|
||||
/**
|
||||
* Normalize a vector to unit length (in-place)
|
||||
*/
|
||||
normalize(vec: Float32Array): void {
|
||||
wasm.normalize(vec);
|
||||
},
|
||||
|
||||
/**
|
||||
* Apply softmax to a vector (in-place)
|
||||
*/
|
||||
softmax(vec: Float32Array): void {
|
||||
wasm.softmax(vec);
|
||||
},
|
||||
|
||||
/**
|
||||
* Compute attention weights from scores (in-place)
|
||||
*/
|
||||
attentionWeights(scores: Float32Array, temperature?: number): void {
|
||||
wasm.attention_weights(scores, temperature);
|
||||
},
|
||||
|
||||
/**
|
||||
* Batch normalize vectors
|
||||
*/
|
||||
batchNormalize(vectors: Float32Array[], epsilon?: number): Float32Array {
|
||||
const result = wasm.batch_normalize(vectors, epsilon);
|
||||
return new Float32Array(result);
|
||||
},
|
||||
|
||||
/**
|
||||
* Generate random orthogonal matrix
|
||||
*/
|
||||
randomOrthogonalMatrix(dim: number): Float32Array {
|
||||
const result = wasm.random_orthogonal_matrix(dim);
|
||||
return new Float32Array(result);
|
||||
},
|
||||
|
||||
/**
|
||||
* Compute pairwise distances between vectors
|
||||
*/
|
||||
pairwiseDistances(vectors: Float32Array[]): Float32Array {
|
||||
const result = wasm.pairwise_distances(vectors);
|
||||
return new Float32Array(result);
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
* Simple scaled dot-product attention (functional API)
|
||||
*/
|
||||
export function scaledDotAttention(
|
||||
query: Float32Array,
|
||||
keys: Float32Array[],
|
||||
values: Float32Array[],
|
||||
scale?: number
|
||||
): Float32Array {
|
||||
const result = wasm.scaled_dot_attention(query, keys, values, scale);
|
||||
return new Float32Array(result);
|
||||
}
|
||||
|
||||
// Re-export WASM module for advanced usage
|
||||
export { wasm };
|
||||
108
crates/ruvector-attention-wasm/js/types.ts
Normal file
108
crates/ruvector-attention-wasm/js/types.ts
Normal file
@@ -0,0 +1,108 @@
|
||||
/**
|
||||
* TypeScript type definitions for ruvector-attention-wasm
|
||||
*/
|
||||
|
||||
export interface AttentionConfig {
|
||||
/** Embedding dimension */
|
||||
dim: number;
|
||||
/** Number of attention heads (for multi-head attention) */
|
||||
numHeads?: number;
|
||||
/** Dropout probability */
|
||||
dropout?: number;
|
||||
/** Scaling factor for attention scores */
|
||||
scale?: number;
|
||||
/** Whether to use causal masking */
|
||||
causal?: boolean;
|
||||
}
|
||||
|
||||
export interface MultiHeadConfig extends AttentionConfig {
|
||||
numHeads: number;
|
||||
}
|
||||
|
||||
export interface HyperbolicConfig extends AttentionConfig {
|
||||
/** Hyperbolic space curvature */
|
||||
curvature: number;
|
||||
}
|
||||
|
||||
export interface LinearAttentionConfig extends AttentionConfig {
|
||||
/** Number of random features for kernel approximation */
|
||||
numFeatures: number;
|
||||
}
|
||||
|
||||
export interface FlashAttentionConfig extends AttentionConfig {
|
||||
/** Block size for tiling */
|
||||
blockSize: number;
|
||||
}
|
||||
|
||||
export interface LocalGlobalConfig extends AttentionConfig {
|
||||
/** Size of local attention window */
|
||||
localWindow: number;
|
||||
/** Number of global attention tokens */
|
||||
globalTokens: number;
|
||||
}
|
||||
|
||||
export interface MoEConfig extends AttentionConfig {
|
||||
/** Number of expert attention mechanisms */
|
||||
numExperts: number;
|
||||
/** Number of experts to use per query */
|
||||
topK: number;
|
||||
/** Maximum capacity per expert */
|
||||
expertCapacity?: number;
|
||||
/** Load balancing coefficient */
|
||||
balanceCoeff?: number;
|
||||
}
|
||||
|
||||
export interface TrainingConfig {
|
||||
/** Learning rate for optimizer */
|
||||
learningRate: number;
|
||||
/** Temperature parameter for contrastive loss */
|
||||
temperature?: number;
|
||||
/** First moment decay rate (Adam/AdamW) */
|
||||
beta1?: number;
|
||||
/** Second moment decay rate (Adam/AdamW) */
|
||||
beta2?: number;
|
||||
/** Weight decay coefficient (AdamW) */
|
||||
weightDecay?: number;
|
||||
/** Numerical stability constant */
|
||||
epsilon?: number;
|
||||
}
|
||||
|
||||
export interface SchedulerConfig {
|
||||
/** Initial learning rate */
|
||||
initialLR: number;
|
||||
/** Number of warmup steps */
|
||||
warmupSteps: number;
|
||||
/** Total training steps */
|
||||
totalSteps: number;
|
||||
}
|
||||
|
||||
export interface ExpertStats {
|
||||
/** Number of times each expert was selected */
|
||||
selectionCounts: number[];
|
||||
/** Average load per expert */
|
||||
averageLoad: number[];
|
||||
/** Load balance factor (lower is better) */
|
||||
loadBalance: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Attention mechanism types
|
||||
*/
|
||||
export type AttentionType =
|
||||
| 'scaled_dot_product'
|
||||
| 'multi_head'
|
||||
| 'hyperbolic'
|
||||
| 'linear'
|
||||
| 'flash'
|
||||
| 'local_global'
|
||||
| 'moe';
|
||||
|
||||
/**
|
||||
* Optimizer types
|
||||
*/
|
||||
export type OptimizerType = 'adam' | 'adamw';
|
||||
|
||||
/**
|
||||
* Loss function types
|
||||
*/
|
||||
export type LossType = 'info_nce';
|
||||
66
crates/ruvector-attention-wasm/package.json
Normal file
66
crates/ruvector-attention-wasm/package.json
Normal file
@@ -0,0 +1,66 @@
|
||||
{
|
||||
"name": "@ruvector/attention-wasm",
|
||||
"version": "0.1.32",
|
||||
"description": "High-performance WebAssembly attention mechanisms for transformers and LLMs: Multi-Head, Flash Attention, Hyperbolic, Linear (Performer), MoE, Local-Global, and CGT Sheaf Attention with coherence gating. GPU-accelerated with SIMD fallback.",
|
||||
"main": "pkg/ruvector_attention_wasm.js",
|
||||
"module": "pkg/ruvector_attention_wasm.js",
|
||||
"types": "pkg/ruvector_attention_wasm.d.ts",
|
||||
"files": [
|
||||
"pkg/",
|
||||
"js/",
|
||||
"README.md"
|
||||
],
|
||||
"scripts": {
|
||||
"build": "wasm-pack build --target web --out-dir pkg",
|
||||
"build:node": "wasm-pack build --target nodejs --out-dir pkg-node",
|
||||
"build:bundler": "wasm-pack build --target bundler --out-dir pkg-bundler",
|
||||
"build:all": "npm run build && npm run build:node && npm run build:bundler",
|
||||
"test": "wasm-pack test --headless --firefox",
|
||||
"test:chrome": "wasm-pack test --headless --chrome",
|
||||
"clean": "rm -rf pkg pkg-node pkg-bundler target",
|
||||
"prepublishOnly": "npm run build"
|
||||
},
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git+https://github.com/ruvnet/ruvector.git"
|
||||
},
|
||||
"keywords": [
|
||||
"wasm",
|
||||
"webassembly",
|
||||
"attention",
|
||||
"transformer",
|
||||
"llm",
|
||||
"machine-learning",
|
||||
"neural-networks",
|
||||
"multi-head-attention",
|
||||
"flash-attention",
|
||||
"hyperbolic",
|
||||
"moe",
|
||||
"mixture-of-experts",
|
||||
"coherence",
|
||||
"cgt",
|
||||
"sheaf-attention",
|
||||
"ai",
|
||||
"deep-learning",
|
||||
"gpu",
|
||||
"simd",
|
||||
"infonce",
|
||||
"contrastive-learning"
|
||||
],
|
||||
"author": "rUv <team@ruvector.dev>",
|
||||
"license": "MIT OR Apache-2.0",
|
||||
"bugs": {
|
||||
"url": "https://github.com/ruvnet/ruvector/issues"
|
||||
},
|
||||
"homepage": "https://ruv.io/ruvector",
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.0.0",
|
||||
"typescript": "^5.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=16.0.0"
|
||||
},
|
||||
"publishConfig": {
|
||||
"access": "public"
|
||||
}
|
||||
}
|
||||
21
crates/ruvector-attention-wasm/pkg/LICENSE
Normal file
21
crates/ruvector-attention-wasm/pkg/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 rUv
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
220
crates/ruvector-attention-wasm/pkg/README.md
Normal file
220
crates/ruvector-attention-wasm/pkg/README.md
Normal file
@@ -0,0 +1,220 @@
|
||||
# ruvector-attention-wasm
|
||||
|
||||
WebAssembly bindings for the ruvector-attention package, providing high-performance attention mechanisms for browser and Node.js environments.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multiple Attention Mechanisms**:
|
||||
- Scaled Dot-Product Attention
|
||||
- Multi-Head Attention
|
||||
- Hyperbolic Attention (for hierarchical data)
|
||||
- Linear Attention (Performer-style)
|
||||
- Flash Attention (memory-efficient)
|
||||
- Local-Global Attention
|
||||
- Mixture of Experts (MoE) Attention
|
||||
- **CGT Sheaf Attention** (coherence-gated via Prime-Radiant)
|
||||
|
||||
- **Training Utilities**:
|
||||
- InfoNCE contrastive loss
|
||||
- Adam optimizer
|
||||
- AdamW optimizer (with decoupled weight decay)
|
||||
- Learning rate scheduler (warmup + cosine decay)
|
||||
|
||||
- **TypeScript Support**: Full type definitions and modern API
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
npm install ruvector-attention-wasm
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### TypeScript/JavaScript
|
||||
|
||||
```typescript
|
||||
import { initialize, MultiHeadAttention, utils } from 'ruvector-attention-wasm';
|
||||
|
||||
// Initialize WASM module
|
||||
await initialize();
|
||||
|
||||
// Create multi-head attention
|
||||
const attention = new MultiHeadAttention({ dim: 64, numHeads: 8 });
|
||||
|
||||
// Prepare inputs
|
||||
const query = new Float32Array(64);
|
||||
const keys = [new Float32Array(64), new Float32Array(64)];
|
||||
const values = [new Float32Array(64), new Float32Array(64)];
|
||||
|
||||
// Compute attention
|
||||
const output = attention.compute(query, keys, values);
|
||||
|
||||
// Use utilities
|
||||
const similarity = utils.cosineSimilarity(query, keys[0]);
|
||||
```
|
||||
|
||||
### Advanced Examples
|
||||
|
||||
#### Hyperbolic Attention
|
||||
|
||||
```typescript
|
||||
import { HyperbolicAttention } from 'ruvector-attention-wasm';
|
||||
|
||||
const hyperbolic = new HyperbolicAttention({
|
||||
dim: 128,
|
||||
curvature: 1.0
|
||||
});
|
||||
|
||||
const output = hyperbolic.compute(query, keys, values);
|
||||
```
|
||||
|
||||
#### MoE Attention with Expert Stats
|
||||
|
||||
```typescript
|
||||
import { MoEAttention } from 'ruvector-attention-wasm';
|
||||
|
||||
const moe = new MoEAttention({
|
||||
dim: 64,
|
||||
numExperts: 4,
|
||||
topK: 2
|
||||
});
|
||||
|
||||
const output = moe.compute(query, keys, values);
|
||||
|
||||
// Get expert utilization
|
||||
const stats = moe.getExpertStats();
|
||||
console.log('Load balance:', stats.loadBalance);
|
||||
```
|
||||
|
||||
#### Training with InfoNCE Loss
|
||||
|
||||
```typescript
|
||||
import { InfoNCELoss, Adam } from 'ruvector-attention-wasm';
|
||||
|
||||
const loss = new InfoNCELoss(0.07);
|
||||
const optimizer = new Adam(paramCount, {
|
||||
learningRate: 0.001,
|
||||
beta1: 0.9,
|
||||
beta2: 0.999,
|
||||
});
|
||||
|
||||
// Training loop
|
||||
const lossValue = loss.compute(anchor, positive, negatives);
|
||||
optimizer.step(params, gradients);
|
||||
```
|
||||
|
||||
#### Learning Rate Scheduling
|
||||
|
||||
```typescript
|
||||
import { LRScheduler, AdamW } from 'ruvector-attention-wasm';
|
||||
|
||||
const scheduler = new LRScheduler({
|
||||
initialLR: 0.001,
|
||||
warmupSteps: 1000,
|
||||
totalSteps: 10000,
|
||||
});
|
||||
|
||||
const optimizer = new AdamW(paramCount, {
|
||||
learningRate: scheduler.getLR(),
|
||||
weightDecay: 0.01,
|
||||
});
|
||||
|
||||
// Training loop
|
||||
for (let step = 0; step < 10000; step++) {
|
||||
optimizer.learningRate = scheduler.getLR();
|
||||
optimizer.step(params, gradients);
|
||||
scheduler.step();
|
||||
}
|
||||
```
|
||||
|
||||
## Building from Source
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Rust 1.70+
|
||||
- wasm-pack
|
||||
|
||||
### Build Commands
|
||||
|
||||
```bash
|
||||
# Build for web (ES modules)
|
||||
wasm-pack build --target web --out-dir pkg
|
||||
|
||||
# Build for Node.js
|
||||
wasm-pack build --target nodejs --out-dir pkg-node
|
||||
|
||||
# Build for bundlers (webpack, vite, etc.)
|
||||
wasm-pack build --target bundler --out-dir pkg-bundler
|
||||
|
||||
# Run tests
|
||||
wasm-pack test --headless --firefox
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Attention Mechanisms
|
||||
|
||||
- `MultiHeadAttention` - Standard multi-head attention
|
||||
- `HyperbolicAttention` - Attention in hyperbolic space
|
||||
- `LinearAttention` - Linear complexity attention (Performer)
|
||||
- `FlashAttention` - Memory-efficient attention
|
||||
- `LocalGlobalAttention` - Combined local and global attention
|
||||
- `MoEAttention` - Mixture of Experts attention
|
||||
- `CGTSheafAttention` - Coherence-gated via Prime-Radiant energy
|
||||
- `scaledDotAttention()` - Functional API for basic attention
|
||||
|
||||
### CGT Sheaf Attention (Prime-Radiant Integration)
|
||||
|
||||
The CGT (Coherence-Gated Transformer) Sheaf Attention mechanism uses Prime-Radiant's sheaf Laplacian energy to gate attention based on mathematical consistency:
|
||||
|
||||
```typescript
|
||||
import { CGTSheafAttention } from 'ruvector-attention-wasm';
|
||||
|
||||
const cgtAttention = new CGTSheafAttention({
|
||||
dim: 128,
|
||||
numHeads: 8,
|
||||
coherenceThreshold: 0.3, // Block if energy > threshold
|
||||
});
|
||||
|
||||
// Attention is gated by coherence energy
|
||||
const result = cgtAttention.compute(query, keys, values);
|
||||
console.log('Coherence energy:', result.energy);
|
||||
console.log('Is coherent:', result.isCoherent);
|
||||
```
|
||||
|
||||
**Key features:**
|
||||
- Energy-weighted attention: Lower coherence energy → higher attention
|
||||
- Automatic hallucination detection via residual analysis
|
||||
- GPU-accelerated with wgpu WGSL shaders (vec4 optimized)
|
||||
- SIMD fallback (AVX-512/AVX2/NEON)
|
||||
|
||||
### Training
|
||||
|
||||
- `InfoNCELoss` - Contrastive loss function
|
||||
- `Adam` - Adam optimizer
|
||||
- `AdamW` - AdamW optimizer with weight decay
|
||||
- `LRScheduler` - Learning rate scheduler
|
||||
|
||||
### Utilities
|
||||
|
||||
- `utils.cosineSimilarity()` - Cosine similarity between vectors
|
||||
- `utils.l2Norm()` - L2 norm of a vector
|
||||
- `utils.normalize()` - Normalize vector to unit length
|
||||
- `utils.softmax()` - Apply softmax transformation
|
||||
- `utils.attentionWeights()` - Compute attention weights from scores
|
||||
- `utils.batchNormalize()` - Batch normalization
|
||||
- `utils.randomOrthogonalMatrix()` - Generate random orthogonal matrix
|
||||
- `utils.pairwiseDistances()` - Compute pairwise distances
|
||||
|
||||
## Performance
|
||||
|
||||
The WASM bindings provide near-native performance for attention computations:
|
||||
|
||||
- Optimized with `opt-level = "s"` and LTO
|
||||
- SIMD acceleration where available
|
||||
- Efficient memory management
|
||||
- Zero-copy data transfer where possible
|
||||
|
||||
## License
|
||||
|
||||
MIT OR Apache-2.0
|
||||
28
crates/ruvector-attention-wasm/pkg/package.json
Normal file
28
crates/ruvector-attention-wasm/pkg/package.json
Normal file
@@ -0,0 +1,28 @@
|
||||
{
|
||||
"name": "ruvector-attention-wasm",
|
||||
"collaborators": [
|
||||
"Ruvector Team"
|
||||
],
|
||||
"description": "High-performance WebAssembly attention mechanisms: Multi-Head, Flash, Hyperbolic, MoE, CGT Sheaf Attention with GPU acceleration for transformers and LLMs",
|
||||
"version": "2.0.5",
|
||||
"license": "MIT",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "https://github.com/ruvnet/ruvector"
|
||||
},
|
||||
"files": [
|
||||
"ruvector_attention_wasm_bg.wasm",
|
||||
"ruvector_attention_wasm.js",
|
||||
"ruvector_attention_wasm.d.ts"
|
||||
],
|
||||
"main": "ruvector_attention_wasm.js",
|
||||
"homepage": "https://ruv.io/ruvector",
|
||||
"types": "ruvector_attention_wasm.d.ts",
|
||||
"keywords": [
|
||||
"wasm",
|
||||
"attention",
|
||||
"transformer",
|
||||
"flash-attention",
|
||||
"llm"
|
||||
]
|
||||
}
|
||||
359
crates/ruvector-attention-wasm/pkg/ruvector_attention_wasm.d.ts
vendored
Normal file
359
crates/ruvector-attention-wasm/pkg/ruvector_attention_wasm.d.ts
vendored
Normal file
@@ -0,0 +1,359 @@
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* Adam optimizer
|
||||
*/
|
||||
export class WasmAdam {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new Adam optimizer
|
||||
*
|
||||
* # Arguments
|
||||
* * `param_count` - Number of parameters
|
||||
* * `learning_rate` - Learning rate
|
||||
*/
|
||||
constructor(param_count: number, learning_rate: number);
|
||||
/**
|
||||
* Reset optimizer state
|
||||
*/
|
||||
reset(): void;
|
||||
/**
|
||||
* Perform optimization step
|
||||
*
|
||||
* # Arguments
|
||||
* * `params` - Current parameter values (will be updated in-place)
|
||||
* * `gradients` - Gradient values
|
||||
*/
|
||||
step(params: Float32Array, gradients: Float32Array): void;
|
||||
/**
|
||||
* Get current learning rate
|
||||
*/
|
||||
learning_rate: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* AdamW optimizer (Adam with decoupled weight decay)
|
||||
*/
|
||||
export class WasmAdamW {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new AdamW optimizer
|
||||
*
|
||||
* # Arguments
|
||||
* * `param_count` - Number of parameters
|
||||
* * `learning_rate` - Learning rate
|
||||
* * `weight_decay` - Weight decay coefficient
|
||||
*/
|
||||
constructor(param_count: number, learning_rate: number, weight_decay: number);
|
||||
/**
|
||||
* Reset optimizer state
|
||||
*/
|
||||
reset(): void;
|
||||
/**
|
||||
* Perform optimization step with weight decay
|
||||
*/
|
||||
step(params: Float32Array, gradients: Float32Array): void;
|
||||
/**
|
||||
* Get current learning rate
|
||||
*/
|
||||
learning_rate: number;
|
||||
/**
|
||||
* Get weight decay
|
||||
*/
|
||||
readonly weight_decay: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Flash attention mechanism
|
||||
*/
|
||||
export class WasmFlashAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Compute flash attention
|
||||
*/
|
||||
compute(query: Float32Array, keys: any, values: any): Float32Array;
|
||||
/**
|
||||
* Create a new flash attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `dim` - Embedding dimension
|
||||
* * `block_size` - Block size for tiling
|
||||
*/
|
||||
constructor(dim: number, block_size: number);
|
||||
}
|
||||
|
||||
/**
|
||||
* Hyperbolic attention mechanism
|
||||
*/
|
||||
export class WasmHyperbolicAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Compute hyperbolic attention
|
||||
*/
|
||||
compute(query: Float32Array, keys: any, values: any): Float32Array;
|
||||
/**
|
||||
* Create a new hyperbolic attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `dim` - Embedding dimension
|
||||
* * `curvature` - Hyperbolic curvature parameter
|
||||
*/
|
||||
constructor(dim: number, curvature: number);
|
||||
/**
|
||||
* Get the curvature
|
||||
*/
|
||||
readonly curvature: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* InfoNCE contrastive loss for training
|
||||
*/
|
||||
export class WasmInfoNCELoss {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Compute InfoNCE loss
|
||||
*
|
||||
* # Arguments
|
||||
* * `anchor` - Anchor embedding
|
||||
* * `positive` - Positive example embedding
|
||||
* * `negatives` - Array of negative example embeddings
|
||||
*/
|
||||
compute(anchor: Float32Array, positive: Float32Array, negatives: any): number;
|
||||
/**
|
||||
* Create a new InfoNCE loss instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `temperature` - Temperature parameter for softmax
|
||||
*/
|
||||
constructor(temperature: number);
|
||||
}
|
||||
|
||||
/**
|
||||
* Learning rate scheduler
|
||||
*/
|
||||
export class WasmLRScheduler {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Get learning rate for current step
|
||||
*/
|
||||
get_lr(): number;
|
||||
/**
|
||||
* Create a new learning rate scheduler with warmup and cosine decay
|
||||
*
|
||||
* # Arguments
|
||||
* * `initial_lr` - Initial learning rate
|
||||
* * `warmup_steps` - Number of warmup steps
|
||||
* * `total_steps` - Total training steps
|
||||
*/
|
||||
constructor(initial_lr: number, warmup_steps: number, total_steps: number);
|
||||
/**
|
||||
* Reset scheduler
|
||||
*/
|
||||
reset(): void;
|
||||
/**
|
||||
* Advance to next step
|
||||
*/
|
||||
step(): void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Linear attention (Performer-style)
|
||||
*/
|
||||
export class WasmLinearAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Compute linear attention
|
||||
*/
|
||||
compute(query: Float32Array, keys: any, values: any): Float32Array;
|
||||
/**
|
||||
* Create a new linear attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `dim` - Embedding dimension
|
||||
* * `num_features` - Number of random features
|
||||
*/
|
||||
constructor(dim: number, num_features: number);
|
||||
}
|
||||
|
||||
/**
|
||||
* Local-global attention mechanism
|
||||
*/
|
||||
export class WasmLocalGlobalAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Compute local-global attention
|
||||
*/
|
||||
compute(query: Float32Array, keys: any, values: any): Float32Array;
|
||||
/**
|
||||
* Create a new local-global attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `dim` - Embedding dimension
|
||||
* * `local_window` - Size of local attention window
|
||||
* * `global_tokens` - Number of global attention tokens
|
||||
*/
|
||||
constructor(dim: number, local_window: number, global_tokens: number);
|
||||
}
|
||||
|
||||
/**
|
||||
* Mixture of Experts (MoE) attention
|
||||
*/
|
||||
export class WasmMoEAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Compute MoE attention
|
||||
*/
|
||||
compute(query: Float32Array, keys: any, values: any): Float32Array;
|
||||
/**
|
||||
* Create a new MoE attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `dim` - Embedding dimension
|
||||
* * `num_experts` - Number of expert attention mechanisms
|
||||
* * `top_k` - Number of experts to use per query
|
||||
*/
|
||||
constructor(dim: number, num_experts: number, top_k: number);
|
||||
}
|
||||
|
||||
/**
|
||||
* Multi-head attention mechanism
|
||||
*/
|
||||
export class WasmMultiHeadAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Compute multi-head attention
|
||||
*/
|
||||
compute(query: Float32Array, keys: any, values: any): Float32Array;
|
||||
/**
|
||||
* Create a new multi-head attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `dim` - Embedding dimension
|
||||
* * `num_heads` - Number of attention heads
|
||||
*/
|
||||
constructor(dim: number, num_heads: number);
|
||||
/**
|
||||
* Get the dimension
|
||||
*/
|
||||
readonly dim: number;
|
||||
/**
|
||||
* Get the number of heads
|
||||
*/
|
||||
readonly num_heads: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* SGD optimizer with momentum
|
||||
*/
|
||||
export class WasmSGD {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new SGD optimizer
|
||||
*
|
||||
* # Arguments
|
||||
* * `param_count` - Number of parameters
|
||||
* * `learning_rate` - Learning rate
|
||||
* * `momentum` - Momentum coefficient (default: 0)
|
||||
*/
|
||||
constructor(param_count: number, learning_rate: number, momentum?: number | null);
|
||||
/**
|
||||
* Reset optimizer state
|
||||
*/
|
||||
reset(): void;
|
||||
/**
|
||||
* Perform optimization step
|
||||
*/
|
||||
step(params: Float32Array, gradients: Float32Array): void;
|
||||
/**
|
||||
* Get current learning rate
|
||||
*/
|
||||
learning_rate: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute attention weights from scores
|
||||
*/
|
||||
export function attention_weights(scores: Float32Array, temperature?: number | null): void;
|
||||
|
||||
/**
|
||||
* Get information about available attention mechanisms
|
||||
*/
|
||||
export function available_mechanisms(): any;
|
||||
|
||||
/**
|
||||
* Batch normalize vectors
|
||||
*/
|
||||
export function batch_normalize(vectors: any, epsilon?: number | null): Float32Array;
|
||||
|
||||
/**
|
||||
* Compute cosine similarity between two vectors
|
||||
*/
|
||||
export function cosine_similarity(a: Float32Array, b: Float32Array): number;
|
||||
|
||||
/**
|
||||
* Initialize the WASM module with panic hook
|
||||
*/
|
||||
export function init(): void;
|
||||
|
||||
/**
|
||||
* Compute L2 norm of a vector
|
||||
*/
|
||||
export function l2_norm(vec: Float32Array): number;
|
||||
|
||||
/**
|
||||
* Log a message to the browser console
|
||||
*/
|
||||
export function log(message: string): void;
|
||||
|
||||
/**
|
||||
* Log an error to the browser console
|
||||
*/
|
||||
export function log_error(message: string): void;
|
||||
|
||||
/**
|
||||
* Normalize a vector to unit length
|
||||
*/
|
||||
export function normalize(vec: Float32Array): void;
|
||||
|
||||
/**
|
||||
* Compute pairwise distances between vectors
|
||||
*/
|
||||
export function pairwise_distances(vectors: any): Float32Array;
|
||||
|
||||
/**
|
||||
* Generate random orthogonal matrix (for initialization)
|
||||
*/
|
||||
export function random_orthogonal_matrix(dim: number): Float32Array;
|
||||
|
||||
/**
|
||||
* Compute scaled dot-product attention
|
||||
*
|
||||
* # Arguments
|
||||
* * `query` - Query vector as Float32Array
|
||||
* * `keys` - Array of key vectors
|
||||
* * `values` - Array of value vectors
|
||||
* * `scale` - Optional scaling factor (defaults to 1/sqrt(dim))
|
||||
*/
|
||||
export function scaled_dot_attention(query: Float32Array, keys: any, values: any, scale?: number | null): Float32Array;
|
||||
|
||||
/**
|
||||
* Compute softmax of a vector
|
||||
*/
|
||||
export function softmax(vec: Float32Array): void;
|
||||
|
||||
/**
|
||||
* Get the version of the ruvector-attention-wasm crate
|
||||
*/
|
||||
export function version(): string;
|
||||
1417
crates/ruvector-attention-wasm/pkg/ruvector_attention_wasm.js
Normal file
1417
crates/ruvector-attention-wasm/pkg/ruvector_attention_wasm.js
Normal file
File diff suppressed because it is too large
Load Diff
Binary file not shown.
71
crates/ruvector-attention-wasm/pkg/ruvector_attention_wasm_bg.wasm.d.ts
vendored
Normal file
71
crates/ruvector-attention-wasm/pkg/ruvector_attention_wasm_bg.wasm.d.ts
vendored
Normal file
@@ -0,0 +1,71 @@
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export const memory: WebAssembly.Memory;
|
||||
export const __wbg_wasmadam_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmadamw_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmflashattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmhyperbolicattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasminfonceloss_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmlinearattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmmoeattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmmultiheadattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmsgd_free: (a: number, b: number) => void;
|
||||
export const attention_weights: (a: number, b: number, c: number, d: number) => void;
|
||||
export const available_mechanisms: () => number;
|
||||
export const batch_normalize: (a: number, b: number, c: number) => void;
|
||||
export const cosine_similarity: (a: number, b: number, c: number, d: number, e: number) => void;
|
||||
export const l2_norm: (a: number, b: number) => number;
|
||||
export const log: (a: number, b: number) => void;
|
||||
export const log_error: (a: number, b: number) => void;
|
||||
export const normalize: (a: number, b: number, c: number, d: number) => void;
|
||||
export const pairwise_distances: (a: number, b: number) => void;
|
||||
export const random_orthogonal_matrix: (a: number, b: number) => void;
|
||||
export const scaled_dot_attention: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const softmax: (a: number, b: number, c: number) => void;
|
||||
export const version: (a: number) => void;
|
||||
export const wasmadam_learning_rate: (a: number) => number;
|
||||
export const wasmadam_new: (a: number, b: number) => number;
|
||||
export const wasmadam_reset: (a: number) => void;
|
||||
export const wasmadam_set_learning_rate: (a: number, b: number) => void;
|
||||
export const wasmadam_step: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmadamw_new: (a: number, b: number, c: number) => number;
|
||||
export const wasmadamw_reset: (a: number) => void;
|
||||
export const wasmadamw_step: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmadamw_weight_decay: (a: number) => number;
|
||||
export const wasmflashattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmflashattention_new: (a: number, b: number) => number;
|
||||
export const wasmhyperbolicattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmhyperbolicattention_curvature: (a: number) => number;
|
||||
export const wasmhyperbolicattention_new: (a: number, b: number) => number;
|
||||
export const wasminfonceloss_compute: (a: number, b: number, c: number, d: number, e: number, f: number, g: number) => void;
|
||||
export const wasminfonceloss_new: (a: number) => number;
|
||||
export const wasmlinearattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmlinearattention_new: (a: number, b: number) => number;
|
||||
export const wasmlocalglobalattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmlocalglobalattention_new: (a: number, b: number, c: number) => number;
|
||||
export const wasmlrscheduler_get_lr: (a: number) => number;
|
||||
export const wasmlrscheduler_new: (a: number, b: number, c: number) => number;
|
||||
export const wasmlrscheduler_reset: (a: number) => void;
|
||||
export const wasmlrscheduler_step: (a: number) => void;
|
||||
export const wasmmoeattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmmoeattention_new: (a: number, b: number, c: number) => number;
|
||||
export const wasmmultiheadattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmmultiheadattention_dim: (a: number) => number;
|
||||
export const wasmmultiheadattention_new: (a: number, b: number, c: number) => void;
|
||||
export const wasmmultiheadattention_num_heads: (a: number) => number;
|
||||
export const wasmsgd_learning_rate: (a: number) => number;
|
||||
export const wasmsgd_new: (a: number, b: number, c: number) => number;
|
||||
export const wasmsgd_reset: (a: number) => void;
|
||||
export const wasmsgd_set_learning_rate: (a: number, b: number) => void;
|
||||
export const wasmsgd_step: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const init: () => void;
|
||||
export const wasmadamw_set_learning_rate: (a: number, b: number) => void;
|
||||
export const wasmadamw_learning_rate: (a: number) => number;
|
||||
export const __wbg_wasmlocalglobalattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmlrscheduler_free: (a: number, b: number) => void;
|
||||
export const __wbindgen_export: (a: number, b: number) => number;
|
||||
export const __wbindgen_export2: (a: number, b: number, c: number, d: number) => number;
|
||||
export const __wbindgen_export3: (a: number) => void;
|
||||
export const __wbindgen_export4: (a: number, b: number, c: number) => void;
|
||||
export const __wbindgen_add_to_stack_pointer: (a: number) => number;
|
||||
export const __wbindgen_start: () => void;
|
||||
308
crates/ruvector-attention-wasm/src/attention.rs
Normal file
308
crates/ruvector-attention-wasm/src/attention.rs
Normal file
@@ -0,0 +1,308 @@
|
||||
use ruvector_attention::{
|
||||
attention::{MultiHeadAttention, ScaledDotProductAttention},
|
||||
hyperbolic::{HyperbolicAttention, HyperbolicAttentionConfig},
|
||||
moe::{MoEAttention, MoEConfig},
|
||||
sparse::{FlashAttention, LinearAttention, LocalGlobalAttention},
|
||||
traits::Attention,
|
||||
};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Compute scaled dot-product attention
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Query vector as Float32Array
|
||||
/// * `keys` - Array of key vectors
|
||||
/// * `values` - Array of value vectors
|
||||
/// * `scale` - Optional scaling factor (defaults to 1/sqrt(dim))
|
||||
#[wasm_bindgen]
|
||||
pub fn scaled_dot_attention(
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
scale: Option<f32>,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)
|
||||
.map_err(|e| JsError::new(&format!("Failed to parse keys: {}", e)))?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)
|
||||
.map_err(|e| JsError::new(&format!("Failed to parse values: {}", e)))?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let attention = ScaledDotProductAttention::new(query.len());
|
||||
attention
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
|
||||
/// Multi-head attention mechanism
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmMultiHeadAttention {
|
||||
inner: MultiHeadAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmMultiHeadAttention {
|
||||
/// Create a new multi-head attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `num_heads` - Number of attention heads
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, num_heads: usize) -> Result<WasmMultiHeadAttention, JsError> {
|
||||
if dim % num_heads != 0 {
|
||||
return Err(JsError::new(&format!(
|
||||
"Dimension {} must be divisible by number of heads {}",
|
||||
dim, num_heads
|
||||
)));
|
||||
}
|
||||
Ok(Self {
|
||||
inner: MultiHeadAttention::new(dim, num_heads),
|
||||
})
|
||||
}
|
||||
|
||||
/// Compute multi-head attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
|
||||
/// Get the number of heads
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn num_heads(&self) -> usize {
|
||||
self.inner.num_heads()
|
||||
}
|
||||
|
||||
/// Get the dimension
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn dim(&self) -> usize {
|
||||
self.inner.dim()
|
||||
}
|
||||
}
|
||||
|
||||
/// Hyperbolic attention mechanism
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmHyperbolicAttention {
|
||||
inner: HyperbolicAttention,
|
||||
curvature_value: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmHyperbolicAttention {
|
||||
/// Create a new hyperbolic attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `curvature` - Hyperbolic curvature parameter
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, curvature: f32) -> WasmHyperbolicAttention {
|
||||
let config = HyperbolicAttentionConfig {
|
||||
dim,
|
||||
curvature,
|
||||
..Default::default()
|
||||
};
|
||||
Self {
|
||||
inner: HyperbolicAttention::new(config),
|
||||
curvature_value: curvature,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute hyperbolic attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
|
||||
/// Get the curvature
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn curvature(&self) -> f32 {
|
||||
self.curvature_value
|
||||
}
|
||||
}
|
||||
|
||||
/// Linear attention (Performer-style)
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmLinearAttention {
|
||||
inner: LinearAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmLinearAttention {
|
||||
/// Create a new linear attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `num_features` - Number of random features
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, num_features: usize) -> WasmLinearAttention {
|
||||
Self {
|
||||
inner: LinearAttention::new(dim, num_features),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute linear attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Flash attention mechanism
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmFlashAttention {
|
||||
inner: FlashAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmFlashAttention {
|
||||
/// Create a new flash attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `block_size` - Block size for tiling
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, block_size: usize) -> WasmFlashAttention {
|
||||
Self {
|
||||
inner: FlashAttention::new(dim, block_size),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute flash attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Local-global attention mechanism
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmLocalGlobalAttention {
|
||||
inner: LocalGlobalAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmLocalGlobalAttention {
|
||||
/// Create a new local-global attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `local_window` - Size of local attention window
|
||||
/// * `global_tokens` - Number of global attention tokens
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, local_window: usize, global_tokens: usize) -> WasmLocalGlobalAttention {
|
||||
Self {
|
||||
inner: LocalGlobalAttention::new(dim, local_window, global_tokens),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute local-global attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Mixture of Experts (MoE) attention
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmMoEAttention {
|
||||
inner: MoEAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmMoEAttention {
|
||||
/// Create a new MoE attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `num_experts` - Number of expert attention mechanisms
|
||||
/// * `top_k` - Number of experts to use per query
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, num_experts: usize, top_k: usize) -> WasmMoEAttention {
|
||||
let config = MoEConfig::builder()
|
||||
.dim(dim)
|
||||
.num_experts(num_experts)
|
||||
.top_k(top_k)
|
||||
.build();
|
||||
Self {
|
||||
inner: MoEAttention::new(config),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute MoE attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
}
|
||||
33
crates/ruvector-attention-wasm/src/lib.rs
Normal file
33
crates/ruvector-attention-wasm/src/lib.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
pub mod attention;
|
||||
pub mod training;
|
||||
pub mod utils;
|
||||
|
||||
/// Initialize the WASM module with panic hook
|
||||
#[wasm_bindgen(start)]
|
||||
pub fn init() {
|
||||
#[cfg(feature = "console_error_panic_hook")]
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
||||
|
||||
/// Get the version of the ruvector-attention-wasm crate
|
||||
#[wasm_bindgen]
|
||||
pub fn version() -> String {
|
||||
env!("CARGO_PKG_VERSION").to_string()
|
||||
}
|
||||
|
||||
/// Get information about available attention mechanisms
|
||||
#[wasm_bindgen]
|
||||
pub fn available_mechanisms() -> JsValue {
|
||||
let mechanisms = vec![
|
||||
"scaled_dot_product",
|
||||
"multi_head",
|
||||
"hyperbolic",
|
||||
"linear",
|
||||
"flash",
|
||||
"local_global",
|
||||
"moe",
|
||||
];
|
||||
serde_wasm_bindgen::to_value(&mechanisms).unwrap()
|
||||
}
|
||||
238
crates/ruvector-attention-wasm/src/training.rs
Normal file
238
crates/ruvector-attention-wasm/src/training.rs
Normal file
@@ -0,0 +1,238 @@
|
||||
use ruvector_attention::training::{Adam, AdamW, InfoNCELoss, Loss, Optimizer, SGD};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// InfoNCE contrastive loss for training
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmInfoNCELoss {
|
||||
inner: InfoNCELoss,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmInfoNCELoss {
|
||||
/// Create a new InfoNCE loss instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `temperature` - Temperature parameter for softmax
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(temperature: f32) -> WasmInfoNCELoss {
|
||||
Self {
|
||||
inner: InfoNCELoss::new(temperature),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute InfoNCE loss
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `anchor` - Anchor embedding
|
||||
/// * `positive` - Positive example embedding
|
||||
/// * `negatives` - Array of negative example embeddings
|
||||
pub fn compute(
|
||||
&self,
|
||||
anchor: &[f32],
|
||||
positive: &[f32],
|
||||
negatives: JsValue,
|
||||
) -> Result<f32, JsError> {
|
||||
let negatives_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(negatives)?;
|
||||
let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
|
||||
|
||||
Ok(self.inner.compute(anchor, positive, &negatives_refs))
|
||||
}
|
||||
}
|
||||
|
||||
/// Adam optimizer
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmAdam {
|
||||
inner: Adam,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmAdam {
|
||||
/// Create a new Adam optimizer
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `param_count` - Number of parameters
|
||||
/// * `learning_rate` - Learning rate
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(param_count: usize, learning_rate: f32) -> WasmAdam {
|
||||
Self {
|
||||
inner: Adam::new(param_count, learning_rate),
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform optimization step
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `params` - Current parameter values (will be updated in-place)
|
||||
/// * `gradients` - Gradient values
|
||||
pub fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
|
||||
self.inner.step(params, gradients);
|
||||
}
|
||||
|
||||
/// Reset optimizer state
|
||||
pub fn reset(&mut self) {
|
||||
self.inner.reset();
|
||||
}
|
||||
|
||||
/// Get current learning rate
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn learning_rate(&self) -> f32 {
|
||||
self.inner.learning_rate()
|
||||
}
|
||||
|
||||
/// Set learning rate
|
||||
#[wasm_bindgen(setter)]
|
||||
pub fn set_learning_rate(&mut self, lr: f32) {
|
||||
self.inner.set_learning_rate(lr);
|
||||
}
|
||||
}
|
||||
|
||||
/// AdamW optimizer (Adam with decoupled weight decay)
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmAdamW {
|
||||
inner: AdamW,
|
||||
wd: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmAdamW {
|
||||
/// Create a new AdamW optimizer
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `param_count` - Number of parameters
|
||||
/// * `learning_rate` - Learning rate
|
||||
/// * `weight_decay` - Weight decay coefficient
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(param_count: usize, learning_rate: f32, weight_decay: f32) -> WasmAdamW {
|
||||
let optimizer = AdamW::new(param_count, learning_rate).with_weight_decay(weight_decay);
|
||||
Self {
|
||||
inner: optimizer,
|
||||
wd: weight_decay,
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform optimization step with weight decay
|
||||
pub fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
|
||||
self.inner.step(params, gradients);
|
||||
}
|
||||
|
||||
/// Reset optimizer state
|
||||
pub fn reset(&mut self) {
|
||||
self.inner.reset();
|
||||
}
|
||||
|
||||
/// Get current learning rate
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn learning_rate(&self) -> f32 {
|
||||
self.inner.learning_rate()
|
||||
}
|
||||
|
||||
/// Set learning rate
|
||||
#[wasm_bindgen(setter)]
|
||||
pub fn set_learning_rate(&mut self, lr: f32) {
|
||||
self.inner.set_learning_rate(lr);
|
||||
}
|
||||
|
||||
/// Get weight decay
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn weight_decay(&self) -> f32 {
|
||||
self.wd
|
||||
}
|
||||
}
|
||||
|
||||
/// SGD optimizer with momentum
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmSGD {
|
||||
inner: SGD,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmSGD {
|
||||
/// Create a new SGD optimizer
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `param_count` - Number of parameters
|
||||
/// * `learning_rate` - Learning rate
|
||||
/// * `momentum` - Momentum coefficient (default: 0)
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(param_count: usize, learning_rate: f32, momentum: Option<f32>) -> WasmSGD {
|
||||
let mut optimizer = SGD::new(param_count, learning_rate);
|
||||
if let Some(m) = momentum {
|
||||
optimizer = optimizer.with_momentum(m);
|
||||
}
|
||||
Self { inner: optimizer }
|
||||
}
|
||||
|
||||
/// Perform optimization step
|
||||
pub fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
|
||||
self.inner.step(params, gradients);
|
||||
}
|
||||
|
||||
/// Reset optimizer state
|
||||
pub fn reset(&mut self) {
|
||||
self.inner.reset();
|
||||
}
|
||||
|
||||
/// Get current learning rate
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn learning_rate(&self) -> f32 {
|
||||
self.inner.learning_rate()
|
||||
}
|
||||
|
||||
/// Set learning rate
|
||||
#[wasm_bindgen(setter)]
|
||||
pub fn set_learning_rate(&mut self, lr: f32) {
|
||||
self.inner.set_learning_rate(lr);
|
||||
}
|
||||
}
|
||||
|
||||
/// Learning rate scheduler
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmLRScheduler {
|
||||
initial_lr: f32,
|
||||
current_step: usize,
|
||||
warmup_steps: usize,
|
||||
total_steps: usize,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmLRScheduler {
|
||||
/// Create a new learning rate scheduler with warmup and cosine decay
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `initial_lr` - Initial learning rate
|
||||
/// * `warmup_steps` - Number of warmup steps
|
||||
/// * `total_steps` - Total training steps
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(initial_lr: f32, warmup_steps: usize, total_steps: usize) -> WasmLRScheduler {
|
||||
Self {
|
||||
initial_lr,
|
||||
current_step: 0,
|
||||
warmup_steps,
|
||||
total_steps,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get learning rate for current step
|
||||
pub fn get_lr(&self) -> f32 {
|
||||
if self.current_step < self.warmup_steps {
|
||||
// Linear warmup
|
||||
self.initial_lr * (self.current_step as f32 / self.warmup_steps as f32)
|
||||
} else {
|
||||
// Cosine decay
|
||||
let progress = (self.current_step - self.warmup_steps) as f32
|
||||
/ (self.total_steps - self.warmup_steps) as f32;
|
||||
let cosine = 0.5 * (1.0 + (std::f32::consts::PI * progress).cos());
|
||||
self.initial_lr * cosine
|
||||
}
|
||||
}
|
||||
|
||||
/// Advance to next step
|
||||
pub fn step(&mut self) {
|
||||
self.current_step += 1;
|
||||
}
|
||||
|
||||
/// Reset scheduler
|
||||
pub fn reset(&mut self) {
|
||||
self.current_step = 0;
|
||||
}
|
||||
}
|
||||
201
crates/ruvector-attention-wasm/src/utils.rs
Normal file
201
crates/ruvector-attention-wasm/src/utils.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
use wasm_bindgen::prelude::*;
|
||||
use web_sys::console;
|
||||
|
||||
/// Log a message to the browser console
|
||||
#[wasm_bindgen]
|
||||
pub fn log(message: &str) {
|
||||
console::log_1(&message.into());
|
||||
}
|
||||
|
||||
/// Log an error to the browser console
|
||||
#[wasm_bindgen]
|
||||
pub fn log_error(message: &str) {
|
||||
console::error_1(&message.into());
|
||||
}
|
||||
|
||||
/// Compute cosine similarity between two vectors
|
||||
#[wasm_bindgen]
|
||||
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32, JsError> {
|
||||
if a.len() != b.len() {
|
||||
return Err(JsError::new("Vectors must have same length"));
|
||||
}
|
||||
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a == 0.0 || norm_b == 0.0 {
|
||||
return Err(JsError::new("Cannot compute similarity for zero vector"));
|
||||
}
|
||||
|
||||
Ok(dot / (norm_a * norm_b))
|
||||
}
|
||||
|
||||
/// Compute L2 norm of a vector
|
||||
#[wasm_bindgen]
|
||||
pub fn l2_norm(vec: &[f32]) -> f32 {
|
||||
vec.iter().map(|x| x * x).sum::<f32>().sqrt()
|
||||
}
|
||||
|
||||
/// Normalize a vector to unit length
|
||||
#[wasm_bindgen]
|
||||
pub fn normalize(vec: &mut [f32]) -> Result<(), JsError> {
|
||||
let norm = l2_norm(vec);
|
||||
if norm == 0.0 {
|
||||
return Err(JsError::new("Cannot normalize zero vector"));
|
||||
}
|
||||
|
||||
for x in vec.iter_mut() {
|
||||
*x /= norm;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Compute softmax of a vector
|
||||
#[wasm_bindgen]
|
||||
pub fn softmax(vec: &mut [f32]) {
|
||||
// Subtract max for numerical stability
|
||||
let max = vec.iter().copied().fold(f32::NEG_INFINITY, f32::max);
|
||||
|
||||
// Compute exp and sum
|
||||
let mut sum = 0.0;
|
||||
for x in vec.iter_mut() {
|
||||
*x = (*x - max).exp();
|
||||
sum += *x;
|
||||
}
|
||||
|
||||
// Normalize
|
||||
for x in vec.iter_mut() {
|
||||
*x /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute attention weights from scores
|
||||
#[wasm_bindgen]
|
||||
pub fn attention_weights(scores: &mut [f32], temperature: Option<f32>) {
|
||||
let temp = temperature.unwrap_or(1.0);
|
||||
|
||||
// Scale by temperature
|
||||
for score in scores.iter_mut() {
|
||||
*score /= temp;
|
||||
}
|
||||
|
||||
// Apply softmax
|
||||
softmax(scores);
|
||||
}
|
||||
|
||||
/// Batch normalize vectors
|
||||
#[wasm_bindgen]
|
||||
pub fn batch_normalize(vectors: JsValue, epsilon: Option<f32>) -> Result<Vec<f32>, JsError> {
|
||||
let eps = epsilon.unwrap_or(1e-8);
|
||||
let mut vecs: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(vectors)?;
|
||||
|
||||
if vecs.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let dim = vecs[0].len();
|
||||
let batch_size = vecs.len();
|
||||
|
||||
// Compute mean
|
||||
let mut mean = vec![0.0; dim];
|
||||
for vec in &vecs {
|
||||
for (i, &val) in vec.iter().enumerate() {
|
||||
mean[i] += val;
|
||||
}
|
||||
}
|
||||
for m in &mut mean {
|
||||
*m /= batch_size as f32;
|
||||
}
|
||||
|
||||
// Compute variance
|
||||
let mut variance = vec![0.0; dim];
|
||||
for vec in &vecs {
|
||||
for (i, &val) in vec.iter().enumerate() {
|
||||
let diff = val - mean[i];
|
||||
variance[i] += diff * diff;
|
||||
}
|
||||
}
|
||||
for v in &mut variance {
|
||||
*v /= batch_size as f32;
|
||||
}
|
||||
|
||||
// Normalize
|
||||
for vec in &mut vecs {
|
||||
for (i, val) in vec.iter_mut().enumerate() {
|
||||
*val = (*val - mean[i]) / (variance[i] + eps).sqrt();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(vecs.into_iter().flatten().collect())
|
||||
}
|
||||
|
||||
/// Generate random orthogonal matrix (for initialization)
|
||||
#[wasm_bindgen]
|
||||
pub fn random_orthogonal_matrix(dim: usize) -> Vec<f32> {
|
||||
use js_sys::Math;
|
||||
|
||||
let mut matrix = vec![0.0; dim * dim];
|
||||
|
||||
// Generate random matrix
|
||||
for i in 0..dim {
|
||||
for j in 0..dim {
|
||||
matrix[i * dim + j] = (Math::random() as f32 - 0.5) * 2.0;
|
||||
}
|
||||
}
|
||||
|
||||
// QR decomposition (simplified Gram-Schmidt)
|
||||
for i in 0..dim {
|
||||
// Normalize column i
|
||||
let mut norm = 0.0;
|
||||
for j in 0..dim {
|
||||
let val = matrix[j * dim + i];
|
||||
norm += val * val;
|
||||
}
|
||||
norm = norm.sqrt();
|
||||
|
||||
for j in 0..dim {
|
||||
matrix[j * dim + i] /= norm;
|
||||
}
|
||||
|
||||
// Orthogonalize remaining columns
|
||||
for k in (i + 1)..dim {
|
||||
let mut dot = 0.0;
|
||||
for j in 0..dim {
|
||||
dot += matrix[j * dim + i] * matrix[j * dim + k];
|
||||
}
|
||||
|
||||
for j in 0..dim {
|
||||
matrix[j * dim + k] -= dot * matrix[j * dim + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
matrix
|
||||
}
|
||||
|
||||
/// Compute pairwise distances between vectors
|
||||
#[wasm_bindgen]
|
||||
pub fn pairwise_distances(vectors: JsValue) -> Result<Vec<f32>, JsError> {
|
||||
let vecs: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(vectors)?;
|
||||
let n = vecs.len();
|
||||
let mut distances = vec![0.0; n * n];
|
||||
|
||||
for i in 0..n {
|
||||
for j in 0..n {
|
||||
if i == j {
|
||||
distances[i * n + j] = 0.0;
|
||||
} else {
|
||||
let mut dist = 0.0;
|
||||
for k in 0..vecs[i].len() {
|
||||
let diff = vecs[i][k] - vecs[j][k];
|
||||
dist += diff * diff;
|
||||
}
|
||||
distances[i * n + j] = dist.sqrt();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(distances)
|
||||
}
|
||||
185
crates/ruvector-attention-wasm/tests/web.rs
Normal file
185
crates/ruvector-attention-wasm/tests/web.rs
Normal file
@@ -0,0 +1,185 @@
|
||||
//! Test suite for WASM bindings
|
||||
//! Run with: wasm-pack test --headless --firefox
|
||||
|
||||
#![cfg(target_arch = "wasm32")]
|
||||
|
||||
use ruvector_attention_wasm::*;
|
||||
use wasm_bindgen_test::*;
|
||||
|
||||
wasm_bindgen_test_configure!(run_in_browser);
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_version() {
|
||||
let ver = version();
|
||||
assert!(!ver.is_empty());
|
||||
assert_eq!(ver, env!("CARGO_PKG_VERSION"));
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_available_mechanisms() {
|
||||
let mechanisms = available_mechanisms();
|
||||
assert!(mechanisms.is_array());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_multi_head_attention() {
|
||||
let mha = attention::WasmMultiHeadAttention::new(64, 8).unwrap();
|
||||
|
||||
assert_eq!(mha.dim(), 64);
|
||||
assert_eq!(mha.num_heads(), 8);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_hyperbolic_attention() {
|
||||
let ha = attention::WasmHyperbolicAttention::new(64, 1.0);
|
||||
assert_eq!(ha.curvature(), 1.0);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_linear_attention() {
|
||||
let la = attention::WasmLinearAttention::new(64, 256);
|
||||
// Linear attention doesn't expose internal state
|
||||
// Just verify it can be created
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_flash_attention() {
|
||||
let fa = attention::WasmFlashAttention::new(64, 16);
|
||||
// Flash attention doesn't expose internal state
|
||||
// Just verify it can be created
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_local_global_attention() {
|
||||
let lga = attention::WasmLocalGlobalAttention::new(64, 8, 4);
|
||||
// Local-global attention doesn't expose internal state
|
||||
// Just verify it can be created
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_moe_attention() {
|
||||
let moe = attention::WasmMoEAttention::new(64, 4, 2).unwrap();
|
||||
|
||||
// Get expert statistics
|
||||
let stats = moe.expert_stats();
|
||||
assert!(stats.is_object());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_info_nce_loss() {
|
||||
let loss = training::WasmInfoNCELoss::new(0.07);
|
||||
// InfoNCE loss doesn't expose temperature
|
||||
// Just verify it can be created
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_adam_optimizer() {
|
||||
let mut adam = training::WasmAdam::new(100, 0.001, Some(0.9), Some(0.999), Some(1e-8));
|
||||
|
||||
assert_eq!(adam.learning_rate(), 0.001);
|
||||
|
||||
adam.set_learning_rate(0.0001);
|
||||
assert_eq!(adam.learning_rate(), 0.0001);
|
||||
|
||||
adam.reset();
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_adamw_optimizer() {
|
||||
let mut adamw = training::WasmAdamW::new(100, 0.001, 0.01, Some(0.9), Some(0.999), Some(1e-8));
|
||||
|
||||
assert_eq!(adamw.learning_rate(), 0.001);
|
||||
assert_eq!(adamw.weight_decay(), 0.01);
|
||||
|
||||
adamw.reset();
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_lr_scheduler() {
|
||||
let mut scheduler = training::WasmLRScheduler::new(0.001, 100, 1000);
|
||||
|
||||
// At step 0, should be near 0 (warmup)
|
||||
let lr0 = scheduler.get_lr();
|
||||
assert!(lr0 < 0.001);
|
||||
|
||||
// After warmup, should be at initial LR
|
||||
for _ in 0..100 {
|
||||
scheduler.step();
|
||||
}
|
||||
let lr100 = scheduler.get_lr();
|
||||
assert!((lr100 - 0.001).abs() < 1e-6);
|
||||
|
||||
scheduler.reset();
|
||||
assert_eq!(scheduler.get_lr(), lr0);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_cosine_similarity() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
|
||||
let sim = utils::cosine_similarity(&a, &b).unwrap();
|
||||
assert!((sim - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_l2_norm() {
|
||||
let vec = vec![3.0, 4.0];
|
||||
let norm = utils::l2_norm(&vec);
|
||||
assert!((norm - 5.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_normalize() {
|
||||
let mut vec = vec![3.0, 4.0];
|
||||
utils::normalize(&mut vec).unwrap();
|
||||
|
||||
let norm = utils::l2_norm(&vec);
|
||||
assert!((norm - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_softmax() {
|
||||
let mut vec = vec![1.0, 2.0, 3.0];
|
||||
utils::softmax(&mut vec);
|
||||
|
||||
let sum: f32 = vec.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-6);
|
||||
|
||||
// Check monotonicity
|
||||
assert!(vec[0] < vec[1]);
|
||||
assert!(vec[1] < vec[2]);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_attention_weights() {
|
||||
let mut scores = vec![1.0, 2.0, 3.0];
|
||||
utils::attention_weights(&mut scores, Some(1.0));
|
||||
|
||||
let sum: f32 = scores.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_random_orthogonal_matrix() {
|
||||
let dim = 4;
|
||||
let matrix = utils::random_orthogonal_matrix(dim);
|
||||
|
||||
assert_eq!(matrix.len(), dim * dim);
|
||||
|
||||
// Check orthogonality: Q^T * Q = I
|
||||
for i in 0..dim {
|
||||
for j in 0..dim {
|
||||
let mut dot = 0.0;
|
||||
for k in 0..dim {
|
||||
dot += matrix[k * dim + i] * matrix[k * dim + j];
|
||||
}
|
||||
|
||||
if i == j {
|
||||
assert!((dot - 1.0).abs() < 1e-4, "Diagonal should be 1");
|
||||
} else {
|
||||
assert!(dot.abs() < 1e-4, "Off-diagonal should be 0");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
25
crates/ruvector-attention-wasm/tsconfig.json
Normal file
25
crates/ruvector-attention-wasm/tsconfig.json
Normal file
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2020",
|
||||
"module": "ES2020",
|
||||
"lib": ["ES2020", "DOM"],
|
||||
"declaration": true,
|
||||
"declarationMap": true,
|
||||
"sourceMap": true,
|
||||
"outDir": "./dist",
|
||||
"rootDir": "./js",
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"skipLibCheck": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"moduleResolution": "node",
|
||||
"resolveJsonModule": true,
|
||||
"isolatedModules": true,
|
||||
"noUnusedLocals": true,
|
||||
"noUnusedParameters": true,
|
||||
"noImplicitReturns": true,
|
||||
"noFallthroughCasesInSwitch": true
|
||||
},
|
||||
"include": ["js/**/*"],
|
||||
"exclude": ["node_modules", "pkg", "pkg-node", "pkg-bundler", "target"]
|
||||
}
|
||||
Reference in New Issue
Block a user