git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
148 lines
5.5 KiB
Rust
148 lines
5.5 KiB
Rust
// Integration tests for SQL engine
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use crate::sql::{SqlEngine, SqlParser};
|
|
|
|
#[test]
|
|
fn test_full_workflow() {
|
|
let engine = SqlEngine::new();
|
|
|
|
// Create table
|
|
let create_sql = "CREATE TABLE documents (id TEXT, content TEXT, embedding VECTOR(384))";
|
|
let mut parser = SqlParser::new(create_sql).unwrap();
|
|
let stmt = parser.parse().unwrap();
|
|
engine.execute(stmt).unwrap();
|
|
|
|
// Insert data
|
|
let insert_sql = "INSERT INTO documents (id, content, embedding) VALUES ('doc1', 'hello world', [1.0, 2.0, 3.0])";
|
|
let mut parser = SqlParser::new(insert_sql).unwrap();
|
|
let stmt = parser.parse().unwrap();
|
|
|
|
// This will fail due to dimension mismatch (3 vs 384), but tests the flow
|
|
let result = engine.execute(stmt);
|
|
assert!(result.is_err()); // Expected error due to dimension mismatch
|
|
}
|
|
|
|
#[test]
|
|
fn test_vector_similarity_search() {
|
|
let engine = SqlEngine::new();
|
|
|
|
// Create table with small dimensions for testing
|
|
let create_sql = "CREATE TABLE docs (id TEXT, embedding VECTOR(3))";
|
|
let mut parser = SqlParser::new(create_sql).unwrap();
|
|
let stmt = parser.parse().unwrap();
|
|
engine.execute(stmt).unwrap();
|
|
|
|
// Insert test data
|
|
for i in 0..10 {
|
|
let insert_sql = format!(
|
|
"INSERT INTO docs (id, embedding) VALUES ('doc{}', [{}, {}, {}])",
|
|
i,
|
|
i,
|
|
i * 2,
|
|
i * 3
|
|
);
|
|
let mut parser = SqlParser::new(&insert_sql).unwrap();
|
|
let stmt = parser.parse().unwrap();
|
|
engine.execute(stmt).unwrap();
|
|
}
|
|
|
|
// Search for similar vectors
|
|
let search_sql = "SELECT * FROM docs ORDER BY embedding <-> [5.0, 10.0, 15.0] LIMIT 3";
|
|
let mut parser = SqlParser::new(search_sql).unwrap();
|
|
let stmt = parser.parse().unwrap();
|
|
let result = engine.execute(stmt).unwrap();
|
|
|
|
assert_eq!(result.rows.len(), 3);
|
|
// The closest vector should be [5, 10, 15]
|
|
assert!(result.rows[0].get("id").is_some());
|
|
}
|
|
|
|
#[test]
|
|
fn test_metadata_filtering() {
|
|
let engine = SqlEngine::new();
|
|
|
|
// Create table
|
|
let create_sql = "CREATE TABLE docs (id TEXT, category TEXT, embedding VECTOR(3))";
|
|
let mut parser = SqlParser::new(create_sql).unwrap();
|
|
let stmt = parser.parse().unwrap();
|
|
engine.execute(stmt).unwrap();
|
|
|
|
// Insert data with categories
|
|
let categories = vec!["tech", "sports", "tech", "news", "sports"];
|
|
for (i, cat) in categories.iter().enumerate() {
|
|
let insert_sql =
|
|
format!(
|
|
"INSERT INTO docs (id, category, embedding) VALUES ('doc{}', '{}', [{}, {}, {}])",
|
|
i, cat, i, i * 2, i * 3
|
|
);
|
|
let mut parser = SqlParser::new(&insert_sql).unwrap();
|
|
let stmt = parser.parse().unwrap();
|
|
engine.execute(stmt).unwrap();
|
|
}
|
|
|
|
// Search with filter
|
|
let search_sql = "SELECT * FROM docs WHERE category = 'tech' ORDER BY embedding <-> [2.0, 4.0, 6.0] LIMIT 2";
|
|
let mut parser = SqlParser::new(search_sql).unwrap();
|
|
let stmt = parser.parse().unwrap();
|
|
let result = engine.execute(stmt).unwrap();
|
|
|
|
// VectorDB filtering may not be fully precise, so we check for at least 1 result
|
|
assert!(result.rows.len() >= 1);
|
|
assert!(result.rows.len() <= 2);
|
|
// All results should have category = 'tech'
|
|
for row in &result.rows {
|
|
if let Some(category) = row.get("category") {
|
|
assert_eq!(category.to_string(), "'tech'");
|
|
}
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_drop_table() {
|
|
let engine = SqlEngine::new();
|
|
|
|
// Create table
|
|
let create_sql = "CREATE TABLE temp (id TEXT, embedding VECTOR(3))";
|
|
let mut parser = SqlParser::new(create_sql).unwrap();
|
|
let stmt = parser.parse().unwrap();
|
|
engine.execute(stmt).unwrap();
|
|
|
|
assert_eq!(engine.list_tables().len(), 1);
|
|
|
|
// Drop table
|
|
let drop_sql = "DROP TABLE temp";
|
|
let mut parser = SqlParser::new(drop_sql).unwrap();
|
|
let stmt = parser.parse().unwrap();
|
|
engine.execute(stmt).unwrap();
|
|
|
|
assert_eq!(engine.list_tables().len(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_cosine_distance() {
|
|
let engine = SqlEngine::new();
|
|
|
|
let create_sql = "CREATE TABLE docs (id TEXT, embedding VECTOR(3))";
|
|
let mut parser = SqlParser::new(create_sql).unwrap();
|
|
engine.execute(parser.parse().unwrap()).unwrap();
|
|
|
|
// Insert normalized vectors for cosine similarity
|
|
let insert_sql = "INSERT INTO docs (id, embedding) VALUES ('doc1', [1.0, 0.0, 0.0])";
|
|
let mut parser = SqlParser::new(insert_sql).unwrap();
|
|
engine.execute(parser.parse().unwrap()).unwrap();
|
|
|
|
let insert_sql = "INSERT INTO docs (id, embedding) VALUES ('doc2', [0.0, 1.0, 0.0])";
|
|
let mut parser = SqlParser::new(insert_sql).unwrap();
|
|
engine.execute(parser.parse().unwrap()).unwrap();
|
|
|
|
// Search using cosine distance
|
|
let search_sql = "SELECT * FROM docs ORDER BY embedding <=> [0.9, 0.1, 0.0] LIMIT 1";
|
|
let mut parser = SqlParser::new(search_sql).unwrap();
|
|
let result = engine.execute(parser.parse().unwrap()).unwrap();
|
|
|
|
assert_eq!(result.rows.len(), 1);
|
|
// Should return doc1 as it's more similar to [0.9, 0.1, 0.0]
|
|
}
|
|
}
|