Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
220
vendor/ruvector/crates/rvlite/src/sql/ast.rs
vendored
Normal file
220
vendor/ruvector/crates/rvlite/src/sql/ast.rs
vendored
Normal file
@@ -0,0 +1,220 @@
|
||||
// AST types for SQL statements
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// SQL statement types
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum SqlStatement {
|
||||
/// CREATE TABLE name (columns)
|
||||
CreateTable { name: String, columns: Vec<Column> },
|
||||
/// INSERT INTO table (columns) VALUES (values)
|
||||
Insert {
|
||||
table: String,
|
||||
columns: Vec<String>,
|
||||
values: Vec<Value>,
|
||||
},
|
||||
/// SELECT columns FROM table WHERE condition ORDER BY ... LIMIT k
|
||||
Select {
|
||||
columns: Vec<SelectColumn>,
|
||||
from: String,
|
||||
where_clause: Option<Expression>,
|
||||
order_by: Option<OrderBy>,
|
||||
limit: Option<usize>,
|
||||
},
|
||||
/// DROP TABLE name
|
||||
Drop { table: String },
|
||||
}
|
||||
|
||||
/// Column definition for CREATE TABLE
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct Column {
|
||||
pub name: String,
|
||||
pub data_type: DataType,
|
||||
}
|
||||
|
||||
/// Data types supported in SQL
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum DataType {
|
||||
/// TEXT type for strings
|
||||
Text,
|
||||
/// INTEGER type
|
||||
Integer,
|
||||
/// REAL/FLOAT type
|
||||
Real,
|
||||
/// VECTOR(dimensions) type for vector data
|
||||
Vector(usize),
|
||||
}
|
||||
|
||||
/// Column selector in SELECT
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum SelectColumn {
|
||||
/// SELECT *
|
||||
Wildcard,
|
||||
/// SELECT column_name
|
||||
Name(String),
|
||||
/// SELECT expression AS alias
|
||||
Expression {
|
||||
expr: Expression,
|
||||
alias: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
/// SQL expressions
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum Expression {
|
||||
/// Column reference
|
||||
Column(String),
|
||||
/// Literal value
|
||||
Literal(Value),
|
||||
/// Binary operation (e.g., a = b, a > b)
|
||||
BinaryOp {
|
||||
left: Box<Expression>,
|
||||
op: BinaryOperator,
|
||||
right: Box<Expression>,
|
||||
},
|
||||
/// Logical AND
|
||||
And(Box<Expression>, Box<Expression>),
|
||||
/// Logical OR
|
||||
Or(Box<Expression>, Box<Expression>),
|
||||
/// NOT expression
|
||||
Not(Box<Expression>),
|
||||
/// Function call
|
||||
Function { name: String, args: Vec<Expression> },
|
||||
/// Vector literal [1.0, 2.0, 3.0]
|
||||
VectorLiteral(Vec<f32>),
|
||||
/// Distance operation: column <-> vector
|
||||
/// Used for ORDER BY embedding <-> $vector
|
||||
Distance {
|
||||
column: String,
|
||||
metric: DistanceMetric,
|
||||
vector: Vec<f32>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Binary operators
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum BinaryOperator {
|
||||
/// =
|
||||
Eq,
|
||||
/// !=
|
||||
NotEq,
|
||||
/// >
|
||||
Gt,
|
||||
/// >=
|
||||
GtEq,
|
||||
/// <
|
||||
Lt,
|
||||
/// <=
|
||||
LtEq,
|
||||
/// LIKE
|
||||
Like,
|
||||
}
|
||||
|
||||
/// Distance metrics for vector similarity
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum DistanceMetric {
|
||||
/// L2 distance: <->
|
||||
L2,
|
||||
/// Cosine distance: <=>
|
||||
Cosine,
|
||||
/// Dot product: <#>
|
||||
DotProduct,
|
||||
}
|
||||
|
||||
/// ORDER BY clause
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct OrderBy {
|
||||
pub expression: Expression,
|
||||
pub direction: OrderDirection,
|
||||
}
|
||||
|
||||
/// Sort direction
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum OrderDirection {
|
||||
Asc,
|
||||
Desc,
|
||||
}
|
||||
|
||||
/// SQL values
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum Value {
|
||||
Null,
|
||||
Text(String),
|
||||
Integer(i64),
|
||||
Real(f64),
|
||||
Vector(Vec<f32>),
|
||||
Boolean(bool),
|
||||
}
|
||||
|
||||
impl Value {
|
||||
/// Convert to JSON value for metadata storage
|
||||
pub fn to_json(&self) -> serde_json::Value {
|
||||
match self {
|
||||
Value::Null => serde_json::Value::Null,
|
||||
Value::Text(s) => serde_json::Value::String(s.clone()),
|
||||
Value::Integer(i) => serde_json::Value::Number((*i).into()),
|
||||
Value::Real(f) => {
|
||||
serde_json::Value::Number(serde_json::Number::from_f64(*f).unwrap_or(0.into()))
|
||||
}
|
||||
Value::Vector(v) => serde_json::Value::Array(
|
||||
v.iter()
|
||||
.map(|f| {
|
||||
serde_json::Value::Number(
|
||||
serde_json::Number::from_f64(*f as f64).unwrap_or(0.into()),
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
Value::Boolean(b) => serde_json::Value::Bool(*b),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse from JSON value
|
||||
pub fn from_json(json: &serde_json::Value) -> Self {
|
||||
match json {
|
||||
serde_json::Value::Null => Value::Null,
|
||||
serde_json::Value::Bool(b) => Value::Boolean(*b),
|
||||
serde_json::Value::Number(n) => {
|
||||
if let Some(i) = n.as_i64() {
|
||||
Value::Integer(i)
|
||||
} else if let Some(f) = n.as_f64() {
|
||||
Value::Real(f)
|
||||
} else {
|
||||
Value::Null
|
||||
}
|
||||
}
|
||||
serde_json::Value::String(s) => Value::Text(s.clone()),
|
||||
serde_json::Value::Array(arr) => {
|
||||
// Try to parse as vector
|
||||
let floats: Option<Vec<f32>> =
|
||||
arr.iter().map(|v| v.as_f64().map(|f| f as f32)).collect();
|
||||
|
||||
if let Some(vec) = floats {
|
||||
Value::Vector(vec)
|
||||
} else {
|
||||
Value::Null
|
||||
}
|
||||
}
|
||||
serde_json::Value::Object(_) => Value::Null,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Value {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Value::Null => write!(f, "NULL"),
|
||||
Value::Text(s) => write!(f, "'{}'", s),
|
||||
Value::Integer(i) => write!(f, "{}", i),
|
||||
Value::Real(r) => write!(f, "{}", r),
|
||||
Value::Vector(v) => write!(
|
||||
f,
|
||||
"[{}]",
|
||||
v.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
),
|
||||
Value::Boolean(b) => write!(f, "{}", b),
|
||||
}
|
||||
}
|
||||
}
|
||||
561
vendor/ruvector/crates/rvlite/src/sql/executor.rs
vendored
Normal file
561
vendor/ruvector/crates/rvlite/src/sql/executor.rs
vendored
Normal file
@@ -0,0 +1,561 @@
|
||||
// SQL executor that integrates with ruvector-core VectorDB
|
||||
use super::ast::*;
|
||||
use crate::{ErrorKind, RvLiteError};
|
||||
use parking_lot::RwLock;
|
||||
use ruvector_core::{SearchQuery, VectorDB, VectorEntry};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Table schema definition
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TableSchema {
|
||||
pub name: String,
|
||||
pub columns: Vec<Column>,
|
||||
pub vector_column: Option<String>,
|
||||
pub vector_dimensions: Option<usize>,
|
||||
}
|
||||
|
||||
impl TableSchema {
|
||||
/// Find the vector column in the schema
|
||||
fn find_vector_column(&self) -> Option<(String, usize)> {
|
||||
for col in &self.columns {
|
||||
if let DataType::Vector(dims) = col.data_type {
|
||||
return Some((col.name.clone(), dims));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Validate that columns match the schema
|
||||
fn validate_columns(&self, columns: &[String]) -> Result<(), RvLiteError> {
|
||||
for col in columns {
|
||||
if !self.columns.iter().any(|c| &c.name == col) {
|
||||
return Err(RvLiteError {
|
||||
message: format!("Column '{}' not found in table '{}'", col, self.name),
|
||||
kind: ErrorKind::SqlError,
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get column data type
|
||||
fn get_column_type(&self, name: &str) -> Option<&DataType> {
|
||||
self.columns
|
||||
.iter()
|
||||
.find(|c| c.name == name)
|
||||
.map(|c| &c.data_type)
|
||||
}
|
||||
}
|
||||
|
||||
/// SQL execution result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExecutionResult {
|
||||
pub rows: Vec<HashMap<String, Value>>,
|
||||
pub rows_affected: usize,
|
||||
}
|
||||
|
||||
/// SQL Engine that manages tables and executes queries
|
||||
pub struct SqlEngine {
|
||||
/// Table schemas
|
||||
schemas: RwLock<HashMap<String, TableSchema>>,
|
||||
/// Vector databases (one per table)
|
||||
databases: RwLock<HashMap<String, VectorDB>>,
|
||||
}
|
||||
|
||||
impl SqlEngine {
|
||||
/// Create a new SQL engine
|
||||
pub fn new() -> Self {
|
||||
SqlEngine {
|
||||
schemas: RwLock::new(HashMap::new()),
|
||||
databases: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a SQL statement
|
||||
pub fn execute(&self, statement: SqlStatement) -> Result<ExecutionResult, RvLiteError> {
|
||||
match statement {
|
||||
SqlStatement::CreateTable { name, columns } => self.create_table(name, columns),
|
||||
SqlStatement::Insert {
|
||||
table,
|
||||
columns,
|
||||
values,
|
||||
} => self.insert(table, columns, values),
|
||||
SqlStatement::Select {
|
||||
columns,
|
||||
from,
|
||||
where_clause,
|
||||
order_by,
|
||||
limit,
|
||||
} => self.select(columns, from, where_clause, order_by, limit),
|
||||
SqlStatement::Drop { table } => self.drop_table(table),
|
||||
}
|
||||
}
|
||||
|
||||
fn create_table(
|
||||
&self,
|
||||
name: String,
|
||||
columns: Vec<Column>,
|
||||
) -> Result<ExecutionResult, RvLiteError> {
|
||||
let mut schemas = self.schemas.write();
|
||||
|
||||
if schemas.contains_key(&name) {
|
||||
return Err(RvLiteError {
|
||||
message: format!("Table '{}' already exists", name),
|
||||
kind: ErrorKind::SqlError,
|
||||
});
|
||||
}
|
||||
|
||||
// Find vector column
|
||||
let (vector_column, vector_dimensions) = columns
|
||||
.iter()
|
||||
.find_map(|col| {
|
||||
if let DataType::Vector(dims) = col.data_type {
|
||||
Some((col.name.clone(), dims))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.ok_or_else(|| RvLiteError {
|
||||
message: "Table must have at least one VECTOR column".to_string(),
|
||||
kind: ErrorKind::SqlError,
|
||||
})?;
|
||||
|
||||
let schema = TableSchema {
|
||||
name: name.clone(),
|
||||
columns,
|
||||
vector_column: Some(vector_column),
|
||||
vector_dimensions: Some(vector_dimensions),
|
||||
};
|
||||
|
||||
// Create vector database for this table
|
||||
let db_options = ruvector_core::types::DbOptions {
|
||||
dimensions: vector_dimensions,
|
||||
distance_metric: ruvector_core::DistanceMetric::Cosine,
|
||||
storage_path: "memory://".to_string(),
|
||||
hnsw_config: None,
|
||||
quantization: None,
|
||||
};
|
||||
|
||||
let db = VectorDB::new(db_options).map_err(|e| RvLiteError {
|
||||
message: format!("Failed to create vector database: {}", e),
|
||||
kind: ErrorKind::VectorError,
|
||||
})?;
|
||||
|
||||
let mut databases = self.databases.write();
|
||||
databases.insert(name.clone(), db);
|
||||
schemas.insert(name, schema);
|
||||
|
||||
Ok(ExecutionResult {
|
||||
rows: Vec::new(),
|
||||
rows_affected: 0,
|
||||
})
|
||||
}
|
||||
|
||||
fn insert(
|
||||
&self,
|
||||
table: String,
|
||||
columns: Vec<String>,
|
||||
values: Vec<Value>,
|
||||
) -> Result<ExecutionResult, RvLiteError> {
|
||||
let schemas = self.schemas.read();
|
||||
let schema = schemas.get(&table).ok_or_else(|| RvLiteError {
|
||||
message: format!("Table '{}' not found", table),
|
||||
kind: ErrorKind::SqlError,
|
||||
})?;
|
||||
|
||||
// Validate columns
|
||||
schema.validate_columns(&columns)?;
|
||||
|
||||
if columns.len() != values.len() {
|
||||
return Err(RvLiteError {
|
||||
message: format!(
|
||||
"Column count ({}) does not match value count ({})",
|
||||
columns.len(),
|
||||
values.len()
|
||||
),
|
||||
kind: ErrorKind::SqlError,
|
||||
});
|
||||
}
|
||||
|
||||
// Extract vector and metadata
|
||||
let mut vector: Option<Vec<f32>> = None;
|
||||
let mut metadata = HashMap::new();
|
||||
let mut id: Option<String> = None;
|
||||
|
||||
for (col, val) in columns.iter().zip(values.iter()) {
|
||||
if let Some(DataType::Vector(_)) = schema.get_column_type(col) {
|
||||
if let Value::Vector(v) = val {
|
||||
vector = Some(v.clone());
|
||||
} else {
|
||||
return Err(RvLiteError {
|
||||
message: format!("Expected vector value for column '{}'", col),
|
||||
kind: ErrorKind::SqlError,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// Store as metadata
|
||||
metadata.insert(col.clone(), val.to_json());
|
||||
|
||||
// Use 'id' column as vector ID if present
|
||||
if col == "id" {
|
||||
if let Value::Text(s) = val {
|
||||
id = Some(s.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let vector = vector.ok_or_else(|| RvLiteError {
|
||||
message: "No vector value provided".to_string(),
|
||||
kind: ErrorKind::SqlError,
|
||||
})?;
|
||||
|
||||
// Validate vector dimensions
|
||||
if let Some(expected_dims) = schema.vector_dimensions {
|
||||
if vector.len() != expected_dims {
|
||||
return Err(RvLiteError {
|
||||
message: format!(
|
||||
"Vector dimension mismatch: expected {}, got {}",
|
||||
expected_dims,
|
||||
vector.len()
|
||||
),
|
||||
kind: ErrorKind::SqlError,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Insert into vector database
|
||||
let entry = VectorEntry {
|
||||
id,
|
||||
vector,
|
||||
metadata: Some(metadata),
|
||||
};
|
||||
|
||||
let databases = self.databases.read();
|
||||
let db = databases.get(&table).ok_or_else(|| RvLiteError {
|
||||
message: format!("Database for table '{}' not found", table),
|
||||
kind: ErrorKind::SqlError,
|
||||
})?;
|
||||
|
||||
db.insert(entry).map_err(|e| RvLiteError {
|
||||
message: format!("Failed to insert: {}", e),
|
||||
kind: ErrorKind::VectorError,
|
||||
})?;
|
||||
|
||||
Ok(ExecutionResult {
|
||||
rows: Vec::new(),
|
||||
rows_affected: 1,
|
||||
})
|
||||
}
|
||||
|
||||
fn select(
|
||||
&self,
|
||||
_columns: Vec<SelectColumn>,
|
||||
from: String,
|
||||
where_clause: Option<Expression>,
|
||||
order_by: Option<OrderBy>,
|
||||
limit: Option<usize>,
|
||||
) -> Result<ExecutionResult, RvLiteError> {
|
||||
let schemas = self.schemas.read();
|
||||
let schema = schemas.get(&from).ok_or_else(|| RvLiteError {
|
||||
message: format!("Table '{}' not found", from),
|
||||
kind: ErrorKind::SqlError,
|
||||
})?;
|
||||
|
||||
let databases = self.databases.read();
|
||||
let db = databases.get(&from).ok_or_else(|| RvLiteError {
|
||||
message: format!("Database for table '{}' not found", from),
|
||||
kind: ErrorKind::SqlError,
|
||||
})?;
|
||||
|
||||
// Handle vector similarity search
|
||||
if let Some(order_by) = order_by {
|
||||
if let Expression::Distance {
|
||||
column: _,
|
||||
metric: _,
|
||||
vector,
|
||||
} = order_by.expression
|
||||
{
|
||||
let k = limit.unwrap_or(10);
|
||||
|
||||
// Build filter from WHERE clause
|
||||
let filter = if let Some(where_expr) = where_clause {
|
||||
Some(self.build_filter(where_expr)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let query = SearchQuery {
|
||||
vector,
|
||||
k,
|
||||
filter,
|
||||
ef_search: None,
|
||||
};
|
||||
|
||||
let results = db.search(query).map_err(|e| RvLiteError {
|
||||
message: format!("Search failed: {}", e),
|
||||
kind: ErrorKind::VectorError,
|
||||
})?;
|
||||
|
||||
// Convert results to rows
|
||||
let rows: Vec<HashMap<String, Value>> = results
|
||||
.into_iter()
|
||||
.map(|result| {
|
||||
let mut row = HashMap::new();
|
||||
|
||||
// Add vector if present
|
||||
if let Some(vec_col) = &schema.vector_column {
|
||||
if let Some(vector) = result.vector {
|
||||
row.insert(vec_col.clone(), Value::Vector(vector));
|
||||
}
|
||||
}
|
||||
|
||||
// Add metadata
|
||||
if let Some(metadata) = result.metadata {
|
||||
for (key, val) in metadata {
|
||||
row.insert(key, Value::from_json(&val));
|
||||
}
|
||||
}
|
||||
|
||||
// Add distance score
|
||||
row.insert("_distance".to_string(), Value::Real(result.score as f64));
|
||||
|
||||
row
|
||||
})
|
||||
.collect();
|
||||
|
||||
return Ok(ExecutionResult {
|
||||
rows,
|
||||
rows_affected: 0,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Non-vector query - return all rows (scan all vectors)
|
||||
// This is essentially a table scan through the vector database
|
||||
let k = limit.unwrap_or(1000); // Default to 1000 rows max
|
||||
|
||||
// Create a zero vector for exhaustive search
|
||||
let dims = schema.vector_dimensions.unwrap_or(3);
|
||||
let query_vector = vec![0.0f32; dims];
|
||||
|
||||
// Build filter from WHERE clause
|
||||
let filter = if let Some(where_expr) = where_clause {
|
||||
Some(self.build_filter(where_expr)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let query = SearchQuery {
|
||||
vector: query_vector,
|
||||
k,
|
||||
filter,
|
||||
ef_search: None,
|
||||
};
|
||||
|
||||
let results = db.search(query).map_err(|e| RvLiteError {
|
||||
message: format!("Search failed: {}", e),
|
||||
kind: ErrorKind::VectorError,
|
||||
})?;
|
||||
|
||||
// Convert results to rows
|
||||
let rows: Vec<HashMap<String, Value>> = results
|
||||
.into_iter()
|
||||
.map(|result| {
|
||||
let mut row = HashMap::new();
|
||||
|
||||
// Add vector if present
|
||||
if let Some(vec_col) = &schema.vector_column {
|
||||
if let Some(vector) = result.vector {
|
||||
row.insert(vec_col.clone(), Value::Vector(vector));
|
||||
}
|
||||
}
|
||||
|
||||
// Add metadata
|
||||
if let Some(metadata) = result.metadata {
|
||||
for (key, val) in metadata {
|
||||
row.insert(key, Value::from_json(&val));
|
||||
}
|
||||
}
|
||||
|
||||
row
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(ExecutionResult {
|
||||
rows,
|
||||
rows_affected: 0,
|
||||
})
|
||||
}
|
||||
|
||||
fn drop_table(&self, table: String) -> Result<ExecutionResult, RvLiteError> {
|
||||
let mut schemas = self.schemas.write();
|
||||
let mut databases = self.databases.write();
|
||||
|
||||
schemas.remove(&table).ok_or_else(|| RvLiteError {
|
||||
message: format!("Table '{}' not found", table),
|
||||
kind: ErrorKind::SqlError,
|
||||
})?;
|
||||
|
||||
databases.remove(&table);
|
||||
|
||||
Ok(ExecutionResult {
|
||||
rows: Vec::new(),
|
||||
rows_affected: 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// Build metadata filter from WHERE expression
|
||||
fn build_filter(
|
||||
&self,
|
||||
expr: Expression,
|
||||
) -> Result<HashMap<String, serde_json::Value>, RvLiteError> {
|
||||
let mut filter = HashMap::new();
|
||||
|
||||
match expr {
|
||||
Expression::BinaryOp { left, op, right } => {
|
||||
if let (Expression::Column(col), Expression::Literal(val)) = (*left, *right) {
|
||||
if op == BinaryOperator::Eq {
|
||||
filter.insert(col, val.to_json());
|
||||
} else {
|
||||
return Err(RvLiteError {
|
||||
message: "Only equality filters supported in WHERE clause".to_string(),
|
||||
kind: ErrorKind::NotImplemented,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Expression::And(left, right) => {
|
||||
let left_filter = self.build_filter(*left)?;
|
||||
let right_filter = self.build_filter(*right)?;
|
||||
filter.extend(left_filter);
|
||||
filter.extend(right_filter);
|
||||
}
|
||||
_ => {
|
||||
return Err(RvLiteError {
|
||||
message: "Unsupported WHERE clause expression".to_string(),
|
||||
kind: ErrorKind::NotImplemented,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(filter)
|
||||
}
|
||||
|
||||
/// List all tables
|
||||
pub fn list_tables(&self) -> Vec<String> {
|
||||
self.schemas.read().keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Get table schema
|
||||
pub fn get_schema(&self, table: &str) -> Option<TableSchema> {
|
||||
self.schemas.read().get(table).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SqlEngine {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_create_and_insert() {
|
||||
let engine = SqlEngine::new();
|
||||
|
||||
// Create table
|
||||
let create = SqlStatement::CreateTable {
|
||||
name: "docs".to_string(),
|
||||
columns: vec![
|
||||
Column {
|
||||
name: "id".to_string(),
|
||||
data_type: DataType::Text,
|
||||
},
|
||||
Column {
|
||||
name: "content".to_string(),
|
||||
data_type: DataType::Text,
|
||||
},
|
||||
Column {
|
||||
name: "embedding".to_string(),
|
||||
data_type: DataType::Vector(3),
|
||||
},
|
||||
],
|
||||
};
|
||||
engine.execute(create).unwrap();
|
||||
|
||||
// Insert row
|
||||
let insert = SqlStatement::Insert {
|
||||
table: "docs".to_string(),
|
||||
columns: vec![
|
||||
"id".to_string(),
|
||||
"content".to_string(),
|
||||
"embedding".to_string(),
|
||||
],
|
||||
values: vec![
|
||||
Value::Text("1".to_string()),
|
||||
Value::Text("hello".to_string()),
|
||||
Value::Vector(vec![1.0, 2.0, 3.0]),
|
||||
],
|
||||
};
|
||||
let result = engine.execute(insert).unwrap();
|
||||
assert_eq!(result.rows_affected, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_search() {
|
||||
let engine = SqlEngine::new();
|
||||
|
||||
// Create table
|
||||
let create = SqlStatement::CreateTable {
|
||||
name: "docs".to_string(),
|
||||
columns: vec![
|
||||
Column {
|
||||
name: "id".to_string(),
|
||||
data_type: DataType::Text,
|
||||
},
|
||||
Column {
|
||||
name: "embedding".to_string(),
|
||||
data_type: DataType::Vector(3),
|
||||
},
|
||||
],
|
||||
};
|
||||
engine.execute(create).unwrap();
|
||||
|
||||
// Insert rows
|
||||
for i in 0..5 {
|
||||
let insert = SqlStatement::Insert {
|
||||
table: "docs".to_string(),
|
||||
columns: vec!["id".to_string(), "embedding".to_string()],
|
||||
values: vec![
|
||||
Value::Text(format!("{}", i)),
|
||||
Value::Vector(vec![i as f32, i as f32 * 2.0, i as f32 * 3.0]),
|
||||
],
|
||||
};
|
||||
engine.execute(insert).unwrap();
|
||||
}
|
||||
|
||||
// Search
|
||||
let select = SqlStatement::Select {
|
||||
columns: vec![SelectColumn::Wildcard],
|
||||
from: "docs".to_string(),
|
||||
where_clause: None,
|
||||
order_by: Some(OrderBy {
|
||||
expression: Expression::Distance {
|
||||
column: "embedding".to_string(),
|
||||
metric: DistanceMetric::L2,
|
||||
vector: vec![2.0, 4.0, 6.0],
|
||||
},
|
||||
direction: OrderDirection::Asc,
|
||||
}),
|
||||
limit: Some(3),
|
||||
};
|
||||
|
||||
let result = engine.execute(select).unwrap();
|
||||
assert_eq!(result.rows.len(), 3);
|
||||
}
|
||||
}
|
||||
13
vendor/ruvector/crates/rvlite/src/sql/mod.rs
vendored
Normal file
13
vendor/ruvector/crates/rvlite/src/sql/mod.rs
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
// SQL query engine module for rvlite
|
||||
// Provides SQL interface for vector database operations with WASM compatibility
|
||||
|
||||
mod ast;
|
||||
mod executor;
|
||||
mod parser;
|
||||
|
||||
pub use ast::*;
|
||||
pub use executor::{ExecutionResult, SqlEngine};
|
||||
pub use parser::{ParseError, SqlParser};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
823
vendor/ruvector/crates/rvlite/src/sql/parser.rs
vendored
Normal file
823
vendor/ruvector/crates/rvlite/src/sql/parser.rs
vendored
Normal file
@@ -0,0 +1,823 @@
|
||||
// Hand-rolled SQL parser for WASM compatibility
|
||||
// Implements recursive descent parsing for vector-specific SQL
|
||||
|
||||
use super::ast::*;
|
||||
use std::fmt;
|
||||
|
||||
/// Parse error type
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ParseError {
|
||||
pub message: String,
|
||||
pub position: usize,
|
||||
}
|
||||
|
||||
impl fmt::Display for ParseError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Parse error at position {}: {}",
|
||||
self.position, self.message
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ParseError {}
|
||||
|
||||
/// Token types
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum Token {
|
||||
// Keywords
|
||||
Select,
|
||||
From,
|
||||
Where,
|
||||
Insert,
|
||||
Into,
|
||||
Values,
|
||||
Create,
|
||||
Table,
|
||||
Drop,
|
||||
OrderBy,
|
||||
Limit,
|
||||
And,
|
||||
Or,
|
||||
Not,
|
||||
As,
|
||||
|
||||
// Data types
|
||||
Text,
|
||||
Integer,
|
||||
Real,
|
||||
Vector,
|
||||
|
||||
// Operators
|
||||
Eq,
|
||||
NotEq,
|
||||
Gt,
|
||||
GtEq,
|
||||
Lt,
|
||||
LtEq,
|
||||
Like,
|
||||
|
||||
// Distance operators
|
||||
L2Distance, // <->
|
||||
CosineDistance, // <=>
|
||||
DotProduct, // <#>
|
||||
|
||||
// Delimiters
|
||||
LeftParen,
|
||||
RightParen,
|
||||
LeftBracket,
|
||||
RightBracket,
|
||||
Comma,
|
||||
Semicolon,
|
||||
Asterisk,
|
||||
|
||||
// Values
|
||||
Identifier(String),
|
||||
StringLiteral(String),
|
||||
NumberLiteral(String),
|
||||
|
||||
// End
|
||||
Eof,
|
||||
}
|
||||
|
||||
/// Tokenizer (lexer)
|
||||
struct Tokenizer {
|
||||
input: Vec<char>,
|
||||
position: usize,
|
||||
}
|
||||
|
||||
impl Tokenizer {
|
||||
fn new(input: &str) -> Self {
|
||||
Tokenizer {
|
||||
input: input.chars().collect(),
|
||||
position: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn current(&self) -> Option<char> {
|
||||
self.input.get(self.position).copied()
|
||||
}
|
||||
|
||||
fn advance(&mut self) {
|
||||
self.position += 1;
|
||||
}
|
||||
|
||||
fn skip_whitespace(&mut self) {
|
||||
while let Some(ch) = self.current() {
|
||||
if ch.is_whitespace() {
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn read_identifier(&mut self) -> String {
|
||||
let mut result = String::new();
|
||||
while let Some(ch) = self.current() {
|
||||
if ch.is_alphanumeric() || ch == '_' {
|
||||
result.push(ch);
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn read_string(&mut self) -> Result<String, ParseError> {
|
||||
let mut result = String::new();
|
||||
self.advance(); // Skip opening quote
|
||||
|
||||
while let Some(ch) = self.current() {
|
||||
if ch == '\'' {
|
||||
self.advance();
|
||||
return Ok(result);
|
||||
} else {
|
||||
result.push(ch);
|
||||
self.advance();
|
||||
}
|
||||
}
|
||||
|
||||
Err(ParseError {
|
||||
message: "Unterminated string literal".to_string(),
|
||||
position: self.position,
|
||||
})
|
||||
}
|
||||
|
||||
fn read_number(&mut self) -> String {
|
||||
let mut result = String::new();
|
||||
let mut has_dot = false;
|
||||
|
||||
while let Some(ch) = self.current() {
|
||||
if ch.is_numeric() {
|
||||
result.push(ch);
|
||||
self.advance();
|
||||
} else if ch == '.' && !has_dot {
|
||||
has_dot = true;
|
||||
result.push(ch);
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn next_token(&mut self) -> Result<Token, ParseError> {
|
||||
self.skip_whitespace();
|
||||
|
||||
let ch = match self.current() {
|
||||
Some(c) => c,
|
||||
None => return Ok(Token::Eof),
|
||||
};
|
||||
|
||||
match ch {
|
||||
'(' => {
|
||||
self.advance();
|
||||
Ok(Token::LeftParen)
|
||||
}
|
||||
')' => {
|
||||
self.advance();
|
||||
Ok(Token::RightParen)
|
||||
}
|
||||
'[' => {
|
||||
self.advance();
|
||||
Ok(Token::LeftBracket)
|
||||
}
|
||||
']' => {
|
||||
self.advance();
|
||||
Ok(Token::RightBracket)
|
||||
}
|
||||
',' => {
|
||||
self.advance();
|
||||
Ok(Token::Comma)
|
||||
}
|
||||
';' => {
|
||||
self.advance();
|
||||
Ok(Token::Semicolon)
|
||||
}
|
||||
'*' => {
|
||||
self.advance();
|
||||
Ok(Token::Asterisk)
|
||||
}
|
||||
'=' => {
|
||||
self.advance();
|
||||
Ok(Token::Eq)
|
||||
}
|
||||
'!' => {
|
||||
self.advance();
|
||||
if self.current() == Some('=') {
|
||||
self.advance();
|
||||
Ok(Token::NotEq)
|
||||
} else {
|
||||
Err(ParseError {
|
||||
message: "Expected '=' after '!'".to_string(),
|
||||
position: self.position,
|
||||
})
|
||||
}
|
||||
}
|
||||
'>' => {
|
||||
self.advance();
|
||||
if self.current() == Some('=') {
|
||||
self.advance();
|
||||
Ok(Token::GtEq)
|
||||
} else {
|
||||
Ok(Token::Gt)
|
||||
}
|
||||
}
|
||||
'<' => {
|
||||
self.advance();
|
||||
match self.current() {
|
||||
Some('=') => {
|
||||
self.advance();
|
||||
if self.current() == Some('>') {
|
||||
self.advance();
|
||||
Ok(Token::CosineDistance)
|
||||
} else {
|
||||
Ok(Token::LtEq)
|
||||
}
|
||||
}
|
||||
Some('-') => {
|
||||
self.advance();
|
||||
if self.current() == Some('>') {
|
||||
self.advance();
|
||||
Ok(Token::L2Distance)
|
||||
} else {
|
||||
Err(ParseError {
|
||||
message: "Expected '>' after '<-'".to_string(),
|
||||
position: self.position,
|
||||
})
|
||||
}
|
||||
}
|
||||
Some('#') => {
|
||||
self.advance();
|
||||
if self.current() == Some('>') {
|
||||
self.advance();
|
||||
Ok(Token::DotProduct)
|
||||
} else {
|
||||
Err(ParseError {
|
||||
message: "Expected '>' after '<#'".to_string(),
|
||||
position: self.position,
|
||||
})
|
||||
}
|
||||
}
|
||||
_ => Ok(Token::Lt),
|
||||
}
|
||||
}
|
||||
'\'' => Ok(Token::StringLiteral(self.read_string()?)),
|
||||
_ if ch.is_numeric() => Ok(Token::NumberLiteral(self.read_number())),
|
||||
_ if ch.is_alphabetic() || ch == '_' => {
|
||||
let ident = self.read_identifier();
|
||||
Ok(match ident.to_uppercase().as_str() {
|
||||
"SELECT" => Token::Select,
|
||||
"FROM" => Token::From,
|
||||
"WHERE" => Token::Where,
|
||||
"INSERT" => Token::Insert,
|
||||
"INTO" => Token::Into,
|
||||
"VALUES" => Token::Values,
|
||||
"CREATE" => Token::Create,
|
||||
"TABLE" => Token::Table,
|
||||
"DROP" => Token::Drop,
|
||||
"ORDER" => {
|
||||
self.skip_whitespace();
|
||||
if self.read_identifier().to_uppercase() == "BY" {
|
||||
Token::OrderBy
|
||||
} else {
|
||||
Token::Identifier(ident)
|
||||
}
|
||||
}
|
||||
"LIMIT" => Token::Limit,
|
||||
"AND" => Token::And,
|
||||
"OR" => Token::Or,
|
||||
"NOT" => Token::Not,
|
||||
"AS" => Token::As,
|
||||
"TEXT" => Token::Text,
|
||||
"INTEGER" => Token::Integer,
|
||||
"REAL" => Token::Real,
|
||||
"VECTOR" => Token::Vector,
|
||||
"LIKE" => Token::Like,
|
||||
_ => Token::Identifier(ident),
|
||||
})
|
||||
}
|
||||
_ => Err(ParseError {
|
||||
message: format!("Unexpected character: {}", ch),
|
||||
position: self.position,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// SQL Parser
|
||||
pub struct SqlParser {
|
||||
tokens: Vec<Token>,
|
||||
position: usize,
|
||||
}
|
||||
|
||||
impl SqlParser {
|
||||
/// Create a new parser from SQL string
|
||||
pub fn new(input: &str) -> Result<Self, ParseError> {
|
||||
let mut tokenizer = Tokenizer::new(input);
|
||||
let mut tokens = Vec::new();
|
||||
|
||||
loop {
|
||||
let token = tokenizer.next_token()?;
|
||||
if token == Token::Eof {
|
||||
tokens.push(token);
|
||||
break;
|
||||
}
|
||||
tokens.push(token);
|
||||
}
|
||||
|
||||
Ok(SqlParser {
|
||||
tokens,
|
||||
position: 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse SQL statement
|
||||
pub fn parse(&mut self) -> Result<SqlStatement, ParseError> {
|
||||
let token = self.current().clone();
|
||||
|
||||
match token {
|
||||
Token::Select => self.parse_select(),
|
||||
Token::Insert => self.parse_insert(),
|
||||
Token::Create => self.parse_create(),
|
||||
Token::Drop => self.parse_drop(),
|
||||
_ => Err(ParseError {
|
||||
message: format!("Expected SELECT, INSERT, CREATE, or DROP, got {:?}", token),
|
||||
position: self.position,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn current(&self) -> &Token {
|
||||
self.tokens.get(self.position).unwrap_or(&Token::Eof)
|
||||
}
|
||||
|
||||
fn advance(&mut self) {
|
||||
if self.position < self.tokens.len() {
|
||||
self.position += 1;
|
||||
}
|
||||
}
|
||||
|
||||
fn expect(&mut self, expected: Token) -> Result<(), ParseError> {
|
||||
let current = self.current().clone();
|
||||
if current == expected {
|
||||
self.advance();
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ParseError {
|
||||
message: format!("Expected {:?}, got {:?}", expected, current),
|
||||
position: self.position,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_select(&mut self) -> Result<SqlStatement, ParseError> {
|
||||
self.expect(Token::Select)?;
|
||||
|
||||
let columns = self.parse_select_columns()?;
|
||||
|
||||
self.expect(Token::From)?;
|
||||
let from = self.parse_identifier()?;
|
||||
|
||||
let where_clause = if matches!(self.current(), Token::Where) {
|
||||
self.advance();
|
||||
Some(self.parse_expression()?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let order_by = if matches!(self.current(), Token::OrderBy) {
|
||||
self.advance();
|
||||
Some(self.parse_order_by()?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let limit = if matches!(self.current(), Token::Limit) {
|
||||
self.advance();
|
||||
Some(self.parse_number()? as usize)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(SqlStatement::Select {
|
||||
columns,
|
||||
from,
|
||||
where_clause,
|
||||
order_by,
|
||||
limit,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_select_columns(&mut self) -> Result<Vec<SelectColumn>, ParseError> {
|
||||
if matches!(self.current(), Token::Asterisk) {
|
||||
self.advance();
|
||||
return Ok(vec![SelectColumn::Wildcard]);
|
||||
}
|
||||
|
||||
let mut columns = Vec::new();
|
||||
loop {
|
||||
let name = self.parse_identifier()?;
|
||||
columns.push(SelectColumn::Name(name));
|
||||
|
||||
if !matches!(self.current(), Token::Comma) {
|
||||
break;
|
||||
}
|
||||
self.advance();
|
||||
}
|
||||
|
||||
Ok(columns)
|
||||
}
|
||||
|
||||
fn parse_insert(&mut self) -> Result<SqlStatement, ParseError> {
|
||||
self.expect(Token::Insert)?;
|
||||
self.expect(Token::Into)?;
|
||||
|
||||
let table = self.parse_identifier()?;
|
||||
|
||||
self.expect(Token::LeftParen)?;
|
||||
let columns = self.parse_identifier_list()?;
|
||||
self.expect(Token::RightParen)?;
|
||||
|
||||
self.expect(Token::Values)?;
|
||||
self.expect(Token::LeftParen)?;
|
||||
let values = self.parse_value_list()?;
|
||||
self.expect(Token::RightParen)?;
|
||||
|
||||
Ok(SqlStatement::Insert {
|
||||
table,
|
||||
columns,
|
||||
values,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_create(&mut self) -> Result<SqlStatement, ParseError> {
|
||||
self.expect(Token::Create)?;
|
||||
self.expect(Token::Table)?;
|
||||
|
||||
let name = self.parse_identifier()?;
|
||||
|
||||
self.expect(Token::LeftParen)?;
|
||||
let columns = self.parse_column_definitions()?;
|
||||
self.expect(Token::RightParen)?;
|
||||
|
||||
Ok(SqlStatement::CreateTable { name, columns })
|
||||
}
|
||||
|
||||
fn parse_drop(&mut self) -> Result<SqlStatement, ParseError> {
|
||||
self.expect(Token::Drop)?;
|
||||
self.expect(Token::Table)?;
|
||||
|
||||
let table = self.parse_identifier()?;
|
||||
|
||||
Ok(SqlStatement::Drop { table })
|
||||
}
|
||||
|
||||
fn parse_column_definitions(&mut self) -> Result<Vec<Column>, ParseError> {
|
||||
let mut columns = Vec::new();
|
||||
|
||||
loop {
|
||||
let name = self.parse_identifier()?;
|
||||
let data_type = self.parse_data_type()?;
|
||||
|
||||
columns.push(Column { name, data_type });
|
||||
|
||||
if !matches!(self.current(), Token::Comma) {
|
||||
break;
|
||||
}
|
||||
self.advance();
|
||||
}
|
||||
|
||||
Ok(columns)
|
||||
}
|
||||
|
||||
fn parse_data_type(&mut self) -> Result<DataType, ParseError> {
|
||||
match self.current().clone() {
|
||||
Token::Text => {
|
||||
self.advance();
|
||||
Ok(DataType::Text)
|
||||
}
|
||||
Token::Integer => {
|
||||
self.advance();
|
||||
Ok(DataType::Integer)
|
||||
}
|
||||
Token::Real => {
|
||||
self.advance();
|
||||
Ok(DataType::Real)
|
||||
}
|
||||
Token::Vector => {
|
||||
self.advance();
|
||||
self.expect(Token::LeftParen)?;
|
||||
let dims = self.parse_number()? as usize;
|
||||
self.expect(Token::RightParen)?;
|
||||
Ok(DataType::Vector(dims))
|
||||
}
|
||||
_ => Err(ParseError {
|
||||
message: "Expected data type (TEXT, INTEGER, REAL, or VECTOR)".to_string(),
|
||||
position: self.position,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_expression(&mut self) -> Result<Expression, ParseError> {
|
||||
self.parse_or_expression()
|
||||
}
|
||||
|
||||
fn parse_or_expression(&mut self) -> Result<Expression, ParseError> {
|
||||
let mut left = self.parse_and_expression()?;
|
||||
|
||||
while matches!(self.current(), Token::Or) {
|
||||
self.advance();
|
||||
let right = self.parse_and_expression()?;
|
||||
left = Expression::Or(Box::new(left), Box::new(right));
|
||||
}
|
||||
|
||||
Ok(left)
|
||||
}
|
||||
|
||||
fn parse_and_expression(&mut self) -> Result<Expression, ParseError> {
|
||||
let mut left = self.parse_comparison_expression()?;
|
||||
|
||||
while matches!(self.current(), Token::And) {
|
||||
self.advance();
|
||||
let right = self.parse_comparison_expression()?;
|
||||
left = Expression::And(Box::new(left), Box::new(right));
|
||||
}
|
||||
|
||||
Ok(left)
|
||||
}
|
||||
|
||||
fn parse_comparison_expression(&mut self) -> Result<Expression, ParseError> {
|
||||
let left = self.parse_primary_expression()?;
|
||||
|
||||
let op = match self.current() {
|
||||
Token::Eq => BinaryOperator::Eq,
|
||||
Token::NotEq => BinaryOperator::NotEq,
|
||||
Token::Gt => BinaryOperator::Gt,
|
||||
Token::GtEq => BinaryOperator::GtEq,
|
||||
Token::Lt => BinaryOperator::Lt,
|
||||
Token::LtEq => BinaryOperator::LtEq,
|
||||
Token::Like => BinaryOperator::Like,
|
||||
_ => return Ok(left),
|
||||
};
|
||||
|
||||
self.advance();
|
||||
let right = self.parse_primary_expression()?;
|
||||
|
||||
Ok(Expression::BinaryOp {
|
||||
left: Box::new(left),
|
||||
op,
|
||||
right: Box::new(right),
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_primary_expression(&mut self) -> Result<Expression, ParseError> {
|
||||
match self.current().clone() {
|
||||
Token::Identifier(name) => {
|
||||
self.advance();
|
||||
Ok(Expression::Column(name))
|
||||
}
|
||||
Token::StringLiteral(s) => {
|
||||
self.advance();
|
||||
Ok(Expression::Literal(Value::Text(s)))
|
||||
}
|
||||
Token::NumberLiteral(n) => {
|
||||
self.advance();
|
||||
let value = if n.contains('.') {
|
||||
Value::Real(n.parse().unwrap())
|
||||
} else {
|
||||
Value::Integer(n.parse().unwrap())
|
||||
};
|
||||
Ok(Expression::Literal(value))
|
||||
}
|
||||
Token::LeftBracket => {
|
||||
self.advance();
|
||||
let vec = self.parse_vector_literal()?;
|
||||
self.expect(Token::RightBracket)?;
|
||||
Ok(Expression::VectorLiteral(vec))
|
||||
}
|
||||
Token::Not => {
|
||||
self.advance();
|
||||
let expr = self.parse_primary_expression()?;
|
||||
Ok(Expression::Not(Box::new(expr)))
|
||||
}
|
||||
_ => Err(ParseError {
|
||||
message: format!("Unexpected token in expression: {:?}", self.current()),
|
||||
position: self.position,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_order_by(&mut self) -> Result<OrderBy, ParseError> {
|
||||
// Parse column <-> vector or column <=> vector
|
||||
let column = self.parse_identifier()?;
|
||||
|
||||
let metric = match self.current() {
|
||||
Token::L2Distance => {
|
||||
self.advance();
|
||||
DistanceMetric::L2
|
||||
}
|
||||
Token::CosineDistance => {
|
||||
self.advance();
|
||||
DistanceMetric::Cosine
|
||||
}
|
||||
Token::DotProduct => {
|
||||
self.advance();
|
||||
DistanceMetric::DotProduct
|
||||
}
|
||||
_ => {
|
||||
return Err(ParseError {
|
||||
message: "Expected distance operator (<->, <=>, or <#>)".to_string(),
|
||||
position: self.position,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let vector = if matches!(self.current(), Token::LeftBracket) {
|
||||
self.advance();
|
||||
let vec = self.parse_vector_literal()?;
|
||||
self.expect(Token::RightBracket)?;
|
||||
vec
|
||||
} else {
|
||||
return Err(ParseError {
|
||||
message: "Expected vector literal after distance operator".to_string(),
|
||||
position: self.position,
|
||||
});
|
||||
};
|
||||
|
||||
Ok(OrderBy {
|
||||
expression: Expression::Distance {
|
||||
column,
|
||||
metric,
|
||||
vector,
|
||||
},
|
||||
direction: OrderDirection::Asc,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_identifier(&mut self) -> Result<String, ParseError> {
|
||||
match self.current().clone() {
|
||||
Token::Identifier(name) => {
|
||||
self.advance();
|
||||
Ok(name)
|
||||
}
|
||||
_ => Err(ParseError {
|
||||
message: "Expected identifier".to_string(),
|
||||
position: self.position,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_identifier_list(&mut self) -> Result<Vec<String>, ParseError> {
|
||||
let mut identifiers = Vec::new();
|
||||
|
||||
loop {
|
||||
identifiers.push(self.parse_identifier()?);
|
||||
|
||||
if !matches!(self.current(), Token::Comma) {
|
||||
break;
|
||||
}
|
||||
self.advance();
|
||||
}
|
||||
|
||||
Ok(identifiers)
|
||||
}
|
||||
|
||||
fn parse_value_list(&mut self) -> Result<Vec<Value>, ParseError> {
|
||||
let mut values = Vec::new();
|
||||
|
||||
loop {
|
||||
values.push(self.parse_value()?);
|
||||
|
||||
if !matches!(self.current(), Token::Comma) {
|
||||
break;
|
||||
}
|
||||
self.advance();
|
||||
}
|
||||
|
||||
Ok(values)
|
||||
}
|
||||
|
||||
fn parse_value(&mut self) -> Result<Value, ParseError> {
|
||||
match self.current().clone() {
|
||||
Token::StringLiteral(s) => {
|
||||
self.advance();
|
||||
Ok(Value::Text(s))
|
||||
}
|
||||
Token::NumberLiteral(n) => {
|
||||
self.advance();
|
||||
if n.contains('.') {
|
||||
Ok(Value::Real(n.parse().unwrap()))
|
||||
} else {
|
||||
Ok(Value::Integer(n.parse().unwrap()))
|
||||
}
|
||||
}
|
||||
Token::LeftBracket => {
|
||||
self.advance();
|
||||
let vec = self.parse_vector_literal()?;
|
||||
self.expect(Token::RightBracket)?;
|
||||
Ok(Value::Vector(vec))
|
||||
}
|
||||
_ => Err(ParseError {
|
||||
message: format!("Expected value, got {:?}", self.current()),
|
||||
position: self.position,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_vector_literal(&mut self) -> Result<Vec<f32>, ParseError> {
|
||||
let mut values = Vec::new();
|
||||
|
||||
loop {
|
||||
let n = self.parse_number()?;
|
||||
values.push(n as f32);
|
||||
|
||||
if !matches!(self.current(), Token::Comma) {
|
||||
break;
|
||||
}
|
||||
self.advance();
|
||||
}
|
||||
|
||||
Ok(values)
|
||||
}
|
||||
|
||||
fn parse_number(&mut self) -> Result<f64, ParseError> {
|
||||
match self.current().clone() {
|
||||
Token::NumberLiteral(n) => {
|
||||
self.advance();
|
||||
n.parse().map_err(|_| ParseError {
|
||||
message: format!("Invalid number: {}", n),
|
||||
position: self.position,
|
||||
})
|
||||
}
|
||||
_ => Err(ParseError {
|
||||
message: "Expected number".to_string(),
|
||||
position: self.position,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_create_table() {
|
||||
let sql = "CREATE TABLE documents (id TEXT, content TEXT, embedding VECTOR(384))";
|
||||
let mut parser = SqlParser::new(sql).unwrap();
|
||||
let stmt = parser.parse().unwrap();
|
||||
|
||||
match stmt {
|
||||
SqlStatement::CreateTable { name, columns } => {
|
||||
assert_eq!(name, "documents");
|
||||
assert_eq!(columns.len(), 3);
|
||||
assert_eq!(columns[2].data_type, DataType::Vector(384));
|
||||
}
|
||||
_ => panic!("Expected CreateTable"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_insert() {
|
||||
let sql =
|
||||
"INSERT INTO documents (id, content, embedding) VALUES ('1', 'hello', [1.0, 2.0, 3.0])";
|
||||
let mut parser = SqlParser::new(sql).unwrap();
|
||||
let stmt = parser.parse().unwrap();
|
||||
|
||||
match stmt {
|
||||
SqlStatement::Insert {
|
||||
table,
|
||||
columns,
|
||||
values,
|
||||
} => {
|
||||
assert_eq!(table, "documents");
|
||||
assert_eq!(columns.len(), 3);
|
||||
assert_eq!(values.len(), 3);
|
||||
}
|
||||
_ => panic!("Expected Insert"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_select_with_vector_search() {
|
||||
let sql = "SELECT * FROM documents ORDER BY embedding <-> [1.0, 2.0, 3.0] LIMIT 5";
|
||||
let mut parser = SqlParser::new(sql).unwrap();
|
||||
let stmt = parser.parse().unwrap();
|
||||
|
||||
match stmt {
|
||||
SqlStatement::Select {
|
||||
order_by, limit, ..
|
||||
} => {
|
||||
assert!(order_by.is_some());
|
||||
assert_eq!(limit, Some(5));
|
||||
}
|
||||
_ => panic!("Expected Select"),
|
||||
}
|
||||
}
|
||||
}
|
||||
147
vendor/ruvector/crates/rvlite/src/sql/tests.rs
vendored
Normal file
147
vendor/ruvector/crates/rvlite/src/sql/tests.rs
vendored
Normal file
@@ -0,0 +1,147 @@
|
||||
// Integration tests for SQL engine
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::sql::{SqlEngine, SqlParser};
|
||||
|
||||
#[test]
|
||||
fn test_full_workflow() {
|
||||
let engine = SqlEngine::new();
|
||||
|
||||
// Create table
|
||||
let create_sql = "CREATE TABLE documents (id TEXT, content TEXT, embedding VECTOR(384))";
|
||||
let mut parser = SqlParser::new(create_sql).unwrap();
|
||||
let stmt = parser.parse().unwrap();
|
||||
engine.execute(stmt).unwrap();
|
||||
|
||||
// Insert data
|
||||
let insert_sql = "INSERT INTO documents (id, content, embedding) VALUES ('doc1', 'hello world', [1.0, 2.0, 3.0])";
|
||||
let mut parser = SqlParser::new(insert_sql).unwrap();
|
||||
let stmt = parser.parse().unwrap();
|
||||
|
||||
// This will fail due to dimension mismatch (3 vs 384), but tests the flow
|
||||
let result = engine.execute(stmt);
|
||||
assert!(result.is_err()); // Expected error due to dimension mismatch
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_similarity_search() {
|
||||
let engine = SqlEngine::new();
|
||||
|
||||
// Create table with small dimensions for testing
|
||||
let create_sql = "CREATE TABLE docs (id TEXT, embedding VECTOR(3))";
|
||||
let mut parser = SqlParser::new(create_sql).unwrap();
|
||||
let stmt = parser.parse().unwrap();
|
||||
engine.execute(stmt).unwrap();
|
||||
|
||||
// Insert test data
|
||||
for i in 0..10 {
|
||||
let insert_sql = format!(
|
||||
"INSERT INTO docs (id, embedding) VALUES ('doc{}', [{}, {}, {}])",
|
||||
i,
|
||||
i,
|
||||
i * 2,
|
||||
i * 3
|
||||
);
|
||||
let mut parser = SqlParser::new(&insert_sql).unwrap();
|
||||
let stmt = parser.parse().unwrap();
|
||||
engine.execute(stmt).unwrap();
|
||||
}
|
||||
|
||||
// Search for similar vectors
|
||||
let search_sql = "SELECT * FROM docs ORDER BY embedding <-> [5.0, 10.0, 15.0] LIMIT 3";
|
||||
let mut parser = SqlParser::new(search_sql).unwrap();
|
||||
let stmt = parser.parse().unwrap();
|
||||
let result = engine.execute(stmt).unwrap();
|
||||
|
||||
assert_eq!(result.rows.len(), 3);
|
||||
// The closest vector should be [5, 10, 15]
|
||||
assert!(result.rows[0].get("id").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_metadata_filtering() {
|
||||
let engine = SqlEngine::new();
|
||||
|
||||
// Create table
|
||||
let create_sql = "CREATE TABLE docs (id TEXT, category TEXT, embedding VECTOR(3))";
|
||||
let mut parser = SqlParser::new(create_sql).unwrap();
|
||||
let stmt = parser.parse().unwrap();
|
||||
engine.execute(stmt).unwrap();
|
||||
|
||||
// Insert data with categories
|
||||
let categories = vec!["tech", "sports", "tech", "news", "sports"];
|
||||
for (i, cat) in categories.iter().enumerate() {
|
||||
let insert_sql =
|
||||
format!(
|
||||
"INSERT INTO docs (id, category, embedding) VALUES ('doc{}', '{}', [{}, {}, {}])",
|
||||
i, cat, i, i * 2, i * 3
|
||||
);
|
||||
let mut parser = SqlParser::new(&insert_sql).unwrap();
|
||||
let stmt = parser.parse().unwrap();
|
||||
engine.execute(stmt).unwrap();
|
||||
}
|
||||
|
||||
// Search with filter
|
||||
let search_sql = "SELECT * FROM docs WHERE category = 'tech' ORDER BY embedding <-> [2.0, 4.0, 6.0] LIMIT 2";
|
||||
let mut parser = SqlParser::new(search_sql).unwrap();
|
||||
let stmt = parser.parse().unwrap();
|
||||
let result = engine.execute(stmt).unwrap();
|
||||
|
||||
// VectorDB filtering may not be fully precise, so we check for at least 1 result
|
||||
assert!(result.rows.len() >= 1);
|
||||
assert!(result.rows.len() <= 2);
|
||||
// All results should have category = 'tech'
|
||||
for row in &result.rows {
|
||||
if let Some(category) = row.get("category") {
|
||||
assert_eq!(category.to_string(), "'tech'");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_drop_table() {
|
||||
let engine = SqlEngine::new();
|
||||
|
||||
// Create table
|
||||
let create_sql = "CREATE TABLE temp (id TEXT, embedding VECTOR(3))";
|
||||
let mut parser = SqlParser::new(create_sql).unwrap();
|
||||
let stmt = parser.parse().unwrap();
|
||||
engine.execute(stmt).unwrap();
|
||||
|
||||
assert_eq!(engine.list_tables().len(), 1);
|
||||
|
||||
// Drop table
|
||||
let drop_sql = "DROP TABLE temp";
|
||||
let mut parser = SqlParser::new(drop_sql).unwrap();
|
||||
let stmt = parser.parse().unwrap();
|
||||
engine.execute(stmt).unwrap();
|
||||
|
||||
assert_eq!(engine.list_tables().len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_distance() {
|
||||
let engine = SqlEngine::new();
|
||||
|
||||
let create_sql = "CREATE TABLE docs (id TEXT, embedding VECTOR(3))";
|
||||
let mut parser = SqlParser::new(create_sql).unwrap();
|
||||
engine.execute(parser.parse().unwrap()).unwrap();
|
||||
|
||||
// Insert normalized vectors for cosine similarity
|
||||
let insert_sql = "INSERT INTO docs (id, embedding) VALUES ('doc1', [1.0, 0.0, 0.0])";
|
||||
let mut parser = SqlParser::new(insert_sql).unwrap();
|
||||
engine.execute(parser.parse().unwrap()).unwrap();
|
||||
|
||||
let insert_sql = "INSERT INTO docs (id, embedding) VALUES ('doc2', [0.0, 1.0, 0.0])";
|
||||
let mut parser = SqlParser::new(insert_sql).unwrap();
|
||||
engine.execute(parser.parse().unwrap()).unwrap();
|
||||
|
||||
// Search using cosine distance
|
||||
let search_sql = "SELECT * FROM docs ORDER BY embedding <=> [0.9, 0.1, 0.0] LIMIT 1";
|
||||
let mut parser = SqlParser::new(search_sql).unwrap();
|
||||
let result = engine.execute(parser.parse().unwrap()).unwrap();
|
||||
|
||||
assert_eq!(result.rows.len(), 1);
|
||||
// Should return doc1 as it's more similar to [0.9, 0.1, 0.0]
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user