352 lines
11 KiB
JavaScript
352 lines
11 KiB
JavaScript
/**
|
|
* Model Loader for RuVector ONNX Embeddings WASM
|
|
*
|
|
* Provides easy loading of pre-trained models from HuggingFace Hub
|
|
*/
|
|
|
|
/**
|
|
* Pre-configured models with their HuggingFace URLs
|
|
*/
|
|
export const MODELS = {
|
|
// Sentence Transformers - Small & Fast
|
|
'all-MiniLM-L6-v2': {
|
|
name: 'all-MiniLM-L6-v2',
|
|
dimension: 384,
|
|
maxLength: 256,
|
|
size: '23MB',
|
|
description: 'Fast, general-purpose embeddings',
|
|
model: 'https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx',
|
|
tokenizer: 'https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/tokenizer.json',
|
|
},
|
|
'all-MiniLM-L12-v2': {
|
|
name: 'all-MiniLM-L12-v2',
|
|
dimension: 384,
|
|
maxLength: 256,
|
|
size: '33MB',
|
|
description: 'Better quality, balanced speed',
|
|
model: 'https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/onnx/model.onnx',
|
|
tokenizer: 'https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/tokenizer.json',
|
|
},
|
|
|
|
// BGE Models - State of the art
|
|
'bge-small-en-v1.5': {
|
|
name: 'bge-small-en-v1.5',
|
|
dimension: 384,
|
|
maxLength: 512,
|
|
size: '33MB',
|
|
description: 'State-of-the-art small model',
|
|
model: 'https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/onnx/model.onnx',
|
|
tokenizer: 'https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/tokenizer.json',
|
|
},
|
|
'bge-base-en-v1.5': {
|
|
name: 'bge-base-en-v1.5',
|
|
dimension: 768,
|
|
maxLength: 512,
|
|
size: '110MB',
|
|
description: 'Best overall quality',
|
|
model: 'https://huggingface.co/BAAI/bge-base-en-v1.5/resolve/main/onnx/model.onnx',
|
|
tokenizer: 'https://huggingface.co/BAAI/bge-base-en-v1.5/resolve/main/tokenizer.json',
|
|
},
|
|
|
|
// E5 Models - Microsoft
|
|
'e5-small-v2': {
|
|
name: 'e5-small-v2',
|
|
dimension: 384,
|
|
maxLength: 512,
|
|
size: '33MB',
|
|
description: 'Excellent for search & retrieval',
|
|
model: 'https://huggingface.co/intfloat/e5-small-v2/resolve/main/onnx/model.onnx',
|
|
tokenizer: 'https://huggingface.co/intfloat/e5-small-v2/resolve/main/tokenizer.json',
|
|
},
|
|
|
|
// GTE Models - Alibaba
|
|
'gte-small': {
|
|
name: 'gte-small',
|
|
dimension: 384,
|
|
maxLength: 512,
|
|
size: '33MB',
|
|
description: 'Good multilingual support',
|
|
model: 'https://huggingface.co/thenlper/gte-small/resolve/main/onnx/model.onnx',
|
|
tokenizer: 'https://huggingface.co/thenlper/gte-small/resolve/main/tokenizer.json',
|
|
},
|
|
};
|
|
|
|
/**
|
|
* Default model for quick start
|
|
*/
|
|
export const DEFAULT_MODEL = 'all-MiniLM-L6-v2';
|
|
|
|
/**
|
|
* Model loader with caching support
|
|
*/
|
|
export class ModelLoader {
|
|
constructor(options = {}) {
|
|
this.cache = options.cache ?? true;
|
|
this.cacheStorage = options.cacheStorage ?? 'ruvector-models';
|
|
this.onProgress = options.onProgress ?? null;
|
|
}
|
|
|
|
/**
|
|
* Load a pre-configured model by name
|
|
* @param {string} modelName - Model name from MODELS
|
|
* @returns {Promise<{modelBytes: Uint8Array, tokenizerJson: string, config: object}>}
|
|
*/
|
|
async loadModel(modelName = DEFAULT_MODEL) {
|
|
const modelConfig = MODELS[modelName];
|
|
if (!modelConfig) {
|
|
throw new Error(`Unknown model: ${modelName}. Available: ${Object.keys(MODELS).join(', ')}`);
|
|
}
|
|
|
|
console.log(`Loading model: ${modelConfig.name} (${modelConfig.size})`);
|
|
|
|
const [modelBytes, tokenizerJson] = await Promise.all([
|
|
this.fetchWithCache(modelConfig.model, `${modelName}-model.onnx`, 'arraybuffer'),
|
|
this.fetchWithCache(modelConfig.tokenizer, `${modelName}-tokenizer.json`, 'text'),
|
|
]);
|
|
|
|
return {
|
|
modelBytes: new Uint8Array(modelBytes),
|
|
tokenizerJson,
|
|
config: modelConfig,
|
|
};
|
|
}
|
|
|
|
/**
|
|
* Load model from custom URLs
|
|
* @param {string} modelUrl - URL to ONNX model
|
|
* @param {string} tokenizerUrl - URL to tokenizer.json
|
|
* @returns {Promise<{modelBytes: Uint8Array, tokenizerJson: string}>}
|
|
*/
|
|
async loadFromUrls(modelUrl, tokenizerUrl) {
|
|
const [modelBytes, tokenizerJson] = await Promise.all([
|
|
this.fetchWithCache(modelUrl, null, 'arraybuffer'),
|
|
this.fetchWithCache(tokenizerUrl, null, 'text'),
|
|
]);
|
|
|
|
return {
|
|
modelBytes: new Uint8Array(modelBytes),
|
|
tokenizerJson,
|
|
};
|
|
}
|
|
|
|
/**
|
|
* Load model from local files (Node.js)
|
|
* @param {string} modelPath - Path to ONNX model
|
|
* @param {string} tokenizerPath - Path to tokenizer.json
|
|
* @returns {Promise<{modelBytes: Uint8Array, tokenizerJson: string}>}
|
|
*/
|
|
async loadFromFiles(modelPath, tokenizerPath) {
|
|
// Node.js environment
|
|
if (typeof process !== 'undefined' && process.versions?.node) {
|
|
const fs = await import('fs/promises');
|
|
const [modelBytes, tokenizerJson] = await Promise.all([
|
|
fs.readFile(modelPath),
|
|
fs.readFile(tokenizerPath, 'utf8'),
|
|
]);
|
|
return {
|
|
modelBytes: new Uint8Array(modelBytes),
|
|
tokenizerJson,
|
|
};
|
|
}
|
|
throw new Error('loadFromFiles is only available in Node.js');
|
|
}
|
|
|
|
/**
|
|
* Fetch with optional caching (uses Cache API in browsers)
|
|
*/
|
|
async fetchWithCache(url, cacheKey, responseType) {
|
|
// Try cache first (browser only)
|
|
if (this.cache && typeof caches !== 'undefined' && cacheKey) {
|
|
try {
|
|
const cache = await caches.open(this.cacheStorage);
|
|
const cached = await cache.match(cacheKey);
|
|
if (cached) {
|
|
console.log(` Cache hit: ${cacheKey}`);
|
|
return responseType === 'arraybuffer'
|
|
? await cached.arrayBuffer()
|
|
: await cached.text();
|
|
}
|
|
} catch (e) {
|
|
// Cache API not available, continue with fetch
|
|
}
|
|
}
|
|
|
|
// Fetch from network
|
|
console.log(` Downloading: ${url}`);
|
|
const response = await this.fetchWithProgress(url);
|
|
|
|
if (!response.ok) {
|
|
throw new Error(`Failed to fetch ${url}: ${response.status} ${response.statusText}`);
|
|
}
|
|
|
|
// Clone for caching
|
|
const responseClone = response.clone();
|
|
|
|
// Cache the response (browser only)
|
|
if (this.cache && typeof caches !== 'undefined' && cacheKey) {
|
|
try {
|
|
const cache = await caches.open(this.cacheStorage);
|
|
await cache.put(cacheKey, responseClone);
|
|
} catch (e) {
|
|
// Cache write failed, continue
|
|
}
|
|
}
|
|
|
|
return responseType === 'arraybuffer'
|
|
? await response.arrayBuffer()
|
|
: await response.text();
|
|
}
|
|
|
|
/**
|
|
* Fetch with progress reporting
|
|
*/
|
|
async fetchWithProgress(url) {
|
|
const response = await fetch(url);
|
|
|
|
if (!this.onProgress || !response.body) {
|
|
return response;
|
|
}
|
|
|
|
const contentLength = response.headers.get('content-length');
|
|
if (!contentLength) {
|
|
return response;
|
|
}
|
|
|
|
const total = parseInt(contentLength, 10);
|
|
let loaded = 0;
|
|
|
|
const reader = response.body.getReader();
|
|
const chunks = [];
|
|
|
|
while (true) {
|
|
const { done, value } = await reader.read();
|
|
if (done) break;
|
|
|
|
chunks.push(value);
|
|
loaded += value.length;
|
|
|
|
this.onProgress({
|
|
loaded,
|
|
total,
|
|
percent: Math.round((loaded / total) * 100),
|
|
});
|
|
}
|
|
|
|
const body = new Uint8Array(loaded);
|
|
let position = 0;
|
|
for (const chunk of chunks) {
|
|
body.set(chunk, position);
|
|
position += chunk.length;
|
|
}
|
|
|
|
return new Response(body, {
|
|
headers: response.headers,
|
|
status: response.status,
|
|
statusText: response.statusText,
|
|
});
|
|
}
|
|
|
|
/**
|
|
* Clear cached models
|
|
*/
|
|
async clearCache() {
|
|
if (typeof caches !== 'undefined') {
|
|
await caches.delete(this.cacheStorage);
|
|
console.log('Model cache cleared');
|
|
}
|
|
}
|
|
|
|
/**
|
|
* List available models
|
|
*/
|
|
static listModels() {
|
|
return Object.entries(MODELS).map(([key, config]) => ({
|
|
id: key,
|
|
...config,
|
|
}));
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Quick helper to create an embedder with a pre-configured model
|
|
*
|
|
* @example
|
|
* ```javascript
|
|
* import { createEmbedder } from './loader.js';
|
|
*
|
|
* const embedder = await createEmbedder('all-MiniLM-L6-v2');
|
|
* const embedding = embedder.embedOne("Hello world");
|
|
* ```
|
|
*/
|
|
export async function createEmbedder(modelName = DEFAULT_MODEL, wasmModule = null) {
|
|
// Import WASM module if not provided
|
|
if (!wasmModule) {
|
|
wasmModule = await import('./pkg/ruvector_onnx_embeddings_wasm.js');
|
|
// Only call default() for web/bundler targets (not nodejs)
|
|
if (typeof wasmModule.default === 'function') {
|
|
await wasmModule.default();
|
|
}
|
|
}
|
|
|
|
const loader = new ModelLoader();
|
|
const { modelBytes, tokenizerJson, config } = await loader.loadModel(modelName);
|
|
|
|
const embedderConfig = new wasmModule.WasmEmbedderConfig()
|
|
.setMaxLength(config.maxLength)
|
|
.setNormalize(true)
|
|
.setPooling(0); // Mean pooling
|
|
|
|
const embedder = wasmModule.WasmEmbedder.withConfig(
|
|
modelBytes,
|
|
tokenizerJson,
|
|
embedderConfig
|
|
);
|
|
|
|
return embedder;
|
|
}
|
|
|
|
/**
|
|
* Quick helper for one-off embedding (loads model, embeds, returns)
|
|
*
|
|
* @example
|
|
* ```javascript
|
|
* import { embed } from './loader.js';
|
|
*
|
|
* const embedding = await embed("Hello world");
|
|
* const embeddings = await embed(["Hello", "World"]);
|
|
* ```
|
|
*/
|
|
export async function embed(text, modelName = DEFAULT_MODEL) {
|
|
const embedder = await createEmbedder(modelName);
|
|
|
|
if (Array.isArray(text)) {
|
|
return embedder.embedBatch(text);
|
|
}
|
|
return embedder.embedOne(text);
|
|
}
|
|
|
|
/**
|
|
* Quick helper for similarity comparison
|
|
*
|
|
* @example
|
|
* ```javascript
|
|
* import { similarity } from './loader.js';
|
|
*
|
|
* const score = await similarity("I love dogs", "I adore puppies");
|
|
* console.log(score); // ~0.85
|
|
* ```
|
|
*/
|
|
export async function similarity(text1, text2, modelName = DEFAULT_MODEL) {
|
|
const embedder = await createEmbedder(modelName);
|
|
return embedder.similarity(text1, text2);
|
|
}
|
|
|
|
export default {
|
|
MODELS,
|
|
DEFAULT_MODEL,
|
|
ModelLoader,
|
|
createEmbedder,
|
|
embed,
|
|
similarity,
|
|
};
|