Squashed 'vendor/ruvector/' content from commit b64c2172

git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
commit d803bfe2b1
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,26 @@
# WASM build artifacts
pkg/
pkg-node/
pkg-bundler/
# Dependencies
node_modules/
# Build artifacts
target/
Cargo.lock
# Editor files
.vscode/
.idea/
*.swp
*.swo
*~
# OS files
.DS_Store
Thumbs.db
# Test coverage
coverage/
*.profraw

View File

@@ -0,0 +1,40 @@
[package]
name = "ruvector-attention-wasm"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
repository.workspace = true
description = "High-performance WebAssembly attention mechanisms: Multi-Head, Flash, Hyperbolic, MoE, CGT Sheaf Attention with GPU acceleration for transformers and LLMs"
homepage = "https://ruv.io/ruvector"
documentation = "https://docs.rs/ruvector-attention-wasm"
keywords = ["wasm", "attention", "transformer", "flash-attention", "llm"]
categories = ["wasm", "algorithms", "science"]
readme = "README.md"
[lib]
crate-type = ["cdylib", "rlib"]
[dependencies]
ruvector-attention = { version = "2.0", path = "../ruvector-attention", default-features = false, features = ["wasm"] }
wasm-bindgen = "0.2"
js-sys = "0.3"
web-sys = { version = "0.3", features = ["console"] }
serde = { version = "1.0", features = ["derive"] }
serde-wasm-bindgen = "0.6"
console_error_panic_hook = { version = "0.1", optional = true }
getrandom = { version = "0.2", features = ["js"] }
[dev-dependencies]
wasm-bindgen-test = "0.3"
[features]
default = ["console_error_panic_hook"]
[profile.release]
opt-level = "s"
lto = true
codegen-units = 1
[package.metadata.wasm-pack.profile.release]
wasm-opt = false

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 rUv
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,220 @@
# ruvector-attention-wasm
WebAssembly bindings for the ruvector-attention package, providing high-performance attention mechanisms for browser and Node.js environments.
## Features
- **Multiple Attention Mechanisms**:
- Scaled Dot-Product Attention
- Multi-Head Attention
- Hyperbolic Attention (for hierarchical data)
- Linear Attention (Performer-style)
- Flash Attention (memory-efficient)
- Local-Global Attention
- Mixture of Experts (MoE) Attention
- **CGT Sheaf Attention** (coherence-gated via Prime-Radiant)
- **Training Utilities**:
- InfoNCE contrastive loss
- Adam optimizer
- AdamW optimizer (with decoupled weight decay)
- Learning rate scheduler (warmup + cosine decay)
- **TypeScript Support**: Full type definitions and modern API
## Installation
```bash
npm install ruvector-attention-wasm
```
## Usage
### TypeScript/JavaScript
```typescript
import { initialize, MultiHeadAttention, utils } from 'ruvector-attention-wasm';
// Initialize WASM module
await initialize();
// Create multi-head attention
const attention = new MultiHeadAttention({ dim: 64, numHeads: 8 });
// Prepare inputs
const query = new Float32Array(64);
const keys = [new Float32Array(64), new Float32Array(64)];
const values = [new Float32Array(64), new Float32Array(64)];
// Compute attention
const output = attention.compute(query, keys, values);
// Use utilities
const similarity = utils.cosineSimilarity(query, keys[0]);
```
### Advanced Examples
#### Hyperbolic Attention
```typescript
import { HyperbolicAttention } from 'ruvector-attention-wasm';
const hyperbolic = new HyperbolicAttention({
dim: 128,
curvature: 1.0
});
const output = hyperbolic.compute(query, keys, values);
```
#### MoE Attention with Expert Stats
```typescript
import { MoEAttention } from 'ruvector-attention-wasm';
const moe = new MoEAttention({
dim: 64,
numExperts: 4,
topK: 2
});
const output = moe.compute(query, keys, values);
// Get expert utilization
const stats = moe.getExpertStats();
console.log('Load balance:', stats.loadBalance);
```
#### Training with InfoNCE Loss
```typescript
import { InfoNCELoss, Adam } from 'ruvector-attention-wasm';
const loss = new InfoNCELoss(0.07);
const optimizer = new Adam(paramCount, {
learningRate: 0.001,
beta1: 0.9,
beta2: 0.999,
});
// Training loop
const lossValue = loss.compute(anchor, positive, negatives);
optimizer.step(params, gradients);
```
#### Learning Rate Scheduling
```typescript
import { LRScheduler, AdamW } from 'ruvector-attention-wasm';
const scheduler = new LRScheduler({
initialLR: 0.001,
warmupSteps: 1000,
totalSteps: 10000,
});
const optimizer = new AdamW(paramCount, {
learningRate: scheduler.getLR(),
weightDecay: 0.01,
});
// Training loop
for (let step = 0; step < 10000; step++) {
optimizer.learningRate = scheduler.getLR();
optimizer.step(params, gradients);
scheduler.step();
}
```
## Building from Source
### Prerequisites
- Rust 1.70+
- wasm-pack
### Build Commands
```bash
# Build for web (ES modules)
wasm-pack build --target web --out-dir pkg
# Build for Node.js
wasm-pack build --target nodejs --out-dir pkg-node
# Build for bundlers (webpack, vite, etc.)
wasm-pack build --target bundler --out-dir pkg-bundler
# Run tests
wasm-pack test --headless --firefox
```
## API Reference
### Attention Mechanisms
- `MultiHeadAttention` - Standard multi-head attention
- `HyperbolicAttention` - Attention in hyperbolic space
- `LinearAttention` - Linear complexity attention (Performer)
- `FlashAttention` - Memory-efficient attention
- `LocalGlobalAttention` - Combined local and global attention
- `MoEAttention` - Mixture of Experts attention
- `CGTSheafAttention` - Coherence-gated via Prime-Radiant energy
- `scaledDotAttention()` - Functional API for basic attention
### CGT Sheaf Attention (Prime-Radiant Integration)
The CGT (Coherence-Gated Transformer) Sheaf Attention mechanism uses Prime-Radiant's sheaf Laplacian energy to gate attention based on mathematical consistency:
```typescript
import { CGTSheafAttention } from 'ruvector-attention-wasm';
const cgtAttention = new CGTSheafAttention({
dim: 128,
numHeads: 8,
coherenceThreshold: 0.3, // Block if energy > threshold
});
// Attention is gated by coherence energy
const result = cgtAttention.compute(query, keys, values);
console.log('Coherence energy:', result.energy);
console.log('Is coherent:', result.isCoherent);
```
**Key features:**
- Energy-weighted attention: Lower coherence energy → higher attention
- Automatic hallucination detection via residual analysis
- GPU-accelerated with wgpu WGSL shaders (vec4 optimized)
- SIMD fallback (AVX-512/AVX2/NEON)
### Training
- `InfoNCELoss` - Contrastive loss function
- `Adam` - Adam optimizer
- `AdamW` - AdamW optimizer with weight decay
- `LRScheduler` - Learning rate scheduler
### Utilities
- `utils.cosineSimilarity()` - Cosine similarity between vectors
- `utils.l2Norm()` - L2 norm of a vector
- `utils.normalize()` - Normalize vector to unit length
- `utils.softmax()` - Apply softmax transformation
- `utils.attentionWeights()` - Compute attention weights from scores
- `utils.batchNormalize()` - Batch normalization
- `utils.randomOrthogonalMatrix()` - Generate random orthogonal matrix
- `utils.pairwiseDistances()` - Compute pairwise distances
## Performance
The WASM bindings provide near-native performance for attention computations:
- Optimized with `opt-level = "s"` and LTO
- SIMD acceleration where available
- Efficient memory management
- Zero-copy data transfer where possible
## License
MIT OR Apache-2.0

View File

