Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,220 @@
// AST types for SQL statements
use serde::{Deserialize, Serialize};
/// SQL statement types
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum SqlStatement {
/// CREATE TABLE name (columns)
CreateTable { name: String, columns: Vec<Column> },
/// INSERT INTO table (columns) VALUES (values)
Insert {
table: String,
columns: Vec<String>,
values: Vec<Value>,
},
/// SELECT columns FROM table WHERE condition ORDER BY ... LIMIT k
Select {
columns: Vec<SelectColumn>,
from: String,
where_clause: Option<Expression>,
order_by: Option<OrderBy>,
limit: Option<usize>,
},
/// DROP TABLE name
Drop { table: String },
}
/// Column definition for CREATE TABLE
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Column {
pub name: String,
pub data_type: DataType,
}
/// Data types supported in SQL
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum DataType {
/// TEXT type for strings
Text,
/// INTEGER type
Integer,
/// REAL/FLOAT type
Real,
/// VECTOR(dimensions) type for vector data
Vector(usize),
}
/// Column selector in SELECT
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum SelectColumn {
/// SELECT *
Wildcard,
/// SELECT column_name
Name(String),
/// SELECT expression AS alias
Expression {
expr: Expression,
alias: Option<String>,
},
}
/// SQL expressions
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Expression {
/// Column reference
Column(String),
/// Literal value
Literal(Value),
/// Binary operation (e.g., a = b, a > b)
BinaryOp {
left: Box<Expression>,
op: BinaryOperator,
right: Box<Expression>,
},
/// Logical AND
And(Box<Expression>, Box<Expression>),
/// Logical OR
Or(Box<Expression>, Box<Expression>),
/// NOT expression
Not(Box<Expression>),
/// Function call
Function { name: String, args: Vec<Expression> },
/// Vector literal [1.0, 2.0, 3.0]
VectorLiteral(Vec<f32>),
/// Distance operation: column <-> vector
/// Used for ORDER BY embedding <-> $vector
Distance {
column: String,
metric: DistanceMetric,
vector: Vec<f32>,
},
}
/// Binary operators
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum BinaryOperator {
/// =
Eq,
/// !=
NotEq,
/// >
Gt,
/// >=
GtEq,
/// <
Lt,
/// <=
LtEq,
/// LIKE
Like,
}
/// Distance metrics for vector similarity
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum DistanceMetric {
/// L2 distance: <->
L2,
/// Cosine distance: <=>
Cosine,
/// Dot product: <#>
DotProduct,
}
/// ORDER BY clause
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct OrderBy {
pub expression: Expression,
pub direction: OrderDirection,
}
/// Sort direction
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum OrderDirection {
Asc,
Desc,
}
/// SQL values
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Value {
Null,
Text(String),
Integer(i64),
Real(f64),
Vector(Vec<f32>),
Boolean(bool),
}
impl Value {
/// Convert to JSON value for metadata storage
pub fn to_json(&self) -> serde_json::Value {
match self {
Value::Null => serde_json::Value::Null,
Value::Text(s) => serde_json::Value::String(s.clone()),
Value::Integer(i) => serde_json::Value::Number((*i).into()),
Value::Real(f) => {
serde_json::Value::Number(serde_json::Number::from_f64(*f).unwrap_or(0.into()))
}
Value::Vector(v) => serde_json::Value::Array(
v.iter()
.map(|f| {
serde_json::Value::Number(
serde_json::Number::from_f64(*f as f64).unwrap_or(0.into()),
)
})
.collect(),
),
Value::Boolean(b) => serde_json::Value::Bool(*b),
}
}
/// Parse from JSON value
pub fn from_json(json: &serde_json::Value) -> Self {
match json {
serde_json::Value::Null => Value::Null,
serde_json::Value::Bool(b) => Value::Boolean(*b),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
Value::Integer(i)
} else if let Some(f) = n.as_f64() {
Value::Real(f)
} else {
Value::Null
}
}
serde_json::Value::String(s) => Value::Text(s.clone()),
serde_json::Value::Array(arr) => {
// Try to parse as vector
let floats: Option<Vec<f32>> =
arr.iter().map(|v| v.as_f64().map(|f| f as f32)).collect();
if let Some(vec) = floats {
Value::Vector(vec)
} else {
Value::Null
}
}
serde_json::Value::Object(_) => Value::Null,
}
}
}
impl std::fmt::Display for Value {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Value::Null => write!(f, "NULL"),
Value::Text(s) => write!(f, "'{}'", s),
Value::Integer(i) => write!(f, "{}", i),
Value::Real(r) => write!(f, "{}", r),
Value::Vector(v) => write!(
f,
"[{}]",
v.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(", ")
),
Value::Boolean(b) => write!(f, "{}", b),
}
}
}

View File

