Files

187 lines
6.4 KiB
JavaScript

/**
* Integration test for ruvector-attention-wasm package
* Tests all attention mechanisms from published npm package
*/
import { test, describe } from 'node:test';
import assert from 'node:assert';
// Import from published WASM package
import init, {
scaled_dot_attention,
WasmMultiHeadAttention,
WasmHyperbolicAttention,
WasmLinearAttention,
WasmFlashAttention,
WasmLocalGlobalAttention,
WasmMoEAttention
} from 'ruvector-attention-wasm';
describe('WASM Attention Package Tests', async () => {
// Initialize WASM before tests
await init();
test('Scaled Dot-Product Attention', () => {
const dim = 64;
const query = new Float32Array(dim).fill(0.5);
const keys = [
Array.from({ length: dim }, () => Math.random()),
Array.from({ length: dim }, () => Math.random()),
Array.from({ length: dim }, () => Math.random())
];
const values = [
Array.from({ length: dim }, () => Math.random()),
Array.from({ length: dim }, () => Math.random()),
Array.from({ length: dim }, () => Math.random())
];
const result = scaled_dot_attention(query, keys, values, null);
assert.ok(result instanceof Float32Array, 'Result should be Float32Array');
assert.strictEqual(result.length, dim, `Result dimension should be ${dim}`);
console.log(' ✓ Scaled dot-product attention works correctly');
});
test('Multi-Head Attention', () => {
const dim = 64;
const numHeads = 8;
const mha = new WasmMultiHeadAttention(dim, numHeads);
assert.strictEqual(mha.dim, dim, 'Dimension should match');
assert.strictEqual(mha.num_heads, numHeads, 'Number of heads should match');
const query = new Float32Array(dim).fill(0.5);
const keys = [
Array.from({ length: dim }, () => Math.random()),
Array.from({ length: dim }, () => Math.random())
];
const values = [
Array.from({ length: dim }, () => Math.random()),
Array.from({ length: dim }, () => Math.random())
];
const result = mha.compute(query, keys, values);
assert.ok(result instanceof Float32Array, 'Result should be Float32Array');
assert.strictEqual(result.length, dim, `Result dimension should be ${dim}`);
console.log(' ✓ Multi-head attention works correctly');
});
test('Hyperbolic Attention', () => {
const dim = 64;
const curvature = 1.0;
const hyperbolic = new WasmHyperbolicAttention(dim, curvature);
assert.strictEqual(hyperbolic.curvature, curvature, 'Curvature should match');
const query = new Float32Array(dim).fill(0.1);
const keys = [
Array.from({ length: dim }, () => Math.random() * 0.1),
Array.from({ length: dim }, () => Math.random() * 0.1)
];
const values = [
Array.from({ length: dim }, () => Math.random()),
Array.from({ length: dim }, () => Math.random())
];
const result = hyperbolic.compute(query, keys, values);
assert.ok(result instanceof Float32Array, 'Result should be Float32Array');
assert.strictEqual(result.length, dim, `Result dimension should be ${dim}`);
console.log(' ✓ Hyperbolic attention works correctly');
});
test('Linear Attention (Performer-style)', () => {
const dim = 64;
const numFeatures = 128;
const linear = new WasmLinearAttention(dim, numFeatures);
const query = new Float32Array(dim).fill(0.5);
const keys = [
Array.from({ length: dim }, () => Math.random()),
Array.from({ length: dim }, () => Math.random())
];
const values = [
Array.from({ length: dim }, () => Math.random()),
Array.from({ length: dim }, () => Math.random())
];
const result = linear.compute(query, keys, values);
assert.ok(result instanceof Float32Array, 'Result should be Float32Array');
assert.strictEqual(result.length, dim, `Result dimension should be ${dim}`);
console.log(' ✓ Linear attention works correctly');
});
test('Flash Attention', () => {
const dim = 64;
const blockSize = 16;
const flash = new WasmFlashAttention(dim, blockSize);
const query = new Float32Array(dim).fill(0.5);
const keys = [
Array.from({ length: dim }, () => Math.random()),
Array.from({ length: dim }, () => Math.random())
];
const values = [
Array.from({ length: dim }, () => Math.random()),
Array.from({ length: dim }, () => Math.random())
];
const result = flash.compute(query, keys, values);
assert.ok(result instanceof Float32Array, 'Result should be Float32Array');
assert.strictEqual(result.length, dim, `Result dimension should be ${dim}`);
console.log(' ✓ Flash attention works correctly');
});
test('Local-Global Attention', () => {
const dim = 64;
const localWindow = 4;
const globalTokens = 2;
const localGlobal = new WasmLocalGlobalAttention(dim, localWindow, globalTokens);
const query = new Float32Array(dim).fill(0.5);
const keys = [
Array.from({ length: dim }, () => Math.random()),
Array.from({ length: dim }, () => Math.random()),
Array.from({ length: dim }, () => Math.random()),
Array.from({ length: dim }, () => Math.random())
];
const values = [
Array.from({ length: dim }, () => Math.random()),
Array.from({ length: dim }, () => Math.random()),
Array.from({ length: dim }, () => Math.random()),
Array.from({ length: dim }, () => Math.random())
];
const result = localGlobal.compute(query, keys, values);
assert.ok(result instanceof Float32Array, 'Result should be Float32Array');
assert.strictEqual(result.length, dim, `Result dimension should be ${dim}`);
console.log(' ✓ Local-global attention works correctly');
});
test('Mixture of Experts (MoE) Attention', () => {
const dim = 64;
const numExperts = 4;
const topK = 2;
const moe = new WasmMoEAttention(dim, numExperts, topK);
const query = new Float32Array(dim).fill(0.5);
const keys = [
Array.from({ length: dim }, () => Math.random()),
Array.from({ length: dim }, () => Math.random())
];
const values = [
Array.from({ length: dim }, () => Math.random()),
Array.from({ length: dim }, () => Math.random())
];
const result = moe.compute(query, keys, values);
assert.ok(result instanceof Float32Array, 'Result should be Float32Array');
assert.strictEqual(result.length, dim, `Result dimension should be ${dim}`);
console.log(' ✓ MoE attention works correctly');
});
});
console.log('\n✅ All WASM attention tests passed!\n');