@@ -0,0 +1,412 @@
/**
* TypeScript wrapper for ruvector-attention-wasm
* Provides a clean, type-safe API for attention mechanisms
*/
import init, * as wasm from '../pkg/ruvector_attention_wasm';
import type {
AttentionConfig,
MultiHeadConfig,
HyperbolicConfig,
LinearAttentionConfig,
FlashAttentionConfig,
LocalGlobalConfig,
MoEConfig,
TrainingConfig,
SchedulerConfig,
ExpertStats,
AttentionType,
} from './types';
export * from './types';
let initialized = false;
/**
* Initialize the WASM module
* Must be called before using any attention mechanisms
*/
export async function initialize(): Promise<void> {
if (!initialized) {
await init();
initialized = true;
}
}
/**
* Get the version of the ruvector-attention-wasm package
*/
export function version(): string {
return wasm.version();
}
/**
* Get list of available attention mechanisms
*/
export function availableMechanisms(): AttentionType[] {
return wasm.available_mechanisms() as AttentionType[];
}
/**
* Multi-head attention mechanism
*/
export class MultiHeadAttention {
private inner: wasm.WasmMultiHeadAttention;
constructor(config: MultiHeadConfig) {
this.inner = new wasm.WasmMultiHeadAttention(config.dim, config.numHeads);
}
/**
* Compute multi-head attention
*/
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
const result = this.inner.compute(query, keys, values);
return new Float32Array(result);
}
get numHeads(): number {
return this.inner.num_heads;
}
get dim(): number {
return this.inner.dim;
}
free(): void {
this.inner.free();
}
}
/**
* Hyperbolic attention mechanism
*/
export class HyperbolicAttention {
private inner: wasm.WasmHyperbolicAttention;
constructor(config: HyperbolicConfig) {
this.inner = new wasm.WasmHyperbolicAttention(config.dim, config.curvature);
}
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
const result = this.inner.compute(query, keys, values);
return new Float32Array(result);
}
get curvature(): number {
return this.inner.curvature;
}
free(): void {
this.inner.free();
}
}
/**
* Linear attention (Performer-style)
*/
export class LinearAttention {
private inner: wasm.WasmLinearAttention;
constructor(config: LinearAttentionConfig) {
this.inner = new wasm.WasmLinearAttention(config.dim, config.numFeatures);
}
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
const result = this.inner.compute(query, keys, values);
return new Float32Array(result);
}
free(): void {
this.inner.free();
}
}
/**
* Flash attention mechanism
*/
export class FlashAttention {
private inner: wasm.WasmFlashAttention;
constructor(config: FlashAttentionConfig) {
this.inner = new wasm.WasmFlashAttention(config.dim, config.blockSize);
}
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
const result = this.inner.compute(query, keys, values);
return new Float32Array(result);
}
free(): void {
this.inner.free();
}
}
/**
* Local-global attention mechanism
*/
export class LocalGlobalAttention {
private inner: wasm.WasmLocalGlobalAttention;
constructor(config: LocalGlobalConfig) {
this.inner = new wasm.WasmLocalGlobalAttention(
config.dim,
config.localWindow,
config.globalTokens
);
}
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
const result = this.inner.compute(query, keys, values);
return new Float32Array(result);
}
free(): void {
this.inner.free();
}
}
/**
* Mixture of Experts attention
*/
export class MoEAttention {
private inner: wasm.WasmMoEAttention;
constructor(config: MoEConfig) {
this.inner = new wasm.WasmMoEAttention(config.dim, config.numExperts, config.topK);
}
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
const result = this.inner.compute(query, keys, values);
return new Float32Array(result);
}
getExpertStats(): ExpertStats {
return this.inner.expert_stats() as ExpertStats;
}
free(): void {
this.inner.free();
}
}
/**
* InfoNCE contrastive loss
*/
export class InfoNCELoss {
private inner: wasm.WasmInfoNCELoss;
constructor(temperature: number = 0.07) {
this.inner = new wasm.WasmInfoNCELoss(temperature);
}
compute(anchor: Float32Array, positive: Float32Array, negatives: Float32Array[]): number {
return this.inner.compute(anchor, positive, negatives);
}
computeMultiPositive(
anchor: Float32Array,
positives: Float32Array[],
negatives: Float32Array[]
): number {
return this.inner.compute_multi_positive(anchor, positives, negatives);
}
free(): void {
this.inner.free();
}
}
/**
* Adam optimizer
*/
export class Adam {
private inner: wasm.WasmAdam;
constructor(paramCount: number, config: TrainingConfig) {
this.inner = new wasm.WasmAdam(
paramCount,
config.learningRate,
config.beta1,
config.beta2,
config.epsilon
);
}
step(params: Float32Array, gradients: Float32Array): void {
this.inner.step(params, gradients);
}
reset(): void {
this.inner.reset();
}
get learningRate(): number {
return this.inner.learning_rate;
}
set learningRate(lr: number) {
this.inner.learning_rate = lr;
}
free(): void {
this.inner.free();
}
}
/**
* AdamW optimizer (Adam with decoupled weight decay)
*/
export class AdamW {
private inner: wasm.WasmAdamW;
constructor(paramCount: number, config: TrainingConfig) {
if (!config.weightDecay) {
throw new Error('AdamW requires weightDecay parameter');
}
this.inner = new wasm.WasmAdamW(
paramCount,
config.learningRate,
config.weightDecay,
config.beta1,
config.beta2,
config.epsilon
);
}
step(params: Float32Array, gradients: Float32Array): void {
this.inner.step(params, gradients);
}
reset(): void {
this.inner.reset();
}
get learningRate(): number {
return this.inner.learning_rate;
}
set learningRate(lr: number) {
this.inner.learning_rate = lr;
}
get weightDecay(): number {
return this.inner.weight_decay;
}
free(): void {
this.inner.free();
}
}
/**
* Learning rate scheduler with warmup and cosine decay
*/
export class LRScheduler {
private inner: wasm.WasmLRScheduler;
constructor(config: SchedulerConfig) {
this.inner = new wasm.WasmLRScheduler(
config.initialLR,
config.warmupSteps,
config.totalSteps
);
}
getLR(): number {
return this.inner.get_lr();
}
step(): void {
this.inner.step();
}
reset(): void {
this.inner.reset();
}
free(): void {
this.inner.free();
}
}
/**
* Utility functions
*/
export const utils = {
/**
* Compute cosine similarity between two vectors
*/
cosineSimilarity(a: Float32Array, b: Float32Array): number {
return wasm.cosine_similarity(a, b);
},
/**
* Compute L2 norm of a vector
*/
l2Norm(vec: Float32Array): number {
return wasm.l2_norm(vec);
},
/**
* Normalize a vector to unit length (in-place)
*/
normalize(vec: Float32Array): void {
wasm.normalize(vec);
},
/**
* Apply softmax to a vector (in-place)
*/
softmax(vec: Float32Array): void {
wasm.softmax(vec);
},
/**
* Compute attention weights from scores (in-place)
*/
attentionWeights(scores: Float32Array, temperature?: number): void {
wasm.attention_weights(scores, temperature);
},
/**
* Batch normalize vectors
*/
batchNormalize(vectors: Float32Array[], epsilon?: number): Float32Array {
const result = wasm.batch_normalize(vectors, epsilon);
return new Float32Array(result);
},
/**
* Generate random orthogonal matrix
*/
randomOrthogonalMatrix(dim: number): Float32Array {
const result = wasm.random_orthogonal_matrix(dim);
return new Float32Array(result);
},
/**
* Compute pairwise distances between vectors
*/
pairwiseDistances(vectors: Float32Array[]): Float32Array {
const result = wasm.pairwise_distances(vectors);
return new Float32Array(result);
},
};
/**
* Simple scaled dot-product attention (functional API)
*/
export function scaledDotAttention(
query: Float32Array,
keys: Float32Array[],
values: Float32Array[],
scale?: number
): Float32Array {
const result = wasm.scaled_dot_attention(query, keys, values, scale);
return new Float32Array(result);
}
// Re-export WASM module for advanced usage
export { wasm };

View File

