Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
62
vendor/ruvector/crates/ruvector-attention-unified-wasm/Cargo.toml
vendored
Normal file
62
vendor/ruvector/crates/ruvector-attention-unified-wasm/Cargo.toml
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
[package]
|
||||
name = "ruvector-attention-unified-wasm"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
authors = ["RuVector Team"]
|
||||
description = "Unified WebAssembly bindings for 18+ attention mechanisms: Neural, DAG, Graph, and Mamba SSM"
|
||||
license = "MIT OR Apache-2.0"
|
||||
repository = "https://github.com/ruvnet/ruvector"
|
||||
keywords = ["attention", "wasm", "neural", "dag", "mamba"]
|
||||
categories = ["wasm", "science", "algorithms"]
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib", "rlib"]
|
||||
|
||||
[dependencies]
|
||||
# Core attention mechanisms (7 neural attention types)
|
||||
ruvector-attention = { version = "2.0", path = "../ruvector-attention", default-features = false, features = ["wasm"] }
|
||||
|
||||
# DAG attention mechanisms (7 DAG-specific attention types)
|
||||
ruvector-dag = { version = "2.0", path = "../ruvector-dag", default-features = false, features = ["wasm"] }
|
||||
|
||||
# GNN/Graph attention (GAT, GCN, GraphSAGE)
|
||||
ruvector-gnn = { version = "2.0", path = "../ruvector-gnn", default-features = false, features = ["wasm"] }
|
||||
|
||||
# WASM bindings
|
||||
wasm-bindgen = "0.2"
|
||||
js-sys = "0.3"
|
||||
web-sys = { version = "0.3", features = ["console"] }
|
||||
|
||||
# Serialization
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde-wasm-bindgen = "0.6"
|
||||
serde_json = "1.0"
|
||||
|
||||
# Utils
|
||||
console_error_panic_hook = { version = "0.1", optional = true }
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
|
||||
# Allocator for smaller binary (optional)
|
||||
wee_alloc = { version = "0.4", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
wasm-bindgen-test = "0.3"
|
||||
|
||||
[features]
|
||||
default = ["console_error_panic_hook"]
|
||||
console_error_panic_hook = ["dep:console_error_panic_hook"]
|
||||
# Enable wee_alloc for ~10KB smaller WASM binary
|
||||
wee_alloc = ["dep:wee_alloc"]
|
||||
|
||||
[profile.release]
|
||||
opt-level = "z"
|
||||
lto = true
|
||||
codegen-units = 1
|
||||
panic = "abort"
|
||||
strip = true
|
||||
|
||||
[profile.release.package."*"]
|
||||
opt-level = "z"
|
||||
|
||||
[package.metadata.wasm-pack.profile.release]
|
||||
wasm-opt = false
|
||||
553
vendor/ruvector/crates/ruvector-attention-unified-wasm/README.md
vendored
Normal file
553
vendor/ruvector/crates/ruvector-attention-unified-wasm/README.md
vendored
Normal file
@@ -0,0 +1,553 @@
|
||||
# ruvector-attention-unified-wasm
|
||||
|
||||
Unified WebAssembly bindings for 18+ attention mechanisms, combining Neural, DAG, Graph, and Mamba SSM attention types into a single npm package.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
npm install ruvector-attention-unified-wasm
|
||||
# or
|
||||
yarn add ruvector-attention-unified-wasm
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```javascript
|
||||
import init, {
|
||||
// Neural attention
|
||||
WasmScaledDotProductAttention,
|
||||
WasmMultiHeadAttention,
|
||||
|
||||
// DAG attention
|
||||
WasmQueryDag,
|
||||
WasmTopologicalAttention,
|
||||
|
||||
// Graph attention
|
||||
WasmGraphAttention,
|
||||
GraphAttentionType,
|
||||
|
||||
// SSM attention
|
||||
MambaSSMAttention,
|
||||
MambaConfig,
|
||||
|
||||
// Utilities
|
||||
UnifiedAttention,
|
||||
availableMechanisms,
|
||||
version
|
||||
} from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// Initialize WASM module
|
||||
await init();
|
||||
|
||||
console.log('Version:', version());
|
||||
console.log('Mechanisms:', availableMechanisms());
|
||||
```
|
||||
|
||||
## Attention Mechanism Categories
|
||||
|
||||
### 1. Neural Attention (7 mechanisms)
|
||||
|
||||
Standard transformer-style attention mechanisms for sequence processing.
|
||||
|
||||
#### Scaled Dot-Product Attention
|
||||
|
||||
```javascript
|
||||
import { WasmScaledDotProductAttention } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// Create attention layer (dimension, dropout_rate)
|
||||
const attention = new WasmScaledDotProductAttention(64, 0.1);
|
||||
|
||||
// Prepare query, key, value vectors (as Float32Array)
|
||||
const query = new Float32Array(64); // [dim]
|
||||
const keys = new Float32Array(320); // [5, dim] = 5 key vectors
|
||||
const values = new Float32Array(320); // [5, dim] = 5 value vectors
|
||||
|
||||
// Fill with your embeddings...
|
||||
for (let i = 0; i < 64; i++) query[i] = Math.random();
|
||||
|
||||
// Compute attention output
|
||||
const output = attention.forward(query, keys, values, 5); // numKeys = 5
|
||||
console.log('Output shape:', output.length); // 64
|
||||
|
||||
// Get attention weights for visualization
|
||||
const weights = attention.getWeights(query, keys, 5);
|
||||
console.log('Attention weights:', weights); // [5] probabilities
|
||||
```
|
||||
|
||||
#### Multi-Head Attention
|
||||
|
||||
```javascript
|
||||
import { WasmMultiHeadAttention } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// Create with dimensions and number of heads
|
||||
const mha = new WasmMultiHeadAttention(
|
||||
512, // model dimension
|
||||
8, // number of heads
|
||||
0.1 // dropout
|
||||
);
|
||||
|
||||
// Forward pass with batched inputs
|
||||
const queries = new Float32Array(512 * 10); // [batch=10, dim=512]
|
||||
const keys = new Float32Array(512 * 20); // [seq=20, dim=512]
|
||||
const values = new Float32Array(512 * 20);
|
||||
|
||||
const output = mha.forward(queries, keys, values, 10, 20);
|
||||
console.log('Output:', output.length); // 512 * 10 = 5120
|
||||
```
|
||||
|
||||
#### Hyperbolic Attention
|
||||
|
||||
For hierarchical data like trees and taxonomies.
|
||||
|
||||
```javascript
|
||||
import { WasmHyperbolicAttention } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// Curvature controls the hyperbolic space geometry
|
||||
const hyperbolic = new WasmHyperbolicAttention(64, -1.0);
|
||||
|
||||
const output = hyperbolic.forward(query, keys, values, 5);
|
||||
```
|
||||
|
||||
#### Linear Attention (Performer-style)
|
||||
|
||||
O(n) complexity for long sequences.
|
||||
|
||||
```javascript
|
||||
import { WasmLinearAttention } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
const linear = new WasmLinearAttention(64);
|
||||
const output = linear.forward(query, keys, values, numKeys);
|
||||
```
|
||||
|
||||
#### Flash Attention
|
||||
|
||||
Memory-efficient blocked attention for large sequences.
|
||||
|
||||
```javascript
|
||||
import { WasmFlashAttention } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// Block size controls memory/compute tradeoff
|
||||
const flash = new WasmFlashAttention(64, 256); // dim=64, block_size=256
|
||||
const output = flash.forward(queries, keys, values, seqLen);
|
||||
```
|
||||
|
||||
#### Local-Global Attention
|
||||
|
||||
Sparse attention with global tokens (like Longformer).
|
||||
|
||||
```javascript
|
||||
import { WasmLocalGlobalAttention } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
const lg = new WasmLocalGlobalAttention(
|
||||
64, // dimension
|
||||
128, // local window size
|
||||
4 // number of global tokens
|
||||
);
|
||||
const output = lg.forward(queries, keys, values, seqLen);
|
||||
```
|
||||
|
||||
#### Mixture of Experts Attention
|
||||
|
||||
Route tokens to specialized expert attention heads.
|
||||
|
||||
```javascript
|
||||
import { WasmMoEAttention } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
const moe = new WasmMoEAttention(
|
||||
64, // dimension
|
||||
8, // number of experts
|
||||
2 // top-k experts per token
|
||||
);
|
||||
const output = moe.forward(input, seqLen);
|
||||
```
|
||||
|
||||
### 2. DAG Attention (7 mechanisms)
|
||||
|
||||
Graph-topology-aware attention for directed acyclic graphs.
|
||||
|
||||
#### Building a DAG
|
||||
|
||||
```javascript
|
||||
import { WasmQueryDag } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// Create DAG for query plan
|
||||
const dag = new WasmQueryDag();
|
||||
|
||||
// Add nodes (operator_type, cost)
|
||||
const scan = dag.addNode("scan", 100.0);
|
||||
const filter = dag.addNode("filter", 20.0);
|
||||
const join = dag.addNode("join", 50.0);
|
||||
const aggregate = dag.addNode("aggregate", 30.0);
|
||||
|
||||
// Add edges (from, to)
|
||||
dag.addEdge(scan, filter);
|
||||
dag.addEdge(filter, join);
|
||||
dag.addEdge(join, aggregate);
|
||||
|
||||
console.log('Nodes:', dag.nodeCount); // 4
|
||||
console.log('Edges:', dag.edgeCount); // 3
|
||||
console.log('JSON:', dag.toJson());
|
||||
```
|
||||
|
||||
#### Topological Attention
|
||||
|
||||
Position-based attention following DAG order.
|
||||
|
||||
```javascript
|
||||
import { WasmTopologicalAttention } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// decay_factor controls position-based decay (0.0-1.0)
|
||||
const topo = new WasmTopologicalAttention(0.9);
|
||||
const scores = topo.forward(dag);
|
||||
console.log('Attention scores:', scores); // [0.35, 0.30, 0.20, 0.15]
|
||||
```
|
||||
|
||||
#### Causal Cone Attention
|
||||
|
||||
Lightcone-based attention respecting causal dependencies.
|
||||
|
||||
```javascript
|
||||
import { WasmCausalConeAttention } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// future_discount, ancestor_weight
|
||||
const causal = new WasmCausalConeAttention(0.8, 0.9);
|
||||
const scores = causal.forward(dag);
|
||||
```
|
||||
|
||||
#### Critical Path Attention
|
||||
|
||||
Weight attention by critical execution path.
|
||||
|
||||
```javascript
|
||||
import { WasmCriticalPathAttention } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// path_weight for critical path nodes, branch_penalty
|
||||
const critical = new WasmCriticalPathAttention(2.0, 0.5);
|
||||
const scores = critical.forward(dag);
|
||||
```
|
||||
|
||||
#### MinCut-Gated Attention
|
||||
|
||||
Flow-based gating through bottleneck nodes.
|
||||
|
||||
```javascript
|
||||
import { WasmMinCutGatedAttention } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// gate_threshold determines bottleneck detection sensitivity
|
||||
const mincut = new WasmMinCutGatedAttention(0.5);
|
||||
const scores = mincut.forward(dag);
|
||||
```
|
||||
|
||||
#### Hierarchical Lorentz Attention
|
||||
|
||||
Multi-scale hyperbolic attention for DAG hierarchies.
|
||||
|
||||
```javascript
|
||||
import { WasmHierarchicalLorentzAttention } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// curvature, temperature
|
||||
const lorentz = new WasmHierarchicalLorentzAttention(-1.0, 0.1);
|
||||
const scores = lorentz.forward(dag);
|
||||
```
|
||||
|
||||
#### Parallel Branch Attention
|
||||
|
||||
Branch-aware attention for parallel DAG structures.
|
||||
|
||||
```javascript
|
||||
import { WasmParallelBranchAttention } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// max_branches, sync_penalty
|
||||
const parallel = new WasmParallelBranchAttention(8, 0.2);
|
||||
const scores = parallel.forward(dag);
|
||||
```
|
||||
|
||||
#### Temporal BTSP Attention
|
||||
|
||||
Behavioral Time-Series Pattern attention for temporal DAGs.
|
||||
|
||||
```javascript
|
||||
import { WasmTemporalBTSPAttention } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// eligibility_decay, baseline_attention
|
||||
const btsp = new WasmTemporalBTSPAttention(0.95, 0.5);
|
||||
const scores = btsp.forward(dag);
|
||||
```
|
||||
|
||||
### 3. Graph Attention (3 mechanisms)
|
||||
|
||||
Graph neural network attention for arbitrary graph structures.
|
||||
|
||||
#### Graph Attention Networks (GAT)
|
||||
|
||||
```javascript
|
||||
import {
|
||||
WasmGraphAttention,
|
||||
GraphAttentionType
|
||||
} from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// Create GAT layer
|
||||
const gat = new WasmGraphAttention(
|
||||
GraphAttentionType.GAT,
|
||||
64, // input dimension
|
||||
32, // output dimension
|
||||
8 // number of heads
|
||||
);
|
||||
|
||||
// Build adjacency list
|
||||
const adjacency = [
|
||||
[1, 2], // node 0 connects to 1, 2
|
||||
[0, 2, 3], // node 1 connects to 0, 2, 3
|
||||
[0, 1, 3], // node 2 connects to 0, 1, 3
|
||||
[1, 2] // node 3 connects to 1, 2
|
||||
];
|
||||
|
||||
// Node features [4 nodes x 64 dims]
|
||||
const features = new Float32Array(4 * 64);
|
||||
// ... fill with node embeddings
|
||||
|
||||
// Forward pass
|
||||
const output = gat.forward(features, adjacency, 4);
|
||||
console.log('Output shape:', output.length); // 4 * 32 = 128
|
||||
```
|
||||
|
||||
#### Graph Convolutional Networks (GCN)
|
||||
|
||||
```javascript
|
||||
const gcn = new WasmGraphAttention(
|
||||
GraphAttentionType.GCN,
|
||||
64,
|
||||
32,
|
||||
1 // GCN typically uses 1 head
|
||||
);
|
||||
|
||||
const output = gcn.forward(features, adjacency, numNodes);
|
||||
```
|
||||
|
||||
#### GraphSAGE
|
||||
|
||||
```javascript
|
||||
const sage = new WasmGraphAttention(
|
||||
GraphAttentionType.GraphSAGE,
|
||||
64,
|
||||
32,
|
||||
1
|
||||
);
|
||||
|
||||
const output = sage.forward(features, adjacency, numNodes);
|
||||
```
|
||||
|
||||
#### Factory Methods
|
||||
|
||||
```javascript
|
||||
import { GraphAttentionFactory } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
console.log(GraphAttentionFactory.availableTypes());
|
||||
// ["gat", "gcn", "graphsage"]
|
||||
|
||||
console.log(GraphAttentionFactory.getDescription("gat"));
|
||||
// "Graph Attention Networks with multi-head attention"
|
||||
|
||||
console.log(GraphAttentionFactory.getUseCases("gat"));
|
||||
// ["Node classification", "Link prediction", ...]
|
||||
```
|
||||
|
||||
### 4. State Space Models (1 mechanism)
|
||||
|
||||
#### Mamba SSM Attention
|
||||
|
||||
Selective State Space Model for efficient sequence modeling.
|
||||
|
||||
```javascript
|
||||
import {
|
||||
MambaSSMAttention,
|
||||
MambaConfig,
|
||||
HybridMambaAttention
|
||||
} from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// Configure Mamba
|
||||
const config = new MambaConfig(256) // model dimension
|
||||
.withStateDim(16)
|
||||
.withExpandFactor(2)
|
||||
.withConvKernelSize(4);
|
||||
|
||||
// Create Mamba layer
|
||||
const mamba = new MambaSSMAttention(config);
|
||||
|
||||
// Or use defaults
|
||||
const mamba2 = MambaSSMAttention.withDefaults(256);
|
||||
|
||||
// Forward pass
|
||||
const input = new Float32Array(256 * 100); // [seq_len=100, dim=256]
|
||||
const output = mamba.forward(input, 100);
|
||||
|
||||
// Get attention-like scores for visualization
|
||||
const scores = mamba.getAttentionScores(input, 100);
|
||||
```
|
||||
|
||||
#### Hybrid Mamba-Attention
|
||||
|
||||
Combine Mamba efficiency with local attention.
|
||||
|
||||
```javascript
|
||||
import { HybridMambaAttention, MambaConfig } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
const config = new MambaConfig(256);
|
||||
const hybrid = new HybridMambaAttention(config, 64); // local_window=64
|
||||
|
||||
const output = hybrid.forward(input, seqLen);
|
||||
console.log('Local window:', hybrid.localWindow); // 64
|
||||
```
|
||||
|
||||
## Unified Attention Selector
|
||||
|
||||
Select the right mechanism dynamically.
|
||||
|
||||
```javascript
|
||||
import { UnifiedAttention } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// Create selector for any mechanism
|
||||
const selector = new UnifiedAttention("multi_head");
|
||||
|
||||
// Query mechanism properties
|
||||
console.log(selector.mechanism); // "multi_head"
|
||||
console.log(selector.category); // "neural"
|
||||
console.log(selector.supportsSequences); // true
|
||||
console.log(selector.supportsGraphs); // false
|
||||
console.log(selector.supportsHyperbolic); // false
|
||||
|
||||
// DAG mechanism
|
||||
const dagSelector = new UnifiedAttention("topological");
|
||||
console.log(dagSelector.category); // "dag"
|
||||
console.log(dagSelector.supportsGraphs); // true
|
||||
```
|
||||
|
||||
## Utility Functions
|
||||
|
||||
```javascript
|
||||
import {
|
||||
softmax,
|
||||
temperatureSoftmax,
|
||||
cosineSimilarity,
|
||||
availableMechanisms,
|
||||
getStats
|
||||
} from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// Softmax normalization
|
||||
const probs = softmax(new Float32Array([1.0, 2.0, 3.0]));
|
||||
console.log(probs); // [0.09, 0.24, 0.67]
|
||||
|
||||
// Temperature-scaled softmax
|
||||
const sharpProbs = temperatureSoftmax(
|
||||
new Float32Array([1.0, 2.0, 3.0]),
|
||||
0.5 // lower temperature = sharper distribution
|
||||
);
|
||||
|
||||
// Cosine similarity
|
||||
const sim = cosineSimilarity(
|
||||
new Float32Array([1, 0, 0]),
|
||||
new Float32Array([0.707, 0.707, 0])
|
||||
);
|
||||
console.log(sim); // 0.707
|
||||
|
||||
// List all mechanisms
|
||||
const mechs = availableMechanisms();
|
||||
console.log(mechs.neural); // ["scaled_dot_product", "multi_head", ...]
|
||||
console.log(mechs.dag); // ["topological", "causal_cone", ...]
|
||||
console.log(mechs.graph); // ["gat", "gcn", "graphsage"]
|
||||
console.log(mechs.ssm); // ["mamba"]
|
||||
|
||||
// Library stats
|
||||
const stats = getStats();
|
||||
console.log(stats.total_mechanisms); // 18
|
||||
console.log(stats.version); // "0.1.0"
|
||||
```
|
||||
|
||||
## TypeScript Support
|
||||
|
||||
Full TypeScript definitions are included. Import types as needed:
|
||||
|
||||
```typescript
|
||||
import type {
|
||||
MambaConfig,
|
||||
GraphAttentionType,
|
||||
WasmQueryDag
|
||||
} from 'ruvector-attention-unified-wasm';
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Reuse attention instances** - Creating new instances has overhead
|
||||
2. **Use typed arrays** - Pass `Float32Array` directly, not regular arrays
|
||||
3. **Batch when possible** - Multi-head attention supports batched inputs
|
||||
4. **Choose the right mechanism**:
|
||||
- Sequences: Scaled Dot-Product, Multi-Head, Linear, Flash
|
||||
- Long sequences: Linear, Flash, Mamba
|
||||
- Hierarchical data: Hyperbolic, Hierarchical Lorentz
|
||||
- Graphs: GAT, GCN, GraphSAGE
|
||||
- DAG structures: Topological, Critical Path, MinCut-Gated
|
||||
|
||||
## Browser Usage
|
||||
|
||||
```html
|
||||
<script type="module">
|
||||
import init, {
|
||||
WasmScaledDotProductAttention
|
||||
} from './pkg/ruvector_attention_unified_wasm.js';
|
||||
|
||||
async function run() {
|
||||
await init();
|
||||
|
||||
const attention = new WasmScaledDotProductAttention(64, 0.1);
|
||||
// ... use attention
|
||||
}
|
||||
|
||||
run();
|
||||
</script>
|
||||
```
|
||||
|
||||
## Node.js Usage
|
||||
|
||||
```javascript
|
||||
import { readFile } from 'fs/promises';
|
||||
import { initSync } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// Load WASM binary
|
||||
const wasmBuffer = await readFile(
|
||||
'./node_modules/ruvector-attention-unified-wasm/ruvector_attention_unified_wasm_bg.wasm'
|
||||
);
|
||||
initSync(wasmBuffer);
|
||||
|
||||
// Now use the library
|
||||
import { WasmMultiHeadAttention } from 'ruvector-attention-unified-wasm';
|
||||
```
|
||||
|
||||
## Memory Management
|
||||
|
||||
WASM objects need explicit cleanup:
|
||||
|
||||
```javascript
|
||||
const attention = new WasmScaledDotProductAttention(64, 0.1);
|
||||
try {
|
||||
const output = attention.forward(query, keys, values, numKeys);
|
||||
// ... use output
|
||||
} finally {
|
||||
attention.free(); // Release WASM memory
|
||||
}
|
||||
|
||||
// Or use Symbol.dispose (requires TypeScript 5.2+)
|
||||
{
|
||||
using attention = new WasmScaledDotProductAttention(64, 0.1);
|
||||
// Automatically freed at end of block
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT OR Apache-2.0
|
||||
|
||||
## Links
|
||||
|
||||
- [GitHub Repository](https://github.com/ruvnet/ruvector)
|
||||
- [Documentation](https://ruvector.dev/docs)
|
||||
- [NPM Package](https://www.npmjs.com/package/ruvector-attention-unified-wasm)
|
||||
401
vendor/ruvector/crates/ruvector-attention-unified-wasm/pkg/README.md
vendored
Normal file
401
vendor/ruvector/crates/ruvector-attention-unified-wasm/pkg/README.md
vendored
Normal file
@@ -0,0 +1,401 @@
|
||||
# @ruvector/attention-unified-wasm - 18+ Attention Mechanisms in WASM
|
||||
|
||||
[](https://www.npmjs.com/package/ruvector-attention-unified-wasm)
|
||||
[](https://github.com/ruvnet/ruvector)
|
||||
[](https://www.npmjs.com/package/ruvector-attention-unified-wasm)
|
||||
[](https://webassembly.org/)
|
||||
|
||||
**Unified WebAssembly library** with 18+ attention mechanisms spanning Neural, DAG, Graph, and State Space Model categories. Single import for all your attention needs in browser and edge environments.
|
||||
|
||||
## Key Features
|
||||
|
||||
- **7 Neural Attention**: Scaled dot-product, multi-head, hyperbolic, linear, flash, local-global, MoE
|
||||
- **7 DAG Attention**: Topological, causal cone, critical path, MinCut-gated, hierarchical Lorentz, parallel branch, temporal BTSP
|
||||
- **3 Graph Attention**: GAT, GCN, GraphSAGE
|
||||
- **1 State Space**: Mamba SSM with hybrid attention
|
||||
- **Unified API**: Single selector for all mechanisms
|
||||
- **WASM-Optimized**: Runs in browsers, Node.js, and edge runtimes
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
npm install ruvector-attention-unified-wasm
|
||||
# or
|
||||
yarn add ruvector-attention-unified-wasm
|
||||
# or
|
||||
pnpm add ruvector-attention-unified-wasm
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```typescript
|
||||
import init, {
|
||||
UnifiedAttention,
|
||||
availableMechanisms,
|
||||
scaledDotAttention,
|
||||
WasmMultiHeadAttention,
|
||||
MambaSSMAttention,
|
||||
MambaConfig
|
||||
} from 'ruvector-attention-unified-wasm';
|
||||
|
||||
await init();
|
||||
|
||||
// List all available mechanisms
|
||||
const mechanisms = availableMechanisms();
|
||||
console.log(mechanisms);
|
||||
// { neural: [...], dag: [...], graph: [...], ssm: [...] }
|
||||
|
||||
// Use unified selector
|
||||
const attention = new UnifiedAttention("multi_head");
|
||||
console.log(`Category: ${attention.category}`); // "neural"
|
||||
console.log(`Supports sequences: ${attention.supportsSequences()}`);
|
||||
|
||||
// Direct attention computation
|
||||
const query = new Float32Array([1.0, 0.5, 0.3, 0.1]);
|
||||
const keys = [new Float32Array([0.9, 0.4, 0.2, 0.1])];
|
||||
const values = [new Float32Array([1.0, 1.0, 1.0, 1.0])];
|
||||
const output = scaledDotAttention(query, keys, values);
|
||||
```
|
||||
|
||||
## Attention Categories
|
||||
|
||||
### Neural Attention (7 mechanisms)
|
||||
|
||||
Standard transformer-style attention mechanisms for sequence processing.
|
||||
|
||||
```typescript
|
||||
import {
|
||||
scaledDotAttention,
|
||||
WasmMultiHeadAttention,
|
||||
WasmHyperbolicAttention,
|
||||
WasmLinearAttention,
|
||||
WasmFlashAttention,
|
||||
WasmLocalGlobalAttention,
|
||||
WasmMoEAttention
|
||||
} from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// Scaled Dot-Product Attention
|
||||
const output = scaledDotAttention(query, keys, values, scale);
|
||||
|
||||
// Multi-Head Attention
|
||||
const mha = new WasmMultiHeadAttention(256, 8); // 256 dim, 8 heads
|
||||
const attended = mha.compute(query, keys, values);
|
||||
console.log(`Heads: ${mha.numHeads}, Head dim: ${mha.headDim}`);
|
||||
|
||||
// Hyperbolic Attention (for hierarchical data)
|
||||
const hyperbolic = new WasmHyperbolicAttention(64, -1.0); // curvature = -1
|
||||
const hypOut = hyperbolic.compute(query, keys, values);
|
||||
|
||||
// Linear Attention (O(n) complexity)
|
||||
const linear = new WasmLinearAttention(64, 32); // 32 random features
|
||||
const linOut = linear.compute(query, keys, values);
|
||||
|
||||
// Flash Attention (memory-efficient)
|
||||
const flash = new WasmFlashAttention(64, 32); // block size 32
|
||||
const flashOut = flash.compute(query, keys, values);
|
||||
|
||||
// Local-Global Attention (sparse)
|
||||
const localGlobal = new WasmLocalGlobalAttention(64, 128, 4); // window=128, 4 global
|
||||
const lgOut = localGlobal.compute(query, keys, values);
|
||||
|
||||
// Mixture of Experts Attention
|
||||
const moe = new WasmMoEAttention(64, 8, 2); // 8 experts, top-2
|
||||
const moeOut = moe.compute(query, keys, values);
|
||||
```
|
||||
|
||||
### DAG Attention (7 mechanisms)
|
||||
|
||||
Specialized attention for Directed Acyclic Graphs, query plans, and workflow optimization.
|
||||
|
||||
```typescript
|
||||
import {
|
||||
WasmQueryDag,
|
||||
WasmTopologicalAttention,
|
||||
WasmCausalConeAttention,
|
||||
WasmCriticalPathAttention,
|
||||
WasmMinCutGatedAttention,
|
||||
WasmHierarchicalLorentzAttention,
|
||||
WasmParallelBranchAttention,
|
||||
WasmTemporalBTSPAttention
|
||||
} from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// Create a query DAG
|
||||
const dag = new WasmQueryDag();
|
||||
const scan = dag.addNode("scan", 10.0);
|
||||
const filter = dag.addNode("filter", 5.0);
|
||||
const join = dag.addNode("join", 20.0);
|
||||
const aggregate = dag.addNode("aggregate", 15.0);
|
||||
|
||||
dag.addEdge(scan, filter);
|
||||
dag.addEdge(filter, join);
|
||||
dag.addEdge(scan, join);
|
||||
dag.addEdge(join, aggregate);
|
||||
|
||||
// Topological Attention (position-aware)
|
||||
const topo = new WasmTopologicalAttention(0.9); // decay factor
|
||||
const topoScores = topo.forward(dag);
|
||||
|
||||
// Causal Cone Attention (lightcone-based)
|
||||
const causal = new WasmCausalConeAttention(0.8, 0.6); // future discount, ancestor weight
|
||||
const causalScores = causal.forward(dag);
|
||||
|
||||
// Critical Path Attention
|
||||
const critical = new WasmCriticalPathAttention(2.0, 0.5); // path weight, branch penalty
|
||||
const criticalScores = critical.forward(dag);
|
||||
|
||||
// MinCut-Gated Attention (flow-based)
|
||||
const mincut = new WasmMinCutGatedAttention(0.5); // gate threshold
|
||||
const mincutScores = mincut.forward(dag);
|
||||
|
||||
// Hierarchical Lorentz Attention (hyperbolic DAG)
|
||||
const lorentz = new WasmHierarchicalLorentzAttention(-1.0, 0.1); // curvature, temperature
|
||||
const lorentzScores = lorentz.forward(dag);
|
||||
|
||||
// Parallel Branch Attention
|
||||
const parallel = new WasmParallelBranchAttention(4, 0.2); // max branches, sync penalty
|
||||
const parallelScores = parallel.forward(dag);
|
||||
|
||||
// Temporal BTSP Attention
|
||||
const btsp = new WasmTemporalBTSPAttention(0.95, 0.1); // decay, baseline
|
||||
const btspScores = btsp.forward(dag);
|
||||
```
|
||||
|
||||
### Graph Attention (3 mechanisms)
|
||||
|
||||
Attention mechanisms for graph-structured data.
|
||||
|
||||
```typescript
|
||||
import {
|
||||
WasmGNNLayer,
|
||||
GraphAttentionFactory,
|
||||
graphHierarchicalForward,
|
||||
graphDifferentiableSearch,
|
||||
WasmSearchConfig
|
||||
} from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// Create GNN layer with attention
|
||||
const gnn = new WasmGNNLayer(
|
||||
64, // input dimension
|
||||
128, // hidden dimension
|
||||
4, // attention heads
|
||||
0.1 // dropout
|
||||
);
|
||||
|
||||
// Forward pass for a node
|
||||
const nodeEmbed = new Float32Array(64);
|
||||
const neighborEmbeds = [
|
||||
new Float32Array(64),
|
||||
new Float32Array(64)
|
||||
];
|
||||
const edgeWeights = new Float32Array([0.8, 0.6]);
|
||||
|
||||
const updated = gnn.forward(nodeEmbed, neighborEmbeds, edgeWeights);
|
||||
console.log(`Output dim: ${gnn.outputDim}`);
|
||||
|
||||
// Get available graph attention types
|
||||
const types = GraphAttentionFactory.availableTypes(); // ["GAT", "GCN", "GraphSAGE"]
|
||||
|
||||
// Differentiable search
|
||||
const config = new WasmSearchConfig(5, 0.1); // top-5, temperature
|
||||
const candidates = [query, ...keys];
|
||||
const searchResults = graphDifferentiableSearch(query, candidates, config);
|
||||
|
||||
// Hierarchical forward through multiple layers
|
||||
const layers = [gnn, gnn2, gnn3];
|
||||
const final = graphHierarchicalForward(query, layerEmbeddings, layers);
|
||||
```
|
||||
|
||||
### Mamba SSM (State Space Model)
|
||||
|
||||
Selective State Space Model for efficient sequence processing with O(n) complexity.
|
||||
|
||||
```typescript
|
||||
import {
|
||||
MambaConfig,
|
||||
MambaSSMAttention,
|
||||
HybridMambaAttention
|
||||
} from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// Configure Mamba
|
||||
const config = new MambaConfig(256) // d_model = 256
|
||||
.withStateDim(16) // state space dimension
|
||||
.withExpandFactor(2) // expansion factor
|
||||
.withConvKernelSize(4); // conv kernel
|
||||
|
||||
console.log(`Dim: ${config.dim}, State: ${config.state_dim}`);
|
||||
|
||||
// Create Mamba SSM Attention
|
||||
const mamba = new MambaSSMAttention(config);
|
||||
console.log(`Inner dim: ${mamba.innerDim}`);
|
||||
|
||||
// Or use defaults
|
||||
const mambaDefault = MambaSSMAttention.withDefaults(128);
|
||||
|
||||
// Forward pass (seq_len, dim) flattened to 1D
|
||||
const seqLen = 32;
|
||||
const input = new Float32Array(seqLen * 256);
|
||||
const output = mamba.forward(input, seqLen);
|
||||
|
||||
// Get pseudo-attention scores for visualization
|
||||
const scores = mamba.getAttentionScores(input, seqLen);
|
||||
|
||||
// Hybrid Mamba + Local Attention
|
||||
const hybrid = new HybridMambaAttention(config, 64); // local window = 64
|
||||
const hybridOut = hybrid.forward(input, seqLen);
|
||||
console.log(`Local window: ${hybrid.localWindow}`);
|
||||
```
|
||||
|
||||
## Unified Selector API
|
||||
|
||||
```typescript
|
||||
import { UnifiedAttention } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// Create selector for any mechanism
|
||||
const attention = new UnifiedAttention("mamba");
|
||||
|
||||
// Query capabilities
|
||||
console.log(`Mechanism: ${attention.mechanism}`); // "mamba"
|
||||
console.log(`Category: ${attention.category}`); // "ssm"
|
||||
console.log(`Supports sequences: ${attention.supportsSequences()}`); // true
|
||||
console.log(`Supports graphs: ${attention.supportsGraphs()}`); // false
|
||||
console.log(`Supports hyperbolic: ${attention.supportsHyperbolic()}`); // false
|
||||
|
||||
// Valid mechanisms:
|
||||
// Neural: scaled_dot_product, multi_head, hyperbolic, linear, flash, local_global, moe
|
||||
// DAG: topological, causal_cone, critical_path, mincut_gated, hierarchical_lorentz, parallel_branch, temporal_btsp
|
||||
// Graph: gat, gcn, graphsage
|
||||
// SSM: mamba
|
||||
```
|
||||
|
||||
## Utility Functions
|
||||
|
||||
```typescript
|
||||
import { softmax, temperatureSoftmax, cosineSimilarity, getStats } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
// Softmax normalization
|
||||
const logits = new Float32Array([1.0, 2.0, 3.0]);
|
||||
const probs = softmax(logits);
|
||||
|
||||
// Temperature-scaled softmax
|
||||
const sharper = temperatureSoftmax(logits, 0.5); // More peaked
|
||||
const flatter = temperatureSoftmax(logits, 2.0); // More uniform
|
||||
|
||||
// Cosine similarity
|
||||
const a = new Float32Array([1, 0, 0]);
|
||||
const b = new Float32Array([0.7, 0.7, 0]);
|
||||
const sim = cosineSimilarity(a, b);
|
||||
|
||||
// Library statistics
|
||||
const stats = getStats();
|
||||
console.log(`Total mechanisms: ${stats.total_mechanisms}`); // 18
|
||||
console.log(`Neural: ${stats.neural_count}`); // 7
|
||||
console.log(`DAG: ${stats.dag_count}`); // 7
|
||||
console.log(`Graph: ${stats.graph_count}`); // 3
|
||||
console.log(`SSM: ${stats.ssm_count}`); // 1
|
||||
```
|
||||
|
||||
## Tensor Compression
|
||||
|
||||
```typescript
|
||||
import { WasmTensorCompress } from 'ruvector-attention-unified-wasm';
|
||||
|
||||
const compressor = new WasmTensorCompress();
|
||||
const embedding = new Float32Array(256);
|
||||
|
||||
// Compress based on access frequency
|
||||
const compressed = compressor.compress(embedding, 0.5); // 50% access frequency
|
||||
const decompressed = compressor.decompress(compressed);
|
||||
|
||||
// Or specify compression level directly
|
||||
const pq8 = compressor.compressWithLevel(embedding, "pq8"); // 8-bit product quantization
|
||||
|
||||
// Compression levels: "none", "half", "pq8", "pq4", "binary"
|
||||
const ratio = compressor.getCompressionRatio(0.5);
|
||||
```
|
||||
|
||||
## Performance Benchmarks
|
||||
|
||||
| Mechanism | Complexity | Latency (256-dim) |
|
||||
|-----------|------------|-------------------|
|
||||
| Scaled Dot-Product | O(n^2) | ~50us |
|
||||
| Multi-Head (8 heads) | O(n^2) | ~200us |
|
||||
| Linear | O(n) | ~30us |
|
||||
| Flash | O(n^2) | ~100us (memory-efficient) |
|
||||
| Mamba SSM | O(n) | ~80us |
|
||||
| Topological DAG | O(V+E) | ~40us |
|
||||
| GAT | O(E*h) | ~150us |
|
||||
|
||||
## API Reference Summary
|
||||
|
||||
### Neural Attention
|
||||
|
||||
| Class | Description |
|
||||
|-------|-------------|
|
||||
| `WasmMultiHeadAttention` | Parallel attention heads |
|
||||
| `WasmHyperbolicAttention` | Hyperbolic space attention |
|
||||
| `WasmLinearAttention` | O(n) performer-style |
|
||||
| `WasmFlashAttention` | Memory-efficient blocked |
|
||||
| `WasmLocalGlobalAttention` | Sparse with global tokens |
|
||||
| `WasmMoEAttention` | Mixture of experts |
|
||||
|
||||
### DAG Attention
|
||||
|
||||
| Class | Description |
|
||||
|-------|-------------|
|
||||
| `WasmTopologicalAttention` | Position in topological order |
|
||||
| `WasmCausalConeAttention` | Lightcone causality |
|
||||
| `WasmCriticalPathAttention` | Critical path weighting |
|
||||
| `WasmMinCutGatedAttention` | Flow-based gating |
|
||||
| `WasmHierarchicalLorentzAttention` | Multi-scale hyperbolic |
|
||||
| `WasmParallelBranchAttention` | Parallel DAG branches |
|
||||
| `WasmTemporalBTSPAttention` | Temporal eligibility traces |
|
||||
|
||||
### Graph Attention
|
||||
|
||||
| Class | Description |
|
||||
|-------|-------------|
|
||||
| `WasmGNNLayer` | Multi-head graph attention |
|
||||
| `GraphAttentionFactory` | Factory for graph attention types |
|
||||
|
||||
### State Space
|
||||
|
||||
| Class | Description |
|
||||
|-------|-------------|
|
||||
| `MambaSSMAttention` | Selective state space model |
|
||||
| `HybridMambaAttention` | Mamba + local attention |
|
||||
| `MambaConfig` | Mamba configuration |
|
||||
|
||||
## Use Cases
|
||||
|
||||
- **Transformers**: Standard and efficient attention variants
|
||||
- **Query Optimization**: DAG-aware attention for SQL planners
|
||||
- **Knowledge Graphs**: Graph attention for entity reasoning
|
||||
- **Long Sequences**: O(n) attention with Mamba SSM
|
||||
- **Hierarchical Data**: Hyperbolic attention for trees
|
||||
- **Sparse Attention**: Local-global for long documents
|
||||
|
||||
## Bundle Size
|
||||
|
||||
- **WASM binary**: ~331KB (uncompressed)
|
||||
- **Gzip compressed**: ~120KB
|
||||
- **JavaScript glue**: ~12KB
|
||||
|
||||
## Related Packages
|
||||
|
||||
- [ruvector-learning-wasm](https://www.npmjs.com/package/ruvector-learning-wasm) - MicroLoRA adaptation
|
||||
- [ruvector-nervous-system-wasm](https://www.npmjs.com/package/ruvector-nervous-system-wasm) - Bio-inspired neural
|
||||
- [ruvector-economy-wasm](https://www.npmjs.com/package/ruvector-economy-wasm) - CRDT credit economy
|
||||
|
||||
## License
|
||||
|
||||
MIT OR Apache-2.0
|
||||
|
||||
## Links
|
||||
|
||||
- [GitHub Repository](https://github.com/ruvnet/ruvector)
|
||||
- [Full Documentation](https://ruv.io)
|
||||
- [Bug Reports](https://github.com/ruvnet/ruvector/issues)
|
||||
|
||||
---
|
||||
|
||||
**Keywords**: attention mechanism, transformer, multi-head attention, DAG attention, graph neural network, GAT, GCN, GraphSAGE, Mamba, SSM, state space model, WebAssembly, WASM, hyperbolic attention, linear attention, flash attention, query optimization, neural network, deep learning, browser ML
|
||||
43
vendor/ruvector/crates/ruvector-attention-unified-wasm/pkg/package.json
vendored
Normal file
43
vendor/ruvector/crates/ruvector-attention-unified-wasm/pkg/package.json
vendored
Normal file
@@ -0,0 +1,43 @@
|
||||
{
|
||||
"name": "@ruvector/attention-unified-wasm",
|
||||
"type": "module",
|
||||
"collaborators": [
|
||||
"RuVector Team"
|
||||
],
|
||||
"author": "RuVector Team <ruvnet@users.noreply.github.com>",
|
||||
"description": "Unified WebAssembly bindings for 18+ attention mechanisms: Neural, DAG, Graph, and Mamba SSM",
|
||||
"version": "0.1.29",
|
||||
"license": "MIT OR Apache-2.0",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "https://github.com/ruvnet/ruvector"
|
||||
},
|
||||
"bugs": {
|
||||
"url": "https://github.com/ruvnet/ruvector/issues"
|
||||
},
|
||||
"files": [
|
||||
"ruvector_attention_unified_wasm_bg.wasm",
|
||||
"ruvector_attention_unified_wasm.js",
|
||||
"ruvector_attention_unified_wasm.d.ts",
|
||||
"ruvector_attention_unified_wasm_bg.wasm.d.ts",
|
||||
"README.md"
|
||||
],
|
||||
"main": "ruvector_attention_unified_wasm.js",
|
||||
"homepage": "https://ruv.io",
|
||||
"types": "ruvector_attention_unified_wasm.d.ts",
|
||||
"sideEffects": [
|
||||
"./snippets/*"
|
||||
],
|
||||
"keywords": [
|
||||
"attention",
|
||||
"wasm",
|
||||
"neural",
|
||||
"dag",
|
||||
"mamba",
|
||||
"ruvector",
|
||||
"webassembly",
|
||||
"transformer",
|
||||
"graph-attention",
|
||||
"state-space-models"
|
||||
]
|
||||
}
|
||||
790
vendor/ruvector/crates/ruvector-attention-unified-wasm/pkg/ruvector_attention_unified_wasm.d.ts
vendored
Normal file
790
vendor/ruvector/crates/ruvector-attention-unified-wasm/pkg/ruvector_attention_unified_wasm.d.ts
vendored
Normal file
@@ -0,0 +1,790 @@
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
export class DagAttentionFactory {
|
||||
private constructor();
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Get available DAG attention types
|
||||
*/
|
||||
static availableTypes(): any;
|
||||
/**
|
||||
* Get description for a DAG attention type
|
||||
*/
|
||||
static getDescription(attention_type: string): string;
|
||||
}
|
||||
|
||||
export class GraphAttentionFactory {
|
||||
private constructor();
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Get recommended use cases for a graph attention type
|
||||
*/
|
||||
static getUseCases(attention_type: string): any;
|
||||
/**
|
||||
* Get available graph attention types
|
||||
*/
|
||||
static availableTypes(): any;
|
||||
/**
|
||||
* Get description for a graph attention type
|
||||
*/
|
||||
static getDescription(attention_type: string): string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Graph attention mechanism types
|
||||
*/
|
||||
export enum GraphAttentionType {
|
||||
/**
|
||||
* Graph Attention Networks (Velickovic et al., 2018)
|
||||
*/
|
||||
GAT = 0,
|
||||
/**
|
||||
* Graph Convolutional Networks (Kipf & Welling, 2017)
|
||||
*/
|
||||
GCN = 1,
|
||||
/**
|
||||
* GraphSAGE (Hamilton et al., 2017)
|
||||
*/
|
||||
GraphSAGE = 2,
|
||||
}
|
||||
|
||||
export class HybridMambaAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new hybrid Mamba-Attention layer
|
||||
*/
|
||||
constructor(config: MambaConfig, local_window: number);
|
||||
/**
|
||||
* Forward pass
|
||||
*/
|
||||
forward(input: Float32Array, seq_len: number): Float32Array;
|
||||
/**
|
||||
* Get local window size
|
||||
*/
|
||||
readonly localWindow: number;
|
||||
}
|
||||
|
||||
export class MambaConfig {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Set state space dimension
|
||||
*/
|
||||
withStateDim(state_dim: number): MambaConfig;
|
||||
/**
|
||||
* Set expansion factor
|
||||
*/
|
||||
withExpandFactor(factor: number): MambaConfig;
|
||||
/**
|
||||
* Set convolution kernel size
|
||||
*/
|
||||
withConvKernelSize(size: number): MambaConfig;
|
||||
/**
|
||||
* Create a new Mamba configuration
|
||||
*/
|
||||
constructor(dim: number);
|
||||
/**
|
||||
* Model dimension (d_model)
|
||||
*/
|
||||
dim: number;
|
||||
/**
|
||||
* State space dimension (n)
|
||||
*/
|
||||
state_dim: number;
|
||||
/**
|
||||
* Expansion factor for inner dimension
|
||||
*/
|
||||
expand_factor: number;
|
||||
/**
|
||||
* Convolution kernel size
|
||||
*/
|
||||
conv_kernel_size: number;
|
||||
/**
|
||||
* Delta (discretization step) range minimum
|
||||
*/
|
||||
dt_min: number;
|
||||
/**
|
||||
* Delta range maximum
|
||||
*/
|
||||
dt_max: number;
|
||||
/**
|
||||
* Whether to use learnable D skip connection
|
||||
*/
|
||||
use_d_skip: boolean;
|
||||
}
|
||||
|
||||
export class MambaSSMAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create with default configuration
|
||||
*/
|
||||
static withDefaults(dim: number): MambaSSMAttention;
|
||||
/**
|
||||
* Compute attention-like scores (for visualization/analysis)
|
||||
*
|
||||
* Returns pseudo-attention scores showing which positions influence output
|
||||
*/
|
||||
getAttentionScores(input: Float32Array, seq_len: number): Float32Array;
|
||||
/**
|
||||
* Create a new Mamba SSM attention layer
|
||||
*/
|
||||
constructor(config: MambaConfig);
|
||||
/**
|
||||
* Forward pass through Mamba SSM
|
||||
*
|
||||
* # Arguments
|
||||
* * `input` - Input sequence (seq_len, dim) flattened to 1D
|
||||
* * `seq_len` - Sequence length
|
||||
*
|
||||
* # Returns
|
||||
* Output sequence (seq_len, dim) flattened to 1D
|
||||
*/
|
||||
forward(input: Float32Array, seq_len: number): Float32Array;
|
||||
/**
|
||||
* Get the configuration
|
||||
*/
|
||||
readonly config: MambaConfig;
|
||||
/**
|
||||
* Get the inner dimension
|
||||
*/
|
||||
readonly innerDim: number;
|
||||
}
|
||||
|
||||
export class UnifiedAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Check if this mechanism supports graph/DAG structures
|
||||
*/
|
||||
supportsGraphs(): boolean;
|
||||
/**
|
||||
* Check if this mechanism supports sequence processing
|
||||
*/
|
||||
supportsSequences(): boolean;
|
||||
/**
|
||||
* Check if this mechanism supports hyperbolic geometry
|
||||
*/
|
||||
supportsHyperbolic(): boolean;
|
||||
/**
|
||||
* Create a new unified attention selector
|
||||
*/
|
||||
constructor(mechanism: string);
|
||||
/**
|
||||
* Get the category of the selected mechanism
|
||||
*/
|
||||
readonly category: string;
|
||||
/**
|
||||
* Get the currently selected mechanism type
|
||||
*/
|
||||
readonly mechanism: string;
|
||||
}
|
||||
|
||||
export class WasmCausalConeAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new causal cone attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `future_discount` - Discount for future nodes
|
||||
* * `ancestor_weight` - Weight for ancestor influence
|
||||
*/
|
||||
constructor(future_discount: number, ancestor_weight: number);
|
||||
/**
|
||||
* Compute attention scores for the DAG
|
||||
*/
|
||||
forward(dag: WasmQueryDag): Float32Array;
|
||||
}
|
||||
|
||||
export class WasmCriticalPathAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new critical path attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `path_weight` - Weight for critical path membership
|
||||
* * `branch_penalty` - Penalty for branching nodes
|
||||
*/
|
||||
constructor(path_weight: number, branch_penalty: number);
|
||||
/**
|
||||
* Compute attention scores for the DAG
|
||||
*/
|
||||
forward(dag: WasmQueryDag): Float32Array;
|
||||
}
|
||||
|
||||
export class WasmFlashAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new flash attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `dim` - Embedding dimension
|
||||
* * `block_size` - Block size for tiled computation
|
||||
*/
|
||||
constructor(dim: number, block_size: number);
|
||||
/**
|
||||
* Compute flash attention
|
||||
*/
|
||||
compute(query: Float32Array, keys: any, values: any): Float32Array;
|
||||
}
|
||||
|
||||
export class WasmGNNLayer {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new GNN layer with attention
|
||||
*
|
||||
* # Arguments
|
||||
* * `input_dim` - Dimension of input node embeddings
|
||||
* * `hidden_dim` - Dimension of hidden representations
|
||||
* * `heads` - Number of attention heads
|
||||
* * `dropout` - Dropout rate (0.0 to 1.0)
|
||||
*/
|
||||
constructor(input_dim: number, hidden_dim: number, heads: number, dropout: number);
|
||||
/**
|
||||
* Forward pass through the GNN layer
|
||||
*
|
||||
* # Arguments
|
||||
* * `node_embedding` - Current node's embedding (Float32Array)
|
||||
* * `neighbor_embeddings` - Embeddings of neighbor nodes (array of Float32Arrays)
|
||||
* * `edge_weights` - Weights of edges to neighbors (Float32Array)
|
||||
*
|
||||
* # Returns
|
||||
* Updated node embedding (Float32Array)
|
||||
*/
|
||||
forward(node_embedding: Float32Array, neighbor_embeddings: any, edge_weights: Float32Array): Float32Array;
|
||||
/**
|
||||
* Get the output dimension
|
||||
*/
|
||||
readonly outputDim: number;
|
||||
}
|
||||
|
||||
export class WasmHierarchicalLorentzAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new hierarchical Lorentz attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `curvature` - Hyperbolic curvature parameter
|
||||
* * `temperature` - Temperature for softmax
|
||||
*/
|
||||
constructor(curvature: number, temperature: number);
|
||||
/**
|
||||
* Compute attention scores for the DAG
|
||||
*/
|
||||
forward(dag: WasmQueryDag): Float32Array;
|
||||
}
|
||||
|
||||
export class WasmHyperbolicAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new hyperbolic attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `dim` - Embedding dimension
|
||||
* * `curvature` - Hyperbolic curvature parameter (negative for hyperbolic space)
|
||||
*/
|
||||
constructor(dim: number, curvature: number);
|
||||
/**
|
||||
* Compute hyperbolic attention
|
||||
*/
|
||||
compute(query: Float32Array, keys: any, values: any): Float32Array;
|
||||
/**
|
||||
* Get the curvature parameter
|
||||
*/
|
||||
readonly curvature: number;
|
||||
}
|
||||
|
||||
export class WasmLinearAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new linear attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `dim` - Embedding dimension
|
||||
* * `num_features` - Number of random features for kernel approximation
|
||||
*/
|
||||
constructor(dim: number, num_features: number);
|
||||
/**
|
||||
* Compute linear attention
|
||||
*/
|
||||
compute(query: Float32Array, keys: any, values: any): Float32Array;
|
||||
}
|
||||
|
||||
export class WasmLocalGlobalAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new local-global attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `dim` - Embedding dimension
|
||||
* * `local_window` - Size of local attention window
|
||||
* * `global_tokens` - Number of global attention tokens
|
||||
*/
|
||||
constructor(dim: number, local_window: number, global_tokens: number);
|
||||
/**
|
||||
* Compute local-global attention
|
||||
*/
|
||||
compute(query: Float32Array, keys: any, values: any): Float32Array;
|
||||
}
|
||||
|
||||
export class WasmMinCutGatedAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new MinCut-gated attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `gate_threshold` - Threshold for gating (0.0-1.0)
|
||||
*/
|
||||
constructor(gate_threshold: number);
|
||||
/**
|
||||
* Compute attention scores for the DAG
|
||||
*/
|
||||
forward(dag: WasmQueryDag): Float32Array;
|
||||
}
|
||||
|
||||
export class WasmMoEAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new MoE attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `dim` - Embedding dimension
|
||||
* * `num_experts` - Number of expert attention mechanisms
|
||||
* * `top_k` - Number of experts to activate per query
|
||||
*/
|
||||
constructor(dim: number, num_experts: number, top_k: number);
|
||||
/**
|
||||
* Compute MoE attention
|
||||
*/
|
||||
compute(query: Float32Array, keys: any, values: any): Float32Array;
|
||||
}
|
||||
|
||||
export class WasmMultiHeadAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new multi-head attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `dim` - Embedding dimension (must be divisible by num_heads)
|
||||
* * `num_heads` - Number of parallel attention heads
|
||||
*/
|
||||
constructor(dim: number, num_heads: number);
|
||||
/**
|
||||
* Compute multi-head attention
|
||||
*
|
||||
* # Arguments
|
||||
* * `query` - Query vector
|
||||
* * `keys` - Array of key vectors
|
||||
* * `values` - Array of value vectors
|
||||
*/
|
||||
compute(query: Float32Array, keys: any, values: any): Float32Array;
|
||||
/**
|
||||
* Get the embedding dimension
|
||||
*/
|
||||
readonly dim: number;
|
||||
/**
|
||||
* Get the dimension per head
|
||||
*/
|
||||
readonly headDim: number;
|
||||
/**
|
||||
* Get the number of attention heads
|
||||
*/
|
||||
readonly numHeads: number;
|
||||
}
|
||||
|
||||
export class WasmParallelBranchAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new parallel branch attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `max_branches` - Maximum number of branches to consider
|
||||
* * `sync_penalty` - Penalty for synchronization between branches
|
||||
*/
|
||||
constructor(max_branches: number, sync_penalty: number);
|
||||
/**
|
||||
* Compute attention scores for the DAG
|
||||
*/
|
||||
forward(dag: WasmQueryDag): Float32Array;
|
||||
}
|
||||
|
||||
export class WasmQueryDag {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new empty DAG
|
||||
*/
|
||||
constructor();
|
||||
/**
|
||||
* Serialize to JSON
|
||||
*/
|
||||
toJson(): string;
|
||||
/**
|
||||
* Add an edge between nodes
|
||||
*
|
||||
* # Arguments
|
||||
* * `from` - Source node ID
|
||||
* * `to` - Target node ID
|
||||
*
|
||||
* # Returns
|
||||
* True if edge was added successfully
|
||||
*/
|
||||
addEdge(from: number, to: number): boolean;
|
||||
/**
|
||||
* Add a node with operator type and cost
|
||||
*
|
||||
* # Arguments
|
||||
* * `op_type` - Operator type: "scan", "filter", "join", "aggregate", "project", "sort"
|
||||
* * `cost` - Estimated execution cost
|
||||
*
|
||||
* # Returns
|
||||
* Node ID
|
||||
*/
|
||||
addNode(op_type: string, cost: number): number;
|
||||
/**
|
||||
* Get the number of edges
|
||||
*/
|
||||
readonly edgeCount: number;
|
||||
/**
|
||||
* Get the number of nodes
|
||||
*/
|
||||
readonly nodeCount: number;
|
||||
}
|
||||
|
||||
export class WasmSearchConfig {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new search configuration
|
||||
*/
|
||||
constructor(k: number, temperature: number);
|
||||
/**
|
||||
* Number of top results to return
|
||||
*/
|
||||
k: number;
|
||||
/**
|
||||
* Temperature for softmax
|
||||
*/
|
||||
temperature: number;
|
||||
}
|
||||
|
||||
export class WasmTemporalBTSPAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new temporal BTSP attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `eligibility_decay` - Decay rate for eligibility traces (0.0-1.0)
|
||||
* * `baseline_attention` - Baseline attention for nodes without history
|
||||
*/
|
||||
constructor(eligibility_decay: number, baseline_attention: number);
|
||||
/**
|
||||
* Compute attention scores for the DAG
|
||||
*/
|
||||
forward(dag: WasmQueryDag): Float32Array;
|
||||
}
|
||||
|
||||
export class WasmTensorCompress {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Decompress a compressed tensor
|
||||
*/
|
||||
decompress(compressed: any): Float32Array;
|
||||
/**
|
||||
* Compress with explicit compression level
|
||||
*
|
||||
* # Arguments
|
||||
* * `embedding` - The input embedding vector
|
||||
* * `level` - Compression level: "none", "half", "pq8", "pq4", "binary"
|
||||
*/
|
||||
compressWithLevel(embedding: Float32Array, level: string): any;
|
||||
/**
|
||||
* Get compression ratio estimate for a given access frequency
|
||||
*/
|
||||
getCompressionRatio(access_freq: number): number;
|
||||
/**
|
||||
* Create a new tensor compressor
|
||||
*/
|
||||
constructor();
|
||||
/**
|
||||
* Compress an embedding based on access frequency
|
||||
*
|
||||
* # Arguments
|
||||
* * `embedding` - The input embedding vector
|
||||
* * `access_freq` - Access frequency in range [0.0, 1.0]
|
||||
* - f > 0.8: Full precision (hot data)
|
||||
* - f > 0.4: Half precision (warm data)
|
||||
* - f > 0.1: 8-bit PQ (cool data)
|
||||
* - f > 0.01: 4-bit PQ (cold data)
|
||||
* - f <= 0.01: Binary (archive)
|
||||
*/
|
||||
compress(embedding: Float32Array, access_freq: number): any;
|
||||
}
|
||||
|
||||
export class WasmTopologicalAttention {
|
||||
free(): void;
|
||||
[Symbol.dispose](): void;
|
||||
/**
|
||||
* Create a new topological attention instance
|
||||
*
|
||||
* # Arguments
|
||||
* * `decay_factor` - Decay factor for position-based attention (0.0-1.0)
|
||||
*/
|
||||
constructor(decay_factor: number);
|
||||
/**
|
||||
* Compute attention scores for the DAG
|
||||
*
|
||||
* # Returns
|
||||
* Attention scores for each node
|
||||
*/
|
||||
forward(dag: WasmQueryDag): Float32Array;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get information about all available attention mechanisms
|
||||
*/
|
||||
export function availableMechanisms(): any;
|
||||
|
||||
/**
|
||||
* Compute cosine similarity between two vectors
|
||||
*/
|
||||
export function cosineSimilarity(a: Float32Array, b: Float32Array): number;
|
||||
|
||||
/**
|
||||
* Get summary statistics about the unified attention library
|
||||
*/
|
||||
export function getStats(): any;
|
||||
|
||||
/**
|
||||
* Differentiable search using soft attention mechanism
|
||||
*
|
||||
* # Arguments
|
||||
* * `query` - The query vector
|
||||
* * `candidate_embeddings` - List of candidate embedding vectors
|
||||
* * `config` - Search configuration
|
||||
*
|
||||
* # Returns
|
||||
* Object with indices and weights for top-k candidates
|
||||
*/
|
||||
export function graphDifferentiableSearch(query: Float32Array, candidate_embeddings: any, config: WasmSearchConfig): any;
|
||||
|
||||
/**
|
||||
* Hierarchical forward pass through multiple GNN layers
|
||||
*
|
||||
* # Arguments
|
||||
* * `query` - The query vector
|
||||
* * `layer_embeddings` - Embeddings organized by layer
|
||||
* * `gnn_layers` - Array of GNN layers
|
||||
*
|
||||
* # Returns
|
||||
* Final embedding after hierarchical processing
|
||||
*/
|
||||
export function graphHierarchicalForward(query: Float32Array, layer_embeddings: any, gnn_layers: WasmGNNLayer[]): Float32Array;
|
||||
|
||||
/**
|
||||
* Initialize the WASM module with panic hook for better error messages
|
||||
*/
|
||||
export function init(): void;
|
||||
|
||||
/**
|
||||
* Compute scaled dot-product attention
|
||||
*
|
||||
* Standard transformer attention: softmax(QK^T / sqrt(d)) * V
|
||||
*
|
||||
* # Arguments
|
||||
* * `query` - Query vector (Float32Array)
|
||||
* * `keys` - Array of key vectors (JsValue - array of Float32Arrays)
|
||||
* * `values` - Array of value vectors (JsValue - array of Float32Arrays)
|
||||
* * `scale` - Optional scaling factor (defaults to 1/sqrt(dim))
|
||||
*
|
||||
* # Returns
|
||||
* Attention-weighted output vector
|
||||
*/
|
||||
export function scaledDotAttention(query: Float32Array, keys: any, values: any, scale?: number | null): Float32Array;
|
||||
|
||||
/**
|
||||
* Softmax normalization
|
||||
*/
|
||||
export function softmax(values: Float32Array): Float32Array;
|
||||
|
||||
/**
|
||||
* Temperature-scaled softmax
|
||||
*/
|
||||
export function temperatureSoftmax(values: Float32Array, temperature: number): Float32Array;
|
||||
|
||||
/**
|
||||
* Get the version of the unified attention WASM crate
|
||||
*/
|
||||
export function version(): string;
|
||||
|
||||
export type InitInput = RequestInfo | URL | Response | BufferSource | WebAssembly.Module;
|
||||
|
||||
export interface InitOutput {
|
||||
readonly memory: WebAssembly.Memory;
|
||||
readonly __wbg_dagattentionfactory_free: (a: number, b: number) => void;
|
||||
readonly __wbg_get_mambaconfig_conv_kernel_size: (a: number) => number;
|
||||
readonly __wbg_get_mambaconfig_dim: (a: number) => number;
|
||||
readonly __wbg_get_mambaconfig_dt_max: (a: number) => number;
|
||||
readonly __wbg_get_mambaconfig_dt_min: (a: number) => number;
|
||||
readonly __wbg_get_mambaconfig_expand_factor: (a: number) => number;
|
||||
readonly __wbg_get_mambaconfig_state_dim: (a: number) => number;
|
||||
readonly __wbg_get_mambaconfig_use_d_skip: (a: number) => number;
|
||||
readonly __wbg_get_wasmsearchconfig_temperature: (a: number) => number;
|
||||
readonly __wbg_hybridmambaattention_free: (a: number, b: number) => void;
|
||||
readonly __wbg_mambaconfig_free: (a: number, b: number) => void;
|
||||
readonly __wbg_mambassmattention_free: (a: number, b: number) => void;
|
||||
readonly __wbg_set_mambaconfig_conv_kernel_size: (a: number, b: number) => void;
|
||||
readonly __wbg_set_mambaconfig_dim: (a: number, b: number) => void;
|
||||
readonly __wbg_set_mambaconfig_dt_max: (a: number, b: number) => void;
|
||||
readonly __wbg_set_mambaconfig_dt_min: (a: number, b: number) => void;
|
||||
readonly __wbg_set_mambaconfig_expand_factor: (a: number, b: number) => void;
|
||||
readonly __wbg_set_mambaconfig_state_dim: (a: number, b: number) => void;
|
||||
readonly __wbg_set_mambaconfig_use_d_skip: (a: number, b: number) => void;
|
||||
readonly __wbg_set_wasmsearchconfig_temperature: (a: number, b: number) => void;
|
||||
readonly __wbg_unifiedattention_free: (a: number, b: number) => void;
|
||||
readonly __wbg_wasmcausalconeattention_free: (a: number, b: number) => void;
|
||||
readonly __wbg_wasmflashattention_free: (a: number, b: number) => void;
|
||||
readonly __wbg_wasmgnnlayer_free: (a: number, b: number) => void;
|
||||
readonly __wbg_wasmhyperbolicattention_free: (a: number, b: number) => void;
|
||||
readonly __wbg_wasmlinearattention_free: (a: number, b: number) => void;
|
||||
readonly __wbg_wasmmincutgatedattention_free: (a: number, b: number) => void;
|
||||
readonly __wbg_wasmmoeattention_free: (a: number, b: number) => void;
|
||||
readonly __wbg_wasmmultiheadattention_free: (a: number, b: number) => void;
|
||||
readonly __wbg_wasmquerydag_free: (a: number, b: number) => void;
|
||||
readonly __wbg_wasmtensorcompress_free: (a: number, b: number) => void;
|
||||
readonly availableMechanisms: () => number;
|
||||
readonly cosineSimilarity: (a: number, b: number, c: number, d: number, e: number) => void;
|
||||
readonly dagattentionfactory_availableTypes: () => number;
|
||||
readonly dagattentionfactory_getDescription: (a: number, b: number, c: number) => void;
|
||||
readonly getStats: () => number;
|
||||
readonly graphDifferentiableSearch: (a: number, b: number, c: number, d: number, e: number) => void;
|
||||
readonly graphHierarchicalForward: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
readonly graphattentionfactory_availableTypes: () => number;
|
||||
readonly graphattentionfactory_getDescription: (a: number, b: number, c: number) => void;
|
||||
readonly graphattentionfactory_getUseCases: (a: number, b: number) => number;
|
||||
readonly hybridmambaattention_forward: (a: number, b: number, c: number, d: number, e: number) => void;
|
||||
readonly hybridmambaattention_localWindow: (a: number) => number;
|
||||
readonly hybridmambaattention_new: (a: number, b: number) => number;
|
||||
readonly mambaconfig_new: (a: number) => number;
|
||||
readonly mambaconfig_withConvKernelSize: (a: number, b: number) => number;
|
||||
readonly mambaconfig_withExpandFactor: (a: number, b: number) => number;
|
||||
readonly mambaconfig_withStateDim: (a: number, b: number) => number;
|
||||
readonly mambassmattention_config: (a: number) => number;
|
||||
readonly mambassmattention_forward: (a: number, b: number, c: number, d: number, e: number) => void;
|
||||
readonly mambassmattention_getAttentionScores: (a: number, b: number, c: number, d: number, e: number) => void;
|
||||
readonly mambassmattention_innerDim: (a: number) => number;
|
||||
readonly mambassmattention_new: (a: number) => number;
|
||||
readonly mambassmattention_withDefaults: (a: number) => number;
|
||||
readonly scaledDotAttention: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
readonly softmax: (a: number, b: number, c: number) => void;
|
||||
readonly temperatureSoftmax: (a: number, b: number, c: number, d: number) => void;
|
||||
readonly unifiedattention_category: (a: number, b: number) => void;
|
||||
readonly unifiedattention_mechanism: (a: number, b: number) => void;
|
||||
readonly unifiedattention_new: (a: number, b: number, c: number) => void;
|
||||
readonly unifiedattention_supportsGraphs: (a: number) => number;
|
||||
readonly unifiedattention_supportsHyperbolic: (a: number) => number;
|
||||
readonly unifiedattention_supportsSequences: (a: number) => number;
|
||||
readonly version: (a: number) => void;
|
||||
readonly wasmcausalconeattention_forward: (a: number, b: number, c: number) => void;
|
||||
readonly wasmcriticalpathattention_forward: (a: number, b: number, c: number) => void;
|
||||
readonly wasmflashattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
readonly wasmflashattention_new: (a: number, b: number) => number;
|
||||
readonly wasmgnnlayer_forward: (a: number, b: number, c: number, d: number, e: number, f: number, g: number) => void;
|
||||
readonly wasmgnnlayer_new: (a: number, b: number, c: number, d: number, e: number) => void;
|
||||
readonly wasmgnnlayer_outputDim: (a: number) => number;
|
||||
readonly wasmhierarchicallorentzattention_forward: (a: number, b: number, c: number) => void;
|
||||
readonly wasmhyperbolicattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
readonly wasmhyperbolicattention_curvature: (a: number) => number;
|
||||
readonly wasmhyperbolicattention_new: (a: number, b: number) => number;
|
||||
readonly wasmlinearattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
readonly wasmlinearattention_new: (a: number, b: number) => number;
|
||||
readonly wasmlocalglobalattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
readonly wasmlocalglobalattention_new: (a: number, b: number, c: number) => number;
|
||||
readonly wasmmincutgatedattention_forward: (a: number, b: number, c: number) => void;
|
||||
readonly wasmmoeattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
readonly wasmmoeattention_new: (a: number, b: number, c: number) => number;
|
||||
readonly wasmmultiheadattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
readonly wasmmultiheadattention_dim: (a: number) => number;
|
||||
readonly wasmmultiheadattention_headDim: (a: number) => number;
|
||||
readonly wasmmultiheadattention_new: (a: number, b: number, c: number) => void;
|
||||
readonly wasmmultiheadattention_numHeads: (a: number) => number;
|
||||
readonly wasmparallelbranchattention_forward: (a: number, b: number, c: number) => void;
|
||||
readonly wasmquerydag_addEdge: (a: number, b: number, c: number) => number;
|
||||
readonly wasmquerydag_addNode: (a: number, b: number, c: number, d: number) => number;
|
||||
readonly wasmquerydag_edgeCount: (a: number) => number;
|
||||
readonly wasmquerydag_new: () => number;
|
||||
readonly wasmquerydag_nodeCount: (a: number) => number;
|
||||
readonly wasmquerydag_toJson: (a: number, b: number) => void;
|
||||
readonly wasmtemporalbtspattention_forward: (a: number, b: number, c: number) => void;
|
||||
readonly wasmtensorcompress_compress: (a: number, b: number, c: number, d: number, e: number) => void;
|
||||
readonly wasmtensorcompress_compressWithLevel: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
readonly wasmtensorcompress_decompress: (a: number, b: number, c: number) => void;
|
||||
readonly wasmtensorcompress_getCompressionRatio: (a: number, b: number) => number;
|
||||
readonly wasmtensorcompress_new: () => number;
|
||||
readonly wasmtopologicalattention_forward: (a: number, b: number, c: number) => void;
|
||||
readonly init: () => void;
|
||||
readonly wasmmincutgatedattention_new: (a: number) => number;
|
||||
readonly wasmtopologicalattention_new: (a: number) => number;
|
||||
readonly __wbg_set_wasmsearchconfig_k: (a: number, b: number) => void;
|
||||
readonly wasmcausalconeattention_new: (a: number, b: number) => number;
|
||||
readonly wasmcriticalpathattention_new: (a: number, b: number) => number;
|
||||
readonly wasmhierarchicallorentzattention_new: (a: number, b: number) => number;
|
||||
readonly wasmparallelbranchattention_new: (a: number, b: number) => number;
|
||||
readonly wasmsearchconfig_new: (a: number, b: number) => number;
|
||||
readonly wasmtemporalbtspattention_new: (a: number, b: number) => number;
|
||||
readonly __wbg_get_wasmsearchconfig_k: (a: number) => number;
|
||||
readonly __wbg_graphattentionfactory_free: (a: number, b: number) => void;
|
||||
readonly __wbg_wasmcriticalpathattention_free: (a: number, b: number) => void;
|
||||
readonly __wbg_wasmhierarchicallorentzattention_free: (a: number, b: number) => void;
|
||||
readonly __wbg_wasmlocalglobalattention_free: (a: number, b: number) => void;
|
||||
readonly __wbg_wasmparallelbranchattention_free: (a: number, b: number) => void;
|
||||
readonly __wbg_wasmsearchconfig_free: (a: number, b: number) => void;
|
||||
readonly __wbg_wasmtemporalbtspattention_free: (a: number, b: number) => void;
|
||||
readonly __wbg_wasmtopologicalattention_free: (a: number, b: number) => void;
|
||||
readonly __wbindgen_export: (a: number, b: number) => number;
|
||||
readonly __wbindgen_export2: (a: number, b: number, c: number, d: number) => number;
|
||||
readonly __wbindgen_export3: (a: number) => void;
|
||||
readonly __wbindgen_export4: (a: number, b: number, c: number) => void;
|
||||
readonly __wbindgen_add_to_stack_pointer: (a: number) => number;
|
||||
readonly __wbindgen_start: () => void;
|
||||
}
|
||||
|
||||
export type SyncInitInput = BufferSource | WebAssembly.Module;
|
||||
|
||||
/**
|
||||
* Instantiates the given `module`, which can either be bytes or
|
||||
* a precompiled `WebAssembly.Module`.
|
||||
*
|
||||
* @param {{ module: SyncInitInput }} module - Passing `SyncInitInput` directly is deprecated.
|
||||
*
|
||||
* @returns {InitOutput}
|
||||
*/
|
||||
export function initSync(module: { module: SyncInitInput } | SyncInitInput): InitOutput;
|
||||
|
||||
/**
|
||||
* If `module_or_path` is {RequestInfo} or {URL}, makes a request and
|
||||
* for everything else, calls `WebAssembly.instantiate` directly.
|
||||
*
|
||||
* @param {{ module_or_path: InitInput | Promise<InitInput> }} module_or_path - Passing `InitInput` directly is deprecated.
|
||||
*
|
||||
* @returns {Promise<InitOutput>}
|
||||
*/
|
||||
export default function __wbg_init (module_or_path?: { module_or_path: InitInput | Promise<InitInput> } | InitInput | Promise<InitInput>): Promise<InitOutput>;
|
||||
2751
vendor/ruvector/crates/ruvector-attention-unified-wasm/pkg/ruvector_attention_unified_wasm.js
vendored
Normal file
2751
vendor/ruvector/crates/ruvector-attention-unified-wasm/pkg/ruvector_attention_unified_wasm.js
vendored
Normal file
File diff suppressed because it is too large
Load Diff
BIN
vendor/ruvector/crates/ruvector-attention-unified-wasm/pkg/ruvector_attention_unified_wasm_bg.wasm
vendored
Normal file
BIN
vendor/ruvector/crates/ruvector-attention-unified-wasm/pkg/ruvector_attention_unified_wasm_bg.wasm
vendored
Normal file
Binary file not shown.
@@ -0,0 +1,129 @@
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export const memory: WebAssembly.Memory;
|
||||
export const __wbg_dagattentionfactory_free: (a: number, b: number) => void;
|
||||
export const __wbg_get_mambaconfig_conv_kernel_size: (a: number) => number;
|
||||
export const __wbg_get_mambaconfig_dim: (a: number) => number;
|
||||
export const __wbg_get_mambaconfig_dt_max: (a: number) => number;
|
||||
export const __wbg_get_mambaconfig_dt_min: (a: number) => number;
|
||||
export const __wbg_get_mambaconfig_expand_factor: (a: number) => number;
|
||||
export const __wbg_get_mambaconfig_state_dim: (a: number) => number;
|
||||
export const __wbg_get_mambaconfig_use_d_skip: (a: number) => number;
|
||||
export const __wbg_get_wasmsearchconfig_temperature: (a: number) => number;
|
||||
export const __wbg_hybridmambaattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_mambaconfig_free: (a: number, b: number) => void;
|
||||
export const __wbg_mambassmattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_set_mambaconfig_conv_kernel_size: (a: number, b: number) => void;
|
||||
export const __wbg_set_mambaconfig_dim: (a: number, b: number) => void;
|
||||
export const __wbg_set_mambaconfig_dt_max: (a: number, b: number) => void;
|
||||
export const __wbg_set_mambaconfig_dt_min: (a: number, b: number) => void;
|
||||
export const __wbg_set_mambaconfig_expand_factor: (a: number, b: number) => void;
|
||||
export const __wbg_set_mambaconfig_state_dim: (a: number, b: number) => void;
|
||||
export const __wbg_set_mambaconfig_use_d_skip: (a: number, b: number) => void;
|
||||
export const __wbg_set_wasmsearchconfig_temperature: (a: number, b: number) => void;
|
||||
export const __wbg_unifiedattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmcausalconeattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmflashattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmgnnlayer_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmhyperbolicattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmlinearattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmmincutgatedattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmmoeattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmmultiheadattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmquerydag_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmtensorcompress_free: (a: number, b: number) => void;
|
||||
export const availableMechanisms: () => number;
|
||||
export const cosineSimilarity: (a: number, b: number, c: number, d: number, e: number) => void;
|
||||
export const dagattentionfactory_availableTypes: () => number;
|
||||
export const dagattentionfactory_getDescription: (a: number, b: number, c: number) => void;
|
||||
export const getStats: () => number;
|
||||
export const graphDifferentiableSearch: (a: number, b: number, c: number, d: number, e: number) => void;
|
||||
export const graphHierarchicalForward: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const graphattentionfactory_availableTypes: () => number;
|
||||
export const graphattentionfactory_getDescription: (a: number, b: number, c: number) => void;
|
||||
export const graphattentionfactory_getUseCases: (a: number, b: number) => number;
|
||||
export const hybridmambaattention_forward: (a: number, b: number, c: number, d: number, e: number) => void;
|
||||
export const hybridmambaattention_localWindow: (a: number) => number;
|
||||
export const hybridmambaattention_new: (a: number, b: number) => number;
|
||||
export const mambaconfig_new: (a: number) => number;
|
||||
export const mambaconfig_withConvKernelSize: (a: number, b: number) => number;
|
||||
export const mambaconfig_withExpandFactor: (a: number, b: number) => number;
|
||||
export const mambaconfig_withStateDim: (a: number, b: number) => number;
|
||||
export const mambassmattention_config: (a: number) => number;
|
||||
export const mambassmattention_forward: (a: number, b: number, c: number, d: number, e: number) => void;
|
||||
export const mambassmattention_getAttentionScores: (a: number, b: number, c: number, d: number, e: number) => void;
|
||||
export const mambassmattention_innerDim: (a: number) => number;
|
||||
export const mambassmattention_new: (a: number) => number;
|
||||
export const mambassmattention_withDefaults: (a: number) => number;
|
||||
export const scaledDotAttention: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const softmax: (a: number, b: number, c: number) => void;
|
||||
export const temperatureSoftmax: (a: number, b: number, c: number, d: number) => void;
|
||||
export const unifiedattention_category: (a: number, b: number) => void;
|
||||
export const unifiedattention_mechanism: (a: number, b: number) => void;
|
||||
export const unifiedattention_new: (a: number, b: number, c: number) => void;
|
||||
export const unifiedattention_supportsGraphs: (a: number) => number;
|
||||
export const unifiedattention_supportsHyperbolic: (a: number) => number;
|
||||
export const unifiedattention_supportsSequences: (a: number) => number;
|
||||
export const version: (a: number) => void;
|
||||
export const wasmcausalconeattention_forward: (a: number, b: number, c: number) => void;
|
||||
export const wasmcriticalpathattention_forward: (a: number, b: number, c: number) => void;
|
||||
export const wasmflashattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmflashattention_new: (a: number, b: number) => number;
|
||||
export const wasmgnnlayer_forward: (a: number, b: number, c: number, d: number, e: number, f: number, g: number) => void;
|
||||
export const wasmgnnlayer_new: (a: number, b: number, c: number, d: number, e: number) => void;
|
||||
export const wasmgnnlayer_outputDim: (a: number) => number;
|
||||
export const wasmhierarchicallorentzattention_forward: (a: number, b: number, c: number) => void;
|
||||
export const wasmhyperbolicattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmhyperbolicattention_curvature: (a: number) => number;
|
||||
export const wasmhyperbolicattention_new: (a: number, b: number) => number;
|
||||
export const wasmlinearattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmlinearattention_new: (a: number, b: number) => number;
|
||||
export const wasmlocalglobalattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmlocalglobalattention_new: (a: number, b: number, c: number) => number;
|
||||
export const wasmmincutgatedattention_forward: (a: number, b: number, c: number) => void;
|
||||
export const wasmmoeattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmmoeattention_new: (a: number, b: number, c: number) => number;
|
||||
export const wasmmultiheadattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmmultiheadattention_dim: (a: number) => number;
|
||||
export const wasmmultiheadattention_headDim: (a: number) => number;
|
||||
export const wasmmultiheadattention_new: (a: number, b: number, c: number) => void;
|
||||
export const wasmmultiheadattention_numHeads: (a: number) => number;
|
||||
export const wasmparallelbranchattention_forward: (a: number, b: number, c: number) => void;
|
||||
export const wasmquerydag_addEdge: (a: number, b: number, c: number) => number;
|
||||
export const wasmquerydag_addNode: (a: number, b: number, c: number, d: number) => number;
|
||||
export const wasmquerydag_edgeCount: (a: number) => number;
|
||||
export const wasmquerydag_new: () => number;
|
||||
export const wasmquerydag_nodeCount: (a: number) => number;
|
||||
export const wasmquerydag_toJson: (a: number, b: number) => void;
|
||||
export const wasmtemporalbtspattention_forward: (a: number, b: number, c: number) => void;
|
||||
export const wasmtensorcompress_compress: (a: number, b: number, c: number, d: number, e: number) => void;
|
||||
export const wasmtensorcompress_compressWithLevel: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
|
||||
export const wasmtensorcompress_decompress: (a: number, b: number, c: number) => void;
|
||||
export const wasmtensorcompress_getCompressionRatio: (a: number, b: number) => number;
|
||||
export const wasmtensorcompress_new: () => number;
|
||||
export const wasmtopologicalattention_forward: (a: number, b: number, c: number) => void;
|
||||
export const init: () => void;
|
||||
export const wasmmincutgatedattention_new: (a: number) => number;
|
||||
export const wasmtopologicalattention_new: (a: number) => number;
|
||||
export const __wbg_set_wasmsearchconfig_k: (a: number, b: number) => void;
|
||||
export const wasmcausalconeattention_new: (a: number, b: number) => number;
|
||||
export const wasmcriticalpathattention_new: (a: number, b: number) => number;
|
||||
export const wasmhierarchicallorentzattention_new: (a: number, b: number) => number;
|
||||
export const wasmparallelbranchattention_new: (a: number, b: number) => number;
|
||||
export const wasmsearchconfig_new: (a: number, b: number) => number;
|
||||
export const wasmtemporalbtspattention_new: (a: number, b: number) => number;
|
||||
export const __wbg_get_wasmsearchconfig_k: (a: number) => number;
|
||||
export const __wbg_graphattentionfactory_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmcriticalpathattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmhierarchicallorentzattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmlocalglobalattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmparallelbranchattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmsearchconfig_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmtemporalbtspattention_free: (a: number, b: number) => void;
|
||||
export const __wbg_wasmtopologicalattention_free: (a: number, b: number) => void;
|
||||
export const __wbindgen_export: (a: number, b: number) => number;
|
||||
export const __wbindgen_export2: (a: number, b: number, c: number, d: number) => number;
|
||||
export const __wbindgen_export3: (a: number) => void;
|
||||
export const __wbindgen_export4: (a: number, b: number, c: number) => void;
|
||||
export const __wbindgen_add_to_stack_pointer: (a: number) => number;
|
||||
export const __wbindgen_start: () => void;
|
||||
806
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/dag.rs
vendored
Normal file
806
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/dag.rs
vendored
Normal file
@@ -0,0 +1,806 @@
|
||||
//! DAG Attention Mechanisms (from ruvector-dag)
|
||||
//!
|
||||
//! Re-exports the 7 DAG-specific attention mechanisms:
|
||||
//! - Topological Attention
|
||||
//! - Causal Cone Attention
|
||||
//! - Critical Path Attention
|
||||
//! - MinCut-Gated Attention
|
||||
//! - Hierarchical Lorentz Attention
|
||||
//! - Parallel Branch Attention
|
||||
//! - Temporal BTSP Attention
|
||||
|
||||
use ruvector_dag::{OperatorNode, QueryDag};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
// ============================================================================
|
||||
// Minimal DAG for WASM
|
||||
// ============================================================================
|
||||
|
||||
/// Minimal DAG structure for WASM attention computation
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmQueryDag {
|
||||
inner: QueryDag,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmQueryDag {
|
||||
/// Create a new empty DAG
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> WasmQueryDag {
|
||||
WasmQueryDag {
|
||||
inner: QueryDag::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a node with operator type and cost
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `op_type` - Operator type: "scan", "filter", "join", "aggregate", "project", "sort"
|
||||
/// * `cost` - Estimated execution cost
|
||||
///
|
||||
/// # Returns
|
||||
/// Node ID
|
||||
#[wasm_bindgen(js_name = addNode)]
|
||||
pub fn add_node(&mut self, op_type: &str, cost: f32) -> u32 {
|
||||
let table_id = self.inner.node_count() as usize;
|
||||
let mut node = match op_type {
|
||||
"scan" => OperatorNode::seq_scan(table_id, &format!("table_{}", table_id)),
|
||||
"filter" => OperatorNode::filter(table_id, "condition"),
|
||||
"join" => OperatorNode::hash_join(table_id, "join_key"),
|
||||
"aggregate" => OperatorNode::aggregate(table_id, vec!["*".to_string()]),
|
||||
"project" => OperatorNode::project(table_id, vec!["*".to_string()]),
|
||||
"sort" => OperatorNode::sort(table_id, vec!["col".to_string()]),
|
||||
_ => OperatorNode::seq_scan(table_id, "unknown"),
|
||||
};
|
||||
node.estimated_cost = cost as f64;
|
||||
self.inner.add_node(node) as u32
|
||||
}
|
||||
|
||||
/// Add an edge between nodes
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `from` - Source node ID
|
||||
/// * `to` - Target node ID
|
||||
///
|
||||
/// # Returns
|
||||
/// True if edge was added successfully
|
||||
#[wasm_bindgen(js_name = addEdge)]
|
||||
pub fn add_edge(&mut self, from: u32, to: u32) -> bool {
|
||||
self.inner.add_edge(from as usize, to as usize).is_ok()
|
||||
}
|
||||
|
||||
/// Get the number of nodes
|
||||
#[wasm_bindgen(getter, js_name = nodeCount)]
|
||||
pub fn node_count(&self) -> u32 {
|
||||
self.inner.node_count() as u32
|
||||
}
|
||||
|
||||
/// Get the number of edges
|
||||
#[wasm_bindgen(getter, js_name = edgeCount)]
|
||||
pub fn edge_count(&self) -> u32 {
|
||||
self.inner.edge_count() as u32
|
||||
}
|
||||
|
||||
/// Serialize to JSON
|
||||
#[wasm_bindgen(js_name = toJson)]
|
||||
pub fn to_json(&self) -> String {
|
||||
serde_json::to_string(&DagSummary {
|
||||
node_count: self.inner.node_count(),
|
||||
edge_count: self.inner.edge_count(),
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
impl WasmQueryDag {
|
||||
/// Get internal reference
|
||||
pub(crate) fn inner(&self) -> &QueryDag {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct DagSummary {
|
||||
node_count: usize,
|
||||
edge_count: usize,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper trait for converting HashMap scores to Vec
|
||||
// ============================================================================
|
||||
|
||||
fn hashmap_to_vec(scores: &HashMap<usize, f32>, n: usize) -> Vec<f32> {
|
||||
(0..n)
|
||||
.map(|i| scores.get(&i).copied().unwrap_or(0.0))
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Topological Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Topological attention based on DAG position
|
||||
///
|
||||
/// Assigns attention scores based on node position in topological order.
|
||||
/// Earlier nodes (closer to sources) get higher attention.
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmTopologicalAttention {
|
||||
decay_factor: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmTopologicalAttention {
|
||||
/// Create a new topological attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `decay_factor` - Decay factor for position-based attention (0.0-1.0)
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(decay_factor: f32) -> WasmTopologicalAttention {
|
||||
WasmTopologicalAttention { decay_factor }
|
||||
}
|
||||
|
||||
/// Compute attention scores for the DAG
|
||||
///
|
||||
/// # Returns
|
||||
/// Attention scores for each node
|
||||
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
|
||||
let n = dag.inner.node_count();
|
||||
if n == 0 {
|
||||
return Err(JsError::new("Empty DAG"));
|
||||
}
|
||||
|
||||
let depths = dag.inner.compute_depths();
|
||||
let max_depth = depths.values().max().copied().unwrap_or(0);
|
||||
|
||||
let mut scores = HashMap::new();
|
||||
let mut total = 0.0f32;
|
||||
|
||||
for (&node_id, &depth) in &depths {
|
||||
let normalized_depth = depth as f32 / (max_depth.max(1) as f32);
|
||||
let score = self.decay_factor.powf(1.0 - normalized_depth);
|
||||
scores.insert(node_id, score);
|
||||
total += score;
|
||||
}
|
||||
|
||||
if total > 0.0 {
|
||||
for score in scores.values_mut() {
|
||||
*score /= total;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(hashmap_to_vec(&scores, n))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Causal Cone Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Causal cone attention based on dependency lightcones
|
||||
///
|
||||
/// Nodes can only attend to ancestors in the DAG (causal predecessors).
|
||||
/// Attention strength decays with causal distance.
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmCausalConeAttention {
|
||||
future_discount: f32,
|
||||
ancestor_weight: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmCausalConeAttention {
|
||||
/// Create a new causal cone attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `future_discount` - Discount for future nodes
|
||||
/// * `ancestor_weight` - Weight for ancestor influence
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(future_discount: f32, ancestor_weight: f32) -> WasmCausalConeAttention {
|
||||
WasmCausalConeAttention {
|
||||
future_discount,
|
||||
ancestor_weight,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute attention scores for the DAG
|
||||
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
|
||||
let n = dag.inner.node_count();
|
||||
if n == 0 {
|
||||
return Err(JsError::new("Empty DAG"));
|
||||
}
|
||||
|
||||
let mut scores = HashMap::new();
|
||||
let mut total = 0.0f32;
|
||||
|
||||
let depths = dag.inner.compute_depths();
|
||||
|
||||
for node_id in 0..n {
|
||||
if dag.inner.get_node(node_id).is_none() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let ancestors = dag.inner.ancestors(node_id);
|
||||
let ancestor_count = ancestors.len();
|
||||
|
||||
let mut score = 1.0 + (ancestor_count as f32 * self.ancestor_weight);
|
||||
|
||||
if let Some(&depth) = depths.get(&node_id) {
|
||||
score *= self.future_discount.powi(depth as i32);
|
||||
}
|
||||
|
||||
scores.insert(node_id, score);
|
||||
total += score;
|
||||
}
|
||||
|
||||
if total > 0.0 {
|
||||
for score in scores.values_mut() {
|
||||
*score /= total;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(hashmap_to_vec(&scores, n))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Critical Path Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Critical path attention weighted by path criticality
|
||||
///
|
||||
/// Nodes on or near the critical path (longest execution path)
|
||||
/// receive higher attention scores.
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmCriticalPathAttention {
|
||||
path_weight: f32,
|
||||
branch_penalty: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmCriticalPathAttention {
|
||||
/// Create a new critical path attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path_weight` - Weight for critical path membership
|
||||
/// * `branch_penalty` - Penalty for branching nodes
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(path_weight: f32, branch_penalty: f32) -> WasmCriticalPathAttention {
|
||||
WasmCriticalPathAttention {
|
||||
path_weight,
|
||||
branch_penalty,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the critical path (longest path by cost)
|
||||
fn compute_critical_path(&self, dag: &QueryDag) -> Vec<usize> {
|
||||
let mut longest_path: HashMap<usize, (f64, Vec<usize>)> = HashMap::new();
|
||||
|
||||
for &leaf in &dag.leaves() {
|
||||
if let Some(node) = dag.get_node(leaf) {
|
||||
longest_path.insert(leaf, (node.estimated_cost, vec![leaf]));
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(topo_order) = dag.topological_sort() {
|
||||
for &node_id in topo_order.iter().rev() {
|
||||
let node = match dag.get_node(node_id) {
|
||||
Some(n) => n,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let mut max_cost = node.estimated_cost;
|
||||
let mut max_path = vec![node_id];
|
||||
|
||||
for &child in dag.children(node_id) {
|
||||
if let Some(&(child_cost, ref child_path)) = longest_path.get(&child) {
|
||||
let total_cost = node.estimated_cost + child_cost;
|
||||
if total_cost > max_cost {
|
||||
max_cost = total_cost;
|
||||
max_path = vec![node_id];
|
||||
max_path.extend(child_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
longest_path.insert(node_id, (max_cost, max_path));
|
||||
}
|
||||
}
|
||||
|
||||
longest_path
|
||||
.into_iter()
|
||||
.max_by(|a, b| {
|
||||
a.1 .0
|
||||
.partial_cmp(&b.1 .0)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
.map(|(_, (_, path))| path)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Compute attention scores for the DAG
|
||||
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
|
||||
let n = dag.inner.node_count();
|
||||
if n == 0 {
|
||||
return Err(JsError::new("Empty DAG"));
|
||||
}
|
||||
|
||||
let critical = self.compute_critical_path(&dag.inner);
|
||||
let mut scores = HashMap::new();
|
||||
let mut total = 0.0f32;
|
||||
|
||||
for node_id in 0..n {
|
||||
if dag.inner.get_node(node_id).is_none() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let is_on_critical_path = critical.contains(&node_id);
|
||||
let num_children = dag.inner.children(node_id).len();
|
||||
|
||||
let mut score = if is_on_critical_path {
|
||||
self.path_weight
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
if num_children > 1 {
|
||||
score *= 1.0 + (num_children as f32 - 1.0) * self.branch_penalty;
|
||||
}
|
||||
|
||||
scores.insert(node_id, score);
|
||||
total += score;
|
||||
}
|
||||
|
||||
if total > 0.0 {
|
||||
for score in scores.values_mut() {
|
||||
*score /= total;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(hashmap_to_vec(&scores, n))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// MinCut-Gated Attention
|
||||
// ============================================================================
|
||||
|
||||
/// MinCut-gated attention using flow-based bottleneck detection
|
||||
///
|
||||
/// Uses minimum cut analysis to identify bottleneck nodes
|
||||
/// and gates attention through these critical points.
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmMinCutGatedAttention {
|
||||
gate_threshold: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmMinCutGatedAttention {
|
||||
/// Create a new MinCut-gated attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `gate_threshold` - Threshold for gating (0.0-1.0)
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(gate_threshold: f32) -> WasmMinCutGatedAttention {
|
||||
WasmMinCutGatedAttention { gate_threshold }
|
||||
}
|
||||
|
||||
/// Compute attention scores for the DAG
|
||||
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
|
||||
let n = dag.inner.node_count();
|
||||
if n == 0 {
|
||||
return Err(JsError::new("Empty DAG"));
|
||||
}
|
||||
|
||||
// Simple bottleneck detection: nodes with high in-degree and out-degree
|
||||
let mut scores = HashMap::new();
|
||||
let mut total = 0.0f32;
|
||||
|
||||
for node_id in 0..n {
|
||||
if dag.inner.get_node(node_id).is_none() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let in_degree = dag.inner.parents(node_id).len();
|
||||
let out_degree = dag.inner.children(node_id).len();
|
||||
|
||||
// Bottleneck score: higher for nodes with high connectivity
|
||||
let connectivity = (in_degree + out_degree) as f32;
|
||||
let is_bottleneck = connectivity >= self.gate_threshold * n as f32;
|
||||
|
||||
let score = if is_bottleneck {
|
||||
2.0 + connectivity * 0.1
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
scores.insert(node_id, score);
|
||||
total += score;
|
||||
}
|
||||
|
||||
if total > 0.0 {
|
||||
for score in scores.values_mut() {
|
||||
*score /= total;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(hashmap_to_vec(&scores, n))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Hierarchical Lorentz Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Hierarchical Lorentz attention in hyperbolic space
|
||||
///
|
||||
/// Combines DAG hierarchy with Lorentz (hyperboloid) geometry
|
||||
/// for multi-scale hierarchical attention.
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmHierarchicalLorentzAttention {
|
||||
curvature: f32,
|
||||
temperature: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmHierarchicalLorentzAttention {
|
||||
/// Create a new hierarchical Lorentz attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `curvature` - Hyperbolic curvature parameter
|
||||
/// * `temperature` - Temperature for softmax
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(curvature: f32, temperature: f32) -> WasmHierarchicalLorentzAttention {
|
||||
WasmHierarchicalLorentzAttention {
|
||||
curvature,
|
||||
temperature,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute attention scores for the DAG
|
||||
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
|
||||
let n = dag.inner.node_count();
|
||||
if n == 0 {
|
||||
return Err(JsError::new("Empty DAG"));
|
||||
}
|
||||
|
||||
let depths = dag.inner.compute_depths();
|
||||
let max_depth = depths.values().max().copied().unwrap_or(0);
|
||||
|
||||
// Compute hyperbolic distances from origin
|
||||
let mut distances: Vec<f32> = Vec::with_capacity(n);
|
||||
for node_id in 0..n {
|
||||
let depth = depths.get(&node_id).copied().unwrap_or(0);
|
||||
// In hyperbolic space, distance grows exponentially with depth
|
||||
let radial = (depth as f32 * 0.5).tanh();
|
||||
let distance = (1.0 + radial).acosh() * self.curvature.abs();
|
||||
distances.push(distance);
|
||||
}
|
||||
|
||||
// Convert to attention scores using softmax
|
||||
let max_neg_dist = distances
|
||||
.iter()
|
||||
.map(|&d| -d / self.temperature)
|
||||
.fold(f32::NEG_INFINITY, f32::max);
|
||||
let exp_sum: f32 = distances
|
||||
.iter()
|
||||
.map(|&d| ((-d / self.temperature) - max_neg_dist).exp())
|
||||
.sum();
|
||||
|
||||
let scores: Vec<f32> = distances
|
||||
.iter()
|
||||
.map(|&d| ((-d / self.temperature) - max_neg_dist).exp() / exp_sum.max(1e-10))
|
||||
.collect();
|
||||
|
||||
Ok(scores)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Parallel Branch Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Parallel branch attention for concurrent DAG branches
|
||||
///
|
||||
/// Identifies parallel branches in the DAG and applies
|
||||
/// attention patterns that respect branch independence.
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmParallelBranchAttention {
|
||||
max_branches: usize,
|
||||
sync_penalty: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmParallelBranchAttention {
|
||||
/// Create a new parallel branch attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `max_branches` - Maximum number of branches to consider
|
||||
/// * `sync_penalty` - Penalty for synchronization between branches
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(max_branches: usize, sync_penalty: f32) -> WasmParallelBranchAttention {
|
||||
WasmParallelBranchAttention {
|
||||
max_branches,
|
||||
sync_penalty,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute attention scores for the DAG
|
||||
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
|
||||
let n = dag.inner.node_count();
|
||||
if n == 0 {
|
||||
return Err(JsError::new("Empty DAG"));
|
||||
}
|
||||
|
||||
// Detect branch points (nodes with multiple children)
|
||||
let mut branch_starts: Vec<usize> = Vec::new();
|
||||
for node_id in 0..n {
|
||||
if dag.inner.children(node_id).len() > 1 {
|
||||
branch_starts.push(node_id);
|
||||
}
|
||||
}
|
||||
|
||||
let mut scores = HashMap::new();
|
||||
let mut total = 0.0f32;
|
||||
|
||||
for node_id in 0..n {
|
||||
if dag.inner.get_node(node_id).is_none() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if node is part of a parallel branch
|
||||
let parents = dag.inner.parents(node_id);
|
||||
let is_branch_child = parents.iter().any(|&p| branch_starts.contains(&p));
|
||||
|
||||
let children = dag.inner.children(node_id);
|
||||
let is_sync_point = children.len() == 0 && parents.len() > 1;
|
||||
|
||||
let score = if is_branch_child {
|
||||
1.5 // Boost parallel branch nodes
|
||||
} else if is_sync_point {
|
||||
1.0 * (1.0 - self.sync_penalty) // Penalize sync points
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
scores.insert(node_id, score);
|
||||
total += score;
|
||||
}
|
||||
|
||||
if total > 0.0 {
|
||||
for score in scores.values_mut() {
|
||||
*score /= total;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(hashmap_to_vec(&scores, n))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Temporal BTSP Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Temporal BTSP (Behavioral Time-Series Pattern) attention
|
||||
///
|
||||
/// Incorporates temporal patterns and behavioral sequences
|
||||
/// for time-aware DAG attention.
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmTemporalBTSPAttention {
|
||||
eligibility_decay: f32,
|
||||
baseline_attention: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmTemporalBTSPAttention {
|
||||
/// Create a new temporal BTSP attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `eligibility_decay` - Decay rate for eligibility traces (0.0-1.0)
|
||||
/// * `baseline_attention` - Baseline attention for nodes without history
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(eligibility_decay: f32, baseline_attention: f32) -> WasmTemporalBTSPAttention {
|
||||
WasmTemporalBTSPAttention {
|
||||
eligibility_decay,
|
||||
baseline_attention,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute attention scores for the DAG
|
||||
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
|
||||
let n = dag.inner.node_count();
|
||||
if n == 0 {
|
||||
return Err(JsError::new("Empty DAG"));
|
||||
}
|
||||
|
||||
let mut scores = Vec::with_capacity(n);
|
||||
let mut total = 0.0f32;
|
||||
|
||||
for node_id in 0..n {
|
||||
let node = match dag.inner.get_node(node_id) {
|
||||
Some(n) => n,
|
||||
None => {
|
||||
scores.push(0.0);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Base score from cost and rows
|
||||
let cost_factor = (node.estimated_cost as f32 / 100.0).min(1.0);
|
||||
let rows_factor = (node.estimated_rows as f32 / 1000.0).min(1.0);
|
||||
let score = self.baseline_attention * (0.5 * cost_factor + 0.5 * rows_factor + 0.5);
|
||||
|
||||
scores.push(score);
|
||||
total += score;
|
||||
}
|
||||
|
||||
// Normalize
|
||||
if total > 0.0 {
|
||||
for score in scores.iter_mut() {
|
||||
*score /= total;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(scores)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// DAG Attention Factory
|
||||
// ============================================================================
|
||||
|
||||
/// Factory for creating DAG attention mechanisms
|
||||
#[wasm_bindgen]
|
||||
pub struct DagAttentionFactory;
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl DagAttentionFactory {
|
||||
/// Get available DAG attention types
|
||||
#[wasm_bindgen(js_name = availableTypes)]
|
||||
pub fn available_types() -> JsValue {
|
||||
let types = vec![
|
||||
"topological",
|
||||
"causal_cone",
|
||||
"critical_path",
|
||||
"mincut_gated",
|
||||
"hierarchical_lorentz",
|
||||
"parallel_branch",
|
||||
"temporal_btsp",
|
||||
];
|
||||
serde_wasm_bindgen::to_value(&types).unwrap()
|
||||
}
|
||||
|
||||
/// Get description for a DAG attention type
|
||||
#[wasm_bindgen(js_name = getDescription)]
|
||||
pub fn get_description(attention_type: &str) -> String {
|
||||
match attention_type {
|
||||
"topological" => "Position-based attention following DAG topological order".to_string(),
|
||||
"causal_cone" => "Lightcone-based attention respecting causal dependencies".to_string(),
|
||||
"critical_path" => "Attention weighted by critical execution path distance".to_string(),
|
||||
"mincut_gated" => "Flow-based gating through bottleneck nodes".to_string(),
|
||||
"hierarchical_lorentz" => {
|
||||
"Multi-scale hyperbolic attention for DAG hierarchies".to_string()
|
||||
}
|
||||
"parallel_branch" => "Branch-aware attention for parallel DAG structures".to_string(),
|
||||
"temporal_btsp" => "Time-series pattern attention for temporal DAGs".to_string(),
|
||||
_ => "Unknown attention type".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wasm_bindgen_test::*;
|
||||
|
||||
wasm_bindgen_test_configure!(run_in_browser);
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_dag_creation() {
|
||||
let mut dag = WasmQueryDag::new();
|
||||
let n1 = dag.add_node("scan", 1.0);
|
||||
let n2 = dag.add_node("filter", 0.5);
|
||||
dag.add_edge(n1, n2);
|
||||
|
||||
assert_eq!(dag.node_count(), 2);
|
||||
assert_eq!(dag.edge_count(), 1);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_topological_attention() {
|
||||
let mut dag = WasmQueryDag::new();
|
||||
dag.add_node("scan", 1.0);
|
||||
dag.add_node("filter", 0.5);
|
||||
dag.add_node("project", 0.3);
|
||||
dag.add_edge(0, 1);
|
||||
dag.add_edge(1, 2);
|
||||
|
||||
let attention = WasmTopologicalAttention::new(0.9);
|
||||
let scores = attention.forward(&dag);
|
||||
assert!(scores.is_ok());
|
||||
let s = scores.unwrap();
|
||||
assert_eq!(s.len(), 3);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_causal_cone_attention() {
|
||||
let mut dag = WasmQueryDag::new();
|
||||
dag.add_node("scan", 1.0);
|
||||
dag.add_node("filter", 0.5);
|
||||
dag.add_edge(0, 1);
|
||||
|
||||
let attention = WasmCausalConeAttention::new(0.8, 0.9);
|
||||
let scores = attention.forward(&dag);
|
||||
assert!(scores.is_ok());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_critical_path_attention() {
|
||||
let mut dag = WasmQueryDag::new();
|
||||
dag.add_node("scan", 1.0);
|
||||
dag.add_node("filter", 0.5);
|
||||
dag.add_edge(0, 1);
|
||||
|
||||
let attention = WasmCriticalPathAttention::new(2.0, 0.5);
|
||||
let scores = attention.forward(&dag);
|
||||
assert!(scores.is_ok());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_mincut_gated_attention() {
|
||||
let mut dag = WasmQueryDag::new();
|
||||
dag.add_node("scan", 1.0);
|
||||
dag.add_node("filter", 0.5);
|
||||
dag.add_edge(0, 1);
|
||||
|
||||
let attention = WasmMinCutGatedAttention::new(0.5);
|
||||
let scores = attention.forward(&dag);
|
||||
assert!(scores.is_ok());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_hierarchical_lorentz_attention() {
|
||||
let mut dag = WasmQueryDag::new();
|
||||
dag.add_node("scan", 1.0);
|
||||
dag.add_node("filter", 0.5);
|
||||
dag.add_edge(0, 1);
|
||||
|
||||
let attention = WasmHierarchicalLorentzAttention::new(-1.0, 0.1);
|
||||
let scores = attention.forward(&dag);
|
||||
assert!(scores.is_ok());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_parallel_branch_attention() {
|
||||
let mut dag = WasmQueryDag::new();
|
||||
dag.add_node("scan", 1.0);
|
||||
dag.add_node("filter", 0.5);
|
||||
dag.add_edge(0, 1);
|
||||
|
||||
let attention = WasmParallelBranchAttention::new(8, 0.2);
|
||||
let scores = attention.forward(&dag);
|
||||
assert!(scores.is_ok());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_temporal_btsp_attention() {
|
||||
let mut dag = WasmQueryDag::new();
|
||||
dag.add_node("scan", 1.0);
|
||||
dag.add_node("filter", 0.5);
|
||||
dag.add_edge(0, 1);
|
||||
|
||||
let attention = WasmTemporalBTSPAttention::new(0.95, 0.5);
|
||||
let scores = attention.forward(&dag);
|
||||
assert!(scores.is_ok());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_factory_types() {
|
||||
let types_js = DagAttentionFactory::available_types();
|
||||
assert!(!types_js.is_null());
|
||||
}
|
||||
}
|
||||
417
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/graph.rs
vendored
Normal file
417
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/graph.rs
vendored
Normal file
@@ -0,0 +1,417 @@
|
||||
//! Graph Attention Mechanisms (from ruvector-gnn)
|
||||
//!
|
||||
//! Re-exports graph neural network attention mechanisms:
|
||||
//! - GAT (Graph Attention Networks)
|
||||
//! - GCN (Graph Convolutional Networks)
|
||||
//! - GraphSAGE (Sample and Aggregate)
|
||||
|
||||
use ruvector_gnn::{
|
||||
differentiable_search as core_differentiable_search,
|
||||
hierarchical_forward as core_hierarchical_forward, CompressedTensor, CompressionLevel,
|
||||
RuvectorLayer, TensorCompress,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
// ============================================================================
|
||||
// GNN Layer (GAT-based)
|
||||
// ============================================================================
|
||||
|
||||
/// Graph Neural Network layer with attention mechanism
|
||||
///
|
||||
/// Implements Graph Attention Networks (GAT) for HNSW topology.
|
||||
/// Each node aggregates information from neighbors using learned attention weights.
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmGNNLayer {
|
||||
inner: RuvectorLayer,
|
||||
hidden_dim: usize,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmGNNLayer {
|
||||
/// Create a new GNN layer with attention
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input_dim` - Dimension of input node embeddings
|
||||
/// * `hidden_dim` - Dimension of hidden representations
|
||||
/// * `heads` - Number of attention heads
|
||||
/// * `dropout` - Dropout rate (0.0 to 1.0)
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(
|
||||
input_dim: usize,
|
||||
hidden_dim: usize,
|
||||
heads: usize,
|
||||
dropout: f32,
|
||||
) -> Result<WasmGNNLayer, JsError> {
|
||||
let inner = RuvectorLayer::new(input_dim, hidden_dim, heads, dropout)
|
||||
.map_err(|e| JsError::new(&e.to_string()))?;
|
||||
|
||||
Ok(WasmGNNLayer { inner, hidden_dim })
|
||||
}
|
||||
|
||||
/// Forward pass through the GNN layer
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `node_embedding` - Current node's embedding (Float32Array)
|
||||
/// * `neighbor_embeddings` - Embeddings of neighbor nodes (array of Float32Arrays)
|
||||
/// * `edge_weights` - Weights of edges to neighbors (Float32Array)
|
||||
///
|
||||
/// # Returns
|
||||
/// Updated node embedding (Float32Array)
|
||||
pub fn forward(
|
||||
&self,
|
||||
node_embedding: Vec<f32>,
|
||||
neighbor_embeddings: JsValue,
|
||||
edge_weights: Vec<f32>,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let neighbors: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(neighbor_embeddings)
|
||||
.map_err(|e| JsError::new(&format!("Failed to parse neighbor embeddings: {}", e)))?;
|
||||
|
||||
if neighbors.len() != edge_weights.len() {
|
||||
return Err(JsError::new(&format!(
|
||||
"Number of neighbors ({}) must match number of edge weights ({})",
|
||||
neighbors.len(),
|
||||
edge_weights.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let result = self
|
||||
.inner
|
||||
.forward(&node_embedding, &neighbors, &edge_weights);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Get the output dimension
|
||||
#[wasm_bindgen(getter, js_name = outputDim)]
|
||||
pub fn output_dim(&self) -> usize {
|
||||
self.hidden_dim
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tensor Compression (for efficient GNN)
|
||||
// ============================================================================
|
||||
|
||||
/// Tensor compressor with adaptive level selection
|
||||
///
|
||||
/// Compresses embeddings based on access frequency for memory-efficient GNN
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmTensorCompress {
|
||||
inner: TensorCompress,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmTensorCompress {
|
||||
/// Create a new tensor compressor
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: TensorCompress::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compress an embedding based on access frequency
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `embedding` - The input embedding vector
|
||||
/// * `access_freq` - Access frequency in range [0.0, 1.0]
|
||||
/// - f > 0.8: Full precision (hot data)
|
||||
/// - f > 0.4: Half precision (warm data)
|
||||
/// - f > 0.1: 8-bit PQ (cool data)
|
||||
/// - f > 0.01: 4-bit PQ (cold data)
|
||||
/// - f <= 0.01: Binary (archive)
|
||||
pub fn compress(&self, embedding: Vec<f32>, access_freq: f32) -> Result<JsValue, JsError> {
|
||||
let compressed = self
|
||||
.inner
|
||||
.compress(&embedding, access_freq)
|
||||
.map_err(|e| JsError::new(&format!("Compression failed: {}", e)))?;
|
||||
|
||||
serde_wasm_bindgen::to_value(&compressed)
|
||||
.map_err(|e| JsError::new(&format!("Serialization failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Compress with explicit compression level
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `embedding` - The input embedding vector
|
||||
/// * `level` - Compression level: "none", "half", "pq8", "pq4", "binary"
|
||||
#[wasm_bindgen(js_name = compressWithLevel)]
|
||||
pub fn compress_with_level(
|
||||
&self,
|
||||
embedding: Vec<f32>,
|
||||
level: &str,
|
||||
) -> Result<JsValue, JsError> {
|
||||
let compression_level = match level {
|
||||
"none" => CompressionLevel::None,
|
||||
"half" => CompressionLevel::Half { scale: 1.0 },
|
||||
"pq8" => CompressionLevel::PQ8 {
|
||||
subvectors: 8,
|
||||
centroids: 16,
|
||||
},
|
||||
"pq4" => CompressionLevel::PQ4 {
|
||||
subvectors: 8,
|
||||
outlier_threshold: 3.0,
|
||||
},
|
||||
"binary" => CompressionLevel::Binary { threshold: 0.0 },
|
||||
_ => {
|
||||
return Err(JsError::new(&format!(
|
||||
"Unknown compression level: {}",
|
||||
level
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
let compressed = self
|
||||
.inner
|
||||
.compress_with_level(&embedding, &compression_level)
|
||||
.map_err(|e| JsError::new(&format!("Compression failed: {}", e)))?;
|
||||
|
||||
serde_wasm_bindgen::to_value(&compressed)
|
||||
.map_err(|e| JsError::new(&format!("Serialization failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Decompress a compressed tensor
|
||||
pub fn decompress(&self, compressed: JsValue) -> Result<Vec<f32>, JsError> {
|
||||
let compressed_tensor: CompressedTensor = serde_wasm_bindgen::from_value(compressed)
|
||||
.map_err(|e| JsError::new(&format!("Deserialization failed: {}", e)))?;
|
||||
|
||||
self.inner
|
||||
.decompress(&compressed_tensor)
|
||||
.map_err(|e| JsError::new(&format!("Decompression failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Get compression ratio estimate for a given access frequency
|
||||
#[wasm_bindgen(js_name = getCompressionRatio)]
|
||||
pub fn get_compression_ratio(&self, access_freq: f32) -> f32 {
|
||||
if access_freq > 0.8 {
|
||||
1.0
|
||||
} else if access_freq > 0.4 {
|
||||
2.0
|
||||
} else if access_freq > 0.1 {
|
||||
4.0
|
||||
} else if access_freq > 0.01 {
|
||||
8.0
|
||||
} else {
|
||||
32.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Search Configuration
|
||||
// ============================================================================
|
||||
|
||||
/// Search configuration for differentiable search
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmSearchConfig {
|
||||
/// Number of top results to return
|
||||
pub k: usize,
|
||||
/// Temperature for softmax
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmSearchConfig {
|
||||
/// Create a new search configuration
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(k: usize, temperature: f32) -> Self {
|
||||
Self { k, temperature }
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Differentiable Search
|
||||
// ============================================================================
|
||||
|
||||
/// Differentiable search using soft attention mechanism
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - The query vector
|
||||
/// * `candidate_embeddings` - List of candidate embedding vectors
|
||||
/// * `config` - Search configuration
|
||||
///
|
||||
/// # Returns
|
||||
/// Object with indices and weights for top-k candidates
|
||||
#[wasm_bindgen(js_name = graphDifferentiableSearch)]
|
||||
pub fn differentiable_search(
|
||||
query: Vec<f32>,
|
||||
candidate_embeddings: JsValue,
|
||||
config: &WasmSearchConfig,
|
||||
) -> Result<JsValue, JsError> {
|
||||
let candidates: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(candidate_embeddings)
|
||||
.map_err(|e| JsError::new(&format!("Failed to parse candidate embeddings: {}", e)))?;
|
||||
|
||||
let (indices, weights) =
|
||||
core_differentiable_search(&query, &candidates, config.k, config.temperature);
|
||||
|
||||
let result = SearchResult { indices, weights };
|
||||
serde_wasm_bindgen::to_value(&result)
|
||||
.map_err(|e| JsError::new(&format!("Failed to serialize result: {}", e)))
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SearchResult {
|
||||
indices: Vec<usize>,
|
||||
weights: Vec<f32>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Hierarchical Forward
|
||||
// ============================================================================
|
||||
|
||||
/// Hierarchical forward pass through multiple GNN layers
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - The query vector
|
||||
/// * `layer_embeddings` - Embeddings organized by layer
|
||||
/// * `gnn_layers` - Array of GNN layers
|
||||
///
|
||||
/// # Returns
|
||||
/// Final embedding after hierarchical processing
|
||||
#[wasm_bindgen(js_name = graphHierarchicalForward)]
|
||||
pub fn hierarchical_forward(
|
||||
query: Vec<f32>,
|
||||
layer_embeddings: JsValue,
|
||||
gnn_layers: Vec<WasmGNNLayer>,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let embeddings: Vec<Vec<Vec<f32>>> = serde_wasm_bindgen::from_value(layer_embeddings)
|
||||
.map_err(|e| JsError::new(&format!("Failed to parse layer embeddings: {}", e)))?;
|
||||
|
||||
let core_layers: Vec<RuvectorLayer> = gnn_layers.iter().map(|l| l.inner.clone()).collect();
|
||||
|
||||
let result = core_hierarchical_forward(&query, &embeddings, &core_layers);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Graph Attention Types
|
||||
// ============================================================================
|
||||
|
||||
/// Graph attention mechanism types
|
||||
#[wasm_bindgen]
|
||||
pub enum GraphAttentionType {
|
||||
/// Graph Attention Networks (Velickovic et al., 2018)
|
||||
GAT,
|
||||
/// Graph Convolutional Networks (Kipf & Welling, 2017)
|
||||
GCN,
|
||||
/// GraphSAGE (Hamilton et al., 2017)
|
||||
GraphSAGE,
|
||||
}
|
||||
|
||||
/// Factory for graph attention information
|
||||
#[wasm_bindgen]
|
||||
pub struct GraphAttentionFactory;
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl GraphAttentionFactory {
|
||||
/// Get available graph attention types
|
||||
#[wasm_bindgen(js_name = availableTypes)]
|
||||
pub fn available_types() -> JsValue {
|
||||
let types = vec!["gat", "gcn", "graphsage"];
|
||||
serde_wasm_bindgen::to_value(&types).unwrap()
|
||||
}
|
||||
|
||||
/// Get description for a graph attention type
|
||||
#[wasm_bindgen(js_name = getDescription)]
|
||||
pub fn get_description(attention_type: &str) -> String {
|
||||
match attention_type {
|
||||
"gat" => {
|
||||
"Graph Attention Networks - learns attention weights over neighbors".to_string()
|
||||
}
|
||||
"gcn" => "Graph Convolutional Networks - spectral convolution on graphs".to_string(),
|
||||
"graphsage" => "GraphSAGE - sample and aggregate neighbor features".to_string(),
|
||||
_ => "Unknown graph attention type".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get recommended use cases for a graph attention type
|
||||
#[wasm_bindgen(js_name = getUseCases)]
|
||||
pub fn get_use_cases(attention_type: &str) -> JsValue {
|
||||
let cases = match attention_type {
|
||||
"gat" => vec![
|
||||
"Node classification with varying neighbor importance",
|
||||
"Link prediction in heterogeneous graphs",
|
||||
"Knowledge graph reasoning",
|
||||
],
|
||||
"gcn" => vec![
|
||||
"Semi-supervised node classification",
|
||||
"Graph-level classification",
|
||||
"Spectral clustering",
|
||||
],
|
||||
"graphsage" => vec![
|
||||
"Inductive learning on new nodes",
|
||||
"Large-scale graph processing",
|
||||
"Dynamic graphs with new vertices",
|
||||
],
|
||||
_ => vec!["Unknown type"],
|
||||
};
|
||||
serde_wasm_bindgen::to_value(&cases).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wasm_bindgen_test::*;
|
||||
|
||||
wasm_bindgen_test_configure!(run_in_browser);
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_gnn_layer_creation() {
|
||||
let layer = WasmGNNLayer::new(4, 8, 2, 0.1);
|
||||
assert!(layer.is_ok());
|
||||
let l = layer.unwrap();
|
||||
assert_eq!(l.output_dim(), 8);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_gnn_layer_invalid_dropout() {
|
||||
let layer = WasmGNNLayer::new(4, 8, 2, 1.5);
|
||||
assert!(layer.is_err());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_gnn_layer_invalid_heads() {
|
||||
let layer = WasmGNNLayer::new(4, 7, 3, 0.1);
|
||||
assert!(layer.is_err());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_tensor_compress_creation() {
|
||||
let compressor = WasmTensorCompress::new();
|
||||
assert_eq!(compressor.get_compression_ratio(1.0), 1.0);
|
||||
assert_eq!(compressor.get_compression_ratio(0.5), 2.0);
|
||||
assert_eq!(compressor.get_compression_ratio(0.2), 4.0);
|
||||
assert_eq!(compressor.get_compression_ratio(0.05), 8.0);
|
||||
assert_eq!(compressor.get_compression_ratio(0.005), 32.0);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_search_config() {
|
||||
let config = WasmSearchConfig::new(5, 1.0);
|
||||
assert_eq!(config.k, 5);
|
||||
assert_eq!(config.temperature, 1.0);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_factory_types() {
|
||||
let types_js = GraphAttentionFactory::available_types();
|
||||
assert!(!types_js.is_null());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_factory_descriptions() {
|
||||
let desc = GraphAttentionFactory::get_description("gat");
|
||||
assert!(desc.contains("Graph Attention"));
|
||||
|
||||
let desc = GraphAttentionFactory::get_description("gcn");
|
||||
assert!(desc.contains("Graph Convolutional"));
|
||||
|
||||
let desc = GraphAttentionFactory::get_description("graphsage");
|
||||
assert!(desc.contains("GraphSAGE"));
|
||||
}
|
||||
}
|
||||
382
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/lib.rs
vendored
Normal file
382
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/lib.rs
vendored
Normal file
@@ -0,0 +1,382 @@
|
||||
//! Unified WebAssembly Attention Library
|
||||
//!
|
||||
//! This crate provides a unified WASM interface for 18+ attention mechanisms:
|
||||
//!
|
||||
//! ## Neural Attention (from ruvector-attention)
|
||||
//! - **Scaled Dot-Product**: Standard transformer attention
|
||||
//! - **Multi-Head**: Parallel attention heads
|
||||
//! - **Hyperbolic**: Attention in hyperbolic space for hierarchical data
|
||||
//! - **Linear**: O(n) Performer-style attention
|
||||
//! - **Flash**: Memory-efficient blocked attention
|
||||
//! - **Local-Global**: Sparse attention with global tokens
|
||||
//! - **MoE**: Mixture of Experts attention
|
||||
//!
|
||||
//! ## DAG Attention (from ruvector-dag)
|
||||
//! - **Topological**: Position-aware attention in DAG order
|
||||
//! - **Causal Cone**: Lightcone-based causal attention
|
||||
//! - **Critical Path**: Attention weighted by critical path distance
|
||||
//! - **MinCut-Gated**: Flow-based gating attention
|
||||
//! - **Hierarchical Lorentz**: Multi-scale hyperbolic DAG attention
|
||||
//! - **Parallel Branch**: Attention for parallel DAG branches
|
||||
//! - **Temporal BTSP**: Behavioral Time-Series Pattern attention
|
||||
//!
|
||||
//! ## Graph Attention (from ruvector-gnn)
|
||||
//! - **GAT**: Graph Attention Networks
|
||||
//! - **GCN**: Graph Convolutional Networks
|
||||
//! - **GraphSAGE**: Sampling and Aggregating graph embeddings
|
||||
//!
|
||||
//! ## State Space Models
|
||||
//! - **Mamba SSM**: Selective State Space Model attention
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
// Use wee_alloc for smaller WASM binary (~10KB reduction)
|
||||
#[cfg(feature = "wee_alloc")]
|
||||
#[global_allocator]
|
||||
static ALLOC: wee_alloc::WeeAlloc = wee_alloc::WeeAlloc::INIT;
|
||||
|
||||
// ============================================================================
|
||||
// Module declarations
|
||||
// ============================================================================
|
||||
|
||||
pub mod mamba;
|
||||
|
||||
mod dag;
|
||||
mod graph;
|
||||
mod neural;
|
||||
|
||||
// ============================================================================
|
||||
// Re-exports for convenient access
|
||||
// ============================================================================
|
||||
|
||||
pub use dag::*;
|
||||
pub use graph::*;
|
||||
pub use mamba::*;
|
||||
pub use neural::*;
|
||||
|
||||
// ============================================================================
|
||||
// Initialization
|
||||
// ============================================================================
|
||||
|
||||
/// Initialize the WASM module with panic hook for better error messages
|
||||
#[wasm_bindgen(start)]
|
||||
pub fn init() {
|
||||
#[cfg(feature = "console_error_panic_hook")]
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Version and Info
|
||||
// ============================================================================
|
||||
|
||||
/// Get the version of the unified attention WASM crate
|
||||
#[wasm_bindgen]
|
||||
pub fn version() -> String {
|
||||
env!("CARGO_PKG_VERSION").to_string()
|
||||
}
|
||||
|
||||
/// Get information about all available attention mechanisms
|
||||
#[wasm_bindgen(js_name = availableMechanisms)]
|
||||
pub fn available_mechanisms() -> JsValue {
|
||||
let mechanisms = AttentionMechanisms {
|
||||
neural: vec![
|
||||
"scaled_dot_product".into(),
|
||||
"multi_head".into(),
|
||||
"hyperbolic".into(),
|
||||
"linear".into(),
|
||||
"flash".into(),
|
||||
"local_global".into(),
|
||||
"moe".into(),
|
||||
],
|
||||
dag: vec![
|
||||
"topological".into(),
|
||||
"causal_cone".into(),
|
||||
"critical_path".into(),
|
||||
"mincut_gated".into(),
|
||||
"hierarchical_lorentz".into(),
|
||||
"parallel_branch".into(),
|
||||
"temporal_btsp".into(),
|
||||
],
|
||||
graph: vec!["gat".into(), "gcn".into(), "graphsage".into()],
|
||||
ssm: vec!["mamba".into()],
|
||||
};
|
||||
serde_wasm_bindgen::to_value(&mechanisms).unwrap()
|
||||
}
|
||||
|
||||
/// Get summary statistics about the unified attention library
|
||||
#[wasm_bindgen(js_name = getStats)]
|
||||
pub fn get_stats() -> JsValue {
|
||||
let stats = UnifiedStats {
|
||||
total_mechanisms: 18,
|
||||
neural_count: 7,
|
||||
dag_count: 7,
|
||||
graph_count: 3,
|
||||
ssm_count: 1,
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
};
|
||||
serde_wasm_bindgen::to_value(&stats).unwrap()
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Internal Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct AttentionMechanisms {
|
||||
neural: Vec<String>,
|
||||
dag: Vec<String>,
|
||||
graph: Vec<String>,
|
||||
ssm: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct UnifiedStats {
|
||||
total_mechanisms: usize,
|
||||
neural_count: usize,
|
||||
dag_count: usize,
|
||||
graph_count: usize,
|
||||
ssm_count: usize,
|
||||
version: String,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Unified Attention Selector
|
||||
// ============================================================================
|
||||
|
||||
/// Unified attention mechanism selector
|
||||
/// Automatically routes to the appropriate attention implementation
|
||||
#[wasm_bindgen]
|
||||
pub struct UnifiedAttention {
|
||||
mechanism_type: String,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl UnifiedAttention {
|
||||
/// Create a new unified attention selector
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(mechanism: &str) -> Result<UnifiedAttention, JsError> {
|
||||
let valid_mechanisms = [
|
||||
// Neural
|
||||
"scaled_dot_product",
|
||||
"multi_head",
|
||||
"hyperbolic",
|
||||
"linear",
|
||||
"flash",
|
||||
"local_global",
|
||||
"moe",
|
||||
// DAG
|
||||
"topological",
|
||||
"causal_cone",
|
||||
"critical_path",
|
||||
"mincut_gated",
|
||||
"hierarchical_lorentz",
|
||||
"parallel_branch",
|
||||
"temporal_btsp",
|
||||
// Graph
|
||||
"gat",
|
||||
"gcn",
|
||||
"graphsage",
|
||||
// SSM
|
||||
"mamba",
|
||||
];
|
||||
|
||||
if !valid_mechanisms.contains(&mechanism) {
|
||||
return Err(JsError::new(&format!(
|
||||
"Unknown mechanism: {}. Valid options: {:?}",
|
||||
mechanism, valid_mechanisms
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
mechanism_type: mechanism.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the currently selected mechanism type
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn mechanism(&self) -> String {
|
||||
self.mechanism_type.clone()
|
||||
}
|
||||
|
||||
/// Get the category of the selected mechanism
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn category(&self) -> String {
|
||||
match self.mechanism_type.as_str() {
|
||||
"scaled_dot_product" | "multi_head" | "hyperbolic" | "linear" | "flash"
|
||||
| "local_global" | "moe" => "neural".to_string(),
|
||||
|
||||
"topological"
|
||||
| "causal_cone"
|
||||
| "critical_path"
|
||||
| "mincut_gated"
|
||||
| "hierarchical_lorentz"
|
||||
| "parallel_branch"
|
||||
| "temporal_btsp" => "dag".to_string(),
|
||||
|
||||
"gat" | "gcn" | "graphsage" => "graph".to_string(),
|
||||
|
||||
"mamba" => "ssm".to_string(),
|
||||
|
||||
_ => "unknown".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this mechanism supports sequence processing
|
||||
#[wasm_bindgen(js_name = supportsSequences)]
|
||||
pub fn supports_sequences(&self) -> bool {
|
||||
matches!(
|
||||
self.mechanism_type.as_str(),
|
||||
"scaled_dot_product" | "multi_head" | "linear" | "flash" | "local_global" | "mamba"
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if this mechanism supports graph/DAG structures
|
||||
#[wasm_bindgen(js_name = supportsGraphs)]
|
||||
pub fn supports_graphs(&self) -> bool {
|
||||
matches!(
|
||||
self.mechanism_type.as_str(),
|
||||
"topological"
|
||||
| "causal_cone"
|
||||
| "critical_path"
|
||||
| "mincut_gated"
|
||||
| "hierarchical_lorentz"
|
||||
| "parallel_branch"
|
||||
| "temporal_btsp"
|
||||
| "gat"
|
||||
| "gcn"
|
||||
| "graphsage"
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if this mechanism supports hyperbolic geometry
|
||||
#[wasm_bindgen(js_name = supportsHyperbolic)]
|
||||
pub fn supports_hyperbolic(&self) -> bool {
|
||||
matches!(
|
||||
self.mechanism_type.as_str(),
|
||||
"hyperbolic" | "hierarchical_lorentz"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Utility Functions
|
||||
// ============================================================================
|
||||
|
||||
/// Compute cosine similarity between two vectors
|
||||
#[wasm_bindgen(js_name = cosineSimilarity)]
|
||||
pub fn cosine_similarity(a: Vec<f32>, b: Vec<f32>) -> Result<f32, JsError> {
|
||||
if a.len() != b.len() {
|
||||
return Err(JsError::new(&format!(
|
||||
"Vector dimensions must match: {} vs {}",
|
||||
a.len(),
|
||||
b.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a == 0.0 || norm_b == 0.0 {
|
||||
Ok(0.0)
|
||||
} else {
|
||||
Ok(dot / (norm_a * norm_b))
|
||||
}
|
||||
}
|
||||
|
||||
/// Softmax normalization
|
||||
#[wasm_bindgen]
|
||||
pub fn softmax(values: Vec<f32>) -> Vec<f32> {
|
||||
let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
|
||||
let exp_values: Vec<f32> = values.iter().map(|&x| (x - max_val).exp()).collect();
|
||||
let sum: f32 = exp_values.iter().sum();
|
||||
exp_values.iter().map(|&x| x / sum).collect()
|
||||
}
|
||||
|
||||
/// Temperature-scaled softmax
|
||||
#[wasm_bindgen(js_name = temperatureSoftmax)]
|
||||
pub fn temperature_softmax(values: Vec<f32>, temperature: f32) -> Vec<f32> {
|
||||
if temperature <= 0.0 {
|
||||
// Return one-hot for the maximum
|
||||
let max_idx = values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||
.map(|(i, _)| i)
|
||||
.unwrap_or(0);
|
||||
let mut result = vec![0.0; values.len()];
|
||||
result[max_idx] = 1.0;
|
||||
return result;
|
||||
}
|
||||
|
||||
let scaled: Vec<f32> = values.iter().map(|&x| x / temperature).collect();
|
||||
softmax(scaled)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wasm_bindgen_test::*;
|
||||
|
||||
wasm_bindgen_test_configure!(run_in_browser);
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_version() {
|
||||
assert!(!version().is_empty());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_unified_attention_creation() {
|
||||
let attention = UnifiedAttention::new("multi_head");
|
||||
assert!(attention.is_ok());
|
||||
|
||||
let invalid = UnifiedAttention::new("invalid_mechanism");
|
||||
assert!(invalid.is_err());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_mechanism_categories() {
|
||||
let neural = UnifiedAttention::new("multi_head").unwrap();
|
||||
assert_eq!(neural.category(), "neural");
|
||||
|
||||
let dag = UnifiedAttention::new("topological").unwrap();
|
||||
assert_eq!(dag.category(), "dag");
|
||||
|
||||
let graph = UnifiedAttention::new("gat").unwrap();
|
||||
assert_eq!(graph.category(), "graph");
|
||||
|
||||
let ssm = UnifiedAttention::new("mamba").unwrap();
|
||||
assert_eq!(ssm.category(), "ssm");
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_softmax() {
|
||||
let input = vec![1.0, 2.0, 3.0];
|
||||
let output = softmax(input);
|
||||
|
||||
// Sum should be 1.0
|
||||
let sum: f32 = output.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-6);
|
||||
|
||||
// Should be monotonically increasing
|
||||
assert!(output[0] < output[1]);
|
||||
assert!(output[1] < output[2]);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_cosine_similarity() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
let sim = cosine_similarity(a, b).unwrap();
|
||||
assert!((sim - 1.0).abs() < 1e-6);
|
||||
|
||||
let c = vec![1.0, 0.0, 0.0];
|
||||
let d = vec![0.0, 1.0, 0.0];
|
||||
let sim2 = cosine_similarity(c, d).unwrap();
|
||||
assert!(sim2.abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
554
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/mamba.rs
vendored
Normal file
554
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/mamba.rs
vendored
Normal file
@@ -0,0 +1,554 @@
|
||||
//! Mamba SSM (Selective State Space Model) Attention Mechanism
|
||||
//!
|
||||
//! Implements the Mamba architecture's selective scan mechanism for efficient
|
||||
//! sequence modeling with linear time complexity O(n).
|
||||
//!
|
||||
//! Key Features:
|
||||
//! - **Selective Scan**: Input-dependent state transitions
|
||||
//! - **Linear Complexity**: O(n) vs O(n^2) for standard attention
|
||||
//! - **Hardware Efficient**: Optimized for parallel scan operations
|
||||
//! - **Long Context**: Handles very long sequences efficiently
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! Mamba uses a selective state space model:
|
||||
//! ```text
|
||||
//! h_t = A_t * h_{t-1} + B_t * x_t
|
||||
//! y_t = C_t * h_t
|
||||
//! ```
|
||||
//!
|
||||
//! Where A_t, B_t, C_t are input-dependent (selective), computed from x_t.
|
||||
//!
|
||||
//! ## References
|
||||
//!
|
||||
//! - Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Gu & Dao, 2023)
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
// ============================================================================
|
||||
// Configuration
|
||||
// ============================================================================
|
||||
|
||||
/// Configuration for Mamba SSM attention
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[wasm_bindgen]
|
||||
pub struct MambaConfig {
|
||||
/// Model dimension (d_model)
|
||||
pub dim: usize,
|
||||
/// State space dimension (n)
|
||||
pub state_dim: usize,
|
||||
/// Expansion factor for inner dimension
|
||||
pub expand_factor: usize,
|
||||
/// Convolution kernel size
|
||||
pub conv_kernel_size: usize,
|
||||
/// Delta (discretization step) range minimum
|
||||
pub dt_min: f32,
|
||||
/// Delta range maximum
|
||||
pub dt_max: f32,
|
||||
/// Whether to use learnable D skip connection
|
||||
pub use_d_skip: bool,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl MambaConfig {
|
||||
/// Create a new Mamba configuration
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize) -> MambaConfig {
|
||||
MambaConfig {
|
||||
dim,
|
||||
state_dim: 16,
|
||||
expand_factor: 2,
|
||||
conv_kernel_size: 4,
|
||||
dt_min: 0.001,
|
||||
dt_max: 0.1,
|
||||
use_d_skip: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set state space dimension
|
||||
#[wasm_bindgen(js_name = withStateDim)]
|
||||
pub fn with_state_dim(mut self, state_dim: usize) -> MambaConfig {
|
||||
self.state_dim = state_dim;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set expansion factor
|
||||
#[wasm_bindgen(js_name = withExpandFactor)]
|
||||
pub fn with_expand_factor(mut self, factor: usize) -> MambaConfig {
|
||||
self.expand_factor = factor;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set convolution kernel size
|
||||
#[wasm_bindgen(js_name = withConvKernelSize)]
|
||||
pub fn with_conv_kernel_size(mut self, size: usize) -> MambaConfig {
|
||||
self.conv_kernel_size = size;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MambaConfig {
|
||||
fn default() -> Self {
|
||||
MambaConfig::new(256)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// State Space Parameters
|
||||
// ============================================================================
|
||||
|
||||
/// Selective state space parameters (input-dependent)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct SelectiveSSMParams {
|
||||
/// Discretized A matrix diagonal (batch, seq_len, state_dim)
|
||||
a_bar: Vec<Vec<Vec<f32>>>,
|
||||
/// Discretized B matrix (batch, seq_len, state_dim)
|
||||
b_bar: Vec<Vec<Vec<f32>>>,
|
||||
/// Output projection C (batch, seq_len, state_dim)
|
||||
c: Vec<Vec<Vec<f32>>>,
|
||||
/// Discretization step delta (batch, seq_len, inner_dim)
|
||||
delta: Vec<Vec<Vec<f32>>>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Mamba SSM Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Mamba Selective State Space Model for sequence attention
|
||||
///
|
||||
/// Provides O(n) attention-like mechanism using selective state spaces
|
||||
#[wasm_bindgen]
|
||||
pub struct MambaSSMAttention {
|
||||
config: MambaConfig,
|
||||
/// Inner dimension after expansion
|
||||
inner_dim: usize,
|
||||
/// A parameter (state_dim,) - diagonal of continuous A
|
||||
a_log: Vec<f32>,
|
||||
/// D skip connection (inner_dim,)
|
||||
d_skip: Vec<f32>,
|
||||
/// Projection weights (simplified for WASM)
|
||||
in_proj: Vec<Vec<f32>>,
|
||||
out_proj: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl MambaSSMAttention {
|
||||
/// Create a new Mamba SSM attention layer
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(config: MambaConfig) -> MambaSSMAttention {
|
||||
let inner_dim = config.dim * config.expand_factor;
|
||||
|
||||
// Initialize A as negative values (for stability) - log of eigenvalues
|
||||
let a_log: Vec<f32> = (0..config.state_dim)
|
||||
.map(|i| -((i + 1) as f32).ln())
|
||||
.collect();
|
||||
|
||||
// D skip connection
|
||||
let d_skip = vec![1.0; inner_dim];
|
||||
|
||||
// Simplified projection matrices (identity-like for stub)
|
||||
let in_proj: Vec<Vec<f32>> = (0..inner_dim)
|
||||
.map(|i| {
|
||||
let mut row = vec![0.0; config.dim];
|
||||
if i < config.dim {
|
||||
row[i] = 1.0;
|
||||
}
|
||||
row
|
||||
})
|
||||
.collect();
|
||||
|
||||
let out_proj: Vec<Vec<f32>> = (0..config.dim)
|
||||
.map(|i| {
|
||||
let mut row = vec![0.0; inner_dim];
|
||||
if i < inner_dim {
|
||||
row[i] = 1.0;
|
||||
}
|
||||
row
|
||||
})
|
||||
.collect();
|
||||
|
||||
MambaSSMAttention {
|
||||
config,
|
||||
inner_dim,
|
||||
a_log,
|
||||
d_skip,
|
||||
in_proj,
|
||||
out_proj,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with default configuration
|
||||
#[wasm_bindgen(js_name = withDefaults)]
|
||||
pub fn with_defaults(dim: usize) -> MambaSSMAttention {
|
||||
MambaSSMAttention::new(MambaConfig::new(dim))
|
||||
}
|
||||
|
||||
/// Forward pass through Mamba SSM
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input` - Input sequence (seq_len, dim) flattened to 1D
|
||||
/// * `seq_len` - Sequence length
|
||||
///
|
||||
/// # Returns
|
||||
/// Output sequence (seq_len, dim) flattened to 1D
|
||||
#[wasm_bindgen]
|
||||
pub fn forward(&self, input: Vec<f32>, seq_len: usize) -> Result<Vec<f32>, JsError> {
|
||||
let dim = self.config.dim;
|
||||
|
||||
if input.len() != seq_len * dim {
|
||||
return Err(JsError::new(&format!(
|
||||
"Input size mismatch: expected {} ({}x{}), got {}",
|
||||
seq_len * dim,
|
||||
seq_len,
|
||||
dim,
|
||||
input.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Reshape input to 2D
|
||||
let input_2d: Vec<Vec<f32>> = (0..seq_len)
|
||||
.map(|t| input[t * dim..(t + 1) * dim].to_vec())
|
||||
.collect();
|
||||
|
||||
// Step 1: Input projection to inner_dim
|
||||
let projected = self.project_in(&input_2d);
|
||||
|
||||
// Step 2: Compute selective SSM parameters from input
|
||||
let ssm_params = self.compute_selective_params(&projected);
|
||||
|
||||
// Step 3: Run selective scan
|
||||
let ssm_output = self.selective_scan(&projected, &ssm_params);
|
||||
|
||||
// Step 4: Apply D skip connection
|
||||
let with_skip: Vec<Vec<f32>> = ssm_output
|
||||
.iter()
|
||||
.zip(projected.iter())
|
||||
.map(|(y, x)| {
|
||||
y.iter()
|
||||
.zip(x.iter())
|
||||
.zip(self.d_skip.iter())
|
||||
.map(|((yi, xi), di)| yi + di * xi)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Step 5: Output projection
|
||||
let output = self.project_out(&with_skip);
|
||||
|
||||
// Flatten output
|
||||
Ok(output.into_iter().flatten().collect())
|
||||
}
|
||||
|
||||
/// Get the configuration
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn config(&self) -> MambaConfig {
|
||||
self.config.clone()
|
||||
}
|
||||
|
||||
/// Get the inner dimension
|
||||
#[wasm_bindgen(getter, js_name = innerDim)]
|
||||
pub fn inner_dim(&self) -> usize {
|
||||
self.inner_dim
|
||||
}
|
||||
|
||||
/// Compute attention-like scores (for visualization/analysis)
|
||||
///
|
||||
/// Returns pseudo-attention scores showing which positions influence output
|
||||
#[wasm_bindgen(js_name = getAttentionScores)]
|
||||
pub fn get_attention_scores(
|
||||
&self,
|
||||
input: Vec<f32>,
|
||||
seq_len: usize,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let dim = self.config.dim;
|
||||
|
||||
if input.len() != seq_len * dim {
|
||||
return Err(JsError::new(&format!(
|
||||
"Input size mismatch: expected {}, got {}",
|
||||
seq_len * dim,
|
||||
input.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Compute approximate attention scores based on state decay
|
||||
// This shows how much each position can "attend to" previous positions
|
||||
let mut scores = vec![0.0f32; seq_len * seq_len];
|
||||
|
||||
for t in 0..seq_len {
|
||||
for s in 0..=t {
|
||||
// Exponential decay based on distance and A parameters
|
||||
let distance = (t - s) as f32;
|
||||
let decay: f32 = self
|
||||
.a_log
|
||||
.iter()
|
||||
.map(|&a| (a * distance).exp())
|
||||
.sum::<f32>()
|
||||
/ self.config.state_dim as f32;
|
||||
|
||||
scores[t * seq_len + s] = decay;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(scores)
|
||||
}
|
||||
}
|
||||
|
||||
// Internal implementation methods
|
||||
impl MambaSSMAttention {
|
||||
/// Project input from dim to inner_dim
|
||||
fn project_in(&self, input: &[Vec<f32>]) -> Vec<Vec<f32>> {
|
||||
input
|
||||
.iter()
|
||||
.map(|x| {
|
||||
self.in_proj
|
||||
.iter()
|
||||
.map(|row| row.iter().zip(x.iter()).map(|(w, xi)| w * xi).sum())
|
||||
.collect()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Project from inner_dim back to dim
|
||||
fn project_out(&self, input: &[Vec<f32>]) -> Vec<Vec<f32>> {
|
||||
input
|
||||
.iter()
|
||||
.map(|x| {
|
||||
self.out_proj
|
||||
.iter()
|
||||
.map(|row| row.iter().zip(x.iter()).map(|(w, xi)| w * xi).sum())
|
||||
.collect()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute selective SSM parameters from input
|
||||
fn compute_selective_params(&self, input: &[Vec<f32>]) -> SelectiveSSMParams {
|
||||
let seq_len = input.len();
|
||||
let state_dim = self.config.state_dim;
|
||||
|
||||
// Compute input-dependent delta, B, C
|
||||
// Simplified: use sigmoid/tanh of input projections
|
||||
|
||||
let mut a_bar = vec![vec![vec![0.0; state_dim]; self.inner_dim]; seq_len];
|
||||
let mut b_bar = vec![vec![vec![0.0; state_dim]; self.inner_dim]; seq_len];
|
||||
let mut c = vec![vec![vec![0.0; state_dim]; self.inner_dim]; seq_len];
|
||||
let mut delta = vec![vec![vec![0.0; self.inner_dim]; 1]; seq_len];
|
||||
|
||||
for (t, x) in input.iter().enumerate() {
|
||||
// Compute delta from input (softplus of projection)
|
||||
let dt: Vec<f32> = x
|
||||
.iter()
|
||||
.map(|&xi| {
|
||||
let raw = xi * 0.1; // Simple scaling
|
||||
let dt_val = (1.0 + raw.exp()).ln(); // Softplus
|
||||
dt_val.clamp(self.config.dt_min, self.config.dt_max)
|
||||
})
|
||||
.collect();
|
||||
delta[t][0] = dt.clone();
|
||||
|
||||
for d in 0..self.inner_dim.min(x.len()) {
|
||||
let dt_d = dt[d.min(dt.len() - 1)];
|
||||
|
||||
for n in 0..state_dim {
|
||||
// Discretize A: A_bar = exp(delta * A)
|
||||
let a_continuous = self.a_log[n].exp(); // Negative
|
||||
a_bar[t][d][n] = (dt_d * a_continuous).exp();
|
||||
|
||||
// Discretize B: B_bar = delta * B (simplified)
|
||||
// B is input-dependent
|
||||
let b_input = if d < x.len() { x[d] } else { 0.0 };
|
||||
b_bar[t][d][n] = dt_d * Self::sigmoid(b_input * 0.1);
|
||||
|
||||
// C is input-dependent
|
||||
c[t][d][n] = Self::tanh(b_input * 0.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SelectiveSSMParams {
|
||||
a_bar,
|
||||
b_bar,
|
||||
c,
|
||||
delta,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run selective scan (parallel associative scan in practice)
|
||||
fn selective_scan(&self, input: &[Vec<f32>], params: &SelectiveSSMParams) -> Vec<Vec<f32>> {
|
||||
let seq_len = input.len();
|
||||
let state_dim = self.config.state_dim;
|
||||
|
||||
// Initialize hidden state
|
||||
let mut hidden = vec![vec![0.0f32; state_dim]; self.inner_dim];
|
||||
let mut output = vec![vec![0.0f32; self.inner_dim]; seq_len];
|
||||
|
||||
for t in 0..seq_len {
|
||||
for d in 0..self.inner_dim {
|
||||
let x_d = if d < input[t].len() { input[t][d] } else { 0.0 };
|
||||
|
||||
// Update hidden state: h_t = A_bar * h_{t-1} + B_bar * x_t
|
||||
for n in 0..state_dim {
|
||||
hidden[d][n] =
|
||||
params.a_bar[t][d][n] * hidden[d][n] + params.b_bar[t][d][n] * x_d;
|
||||
}
|
||||
|
||||
// Compute output: y_t = C * h_t
|
||||
output[t][d] = hidden[d]
|
||||
.iter()
|
||||
.zip(params.c[t][d].iter())
|
||||
.map(|(h, c)| h * c)
|
||||
.sum();
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn sigmoid(x: f32) -> f32 {
|
||||
1.0 / (1.0 + (-x).exp())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn tanh(x: f32) -> f32 {
|
||||
x.tanh()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Hybrid Mamba-Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Hybrid layer combining Mamba SSM with standard attention
|
||||
///
|
||||
/// Uses Mamba for long-range dependencies and attention for local patterns
|
||||
#[wasm_bindgen]
|
||||
pub struct HybridMambaAttention {
|
||||
mamba: MambaSSMAttention,
|
||||
local_window: usize,
|
||||
use_attention_for_local: bool,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl HybridMambaAttention {
|
||||
/// Create a new hybrid Mamba-Attention layer
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(config: MambaConfig, local_window: usize) -> HybridMambaAttention {
|
||||
HybridMambaAttention {
|
||||
mamba: MambaSSMAttention::new(config),
|
||||
local_window,
|
||||
use_attention_for_local: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass
|
||||
#[wasm_bindgen]
|
||||
pub fn forward(&self, input: Vec<f32>, seq_len: usize) -> Result<Vec<f32>, JsError> {
|
||||
let dim = self.mamba.config.dim;
|
||||
|
||||
// Run Mamba for global context
|
||||
let mamba_output = self.mamba.forward(input.clone(), seq_len)?;
|
||||
|
||||
// Apply local attention mixing (simplified)
|
||||
let mut output = mamba_output.clone();
|
||||
|
||||
if self.use_attention_for_local {
|
||||
for t in 0..seq_len {
|
||||
let start = t.saturating_sub(self.local_window / 2);
|
||||
let end = (t + self.local_window / 2 + 1).min(seq_len);
|
||||
|
||||
// Simple local averaging
|
||||
for d in 0..dim {
|
||||
let mut local_sum = 0.0;
|
||||
let mut count = 0;
|
||||
for s in start..end {
|
||||
local_sum += input[s * dim + d];
|
||||
count += 1;
|
||||
}
|
||||
// Mix global (Mamba) and local
|
||||
let local_avg = local_sum / count as f32;
|
||||
output[t * dim + d] = 0.7 * output[t * dim + d] + 0.3 * local_avg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Get local window size
|
||||
#[wasm_bindgen(getter, js_name = localWindow)]
|
||||
pub fn local_window(&self) -> usize {
|
||||
self.local_window
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wasm_bindgen_test::*;
|
||||
|
||||
wasm_bindgen_test_configure!(run_in_browser);
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_mamba_config() {
|
||||
let config = MambaConfig::new(256);
|
||||
assert_eq!(config.dim, 256);
|
||||
assert_eq!(config.state_dim, 16);
|
||||
assert_eq!(config.expand_factor, 2);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_mamba_creation() {
|
||||
let config = MambaConfig::new(64);
|
||||
let mamba = MambaSSMAttention::new(config);
|
||||
assert_eq!(mamba.inner_dim(), 128); // 64 * 2
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_mamba_forward() {
|
||||
let config = MambaConfig::new(8);
|
||||
let mamba = MambaSSMAttention::new(config);
|
||||
|
||||
// Input: 4 tokens of dimension 8
|
||||
let input = vec![0.1f32; 32];
|
||||
let output = mamba.forward(input, 4);
|
||||
|
||||
assert!(output.is_ok());
|
||||
let out = output.unwrap();
|
||||
assert_eq!(out.len(), 32); // Same shape as input
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_attention_scores() {
|
||||
let config = MambaConfig::new(8);
|
||||
let mamba = MambaSSMAttention::new(config);
|
||||
|
||||
let input = vec![0.1f32; 24]; // 3 tokens
|
||||
let scores = mamba.get_attention_scores(input, 3);
|
||||
|
||||
assert!(scores.is_ok());
|
||||
let s = scores.unwrap();
|
||||
assert_eq!(s.len(), 9); // 3x3 attention matrix
|
||||
|
||||
// Causal: upper triangle should be 0
|
||||
assert_eq!(s[0 * 3 + 1], 0.0); // t=0 cannot attend to t=1
|
||||
assert_eq!(s[0 * 3 + 2], 0.0); // t=0 cannot attend to t=2
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_hybrid_mamba() {
|
||||
let config = MambaConfig::new(8);
|
||||
let hybrid = HybridMambaAttention::new(config, 4);
|
||||
|
||||
let input = vec![0.5f32; 40]; // 5 tokens
|
||||
let output = hybrid.forward(input, 5);
|
||||
|
||||
assert!(output.is_ok());
|
||||
assert_eq!(output.unwrap().len(), 40);
|
||||
}
|
||||
}
|
||||
439
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/neural.rs
vendored
Normal file
439
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/neural.rs
vendored
Normal file
@@ -0,0 +1,439 @@
|
||||
//! Neural Attention Mechanisms (from ruvector-attention)
|
||||
//!
|
||||
//! Re-exports the 7 core neural attention mechanisms:
|
||||
//! - Scaled Dot-Product Attention
|
||||
//! - Multi-Head Attention
|
||||
//! - Hyperbolic Attention
|
||||
//! - Linear Attention (Performer)
|
||||
//! - Flash Attention
|
||||
//! - Local-Global Attention
|
||||
//! - Mixture of Experts (MoE) Attention
|
||||
|
||||
use ruvector_attention::{
|
||||
attention::{MultiHeadAttention, ScaledDotProductAttention},
|
||||
hyperbolic::{HyperbolicAttention, HyperbolicAttentionConfig},
|
||||
moe::{MoEAttention, MoEConfig},
|
||||
sparse::{FlashAttention, LinearAttention, LocalGlobalAttention},
|
||||
traits::Attention,
|
||||
};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
// ============================================================================
|
||||
// Scaled Dot-Product Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Compute scaled dot-product attention
|
||||
///
|
||||
/// Standard transformer attention: softmax(QK^T / sqrt(d)) * V
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Query vector (Float32Array)
|
||||
/// * `keys` - Array of key vectors (JsValue - array of Float32Arrays)
|
||||
/// * `values` - Array of value vectors (JsValue - array of Float32Arrays)
|
||||
/// * `scale` - Optional scaling factor (defaults to 1/sqrt(dim))
|
||||
///
|
||||
/// # Returns
|
||||
/// Attention-weighted output vector
|
||||
#[wasm_bindgen(js_name = scaledDotAttention)]
|
||||
pub fn scaled_dot_attention(
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
scale: Option<f32>,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)
|
||||
.map_err(|e| JsError::new(&format!("Failed to parse keys: {}", e)))?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)
|
||||
.map_err(|e| JsError::new(&format!("Failed to parse values: {}", e)))?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let attention = ScaledDotProductAttention::new(query.len());
|
||||
attention
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Multi-Head Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Multi-head attention mechanism
|
||||
///
|
||||
/// Splits input into multiple heads, applies attention, and concatenates results
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmMultiHeadAttention {
|
||||
inner: MultiHeadAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmMultiHeadAttention {
|
||||
/// Create a new multi-head attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension (must be divisible by num_heads)
|
||||
/// * `num_heads` - Number of parallel attention heads
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, num_heads: usize) -> Result<WasmMultiHeadAttention, JsError> {
|
||||
if dim % num_heads != 0 {
|
||||
return Err(JsError::new(&format!(
|
||||
"Dimension {} must be divisible by number of heads {}",
|
||||
dim, num_heads
|
||||
)));
|
||||
}
|
||||
Ok(Self {
|
||||
inner: MultiHeadAttention::new(dim, num_heads),
|
||||
})
|
||||
}
|
||||
|
||||
/// Compute multi-head attention
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Query vector
|
||||
/// * `keys` - Array of key vectors
|
||||
/// * `values` - Array of value vectors
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
|
||||
/// Get the number of attention heads
|
||||
#[wasm_bindgen(getter, js_name = numHeads)]
|
||||
pub fn num_heads(&self) -> usize {
|
||||
self.inner.num_heads()
|
||||
}
|
||||
|
||||
/// Get the embedding dimension
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn dim(&self) -> usize {
|
||||
self.inner.dim()
|
||||
}
|
||||
|
||||
/// Get the dimension per head
|
||||
#[wasm_bindgen(getter, js_name = headDim)]
|
||||
pub fn head_dim(&self) -> usize {
|
||||
self.inner.dim() / self.inner.num_heads()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Hyperbolic Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Hyperbolic attention mechanism for hierarchical data
|
||||
///
|
||||
/// Operates in hyperbolic space (Poincare ball model) which naturally
|
||||
/// represents tree-like hierarchical structures with exponential capacity
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmHyperbolicAttention {
|
||||
inner: HyperbolicAttention,
|
||||
curvature_value: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmHyperbolicAttention {
|
||||
/// Create a new hyperbolic attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `curvature` - Hyperbolic curvature parameter (negative for hyperbolic space)
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, curvature: f32) -> WasmHyperbolicAttention {
|
||||
let config = HyperbolicAttentionConfig {
|
||||
dim,
|
||||
curvature,
|
||||
..Default::default()
|
||||
};
|
||||
Self {
|
||||
inner: HyperbolicAttention::new(config),
|
||||
curvature_value: curvature,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute hyperbolic attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
|
||||
/// Get the curvature parameter
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn curvature(&self) -> f32 {
|
||||
self.curvature_value
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Linear Attention (Performer)
|
||||
// ============================================================================
|
||||
|
||||
/// Linear attention using random feature approximation
|
||||
///
|
||||
/// Achieves O(n) complexity instead of O(n^2) by approximating
|
||||
/// the softmax kernel with random Fourier features
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmLinearAttention {
|
||||
inner: LinearAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmLinearAttention {
|
||||
/// Create a new linear attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `num_features` - Number of random features for kernel approximation
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, num_features: usize) -> WasmLinearAttention {
|
||||
Self {
|
||||
inner: LinearAttention::new(dim, num_features),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute linear attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Flash Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Flash attention with memory-efficient tiling
|
||||
///
|
||||
/// Reduces memory usage from O(n^2) to O(n) by computing attention
|
||||
/// in blocks and fusing operations
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmFlashAttention {
|
||||
inner: FlashAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmFlashAttention {
|
||||
/// Create a new flash attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `block_size` - Block size for tiled computation
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, block_size: usize) -> WasmFlashAttention {
|
||||
Self {
|
||||
inner: FlashAttention::new(dim, block_size),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute flash attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Local-Global Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Local-global sparse attention (Longformer-style)
|
||||
///
|
||||
/// Combines local sliding window attention with global tokens
|
||||
/// for efficient long-range dependencies
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmLocalGlobalAttention {
|
||||
inner: LocalGlobalAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmLocalGlobalAttention {
|
||||
/// Create a new local-global attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `local_window` - Size of local attention window
|
||||
/// * `global_tokens` - Number of global attention tokens
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, local_window: usize, global_tokens: usize) -> WasmLocalGlobalAttention {
|
||||
Self {
|
||||
inner: LocalGlobalAttention::new(dim, local_window, global_tokens),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute local-global attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Mixture of Experts (MoE) Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Mixture of Experts attention mechanism
|
||||
///
|
||||
/// Routes queries to specialized expert attention heads based on
|
||||
/// learned gating functions for capacity-efficient computation
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmMoEAttention {
|
||||
inner: MoEAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmMoEAttention {
|
||||
/// Create a new MoE attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `num_experts` - Number of expert attention mechanisms
|
||||
/// * `top_k` - Number of experts to activate per query
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, num_experts: usize, top_k: usize) -> WasmMoEAttention {
|
||||
let config = MoEConfig::builder()
|
||||
.dim(dim)
|
||||
.num_experts(num_experts)
|
||||
.top_k(top_k)
|
||||
.build();
|
||||
Self {
|
||||
inner: MoEAttention::new(config),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute MoE attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wasm_bindgen_test::*;
|
||||
|
||||
wasm_bindgen_test_configure!(run_in_browser);
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_multi_head_creation() {
|
||||
let mha = WasmMultiHeadAttention::new(64, 8);
|
||||
assert!(mha.is_ok());
|
||||
let mha = mha.unwrap();
|
||||
assert_eq!(mha.dim(), 64);
|
||||
assert_eq!(mha.num_heads(), 8);
|
||||
assert_eq!(mha.head_dim(), 8);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_multi_head_invalid_dims() {
|
||||
let mha = WasmMultiHeadAttention::new(65, 8);
|
||||
assert!(mha.is_err());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_hyperbolic_attention() {
|
||||
let hyp = WasmHyperbolicAttention::new(32, -1.0);
|
||||
assert_eq!(hyp.curvature(), -1.0);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_linear_attention_creation() {
|
||||
let linear = WasmLinearAttention::new(64, 128);
|
||||
// Just verify it can be created
|
||||
assert!(true);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_flash_attention_creation() {
|
||||
let flash = WasmFlashAttention::new(64, 16);
|
||||
assert!(true);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_local_global_creation() {
|
||||
let lg = WasmLocalGlobalAttention::new(64, 128, 4);
|
||||
assert!(true);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_moe_attention_creation() {
|
||||
let moe = WasmMoEAttention::new(64, 8, 2);
|
||||
assert!(true);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user