@@ -0,0 +1,561 @@
// SQL executor that integrates with ruvector-core VectorDB
use super::ast::*;
use crate::{ErrorKind, RvLiteError};
use parking_lot::RwLock;
use ruvector_core::{SearchQuery, VectorDB, VectorEntry};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Table schema definition
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TableSchema {
pub name: String,
pub columns: Vec<Column>,
pub vector_column: Option<String>,
pub vector_dimensions: Option<usize>,
}
impl TableSchema {
/// Find the vector column in the schema
fn find_vector_column(&self) -> Option<(String, usize)> {
for col in &self.columns {
if let DataType::Vector(dims) = col.data_type {
return Some((col.name.clone(), dims));
}
}
None
}
/// Validate that columns match the schema
fn validate_columns(&self, columns: &[String]) -> Result<(), RvLiteError> {
for col in columns {
if !self.columns.iter().any(|c| &c.name == col) {
return Err(RvLiteError {
message: format!("Column '{}' not found in table '{}'", col, self.name),
kind: ErrorKind::SqlError,
});
}
}
Ok(())
}
/// Get column data type
fn get_column_type(&self, name: &str) -> Option<&DataType> {
self.columns
.iter()
.find(|c| c.name == name)
.map(|c| &c.data_type)
}
}
/// SQL execution result
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionResult {
pub rows: Vec<HashMap<String, Value>>,
pub rows_affected: usize,
}
/// SQL Engine that manages tables and executes queries
pub struct SqlEngine {
/// Table schemas
schemas: RwLock<HashMap<String, TableSchema>>,
/// Vector databases (one per table)
databases: RwLock<HashMap<String, VectorDB>>,
}
impl SqlEngine {
/// Create a new SQL engine
pub fn new() -> Self {
SqlEngine {
schemas: RwLock::new(HashMap::new()),
databases: RwLock::new(HashMap::new()),
}
}
/// Execute a SQL statement
pub fn execute(&self, statement: SqlStatement) -> Result<ExecutionResult, RvLiteError> {
match statement {
SqlStatement::CreateTable { name, columns } => self.create_table(name, columns),
SqlStatement::Insert {
table,
columns,
values,
} => self.insert(table, columns, values),
SqlStatement::Select {
columns,
from,
where_clause,
order_by,
limit,
} => self.select(columns, from, where_clause, order_by, limit),
SqlStatement::Drop { table } => self.drop_table(table),
}
}
fn create_table(
&self,
name: String,
columns: Vec<Column>,
) -> Result<ExecutionResult, RvLiteError> {
let mut schemas = self.schemas.write();
if schemas.contains_key(&name) {
return Err(RvLiteError {
message: format!("Table '{}' already exists", name),
kind: ErrorKind::SqlError,
});
}
// Find vector column
let (vector_column, vector_dimensions) = columns
.iter()
.find_map(|col| {
if let DataType::Vector(dims) = col.data_type {
Some((col.name.clone(), dims))
} else {
None
}
})
.ok_or_else(|| RvLiteError {
message: "Table must have at least one VECTOR column".to_string(),
kind: ErrorKind::SqlError,
})?;
let schema = TableSchema {
name: name.clone(),
columns,
vector_column: Some(vector_column),
vector_dimensions: Some(vector_dimensions),
};
// Create vector database for this table
let db_options = ruvector_core::types::DbOptions {
dimensions: vector_dimensions,
distance_metric: ruvector_core::DistanceMetric::Cosine,
storage_path: "memory://".to_string(),
hnsw_config: None,
quantization: None,
};
let db = VectorDB::new(db_options).map_err(|e| RvLiteError {
message: format!("Failed to create vector database: {}", e),
kind: ErrorKind::VectorError,
})?;
let mut databases = self.databases.write();
databases.insert(name.clone(), db);
schemas.insert(name, schema);
Ok(ExecutionResult {
rows: Vec::new(),
rows_affected: 0,
})
}
fn insert(
&self,
table: String,
columns: Vec<String>,
values: Vec<Value>,
) -> Result<ExecutionResult, RvLiteError> {
let schemas = self.schemas.read();
let schema = schemas.get(&table).ok_or_else(|| RvLiteError {
message: format!("Table '{}' not found", table),
kind: ErrorKind::SqlError,
})?;
// Validate columns
schema.validate_columns(&columns)?;
if columns.len() != values.len() {
return Err(RvLiteError {
message: format!(
"Column count ({}) does not match value count ({})",
columns.len(),
values.len()
),
kind: ErrorKind::SqlError,
});
}
// Extract vector and metadata
let mut vector: Option<Vec<f32>> = None;
let mut metadata = HashMap::new();
let mut id: Option<String> = None;
for (col, val) in columns.iter().zip(values.iter()) {
if let Some(DataType::Vector(_)) = schema.get_column_type(col) {
if let Value::Vector(v) = val {
vector = Some(v.clone());
} else {
return Err(RvLiteError {
message: format!("Expected vector value for column '{}'", col),
kind: ErrorKind::SqlError,
});
}
} else {
// Store as metadata
metadata.insert(col.clone(), val.to_json());
// Use 'id' column as vector ID if present
if col == "id" {
if let Value::Text(s) = val {
id = Some(s.clone());
}
}
}
}
let vector = vector.ok_or_else(|| RvLiteError {
message: "No vector value provided".to_string(),
kind: ErrorKind::SqlError,
})?;
// Validate vector dimensions
if let Some(expected_dims) = schema.vector_dimensions {
if vector.len() != expected_dims {
return Err(RvLiteError {
message: format!(
"Vector dimension mismatch: expected {}, got {}",
expected_dims,
vector.len()
),
kind: ErrorKind::SqlError,
});
}
}
// Insert into vector database
let entry = VectorEntry {
id,
vector,
metadata: Some(metadata),
};
let databases = self.databases.read();
let db = databases.get(&table).ok_or_else(|| RvLiteError {
message: format!("Database for table '{}' not found", table),
kind: ErrorKind::SqlError,
})?;
db.insert(entry).map_err(|e| RvLiteError {
message: format!("Failed to insert: {}", e),
kind: ErrorKind::VectorError,
})?;
Ok(ExecutionResult {
rows: Vec::new(),
rows_affected: 1,
})
}
fn select(
&self,
_columns: Vec<SelectColumn>,
from: String,
where_clause: Option<Expression>,
order_by: Option<OrderBy>,
limit: Option<usize>,
) -> Result<ExecutionResult, RvLiteError> {
let schemas = self.schemas.read();
let schema = schemas.get(&from).ok_or_else(|| RvLiteError {
message: format!("Table '{}' not found", from),
kind: ErrorKind::SqlError,
})?;
let databases = self.databases.read();
let db = databases.get(&from).ok_or_else(|| RvLiteError {
message: format!("Database for table '{}' not found", from),
kind: ErrorKind::SqlError,
})?;
// Handle vector similarity search
if let Some(order_by) = order_by {
if let Expression::Distance {
column: _,
metric: _,
vector,
} = order_by.expression
{
let k = limit.unwrap_or(10);
// Build filter from WHERE clause
let filter = if let Some(where_expr) = where_clause {
Some(self.build_filter(where_expr)?)
} else {
None
};
let query = SearchQuery {
vector,
k,
filter,
ef_search: None,
};
let results = db.search(query).map_err(|e| RvLiteError {
message: format!("Search failed: {}", e),
kind: ErrorKind::VectorError,
})?;
// Convert results to rows
let rows: Vec<HashMap<String, Value>> = results
.into_iter()
.map(|result| {
let mut row = HashMap::new();
// Add vector if present
if let Some(vec_col) = &schema.vector_column {
if let Some(vector) = result.vector {
row.insert(vec_col.clone(), Value::Vector(vector));
}
}
// Add metadata
if let Some(metadata) = result.metadata {
for (key, val) in metadata {
row.insert(key, Value::from_json(&val));
}
}
// Add distance score
row.insert("_distance".to_string(), Value::Real(result.score as f64));
row
})
.collect();
return Ok(ExecutionResult {
rows,
rows_affected: 0,
});
}
}
// Non-vector query - return all rows (scan all vectors)
// This is essentially a table scan through the vector database
let k = limit.unwrap_or(1000); // Default to 1000 rows max
// Create a zero vector for exhaustive search
let dims = schema.vector_dimensions.unwrap_or(3);
let query_vector = vec![0.0f32; dims];
// Build filter from WHERE clause
let filter = if let Some(where_expr) = where_clause {
Some(self.build_filter(where_expr)?)
} else {
None
};
let query = SearchQuery {
vector: query_vector,
k,
filter,
ef_search: None,
};
let results = db.search(query).map_err(|e| RvLiteError {
message: format!("Search failed: {}", e),
kind: ErrorKind::VectorError,
})?;
// Convert results to rows
let rows: Vec<HashMap<String, Value>> = results
.into_iter()
.map(|result| {
let mut row = HashMap::new();
// Add vector if present
if let Some(vec_col) = &schema.vector_column {
if let Some(vector) = result.vector {
row.insert(vec_col.clone(), Value::Vector(vector));
}
}
// Add metadata
if let Some(metadata) = result.metadata {
for (key, val) in metadata {
row.insert(key, Value::from_json(&val));
}
}
row
})
.collect();
Ok(ExecutionResult {
rows,
rows_affected: 0,
})
}
fn drop_table(&self, table: String) -> Result<ExecutionResult, RvLiteError> {
let mut schemas = self.schemas.write();
let mut databases = self.databases.write();
schemas.remove(&table).ok_or_else(|| RvLiteError {
message: format!("Table '{}' not found", table),
kind: ErrorKind::SqlError,
})?;
databases.remove(&table);
Ok(ExecutionResult {
rows: Vec::new(),
rows_affected: 0,
})
}
/// Build metadata filter from WHERE expression
fn build_filter(
&self,
expr: Expression,
) -> Result<HashMap<String, serde_json::Value>, RvLiteError> {
let mut filter = HashMap::new();
match expr {
Expression::BinaryOp { left, op, right } => {
if let (Expression::Column(col), Expression::Literal(val)) = (*left, *right) {
if op == BinaryOperator::Eq {
filter.insert(col, val.to_json());
} else {
return Err(RvLiteError {
message: "Only equality filters supported in WHERE clause".to_string(),
kind: ErrorKind::NotImplemented,
});
}
}
}
Expression::And(left, right) => {
let left_filter = self.build_filter(*left)?;
let right_filter = self.build_filter(*right)?;
filter.extend(left_filter);
filter.extend(right_filter);
}
_ => {
return Err(RvLiteError {
message: "Unsupported WHERE clause expression".to_string(),
kind: ErrorKind::NotImplemented,
});
}
}
Ok(filter)
}
/// List all tables
pub fn list_tables(&self) -> Vec<String> {
self.schemas.read().keys().cloned().collect()
}
/// Get table schema
pub fn get_schema(&self, table: &str) -> Option<TableSchema> {
self.schemas.read().get(table).cloned()
}
}
impl Default for SqlEngine {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_and_insert() {
let engine = SqlEngine::new();
// Create table
let create = SqlStatement::CreateTable {
name: "docs".to_string(),
columns: vec![
Column {
name: "id".to_string(),
data_type: DataType::Text,
},
Column {
name: "content".to_string(),
data_type: DataType::Text,
},
Column {
name: "embedding".to_string(),
data_type: DataType::Vector(3),
},
],
};
engine.execute(create).unwrap();
// Insert row
let insert = SqlStatement::Insert {
table: "docs".to_string(),
columns: vec![
"id".to_string(),
"content".to_string(),
"embedding".to_string(),
],
values: vec![
Value::Text("1".to_string()),
Value::Text("hello".to_string()),
Value::Vector(vec![1.0, 2.0, 3.0]),
],
};
let result = engine.execute(insert).unwrap();
assert_eq!(result.rows_affected, 1);
}
#[test]
fn test_vector_search() {
let engine = SqlEngine::new();
// Create table
let create = SqlStatement::CreateTable {
name: "docs".to_string(),
columns: vec![
Column {
name: "id".to_string(),
data_type: DataType::Text,
},
Column {
name: "embedding".to_string(),
data_type: DataType::Vector(3),
},
],
};
engine.execute(create).unwrap();
// Insert rows
for i in 0..5 {
let insert = SqlStatement::Insert {
table: "docs".to_string(),
columns: vec!["id".to_string(), "embedding".to_string()],
values: vec![
Value::Text(format!("{}", i)),
Value::Vector(vec![i as f32, i as f32 * 2.0, i as f32 * 3.0]),
],
};
engine.execute(insert).unwrap();
}
// Search
let select = SqlStatement::Select {
columns: vec![SelectColumn::Wildcard],
from: "docs".to_string(),
where_clause: None,
order_by: Some(OrderBy {
expression: Expression::Distance {
column: "embedding".to_string(),
metric: DistanceMetric::L2,
vector: vec![2.0, 4.0, 6.0],
},
direction: OrderDirection::Asc,
}),
limit: Some(3),
};
let result = engine.execute(select).unwrap();
assert_eq!(result.rows.len(), 3);
}
}