@@ -0,0 +1,108 @@
/**
* TypeScript type definitions for ruvector-attention-wasm
*/
export interface AttentionConfig {
/** Embedding dimension */
dim: number;
/** Number of attention heads (for multi-head attention) */
numHeads?: number;
/** Dropout probability */
dropout?: number;
/** Scaling factor for attention scores */
scale?: number;
/** Whether to use causal masking */
causal?: boolean;
}
export interface MultiHeadConfig extends AttentionConfig {
numHeads: number;
}
export interface HyperbolicConfig extends AttentionConfig {
/** Hyperbolic space curvature */
curvature: number;
}
export interface LinearAttentionConfig extends AttentionConfig {
/** Number of random features for kernel approximation */
numFeatures: number;
}
export interface FlashAttentionConfig extends AttentionConfig {
/** Block size for tiling */
blockSize: number;
}
export interface LocalGlobalConfig extends AttentionConfig {
/** Size of local attention window */
localWindow: number;
/** Number of global attention tokens */
globalTokens: number;
}
export interface MoEConfig extends AttentionConfig {
/** Number of expert attention mechanisms */
numExperts: number;
/** Number of experts to use per query */
topK: number;
/** Maximum capacity per expert */
expertCapacity?: number;
/** Load balancing coefficient */
balanceCoeff?: number;
}
export interface TrainingConfig {
/** Learning rate for optimizer */
learningRate: number;
/** Temperature parameter for contrastive loss */
temperature?: number;
/** First moment decay rate (Adam/AdamW) */
beta1?: number;
/** Second moment decay rate (Adam/AdamW) */
beta2?: number;
/** Weight decay coefficient (AdamW) */
weightDecay?: number;
/** Numerical stability constant */
epsilon?: number;
}
export interface SchedulerConfig {
/** Initial learning rate */
initialLR: number;
/** Number of warmup steps */
warmupSteps: number;
/** Total training steps */
totalSteps: number;
}
export interface ExpertStats {
/** Number of times each expert was selected */
selectionCounts: number[];
/** Average load per expert */
averageLoad: number[];
/** Load balance factor (lower is better) */
loadBalance: number;
}
/**
* Attention mechanism types
*/
export type AttentionType =
| 'scaled_dot_product'
| 'multi_head'
| 'hyperbolic'
| 'linear'
| 'flash'
| 'local_global'
| 'moe';
/**
* Optimizer types
*/
export type OptimizerType = 'adam' | 'adamw';
/**
* Loss function types
*/
export type LossType = 'info_nce';

View File

@@ -0,0 +1,66 @@
{
"name": "@ruvector/attention-wasm",
"version": "0.1.32",
"description": "High-performance WebAssembly attention mechanisms for transformers and LLMs: Multi-Head, Flash Attention, Hyperbolic, Linear (Performer), MoE, Local-Global, and CGT Sheaf Attention with coherence gating. GPU-accelerated with SIMD fallback.",
"main": "pkg/ruvector_attention_wasm.js",
"module": "pkg/ruvector_attention_wasm.js",
"types": "pkg/ruvector_attention_wasm.d.ts",
"files": [
"pkg/",
"js/",
"README.md"
],
"scripts": {
"build": "wasm-pack build --target web --out-dir pkg",
"build:node": "wasm-pack build --target nodejs --out-dir pkg-node",
"build:bundler": "wasm-pack build --target bundler --out-dir pkg-bundler",
"build:all": "npm run build && npm run build:node && npm run build:bundler",
"test": "wasm-pack test --headless --firefox",
"test:chrome": "wasm-pack test --headless --chrome",
"clean": "rm -rf pkg pkg-node pkg-bundler target",
"prepublishOnly": "npm run build"
},
"repository": {
"type": "git",
"url": "git+https://github.com/ruvnet/ruvector.git"
},
"keywords": [
"wasm",
"webassembly",
"attention",
"transformer",
"llm",
"machine-learning",
"neural-networks",
"multi-head-attention",
"flash-attention",
"hyperbolic",
"moe",
"mixture-of-experts",
"coherence",
"cgt",
"sheaf-attention",
"ai",
"deep-learning",
"gpu",
"simd",
"infonce",
"contrastive-learning"
],
"author": "rUv <team@ruvector.dev>",
"license": "MIT OR Apache-2.0",
"bugs": {
"url": "https://github.com/ruvnet/ruvector/issues"
},
"homepage": "https://ruv.io/ruvector",
"devDependencies": {
"@types/node": "^20.0.0",
"typescript": "^5.0.0"
},
"engines": {
"node": ">=16.0.0"
},
"publishConfig": {
"access": "public"
}
}

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 rUv
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,220 @@
# ruvector-attention-wasm
WebAssembly bindings for the ruvector-attention package, providing high-performance attention mechanisms for browser and Node.js environments.
## Features
- **Multiple Attention Mechanisms**:
- Scaled Dot-Product Attention
- Multi-Head Attention
- Hyperbolic Attention (for hierarchical data)
- Linear Attention (Performer-style)
- Flash Attention (memory-efficient)
- Local-Global Attention
- Mixture of Experts (MoE) Attention
- **CGT Sheaf Attention** (coherence-gated via Prime-Radiant)
- **Training Utilities**:
- InfoNCE contrastive loss
- Adam optimizer
- AdamW optimizer (with decoupled weight decay)
- Learning rate scheduler (warmup + cosine decay)
- **TypeScript Support**: Full type definitions and modern API
## Installation
```bash
npm install ruvector-attention-wasm
```
## Usage
### TypeScript/JavaScript
```typescript
import { initialize, MultiHeadAttention, utils } from 'ruvector-attention-wasm';
// Initialize WASM module
await initialize();
// Create multi-head attention
const attention = new MultiHeadAttention({ dim: 64, numHeads: 8 });
// Prepare inputs
const query = new Float32Array(64);
const keys = [new Float32Array(64), new Float32Array(64)];
const values = [new Float32Array(64), new Float32Array(64)];
// Compute attention
const output = attention.compute(query, keys, values);
// Use utilities
const similarity = utils.cosineSimilarity(query, keys[0]);
```
### Advanced Examples
#### Hyperbolic Attention
```typescript
import { HyperbolicAttention } from 'ruvector-attention-wasm';
const hyperbolic = new HyperbolicAttention({
dim: 128,
curvature: 1.0
});
const output = hyperbolic.compute(query, keys, values);
```
#### MoE Attention with Expert Stats
```typescript
import { MoEAttention } from 'ruvector-attention-wasm';
const moe = new MoEAttention({
dim: 64,
numExperts: 4,
topK: 2
});
const output = moe.compute(query, keys, values);
// Get expert utilization
const stats = moe.getExpertStats();
console.log('Load balance:', stats.loadBalance);
```
#### Training with InfoNCE Loss
```typescript
import { InfoNCELoss, Adam } from 'ruvector-attention-wasm';
const loss = new InfoNCELoss(0.07);
const optimizer = new Adam(paramCount, {
learningRate: 0.001,
beta1: 0.9,
beta2: 0.999,
});
// Training loop
const lossValue = loss.compute(anchor, positive, negatives);
optimizer.step(params, gradients);
```
#### Learning Rate Scheduling
```typescript
import { LRScheduler, AdamW } from 'ruvector-attention-wasm';
const scheduler = new LRScheduler({
initialLR: 0.001,
warmupSteps: 1000,
totalSteps: 10000,
});
const optimizer = new AdamW(paramCount, {
learningRate: scheduler.getLR(),
weightDecay: 0.01,
});
// Training loop
for (let step = 0; step < 10000; step++) {
optimizer.learningRate = scheduler.getLR();
optimizer.step(params, gradients);
scheduler.step();
}
```
## Building from Source
### Prerequisites
- Rust 1.70+
- wasm-pack
### Build Commands
```bash
# Build for web (ES modules)
wasm-pack build --target web --out-dir pkg
# Build for Node.js
wasm-pack build --target nodejs --out-dir pkg-node
# Build for bundlers (webpack, vite, etc.)
wasm-pack build --target bundler --out-dir pkg-bundler
# Run tests
wasm-pack test --headless --firefox
```
## API Reference
### Attention Mechanisms
- `MultiHeadAttention` - Standard multi-head attention
- `HyperbolicAttention` - Attention in hyperbolic space
- `LinearAttention` - Linear complexity attention (Performer)
- `FlashAttention` - Memory-efficient attention
- `LocalGlobalAttention` - Combined local and global attention
- `MoEAttention` - Mixture of Experts attention
- `CGTSheafAttention` - Coherence-gated via Prime-Radiant energy
- `scaledDotAttention()` - Functional API for basic attention
### CGT Sheaf Attention (Prime-Radiant Integration)
The CGT (Coherence-Gated Transformer) Sheaf Attention mechanism uses Prime-Radiant's sheaf Laplacian energy to gate attention based on mathematical consistency:
```typescript
import { CGTSheafAttention } from 'ruvector-attention-wasm';
const cgtAttention = new CGTSheafAttention({
dim: 128,
numHeads: 8,
coherenceThreshold: 0.3, // Block if energy > threshold
});
// Attention is gated by coherence energy
const result = cgtAttention.compute(query, keys, values);
console.log('Coherence energy:', result.energy);
console.log('Is coherent:', result.isCoherent);
```
**Key features:**
- Energy-weighted attention: Lower coherence energy → higher attention
- Automatic hallucination detection via residual analysis
- GPU-accelerated with wgpu WGSL shaders (vec4 optimized)
- SIMD fallback (AVX-512/AVX2/NEON)
### Training
- `InfoNCELoss` - Contrastive loss function
- `Adam` - Adam optimizer
- `AdamW` - AdamW optimizer with weight decay
- `LRScheduler` - Learning rate scheduler
### Utilities
- `utils.cosineSimilarity()` - Cosine similarity between vectors
- `utils.l2Norm()` - L2 norm of a vector
- `utils.normalize()` - Normalize vector to unit length
- `utils.softmax()` - Apply softmax transformation
- `utils.attentionWeights()` - Compute attention weights from scores
- `utils.batchNormalize()` - Batch normalization
- `utils.randomOrthogonalMatrix()` - Generate random orthogonal matrix
- `utils.pairwiseDistances()` - Compute pairwise distances
## Performance
The WASM bindings provide near-native performance for attention computations:
- Optimized with `opt-level = "s"` and LTO
- SIMD acceleration where available
- Efficient memory management
- Zero-copy data transfer where possible
## License
MIT OR Apache-2.0

