Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
52
vendor/ruvector/crates/ruvector-gnn-wasm/Cargo.toml
vendored
Normal file
52
vendor/ruvector/crates/ruvector-gnn-wasm/Cargo.toml
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
[package]
|
||||
name = "ruvector-gnn-wasm"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
rust-version.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
repository.workspace = true
|
||||
readme = "README.md"
|
||||
description = "WebAssembly bindings for RuVector GNN with tensor compression and differentiable search"
|
||||
|
||||
[package.metadata.wasm-pack.profile.release]
|
||||
wasm-opt = false
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib", "rlib"]
|
||||
|
||||
[dependencies]
|
||||
ruvector-gnn = { version = "2.0", path = "../ruvector-gnn", default-features = false, features = ["wasm"] }
|
||||
|
||||
# WASM
|
||||
wasm-bindgen = { workspace = true }
|
||||
js-sys = { workspace = true }
|
||||
getrandom = { workspace = true }
|
||||
|
||||
# Serialization
|
||||
serde = { workspace = true }
|
||||
serde-wasm-bindgen = "0.6"
|
||||
|
||||
# Utils
|
||||
console_error_panic_hook = { version = "0.1", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
wasm-bindgen-test = "0.3"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
console_error_panic_hook = ["dep:console_error_panic_hook"]
|
||||
|
||||
# Ensure getrandom uses wasm_js/js feature for WASM
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
getrandom = { workspace = true, features = ["wasm_js"] }
|
||||
getrandom02 = { package = "getrandom", version = "0.2", features = ["js"] }
|
||||
|
||||
[profile.release]
|
||||
opt-level = "z"
|
||||
lto = true
|
||||
codegen-units = 1
|
||||
panic = "abort"
|
||||
|
||||
[profile.release.package."*"]
|
||||
opt-level = "z"
|
||||
190
vendor/ruvector/crates/ruvector-gnn-wasm/README.md
vendored
Normal file
190
vendor/ruvector/crates/ruvector-gnn-wasm/README.md
vendored
Normal file
@@ -0,0 +1,190 @@
|
||||
# RuVector GNN WASM
|
||||
|
||||
WebAssembly bindings for RuVector Graph Neural Network operations.
|
||||
|
||||
## Features
|
||||
|
||||
- **GNN Layer Operations**: Multi-head attention, GRU updates, layer normalization
|
||||
- **Tensor Compression**: Adaptive compression based on access frequency
|
||||
- **Differentiable Search**: Soft attention-based similarity search
|
||||
- **Hierarchical Forward**: Multi-layer GNN processing
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
npm install ruvector-gnn-wasm
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Initialize
|
||||
|
||||
```typescript
|
||||
import init, {
|
||||
JsRuvectorLayer,
|
||||
JsTensorCompress,
|
||||
differentiableSearch,
|
||||
SearchConfig
|
||||
} from 'ruvector-gnn-wasm';
|
||||
|
||||
await init();
|
||||
```
|
||||
|
||||
### GNN Layer
|
||||
|
||||
```typescript
|
||||
// Create a GNN layer
|
||||
const layer = new JsRuvectorLayer(
|
||||
4, // input dimension
|
||||
8, // hidden dimension
|
||||
2, // number of attention heads
|
||||
0.1 // dropout rate
|
||||
);
|
||||
|
||||
// Forward pass
|
||||
const nodeEmbedding = new Float32Array([1.0, 2.0, 3.0, 4.0]);
|
||||
const neighbors = [
|
||||
new Float32Array([0.5, 1.0, 1.5, 2.0]),
|
||||
new Float32Array([2.0, 3.0, 4.0, 5.0])
|
||||
];
|
||||
const edgeWeights = new Float32Array([0.3, 0.7]);
|
||||
|
||||
const output = layer.forward(nodeEmbedding, neighbors, edgeWeights);
|
||||
console.log('Output dimension:', layer.outputDim);
|
||||
```
|
||||
|
||||
### Tensor Compression
|
||||
|
||||
```typescript
|
||||
const compressor = new JsTensorCompress();
|
||||
|
||||
// Compress based on access frequency
|
||||
const embedding = new Float32Array(128).fill(0.5);
|
||||
const compressed = compressor.compress(embedding, 0.5); // 50% access frequency
|
||||
|
||||
// Decompress
|
||||
const decompressed = compressor.decompress(compressed);
|
||||
|
||||
// Or specify compression level explicitly
|
||||
const compressedPQ8 = compressor.compressWithLevel(embedding, "pq8");
|
||||
|
||||
// Get compression ratio
|
||||
const ratio = compressor.getCompressionRatio(0.5); // Returns ~2.0 for half precision
|
||||
```
|
||||
|
||||
### Compression Levels
|
||||
|
||||
Access frequency determines compression:
|
||||
- `f > 0.8`: **Full precision** (no compression) - hot data
|
||||
- `f > 0.4`: **Half precision** (2x compression) - warm data
|
||||
- `f > 0.1`: **8-bit PQ** (4x compression) - cool data
|
||||
- `f > 0.01`: **4-bit PQ** (8x compression) - cold data
|
||||
- `f <= 0.01`: **Binary** (32x compression) - archive data
|
||||
|
||||
### Differentiable Search
|
||||
|
||||
```typescript
|
||||
const query = new Float32Array([1.0, 0.0, 0.0]);
|
||||
const candidates = [
|
||||
new Float32Array([1.0, 0.0, 0.0]), // Perfect match
|
||||
new Float32Array([0.9, 0.1, 0.0]), // Close match
|
||||
new Float32Array([0.0, 1.0, 0.0]) // Orthogonal
|
||||
];
|
||||
|
||||
const config = new SearchConfig(2, 1.0); // k=2, temperature=1.0
|
||||
const result = differentiableSearch(query, candidates, config);
|
||||
|
||||
console.log('Top indices:', result.indices);
|
||||
console.log('Weights:', result.weights);
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### `JsRuvectorLayer`
|
||||
|
||||
```typescript
|
||||
class JsRuvectorLayer {
|
||||
constructor(
|
||||
inputDim: number,
|
||||
hiddenDim: number,
|
||||
heads: number,
|
||||
dropout: number
|
||||
);
|
||||
|
||||
forward(
|
||||
nodeEmbedding: Float32Array,
|
||||
neighborEmbeddings: Float32Array[],
|
||||
edgeWeights: Float32Array
|
||||
): Float32Array;
|
||||
|
||||
readonly outputDim: number;
|
||||
}
|
||||
```
|
||||
|
||||
### `JsTensorCompress`
|
||||
|
||||
```typescript
|
||||
class JsTensorCompress {
|
||||
constructor();
|
||||
|
||||
compress(embedding: Float32Array, accessFreq: number): object;
|
||||
compressWithLevel(embedding: Float32Array, level: string): object;
|
||||
decompress(compressed: object): Float32Array;
|
||||
getCompressionRatio(accessFreq: number): number;
|
||||
}
|
||||
```
|
||||
|
||||
Compression levels: `"none"`, `"half"`, `"pq8"`, `"pq4"`, `"binary"`
|
||||
|
||||
### `differentiableSearch`
|
||||
|
||||
```typescript
|
||||
function differentiableSearch(
|
||||
query: Float32Array,
|
||||
candidateEmbeddings: Float32Array[],
|
||||
config: SearchConfig
|
||||
): { indices: number[], weights: number[] };
|
||||
```
|
||||
|
||||
### `SearchConfig`
|
||||
|
||||
```typescript
|
||||
class SearchConfig {
|
||||
constructor(k: number, temperature: number);
|
||||
k: number; // Number of results
|
||||
temperature: number; // Softmax temperature (lower = sharper)
|
||||
}
|
||||
```
|
||||
|
||||
### `cosineSimilarity`
|
||||
|
||||
```typescript
|
||||
function cosineSimilarity(a: Float32Array, b: Float32Array): number;
|
||||
```
|
||||
|
||||
## Building from Source
|
||||
|
||||
```bash
|
||||
# Install wasm-pack
|
||||
curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
|
||||
|
||||
# Build for Node.js
|
||||
wasm-pack build --target nodejs
|
||||
|
||||
# Build for browser
|
||||
wasm-pack build --target web
|
||||
|
||||
# Build for bundler (webpack, etc.)
|
||||
wasm-pack build --target bundler
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
- GNN layers use efficient attention mechanisms
|
||||
- Compression reduces memory usage by 2-32x
|
||||
- All operations are optimized for WASM
|
||||
- No garbage collection during forward passes
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
35
vendor/ruvector/crates/ruvector-gnn-wasm/package.json
vendored
Normal file
35
vendor/ruvector/crates/ruvector-gnn-wasm/package.json
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
{
|
||||
"name": "@ruvector/gnn-wasm",
|
||||
"version": "0.1.0",
|
||||
"description": "WebAssembly bindings for ruvector-gnn - Graph Neural Network layers for browsers",
|
||||
"main": "pkg/ruvector_gnn_wasm.js",
|
||||
"types": "pkg/ruvector_gnn_wasm.d.ts",
|
||||
"files": [
|
||||
"pkg/"
|
||||
],
|
||||
"scripts": {
|
||||
"build": "wasm-pack build --target web --out-dir pkg",
|
||||
"build:node": "wasm-pack build --target nodejs --out-dir pkg-node",
|
||||
"build:bundler": "wasm-pack build --target bundler --out-dir pkg-bundler",
|
||||
"test": "wasm-pack test --headless --firefox"
|
||||
},
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "https://github.com/ruvnet/ruvector"
|
||||
},
|
||||
"keywords": [
|
||||
"wasm",
|
||||
"webassembly",
|
||||
"gnn",
|
||||
"graph-neural-network",
|
||||
"machine-learning",
|
||||
"neural-networks",
|
||||
"browser"
|
||||
],
|
||||
"author": "rUv",
|
||||
"license": "MIT",
|
||||
"bugs": {
|
||||
"url": "https://github.com/ruvnet/ruvector/issues"
|
||||
},
|
||||
"homepage": "https://github.com/ruvnet/ruvector"
|
||||
}
|
||||
410
vendor/ruvector/crates/ruvector-gnn-wasm/src/lib.rs
vendored
Normal file
410
vendor/ruvector/crates/ruvector-gnn-wasm/src/lib.rs
vendored
Normal file
@@ -0,0 +1,410 @@
|
||||
//! WebAssembly bindings for RuVector GNN
|
||||
//!
|
||||
//! This module provides high-performance browser bindings for Graph Neural Network
|
||||
//! operations on HNSW topology, including:
|
||||
//! - GNN layer forward passes
|
||||
//! - Tensor compression with adaptive level selection
|
||||
//! - Differentiable search with soft attention
|
||||
//! - Hierarchical forward propagation
|
||||
|
||||
use ruvector_gnn::{
|
||||
differentiable_search as core_differentiable_search,
|
||||
hierarchical_forward as core_hierarchical_forward, CompressedTensor, CompressionLevel,
|
||||
RuvectorLayer, TensorCompress,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Initialize panic hook for better error messages
|
||||
#[wasm_bindgen(start)]
|
||||
pub fn init() {
|
||||
#[cfg(feature = "console_error_panic_hook")]
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Type Definitions for WASM
|
||||
// ============================================================================
|
||||
|
||||
/// Query configuration for differentiable search
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[wasm_bindgen]
|
||||
pub struct SearchConfig {
|
||||
/// Number of top results to return
|
||||
pub k: usize,
|
||||
/// Temperature for softmax (lower = sharper, higher = smoother)
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl SearchConfig {
|
||||
/// Create a new search configuration
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(k: usize, temperature: f32) -> Self {
|
||||
Self { k, temperature }
|
||||
}
|
||||
}
|
||||
|
||||
/// Search results with indices and weights (internal)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct SearchResultInternal {
|
||||
/// Indices of top-k candidates
|
||||
indices: Vec<usize>,
|
||||
/// Soft weights for each result
|
||||
weights: Vec<f32>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// JsRuvectorLayer - GNN Layer Wrapper
|
||||
// ============================================================================
|
||||
|
||||
/// Graph Neural Network layer for HNSW topology
|
||||
#[wasm_bindgen]
|
||||
pub struct JsRuvectorLayer {
|
||||
inner: RuvectorLayer,
|
||||
hidden_dim: usize,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl JsRuvectorLayer {
|
||||
/// Create a new GNN layer
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input_dim` - Dimension of input node embeddings
|
||||
/// * `hidden_dim` - Dimension of hidden representations
|
||||
/// * `heads` - Number of attention heads
|
||||
/// * `dropout` - Dropout rate (0.0 to 1.0)
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(
|
||||
input_dim: usize,
|
||||
hidden_dim: usize,
|
||||
heads: usize,
|
||||
dropout: f32,
|
||||
) -> Result<JsRuvectorLayer, JsValue> {
|
||||
let inner = RuvectorLayer::new(input_dim, hidden_dim, heads, dropout)
|
||||
.map_err(|e| JsValue::from_str(&e.to_string()))?;
|
||||
|
||||
Ok(JsRuvectorLayer { inner, hidden_dim })
|
||||
}
|
||||
|
||||
/// Forward pass through the GNN layer
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `node_embedding` - Current node's embedding (Float32Array)
|
||||
/// * `neighbor_embeddings` - Embeddings of neighbor nodes (array of Float32Arrays)
|
||||
/// * `edge_weights` - Weights of edges to neighbors (Float32Array)
|
||||
///
|
||||
/// # Returns
|
||||
/// Updated node embedding (Float32Array)
|
||||
#[wasm_bindgen]
|
||||
pub fn forward(
|
||||
&self,
|
||||
node_embedding: Vec<f32>,
|
||||
neighbor_embeddings: JsValue,
|
||||
edge_weights: Vec<f32>,
|
||||
) -> Result<Vec<f32>, JsValue> {
|
||||
// Convert neighbor embeddings from JS value
|
||||
let neighbors: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(neighbor_embeddings)
|
||||
.map_err(|e| {
|
||||
JsValue::from_str(&format!("Failed to parse neighbor embeddings: {}", e))
|
||||
})?;
|
||||
|
||||
// Validate inputs
|
||||
if neighbors.len() != edge_weights.len() {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Number of neighbors ({}) must match number of edge weights ({})",
|
||||
neighbors.len(),
|
||||
edge_weights.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Call core forward
|
||||
let result = self
|
||||
.inner
|
||||
.forward(&node_embedding, &neighbors, &edge_weights);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Get the output dimension of this layer
|
||||
#[wasm_bindgen(getter, js_name = outputDim)]
|
||||
pub fn output_dim(&self) -> usize {
|
||||
self.hidden_dim
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// JsTensorCompress - Tensor Compression Wrapper
|
||||
// ============================================================================
|
||||
|
||||
/// Tensor compressor with adaptive level selection
|
||||
#[wasm_bindgen]
|
||||
pub struct JsTensorCompress {
|
||||
inner: TensorCompress,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl JsTensorCompress {
|
||||
/// Create a new tensor compressor
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: TensorCompress::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compress an embedding based on access frequency
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `embedding` - The input embedding vector (Float32Array)
|
||||
/// * `access_freq` - Access frequency in range [0.0, 1.0]
|
||||
/// - f > 0.8: Full precision (hot data)
|
||||
/// - f > 0.4: Half precision (warm data)
|
||||
/// - f > 0.1: 8-bit PQ (cool data)
|
||||
/// - f > 0.01: 4-bit PQ (cold data)
|
||||
/// - f <= 0.01: Binary (archive)
|
||||
///
|
||||
/// # Returns
|
||||
/// Compressed tensor as JsValue
|
||||
#[wasm_bindgen]
|
||||
pub fn compress(&self, embedding: Vec<f32>, access_freq: f32) -> Result<JsValue, JsValue> {
|
||||
let compressed = self
|
||||
.inner
|
||||
.compress(&embedding, access_freq)
|
||||
.map_err(|e| JsValue::from_str(&format!("Compression failed: {}", e)))?;
|
||||
|
||||
// Serialize using serde_wasm_bindgen
|
||||
serde_wasm_bindgen::to_value(&compressed)
|
||||
.map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Compress with explicit compression level
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `embedding` - The input embedding vector
|
||||
/// * `level` - Compression level ("none", "half", "pq8", "pq4", "binary")
|
||||
///
|
||||
/// # Returns
|
||||
/// Compressed tensor as JsValue
|
||||
#[wasm_bindgen(js_name = compressWithLevel)]
|
||||
pub fn compress_with_level(
|
||||
&self,
|
||||
embedding: Vec<f32>,
|
||||
level: &str,
|
||||
) -> Result<JsValue, JsValue> {
|
||||
let compression_level = match level {
|
||||
"none" => CompressionLevel::None,
|
||||
"half" => CompressionLevel::Half { scale: 1.0 },
|
||||
"pq8" => CompressionLevel::PQ8 {
|
||||
subvectors: 8,
|
||||
centroids: 16,
|
||||
},
|
||||
"pq4" => CompressionLevel::PQ4 {
|
||||
subvectors: 8,
|
||||
outlier_threshold: 3.0,
|
||||
},
|
||||
"binary" => CompressionLevel::Binary { threshold: 0.0 },
|
||||
_ => {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Unknown compression level: {}",
|
||||
level
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
let compressed = self
|
||||
.inner
|
||||
.compress_with_level(&embedding, &compression_level)
|
||||
.map_err(|e| JsValue::from_str(&format!("Compression failed: {}", e)))?;
|
||||
|
||||
// Serialize using serde_wasm_bindgen
|
||||
serde_wasm_bindgen::to_value(&compressed)
|
||||
.map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Decompress a compressed tensor
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `compressed` - Serialized compressed tensor (JsValue)
|
||||
///
|
||||
/// # Returns
|
||||
/// Decompressed embedding vector (Float32Array)
|
||||
#[wasm_bindgen]
|
||||
pub fn decompress(&self, compressed: JsValue) -> Result<Vec<f32>, JsValue> {
|
||||
let compressed_tensor: CompressedTensor = serde_wasm_bindgen::from_value(compressed)
|
||||
.map_err(|e| JsValue::from_str(&format!("Deserialization failed: {}", e)))?;
|
||||
|
||||
let decompressed = self
|
||||
.inner
|
||||
.decompress(&compressed_tensor)
|
||||
.map_err(|e| JsValue::from_str(&format!("Decompression failed: {}", e)))?;
|
||||
|
||||
Ok(decompressed)
|
||||
}
|
||||
|
||||
/// Get compression ratio estimate for a given access frequency
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `access_freq` - Access frequency in range [0.0, 1.0]
|
||||
///
|
||||
/// # Returns
|
||||
/// Estimated compression ratio (original_size / compressed_size)
|
||||
#[wasm_bindgen(js_name = getCompressionRatio)]
|
||||
pub fn get_compression_ratio(&self, access_freq: f32) -> f32 {
|
||||
if access_freq > 0.8 {
|
||||
1.0 // No compression
|
||||
} else if access_freq > 0.4 {
|
||||
2.0 // Half precision
|
||||
} else if access_freq > 0.1 {
|
||||
4.0 // 8-bit PQ
|
||||
} else if access_freq > 0.01 {
|
||||
8.0 // 4-bit PQ
|
||||
} else {
|
||||
32.0 // Binary
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Standalone Functions
|
||||
// ============================================================================
|
||||
|
||||
/// Differentiable search using soft attention mechanism
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - The query vector (Float32Array)
|
||||
/// * `candidate_embeddings` - List of candidate embedding vectors (array of Float32Arrays)
|
||||
/// * `config` - Search configuration (k and temperature)
|
||||
///
|
||||
/// # Returns
|
||||
/// Object with indices and weights for top-k candidates
|
||||
#[wasm_bindgen(js_name = differentiableSearch)]
|
||||
pub fn differentiable_search(
|
||||
query: Vec<f32>,
|
||||
candidate_embeddings: JsValue,
|
||||
config: &SearchConfig,
|
||||
) -> Result<JsValue, JsValue> {
|
||||
// Convert candidate embeddings from JS value
|
||||
let candidates: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(candidate_embeddings)
|
||||
.map_err(|e| JsValue::from_str(&format!("Failed to parse candidate embeddings: {}", e)))?;
|
||||
|
||||
// Call core search function
|
||||
let (indices, weights) =
|
||||
core_differentiable_search(&query, &candidates, config.k, config.temperature);
|
||||
|
||||
let result = SearchResultInternal { indices, weights };
|
||||
serde_wasm_bindgen::to_value(&result)
|
||||
.map_err(|e| JsValue::from_str(&format!("Failed to serialize result: {}", e)))
|
||||
}
|
||||
|
||||
/// Hierarchical forward pass through multiple GNN layers
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - The query vector (Float32Array)
|
||||
/// * `layer_embeddings` - Embeddings organized by layer (array of arrays of Float32Arrays)
|
||||
/// * `gnn_layers` - Array of GNN layers to process through
|
||||
///
|
||||
/// # Returns
|
||||
/// Final embedding after hierarchical processing (Float32Array)
|
||||
#[wasm_bindgen(js_name = hierarchicalForward)]
|
||||
pub fn hierarchical_forward(
|
||||
query: Vec<f32>,
|
||||
layer_embeddings: JsValue,
|
||||
gnn_layers: Vec<JsRuvectorLayer>,
|
||||
) -> Result<Vec<f32>, JsValue> {
|
||||
// Convert layer embeddings from JS value
|
||||
let embeddings: Vec<Vec<Vec<f32>>> = serde_wasm_bindgen::from_value(layer_embeddings)
|
||||
.map_err(|e| JsValue::from_str(&format!("Failed to parse layer embeddings: {}", e)))?;
|
||||
|
||||
// Extract inner layers
|
||||
let core_layers: Vec<RuvectorLayer> = gnn_layers.iter().map(|l| l.inner.clone()).collect();
|
||||
|
||||
// Call core function
|
||||
let result = core_hierarchical_forward(&query, &embeddings, &core_layers);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Utility Functions
|
||||
// ============================================================================
|
||||
|
||||
/// Get version information
|
||||
#[wasm_bindgen]
|
||||
pub fn version() -> String {
|
||||
env!("CARGO_PKG_VERSION").to_string()
|
||||
}
|
||||
|
||||
/// Compute cosine similarity between two vectors
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `a` - First vector (Float32Array)
|
||||
/// * `b` - Second vector (Float32Array)
|
||||
///
|
||||
/// # Returns
|
||||
/// Cosine similarity score [-1.0, 1.0]
|
||||
#[wasm_bindgen(js_name = cosineSimilarity)]
|
||||
pub fn cosine_similarity(a: Vec<f32>, b: Vec<f32>) -> Result<f32, JsValue> {
|
||||
if a.len() != b.len() {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Vector dimensions must match: {} vs {}",
|
||||
a.len(),
|
||||
b.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a == 0.0 || norm_b == 0.0 {
|
||||
Ok(0.0)
|
||||
} else {
|
||||
Ok(dot_product / (norm_a * norm_b))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wasm_bindgen_test::*;
|
||||
|
||||
wasm_bindgen_test_configure!(run_in_browser);
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_version() {
|
||||
assert!(!version().is_empty());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_ruvector_layer_creation() {
|
||||
let layer = JsRuvectorLayer::new(4, 8, 2, 0.1);
|
||||
assert!(layer.is_ok());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_tensor_compress_creation() {
|
||||
let compressor = JsTensorCompress::new();
|
||||
assert_eq!(compressor.get_compression_ratio(1.0), 1.0);
|
||||
assert_eq!(compressor.get_compression_ratio(0.5), 2.0);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_cosine_similarity() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
let sim = cosine_similarity(a, b).unwrap();
|
||||
assert!((sim - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_search_config() {
|
||||
let config = SearchConfig::new(5, 1.0);
|
||||
assert_eq!(config.k, 5);
|
||||
assert_eq!(config.temperature, 1.0);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user