Files
wifi-densepose/crates/ruvector-core/tests/embeddings_test.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

308 lines
10 KiB
Rust

//! Integration tests for embedding providers
use ruvector_core::embeddings::{ApiEmbedding, EmbeddingProvider, HashEmbedding};
use ruvector_core::{types::DbOptions, AgenticDB};
use std::sync::Arc;
use tempfile::tempdir;
#[test]
fn test_hash_embedding_provider() {
let provider = HashEmbedding::new(128);
// Test basic embedding
let emb1 = provider.embed("hello world").unwrap();
assert_eq!(emb1.len(), 128);
// Test consistency
let emb2 = provider.embed("hello world").unwrap();
assert_eq!(emb1, emb2, "Same text should produce same embedding");
// Test different text produces different embeddings
let emb3 = provider.embed("goodbye world").unwrap();
assert_ne!(
emb1, emb3,
"Different text should produce different embeddings"
);
// Test normalization
let norm: f32 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-5,
"Embedding should be normalized to unit length"
);
// Test provider info
assert_eq!(provider.dimensions(), 128);
assert!(provider.name().contains("Hash"));
}
#[test]
fn test_agenticdb_with_hash_embeddings() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 128;
// Create AgenticDB with default hash embeddings
let db = AgenticDB::new(options).unwrap();
assert_eq!(db.embedding_provider_name(), "HashEmbedding (placeholder)");
// Test storing a reflexion episode
let episode_id = db
.store_episode(
"Solve a math problem".to_string(),
vec!["read problem".to_string(), "calculate".to_string()],
vec!["got answer 42".to_string()],
"Should have shown intermediate steps".to_string(),
)
.unwrap();
// Test retrieving similar episodes
let episodes = db
.retrieve_similar_episodes("math problem solving", 5)
.unwrap();
assert!(!episodes.is_empty());
assert_eq!(episodes[0].id, episode_id);
}
#[test]
fn test_agenticdb_with_custom_hash_provider() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 256;
// Create custom hash provider
let provider = Arc::new(HashEmbedding::new(256));
// Create AgenticDB with custom provider
let db = AgenticDB::with_embedding_provider(options, provider).unwrap();
assert_eq!(db.embedding_provider_name(), "HashEmbedding (placeholder)");
// Test creating a skill
let mut params = std::collections::HashMap::new();
params.insert("input".to_string(), "string".to_string());
let skill_id = db
.create_skill(
"Parse JSON".to_string(),
"Parse JSON from string".to_string(),
params,
vec!["json.parse()".to_string()],
)
.unwrap();
// Search for skills
let skills = db.search_skills("parse json data", 5).unwrap();
assert!(!skills.is_empty());
assert_eq!(skills[0].id, skill_id);
}
#[test]
fn test_dimension_mismatch_validation() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 128;
// Try to create with mismatched dimensions
let provider = Arc::new(HashEmbedding::new(256)); // Different from options
let result = AgenticDB::with_embedding_provider(options, provider);
assert!(result.is_err(), "Should fail when dimensions don't match");
if let Err(err) = result {
assert!(
err.to_string().contains("do not match"),
"Error should mention dimension mismatch"
);
}
}
#[test]
fn test_api_embedding_provider_construction() {
// Test OpenAI provider construction
let openai_small = ApiEmbedding::openai("sk-test", "text-embedding-3-small");
assert_eq!(openai_small.dimensions(), 1536);
assert_eq!(openai_small.name(), "ApiEmbedding");
let openai_large = ApiEmbedding::openai("sk-test", "text-embedding-3-large");
assert_eq!(openai_large.dimensions(), 3072);
// Test Cohere provider construction
let cohere = ApiEmbedding::cohere("co-test", "embed-english-v3.0");
assert_eq!(cohere.dimensions(), 1024);
// Test Voyage provider construction
let voyage = ApiEmbedding::voyage("vo-test", "voyage-2");
assert_eq!(voyage.dimensions(), 1024);
let voyage_large = ApiEmbedding::voyage("vo-test", "voyage-large-2");
assert_eq!(voyage_large.dimensions(), 1536);
}
#[test]
#[ignore] // Requires API key and network access
fn test_api_embedding_openai() {
let api_key = std::env::var("OPENAI_API_KEY")
.expect("OPENAI_API_KEY environment variable required for this test");
let provider = ApiEmbedding::openai(&api_key, "text-embedding-3-small");
let embedding = provider.embed("hello world").unwrap();
assert_eq!(embedding.len(), 1536);
// Check that embeddings are different for different texts
let embedding2 = provider.embed("goodbye world").unwrap();
assert_ne!(embedding, embedding2);
}
#[test]
#[ignore] // Requires API key and network access
fn test_agenticdb_with_openai_embeddings() {
let api_key = std::env::var("OPENAI_API_KEY")
.expect("OPENAI_API_KEY environment variable required for this test");
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 1536; // OpenAI text-embedding-3-small dimensions
let provider = Arc::new(ApiEmbedding::openai(&api_key, "text-embedding-3-small"));
let db = AgenticDB::with_embedding_provider(options, provider).unwrap();
assert_eq!(db.embedding_provider_name(), "ApiEmbedding");
// Test with real semantic embeddings
let _episode1_id = db
.store_episode(
"Solve calculus problem".to_string(),
vec![
"identify function".to_string(),
"take derivative".to_string(),
],
vec!["computed derivative".to_string()],
"Should explain chain rule application".to_string(),
)
.unwrap();
let _episode2_id = db
.store_episode(
"Solve algebra problem".to_string(),
vec!["simplify equation".to_string(), "solve for x".to_string()],
vec!["found x = 5".to_string()],
"Should show all steps".to_string(),
)
.unwrap();
// Search with semantic query - should find calculus episode first
let episodes = db
.retrieve_similar_episodes("derivative calculation", 2)
.unwrap();
assert!(!episodes.is_empty());
// With real embeddings, "derivative" should match calculus better than algebra
println!(
"Found episodes: {:?}",
episodes.iter().map(|e| &e.task).collect::<Vec<_>>()
);
}
#[cfg(feature = "real-embeddings")]
#[test]
#[ignore] // Requires model download
fn test_candle_embedding_provider() {
use ruvector_core::CandleEmbedding;
let provider =
CandleEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2", false).unwrap();
assert_eq!(provider.dimensions(), 384);
assert_eq!(provider.name(), "CandleEmbedding (transformer)");
let embedding = provider.embed("hello world").unwrap();
assert_eq!(embedding.len(), 384);
// Check normalization
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-3, "Embedding should be normalized");
// Test semantic similarity
let emb_dog = provider.embed("dog").unwrap();
let emb_cat = provider.embed("cat").unwrap();
let emb_car = provider.embed("car").unwrap();
// Cosine similarity
let similarity_dog_cat: f32 = emb_dog.iter().zip(emb_cat.iter()).map(|(a, b)| a * b).sum();
let similarity_dog_car: f32 = emb_dog.iter().zip(emb_car.iter()).map(|(a, b)| a * b).sum();
// "dog" and "cat" should be more similar than "dog" and "car"
assert!(
similarity_dog_cat > similarity_dog_car,
"Semantic embeddings should show dog-cat more similar than dog-car"
);
}
#[cfg(feature = "real-embeddings")]
#[test]
#[ignore] // Requires model download
fn test_agenticdb_with_candle_embeddings() {
use ruvector_core::CandleEmbedding;
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 384;
let provider = Arc::new(
CandleEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2", false).unwrap(),
);
let db = AgenticDB::with_embedding_provider(options, provider).unwrap();
assert_eq!(
db.embedding_provider_name(),
"CandleEmbedding (transformer)"
);
// Test with real semantic embeddings
let skill1_id = db
.create_skill(
"File I/O".to_string(),
"Read and write files to disk".to_string(),
std::collections::HashMap::new(),
vec![
"open()".to_string(),
"read()".to_string(),
"write()".to_string(),
],
)
.unwrap();
let skill2_id = db
.create_skill(
"Network I/O".to_string(),
"Send and receive data over network".to_string(),
std::collections::HashMap::new(),
vec![
"connect()".to_string(),
"send()".to_string(),
"recv()".to_string(),
],
)
.unwrap();
// Search with semantic query
let skills = db.search_skills("reading files from storage", 2).unwrap();
assert!(!skills.is_empty());
// With real embeddings, file I/O should match better
println!(
"Found skills: {:?}",
skills.iter().map(|s| &s.name).collect::<Vec<_>>()
);
}