View File

@@ -0,0 +1,28 @@
{
"name": "ruvector-attention-wasm",
"collaborators": [
"Ruvector Team"
],
"description": "High-performance WebAssembly attention mechanisms: Multi-Head, Flash, Hyperbolic, MoE, CGT Sheaf Attention with GPU acceleration for transformers and LLMs",
"version": "2.0.5",
"license": "MIT",
"repository": {
"type": "git",
"url": "https://github.com/ruvnet/ruvector"
},
"files": [
"ruvector_attention_wasm_bg.wasm",
"ruvector_attention_wasm.js",
"ruvector_attention_wasm.d.ts"
],
"main": "ruvector_attention_wasm.js",
"homepage": "https://ruv.io/ruvector",
"types": "ruvector_attention_wasm.d.ts",
"keywords": [
"wasm",
"attention",
"transformer",
"flash-attention",
"llm"
]
}

View File

@@ -0,0 +1,359 @@
/* tslint:disable */
/* eslint-disable */
/**
* Adam optimizer
*/
export class WasmAdam {
free(): void;
[Symbol.dispose](): void;
/**
* Create a new Adam optimizer
*
* # Arguments
* * `param_count` - Number of parameters
* * `learning_rate` - Learning rate
*/
constructor(param_count: number, learning_rate: number);
/**
* Reset optimizer state
*/
reset(): void;
/**
* Perform optimization step
*
* # Arguments
* * `params` - Current parameter values (will be updated in-place)
* * `gradients` - Gradient values
*/
step(params: Float32Array, gradients: Float32Array): void;
/**
* Get current learning rate
*/
learning_rate: number;
}
/**
* AdamW optimizer (Adam with decoupled weight decay)
*/
export class WasmAdamW {
free(): void;
[Symbol.dispose](): void;
/**
* Create a new AdamW optimizer
*
* # Arguments
* * `param_count` - Number of parameters
* * `learning_rate` - Learning rate
* * `weight_decay` - Weight decay coefficient
*/
constructor(param_count: number, learning_rate: number, weight_decay: number);
/**
* Reset optimizer state
*/
reset(): void;
/**
* Perform optimization step with weight decay
*/
step(params: Float32Array, gradients: Float32Array): void;
/**
* Get current learning rate
*/
learning_rate: number;
/**
* Get weight decay
*/
readonly weight_decay: number;
}
/**
* Flash attention mechanism
*/
export class WasmFlashAttention {
free(): void;
[Symbol.dispose](): void;
/**
* Compute flash attention
*/
compute(query: Float32Array, keys: any, values: any): Float32Array;
/**
* Create a new flash attention instance
*
* # Arguments
* * `dim` - Embedding dimension
* * `block_size` - Block size for tiling
*/
constructor(dim: number, block_size: number);
}
/**
* Hyperbolic attention mechanism
*/
export class WasmHyperbolicAttention {
free(): void;
[Symbol.dispose](): void;
/**
* Compute hyperbolic attention
*/
compute(query: Float32Array, keys: any, values: any): Float32Array;
/**
* Create a new hyperbolic attention instance
*
* # Arguments
* * `dim` - Embedding dimension
* * `curvature` - Hyperbolic curvature parameter
*/
constructor(dim: number, curvature: number);
/**
* Get the curvature
*/
readonly curvature: number;
}
/**
* InfoNCE contrastive loss for training
*/
export class WasmInfoNCELoss {
free(): void;
[Symbol.dispose](): void;
/**
* Compute InfoNCE loss
*
* # Arguments
* * `anchor` - Anchor embedding
* * `positive` - Positive example embedding
* * `negatives` - Array of negative example embeddings
*/
compute(anchor: Float32Array, positive: Float32Array, negatives: any): number;
/**
* Create a new InfoNCE loss instance
*
* # Arguments
* * `temperature` - Temperature parameter for softmax
*/
constructor(temperature: number);
}
/**
* Learning rate scheduler
*/
export class WasmLRScheduler {
free(): void;
[Symbol.dispose](): void;
/**
* Get learning rate for current step
*/
get_lr(): number;
/**
* Create a new learning rate scheduler with warmup and cosine decay
*
* # Arguments
* * `initial_lr` - Initial learning rate
* * `warmup_steps` - Number of warmup steps
* * `total_steps` - Total training steps
*/
constructor(initial_lr: number, warmup_steps: number, total_steps: number);
/**
* Reset scheduler
*/
reset(): void;
/**
* Advance to next step
*/
step(): void;
}
/**
* Linear attention (Performer-style)
*/
export class WasmLinearAttention {
free(): void;
[Symbol.dispose](): void;
/**
* Compute linear attention
*/
compute(query: Float32Array, keys: any, values: any): Float32Array;
/**
* Create a new linear attention instance
*
* # Arguments
* * `dim` - Embedding dimension
* * `num_features` - Number of random features
*/
constructor(dim: number, num_features: number);
}
/**
* Local-global attention mechanism
*/
export class WasmLocalGlobalAttention {
free(): void;
[Symbol.dispose](): void;
/**
* Compute local-global attention
*/
compute(query: Float32Array, keys: any, values: any): Float32Array;
/**
* Create a new local-global attention instance
*
* # Arguments
* * `dim` - Embedding dimension
* * `local_window` - Size of local attention window
* * `global_tokens` - Number of global attention tokens
*/
constructor(dim: number, local_window: number, global_tokens: number);
}
/**
* Mixture of Experts (MoE) attention
*/
export class WasmMoEAttention {
free(): void;
[Symbol.dispose](): void;
/**
* Compute MoE attention
*/
compute(query: Float32Array, keys: any, values: any): Float32Array;
/**
* Create a new MoE attention instance
*
* # Arguments
* * `dim` - Embedding dimension
* * `num_experts` - Number of expert attention mechanisms
* * `top_k` - Number of experts to use per query
*/
constructor(dim: number, num_experts: number, top_k: number);
}
/**
* Multi-head attention mechanism
*/
export class WasmMultiHeadAttention {
free(): void;
[Symbol.dispose](): void;
/**
* Compute multi-head attention
*/
compute(query: Float32Array, keys: any, values: any): Float32Array;
/**
* Create a new multi-head attention instance
*
* # Arguments
* * `dim` - Embedding dimension
* * `num_heads` - Number of attention heads
*/
constructor(dim: number, num_heads: number);
/**
* Get the dimension
*/
readonly dim: number;
/**
* Get the number of heads
*/
readonly num_heads: number;
}
/**
* SGD optimizer with momentum
*/
export class WasmSGD {
free(): void;
[Symbol.dispose](): void;
/**
* Create a new SGD optimizer
*
* # Arguments
* * `param_count` - Number of parameters
* * `learning_rate` - Learning rate
* * `momentum` - Momentum coefficient (default: 0)
*/
constructor(param_count: number, learning_rate: number, momentum?: number | null);
/**
* Reset optimizer state
*/
reset(): void;
/**
* Perform optimization step
*/
step(params: Float32Array, gradients: Float32Array): void;
/**
* Get current learning rate
*/
learning_rate: number;
}
/**
* Compute attention weights from scores
*/
export function attention_weights(scores: Float32Array, temperature?: number | null): void;
/**
* Get information about available attention mechanisms
*/
export function available_mechanisms(): any;
/**
* Batch normalize vectors
*/
export function batch_normalize(vectors: any, epsilon?: number | null): Float32Array;
/**
* Compute cosine similarity between two vectors
*/
export function cosine_similarity(a: Float32Array, b: Float32Array): number;
/**
* Initialize the WASM module with panic hook
*/
export function init(): void;
/**
* Compute L2 norm of a vector
*/
export function l2_norm(vec: Float32Array): number;
/**
* Log a message to the browser console
*/
export function log(message: string): void;
/**
* Log an error to the browser console
*/
export function log_error(message: string): void;
/**
* Normalize a vector to unit length
*/
export function normalize(vec: Float32Array): void;
/**
* Compute pairwise distances between vectors
*/
export function pairwise_distances(vectors: any): Float32Array;
/**
* Generate random orthogonal matrix (for initialization)
*/
export function random_orthogonal_matrix(dim: number): Float32Array;
/**
* Compute scaled dot-product attention
*
* # Arguments
* * `query` - Query vector as Float32Array
* * `keys` - Array of key vectors
* * `values` - Array of value vectors
* * `scale` - Optional scaling factor (defaults to 1/sqrt(dim))
*/
export function scaled_dot_attention(query: Float32Array, keys: any, values: any, scale?: number | null): Float32Array;
/**
* Compute softmax of a vector
*/
export function softmax(vec: Float32Array): void;
/**
* Get the version of the ruvector-attention-wasm crate
*/
export function version(): string;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,71 @@
/* tslint:disable */
/* eslint-disable */
export const memory: WebAssembly.Memory;
export const __wbg_wasmadam_free: (a: number, b: number) => void;
export const __wbg_wasmadamw_free: (a: number, b: number) => void;
export const __wbg_wasmflashattention_free: (a: number, b: number) => void;
export const __wbg_wasmhyperbolicattention_free: (a: number, b: number) => void;
export const __wbg_wasminfonceloss_free: (a: number, b: number) => void;
export const __wbg_wasmlinearattention_free: (a: number, b: number) => void;
export const __wbg_wasmmoeattention_free: (a: number, b: number) => void;
export const __wbg_wasmmultiheadattention_free: (a: number, b: number) => void;
export const __wbg_wasmsgd_free: (a: number, b: number) => void;
export const attention_weights: (a: number, b: number, c: number, d: number) => void;
export const available_mechanisms: () => number;
export const batch_normalize: (a: number, b: number, c: number) => void;
export const cosine_similarity: (a: number, b: number, c: number, d: number, e: number) => void;
export const l2_norm: (a: number, b: number) => number;
export const log: (a: number, b: number) => void;
export const log_error: (a: number, b: number) => void;
export const normalize: (a: number, b: number, c: number, d: number) => void;
export const pairwise_distances: (a: number, b: number) => void;
export const random_orthogonal_matrix: (a: number, b: number) => void;
export const scaled_dot_attention: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
export const softmax: (a: number, b: number, c: number) => void;
export const version: (a: number) => void;
export const wasmadam_learning_rate: (a: number) => number;
export const wasmadam_new: (a: number, b: number) => number;
export const wasmadam_reset: (a: number) => void;
export const wasmadam_set_learning_rate: (a: number, b: number) => void;
export const wasmadam_step: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
export const wasmadamw_new: (a: number, b: number, c: number) => number;
export const wasmadamw_reset: (a: number) => void;
export const wasmadamw_step: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
export const wasmadamw_weight_decay: (a: number) => number;
export const wasmflashattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
export const wasmflashattention_new: (a: number, b: number) => number;
export const wasmhyperbolicattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
export const wasmhyperbolicattention_curvature: (a: number) => number;
export const wasmhyperbolicattention_new: (a: number, b: number) => number;
export const wasminfonceloss_compute: (a: number, b: number, c: number, d: number, e: number, f: number, g: number) => void;
export const wasminfonceloss_new: (a: number) => number;
export const wasmlinearattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
export const wasmlinearattention_new: (a: number, b: number) => number;
export const wasmlocalglobalattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
export const wasmlocalglobalattention_new: (a: number, b: number, c: number) => number;
export const wasmlrscheduler_get_lr: (a: number) => number;
export const wasmlrscheduler_new: (a: number, b: number, c: number) => number;
export const wasmlrscheduler_reset: (a: number) => void;
export const wasmlrscheduler_step: (a: number) => void;
export const wasmmoeattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
export const wasmmoeattention_new: (a: number, b: number, c: number) => number;
export const wasmmultiheadattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
export const wasmmultiheadattention_dim: (a: number) => number;
export const wasmmultiheadattention_new: (a: number, b: number, c: number) => void;
export const wasmmultiheadattention_num_heads: (a: number) => number;
export const wasmsgd_learning_rate: (a: number) => number;
export const wasmsgd_new: (a: number, b: number, c: number) => number;
export const wasmsgd_reset: (a: number) => void;
export const wasmsgd_set_learning_rate: (a: number, b: number) => void;
export const wasmsgd_step: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
export const init: () => void;
export const wasmadamw_set_learning_rate: (a: number, b: number) => void;
export const wasmadamw_learning_rate: (a: number) => number;
export const __wbg_wasmlocalglobalattention_free: (a: number, b: number) => void;
export const __wbg_wasmlrscheduler_free: (a: number, b: number) => void;
export const __wbindgen_export: (a: number, b: number) => number;
export const __wbindgen_export2: (a: number, b: number, c: number, d: number) => number;
export const __wbindgen_export3: (a: number) => void;
export const __wbindgen_export4: (a: number, b: number, c: number) => void;
export const __wbindgen_add_to_stack_pointer: (a: number) => number;
export const __wbindgen_start: () => void;

