187 lines
6.4 KiB
JavaScript
187 lines
6.4 KiB
JavaScript
/**
|
|
* Integration test for ruvector-attention-wasm package
|
|
* Tests all attention mechanisms from published npm package
|
|
*/
|
|
|
|
import { test, describe } from 'node:test';
|
|
import assert from 'node:assert';
|
|
|
|
// Import from published WASM package
|
|
import init, {
|
|
scaled_dot_attention,
|
|
WasmMultiHeadAttention,
|
|
WasmHyperbolicAttention,
|
|
WasmLinearAttention,
|
|
WasmFlashAttention,
|
|
WasmLocalGlobalAttention,
|
|
WasmMoEAttention
|
|
} from 'ruvector-attention-wasm';
|
|
|
|
describe('WASM Attention Package Tests', async () => {
|
|
// Initialize WASM before tests
|
|
await init();
|
|
|
|
test('Scaled Dot-Product Attention', () => {
|
|
const dim = 64;
|
|
const query = new Float32Array(dim).fill(0.5);
|
|
const keys = [
|
|
Array.from({ length: dim }, () => Math.random()),
|
|
Array.from({ length: dim }, () => Math.random()),
|
|
Array.from({ length: dim }, () => Math.random())
|
|
];
|
|
const values = [
|
|
Array.from({ length: dim }, () => Math.random()),
|
|
Array.from({ length: dim }, () => Math.random()),
|
|
Array.from({ length: dim }, () => Math.random())
|
|
];
|
|
|
|
const result = scaled_dot_attention(query, keys, values, null);
|
|
assert.ok(result instanceof Float32Array, 'Result should be Float32Array');
|
|
assert.strictEqual(result.length, dim, `Result dimension should be ${dim}`);
|
|
console.log(' ✓ Scaled dot-product attention works correctly');
|
|
});
|
|
|
|
test('Multi-Head Attention', () => {
|
|
const dim = 64;
|
|
const numHeads = 8;
|
|
|
|
const mha = new WasmMultiHeadAttention(dim, numHeads);
|
|
assert.strictEqual(mha.dim, dim, 'Dimension should match');
|
|
assert.strictEqual(mha.num_heads, numHeads, 'Number of heads should match');
|
|
|
|
const query = new Float32Array(dim).fill(0.5);
|
|
const keys = [
|
|
Array.from({ length: dim }, () => Math.random()),
|
|
Array.from({ length: dim }, () => Math.random())
|
|
];
|
|
const values = [
|
|
Array.from({ length: dim }, () => Math.random()),
|
|
Array.from({ length: dim }, () => Math.random())
|
|
];
|
|
|
|
const result = mha.compute(query, keys, values);
|
|
assert.ok(result instanceof Float32Array, 'Result should be Float32Array');
|
|
assert.strictEqual(result.length, dim, `Result dimension should be ${dim}`);
|
|
console.log(' ✓ Multi-head attention works correctly');
|
|
});
|
|
|
|
test('Hyperbolic Attention', () => {
|
|
const dim = 64;
|
|
const curvature = 1.0;
|
|
|
|
const hyperbolic = new WasmHyperbolicAttention(dim, curvature);
|
|
assert.strictEqual(hyperbolic.curvature, curvature, 'Curvature should match');
|
|
|
|
const query = new Float32Array(dim).fill(0.1);
|
|
const keys = [
|
|
Array.from({ length: dim }, () => Math.random() * 0.1),
|
|
Array.from({ length: dim }, () => Math.random() * 0.1)
|
|
];
|
|
const values = [
|
|
Array.from({ length: dim }, () => Math.random()),
|
|
Array.from({ length: dim }, () => Math.random())
|
|
];
|
|
|
|
const result = hyperbolic.compute(query, keys, values);
|
|
assert.ok(result instanceof Float32Array, 'Result should be Float32Array');
|
|
assert.strictEqual(result.length, dim, `Result dimension should be ${dim}`);
|
|
console.log(' ✓ Hyperbolic attention works correctly');
|
|
});
|
|
|
|
test('Linear Attention (Performer-style)', () => {
|
|
const dim = 64;
|
|
const numFeatures = 128;
|
|
|
|
const linear = new WasmLinearAttention(dim, numFeatures);
|
|
|
|
const query = new Float32Array(dim).fill(0.5);
|
|
const keys = [
|
|
Array.from({ length: dim }, () => Math.random()),
|
|
Array.from({ length: dim }, () => Math.random())
|
|
];
|
|
const values = [
|
|
Array.from({ length: dim }, () => Math.random()),
|
|
Array.from({ length: dim }, () => Math.random())
|
|
];
|
|
|
|
const result = linear.compute(query, keys, values);
|
|
assert.ok(result instanceof Float32Array, 'Result should be Float32Array');
|
|
assert.strictEqual(result.length, dim, `Result dimension should be ${dim}`);
|
|
console.log(' ✓ Linear attention works correctly');
|
|
});
|
|
|
|
test('Flash Attention', () => {
|
|
const dim = 64;
|
|
const blockSize = 16;
|
|
|
|
const flash = new WasmFlashAttention(dim, blockSize);
|
|
|
|
const query = new Float32Array(dim).fill(0.5);
|
|
const keys = [
|
|
Array.from({ length: dim }, () => Math.random()),
|
|
Array.from({ length: dim }, () => Math.random())
|
|
];
|
|
const values = [
|
|
Array.from({ length: dim }, () => Math.random()),
|
|
Array.from({ length: dim }, () => Math.random())
|
|
];
|
|
|
|
const result = flash.compute(query, keys, values);
|
|
assert.ok(result instanceof Float32Array, 'Result should be Float32Array');
|
|
assert.strictEqual(result.length, dim, `Result dimension should be ${dim}`);
|
|
console.log(' ✓ Flash attention works correctly');
|
|
});
|
|
|
|
test('Local-Global Attention', () => {
|
|
const dim = 64;
|
|
const localWindow = 4;
|
|
const globalTokens = 2;
|
|
|
|
const localGlobal = new WasmLocalGlobalAttention(dim, localWindow, globalTokens);
|
|
|
|
const query = new Float32Array(dim).fill(0.5);
|
|
const keys = [
|
|
Array.from({ length: dim }, () => Math.random()),
|
|
Array.from({ length: dim }, () => Math.random()),
|
|
Array.from({ length: dim }, () => Math.random()),
|
|
Array.from({ length: dim }, () => Math.random())
|
|
];
|
|
const values = [
|
|
Array.from({ length: dim }, () => Math.random()),
|
|
Array.from({ length: dim }, () => Math.random()),
|
|
Array.from({ length: dim }, () => Math.random()),
|
|
Array.from({ length: dim }, () => Math.random())
|
|
];
|
|
|
|
const result = localGlobal.compute(query, keys, values);
|
|
assert.ok(result instanceof Float32Array, 'Result should be Float32Array');
|
|
assert.strictEqual(result.length, dim, `Result dimension should be ${dim}`);
|
|
console.log(' ✓ Local-global attention works correctly');
|
|
});
|
|
|
|
test('Mixture of Experts (MoE) Attention', () => {
|
|
const dim = 64;
|
|
const numExperts = 4;
|
|
const topK = 2;
|
|
|
|
const moe = new WasmMoEAttention(dim, numExperts, topK);
|
|
|
|
const query = new Float32Array(dim).fill(0.5);
|
|
const keys = [
|
|
Array.from({ length: dim }, () => Math.random()),
|
|
Array.from({ length: dim }, () => Math.random())
|
|
];
|
|
const values = [
|
|
Array.from({ length: dim }, () => Math.random()),
|
|
Array.from({ length: dim }, () => Math.random())
|
|
];
|
|
|
|
const result = moe.compute(query, keys, values);
|
|
assert.ok(result instanceof Float32Array, 'Result should be Float32Array');
|
|
assert.strictEqual(result.length, dim, `Result dimension should be ${dim}`);
|
|
console.log(' ✓ MoE attention works correctly');
|
|
});
|
|
});
|
|
|
|
console.log('\n✅ All WASM attention tests passed!\n');
|