View File

@@ -0,0 +1,13 @@
// SQL query engine module for rvlite
// Provides SQL interface for vector database operations with WASM compatibility
mod ast;
mod executor;
mod parser;
pub use ast::*;
pub use executor::{ExecutionResult, SqlEngine};
pub use parser::{ParseError, SqlParser};
#[cfg(test)]
mod tests;

View File

@@ -0,0 +1,823 @@
// Hand-rolled SQL parser for WASM compatibility
// Implements recursive descent parsing for vector-specific SQL
use super::ast::*;
use std::fmt;
/// Parse error type
#[derive(Debug, Clone, PartialEq)]
pub struct ParseError {
pub message: String,
pub position: usize,
}
impl fmt::Display for ParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Parse error at position {}: {}",
self.position, self.message
)
}
}
impl std::error::Error for ParseError {}
/// Token types
#[derive(Debug, Clone, PartialEq)]
enum Token {
// Keywords
Select,
From,
Where,
Insert,
Into,
Values,
Create,
Table,
Drop,
OrderBy,
Limit,
And,
Or,
Not,
As,
// Data types
Text,
Integer,
Real,
Vector,
// Operators
Eq,
NotEq,
Gt,
GtEq,
Lt,
LtEq,
Like,
// Distance operators
L2Distance, // <->
CosineDistance, // <=>
DotProduct, // <#>
// Delimiters
LeftParen,
RightParen,
LeftBracket,
RightBracket,
Comma,
Semicolon,
Asterisk,
// Values
Identifier(String),
StringLiteral(String),
NumberLiteral(String),
// End
Eof,
}
/// Tokenizer (lexer)
struct Tokenizer {
input: Vec<char>,
position: usize,
}
impl Tokenizer {
fn new(input: &str) -> Self {
Tokenizer {
input: input.chars().collect(),
position: 0,
}
}
fn current(&self) -> Option<char> {
self.input.get(self.position).copied()
}
fn advance(&mut self) {
self.position += 1;
}
fn skip_whitespace(&mut self) {
while let Some(ch) = self.current() {
if ch.is_whitespace() {
self.advance();
} else {
break;
}
}
}
fn read_identifier(&mut self) -> String {
let mut result = String::new();
while let Some(ch) = self.current() {
if ch.is_alphanumeric() || ch == '_' {
result.push(ch);
self.advance();
} else {
break;
}
}
result
}
fn read_string(&mut self) -> Result<String, ParseError> {
let mut result = String::new();
self.advance(); // Skip opening quote
while let Some(ch) = self.current() {
if ch == '\'' {
self.advance();
return Ok(result);
} else {
result.push(ch);
self.advance();
}
}
Err(ParseError {
message: "Unterminated string literal".to_string(),
position: self.position,
})
}
fn read_number(&mut self) -> String {
let mut result = String::new();
let mut has_dot = false;
while let Some(ch) = self.current() {
if ch.is_numeric() {
result.push(ch);
self.advance();
} else if ch == '.' && !has_dot {
has_dot = true;
result.push(ch);
self.advance();
} else {
break;
}
}
result
}
fn next_token(&mut self) -> Result<Token, ParseError> {
self.skip_whitespace();
let ch = match self.current() {
Some(c) => c,
None => return Ok(Token::Eof),
};
match ch {
'(' => {
self.advance();
Ok(Token::LeftParen)
}
')' => {
self.advance();
Ok(Token::RightParen)
}
'[' => {
self.advance();
Ok(Token::LeftBracket)
}
']' => {
self.advance();
Ok(Token::RightBracket)
}
',' => {
self.advance();
Ok(Token::Comma)
}
';' => {
self.advance();
Ok(Token::Semicolon)
}
'*' => {
self.advance();
Ok(Token::Asterisk)
}
'=' => {
self.advance();
Ok(Token::Eq)
}
'!' => {
self.advance();
if self.current() == Some('=') {
self.advance();
Ok(Token::NotEq)
} else {
Err(ParseError {
message: "Expected '=' after '!'".to_string(),
position: self.position,
})
}
}
'>' => {
self.advance();
if self.current() == Some('=') {
self.advance();
Ok(Token::GtEq)
} else {
Ok(Token::Gt)
}
}
'<' => {
self.advance();
match self.current() {
Some('=') => {
self.advance();
if self.current() == Some('>') {
self.advance();
Ok(Token::CosineDistance)
} else {
Ok(Token::LtEq)
}
}
Some('-') => {
self.advance();
if self.current() == Some('>') {
self.advance();
Ok(Token::L2Distance)
} else {
Err(ParseError {
message: "Expected '>' after '<-'".to_string(),
position: self.position,
})
}
}
Some('#') => {
self.advance();
if self.current() == Some('>') {
self.advance();
Ok(Token::DotProduct)
} else {
Err(ParseError {
message: "Expected '>' after '<#'".to_string(),
position: self.position,
})
}
}
_ => Ok(Token::Lt),
}
}
'\'' => Ok(Token::StringLiteral(self.read_string()?)),
_ if ch.is_numeric() => Ok(Token::NumberLiteral(self.read_number())),
_ if ch.is_alphabetic() || ch == '_' => {
let ident = self.read_identifier();
Ok(match ident.to_uppercase().as_str() {
"SELECT" => Token::Select,
"FROM" => Token::From,
"WHERE" => Token::Where,
"INSERT" => Token::Insert,
"INTO" => Token::Into,
"VALUES" => Token::Values,
"CREATE" => Token::Create,
"TABLE" => Token::Table,
"DROP" => Token::Drop,
"ORDER" => {
self.skip_whitespace();
if self.read_identifier().to_uppercase() == "BY" {
Token::OrderBy
} else {
Token::Identifier(ident)
}
}
"LIMIT" => Token::Limit,
"AND" => Token::And,
"OR" => Token::Or,
"NOT" => Token::Not,
"AS" => Token::As,
"TEXT" => Token::Text,
"INTEGER" => Token::Integer,
"REAL" => Token::Real,
"VECTOR" => Token::Vector,
"LIKE" => Token::Like,
_ => Token::Identifier(ident),
})
}
_ => Err(ParseError {
message: format!("Unexpected character: {}", ch),
position: self.position,
}),
}
}
}
/// SQL Parser
pub struct SqlParser {
tokens: Vec<Token>,
position: usize,
}
impl SqlParser {
/// Create a new parser from SQL string
pub fn new(input: &str) -> Result<Self, ParseError> {
let mut tokenizer = Tokenizer::new(input);
let mut tokens = Vec::new();
loop {
let token = tokenizer.next_token()?;
if token == Token::Eof {
tokens.push(token);
break;
}
tokens.push(token);
}
Ok(SqlParser {
tokens,
position: 0,
})
}
/// Parse SQL statement
pub fn parse(&mut self) -> Result<SqlStatement, ParseError> {
let token = self.current().clone();
match token {
Token::Select => self.parse_select(),
Token::Insert => self.parse_insert(),
Token::Create => self.parse_create(),
Token::Drop => self.parse_drop(),
_ => Err(ParseError {
message: format!("Expected SELECT, INSERT, CREATE, or DROP, got {:?}", token),
position: self.position,
}),
}
}
fn current(&self) -> &Token {
self.tokens.get(self.position).unwrap_or(&Token::Eof)
}
fn advance(&mut self) {
if self.position < self.tokens.len() {
self.position += 1;
}
}
fn expect(&mut self, expected: Token) -> Result<(), ParseError> {
let current = self.current().clone();
if current == expected {
self.advance();
Ok(())
} else {
Err(ParseError {
message: format!("Expected {:?}, got {:?}", expected, current),
position: self.position,
})
}
}
fn parse_select(&mut self) -> Result<SqlStatement, ParseError> {
self.expect(Token::Select)?;
let columns = self.parse_select_columns()?;
self.expect(Token::From)?;
let from = self.parse_identifier()?;
let where_clause = if matches!(self.current(), Token::Where) {
self.advance();
Some(self.parse_expression()?)
} else {
None
};
let order_by = if matches!(self.current(), Token::OrderBy) {
self.advance();
Some(self.parse_order_by()?)
} else {
None
};
let limit = if matches!(self.current(), Token::Limit) {
self.advance();
Some(self.parse_number()? as usize)
} else {
None
};
Ok(SqlStatement::Select {
columns,
from,
where_clause,
order_by,
limit,
})
}
fn parse_select_columns(&mut self) -> Result<Vec<SelectColumn>, ParseError> {
if matches!(self.current(), Token::Asterisk) {
self.advance();
return Ok(vec![SelectColumn::Wildcard]);
}
let mut columns = Vec::new();
loop {
let name = self.parse_identifier()?;
columns.push(SelectColumn::Name(name));
if !matches!(self.current(), Token::Comma) {
break;
}
self.advance();
}
Ok(columns)
}
fn parse_insert(&mut self) -> Result<SqlStatement, ParseError> {
self.expect(Token::Insert)?;
self.expect(Token::Into)?;
let table = self.parse_identifier()?;
self.expect(Token::LeftParen)?;
let columns = self.parse_identifier_list()?;
self.expect(Token::RightParen)?;
self.expect(Token::Values)?;
self.expect(Token::LeftParen)?;
let values = self.parse_value_list()?;
self.expect(Token::RightParen)?;
Ok(SqlStatement::Insert {
table,
columns,
values,
})
}
fn parse_create(&mut self) -> Result<SqlStatement, ParseError> {
self.expect(Token::Create)?;
self.expect(Token::Table)?;
let name = self.parse_identifier()?;
self.expect(Token::LeftParen)?;
let columns = self.parse_column_definitions()?;
self.expect(Token::RightParen)?;
Ok(SqlStatement::CreateTable { name, columns })
}
fn parse_drop(&mut self) -> Result<SqlStatement, ParseError> {
self.expect(Token::Drop)?;
self.expect(Token::Table)?;
let table = self.parse_identifier()?;
Ok(SqlStatement::Drop { table })
}
fn parse_column_definitions(&mut self) -> Result<Vec<Column>, ParseError> {
let mut columns = Vec::new();
loop {
let name = self.parse_identifier()?;
let data_type = self.parse_data_type()?;
columns.push(Column { name, data_type });
if !matches!(self.current(), Token::Comma) {
break;
}
self.advance();
}
Ok(columns)
}
fn parse_data_type(&mut self) -> Result<DataType, ParseError> {
match self.current().clone() {
Token::Text => {
self.advance();
Ok(DataType::Text)
}
Token::Integer => {
self.advance();
Ok(DataType::Integer)
}
Token::Real => {
self.advance();
Ok(DataType::Real)
}
Token::Vector => {
self.advance();
self.expect(Token::LeftParen)?;
let dims = self.parse_number()? as usize;
self.expect(Token::RightParen)?;
Ok(DataType::Vector(dims))
}
_ => Err(ParseError {
message: "Expected data type (TEXT, INTEGER, REAL, or VECTOR)".to_string(),
position: self.position,
}),
}
}
fn parse_expression(&mut self) -> Result<Expression, ParseError> {
self.parse_or_expression()
}
fn parse_or_expression(&mut self) -> Result<Expression, ParseError> {
let mut left = self.parse_and_expression()?;
while matches!(self.current(), Token::Or) {
self.advance();
let right = self.parse_and_expression()?;
left = Expression::Or(Box::new(left), Box::new(right));
}
Ok(left)
}
fn parse_and_expression(&mut self) -> Result<Expression, ParseError> {
let mut left = self.parse_comparison_expression()?;
while matches!(self.current(), Token::And) {
self.advance();
let right = self.parse_comparison_expression()?;
left = Expression::And(Box::new(left), Box::new(right));
}
Ok(left)
}
fn parse_comparison_expression(&mut self) -> Result<Expression, ParseError> {
let left = self.parse_primary_expression()?;
let op = match self.current() {
Token::Eq => BinaryOperator::Eq,
Token::NotEq => BinaryOperator::NotEq,
Token::Gt => BinaryOperator::Gt,
Token::GtEq => BinaryOperator::GtEq,
Token::Lt => BinaryOperator::Lt,
Token::LtEq => BinaryOperator::LtEq,
Token::Like => BinaryOperator::Like,
_ => return Ok(left),
};
self.advance();
let right = self.parse_primary_expression()?;
Ok(Expression::BinaryOp {
left: Box::new(left),
op,
right: Box::new(right),
})
}
fn parse_primary_expression(&mut self) -> Result<Expression, ParseError> {
match self.current().clone() {
Token::Identifier(name) => {
self.advance();
Ok(Expression::Column(name))
}
Token::StringLiteral(s) => {
self.advance();
Ok(Expression::Literal(Value::Text(s)))
}
Token::NumberLiteral(n) => {
self.advance();
let value = if n.contains('.') {
Value::Real(n.parse().unwrap())
} else {
Value::Integer(n.parse().unwrap())
};
Ok(Expression::Literal(value))
}
Token::LeftBracket => {
self.advance();
let vec = self.parse_vector_literal()?;
self.expect(Token::RightBracket)?;
Ok(Expression::VectorLiteral(vec))
}
Token::Not => {
self.advance();
let expr = self.parse_primary_expression()?;
Ok(Expression::Not(Box::new(expr)))
}
_ => Err(ParseError {
message: format!("Unexpected token in expression: {:?}", self.current()),
position: self.position,
}),
}
}
fn parse_order_by(&mut self) -> Result<OrderBy, ParseError> {
// Parse column <-> vector or column <=> vector
let column = self.parse_identifier()?;
let metric = match self.current() {
Token::L2Distance => {
self.advance();
DistanceMetric::L2
}
Token::CosineDistance => {
self.advance();
DistanceMetric::Cosine
}
Token::DotProduct => {
self.advance();
DistanceMetric::DotProduct
}
_ => {
return Err(ParseError {
message: "Expected distance operator (<->, <=>, or <#>)".to_string(),
position: self.position,
});
}
};
let vector = if matches!(self.current(), Token::LeftBracket) {
self.advance();
let vec = self.parse_vector_literal()?;
self.expect(Token::RightBracket)?;
vec
} else {
return Err(ParseError {
message: "Expected vector literal after distance operator".to_string(),
position: self.position,
});
};
Ok(OrderBy {
expression: Expression::Distance {
column,
metric,
vector,
},
direction: OrderDirection::Asc,
})
}
fn parse_identifier(&mut self) -> Result<String, ParseError> {
match self.current().clone() {
Token::Identifier(name) => {
self.advance();
Ok(name)
}
_ => Err(ParseError {
message: "Expected identifier".to_string(),
position: self.position,
}),
}
}
fn parse_identifier_list(&mut self) -> Result<Vec<String>, ParseError> {
let mut identifiers = Vec::new();
loop {
identifiers.push(self.parse_identifier()?);
if !matches!(self.current(), Token::Comma) {
break;
}
self.advance();
}
Ok(identifiers)
}
fn parse_value_list(&mut self) -> Result<Vec<Value>, ParseError> {
let mut values = Vec::new();
loop {
values.push(self.parse_value()?);
if !matches!(self.current(), Token::Comma) {
break;
}
self.advance();
}
Ok(values)
}
fn parse_value(&mut self) -> Result<Value, ParseError> {
match self.current().clone() {
Token::StringLiteral(s) => {
self.advance();
Ok(Value::Text(s))
}
Token::NumberLiteral(n) => {
self.advance();
if n.contains('.') {
Ok(Value::Real(n.parse().unwrap()))
} else {
Ok(Value::Integer(n.parse().unwrap()))
}
}
Token::LeftBracket => {
self.advance();
let vec = self.parse_vector_literal()?;
self.expect(Token::RightBracket)?;
Ok(Value::Vector(vec))
}
_ => Err(ParseError {
message: format!("Expected value, got {:?}", self.current()),
position: self.position,
}),
}
}
fn parse_vector_literal(&mut self) -> Result<Vec<f32>, ParseError> {
let mut values = Vec::new();
loop {
let n = self.parse_number()?;
values.push(n as f32);
if !matches!(self.current(), Token::Comma) {
break;
}
self.advance();
}
Ok(values)
}
fn parse_number(&mut self) -> Result<f64, ParseError> {
match self.current().clone() {
Token::NumberLiteral(n) => {
self.advance();
n.parse().map_err(|_| ParseError {
message: format!("Invalid number: {}", n),
position: self.position,
})
}
_ => Err(ParseError {
message: "Expected number".to_string(),
position: self.position,
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_create_table() {
let sql = "CREATE TABLE documents (id TEXT, content TEXT, embedding VECTOR(384))";
let mut parser = SqlParser::new(sql).unwrap();
let stmt = parser.parse().unwrap();
match stmt {
SqlStatement::CreateTable { name, columns } => {
assert_eq!(name, "documents");
assert_eq!(columns.len(), 3);
assert_eq!(columns[2].data_type, DataType::Vector(384));
}
_ => panic!("Expected CreateTable"),
}
}
#[test]
fn test_parse_insert() {
let sql =
"INSERT INTO documents (id, content, embedding) VALUES ('1', 'hello', [1.0, 2.0, 3.0])";
let mut parser = SqlParser::new(sql).unwrap();
let stmt = parser.parse().unwrap();
match stmt {
SqlStatement::Insert {
table,
columns,
values,
} => {
assert_eq!(table, "documents");
assert_eq!(columns.len(), 3);
assert_eq!(values.len(), 3);
}
_ => panic!("Expected Insert"),
}
}
#[test]
fn test_parse_select_with_vector_search() {
let sql = "SELECT * FROM documents ORDER BY embedding <-> [1.0, 2.0, 3.0] LIMIT 5";
let mut parser = SqlParser::new(sql).unwrap();
let stmt = parser.parse().unwrap();
match stmt {
SqlStatement::Select {
order_by, limit, ..
} => {
assert!(order_by.is_some());
assert_eq!(limit, Some(5));
}
_ => panic!("Expected Select"),
}
}
}

View File

@@ -0,0 +1,147 @@
// Integration tests for SQL engine
#[cfg(test)]
mod tests {
use crate::sql::{SqlEngine, SqlParser};
#[test]
fn test_full_workflow() {
let engine = SqlEngine::new();
// Create table
let create_sql = "CREATE TABLE documents (id TEXT, content TEXT, embedding VECTOR(384))";
let mut parser = SqlParser::new(create_sql).unwrap();
let stmt = parser.parse().unwrap();
engine.execute(stmt).unwrap();
// Insert data
let insert_sql = "INSERT INTO documents (id, content, embedding) VALUES ('doc1', 'hello world', [1.0, 2.0, 3.0])";
let mut parser = SqlParser::new(insert_sql).unwrap();
let stmt = parser.parse().unwrap();
// This will fail due to dimension mismatch (3 vs 384), but tests the flow
let result = engine.execute(stmt);
assert!(result.is_err()); // Expected error due to dimension mismatch
}
#[test]
fn test_vector_similarity_search() {
let engine = SqlEngine::new();
// Create table with small dimensions for testing
let create_sql = "CREATE TABLE docs (id TEXT, embedding VECTOR(3))";
let mut parser = SqlParser::new(create_sql).unwrap();
let stmt = parser.parse().unwrap();
engine.execute(stmt).unwrap();
// Insert test data
for i in 0..10 {
let insert_sql = format!(
"INSERT INTO docs (id, embedding) VALUES ('doc{}', [{}, {}, {}])",
i,
i,
i * 2,
i * 3
);
let mut parser = SqlParser::new(&insert_sql).unwrap();
let stmt = parser.parse().unwrap();
engine.execute(stmt).unwrap();
}
// Search for similar vectors
let search_sql = "SELECT * FROM docs ORDER BY embedding <-> [5.0, 10.0, 15.0] LIMIT 3";
let mut parser = SqlParser::new(search_sql).unwrap();
let stmt = parser.parse().unwrap();
let result = engine.execute(stmt).unwrap();
assert_eq!(result.rows.len(), 3);
// The closest vector should be [5, 10, 15]
assert!(result.rows[0].get("id").is_some());
}
#[test]
fn test_metadata_filtering() {
let engine = SqlEngine::new();
// Create table
let create_sql = "CREATE TABLE docs (id TEXT, category TEXT, embedding VECTOR(3))";
let mut parser = SqlParser::new(create_sql).unwrap();
let stmt = parser.parse().unwrap();
engine.execute(stmt).unwrap();
// Insert data with categories
let categories = vec!["tech", "sports", "tech", "news", "sports"];
for (i, cat) in categories.iter().enumerate() {
let insert_sql =
format!(
"INSERT INTO docs (id, category, embedding) VALUES ('doc{}', '{}', [{}, {}, {}])",
i, cat, i, i * 2, i * 3
);
let mut parser = SqlParser::new(&insert_sql).unwrap();
let stmt = parser.parse().unwrap();
engine.execute(stmt).unwrap();
}
// Search with filter
let search_sql = "SELECT * FROM docs WHERE category = 'tech' ORDER BY embedding <-> [2.0, 4.0, 6.0] LIMIT 2";
let mut parser = SqlParser::new(search_sql).unwrap();
let stmt = parser.parse().unwrap();
let result = engine.execute(stmt).unwrap();
// VectorDB filtering may not be fully precise, so we check for at least 1 result
assert!(result.rows.len() >= 1);
assert!(result.rows.len() <= 2);
// All results should have category = 'tech'
for row in &result.rows {
if let Some(category) = row.get("category") {
assert_eq!(category.to_string(), "'tech'");
}
}
}
#[test]
fn test_drop_table() {
let engine = SqlEngine::new();
// Create table
let create_sql = "CREATE TABLE temp (id TEXT, embedding VECTOR(3))";
let mut parser = SqlParser::new(create_sql).unwrap();
let stmt = parser.parse().unwrap();
engine.execute(stmt).unwrap();
assert_eq!(engine.list_tables().len(), 1);
// Drop table
let drop_sql = "DROP TABLE temp";
let mut parser = SqlParser::new(drop_sql).unwrap();
let stmt = parser.parse().unwrap();
engine.execute(stmt).unwrap();
assert_eq!(engine.list_tables().len(), 0);
}
#[test]
fn test_cosine_distance() {
let engine = SqlEngine::new();
let create_sql = "CREATE TABLE docs (id TEXT, embedding VECTOR(3))";
let mut parser = SqlParser::new(create_sql).unwrap();
engine.execute(parser.parse().unwrap()).unwrap();
// Insert normalized vectors for cosine similarity
let insert_sql = "INSERT INTO docs (id, embedding) VALUES ('doc1', [1.0, 0.0, 0.0])";
let mut parser = SqlParser::new(insert_sql).unwrap();
engine.execute(parser.parse().unwrap()).unwrap();
let insert_sql = "INSERT INTO docs (id, embedding) VALUES ('doc2', [0.0, 1.0, 0.0])";
let mut parser = SqlParser::new(insert_sql).unwrap();
engine.execute(parser.parse().unwrap()).unwrap();
// Search using cosine distance
let search_sql = "SELECT * FROM docs ORDER BY embedding <=> [0.9, 0.1, 0.0] LIMIT 1";
let mut parser = SqlParser::new(search_sql).unwrap();
let result = engine.execute(parser.parse().unwrap()).unwrap();
assert_eq!(result.rows.len(), 1);
// Should return doc1 as it's more similar to [0.9, 0.1, 0.0]
}
}