View File

@@ -0,0 +1,308 @@
use ruvector_attention::{
attention::{MultiHeadAttention, ScaledDotProductAttention},
hyperbolic::{HyperbolicAttention, HyperbolicAttentionConfig},
moe::{MoEAttention, MoEConfig},
sparse::{FlashAttention, LinearAttention, LocalGlobalAttention},
traits::Attention,
};
use wasm_bindgen::prelude::*;
/// Compute scaled dot-product attention
///
/// # Arguments
/// * `query` - Query vector as Float32Array
/// * `keys` - Array of key vectors
/// * `values` - Array of value vectors
/// * `scale` - Optional scaling factor (defaults to 1/sqrt(dim))
#[wasm_bindgen]
pub fn scaled_dot_attention(
query: &[f32],
keys: JsValue,
values: JsValue,
scale: Option<f32>,
) -> Result<Vec<f32>, JsError> {
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)
.map_err(|e| JsError::new(&format!("Failed to parse keys: {}", e)))?;
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)
.map_err(|e| JsError::new(&format!("Failed to parse values: {}", e)))?;
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
let attention = ScaledDotProductAttention::new(query.len());
attention
.compute(query, &keys_refs, &values_refs)
.map_err(|e| JsError::new(&e.to_string()))
}
/// Multi-head attention mechanism
#[wasm_bindgen]
pub struct WasmMultiHeadAttention {
inner: MultiHeadAttention,
}
#[wasm_bindgen]
impl WasmMultiHeadAttention {
/// Create a new multi-head attention instance
///
/// # Arguments
/// * `dim` - Embedding dimension
/// * `num_heads` - Number of attention heads
#[wasm_bindgen(constructor)]
pub fn new(dim: usize, num_heads: usize) -> Result<WasmMultiHeadAttention, JsError> {
if dim % num_heads != 0 {
return Err(JsError::new(&format!(
"Dimension {} must be divisible by number of heads {}",
dim, num_heads
)));
}
Ok(Self {
inner: MultiHeadAttention::new(dim, num_heads),
})
}
/// Compute multi-head attention
pub fn compute(
&self,
query: &[f32],
keys: JsValue,
values: JsValue,
) -> Result<Vec<f32>, JsError> {
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
self.inner
.compute(query, &keys_refs, &values_refs)
.map_err(|e| JsError::new(&e.to_string()))
}
/// Get the number of heads
#[wasm_bindgen(getter)]
pub fn num_heads(&self) -> usize {
self.inner.num_heads()
}
/// Get the dimension
#[wasm_bindgen(getter)]
pub fn dim(&self) -> usize {
self.inner.dim()
}
}
/// Hyperbolic attention mechanism
#[wasm_bindgen]
pub struct WasmHyperbolicAttention {
inner: HyperbolicAttention,
curvature_value: f32,
}
#[wasm_bindgen]
impl WasmHyperbolicAttention {
/// Create a new hyperbolic attention instance
///
/// # Arguments
/// * `dim` - Embedding dimension
/// * `curvature` - Hyperbolic curvature parameter
#[wasm_bindgen(constructor)]
pub fn new(dim: usize, curvature: f32) -> WasmHyperbolicAttention {
let config = HyperbolicAttentionConfig {
dim,
curvature,
..Default::default()
};
Self {
inner: HyperbolicAttention::new(config),
curvature_value: curvature,
}
}
/// Compute hyperbolic attention
pub fn compute(
&self,
query: &[f32],
keys: JsValue,
values: JsValue,
) -> Result<Vec<f32>, JsError> {
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
self.inner
.compute(query, &keys_refs, &values_refs)
.map_err(|e| JsError::new(&e.to_string()))
}
/// Get the curvature
#[wasm_bindgen(getter)]
pub fn curvature(&self) -> f32 {
self.curvature_value
}
}
/// Linear attention (Performer-style)
#[wasm_bindgen]
pub struct WasmLinearAttention {
inner: LinearAttention,
}
#[wasm_bindgen]
impl WasmLinearAttention {
/// Create a new linear attention instance
///
/// # Arguments
/// * `dim` - Embedding dimension
/// * `num_features` - Number of random features
#[wasm_bindgen(constructor)]
pub fn new(dim: usize, num_features: usize) -> WasmLinearAttention {
Self {
inner: LinearAttention::new(dim, num_features),
}
}
/// Compute linear attention
pub fn compute(
&self,
query: &[f32],
keys: JsValue,
values: JsValue,
) -> Result<Vec<f32>, JsError> {
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
self.inner
.compute(query, &keys_refs, &values_refs)
.map_err(|e| JsError::new(&e.to_string()))
}
}
/// Flash attention mechanism
#[wasm_bindgen]
pub struct WasmFlashAttention {
inner: FlashAttention,
}
#[wasm_bindgen]
impl WasmFlashAttention {
/// Create a new flash attention instance
///
/// # Arguments
/// * `dim` - Embedding dimension
/// * `block_size` - Block size for tiling
#[wasm_bindgen(constructor)]
pub fn new(dim: usize, block_size: usize) -> WasmFlashAttention {
Self {
inner: FlashAttention::new(dim, block_size),
}
}
/// Compute flash attention
pub fn compute(
&self,
query: &[f32],
keys: JsValue,
values: JsValue,
) -> Result<Vec<f32>, JsError> {
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
self.inner
.compute(query, &keys_refs, &values_refs)
.map_err(|e| JsError::new(&e.to_string()))
}
}
/// Local-global attention mechanism
#[wasm_bindgen]
pub struct WasmLocalGlobalAttention {
inner: LocalGlobalAttention,
}
#[wasm_bindgen]
impl WasmLocalGlobalAttention {
/// Create a new local-global attention instance
///
/// # Arguments
/// * `dim` - Embedding dimension
/// * `local_window` - Size of local attention window
/// * `global_tokens` - Number of global attention tokens
#[wasm_bindgen(constructor)]
pub fn new(dim: usize, local_window: usize, global_tokens: usize) -> WasmLocalGlobalAttention {
Self {
inner: LocalGlobalAttention::new(dim, local_window, global_tokens),
}
}
/// Compute local-global attention
pub fn compute(
&self,
query: &[f32],
keys: JsValue,
values: JsValue,
) -> Result<Vec<f32>, JsError> {
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
self.inner
.compute(query, &keys_refs, &values_refs)
.map_err(|e| JsError::new(&e.to_string()))
}
}
/// Mixture of Experts (MoE) attention
#[wasm_bindgen]
pub struct WasmMoEAttention {
inner: MoEAttention,
}
#[wasm_bindgen]
impl WasmMoEAttention {
/// Create a new MoE attention instance
///
/// # Arguments
/// * `dim` - Embedding dimension
/// * `num_experts` - Number of expert attention mechanisms
/// * `top_k` - Number of experts to use per query
#[wasm_bindgen(constructor)]
pub fn new(dim: usize, num_experts: usize, top_k: usize) -> WasmMoEAttention {
let config = MoEConfig::builder()
.dim(dim)
.num_experts(num_experts)
.top_k(top_k)
.build();
Self {
inner: MoEAttention::new(config),
}
}
/// Compute MoE attention
pub fn compute(
&self,
query: &[f32],
keys: JsValue,
values: JsValue,
) -> Result<Vec<f32>, JsError> {
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
self.inner
.compute(query, &keys_refs, &values_refs)
.map_err(|e| JsError::new(&e.to_string()))
}
}

