Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
12
vendor/ruvector/crates/ruvector-attention-node/.npmignore
vendored
Normal file
12
vendor/ruvector/crates/ruvector-attention-node/.npmignore
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
target
|
||||
Cargo.lock
|
||||
.cargo
|
||||
.github
|
||||
*.node
|
||||
*.so
|
||||
*.dylib
|
||||
*.dll
|
||||
build
|
||||
.vscode
|
||||
.idea
|
||||
*.iml
|
||||
28
vendor/ruvector/crates/ruvector-attention-node/Cargo.toml
vendored
Normal file
28
vendor/ruvector/crates/ruvector-attention-node/Cargo.toml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
[package]
|
||||
name = "ruvector-attention-node"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
description = "Node.js bindings for ruvector-attention"
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
ruvector-attention = { version = "2.0", path = "../ruvector-attention", default-features = false }
|
||||
napi = { version = "2", default-features = false, features = ["napi9", "async", "serde-json"] }
|
||||
napi-derive = "2"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
tokio = { version = "1", features = ["rt-multi-thread"] }
|
||||
|
||||
[build-dependencies]
|
||||
napi-build = "2"
|
||||
|
||||
[profile.release]
|
||||
lto = true
|
||||
opt-level = 3
|
||||
codegen-units = 1
|
||||
strip = true
|
||||
21
vendor/ruvector/crates/ruvector-attention-node/LICENSE
vendored
Normal file
21
vendor/ruvector/crates/ruvector-attention-node/LICENSE
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 rUv
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
213
vendor/ruvector/crates/ruvector-attention-node/README.md
vendored
Normal file
213
vendor/ruvector/crates/ruvector-attention-node/README.md
vendored
Normal file
@@ -0,0 +1,213 @@
|
||||
# @ruvector/attention
|
||||
|
||||
High-performance attention mechanisms for Node.js, powered by Rust.
|
||||
|
||||
## Features
|
||||
|
||||
- **Scaled Dot-Product Attention**: Classic attention mechanism with optional scaling
|
||||
- **Multi-Head Attention**: Parallel attention heads for richer representations
|
||||
- **Flash Attention**: Memory-efficient attention with block-wise computation
|
||||
- **Linear Attention**: O(N) complexity attention using kernel approximations
|
||||
- **Hyperbolic Attention**: Attention in hyperbolic space for hierarchical data
|
||||
- **Mixture-of-Experts (MoE) Attention**: Dynamic expert routing for specialized attention
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
npm install @ruvector/attention
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Dot-Product Attention
|
||||
|
||||
```javascript
|
||||
const { DotProductAttention } = require('@ruvector/attention');
|
||||
|
||||
const attention = new DotProductAttention(512, 1.0);
|
||||
const query = new Float32Array([/* ... */]);
|
||||
const keys = [new Float32Array([/* ... */])];
|
||||
const values = [new Float32Array([/* ... */])];
|
||||
|
||||
const output = attention.compute(query, keys, values);
|
||||
```
|
||||
|
||||
### Multi-Head Attention
|
||||
|
||||
```javascript
|
||||
const { MultiHeadAttention } = require('@ruvector/attention');
|
||||
|
||||
const mha = new MultiHeadAttention(512, 8); // 512 dim, 8 heads
|
||||
const output = mha.compute(query, keys, values);
|
||||
|
||||
// Async version for large computations
|
||||
const outputAsync = await mha.computeAsync(query, keys, values);
|
||||
```
|
||||
|
||||
### Flash Attention
|
||||
|
||||
```javascript
|
||||
const { FlashAttention } = require('@ruvector/attention');
|
||||
|
||||
const flash = new FlashAttention(512, 64); // 512 dim, 64 block size
|
||||
const output = flash.compute(query, keys, values);
|
||||
```
|
||||
|
||||
### Hyperbolic Attention
|
||||
|
||||
```javascript
|
||||
const { HyperbolicAttention } = require('@ruvector/attention');
|
||||
|
||||
const hyperbolic = new HyperbolicAttention(512, -1.0); // negative curvature
|
||||
const output = hyperbolic.compute(query, keys, values);
|
||||
```
|
||||
|
||||
### Mixture-of-Experts Attention
|
||||
|
||||
```javascript
|
||||
const { MoEAttention } = require('@ruvector/attention');
|
||||
|
||||
const moe = new MoEAttention({
|
||||
dim: 512,
|
||||
numExperts: 8,
|
||||
topK: 2,
|
||||
expertCapacity: 1.25
|
||||
});
|
||||
|
||||
const output = moe.compute(query, keys, values);
|
||||
const expertUsage = moe.getExpertUsage();
|
||||
```
|
||||
|
||||
### Training
|
||||
|
||||
```javascript
|
||||
const { Trainer, AdamOptimizer } = require('@ruvector/attention');
|
||||
|
||||
// Configure training
|
||||
const trainer = new Trainer({
|
||||
learningRate: 0.001,
|
||||
batchSize: 32,
|
||||
numEpochs: 100,
|
||||
weightDecay: 0.01,
|
||||
gradientClip: 1.0,
|
||||
warmupSteps: 1000
|
||||
});
|
||||
|
||||
// Training step
|
||||
const loss = trainer.trainStep(inputs, targets);
|
||||
|
||||
// Get metrics
|
||||
const metrics = trainer.getMetrics();
|
||||
console.log(`Loss: ${metrics.loss}, LR: ${metrics.learningRate}`);
|
||||
|
||||
// Custom optimizer
|
||||
const optimizer = new AdamOptimizer(0.001, 0.9, 0.999, 1e-8);
|
||||
const updatedParams = optimizer.step(gradients);
|
||||
```
|
||||
|
||||
### Batch Processing
|
||||
|
||||
```javascript
|
||||
const { BatchProcessor, parallelAttentionCompute } = require('@ruvector/attention');
|
||||
|
||||
// Batch processor for efficient batching
|
||||
const processor = new BatchProcessor({
|
||||
batchSize: 32,
|
||||
numWorkers: 4,
|
||||
prefetch: true
|
||||
});
|
||||
|
||||
const results = await processor.processBatch(queries, keys, values);
|
||||
const throughput = processor.getThroughput();
|
||||
|
||||
// Parallel computation with automatic worker management
|
||||
const results = await parallelAttentionCompute(
|
||||
'multi-head',
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
4 // number of workers
|
||||
);
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Classes
|
||||
|
||||
#### `DotProductAttention`
|
||||
- `constructor(dim: number, scale?: number)`
|
||||
- `compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array`
|
||||
|
||||
#### `MultiHeadAttention`
|
||||
- `constructor(dim: number, numHeads: number)`
|
||||
- `compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array`
|
||||
- `computeAsync(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Promise<Float32Array>`
|
||||
|
||||
#### `FlashAttention`
|
||||
- `constructor(dim: number, blockSize: number)`
|
||||
- `compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array`
|
||||
|
||||
#### `LinearAttention`
|
||||
- `constructor(dim: number, numFeatures: number)`
|
||||
- `compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array`
|
||||
|
||||
#### `HyperbolicAttention`
|
||||
- `constructor(dim: number, curvature: number)`
|
||||
- `compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array`
|
||||
|
||||
#### `MoEAttention`
|
||||
- `constructor(config: MoEConfig)`
|
||||
- `compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array`
|
||||
- `getExpertUsage(): number[]`
|
||||
|
||||
#### `Trainer`
|
||||
- `constructor(config: TrainingConfig)`
|
||||
- `trainStep(inputs: Float32Array[], targets: Float32Array[]): number`
|
||||
- `trainStepAsync(inputs: Float32Array[], targets: Float32Array[]): Promise<number>`
|
||||
- `getMetrics(): TrainingMetrics`
|
||||
|
||||
#### `AdamOptimizer`
|
||||
- `constructor(learningRate: number, beta1?: number, beta2?: number, epsilon?: number)`
|
||||
- `step(gradients: Float32Array[]): Float32Array[]`
|
||||
- `getLearningRate(): number`
|
||||
- `setLearningRate(lr: number): void`
|
||||
|
||||
#### `BatchProcessor`
|
||||
- `constructor(config: BatchConfig)`
|
||||
- `processBatch(queries: Float32Array[], keys: Float32Array[][], values: Float32Array[][]): Promise<Float32Array[]>`
|
||||
- `getThroughput(): number`
|
||||
|
||||
### Functions
|
||||
|
||||
#### `parallelAttentionCompute`
|
||||
```typescript
|
||||
function parallelAttentionCompute(
|
||||
attentionType: string,
|
||||
queries: Float32Array[],
|
||||
keys: Float32Array[][],
|
||||
values: Float32Array[][],
|
||||
numWorkers?: number
|
||||
): Promise<Float32Array[]>
|
||||
```
|
||||
|
||||
#### `version`
|
||||
Returns the package version string.
|
||||
|
||||
## Performance
|
||||
|
||||
This package uses Rust under the hood for optimal performance:
|
||||
- Zero-copy data transfer where possible
|
||||
- SIMD optimizations for vector operations
|
||||
- Multi-threaded batch processing
|
||||
- Memory-efficient attention mechanisms
|
||||
|
||||
## Platform Support
|
||||
|
||||
Pre-built binaries are provided for:
|
||||
- macOS (x64, ARM64)
|
||||
- Linux (x64, ARM64, musl)
|
||||
- Windows (x64, ARM64)
|
||||
|
||||
## License
|
||||
|
||||
MIT OR Apache-2.0
|
||||
5
vendor/ruvector/crates/ruvector-attention-node/build.rs
vendored
Normal file
5
vendor/ruvector/crates/ruvector-attention-node/build.rs
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
extern crate napi_build;
|
||||
|
||||
fn main() {
|
||||
napi_build::setup();
|
||||
}
|
||||
BIN
vendor/ruvector/crates/ruvector-attention-node/npm/darwin-arm64/attention.darwin-arm64.node
vendored
Normal file
BIN
vendor/ruvector/crates/ruvector-attention-node/npm/darwin-arm64/attention.darwin-arm64.node
vendored
Normal file
Binary file not shown.
19
vendor/ruvector/crates/ruvector-attention-node/npm/darwin-arm64/package.json
vendored
Normal file
19
vendor/ruvector/crates/ruvector-attention-node/npm/darwin-arm64/package.json
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"name": "@ruvector/attention-darwin-arm64",
|
||||
"version": "0.1.1",
|
||||
"os": [
|
||||
"darwin"
|
||||
],
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"main": "attention.darwin-arm64.node",
|
||||
"files": [
|
||||
"attention.darwin-arm64.node"
|
||||
],
|
||||
"license": "MIT OR Apache-2.0",
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
},
|
||||
"repository": "https://github.com/ruvnet/ruvector"
|
||||
}
|
||||
BIN
vendor/ruvector/crates/ruvector-attention-node/npm/darwin-x64/attention.darwin-x64.node
vendored
Normal file
BIN
vendor/ruvector/crates/ruvector-attention-node/npm/darwin-x64/attention.darwin-x64.node
vendored
Normal file
Binary file not shown.
19
vendor/ruvector/crates/ruvector-attention-node/npm/darwin-x64/package.json
vendored
Normal file
19
vendor/ruvector/crates/ruvector-attention-node/npm/darwin-x64/package.json
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"name": "@ruvector/attention-darwin-x64",
|
||||
"version": "0.1.4",
|
||||
"os": [
|
||||
"darwin"
|
||||
],
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"main": "attention.darwin-x64.node",
|
||||
"files": [
|
||||
"attention.darwin-x64.node"
|
||||
],
|
||||
"license": "MIT OR Apache-2.0",
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
},
|
||||
"repository": "https://github.com/ruvnet/ruvector"
|
||||
}
|
||||
BIN
vendor/ruvector/crates/ruvector-attention-node/npm/linux-arm64-gnu/attention.linux-arm64-gnu.node
vendored
Normal file
BIN
vendor/ruvector/crates/ruvector-attention-node/npm/linux-arm64-gnu/attention.linux-arm64-gnu.node
vendored
Normal file
Binary file not shown.
22
vendor/ruvector/crates/ruvector-attention-node/npm/linux-arm64-gnu/package.json
vendored
Normal file
22
vendor/ruvector/crates/ruvector-attention-node/npm/linux-arm64-gnu/package.json
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"name": "@ruvector/attention-linux-arm64-gnu",
|
||||
"version": "0.1.1",
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"main": "attention.linux-arm64-gnu.node",
|
||||
"files": [
|
||||
"attention.linux-arm64-gnu.node"
|
||||
],
|
||||
"license": "MIT OR Apache-2.0",
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
},
|
||||
"libc": [
|
||||
"glibc"
|
||||
],
|
||||
"repository": "https://github.com/ruvnet/ruvector"
|
||||
}
|
||||
BIN
vendor/ruvector/crates/ruvector-attention-node/npm/linux-x64-gnu/attention.linux-x64-gnu.node
vendored
Normal file
BIN
vendor/ruvector/crates/ruvector-attention-node/npm/linux-x64-gnu/attention.linux-x64-gnu.node
vendored
Normal file
Binary file not shown.
22
vendor/ruvector/crates/ruvector-attention-node/npm/linux-x64-gnu/package.json
vendored
Normal file
22
vendor/ruvector/crates/ruvector-attention-node/npm/linux-x64-gnu/package.json
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"name": "@ruvector/attention-linux-x64-gnu",
|
||||
"version": "0.1.4",
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"main": "attention.linux-x64-gnu.node",
|
||||
"files": [
|
||||
"attention.linux-x64-gnu.node"
|
||||
],
|
||||
"license": "MIT OR Apache-2.0",
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
},
|
||||
"libc": [
|
||||
"glibc"
|
||||
],
|
||||
"repository": "https://github.com/ruvnet/ruvector"
|
||||
}
|
||||
22
vendor/ruvector/crates/ruvector-attention-node/npm/linux-x64-musl/package.json
vendored
Normal file
22
vendor/ruvector/crates/ruvector-attention-node/npm/linux-x64-musl/package.json
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"name": "@ruvector/attention-linux-x64-musl",
|
||||
"version": "0.1.0",
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"main": "attention.linux-x64-musl.node",
|
||||
"files": [
|
||||
"attention.linux-x64-musl.node"
|
||||
],
|
||||
"license": "MIT OR Apache-2.0",
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
},
|
||||
"libc": [
|
||||
"musl"
|
||||
],
|
||||
"repository": "https://github.com/ruvnet/ruvector"
|
||||
}
|
||||
19
vendor/ruvector/crates/ruvector-attention-node/npm/win32-arm64-msvc/package.json
vendored
Normal file
19
vendor/ruvector/crates/ruvector-attention-node/npm/win32-arm64-msvc/package.json
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"name": "@ruvector/attention-win32-arm64-msvc",
|
||||
"version": "0.1.0",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"main": "attention.win32-arm64-msvc.node",
|
||||
"files": [
|
||||
"attention.win32-arm64-msvc.node"
|
||||
],
|
||||
"license": "MIT OR Apache-2.0",
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
},
|
||||
"repository": "https://github.com/ruvnet/ruvector"
|
||||
}
|
||||
BIN
vendor/ruvector/crates/ruvector-attention-node/npm/win32-x64-msvc/attention.win32-x64-msvc.node
vendored
Normal file
BIN
vendor/ruvector/crates/ruvector-attention-node/npm/win32-x64-msvc/attention.win32-x64-msvc.node
vendored
Normal file
Binary file not shown.
19
vendor/ruvector/crates/ruvector-attention-node/npm/win32-x64-msvc/package.json
vendored
Normal file
19
vendor/ruvector/crates/ruvector-attention-node/npm/win32-x64-msvc/package.json
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"name": "@ruvector/attention-win32-x64-msvc",
|
||||
"version": "0.1.4",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"main": "attention.win32-x64-msvc.node",
|
||||
"files": [
|
||||
"attention.win32-x64-msvc.node"
|
||||
],
|
||||
"license": "MIT OR Apache-2.0",
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
},
|
||||
"repository": "https://github.com/ruvnet/ruvector"
|
||||
}
|
||||
65
vendor/ruvector/crates/ruvector-attention-node/package.json
vendored
Normal file
65
vendor/ruvector/crates/ruvector-attention-node/package.json
vendored
Normal file
@@ -0,0 +1,65 @@
|
||||
{
|
||||
"name": "@ruvector/attention",
|
||||
"version": "0.1.4",
|
||||
"description": "High-performance attention mechanisms for Node.js",
|
||||
"main": "index.js",
|
||||
"types": "index.d.ts",
|
||||
"napi": {
|
||||
"binaryName": "attention",
|
||||
"targets": [
|
||||
"x86_64-pc-windows-msvc",
|
||||
"x86_64-apple-darwin",
|
||||
"x86_64-unknown-linux-gnu",
|
||||
"x86_64-unknown-linux-musl",
|
||||
"aarch64-apple-darwin",
|
||||
"aarch64-unknown-linux-gnu",
|
||||
"aarch64-unknown-linux-musl",
|
||||
"aarch64-pc-windows-msvc"
|
||||
]
|
||||
},
|
||||
"scripts": {
|
||||
"artifacts": "napi artifacts",
|
||||
"build": "napi build --platform --release",
|
||||
"build:debug": "napi build --platform",
|
||||
"prepublishOnly": "napi prepublish -t npm",
|
||||
"test": "node --test",
|
||||
"universal": "napi universal",
|
||||
"version": "napi version"
|
||||
},
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "https://github.com/ruvnet/ruvector",
|
||||
"directory": "crates/ruvector-attention-node"
|
||||
},
|
||||
"author": "rUv <ruv@ruv.io>",
|
||||
"license": "MIT OR Apache-2.0",
|
||||
"keywords": [
|
||||
"attention",
|
||||
"transformer",
|
||||
"machine-learning",
|
||||
"neural-network",
|
||||
"napi-rs",
|
||||
"rust",
|
||||
"multi-head-attention",
|
||||
"flash-attention",
|
||||
"hyperbolic",
|
||||
"mixture-of-experts"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
},
|
||||
"publishConfig": {
|
||||
"registry": "https://registry.npmjs.org/",
|
||||
"access": "public"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@ruvector/attention-win32-x64-msvc": "0.1.4",
|
||||
"@ruvector/attention-darwin-x64": "0.1.4",
|
||||
"@ruvector/attention-darwin-arm64": "0.1.4",
|
||||
"@ruvector/attention-linux-x64-gnu": "0.1.4",
|
||||
"@ruvector/attention-linux-arm64-gnu": "0.1.4"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@napi-rs/cli": "^2.18.0"
|
||||
}
|
||||
}
|
||||
536
vendor/ruvector/crates/ruvector-attention-node/src/async_ops.rs
vendored
Normal file
536
vendor/ruvector/crates/ruvector-attention-node/src/async_ops.rs
vendored
Normal file
@@ -0,0 +1,536 @@
|
||||
//! NAPI-RS bindings for async and batch operations
|
||||
//!
|
||||
//! Provides Node.js bindings for:
|
||||
//! - Async attention computation with tokio
|
||||
//! - Batch processing utilities
|
||||
//! - Parallel attention computation
|
||||
|
||||
use napi::bindgen_prelude::*;
|
||||
use napi_derive::napi;
|
||||
use ruvector_attention::{
|
||||
attention::ScaledDotProductAttention,
|
||||
hyperbolic::{HyperbolicAttention, HyperbolicAttentionConfig},
|
||||
sparse::{FlashAttention, LinearAttention, LocalGlobalAttention},
|
||||
traits::Attention,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
// ============================================================================
|
||||
// Batch Processing Configuration
|
||||
// ============================================================================
|
||||
|
||||
/// Batch processing configuration
|
||||
#[napi(object)]
|
||||
pub struct BatchConfig {
|
||||
pub batch_size: u32,
|
||||
pub num_workers: Option<u32>,
|
||||
pub prefetch: Option<bool>,
|
||||
}
|
||||
|
||||
/// Batch processing result
|
||||
#[napi(object)]
|
||||
pub struct BatchResult {
|
||||
pub outputs: Vec<Float32Array>,
|
||||
pub elapsed_ms: f64,
|
||||
pub throughput: f64,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Async Attention Operations
|
||||
// ============================================================================
|
||||
|
||||
/// Async scaled dot-product attention computation
|
||||
#[napi]
|
||||
pub async fn compute_attention_async(
|
||||
query: Float32Array,
|
||||
keys: Vec<Float32Array>,
|
||||
values: Vec<Float32Array>,
|
||||
dim: u32,
|
||||
) -> Result<Float32Array> {
|
||||
let query_vec = query.to_vec();
|
||||
let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
|
||||
let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
|
||||
|
||||
let result = tokio::task::spawn_blocking(move || {
|
||||
let attention = ScaledDotProductAttention::new(dim as usize);
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
attention.compute(&query_vec, &keys_refs, &values_refs)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?;
|
||||
|
||||
Ok(Float32Array::new(result))
|
||||
}
|
||||
|
||||
/// Async flash attention computation
|
||||
#[napi]
|
||||
pub async fn compute_flash_attention_async(
|
||||
query: Float32Array,
|
||||
keys: Vec<Float32Array>,
|
||||
values: Vec<Float32Array>,
|
||||
dim: u32,
|
||||
block_size: u32,
|
||||
) -> Result<Float32Array> {
|
||||
let query_vec = query.to_vec();
|
||||
let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
|
||||
let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
|
||||
|
||||
let result = tokio::task::spawn_blocking(move || {
|
||||
let attention = FlashAttention::new(dim as usize, block_size as usize);
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
attention.compute(&query_vec, &keys_refs, &values_refs)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?;
|
||||
|
||||
Ok(Float32Array::new(result))
|
||||
}
|
||||
|
||||
/// Async hyperbolic attention computation
|
||||
#[napi]
|
||||
pub async fn compute_hyperbolic_attention_async(
|
||||
query: Float32Array,
|
||||
keys: Vec<Float32Array>,
|
||||
values: Vec<Float32Array>,
|
||||
dim: u32,
|
||||
curvature: f64,
|
||||
) -> Result<Float32Array> {
|
||||
let query_vec = query.to_vec();
|
||||
let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
|
||||
let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
|
||||
|
||||
let result = tokio::task::spawn_blocking(move || {
|
||||
let config = HyperbolicAttentionConfig {
|
||||
dim: dim as usize,
|
||||
curvature: curvature as f32,
|
||||
..Default::default()
|
||||
};
|
||||
let attention = HyperbolicAttention::new(config);
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
attention.compute(&query_vec, &keys_refs, &values_refs)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?;
|
||||
|
||||
Ok(Float32Array::new(result))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Batch Processing
|
||||
// ============================================================================
|
||||
|
||||
/// Process a batch of attention computations
|
||||
#[napi]
|
||||
pub async fn batch_attention_compute(
|
||||
queries: Vec<Float32Array>,
|
||||
keys: Vec<Vec<Float32Array>>,
|
||||
values: Vec<Vec<Float32Array>>,
|
||||
dim: u32,
|
||||
) -> Result<BatchResult> {
|
||||
let start = std::time::Instant::now();
|
||||
let batch_size = queries.len();
|
||||
|
||||
// Convert to owned vectors for thread safety
|
||||
let queries_vec: Vec<Vec<f32>> = queries.into_iter().map(|q| q.to_vec()).collect();
|
||||
let keys_vec: Vec<Vec<Vec<f32>>> = keys
|
||||
.into_iter()
|
||||
.map(|k| k.into_iter().map(|arr| arr.to_vec()).collect())
|
||||
.collect();
|
||||
let values_vec: Vec<Vec<Vec<f32>>> = values
|
||||
.into_iter()
|
||||
.map(|v| v.into_iter().map(|arr| arr.to_vec()).collect())
|
||||
.collect();
|
||||
|
||||
let dim_usize = dim as usize;
|
||||
|
||||
let results = tokio::task::spawn_blocking(move || {
|
||||
let attention = ScaledDotProductAttention::new(dim_usize);
|
||||
let mut outputs = Vec::with_capacity(batch_size);
|
||||
|
||||
for i in 0..batch_size {
|
||||
let keys_refs: Vec<&[f32]> = keys_vec[i].iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec[i].iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
match attention.compute(&queries_vec[i], &keys_refs, &values_refs) {
|
||||
Ok(output) => outputs.push(output),
|
||||
Err(e) => return Err(e.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(outputs)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?
|
||||
.map_err(|e| Error::from_reason(e))?;
|
||||
|
||||
let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
|
||||
let throughput = batch_size as f64 / start.elapsed().as_secs_f64();
|
||||
|
||||
Ok(BatchResult {
|
||||
outputs: results.into_iter().map(Float32Array::new).collect(),
|
||||
elapsed_ms,
|
||||
throughput,
|
||||
})
|
||||
}
|
||||
|
||||
/// Process a batch with flash attention
|
||||
#[napi]
|
||||
pub async fn batch_flash_attention_compute(
|
||||
queries: Vec<Float32Array>,
|
||||
keys: Vec<Vec<Float32Array>>,
|
||||
values: Vec<Vec<Float32Array>>,
|
||||
dim: u32,
|
||||
block_size: u32,
|
||||
) -> Result<BatchResult> {
|
||||
let start = std::time::Instant::now();
|
||||
let batch_size = queries.len();
|
||||
|
||||
let queries_vec: Vec<Vec<f32>> = queries.into_iter().map(|q| q.to_vec()).collect();
|
||||
let keys_vec: Vec<Vec<Vec<f32>>> = keys
|
||||
.into_iter()
|
||||
.map(|k| k.into_iter().map(|arr| arr.to_vec()).collect())
|
||||
.collect();
|
||||
let values_vec: Vec<Vec<Vec<f32>>> = values
|
||||
.into_iter()
|
||||
.map(|v| v.into_iter().map(|arr| arr.to_vec()).collect())
|
||||
.collect();
|
||||
|
||||
let dim_usize = dim as usize;
|
||||
let block_usize = block_size as usize;
|
||||
|
||||
let results = tokio::task::spawn_blocking(move || {
|
||||
let attention = FlashAttention::new(dim_usize, block_usize);
|
||||
let mut outputs = Vec::with_capacity(batch_size);
|
||||
|
||||
for i in 0..batch_size {
|
||||
let keys_refs: Vec<&[f32]> = keys_vec[i].iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec[i].iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
match attention.compute(&queries_vec[i], &keys_refs, &values_refs) {
|
||||
Ok(output) => outputs.push(output),
|
||||
Err(e) => return Err(e.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(outputs)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?
|
||||
.map_err(|e| Error::from_reason(e))?;
|
||||
|
||||
let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
|
||||
let throughput = batch_size as f64 / start.elapsed().as_secs_f64();
|
||||
|
||||
Ok(BatchResult {
|
||||
outputs: results.into_iter().map(Float32Array::new).collect(),
|
||||
elapsed_ms,
|
||||
throughput,
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Parallel Attention Computation
|
||||
// ============================================================================
|
||||
|
||||
/// Attention type for parallel computation
|
||||
#[napi(string_enum)]
|
||||
pub enum AttentionType {
|
||||
ScaledDotProduct,
|
||||
Flash,
|
||||
Linear,
|
||||
LocalGlobal,
|
||||
Hyperbolic,
|
||||
}
|
||||
|
||||
/// Configuration for parallel attention
|
||||
#[napi(object)]
|
||||
pub struct ParallelConfig {
|
||||
pub attention_type: AttentionType,
|
||||
pub dim: u32,
|
||||
pub block_size: Option<u32>,
|
||||
pub num_features: Option<u32>,
|
||||
pub local_window: Option<u32>,
|
||||
pub global_tokens: Option<u32>,
|
||||
pub curvature: Option<f64>,
|
||||
}
|
||||
|
||||
/// Parallel attention computation across multiple queries
|
||||
#[napi]
|
||||
pub async fn parallel_attention_compute(
|
||||
config: ParallelConfig,
|
||||
queries: Vec<Float32Array>,
|
||||
keys: Vec<Vec<Float32Array>>,
|
||||
values: Vec<Vec<Float32Array>>,
|
||||
) -> Result<BatchResult> {
|
||||
let start = std::time::Instant::now();
|
||||
let batch_size = queries.len();
|
||||
|
||||
let queries_vec: Vec<Vec<f32>> = queries.into_iter().map(|q| q.to_vec()).collect();
|
||||
let keys_vec: Vec<Vec<Vec<f32>>> = keys
|
||||
.into_iter()
|
||||
.map(|k| k.into_iter().map(|arr| arr.to_vec()).collect())
|
||||
.collect();
|
||||
let values_vec: Vec<Vec<Vec<f32>>> = values
|
||||
.into_iter()
|
||||
.map(|v| v.into_iter().map(|arr| arr.to_vec()).collect())
|
||||
.collect();
|
||||
|
||||
let dim = config.dim as usize;
|
||||
let attention_type = config.attention_type;
|
||||
let block_size = config.block_size.unwrap_or(64) as usize;
|
||||
let num_features = config.num_features.unwrap_or(64) as usize;
|
||||
let local_window = config.local_window.unwrap_or(128) as usize;
|
||||
let global_tokens = config.global_tokens.unwrap_or(8) as usize;
|
||||
let curvature = config.curvature.unwrap_or(1.0) as f32;
|
||||
|
||||
let results = tokio::task::spawn_blocking(move || {
|
||||
let mut outputs = Vec::with_capacity(batch_size);
|
||||
|
||||
for i in 0..batch_size {
|
||||
let keys_refs: Vec<&[f32]> = keys_vec[i].iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec[i].iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = match attention_type {
|
||||
AttentionType::ScaledDotProduct => {
|
||||
let attention = ScaledDotProductAttention::new(dim);
|
||||
attention.compute(&queries_vec[i], &keys_refs, &values_refs)
|
||||
}
|
||||
AttentionType::Flash => {
|
||||
let attention = FlashAttention::new(dim, block_size);
|
||||
attention.compute(&queries_vec[i], &keys_refs, &values_refs)
|
||||
}
|
||||
AttentionType::Linear => {
|
||||
let attention = LinearAttention::new(dim, num_features);
|
||||
attention.compute(&queries_vec[i], &keys_refs, &values_refs)
|
||||
}
|
||||
AttentionType::LocalGlobal => {
|
||||
let attention = LocalGlobalAttention::new(dim, local_window, global_tokens);
|
||||
attention.compute(&queries_vec[i], &keys_refs, &values_refs)
|
||||
}
|
||||
AttentionType::Hyperbolic => {
|
||||
let config = HyperbolicAttentionConfig {
|
||||
dim,
|
||||
curvature,
|
||||
..Default::default()
|
||||
};
|
||||
let attention = HyperbolicAttention::new(config);
|
||||
attention.compute(&queries_vec[i], &keys_refs, &values_refs)
|
||||
}
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(output) => outputs.push(output),
|
||||
Err(e) => return Err(e.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(outputs)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?
|
||||
.map_err(|e| Error::from_reason(e))?;
|
||||
|
||||
let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
|
||||
let throughput = batch_size as f64 / start.elapsed().as_secs_f64();
|
||||
|
||||
Ok(BatchResult {
|
||||
outputs: results.into_iter().map(Float32Array::new).collect(),
|
||||
elapsed_ms,
|
||||
throughput,
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Streaming Processing
|
||||
// ============================================================================
|
||||
|
||||
/// Stream processor for handling attention in chunks
|
||||
#[napi]
|
||||
pub struct StreamProcessor {
|
||||
dim: usize,
|
||||
buffer: Vec<Vec<f32>>,
|
||||
max_buffer_size: usize,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl StreamProcessor {
|
||||
/// Create a new stream processor
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `max_buffer_size` - Maximum number of items to buffer
|
||||
#[napi(constructor)]
|
||||
pub fn new(dim: u32, max_buffer_size: u32) -> Self {
|
||||
Self {
|
||||
dim: dim as usize,
|
||||
buffer: Vec::new(),
|
||||
max_buffer_size: max_buffer_size as usize,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a vector to the buffer
|
||||
#[napi]
|
||||
pub fn push(&mut self, vector: Float32Array) -> bool {
|
||||
if self.buffer.len() >= self.max_buffer_size {
|
||||
return false;
|
||||
}
|
||||
self.buffer.push(vector.to_vec());
|
||||
true
|
||||
}
|
||||
|
||||
/// Process buffered vectors with attention against a query
|
||||
#[napi]
|
||||
pub fn process(&self, query: Float32Array) -> Result<Float32Array> {
|
||||
if self.buffer.is_empty() {
|
||||
return Err(Error::from_reason("Buffer is empty"));
|
||||
}
|
||||
|
||||
let attention = ScaledDotProductAttention::new(self.dim);
|
||||
let query_slice = query.as_ref();
|
||||
let keys_refs: Vec<&[f32]> = self.buffer.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = self.buffer.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = attention
|
||||
.compute(query_slice, &keys_refs, &values_refs)
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?;
|
||||
|
||||
Ok(Float32Array::new(result))
|
||||
}
|
||||
|
||||
/// Clear the buffer
|
||||
#[napi]
|
||||
pub fn clear(&mut self) {
|
||||
self.buffer.clear();
|
||||
}
|
||||
|
||||
/// Get current buffer size
|
||||
#[napi(getter)]
|
||||
pub fn size(&self) -> u32 {
|
||||
self.buffer.len() as u32
|
||||
}
|
||||
|
||||
/// Check if buffer is full
|
||||
#[napi(getter)]
|
||||
pub fn is_full(&self) -> bool {
|
||||
self.buffer.len() >= self.max_buffer_size
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Benchmark Utilities
|
||||
// ============================================================================
|
||||
|
||||
/// Benchmark result
|
||||
#[napi(object)]
|
||||
pub struct BenchmarkResult {
|
||||
pub name: String,
|
||||
pub iterations: u32,
|
||||
pub total_ms: f64,
|
||||
pub avg_ms: f64,
|
||||
pub ops_per_sec: f64,
|
||||
pub min_ms: f64,
|
||||
pub max_ms: f64,
|
||||
}
|
||||
|
||||
/// Run attention benchmark
|
||||
#[napi]
|
||||
pub async fn benchmark_attention(
|
||||
attention_type: AttentionType,
|
||||
dim: u32,
|
||||
seq_length: u32,
|
||||
iterations: u32,
|
||||
) -> Result<BenchmarkResult> {
|
||||
let dim_usize = dim as usize;
|
||||
let seq_usize = seq_length as usize;
|
||||
let iter_usize = iterations as usize;
|
||||
|
||||
let result = tokio::task::spawn_blocking(move || {
|
||||
// Generate test data
|
||||
let query: Vec<f32> = (0..dim_usize).map(|i| (i as f32 * 0.01).sin()).collect();
|
||||
let keys: Vec<Vec<f32>> = (0..seq_usize)
|
||||
.map(|j| {
|
||||
(0..dim_usize)
|
||||
.map(|i| ((i + j) as f32 * 0.01).cos())
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let values: Vec<Vec<f32>> = keys.clone();
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let name = match attention_type {
|
||||
AttentionType::ScaledDotProduct => "ScaledDotProduct",
|
||||
AttentionType::Flash => "Flash",
|
||||
AttentionType::Linear => "Linear",
|
||||
AttentionType::LocalGlobal => "LocalGlobal",
|
||||
AttentionType::Hyperbolic => "Hyperbolic",
|
||||
}
|
||||
.to_string();
|
||||
|
||||
let mut times: Vec<f64> = Vec::with_capacity(iter_usize);
|
||||
|
||||
for _ in 0..iter_usize {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
match attention_type {
|
||||
AttentionType::ScaledDotProduct => {
|
||||
let attention = ScaledDotProductAttention::new(dim_usize);
|
||||
let _ = attention.compute(&query, &keys_refs, &values_refs);
|
||||
}
|
||||
AttentionType::Flash => {
|
||||
let attention = FlashAttention::new(dim_usize, 64);
|
||||
let _ = attention.compute(&query, &keys_refs, &values_refs);
|
||||
}
|
||||
AttentionType::Linear => {
|
||||
let attention = LinearAttention::new(dim_usize, 64);
|
||||
let _ = attention.compute(&query, &keys_refs, &values_refs);
|
||||
}
|
||||
AttentionType::LocalGlobal => {
|
||||
let attention = LocalGlobalAttention::new(dim_usize, 128, 8);
|
||||
let _ = attention.compute(&query, &keys_refs, &values_refs);
|
||||
}
|
||||
AttentionType::Hyperbolic => {
|
||||
let config = HyperbolicAttentionConfig {
|
||||
dim: dim_usize,
|
||||
curvature: 1.0,
|
||||
..Default::default()
|
||||
};
|
||||
let attention = HyperbolicAttention::new(config);
|
||||
let _ = attention.compute(&query, &keys_refs, &values_refs);
|
||||
}
|
||||
}
|
||||
|
||||
times.push(start.elapsed().as_secs_f64() * 1000.0);
|
||||
}
|
||||
|
||||
let total_ms: f64 = times.iter().sum();
|
||||
let avg_ms = total_ms / iter_usize as f64;
|
||||
let min_ms = times.iter().copied().fold(f64::INFINITY, f64::min);
|
||||
let max_ms = times.iter().copied().fold(f64::NEG_INFINITY, f64::max);
|
||||
let ops_per_sec = 1000.0 / avg_ms;
|
||||
|
||||
BenchmarkResult {
|
||||
name,
|
||||
iterations: iterations,
|
||||
total_ms,
|
||||
avg_ms,
|
||||
ops_per_sec,
|
||||
min_ms,
|
||||
max_ms,
|
||||
}
|
||||
})
|
||||
.await
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
614
vendor/ruvector/crates/ruvector-attention-node/src/attention.rs
vendored
Normal file
614
vendor/ruvector/crates/ruvector-attention-node/src/attention.rs
vendored
Normal file
@@ -0,0 +1,614 @@
|
||||
//! NAPI-RS bindings for attention mechanisms
|
||||
//!
|
||||
//! Provides Node.js bindings for all attention variants:
|
||||
//! - Scaled dot-product attention
|
||||
//! - Multi-head attention
|
||||
//! - Hyperbolic attention
|
||||
//! - Flash attention
|
||||
//! - Linear attention
|
||||
//! - Local-global attention
|
||||
//! - Mixture of Experts attention
|
||||
|
||||
use napi::bindgen_prelude::*;
|
||||
use napi_derive::napi;
|
||||
use ruvector_attention::{
|
||||
attention::{MultiHeadAttention as RustMultiHead, ScaledDotProductAttention},
|
||||
hyperbolic::{HyperbolicAttention as RustHyperbolic, HyperbolicAttentionConfig},
|
||||
moe::{MoEAttention as RustMoE, MoEConfig as RustMoEConfig},
|
||||
sparse::{
|
||||
FlashAttention as RustFlash, LinearAttention as RustLinear,
|
||||
LocalGlobalAttention as RustLocalGlobal,
|
||||
},
|
||||
traits::Attention,
|
||||
};
|
||||
|
||||
/// Attention configuration object
|
||||
#[napi(object)]
|
||||
pub struct AttentionConfig {
|
||||
pub dim: u32,
|
||||
pub num_heads: Option<u32>,
|
||||
pub dropout: Option<f64>,
|
||||
pub scale: Option<f64>,
|
||||
pub causal: Option<bool>,
|
||||
}
|
||||
|
||||
/// Scaled dot-product attention
|
||||
#[napi]
|
||||
pub struct DotProductAttention {
|
||||
inner: ScaledDotProductAttention,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl DotProductAttention {
|
||||
/// Create a new scaled dot-product attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
#[napi(constructor)]
|
||||
pub fn new(dim: u32) -> Result<Self> {
|
||||
Ok(Self {
|
||||
inner: ScaledDotProductAttention::new(dim as usize),
|
||||
})
|
||||
}
|
||||
|
||||
/// Compute attention output
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Query vector
|
||||
/// * `keys` - Array of key vectors
|
||||
/// * `values` - Array of value vectors
|
||||
#[napi]
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: Float32Array,
|
||||
keys: Vec<Float32Array>,
|
||||
values: Vec<Float32Array>,
|
||||
) -> Result<Float32Array> {
|
||||
let query_slice = query.as_ref();
|
||||
let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
|
||||
let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = self
|
||||
.inner
|
||||
.compute(query_slice, &keys_refs, &values_refs)
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?;
|
||||
|
||||
Ok(Float32Array::new(result))
|
||||
}
|
||||
|
||||
/// Compute attention with mask
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Query vector
|
||||
/// * `keys` - Array of key vectors
|
||||
/// * `values` - Array of value vectors
|
||||
/// * `mask` - Boolean mask array (true = attend, false = mask)
|
||||
#[napi]
|
||||
pub fn compute_with_mask(
|
||||
&self,
|
||||
query: Float32Array,
|
||||
keys: Vec<Float32Array>,
|
||||
values: Vec<Float32Array>,
|
||||
mask: Vec<bool>,
|
||||
) -> Result<Float32Array> {
|
||||
let query_slice = query.as_ref();
|
||||
let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
|
||||
let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = self
|
||||
.inner
|
||||
.compute_with_mask(query_slice, &keys_refs, &values_refs, Some(mask.as_slice()))
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?;
|
||||
|
||||
Ok(Float32Array::new(result))
|
||||
}
|
||||
|
||||
/// Get the dimension
|
||||
#[napi(getter)]
|
||||
pub fn dim(&self) -> u32 {
|
||||
self.inner.dim() as u32
|
||||
}
|
||||
}
|
||||
|
||||
/// Multi-head attention mechanism
|
||||
#[napi]
|
||||
pub struct MultiHeadAttention {
|
||||
inner: RustMultiHead,
|
||||
dim_value: usize,
|
||||
num_heads_value: usize,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl MultiHeadAttention {
|
||||
/// Create a new multi-head attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension (must be divisible by num_heads)
|
||||
/// * `num_heads` - Number of attention heads
|
||||
#[napi(constructor)]
|
||||
pub fn new(dim: u32, num_heads: u32) -> Result<Self> {
|
||||
let d = dim as usize;
|
||||
let h = num_heads as usize;
|
||||
|
||||
if d % h != 0 {
|
||||
return Err(Error::from_reason(format!(
|
||||
"Dimension {} must be divisible by number of heads {}",
|
||||
d, h
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
inner: RustMultiHead::new(d, h),
|
||||
dim_value: d,
|
||||
num_heads_value: h,
|
||||
})
|
||||
}
|
||||
|
||||
/// Compute multi-head attention
|
||||
#[napi]
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: Float32Array,
|
||||
keys: Vec<Float32Array>,
|
||||
values: Vec<Float32Array>,
|
||||
) -> Result<Float32Array> {
|
||||
let query_slice = query.as_ref();
|
||||
let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
|
||||
let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = self
|
||||
.inner
|
||||
.compute(query_slice, &keys_refs, &values_refs)
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?;
|
||||
|
||||
Ok(Float32Array::new(result))
|
||||
}
|
||||
|
||||
/// Get the number of heads
|
||||
#[napi(getter)]
|
||||
pub fn num_heads(&self) -> u32 {
|
||||
self.num_heads_value as u32
|
||||
}
|
||||
|
||||
/// Get the dimension
|
||||
#[napi(getter)]
|
||||
pub fn dim(&self) -> u32 {
|
||||
self.dim_value as u32
|
||||
}
|
||||
|
||||
/// Get the head dimension
|
||||
#[napi(getter)]
|
||||
pub fn head_dim(&self) -> u32 {
|
||||
(self.dim_value / self.num_heads_value) as u32
|
||||
}
|
||||
}
|
||||
|
||||
/// Hyperbolic attention in Poincaré ball model
|
||||
#[napi]
|
||||
pub struct HyperbolicAttention {
|
||||
inner: RustHyperbolic,
|
||||
curvature_value: f32,
|
||||
dim_value: usize,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl HyperbolicAttention {
|
||||
/// Create a new hyperbolic attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `curvature` - Hyperbolic curvature (typically 1.0)
|
||||
#[napi(constructor)]
|
||||
pub fn new(dim: u32, curvature: f64) -> Self {
|
||||
let config = HyperbolicAttentionConfig {
|
||||
dim: dim as usize,
|
||||
curvature: curvature as f32,
|
||||
..Default::default()
|
||||
};
|
||||
Self {
|
||||
inner: RustHyperbolic::new(config),
|
||||
curvature_value: curvature as f32,
|
||||
dim_value: dim as usize,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with full configuration
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `curvature` - Hyperbolic curvature
|
||||
/// * `adaptive_curvature` - Whether to use adaptive curvature
|
||||
/// * `temperature` - Temperature for softmax
|
||||
#[napi(factory)]
|
||||
pub fn with_config(
|
||||
dim: u32,
|
||||
curvature: f64,
|
||||
adaptive_curvature: bool,
|
||||
temperature: f64,
|
||||
) -> Self {
|
||||
let config = HyperbolicAttentionConfig {
|
||||
dim: dim as usize,
|
||||
curvature: curvature as f32,
|
||||
adaptive_curvature,
|
||||
temperature: temperature as f32,
|
||||
frechet_max_iter: 100,
|
||||
frechet_tol: 1e-6,
|
||||
};
|
||||
Self {
|
||||
inner: RustHyperbolic::new(config),
|
||||
curvature_value: curvature as f32,
|
||||
dim_value: dim as usize,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute hyperbolic attention
|
||||
#[napi]
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: Float32Array,
|
||||
keys: Vec<Float32Array>,
|
||||
values: Vec<Float32Array>,
|
||||
) -> Result<Float32Array> {
|
||||
let query_slice = query.as_ref();
|
||||
let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
|
||||
let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = self
|
||||
.inner
|
||||
.compute(query_slice, &keys_refs, &values_refs)
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?;
|
||||
|
||||
Ok(Float32Array::new(result))
|
||||
}
|
||||
|
||||
/// Get the curvature
|
||||
#[napi(getter)]
|
||||
pub fn curvature(&self) -> f64 {
|
||||
self.curvature_value as f64
|
||||
}
|
||||
|
||||
/// Get the dimension
|
||||
#[napi(getter)]
|
||||
pub fn dim(&self) -> u32 {
|
||||
self.dim_value as u32
|
||||
}
|
||||
}
|
||||
|
||||
/// Flash attention with tiled computation
|
||||
#[napi]
|
||||
pub struct FlashAttention {
|
||||
inner: RustFlash,
|
||||
dim_value: usize,
|
||||
block_size_value: usize,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl FlashAttention {
|
||||
/// Create a new flash attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `block_size` - Block size for tiled computation
|
||||
#[napi(constructor)]
|
||||
pub fn new(dim: u32, block_size: u32) -> Self {
|
||||
Self {
|
||||
inner: RustFlash::new(dim as usize, block_size as usize),
|
||||
dim_value: dim as usize,
|
||||
block_size_value: block_size as usize,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute flash attention
|
||||
#[napi]
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: Float32Array,
|
||||
keys: Vec<Float32Array>,
|
||||
values: Vec<Float32Array>,
|
||||
) -> Result<Float32Array> {
|
||||
let query_slice = query.as_ref();
|
||||
let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
|
||||
let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = self
|
||||
.inner
|
||||
.compute(query_slice, &keys_refs, &values_refs)
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?;
|
||||
|
||||
Ok(Float32Array::new(result))
|
||||
}
|
||||
|
||||
/// Get the dimension
|
||||
#[napi(getter)]
|
||||
pub fn dim(&self) -> u32 {
|
||||
self.dim_value as u32
|
||||
}
|
||||
|
||||
/// Get the block size
|
||||
#[napi(getter)]
|
||||
pub fn block_size(&self) -> u32 {
|
||||
self.block_size_value as u32
|
||||
}
|
||||
}
|
||||
|
||||
/// Linear attention (Performer-style) with O(n) complexity
|
||||
#[napi]
|
||||
pub struct LinearAttention {
|
||||
inner: RustLinear,
|
||||
dim_value: usize,
|
||||
num_features_value: usize,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl LinearAttention {
|
||||
/// Create a new linear attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `num_features` - Number of random features
|
||||
#[napi(constructor)]
|
||||
pub fn new(dim: u32, num_features: u32) -> Self {
|
||||
Self {
|
||||
inner: RustLinear::new(dim as usize, num_features as usize),
|
||||
dim_value: dim as usize,
|
||||
num_features_value: num_features as usize,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute linear attention
|
||||
#[napi]
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: Float32Array,
|
||||
keys: Vec<Float32Array>,
|
||||
values: Vec<Float32Array>,
|
||||
) -> Result<Float32Array> {
|
||||
let query_slice = query.as_ref();
|
||||
let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
|
||||
let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = self
|
||||
.inner
|
||||
.compute(query_slice, &keys_refs, &values_refs)
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?;
|
||||
|
||||
Ok(Float32Array::new(result))
|
||||
}
|
||||
|
||||
/// Get the dimension
|
||||
#[napi(getter)]
|
||||
pub fn dim(&self) -> u32 {
|
||||
self.dim_value as u32
|
||||
}
|
||||
|
||||
/// Get the number of random features
|
||||
#[napi(getter)]
|
||||
pub fn num_features(&self) -> u32 {
|
||||
self.num_features_value as u32
|
||||
}
|
||||
}
|
||||
|
||||
/// Local-global attention (Longformer-style)
|
||||
#[napi]
|
||||
pub struct LocalGlobalAttention {
|
||||
inner: RustLocalGlobal,
|
||||
dim_value: usize,
|
||||
local_window_value: usize,
|
||||
global_tokens_value: usize,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl LocalGlobalAttention {
|
||||
/// Create a new local-global attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `local_window` - Size of local attention window
|
||||
/// * `global_tokens` - Number of global attention tokens
|
||||
#[napi(constructor)]
|
||||
pub fn new(dim: u32, local_window: u32, global_tokens: u32) -> Self {
|
||||
Self {
|
||||
inner: RustLocalGlobal::new(
|
||||
dim as usize,
|
||||
local_window as usize,
|
||||
global_tokens as usize,
|
||||
),
|
||||
dim_value: dim as usize,
|
||||
local_window_value: local_window as usize,
|
||||
global_tokens_value: global_tokens as usize,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute local-global attention
|
||||
#[napi]
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: Float32Array,
|
||||
keys: Vec<Float32Array>,
|
||||
values: Vec<Float32Array>,
|
||||
) -> Result<Float32Array> {
|
||||
let query_slice = query.as_ref();
|
||||
let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
|
||||
let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = self
|
||||
.inner
|
||||
.compute(query_slice, &keys_refs, &values_refs)
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?;
|
||||
|
||||
Ok(Float32Array::new(result))
|
||||
}
|
||||
|
||||
/// Get the dimension
|
||||
#[napi(getter)]
|
||||
pub fn dim(&self) -> u32 {
|
||||
self.dim_value as u32
|
||||
}
|
||||
|
||||
/// Get the local window size
|
||||
#[napi(getter)]
|
||||
pub fn local_window(&self) -> u32 {
|
||||
self.local_window_value as u32
|
||||
}
|
||||
|
||||
/// Get the number of global tokens
|
||||
#[napi(getter)]
|
||||
pub fn global_tokens(&self) -> u32 {
|
||||
self.global_tokens_value as u32
|
||||
}
|
||||
}
|
||||
|
||||
/// MoE attention configuration
|
||||
#[napi(object)]
|
||||
pub struct MoEConfig {
|
||||
pub dim: u32,
|
||||
pub num_experts: u32,
|
||||
pub top_k: u32,
|
||||
pub expert_capacity: Option<f64>,
|
||||
}
|
||||
|
||||
/// Mixture of Experts attention
|
||||
#[napi]
|
||||
pub struct MoEAttention {
|
||||
inner: RustMoE,
|
||||
config: MoEConfig,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl MoEAttention {
|
||||
/// Create a new MoE attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `config` - MoE configuration object
|
||||
#[napi(constructor)]
|
||||
pub fn new(config: MoEConfig) -> Self {
|
||||
let rust_config = RustMoEConfig::builder()
|
||||
.dim(config.dim as usize)
|
||||
.num_experts(config.num_experts as usize)
|
||||
.top_k(config.top_k as usize)
|
||||
.expert_capacity(config.expert_capacity.unwrap_or(1.25) as f32)
|
||||
.build();
|
||||
|
||||
Self {
|
||||
inner: RustMoE::new(rust_config),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with simple parameters
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `num_experts` - Number of expert networks
|
||||
/// * `top_k` - Number of experts to route to
|
||||
#[napi(factory)]
|
||||
pub fn simple(dim: u32, num_experts: u32, top_k: u32) -> Self {
|
||||
let config = MoEConfig {
|
||||
dim,
|
||||
num_experts,
|
||||
top_k,
|
||||
expert_capacity: Some(1.25),
|
||||
};
|
||||
Self::new(config)
|
||||
}
|
||||
|
||||
/// Compute MoE attention
|
||||
#[napi]
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: Float32Array,
|
||||
keys: Vec<Float32Array>,
|
||||
values: Vec<Float32Array>,
|
||||
) -> Result<Float32Array> {
|
||||
let query_slice = query.as_ref();
|
||||
let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
|
||||
let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = self
|
||||
.inner
|
||||
.compute(query_slice, &keys_refs, &values_refs)
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?;
|
||||
|
||||
Ok(Float32Array::new(result))
|
||||
}
|
||||
|
||||
/// Get the dimension
|
||||
#[napi(getter)]
|
||||
pub fn dim(&self) -> u32 {
|
||||
self.config.dim
|
||||
}
|
||||
|
||||
/// Get the number of experts
|
||||
#[napi(getter)]
|
||||
pub fn num_experts(&self) -> u32 {
|
||||
self.config.num_experts
|
||||
}
|
||||
|
||||
/// Get the top-k value
|
||||
#[napi(getter)]
|
||||
pub fn top_k(&self) -> u32 {
|
||||
self.config.top_k
|
||||
}
|
||||
}
|
||||
|
||||
// Utility functions
|
||||
|
||||
/// Project a vector into the Poincaré ball
|
||||
#[napi]
|
||||
pub fn project_to_poincare_ball(vector: Float32Array, curvature: f64) -> Float32Array {
|
||||
let v = vector.to_vec();
|
||||
let projected = ruvector_attention::hyperbolic::project_to_ball(&v, curvature as f32, 1e-5);
|
||||
Float32Array::new(projected)
|
||||
}
|
||||
|
||||
/// Compute hyperbolic (Poincaré) distance between two points
|
||||
#[napi]
|
||||
pub fn poincare_distance(a: Float32Array, b: Float32Array, curvature: f64) -> f64 {
|
||||
let a_slice = a.as_ref();
|
||||
let b_slice = b.as_ref();
|
||||
ruvector_attention::hyperbolic::poincare_distance(a_slice, b_slice, curvature as f32) as f64
|
||||
}
|
||||
|
||||
/// Möbius addition in hyperbolic space
|
||||
#[napi]
|
||||
pub fn mobius_addition(a: Float32Array, b: Float32Array, curvature: f64) -> Float32Array {
|
||||
let a_slice = a.as_ref();
|
||||
let b_slice = b.as_ref();
|
||||
let result = ruvector_attention::hyperbolic::mobius_add(a_slice, b_slice, curvature as f32);
|
||||
Float32Array::new(result)
|
||||
}
|
||||
|
||||
/// Exponential map from tangent space to hyperbolic space
|
||||
#[napi]
|
||||
pub fn exp_map(base: Float32Array, tangent: Float32Array, curvature: f64) -> Float32Array {
|
||||
let base_slice = base.as_ref();
|
||||
let tangent_slice = tangent.as_ref();
|
||||
let result =
|
||||
ruvector_attention::hyperbolic::exp_map(base_slice, tangent_slice, curvature as f32);
|
||||
Float32Array::new(result)
|
||||
}
|
||||
|
||||
/// Logarithmic map from hyperbolic space to tangent space
|
||||
#[napi]
|
||||
pub fn log_map(base: Float32Array, point: Float32Array, curvature: f64) -> Float32Array {
|
||||
let base_slice = base.as_ref();
|
||||
let point_slice = point.as_ref();
|
||||
let result = ruvector_attention::hyperbolic::log_map(base_slice, point_slice, curvature as f32);
|
||||
Float32Array::new(result)
|
||||
}
|
||||
430
vendor/ruvector/crates/ruvector-attention-node/src/graph.rs
vendored
Normal file
430
vendor/ruvector/crates/ruvector-attention-node/src/graph.rs
vendored
Normal file
@@ -0,0 +1,430 @@
|
||||
//! NAPI-RS bindings for graph attention mechanisms
|
||||
//!
|
||||
//! Provides Node.js bindings for:
|
||||
//! - Edge-featured attention (GATv2-style)
|
||||
//! - Graph RoPE (Rotary Position Embeddings for graphs)
|
||||
//! - Dual-space attention (Euclidean + Hyperbolic)
|
||||
|
||||
use napi::bindgen_prelude::*;
|
||||
use napi_derive::napi;
|
||||
use ruvector_attention::graph::{
|
||||
DualSpaceAttention as RustDualSpace, DualSpaceConfig as RustDualConfig,
|
||||
EdgeFeaturedAttention as RustEdgeFeatured, EdgeFeaturedConfig as RustEdgeConfig,
|
||||
GraphRoPE as RustGraphRoPE, RoPEConfig as RustRoPEConfig,
|
||||
};
|
||||
use ruvector_attention::traits::Attention;
|
||||
|
||||
// ============================================================================
|
||||
// Edge-Featured Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Configuration for edge-featured attention
|
||||
#[napi(object)]
|
||||
pub struct EdgeFeaturedConfig {
|
||||
pub node_dim: u32,
|
||||
pub edge_dim: u32,
|
||||
pub num_heads: u32,
|
||||
pub concat_heads: Option<bool>,
|
||||
pub add_self_loops: Option<bool>,
|
||||
pub negative_slope: Option<f64>,
|
||||
}
|
||||
|
||||
/// Edge-featured attention (GATv2-style)
|
||||
#[napi]
|
||||
pub struct EdgeFeaturedAttention {
|
||||
inner: RustEdgeFeatured,
|
||||
config: EdgeFeaturedConfig,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl EdgeFeaturedAttention {
|
||||
/// Create a new edge-featured attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `config` - Edge-featured attention configuration
|
||||
#[napi(constructor)]
|
||||
pub fn new(config: EdgeFeaturedConfig) -> Self {
|
||||
let rust_config = RustEdgeConfig {
|
||||
node_dim: config.node_dim as usize,
|
||||
edge_dim: config.edge_dim as usize,
|
||||
num_heads: config.num_heads as usize,
|
||||
concat_heads: config.concat_heads.unwrap_or(true),
|
||||
add_self_loops: config.add_self_loops.unwrap_or(true),
|
||||
negative_slope: config.negative_slope.unwrap_or(0.2) as f32,
|
||||
dropout: 0.0,
|
||||
};
|
||||
Self {
|
||||
inner: RustEdgeFeatured::new(rust_config),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with simple parameters
|
||||
#[napi(factory)]
|
||||
pub fn simple(node_dim: u32, edge_dim: u32, num_heads: u32) -> Self {
|
||||
Self::new(EdgeFeaturedConfig {
|
||||
node_dim,
|
||||
edge_dim,
|
||||
num_heads,
|
||||
concat_heads: Some(true),
|
||||
add_self_loops: Some(true),
|
||||
negative_slope: Some(0.2),
|
||||
})
|
||||
}
|
||||
|
||||
/// Compute attention without edge features (standard attention)
|
||||
#[napi]
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: Float32Array,
|
||||
keys: Vec<Float32Array>,
|
||||
values: Vec<Float32Array>,
|
||||
) -> Result<Float32Array> {
|
||||
let query_slice = query.as_ref();
|
||||
let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
|
||||
let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = self
|
||||
.inner
|
||||
.compute(query_slice, &keys_refs, &values_refs)
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?;
|
||||
|
||||
Ok(Float32Array::new(result))
|
||||
}
|
||||
|
||||
/// Compute attention with edge features
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Query vector
|
||||
/// * `keys` - Array of key vectors
|
||||
/// * `values` - Array of value vectors
|
||||
/// * `edge_features` - Array of edge feature vectors (same length as keys)
|
||||
#[napi]
|
||||
pub fn compute_with_edges(
|
||||
&self,
|
||||
query: Float32Array,
|
||||
keys: Vec<Float32Array>,
|
||||
values: Vec<Float32Array>,
|
||||
edge_features: Vec<Float32Array>,
|
||||
) -> Result<Float32Array> {
|
||||
let query_slice = query.as_ref();
|
||||
let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
|
||||
let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
|
||||
let edge_features_vec: Vec<Vec<f32>> =
|
||||
edge_features.into_iter().map(|e| e.to_vec()).collect();
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
let edges_refs: Vec<&[f32]> = edge_features_vec.iter().map(|e| e.as_slice()).collect();
|
||||
|
||||
let result = self
|
||||
.inner
|
||||
.compute_with_edges(query_slice, &keys_refs, &values_refs, &edges_refs)
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?;
|
||||
|
||||
Ok(Float32Array::new(result))
|
||||
}
|
||||
|
||||
/// Get the node dimension
|
||||
#[napi(getter)]
|
||||
pub fn node_dim(&self) -> u32 {
|
||||
self.config.node_dim
|
||||
}
|
||||
|
||||
/// Get the edge dimension
|
||||
#[napi(getter)]
|
||||
pub fn edge_dim(&self) -> u32 {
|
||||
self.config.edge_dim
|
||||
}
|
||||
|
||||
/// Get the number of heads
|
||||
#[napi(getter)]
|
||||
pub fn num_heads(&self) -> u32 {
|
||||
self.config.num_heads
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Graph RoPE Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Configuration for Graph RoPE attention
|
||||
#[napi(object)]
|
||||
pub struct RoPEConfig {
|
||||
pub dim: u32,
|
||||
pub max_position: u32,
|
||||
pub base: Option<f64>,
|
||||
pub scaling_factor: Option<f64>,
|
||||
}
|
||||
|
||||
/// Graph RoPE attention (Rotary Position Embeddings for graphs)
|
||||
#[napi]
|
||||
pub struct GraphRoPEAttention {
|
||||
inner: RustGraphRoPE,
|
||||
config: RoPEConfig,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl GraphRoPEAttention {
|
||||
/// Create a new Graph RoPE attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `config` - RoPE configuration
|
||||
#[napi(constructor)]
|
||||
pub fn new(config: RoPEConfig) -> Self {
|
||||
let rust_config = RustRoPEConfig {
|
||||
dim: config.dim as usize,
|
||||
max_position: config.max_position as usize,
|
||||
base: config.base.unwrap_or(10000.0) as f32,
|
||||
scaling_factor: config.scaling_factor.unwrap_or(1.0) as f32,
|
||||
};
|
||||
Self {
|
||||
inner: RustGraphRoPE::new(rust_config),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with simple parameters
|
||||
#[napi(factory)]
|
||||
pub fn simple(dim: u32, max_position: u32) -> Self {
|
||||
Self::new(RoPEConfig {
|
||||
dim,
|
||||
max_position,
|
||||
base: Some(10000.0),
|
||||
scaling_factor: Some(1.0),
|
||||
})
|
||||
}
|
||||
|
||||
/// Compute attention without positional encoding
|
||||
#[napi]
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: Float32Array,
|
||||
keys: Vec<Float32Array>,
|
||||
values: Vec<Float32Array>,
|
||||
) -> Result<Float32Array> {
|
||||
let query_slice = query.as_ref();
|
||||
let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
|
||||
let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = self
|
||||
.inner
|
||||
.compute(query_slice, &keys_refs, &values_refs)
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?;
|
||||
|
||||
Ok(Float32Array::new(result))
|
||||
}
|
||||
|
||||
/// Compute attention with graph positions
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Query vector
|
||||
/// * `keys` - Array of key vectors
|
||||
/// * `values` - Array of value vectors
|
||||
/// * `query_position` - Position of query node
|
||||
/// * `key_positions` - Positions of key nodes (e.g., hop distances)
|
||||
#[napi]
|
||||
pub fn compute_with_positions(
|
||||
&self,
|
||||
query: Float32Array,
|
||||
keys: Vec<Float32Array>,
|
||||
values: Vec<Float32Array>,
|
||||
query_position: u32,
|
||||
key_positions: Vec<u32>,
|
||||
) -> Result<Float32Array> {
|
||||
let query_slice = query.as_ref();
|
||||
let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
|
||||
let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
let positions_usize: Vec<usize> = key_positions.into_iter().map(|p| p as usize).collect();
|
||||
|
||||
let result = self
|
||||
.inner
|
||||
.compute_with_positions(
|
||||
query_slice,
|
||||
&keys_refs,
|
||||
&values_refs,
|
||||
query_position as usize,
|
||||
&positions_usize,
|
||||
)
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?;
|
||||
|
||||
Ok(Float32Array::new(result))
|
||||
}
|
||||
|
||||
/// Apply rotary embedding to a vector
|
||||
#[napi]
|
||||
pub fn apply_rotary(&self, vector: Float32Array, position: u32) -> Float32Array {
|
||||
let v = vector.as_ref();
|
||||
let result = self.inner.apply_rotary(v, position as usize);
|
||||
Float32Array::new(result)
|
||||
}
|
||||
|
||||
/// Convert graph distance to position bucket
|
||||
#[napi]
|
||||
pub fn distance_to_position(distance: u32, max_distance: u32) -> u32 {
|
||||
RustGraphRoPE::distance_to_position(distance as usize, max_distance as usize) as u32
|
||||
}
|
||||
|
||||
/// Get the dimension
|
||||
#[napi(getter)]
|
||||
pub fn dim(&self) -> u32 {
|
||||
self.config.dim
|
||||
}
|
||||
|
||||
/// Get the max position
|
||||
#[napi(getter)]
|
||||
pub fn max_position(&self) -> u32 {
|
||||
self.config.max_position
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Dual-Space Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Configuration for dual-space attention
|
||||
#[napi(object)]
|
||||
pub struct DualSpaceConfig {
|
||||
pub dim: u32,
|
||||
pub curvature: f64,
|
||||
pub euclidean_weight: f64,
|
||||
pub hyperbolic_weight: f64,
|
||||
pub temperature: Option<f64>,
|
||||
}
|
||||
|
||||
/// Dual-space attention (Euclidean + Hyperbolic)
|
||||
#[napi]
|
||||
pub struct DualSpaceAttention {
|
||||
inner: RustDualSpace,
|
||||
config: DualSpaceConfig,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl DualSpaceAttention {
|
||||
/// Create a new dual-space attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `config` - Dual-space configuration
|
||||
#[napi(constructor)]
|
||||
pub fn new(config: DualSpaceConfig) -> Self {
|
||||
let rust_config = RustDualConfig {
|
||||
dim: config.dim as usize,
|
||||
curvature: config.curvature as f32,
|
||||
euclidean_weight: config.euclidean_weight as f32,
|
||||
hyperbolic_weight: config.hyperbolic_weight as f32,
|
||||
learn_weights: false,
|
||||
temperature: config.temperature.unwrap_or(1.0) as f32,
|
||||
};
|
||||
Self {
|
||||
inner: RustDualSpace::new(rust_config),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with simple parameters (equal weights)
|
||||
#[napi(factory)]
|
||||
pub fn simple(dim: u32, curvature: f64) -> Self {
|
||||
Self::new(DualSpaceConfig {
|
||||
dim,
|
||||
curvature,
|
||||
euclidean_weight: 0.5,
|
||||
hyperbolic_weight: 0.5,
|
||||
temperature: Some(1.0),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with custom weights
|
||||
#[napi(factory)]
|
||||
pub fn with_weights(
|
||||
dim: u32,
|
||||
curvature: f64,
|
||||
euclidean_weight: f64,
|
||||
hyperbolic_weight: f64,
|
||||
) -> Self {
|
||||
Self::new(DualSpaceConfig {
|
||||
dim,
|
||||
curvature,
|
||||
euclidean_weight,
|
||||
hyperbolic_weight,
|
||||
temperature: Some(1.0),
|
||||
})
|
||||
}
|
||||
|
||||
/// Compute dual-space attention
|
||||
#[napi]
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: Float32Array,
|
||||
keys: Vec<Float32Array>,
|
||||
values: Vec<Float32Array>,
|
||||
) -> Result<Float32Array> {
|
||||
let query_slice = query.as_ref();
|
||||
let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
|
||||
let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = self
|
||||
.inner
|
||||
.compute(query_slice, &keys_refs, &values_refs)
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?;
|
||||
|
||||
Ok(Float32Array::new(result))
|
||||
}
|
||||
|
||||
/// Get space contributions (Euclidean and Hyperbolic scores separately)
|
||||
#[napi]
|
||||
pub fn get_space_contributions(
|
||||
&self,
|
||||
query: Float32Array,
|
||||
keys: Vec<Float32Array>,
|
||||
) -> SpaceContributions {
|
||||
let query_slice = query.as_ref();
|
||||
let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
|
||||
let (euc_scores, hyp_scores) = self.inner.get_space_contributions(query_slice, &keys_refs);
|
||||
|
||||
SpaceContributions {
|
||||
euclidean_scores: Float32Array::new(euc_scores),
|
||||
hyperbolic_scores: Float32Array::new(hyp_scores),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the dimension
|
||||
#[napi(getter)]
|
||||
pub fn dim(&self) -> u32 {
|
||||
self.config.dim
|
||||
}
|
||||
|
||||
/// Get the curvature
|
||||
#[napi(getter)]
|
||||
pub fn curvature(&self) -> f64 {
|
||||
self.config.curvature
|
||||
}
|
||||
|
||||
/// Get the Euclidean weight
|
||||
#[napi(getter)]
|
||||
pub fn euclidean_weight(&self) -> f64 {
|
||||
self.config.euclidean_weight
|
||||
}
|
||||
|
||||
/// Get the Hyperbolic weight
|
||||
#[napi(getter)]
|
||||
pub fn hyperbolic_weight(&self) -> f64 {
|
||||
self.config.hyperbolic_weight
|
||||
}
|
||||
}
|
||||
|
||||
/// Space contribution scores
|
||||
#[napi(object)]
|
||||
pub struct SpaceContributions {
|
||||
pub euclidean_scores: Float32Array,
|
||||
pub hyperbolic_scores: Float32Array,
|
||||
}
|
||||
84
vendor/ruvector/crates/ruvector-attention-node/src/lib.rs
vendored
Normal file
84
vendor/ruvector/crates/ruvector-attention-node/src/lib.rs
vendored
Normal file
@@ -0,0 +1,84 @@
|
||||
//! ruvector-attention-node
|
||||
//!
|
||||
//! Node.js bindings for ruvector-attention via NAPI-RS
|
||||
//!
|
||||
//! This crate provides comprehensive Node.js bindings for:
|
||||
//! - Attention mechanisms (dot-product, multi-head, hyperbolic, flash, linear, local-global, MoE)
|
||||
//! - Training utilities (loss functions, optimizers, schedulers)
|
||||
//! - Async/batch processing
|
||||
//! - Graph attention mechanisms
|
||||
//! - Benchmarking utilities
|
||||
|
||||
#![deny(clippy::all)]
|
||||
|
||||
use napi_derive::napi;
|
||||
|
||||
pub mod async_ops;
|
||||
pub mod attention;
|
||||
pub mod graph;
|
||||
pub mod training;
|
||||
|
||||
// Re-export main attention types
|
||||
pub use attention::{
|
||||
AttentionConfig, DotProductAttention, FlashAttention, HyperbolicAttention, LinearAttention,
|
||||
LocalGlobalAttention, MoEAttention, MoEConfig, MultiHeadAttention,
|
||||
};
|
||||
|
||||
// Re-export training types
|
||||
pub use training::{
|
||||
AdamOptimizer, AdamWOptimizer, CurriculumScheduler, CurriculumStageConfig, DecayType,
|
||||
HardNegativeMiner, InBatchMiner, InfoNCELoss, LearningRateScheduler, LocalContrastiveLoss,
|
||||
LossWithGradients, MiningStrategy, SGDOptimizer, SpectralRegularization, TemperatureAnnealing,
|
||||
};
|
||||
|
||||
// Re-export async/batch types
|
||||
pub use async_ops::{
|
||||
AttentionType, BatchConfig, BatchResult, BenchmarkResult, ParallelConfig, StreamProcessor,
|
||||
};
|
||||
|
||||
// Re-export graph attention types
|
||||
pub use graph::{
|
||||
DualSpaceAttention, DualSpaceConfig, EdgeFeaturedAttention, EdgeFeaturedConfig,
|
||||
GraphRoPEAttention, RoPEConfig,
|
||||
};
|
||||
|
||||
/// Get library version
|
||||
#[napi]
|
||||
pub fn version() -> String {
|
||||
env!("CARGO_PKG_VERSION").to_string()
|
||||
}
|
||||
|
||||
/// Get library info
|
||||
#[napi]
|
||||
pub fn info() -> LibraryInfo {
|
||||
LibraryInfo {
|
||||
name: "ruvector-attention-node".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
description: "Node.js bindings for ruvector-attention".to_string(),
|
||||
features: vec![
|
||||
"scaled-dot-product".to_string(),
|
||||
"multi-head".to_string(),
|
||||
"hyperbolic".to_string(),
|
||||
"flash".to_string(),
|
||||
"linear".to_string(),
|
||||
"local-global".to_string(),
|
||||
"moe".to_string(),
|
||||
"edge-featured".to_string(),
|
||||
"graph-rope".to_string(),
|
||||
"dual-space".to_string(),
|
||||
"training".to_string(),
|
||||
"async".to_string(),
|
||||
"batch".to_string(),
|
||||
"benchmark".to_string(),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
/// Library information
|
||||
#[napi(object)]
|
||||
pub struct LibraryInfo {
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
pub description: String,
|
||||
pub features: Vec<String>,
|
||||
}
|
||||
851
vendor/ruvector/crates/ruvector-attention-node/src/training.rs
vendored
Normal file
851
vendor/ruvector/crates/ruvector-attention-node/src/training.rs
vendored
Normal file
@@ -0,0 +1,851 @@
|
||||
//! NAPI-RS bindings for training utilities
|
||||
//!
|
||||
//! Provides Node.js bindings for:
|
||||
//! - Loss functions (InfoNCE, LocalContrastive, SpectralRegularization)
|
||||
//! - Optimizers (SGD, Adam, AdamW)
|
||||
//! - Learning rate schedulers
|
||||
//! - Curriculum learning
|
||||
//! - Negative mining
|
||||
|
||||
use napi::bindgen_prelude::*;
|
||||
use napi_derive::napi;
|
||||
use ruvector_attention::training::{
|
||||
Adam as RustAdam, AdamW as RustAdamW, CurriculumScheduler as RustCurriculum,
|
||||
CurriculumStage as RustStage, DecayType as RustDecayType, HardNegativeMiner as RustHardMiner,
|
||||
InfoNCELoss as RustInfoNCE, LocalContrastiveLoss as RustLocalContrastive, Loss,
|
||||
MiningStrategy as RustMiningStrategy, NegativeMiner, Optimizer,
|
||||
SpectralRegularization as RustSpectralReg, TemperatureAnnealing as RustTempAnnealing,
|
||||
SGD as RustSGD,
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Loss Functions
|
||||
// ============================================================================
|
||||
|
||||
/// InfoNCE contrastive loss for representation learning
|
||||
#[napi]
|
||||
pub struct InfoNCELoss {
|
||||
inner: RustInfoNCE,
|
||||
temperature_value: f32,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl InfoNCELoss {
|
||||
/// Create a new InfoNCE loss instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `temperature` - Temperature parameter for softmax (typically 0.07-0.1)
|
||||
#[napi(constructor)]
|
||||
pub fn new(temperature: f64) -> Self {
|
||||
Self {
|
||||
inner: RustInfoNCE::new(temperature as f32),
|
||||
temperature_value: temperature as f32,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute InfoNCE loss
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `anchor` - Anchor embedding
|
||||
/// * `positive` - Positive example embedding
|
||||
/// * `negatives` - Array of negative example embeddings
|
||||
#[napi]
|
||||
pub fn compute(
|
||||
&self,
|
||||
anchor: Float32Array,
|
||||
positive: Float32Array,
|
||||
negatives: Vec<Float32Array>,
|
||||
) -> f64 {
|
||||
let anchor_slice = anchor.as_ref();
|
||||
let positive_slice = positive.as_ref();
|
||||
let negatives_vec: Vec<Vec<f32>> = negatives.into_iter().map(|n| n.to_vec()).collect();
|
||||
let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(anchor_slice, positive_slice, &negatives_refs) as f64
|
||||
}
|
||||
|
||||
/// Compute InfoNCE loss with gradients
|
||||
///
|
||||
/// Returns an object with `loss` and `gradients` fields
|
||||
#[napi]
|
||||
pub fn compute_with_gradients(
|
||||
&self,
|
||||
anchor: Float32Array,
|
||||
positive: Float32Array,
|
||||
negatives: Vec<Float32Array>,
|
||||
) -> LossWithGradients {
|
||||
let anchor_slice = anchor.as_ref();
|
||||
let positive_slice = positive.as_ref();
|
||||
let negatives_vec: Vec<Vec<f32>> = negatives.into_iter().map(|n| n.to_vec()).collect();
|
||||
let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
|
||||
|
||||
let (loss, gradients) =
|
||||
self.inner
|
||||
.compute_with_gradients(anchor_slice, positive_slice, &negatives_refs);
|
||||
|
||||
LossWithGradients {
|
||||
loss: loss as f64,
|
||||
gradients: Float32Array::new(gradients),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the temperature
|
||||
#[napi(getter)]
|
||||
pub fn temperature(&self) -> f64 {
|
||||
self.temperature_value as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Loss computation result with gradients
|
||||
#[napi(object)]
|
||||
pub struct LossWithGradients {
|
||||
pub loss: f64,
|
||||
pub gradients: Float32Array,
|
||||
}
|
||||
|
||||
/// Local contrastive loss for neighborhood preservation
|
||||
#[napi]
|
||||
pub struct LocalContrastiveLoss {
|
||||
inner: RustLocalContrastive,
|
||||
margin_value: f32,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl LocalContrastiveLoss {
|
||||
/// Create a new local contrastive loss instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `margin` - Margin for triplet loss
|
||||
#[napi(constructor)]
|
||||
pub fn new(margin: f64) -> Self {
|
||||
Self {
|
||||
inner: RustLocalContrastive::new(margin as f32),
|
||||
margin_value: margin as f32,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute local contrastive loss
|
||||
#[napi]
|
||||
pub fn compute(
|
||||
&self,
|
||||
anchor: Float32Array,
|
||||
positive: Float32Array,
|
||||
negatives: Vec<Float32Array>,
|
||||
) -> f64 {
|
||||
let anchor_slice = anchor.as_ref();
|
||||
let positive_slice = positive.as_ref();
|
||||
let negatives_vec: Vec<Vec<f32>> = negatives.into_iter().map(|n| n.to_vec()).collect();
|
||||
let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(anchor_slice, positive_slice, &negatives_refs) as f64
|
||||
}
|
||||
|
||||
/// Compute with gradients
|
||||
#[napi]
|
||||
pub fn compute_with_gradients(
|
||||
&self,
|
||||
anchor: Float32Array,
|
||||
positive: Float32Array,
|
||||
negatives: Vec<Float32Array>,
|
||||
) -> LossWithGradients {
|
||||
let anchor_slice = anchor.as_ref();
|
||||
let positive_slice = positive.as_ref();
|
||||
let negatives_vec: Vec<Vec<f32>> = negatives.into_iter().map(|n| n.to_vec()).collect();
|
||||
let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
|
||||
|
||||
let (loss, gradients) =
|
||||
self.inner
|
||||
.compute_with_gradients(anchor_slice, positive_slice, &negatives_refs);
|
||||
|
||||
LossWithGradients {
|
||||
loss: loss as f64,
|
||||
gradients: Float32Array::new(gradients),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the margin
|
||||
#[napi(getter)]
|
||||
pub fn margin(&self) -> f64 {
|
||||
self.margin_value as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Spectral regularization for smooth representations
|
||||
#[napi]
|
||||
pub struct SpectralRegularization {
|
||||
inner: RustSpectralReg,
|
||||
weight_value: f32,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl SpectralRegularization {
|
||||
/// Create a new spectral regularization instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `weight` - Regularization weight
|
||||
#[napi(constructor)]
|
||||
pub fn new(weight: f64) -> Self {
|
||||
Self {
|
||||
inner: RustSpectralReg::new(weight as f32),
|
||||
weight_value: weight as f32,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute spectral regularization for a batch of embeddings
|
||||
#[napi]
|
||||
pub fn compute_batch(&self, embeddings: Vec<Float32Array>) -> f64 {
|
||||
let embeddings_vec: Vec<Vec<f32>> = embeddings.into_iter().map(|e| e.to_vec()).collect();
|
||||
let embeddings_refs: Vec<&[f32]> = embeddings_vec.iter().map(|e| e.as_slice()).collect();
|
||||
|
||||
self.inner.compute_batch(&embeddings_refs) as f64
|
||||
}
|
||||
|
||||
/// Get the weight
|
||||
#[napi(getter)]
|
||||
pub fn weight(&self) -> f64 {
|
||||
self.weight_value as f64
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Optimizers
|
||||
// ============================================================================
|
||||
|
||||
/// SGD optimizer with optional momentum and weight decay
|
||||
#[napi]
|
||||
pub struct SGDOptimizer {
|
||||
inner: RustSGD,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl SGDOptimizer {
|
||||
/// Create a new SGD optimizer
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `param_count` - Number of parameters
|
||||
/// * `learning_rate` - Learning rate
|
||||
#[napi(constructor)]
|
||||
pub fn new(param_count: u32, learning_rate: f64) -> Self {
|
||||
Self {
|
||||
inner: RustSGD::new(param_count as usize, learning_rate as f32),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with momentum
|
||||
#[napi(factory)]
|
||||
pub fn with_momentum(param_count: u32, learning_rate: f64, momentum: f64) -> Self {
|
||||
Self {
|
||||
inner: RustSGD::new(param_count as usize, learning_rate as f32)
|
||||
.with_momentum(momentum as f32),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with momentum and weight decay
|
||||
#[napi(factory)]
|
||||
pub fn with_weight_decay(
|
||||
param_count: u32,
|
||||
learning_rate: f64,
|
||||
momentum: f64,
|
||||
weight_decay: f64,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: RustSGD::new(param_count as usize, learning_rate as f32)
|
||||
.with_momentum(momentum as f32)
|
||||
.with_weight_decay(weight_decay as f32),
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform an optimization step
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `params` - Parameter array
|
||||
/// * `gradients` - Gradient array
|
||||
///
|
||||
/// # Returns
|
||||
/// Updated parameter array
|
||||
#[napi]
|
||||
pub fn step(&mut self, params: Float32Array, gradients: Float32Array) -> Float32Array {
|
||||
let mut params_vec = params.to_vec();
|
||||
let gradients_slice = gradients.as_ref();
|
||||
self.inner.step(&mut params_vec, gradients_slice);
|
||||
Float32Array::new(params_vec)
|
||||
}
|
||||
|
||||
/// Reset optimizer state
|
||||
#[napi]
|
||||
pub fn reset(&mut self) {
|
||||
self.inner.reset();
|
||||
}
|
||||
|
||||
/// Get current learning rate
|
||||
#[napi(getter)]
|
||||
pub fn learning_rate(&self) -> f64 {
|
||||
self.inner.learning_rate() as f64
|
||||
}
|
||||
|
||||
/// Set learning rate
|
||||
#[napi(setter)]
|
||||
pub fn set_learning_rate(&mut self, lr: f64) {
|
||||
self.inner.set_learning_rate(lr as f32);
|
||||
}
|
||||
}
|
||||
|
||||
/// Adam optimizer with bias correction
|
||||
#[napi]
|
||||
pub struct AdamOptimizer {
|
||||
inner: RustAdam,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl AdamOptimizer {
|
||||
/// Create a new Adam optimizer
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `param_count` - Number of parameters
|
||||
/// * `learning_rate` - Learning rate
|
||||
#[napi(constructor)]
|
||||
pub fn new(param_count: u32, learning_rate: f64) -> Self {
|
||||
Self {
|
||||
inner: RustAdam::new(param_count as usize, learning_rate as f32),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with custom betas
|
||||
#[napi(factory)]
|
||||
pub fn with_betas(param_count: u32, learning_rate: f64, beta1: f64, beta2: f64) -> Self {
|
||||
Self {
|
||||
inner: RustAdam::new(param_count as usize, learning_rate as f32)
|
||||
.with_betas(beta1 as f32, beta2 as f32),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with full configuration
|
||||
#[napi(factory)]
|
||||
pub fn with_config(
|
||||
param_count: u32,
|
||||
learning_rate: f64,
|
||||
beta1: f64,
|
||||
beta2: f64,
|
||||
epsilon: f64,
|
||||
weight_decay: f64,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: RustAdam::new(param_count as usize, learning_rate as f32)
|
||||
.with_betas(beta1 as f32, beta2 as f32)
|
||||
.with_epsilon(epsilon as f32)
|
||||
.with_weight_decay(weight_decay as f32),
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform an optimization step
|
||||
///
|
||||
/// # Returns
|
||||
/// Updated parameter array
|
||||
#[napi]
|
||||
pub fn step(&mut self, params: Float32Array, gradients: Float32Array) -> Float32Array {
|
||||
let mut params_vec = params.to_vec();
|
||||
let gradients_slice = gradients.as_ref();
|
||||
self.inner.step(&mut params_vec, gradients_slice);
|
||||
Float32Array::new(params_vec)
|
||||
}
|
||||
|
||||
/// Reset optimizer state (momentum terms)
|
||||
#[napi]
|
||||
pub fn reset(&mut self) {
|
||||
self.inner.reset();
|
||||
}
|
||||
|
||||
/// Get current learning rate
|
||||
#[napi(getter)]
|
||||
pub fn learning_rate(&self) -> f64 {
|
||||
self.inner.learning_rate() as f64
|
||||
}
|
||||
|
||||
/// Set learning rate
|
||||
#[napi(setter)]
|
||||
pub fn set_learning_rate(&mut self, lr: f64) {
|
||||
self.inner.set_learning_rate(lr as f32);
|
||||
}
|
||||
}
|
||||
|
||||
/// AdamW optimizer (Adam with decoupled weight decay)
|
||||
#[napi]
|
||||
pub struct AdamWOptimizer {
|
||||
inner: RustAdamW,
|
||||
wd: f32,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl AdamWOptimizer {
|
||||
/// Create a new AdamW optimizer
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `param_count` - Number of parameters
|
||||
/// * `learning_rate` - Learning rate
|
||||
/// * `weight_decay` - Weight decay coefficient
|
||||
#[napi(constructor)]
|
||||
pub fn new(param_count: u32, learning_rate: f64, weight_decay: f64) -> Self {
|
||||
Self {
|
||||
inner: RustAdamW::new(param_count as usize, learning_rate as f32)
|
||||
.with_weight_decay(weight_decay as f32),
|
||||
wd: weight_decay as f32,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with custom betas
|
||||
#[napi(factory)]
|
||||
pub fn with_betas(
|
||||
param_count: u32,
|
||||
learning_rate: f64,
|
||||
weight_decay: f64,
|
||||
beta1: f64,
|
||||
beta2: f64,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: RustAdamW::new(param_count as usize, learning_rate as f32)
|
||||
.with_weight_decay(weight_decay as f32)
|
||||
.with_betas(beta1 as f32, beta2 as f32),
|
||||
wd: weight_decay as f32,
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform an optimization step
|
||||
///
|
||||
/// # Returns
|
||||
/// Updated parameter array
|
||||
#[napi]
|
||||
pub fn step(&mut self, params: Float32Array, gradients: Float32Array) -> Float32Array {
|
||||
let mut params_vec = params.to_vec();
|
||||
let gradients_slice = gradients.as_ref();
|
||||
self.inner.step(&mut params_vec, gradients_slice);
|
||||
Float32Array::new(params_vec)
|
||||
}
|
||||
|
||||
/// Reset optimizer state
|
||||
#[napi]
|
||||
pub fn reset(&mut self) {
|
||||
self.inner.reset();
|
||||
}
|
||||
|
||||
/// Get current learning rate
|
||||
#[napi(getter)]
|
||||
pub fn learning_rate(&self) -> f64 {
|
||||
self.inner.learning_rate() as f64
|
||||
}
|
||||
|
||||
/// Set learning rate
|
||||
#[napi(setter)]
|
||||
pub fn set_learning_rate(&mut self, lr: f64) {
|
||||
self.inner.set_learning_rate(lr as f32);
|
||||
}
|
||||
|
||||
/// Get weight decay
|
||||
#[napi(getter)]
|
||||
pub fn weight_decay(&self) -> f64 {
|
||||
self.wd as f64
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Learning Rate Scheduling
|
||||
// ============================================================================
|
||||
|
||||
/// Learning rate scheduler with warmup and cosine decay
|
||||
#[napi]
|
||||
pub struct LearningRateScheduler {
|
||||
initial_lr: f32,
|
||||
current_step: usize,
|
||||
warmup_steps: usize,
|
||||
total_steps: usize,
|
||||
min_lr: f32,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl LearningRateScheduler {
|
||||
/// Create a new learning rate scheduler
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `initial_lr` - Initial/peak learning rate
|
||||
/// * `warmup_steps` - Number of warmup steps
|
||||
/// * `total_steps` - Total training steps
|
||||
#[napi(constructor)]
|
||||
pub fn new(initial_lr: f64, warmup_steps: u32, total_steps: u32) -> Self {
|
||||
Self {
|
||||
initial_lr: initial_lr as f32,
|
||||
current_step: 0,
|
||||
warmup_steps: warmup_steps as usize,
|
||||
total_steps: total_steps as usize,
|
||||
min_lr: 1e-7,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with minimum learning rate
|
||||
#[napi(factory)]
|
||||
pub fn with_min_lr(initial_lr: f64, warmup_steps: u32, total_steps: u32, min_lr: f64) -> Self {
|
||||
Self {
|
||||
initial_lr: initial_lr as f32,
|
||||
current_step: 0,
|
||||
warmup_steps: warmup_steps as usize,
|
||||
total_steps: total_steps as usize,
|
||||
min_lr: min_lr as f32,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get learning rate for current step
|
||||
#[napi]
|
||||
pub fn get_lr(&self) -> f64 {
|
||||
if self.current_step < self.warmup_steps {
|
||||
// Linear warmup
|
||||
(self.initial_lr * (self.current_step + 1) as f32 / self.warmup_steps as f32) as f64
|
||||
} else {
|
||||
// Cosine decay
|
||||
let progress = (self.current_step - self.warmup_steps) as f32
|
||||
/ (self.total_steps - self.warmup_steps).max(1) as f32;
|
||||
let decay = 0.5 * (1.0 + (std::f32::consts::PI * progress.min(1.0)).cos());
|
||||
(self.min_lr + (self.initial_lr - self.min_lr) * decay) as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Step the scheduler and return current learning rate
|
||||
#[napi]
|
||||
pub fn step(&mut self) -> f64 {
|
||||
let lr = self.get_lr();
|
||||
self.current_step += 1;
|
||||
lr
|
||||
}
|
||||
|
||||
/// Reset scheduler to initial state
|
||||
#[napi]
|
||||
pub fn reset(&mut self) {
|
||||
self.current_step = 0;
|
||||
}
|
||||
|
||||
/// Get current step
|
||||
#[napi(getter)]
|
||||
pub fn current_step(&self) -> u32 {
|
||||
self.current_step as u32
|
||||
}
|
||||
|
||||
/// Get progress (0.0 to 1.0)
|
||||
#[napi(getter)]
|
||||
pub fn progress(&self) -> f64 {
|
||||
(self.current_step as f64 / self.total_steps.max(1) as f64).min(1.0)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Temperature Annealing
|
||||
// ============================================================================
|
||||
|
||||
/// Decay type for temperature annealing
|
||||
#[napi(string_enum)]
|
||||
pub enum DecayType {
|
||||
Linear,
|
||||
Exponential,
|
||||
Cosine,
|
||||
Step,
|
||||
}
|
||||
|
||||
impl From<DecayType> for RustDecayType {
|
||||
fn from(dt: DecayType) -> Self {
|
||||
match dt {
|
||||
DecayType::Linear => RustDecayType::Linear,
|
||||
DecayType::Exponential => RustDecayType::Exponential,
|
||||
DecayType::Cosine => RustDecayType::Cosine,
|
||||
DecayType::Step => RustDecayType::Step,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Temperature annealing scheduler
|
||||
#[napi]
|
||||
pub struct TemperatureAnnealing {
|
||||
inner: RustTempAnnealing,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl TemperatureAnnealing {
|
||||
/// Create a new temperature annealing scheduler
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `initial_temp` - Starting temperature
|
||||
/// * `final_temp` - Final temperature
|
||||
/// * `steps` - Number of annealing steps
|
||||
#[napi(constructor)]
|
||||
pub fn new(initial_temp: f64, final_temp: f64, steps: u32) -> Self {
|
||||
Self {
|
||||
inner: RustTempAnnealing::new(initial_temp as f32, final_temp as f32, steps as usize),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with specific decay type
|
||||
#[napi(factory)]
|
||||
pub fn with_decay(
|
||||
initial_temp: f64,
|
||||
final_temp: f64,
|
||||
steps: u32,
|
||||
decay_type: DecayType,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: RustTempAnnealing::new(initial_temp as f32, final_temp as f32, steps as usize)
|
||||
.with_decay(decay_type.into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current temperature
|
||||
#[napi]
|
||||
pub fn get_temp(&self) -> f64 {
|
||||
self.inner.get_temp() as f64
|
||||
}
|
||||
|
||||
/// Step the scheduler and return current temperature
|
||||
#[napi]
|
||||
pub fn step(&mut self) -> f64 {
|
||||
self.inner.step() as f64
|
||||
}
|
||||
|
||||
/// Reset scheduler
|
||||
#[napi]
|
||||
pub fn reset(&mut self) {
|
||||
self.inner.reset();
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Curriculum Learning
|
||||
// ============================================================================
|
||||
|
||||
/// Curriculum stage configuration
|
||||
#[napi(object)]
|
||||
pub struct CurriculumStageConfig {
|
||||
pub name: String,
|
||||
pub difficulty: f64,
|
||||
pub duration: u32,
|
||||
pub temperature: f64,
|
||||
pub negative_count: u32,
|
||||
}
|
||||
|
||||
/// Curriculum scheduler for progressive training
|
||||
#[napi]
|
||||
pub struct CurriculumScheduler {
|
||||
inner: RustCurriculum,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl CurriculumScheduler {
|
||||
/// Create an empty curriculum scheduler
|
||||
#[napi(constructor)]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: RustCurriculum::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a default easy-to-hard curriculum
|
||||
#[napi(factory)]
|
||||
pub fn default_curriculum(total_steps: u32) -> Self {
|
||||
Self {
|
||||
inner: RustCurriculum::default_curriculum(total_steps as usize),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a stage to the curriculum
|
||||
#[napi]
|
||||
pub fn add_stage(&mut self, config: CurriculumStageConfig) {
|
||||
let stage = RustStage::new(&config.name)
|
||||
.difficulty(config.difficulty as f32)
|
||||
.duration(config.duration as usize)
|
||||
.temperature(config.temperature as f32)
|
||||
.negative_count(config.negative_count as usize);
|
||||
|
||||
// Rebuild with added stage
|
||||
let new_inner = std::mem::take(&mut self.inner).add_stage(stage);
|
||||
self.inner = new_inner;
|
||||
}
|
||||
|
||||
/// Step the curriculum and return current stage info
|
||||
#[napi]
|
||||
pub fn step(&mut self) -> Option<CurriculumStageConfig> {
|
||||
self.inner.step().map(|s| CurriculumStageConfig {
|
||||
name: s.name.clone(),
|
||||
difficulty: s.difficulty as f64,
|
||||
duration: s.duration as u32,
|
||||
temperature: s.temperature as f64,
|
||||
negative_count: s.negative_count as u32,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get current difficulty (0.0 to 1.0)
|
||||
#[napi(getter)]
|
||||
pub fn difficulty(&self) -> f64 {
|
||||
self.inner.difficulty() as f64
|
||||
}
|
||||
|
||||
/// Get current temperature
|
||||
#[napi(getter)]
|
||||
pub fn temperature(&self) -> f64 {
|
||||
self.inner.temperature() as f64
|
||||
}
|
||||
|
||||
/// Get current negative count
|
||||
#[napi(getter)]
|
||||
pub fn negative_count(&self) -> u32 {
|
||||
self.inner.negative_count() as u32
|
||||
}
|
||||
|
||||
/// Check if curriculum is complete
|
||||
#[napi(getter)]
|
||||
pub fn is_complete(&self) -> bool {
|
||||
self.inner.is_complete()
|
||||
}
|
||||
|
||||
/// Get overall progress (0.0 to 1.0)
|
||||
#[napi(getter)]
|
||||
pub fn progress(&self) -> f64 {
|
||||
self.inner.progress() as f64
|
||||
}
|
||||
|
||||
/// Reset curriculum
|
||||
#[napi]
|
||||
pub fn reset(&mut self) {
|
||||
self.inner.reset();
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Negative Mining
|
||||
// ============================================================================
|
||||
|
||||
/// Mining strategy for negative selection
|
||||
#[napi(string_enum)]
|
||||
pub enum MiningStrategy {
|
||||
Random,
|
||||
HardNegative,
|
||||
SemiHard,
|
||||
DistanceWeighted,
|
||||
}
|
||||
|
||||
impl From<MiningStrategy> for RustMiningStrategy {
|
||||
fn from(ms: MiningStrategy) -> Self {
|
||||
match ms {
|
||||
MiningStrategy::Random => RustMiningStrategy::Random,
|
||||
MiningStrategy::HardNegative => RustMiningStrategy::HardNegative,
|
||||
MiningStrategy::SemiHard => RustMiningStrategy::SemiHard,
|
||||
MiningStrategy::DistanceWeighted => RustMiningStrategy::DistanceWeighted,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Hard negative miner for selecting informative negatives
|
||||
#[napi]
|
||||
pub struct HardNegativeMiner {
|
||||
inner: RustHardMiner,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl HardNegativeMiner {
|
||||
/// Create a new hard negative miner
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `strategy` - Mining strategy to use
|
||||
#[napi(constructor)]
|
||||
pub fn new(strategy: MiningStrategy) -> Self {
|
||||
Self {
|
||||
inner: RustHardMiner::new(strategy.into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with margin (for semi-hard mining)
|
||||
#[napi(factory)]
|
||||
pub fn with_margin(strategy: MiningStrategy, margin: f64) -> Self {
|
||||
Self {
|
||||
inner: RustHardMiner::new(strategy.into()).with_margin(margin as f32),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with temperature (for distance-weighted mining)
|
||||
#[napi(factory)]
|
||||
pub fn with_temperature(strategy: MiningStrategy, temperature: f64) -> Self {
|
||||
Self {
|
||||
inner: RustHardMiner::new(strategy.into()).with_temperature(temperature as f32),
|
||||
}
|
||||
}
|
||||
|
||||
/// Mine negative indices from candidates
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `anchor` - Anchor embedding
|
||||
/// * `positive` - Positive example embedding
|
||||
/// * `candidates` - Array of candidate embeddings
|
||||
/// * `num_negatives` - Number of negatives to select
|
||||
///
|
||||
/// # Returns
|
||||
/// Array of indices into the candidates array
|
||||
#[napi]
|
||||
pub fn mine(
|
||||
&self,
|
||||
anchor: Float32Array,
|
||||
positive: Float32Array,
|
||||
candidates: Vec<Float32Array>,
|
||||
num_negatives: u32,
|
||||
) -> Vec<u32> {
|
||||
let anchor_slice = anchor.as_ref();
|
||||
let positive_slice = positive.as_ref();
|
||||
let candidates_vec: Vec<Vec<f32>> = candidates.into_iter().map(|c| c.to_vec()).collect();
|
||||
let candidates_refs: Vec<&[f32]> = candidates_vec.iter().map(|c| c.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.mine(
|
||||
anchor_slice,
|
||||
positive_slice,
|
||||
&candidates_refs,
|
||||
num_negatives as usize,
|
||||
)
|
||||
.into_iter()
|
||||
.map(|i| i as u32)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// In-batch negative mining utility
|
||||
#[napi]
|
||||
pub struct InBatchMiner {
|
||||
exclude_positive: bool,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl InBatchMiner {
|
||||
/// Create a new in-batch miner
|
||||
#[napi(constructor)]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
exclude_positive: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create without excluding positive
|
||||
#[napi(factory)]
|
||||
pub fn include_positive() -> Self {
|
||||
Self {
|
||||
exclude_positive: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get negative indices for a given anchor in a batch
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `anchor_idx` - Index of the anchor in the batch
|
||||
/// * `positive_idx` - Index of the positive in the batch
|
||||
/// * `batch_size` - Total batch size
|
||||
///
|
||||
/// # Returns
|
||||
/// Array of indices that can be used as negatives
|
||||
#[napi]
|
||||
pub fn get_negatives(&self, anchor_idx: u32, positive_idx: u32, batch_size: u32) -> Vec<u32> {
|
||||
(0..batch_size)
|
||||
.filter(|&i| i != anchor_idx && (!self.exclude_positive || i != positive_idx))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user