View File

@@ -0,0 +1,33 @@
use wasm_bindgen::prelude::*;
pub mod attention;
pub mod training;
pub mod utils;
/// Initialize the WASM module with panic hook
#[wasm_bindgen(start)]
pub fn init() {
#[cfg(feature = "console_error_panic_hook")]
console_error_panic_hook::set_once();
}
/// Get the version of the ruvector-attention-wasm crate
#[wasm_bindgen]
pub fn version() -> String {
env!("CARGO_PKG_VERSION").to_string()
}
/// Get information about available attention mechanisms
#[wasm_bindgen]
pub fn available_mechanisms() -> JsValue {
let mechanisms = vec![
"scaled_dot_product",
"multi_head",
"hyperbolic",
"linear",
"flash",
"local_global",
"moe",
];
serde_wasm_bindgen::to_value(&mechanisms).unwrap()
}

View File

@@ -0,0 +1,238 @@
use ruvector_attention::training::{Adam, AdamW, InfoNCELoss, Loss, Optimizer, SGD};
use wasm_bindgen::prelude::*;
/// InfoNCE contrastive loss for training
#[wasm_bindgen]
pub struct WasmInfoNCELoss {
inner: InfoNCELoss,
}
#[wasm_bindgen]
impl WasmInfoNCELoss {
/// Create a new InfoNCE loss instance
///
/// # Arguments
/// * `temperature` - Temperature parameter for softmax
#[wasm_bindgen(constructor)]
pub fn new(temperature: f32) -> WasmInfoNCELoss {
Self {
inner: InfoNCELoss::new(temperature),
}
}
/// Compute InfoNCE loss
///
/// # Arguments
/// * `anchor` - Anchor embedding
/// * `positive` - Positive example embedding
/// * `negatives` - Array of negative example embeddings
pub fn compute(
&self,
anchor: &[f32],
positive: &[f32],
negatives: JsValue,
) -> Result<f32, JsError> {
let negatives_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(negatives)?;
let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
Ok(self.inner.compute(anchor, positive, &negatives_refs))
}
}
/// Adam optimizer
#[wasm_bindgen]
pub struct WasmAdam {
inner: Adam,
}
#[wasm_bindgen]
impl WasmAdam {
/// Create a new Adam optimizer
///
/// # Arguments
/// * `param_count` - Number of parameters
/// * `learning_rate` - Learning rate
#[wasm_bindgen(constructor)]
pub fn new(param_count: usize, learning_rate: f32) -> WasmAdam {
Self {
inner: Adam::new(param_count, learning_rate),
}
}
/// Perform optimization step
///
/// # Arguments
/// * `params` - Current parameter values (will be updated in-place)
/// * `gradients` - Gradient values
pub fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
self.inner.step(params, gradients);
}
/// Reset optimizer state
pub fn reset(&mut self) {
self.inner.reset();
}
/// Get current learning rate
#[wasm_bindgen(getter)]
pub fn learning_rate(&self) -> f32 {
self.inner.learning_rate()
}
/// Set learning rate
#[wasm_bindgen(setter)]
pub fn set_learning_rate(&mut self, lr: f32) {
self.inner.set_learning_rate(lr);
}
}
/// AdamW optimizer (Adam with decoupled weight decay)
#[wasm_bindgen]
pub struct WasmAdamW {
inner: AdamW,
wd: f32,
}
#[wasm_bindgen]
impl WasmAdamW {
/// Create a new AdamW optimizer
///
/// # Arguments
/// * `param_count` - Number of parameters
/// * `learning_rate` - Learning rate
/// * `weight_decay` - Weight decay coefficient
#[wasm_bindgen(constructor)]
pub fn new(param_count: usize, learning_rate: f32, weight_decay: f32) -> WasmAdamW {
let optimizer = AdamW::new(param_count, learning_rate).with_weight_decay(weight_decay);
Self {
inner: optimizer,
wd: weight_decay,
}
}
/// Perform optimization step with weight decay
pub fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
self.inner.step(params, gradients);
}
/// Reset optimizer state
pub fn reset(&mut self) {
self.inner.reset();
}
/// Get current learning rate
#[wasm_bindgen(getter)]
pub fn learning_rate(&self) -> f32 {
self.inner.learning_rate()
}
/// Set learning rate
#[wasm_bindgen(setter)]
pub fn set_learning_rate(&mut self, lr: f32) {
self.inner.set_learning_rate(lr);
}
/// Get weight decay
#[wasm_bindgen(getter)]
pub fn weight_decay(&self) -> f32 {
self.wd
}
}
/// SGD optimizer with momentum
#[wasm_bindgen]
pub struct WasmSGD {
inner: SGD,
}
#[wasm_bindgen]
impl WasmSGD {
/// Create a new SGD optimizer
///
/// # Arguments
/// * `param_count` - Number of parameters
/// * `learning_rate` - Learning rate
/// * `momentum` - Momentum coefficient (default: 0)
#[wasm_bindgen(constructor)]
pub fn new(param_count: usize, learning_rate: f32, momentum: Option<f32>) -> WasmSGD {
let mut optimizer = SGD::new(param_count, learning_rate);
if let Some(m) = momentum {
optimizer = optimizer.with_momentum(m);
}
Self { inner: optimizer }
}
/// Perform optimization step
pub fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
self.inner.step(params, gradients);
}
/// Reset optimizer state
pub fn reset(&mut self) {
self.inner.reset();
}
/// Get current learning rate
#[wasm_bindgen(getter)]
pub fn learning_rate(&self) -> f32 {
self.inner.learning_rate()
}
/// Set learning rate
#[wasm_bindgen(setter)]
pub fn set_learning_rate(&mut self, lr: f32) {
self.inner.set_learning_rate(lr);
}
}
/// Learning rate scheduler
#[wasm_bindgen]
pub struct WasmLRScheduler {
initial_lr: f32,
current_step: usize,
warmup_steps: usize,
total_steps: usize,
}
#[wasm_bindgen]
impl WasmLRScheduler {
/// Create a new learning rate scheduler with warmup and cosine decay
///
/// # Arguments
/// * `initial_lr` - Initial learning rate
/// * `warmup_steps` - Number of warmup steps
/// * `total_steps` - Total training steps
#[wasm_bindgen(constructor)]
pub fn new(initial_lr: f32, warmup_steps: usize, total_steps: usize) -> WasmLRScheduler {
Self {
initial_lr,
current_step: 0,
warmup_steps,
total_steps,
}
}
/// Get learning rate for current step
pub fn get_lr(&self) -> f32 {
if self.current_step < self.warmup_steps {
// Linear warmup
self.initial_lr * (self.current_step as f32 / self.warmup_steps as f32)
} else {
// Cosine decay
let progress = (self.current_step - self.warmup_steps) as f32
/ (self.total_steps - self.warmup_steps) as f32;
let cosine = 0.5 * (1.0 + (std::f32::consts::PI * progress).cos());
self.initial_lr * cosine
}
}
/// Advance to next step
pub fn step(&mut self) {
self.current_step += 1;
}
/// Reset scheduler
pub fn reset(&mut self) {
self.current_step = 0;
}
}

View File

@@ -0,0 +1,201 @@
use wasm_bindgen::prelude::*;
use web_sys::console;
/// Log a message to the browser console
#[wasm_bindgen]
pub fn log(message: &str) {
console::log_1(&message.into());
}
/// Log an error to the browser console
#[wasm_bindgen]
pub fn log_error(message: &str) {
console::error_1(&message.into());
}
/// Compute cosine similarity between two vectors
#[wasm_bindgen]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32, JsError> {
if a.len() != b.len() {
return Err(JsError::new("Vectors must have same length"));
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return Err(JsError::new("Cannot compute similarity for zero vector"));
}
Ok(dot / (norm_a * norm_b))
}
/// Compute L2 norm of a vector
#[wasm_bindgen]
pub fn l2_norm(vec: &[f32]) -> f32 {
vec.iter().map(|x| x * x).sum::<f32>().sqrt()
}
/// Normalize a vector to unit length
#[wasm_bindgen]
pub fn normalize(vec: &mut [f32]) -> Result<(), JsError> {
let norm = l2_norm(vec);
if norm == 0.0 {
return Err(JsError::new("Cannot normalize zero vector"));
}
for x in vec.iter_mut() {
*x /= norm;
}
Ok(())
}
/// Compute softmax of a vector
#[wasm_bindgen]
pub fn softmax(vec: &mut [f32]) {
// Subtract max for numerical stability
let max = vec.iter().copied().fold(f32::NEG_INFINITY, f32::max);
// Compute exp and sum
let mut sum = 0.0;
for x in vec.iter_mut() {
*x = (*x - max).exp();
sum += *x;
}
// Normalize
for x in vec.iter_mut() {
*x /= sum;
}
}
/// Compute attention weights from scores
#[wasm_bindgen]
pub fn attention_weights(scores: &mut [f32], temperature: Option<f32>) {
let temp = temperature.unwrap_or(1.0);
// Scale by temperature
for score in scores.iter_mut() {
*score /= temp;
}
// Apply softmax
softmax(scores);
}
/// Batch normalize vectors
#[wasm_bindgen]
pub fn batch_normalize(vectors: JsValue, epsilon: Option<f32>) -> Result<Vec<f32>, JsError> {
let eps = epsilon.unwrap_or(1e-8);
let mut vecs: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(vectors)?;
if vecs.is_empty() {
return Ok(Vec::new());
}
let dim = vecs[0].len();
let batch_size = vecs.len();
// Compute mean
let mut mean = vec![0.0; dim];
for vec in &vecs {
for (i, &val) in vec.iter().enumerate() {
mean[i] += val;
}
}
for m in &mut mean {
*m /= batch_size as f32;
}
// Compute variance
let mut variance = vec![0.0; dim];
for vec in &vecs {
for (i, &val) in vec.iter().enumerate() {
let diff = val - mean[i];
variance[i] += diff * diff;
}
}
for v in &mut variance {
*v /= batch_size as f32;
}
// Normalize
for vec in &mut vecs {
for (i, val) in vec.iter_mut().enumerate() {
*val = (*val - mean[i]) / (variance[i] + eps).sqrt();
}
}
Ok(vecs.into_iter().flatten().collect())
}
/// Generate random orthogonal matrix (for initialization)
#[wasm_bindgen]
pub fn random_orthogonal_matrix(dim: usize) -> Vec<f32> {
use js_sys::Math;
let mut matrix = vec![0.0; dim * dim];
// Generate random matrix
for i in 0..dim {
for j in 0..dim {
matrix[i * dim + j] = (Math::random() as f32 - 0.5) * 2.0;
}
}
// QR decomposition (simplified Gram-Schmidt)
for i in 0..dim {
// Normalize column i
let mut norm = 0.0;
for j in 0..dim {
let val = matrix[j * dim + i];
norm += val * val;
}
norm = norm.sqrt();
for j in 0..dim {
matrix[j * dim + i] /= norm;
}
// Orthogonalize remaining columns
for k in (i + 1)..dim {
let mut dot = 0.0;
for j in 0..dim {
dot += matrix[j * dim + i] * matrix[j * dim + k];
}
for j in 0..dim {
matrix[j * dim + k] -= dot * matrix[j * dim + i];
}
}
}
matrix
}
/// Compute pairwise distances between vectors
#[wasm_bindgen]
pub fn pairwise_distances(vectors: JsValue) -> Result<Vec<f32>, JsError> {
let vecs: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(vectors)?;
let n = vecs.len();
let mut distances = vec![0.0; n * n];
for i in 0..n {
for j in 0..n {
if i == j {
distances[i * n + j] = 0.0;
} else {
let mut dist = 0.0;
for k in 0..vecs[i].len() {
let diff = vecs[i][k] - vecs[j][k];
dist += diff * diff;
}
distances[i * n + j] = dist.sqrt();
}
}
}
Ok(distances)
}

View File

@@ -0,0 +1,185 @@
//! Test suite for WASM bindings
//! Run with: wasm-pack test --headless --firefox
#![cfg(target_arch = "wasm32")]
use ruvector_attention_wasm::*;
use wasm_bindgen_test::*;
wasm_bindgen_test_configure!(run_in_browser);
#[wasm_bindgen_test]
fn test_version() {
let ver = version();
assert!(!ver.is_empty());
assert_eq!(ver, env!("CARGO_PKG_VERSION"));
}
#[wasm_bindgen_test]
fn test_available_mechanisms() {
let mechanisms = available_mechanisms();
assert!(mechanisms.is_array());
}
#[wasm_bindgen_test]
fn test_multi_head_attention() {
let mha = attention::WasmMultiHeadAttention::new(64, 8).unwrap();
assert_eq!(mha.dim(), 64);
assert_eq!(mha.num_heads(), 8);
}
#[wasm_bindgen_test]
fn test_hyperbolic_attention() {
let ha = attention::WasmHyperbolicAttention::new(64, 1.0);
assert_eq!(ha.curvature(), 1.0);
}
#[wasm_bindgen_test]
fn test_linear_attention() {
let la = attention::WasmLinearAttention::new(64, 256);
// Linear attention doesn't expose internal state
// Just verify it can be created
}
#[wasm_bindgen_test]
fn test_flash_attention() {
let fa = attention::WasmFlashAttention::new(64, 16);
// Flash attention doesn't expose internal state
// Just verify it can be created
}
#[wasm_bindgen_test]
fn test_local_global_attention() {
let lga = attention::WasmLocalGlobalAttention::new(64, 8, 4);
// Local-global attention doesn't expose internal state
// Just verify it can be created
}
#[wasm_bindgen_test]
fn test_moe_attention() {
let moe = attention::WasmMoEAttention::new(64, 4, 2).unwrap();
// Get expert statistics
let stats = moe.expert_stats();
assert!(stats.is_object());
}
#[wasm_bindgen_test]
fn test_info_nce_loss() {
let loss = training::WasmInfoNCELoss::new(0.07);
// InfoNCE loss doesn't expose temperature
// Just verify it can be created
}
#[wasm_bindgen_test]
fn test_adam_optimizer() {
let mut adam = training::WasmAdam::new(100, 0.001, Some(0.9), Some(0.999), Some(1e-8));
assert_eq!(adam.learning_rate(), 0.001);
adam.set_learning_rate(0.0001);
assert_eq!(adam.learning_rate(), 0.0001);
adam.reset();
}
#[wasm_bindgen_test]
fn test_adamw_optimizer() {
let mut adamw = training::WasmAdamW::new(100, 0.001, 0.01, Some(0.9), Some(0.999), Some(1e-8));
assert_eq!(adamw.learning_rate(), 0.001);
assert_eq!(adamw.weight_decay(), 0.01);
adamw.reset();
}
#[wasm_bindgen_test]
fn test_lr_scheduler() {
let mut scheduler = training::WasmLRScheduler::new(0.001, 100, 1000);
// At step 0, should be near 0 (warmup)
let lr0 = scheduler.get_lr();
assert!(lr0 < 0.001);
// After warmup, should be at initial LR
for _ in 0..100 {
scheduler.step();
}
let lr100 = scheduler.get_lr();
assert!((lr100 - 0.001).abs() < 1e-6);
scheduler.reset();
assert_eq!(scheduler.get_lr(), lr0);
}
#[wasm_bindgen_test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let sim = utils::cosine_similarity(&a, &b).unwrap();
assert!((sim - 1.0).abs() < 1e-6);
}
#[wasm_bindgen_test]
fn test_l2_norm() {
let vec = vec![3.0, 4.0];
let norm = utils::l2_norm(&vec);
assert!((norm - 5.0).abs() < 1e-6);
}
#[wasm_bindgen_test]
fn test_normalize() {
let mut vec = vec![3.0, 4.0];
utils::normalize(&mut vec).unwrap();
let norm = utils::l2_norm(&vec);
assert!((norm - 1.0).abs() < 1e-6);
}
#[wasm_bindgen_test]
fn test_softmax() {
let mut vec = vec![1.0, 2.0, 3.0];
utils::softmax(&mut vec);
let sum: f32 = vec.iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
// Check monotonicity
assert!(vec[0] < vec[1]);
assert!(vec[1] < vec[2]);
}
#[wasm_bindgen_test]
fn test_attention_weights() {
let mut scores = vec![1.0, 2.0, 3.0];
utils::attention_weights(&mut scores, Some(1.0));
let sum: f32 = scores.iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
}
#[wasm_bindgen_test]
fn test_random_orthogonal_matrix() {
let dim = 4;
let matrix = utils::random_orthogonal_matrix(dim);
assert_eq!(matrix.len(), dim * dim);
// Check orthogonality: Q^T * Q = I
for i in 0..dim {
for j in 0..dim {
let mut dot = 0.0;
for k in 0..dim {
dot += matrix[k * dim + i] * matrix[k * dim + j];
}
if i == j {
assert!((dot - 1.0).abs() < 1e-4, "Diagonal should be 1");
} else {
assert!(dot.abs() < 1e-4, "Off-diagonal should be 0");
}
}
}
}

View File

@@ -0,0 +1,25 @@
{
"compilerOptions": {
"target": "ES2020",
"module": "ES2020",
"lib": ["ES2020", "DOM"],
"declaration": true,
"declarationMap": true,
"sourceMap": true,
"outDir": "./dist",
"rootDir": "./js",
"strict": true,
"esModuleInterop": true,
"skipLibCheck": true,
"forceConsistentCasingInFileNames": true,
"moduleResolution": "node",
"resolveJsonModule": true,
"isolatedModules": true,
"noUnusedLocals": true,
"noUnusedParameters": true,
"noImplicitReturns": true,
"noFallthroughCasesInSwitch": true
},
"include": ["js/**/*"],
"exclude": ["node_modules", "pkg", "pkg-node", "pkg-bundler", "target"]
}