Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
1154
vendor/ruvector/crates/ruvector-graph/ARCHITECTURE.md
vendored
Normal file
1154
vendor/ruvector/crates/ruvector-graph/ARCHITECTURE.md
vendored
Normal file
File diff suppressed because it is too large
Load Diff
166
vendor/ruvector/crates/ruvector-graph/Cargo.toml
vendored
Normal file
166
vendor/ruvector/crates/ruvector-graph/Cargo.toml
vendored
Normal file
@@ -0,0 +1,166 @@
|
||||
[package]
|
||||
name = "ruvector-graph"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
rust-version.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
repository.workspace = true
|
||||
readme = "README.md"
|
||||
description = "Distributed Neo4j-compatible hypergraph database with SIMD optimization"
|
||||
|
||||
[dependencies]
|
||||
# RuVector dependencies
|
||||
ruvector-core = { version = "2.0.1", path = "../ruvector-core", default-features = false, features = ["simd", "parallel"] }
|
||||
ruvector-raft = { version = "2.0.1", path = "../ruvector-raft", optional = true }
|
||||
ruvector-cluster = { version = "2.0.1", path = "../ruvector-cluster", optional = true }
|
||||
ruvector-replication = { version = "2.0.1", path = "../ruvector-replication", optional = true }
|
||||
|
||||
# Storage and indexing (optional for WASM)
|
||||
redb = { workspace = true, optional = true }
|
||||
memmap2 = { workspace = true, optional = true }
|
||||
hnsw_rs = { workspace = true, optional = true }
|
||||
|
||||
# SIMD and performance
|
||||
simsimd = { workspace = true, optional = true }
|
||||
rayon = { workspace = true }
|
||||
crossbeam = { workspace = true }
|
||||
num_cpus = "1.16"
|
||||
|
||||
# Serialization
|
||||
rkyv = { workspace = true }
|
||||
bincode = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
||||
# Async runtime (optional for WASM)
|
||||
tokio = { workspace = true, features = ["rt-multi-thread", "sync", "macros", "time", "net"], optional = true }
|
||||
futures = { workspace = true, optional = true }
|
||||
|
||||
# Error handling and logging
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
# Data structures
|
||||
dashmap = { workspace = true }
|
||||
parking_lot = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
|
||||
# Math and numerics
|
||||
ndarray = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rand_distr = { workspace = true }
|
||||
ordered-float = "4.2"
|
||||
|
||||
# Time and UUID
|
||||
chrono = { workspace = true }
|
||||
uuid = { workspace = true, features = ["v4", "serde"] }
|
||||
|
||||
# Graph algorithms and partitioning
|
||||
petgraph = "0.6"
|
||||
roaring = "0.10" # Roaring bitmaps for label indexes
|
||||
|
||||
# Query parsing (Cypher)
|
||||
nom = "7.1"
|
||||
nom_locate = "4.2"
|
||||
pest = { version = "2.7", optional = true }
|
||||
pest_derive = { version = "2.7", optional = true }
|
||||
lalrpop-util = { version = "0.21", optional = true }
|
||||
|
||||
# Cache
|
||||
lru = "0.16"
|
||||
moka = { version = "0.12", features = ["future"], optional = true }
|
||||
|
||||
# Compression (for storage optimization, optional for WASM)
|
||||
zstd = { version = "0.13", optional = true }
|
||||
lz4 = { version = "1.24", optional = true }
|
||||
|
||||
# Networking (for federation)
|
||||
tonic = { version = "0.12", features = ["transport"], optional = true }
|
||||
prost = { version = "0.13", optional = true }
|
||||
tower = { version = "0.4", optional = true }
|
||||
hyper = { version = "1.4", optional = true }
|
||||
|
||||
# Hashing for sharding
|
||||
blake3 = { version = "1.5", optional = true }
|
||||
xxhash-rust = { version = "0.8", features = ["xxh3"], optional = true }
|
||||
|
||||
# Metrics
|
||||
prometheus = { version = "0.13", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { workspace = true }
|
||||
proptest = { workspace = true }
|
||||
mockall = { workspace = true }
|
||||
tempfile = "3.13"
|
||||
tracing-subscriber = { workspace = true }
|
||||
tokio-test = "0.4"
|
||||
|
||||
# Benchmark datasets
|
||||
csv = "1.3"
|
||||
|
||||
[build-dependencies]
|
||||
pest_generator = "2.7"
|
||||
|
||||
[features]
|
||||
default = ["full"]
|
||||
|
||||
# Full feature set (non-WASM)
|
||||
full = ["simd", "storage", "async-runtime", "compression", "hnsw_rs", "ruvector-core/hnsw"]
|
||||
|
||||
# SIMD optimizations
|
||||
simd = ["ruvector-core/simd", "simsimd"]
|
||||
|
||||
# Storage backends
|
||||
storage = ["redb", "memmap2"]
|
||||
|
||||
# Async runtime support
|
||||
async-runtime = ["tokio", "futures", "moka"]
|
||||
|
||||
# Compression support
|
||||
compression = ["zstd", "lz4"]
|
||||
|
||||
# WASM-compatible minimal build (parser + core graph operations)
|
||||
wasm = []
|
||||
|
||||
# Distributed deployment with RAFT
|
||||
distributed = ["ruvector-raft", "ruvector-cluster", "ruvector-replication", "blake3", "xxhash-rust", "full"]
|
||||
|
||||
# Cross-cluster federation
|
||||
federation = ["tonic", "prost", "tower", "hyper", "distributed"]
|
||||
|
||||
# Advanced query optimization
|
||||
jit = [] # JIT compilation for hot paths (future)
|
||||
|
||||
# Monitoring and metrics
|
||||
metrics = ["prometheus"]
|
||||
|
||||
# Full-text search support
|
||||
fulltext = []
|
||||
|
||||
# Geospatial indexing
|
||||
geospatial = []
|
||||
|
||||
# Temporal graph support (time-varying graphs)
|
||||
temporal = []
|
||||
|
||||
# Query parser implementations
|
||||
cypher-pest = ["pest", "pest_derive"]
|
||||
cypher-lalrpop = ["lalrpop-util"]
|
||||
|
||||
[[example]]
|
||||
name = "test_cypher_parser"
|
||||
path = "examples/test_cypher_parser.rs"
|
||||
|
||||
[[bench]]
|
||||
name = "new_capabilities_bench"
|
||||
harness = false
|
||||
|
||||
[lib]
|
||||
crate-type = ["rlib"]
|
||||
bench = false
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
all-features = true
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
285
vendor/ruvector/crates/ruvector-graph/README.md
vendored
Normal file
285
vendor/ruvector/crates/ruvector-graph/README.md
vendored
Normal file
@@ -0,0 +1,285 @@
|
||||
# Ruvector Graph
|
||||
|
||||
[](https://crates.io/crates/ruvector-graph)
|
||||
[](https://docs.rs/ruvector-graph)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://www.rust-lang.org)
|
||||
|
||||
**A graph database with Cypher queries, hyperedges, and vector search -- all in one crate.**
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
ruvector-graph = "0.1.1"
|
||||
```
|
||||
|
||||
Most graph databases make you choose: you can have relationships *or* vector search, a query language *or* raw traversals, pairwise edges *or* nothing. `ruvector-graph` gives you all of them together. Write familiar Cypher queries like Neo4j, attach vector embeddings to any node for semantic search, and model complex group relationships with hyperedges that connect three or more nodes at once. It runs on servers, in browsers via WASM, and across clusters with built-in RAFT consensus. Part of the [RuVector](https://github.com/ruvnet/ruvector) ecosystem.
|
||||
|
||||
| | ruvector-graph | Neo4j / Typical Graph DB | Vector DB + Custom Glue |
|
||||
|---|---|---|---|
|
||||
| **Query language** | Full Cypher parser built-in | Cypher (Neo4j) or proprietary | No graph queries |
|
||||
| **Hyperedges** | Native -- one edge connects N nodes | Pairwise only -- workarounds needed | Not applicable |
|
||||
| **Vector search** | HNSW on every node, semantic similarity | Separate plugin or not available | Vectors only, no graph structure |
|
||||
| **SIMD acceleration** | SimSIMD hardware-optimized ops | JVM-based | Varies |
|
||||
| **Browser / WASM** | `default-features = false, features = ["wasm"]` | Server only | Server only |
|
||||
| **Distributed** | Built-in RAFT consensus + federation | Enterprise tier (paid) | Varies |
|
||||
| **Cost** | Free, open source (MIT) | Community or paid license | Varies |
|
||||
|
||||
## Key Features
|
||||
|
||||
| Feature | What It Does | Why It Matters |
|
||||
|---------|-------------|----------------|
|
||||
| **Cypher Engine** | Parse and execute Cypher queries -- `MATCH (a)-[:KNOWS]->(b)` | Use a query language you already know instead of raw traversal code |
|
||||
| **Hypergraph Model** | Edges connect any number of nodes, not just pairs | Model meetings, co-authorships, reactions -- any group relationship -- natively |
|
||||
| **Vector Embeddings** | Attach embeddings to nodes, run HNSW similarity search | Combine "who is connected to whom" with "what is semantically similar" |
|
||||
| **Property Graph** | Rich JSON properties on every node and edge | Store real data on your graph elements, not just IDs |
|
||||
| **Label Indexes** | Roaring bitmap indexes for fast label lookups | Filter millions of nodes by label in microseconds |
|
||||
| **SIMD Optimized** | Hardware-accelerated distance calculations via SimSIMD | Faster vector operations without changing your code |
|
||||
| **Distributed Mode** | RAFT consensus for multi-node deployments | Scale out without bolting on a separate coordination layer |
|
||||
| **Federation** | Cross-cluster graph queries | Query across data centers as if they were one graph |
|
||||
| **Compression** | ZSTD and LZ4 for storage | Smaller on disk without sacrificing read speed |
|
||||
| **WASM Compatible** | Run in browsers with WebAssembly | Same graph engine on server and client |
|
||||
|
||||
## Installation
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
ruvector-graph = "0.1.1"
|
||||
```
|
||||
|
||||
### Feature Flags
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
# Full feature set
|
||||
ruvector-graph = { version = "0.1.1", features = ["full"] }
|
||||
|
||||
# Minimal WASM-compatible build
|
||||
ruvector-graph = { version = "0.1.1", default-features = false, features = ["wasm"] }
|
||||
|
||||
# Distributed deployment
|
||||
ruvector-graph = { version = "0.1.1", features = ["distributed"] }
|
||||
```
|
||||
|
||||
Available features:
|
||||
- `full` (default): Complete feature set with all optimizations
|
||||
- `simd`: SIMD-optimized operations
|
||||
- `storage`: Persistent storage with redb
|
||||
- `async-runtime`: Tokio async support
|
||||
- `compression`: ZSTD/LZ4 compression
|
||||
- `distributed`: RAFT consensus support
|
||||
- `federation`: Cross-cluster federation
|
||||
- `wasm`: WebAssembly-compatible minimal build
|
||||
- `metrics`: Prometheus monitoring
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Create a Graph
|
||||
|
||||
```rust
|
||||
use ruvector_graph::{Graph, Node, Edge, GraphConfig};
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Create a new graph
|
||||
let config = GraphConfig::default();
|
||||
let graph = Graph::new(config)?;
|
||||
|
||||
// Create nodes
|
||||
let alice = graph.create_node(Node {
|
||||
labels: vec!["Person".to_string()],
|
||||
properties: serde_json::json!({
|
||||
"name": "Alice",
|
||||
"age": 30
|
||||
}),
|
||||
..Default::default()
|
||||
})?;
|
||||
|
||||
let bob = graph.create_node(Node {
|
||||
labels: vec!["Person".to_string()],
|
||||
properties: serde_json::json!({
|
||||
"name": "Bob",
|
||||
"age": 25
|
||||
}),
|
||||
..Default::default()
|
||||
})?;
|
||||
|
||||
// Create relationship
|
||||
graph.create_edge(Edge {
|
||||
label: "KNOWS".to_string(),
|
||||
source: alice.id,
|
||||
target: bob.id,
|
||||
properties: serde_json::json!({
|
||||
"since": 2020
|
||||
}),
|
||||
..Default::default()
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
### Cypher Queries
|
||||
|
||||
```rust
|
||||
use ruvector_graph::{Graph, CypherExecutor};
|
||||
|
||||
// Execute Cypher query
|
||||
let executor = CypherExecutor::new(&graph);
|
||||
let results = executor.execute("
|
||||
MATCH (p:Person)-[:KNOWS]->(friend:Person)
|
||||
WHERE p.name = 'Alice'
|
||||
RETURN friend.name AS name, friend.age AS age
|
||||
")?;
|
||||
|
||||
for row in results {
|
||||
println!("Friend: {} (age {})", row["name"], row["age"]);
|
||||
}
|
||||
```
|
||||
|
||||
### Vector-Enhanced Graph
|
||||
|
||||
```rust
|
||||
use ruvector_graph::{Graph, VectorConfig};
|
||||
|
||||
// Enable vector embeddings on nodes
|
||||
let config = GraphConfig {
|
||||
vector_config: Some(VectorConfig {
|
||||
dimensions: 384,
|
||||
distance_metric: DistanceMetric::Cosine,
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let graph = Graph::new(config)?;
|
||||
|
||||
// Create node with embedding
|
||||
let node = graph.create_node(Node {
|
||||
labels: vec!["Document".to_string()],
|
||||
properties: serde_json::json!({"title": "Introduction to Graphs"}),
|
||||
embedding: Some(vec![0.1, 0.2, 0.3, /* ... 384 dims */]),
|
||||
..Default::default()
|
||||
})?;
|
||||
|
||||
// Semantic similarity search
|
||||
let similar = graph.search_similar_nodes(
|
||||
vec![0.1, 0.2, 0.3, /* query vector */],
|
||||
10, // top-k
|
||||
Some(vec!["Document".to_string()]), // filter by labels
|
||||
)?;
|
||||
```
|
||||
|
||||
### Hyperedges
|
||||
|
||||
```rust
|
||||
use ruvector_graph::{Graph, Hyperedge};
|
||||
|
||||
// Create a hyperedge connecting multiple nodes
|
||||
let meeting = graph.create_hyperedge(Hyperedge {
|
||||
label: "PARTICIPATED_IN".to_string(),
|
||||
nodes: vec![alice.id, bob.id, charlie.id],
|
||||
properties: serde_json::json!({
|
||||
"event": "Team Meeting",
|
||||
"date": "2024-01-15"
|
||||
}),
|
||||
..Default::default()
|
||||
})?;
|
||||
```
|
||||
|
||||
## API Overview
|
||||
|
||||
### Core Types
|
||||
|
||||
```rust
|
||||
// Node in the graph
|
||||
pub struct Node {
|
||||
pub id: NodeId,
|
||||
pub labels: Vec<String>,
|
||||
pub properties: serde_json::Value,
|
||||
pub embedding: Option<Vec<f32>>,
|
||||
}
|
||||
|
||||
// Edge connecting two nodes
|
||||
pub struct Edge {
|
||||
pub id: EdgeId,
|
||||
pub label: String,
|
||||
pub source: NodeId,
|
||||
pub target: NodeId,
|
||||
pub properties: serde_json::Value,
|
||||
}
|
||||
|
||||
// Hyperedge connecting multiple nodes
|
||||
pub struct Hyperedge {
|
||||
pub id: HyperedgeId,
|
||||
pub label: String,
|
||||
pub nodes: Vec<NodeId>,
|
||||
pub properties: serde_json::Value,
|
||||
}
|
||||
```
|
||||
|
||||
### Graph Operations
|
||||
|
||||
```rust
|
||||
impl Graph {
|
||||
// Node operations
|
||||
pub fn create_node(&self, node: Node) -> Result<Node>;
|
||||
pub fn get_node(&self, id: &NodeId) -> Result<Option<Node>>;
|
||||
pub fn update_node(&self, node: Node) -> Result<Node>;
|
||||
pub fn delete_node(&self, id: &NodeId) -> Result<bool>;
|
||||
|
||||
// Edge operations
|
||||
pub fn create_edge(&self, edge: Edge) -> Result<Edge>;
|
||||
pub fn get_edge(&self, id: &EdgeId) -> Result<Option<Edge>>;
|
||||
pub fn delete_edge(&self, id: &EdgeId) -> Result<bool>;
|
||||
|
||||
// Traversal
|
||||
pub fn neighbors(&self, id: &NodeId, direction: Direction) -> Result<Vec<Node>>;
|
||||
pub fn traverse(&self, start: &NodeId, config: TraversalConfig) -> Result<Vec<Path>>;
|
||||
|
||||
// Vector search
|
||||
pub fn search_similar_nodes(&self, query: Vec<f32>, k: usize, labels: Option<Vec<String>>) -> Result<Vec<Node>>;
|
||||
}
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
### Benchmarks (1M Nodes, 10M Edges)
|
||||
|
||||
```
|
||||
Operation Latency (p50) Throughput
|
||||
-----------------------------------------------------
|
||||
Node lookup ~0.1ms 100K ops/s
|
||||
Edge traversal ~0.5ms 50K ops/s
|
||||
1-hop neighbors ~1ms 20K ops/s
|
||||
Cypher simple query ~5ms 5K ops/s
|
||||
Vector similarity ~2ms 10K ops/s
|
||||
```
|
||||
|
||||
## Related Crates
|
||||
|
||||
- **[ruvector-core](../ruvector-core/)** - Core vector database engine
|
||||
- **[ruvector-graph-node](../ruvector-graph-node/)** - Node.js bindings
|
||||
- **[ruvector-graph-wasm](../ruvector-graph-wasm/)** - WebAssembly bindings
|
||||
- **[ruvector-raft](../ruvector-raft/)** - RAFT consensus for distributed mode
|
||||
- **[ruvector-cluster](../ruvector-cluster/)** - Clustering and sharding
|
||||
|
||||
## Documentation
|
||||
|
||||
- **[RuVector README](../../README.md)** - Complete project overview
|
||||
- **[API Documentation](https://docs.rs/ruvector-graph)** - Full API reference
|
||||
- **[GitHub Repository](https://github.com/ruvnet/ruvector)** - Source code
|
||||
|
||||
## License
|
||||
|
||||
**MIT License** - see [LICENSE](../../LICENSE) for details.
|
||||
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
|
||||
**Part of [RuVector](https://github.com/ruvnet/ruvector) - Built by [rUv](https://ruv.io)**
|
||||
|
||||
[](https://github.com/ruvnet/ruvector)
|
||||
|
||||
[Documentation](https://docs.rs/ruvector-graph) | [Crates.io](https://crates.io/crates/ruvector-graph) | [GitHub](https://github.com/ruvnet/ruvector)
|
||||
|
||||
</div>
|
||||
58
vendor/ruvector/crates/ruvector-graph/benches/cypher_parser.rs
vendored
Normal file
58
vendor/ruvector/crates/ruvector-graph/benches/cypher_parser.rs
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
use ruvector_graph::cypher::parser::parse_cypher;
|
||||
|
||||
fn bench_simple_match(c: &mut Criterion) {
|
||||
c.bench_function("parse simple MATCH", |b| {
|
||||
b.iter(|| parse_cypher(black_box("MATCH (n:Person) RETURN n")))
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_complex_match(c: &mut Criterion) {
|
||||
c.bench_function("parse complex MATCH with WHERE", |b| {
|
||||
b.iter(|| {
|
||||
parse_cypher(black_box(
|
||||
"MATCH (a:Person)-[r:KNOWS]->(b:Person) WHERE a.age > 30 AND b.name = 'Alice' RETURN a.name, b.name, r.since ORDER BY r.since DESC LIMIT 10"
|
||||
))
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_create_query(c: &mut Criterion) {
|
||||
c.bench_function("parse CREATE query", |b| {
|
||||
b.iter(|| {
|
||||
parse_cypher(black_box(
|
||||
"CREATE (n:Person {name: 'Bob', age: 30, email: 'bob@example.com'})",
|
||||
))
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_hyperedge_query(c: &mut Criterion) {
|
||||
c.bench_function("parse hyperedge query", |b| {
|
||||
b.iter(|| {
|
||||
parse_cypher(black_box(
|
||||
"MATCH (person)-[r:TRANSACTION]->(acc1:Account, acc2:Account, merchant:Merchant) WHERE r.amount > 1000 RETURN person, r, acc1, acc2, merchant"
|
||||
))
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_aggregation_query(c: &mut Criterion) {
|
||||
c.bench_function("parse aggregation query", |b| {
|
||||
b.iter(|| {
|
||||
parse_cypher(black_box(
|
||||
"MATCH (n:Person) RETURN COUNT(n), AVG(n.age), MAX(n.salary), COLLECT(n.name)",
|
||||
))
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
bench_simple_match,
|
||||
bench_complex_match,
|
||||
bench_create_query,
|
||||
bench_hyperedge_query,
|
||||
bench_aggregation_query
|
||||
);
|
||||
criterion_main!(benches);
|
||||
11
vendor/ruvector/crates/ruvector-graph/benches/distributed_query.rs
vendored
Normal file
11
vendor/ruvector/crates/ruvector-graph/benches/distributed_query.rs
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
// Placeholder benchmark for distributed query
|
||||
// TODO: Implement comprehensive benchmarks
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
|
||||
fn distributed_query_benchmark(c: &mut Criterion) {
|
||||
c.bench_function("placeholder", |b| b.iter(|| black_box(42)));
|
||||
}
|
||||
|
||||
criterion_group!(benches, distributed_query_benchmark);
|
||||
criterion_main!(benches);
|
||||
324
vendor/ruvector/crates/ruvector-graph/benches/graph_bench.rs
vendored
Normal file
324
vendor/ruvector/crates/ruvector-graph/benches/graph_bench.rs
vendored
Normal file
@@ -0,0 +1,324 @@
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ruvector_graph::types::{EdgeId, NodeId, Properties, PropertyValue};
|
||||
use ruvector_graph::{Edge, GraphDB, Node};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Helper to create test graph
|
||||
fn create_test_graph() -> GraphDB {
|
||||
GraphDB::new()
|
||||
}
|
||||
|
||||
/// Benchmark: Single node insertion
|
||||
fn bench_node_insertion_single(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("node_insertion_single");
|
||||
|
||||
for size in [1, 10, 100, 1000].iter() {
|
||||
group.throughput(Throughput::Elements(*size as u64));
|
||||
group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| {
|
||||
b.iter(|| {
|
||||
let graph = create_test_graph();
|
||||
for i in 0..size {
|
||||
let mut props = Properties::new();
|
||||
props.insert(
|
||||
"name".to_string(),
|
||||
PropertyValue::String(format!("node_{}", i)),
|
||||
);
|
||||
props.insert("value".to_string(), PropertyValue::Integer(i as i64));
|
||||
|
||||
let node_id = NodeId(format!("node_{}", i));
|
||||
let node = Node::new(node_id, vec!["Person".to_string()], props);
|
||||
black_box(graph.create_node(node).unwrap());
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
/// Benchmark: Batch node insertion
|
||||
fn bench_node_insertion_batch(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("node_insertion_batch");
|
||||
|
||||
for batch_size in [100, 1000, 10000].iter() {
|
||||
group.throughput(Throughput::Elements(*batch_size as u64));
|
||||
group.bench_with_input(
|
||||
BenchmarkId::from_parameter(batch_size),
|
||||
batch_size,
|
||||
|b, &batch_size| {
|
||||
b.iter(|| {
|
||||
let graph = create_test_graph();
|
||||
for i in 0..batch_size {
|
||||
let mut props = Properties::new();
|
||||
props.insert(
|
||||
"name".to_string(),
|
||||
PropertyValue::String(format!("node_{}", i)),
|
||||
);
|
||||
props.insert("value".to_string(), PropertyValue::Integer(i as i64));
|
||||
|
||||
let node_id = NodeId(format!("batch_node_{}", i));
|
||||
let node = Node::new(node_id, vec!["Person".to_string()], props);
|
||||
black_box(graph.create_node(node).unwrap());
|
||||
}
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
/// Benchmark: Bulk node insertion (optimized path)
|
||||
fn bench_node_insertion_bulk(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("node_insertion_bulk");
|
||||
group.sample_size(10); // Reduce samples for large operations
|
||||
|
||||
for bulk_size in [10000, 100000].iter() {
|
||||
group.throughput(Throughput::Elements(*bulk_size as u64));
|
||||
group.bench_with_input(
|
||||
BenchmarkId::from_parameter(bulk_size),
|
||||
bulk_size,
|
||||
|b, &bulk_size| {
|
||||
b.iter(|| {
|
||||
let graph = create_test_graph();
|
||||
for i in 0..bulk_size {
|
||||
let mut props = Properties::new();
|
||||
props.insert("id".to_string(), PropertyValue::Integer(i as i64));
|
||||
props.insert(
|
||||
"name".to_string(),
|
||||
PropertyValue::String(format!("user_{}", i)),
|
||||
);
|
||||
|
||||
let node_id = NodeId(format!("bulk_user_{}", i));
|
||||
let node = Node::new(node_id, vec!["User".to_string()], props);
|
||||
black_box(graph.create_node(node).unwrap());
|
||||
}
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
/// Benchmark: Edge creation
|
||||
fn bench_edge_creation(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("edge_creation");
|
||||
|
||||
// Setup: Create nodes once
|
||||
let graph = Arc::new(create_test_graph());
|
||||
let mut node_ids = Vec::new();
|
||||
for i in 0..1000 {
|
||||
let mut props = Properties::new();
|
||||
props.insert("id".to_string(), PropertyValue::Integer(i as i64));
|
||||
let node_id = NodeId(format!("edge_test_node_{}", i));
|
||||
let node = Node::new(node_id.clone(), vec!["Person".to_string()], props);
|
||||
graph.create_node(node).unwrap();
|
||||
node_ids.push(node_id);
|
||||
}
|
||||
|
||||
for num_edges in [100, 1000].iter() {
|
||||
group.throughput(Throughput::Elements(*num_edges as u64));
|
||||
group.bench_with_input(
|
||||
BenchmarkId::from_parameter(num_edges),
|
||||
num_edges,
|
||||
|b, &num_edges| {
|
||||
let graph = graph.clone();
|
||||
let node_ids = node_ids.clone();
|
||||
b.iter(|| {
|
||||
for i in 0..num_edges {
|
||||
let from = &node_ids[i % node_ids.len()];
|
||||
let to = &node_ids[(i + 1) % node_ids.len()];
|
||||
|
||||
let mut props = Properties::new();
|
||||
props.insert("weight".to_string(), PropertyValue::Float(i as f64));
|
||||
|
||||
let edge_id = EdgeId(format!("edge_{}", i));
|
||||
let edge = Edge::new(
|
||||
edge_id,
|
||||
from.clone(),
|
||||
to.clone(),
|
||||
"KNOWS".to_string(),
|
||||
props,
|
||||
);
|
||||
black_box(graph.create_edge(edge).unwrap());
|
||||
}
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
/// Benchmark: Simple node lookup by ID
|
||||
fn bench_query_node_lookup(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("query_node_lookup");
|
||||
|
||||
// Setup: Create 10k nodes (reduced for faster benchmark)
|
||||
let graph = Arc::new(create_test_graph());
|
||||
let mut node_ids = Vec::new();
|
||||
for i in 0..10000 {
|
||||
let mut props = Properties::new();
|
||||
props.insert("id".to_string(), PropertyValue::Integer(i as i64));
|
||||
let node_id = NodeId(format!("lookup_node_{}", i));
|
||||
let node = Node::new(node_id.clone(), vec!["Person".to_string()], props);
|
||||
graph.create_node(node).unwrap();
|
||||
node_ids.push(node_id);
|
||||
}
|
||||
|
||||
group.bench_function("lookup_by_id", |b| {
|
||||
let graph = graph.clone();
|
||||
let node_ids = node_ids.clone();
|
||||
b.iter(|| {
|
||||
let id = &node_ids[black_box(1234 % node_ids.len())];
|
||||
black_box(graph.get_node(id).unwrap());
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
/// Benchmark: Edge lookup
|
||||
fn bench_query_edge_lookup(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("query_edge_lookup");
|
||||
|
||||
// Setup: Create nodes and edges
|
||||
let graph = Arc::new(create_test_graph());
|
||||
let mut node_ids = Vec::new();
|
||||
let mut edge_ids = Vec::new();
|
||||
|
||||
// Create 100 nodes
|
||||
for i in 0..100 {
|
||||
let mut props = Properties::new();
|
||||
props.insert("id".to_string(), PropertyValue::Integer(i as i64));
|
||||
let node_id = NodeId(format!("trav_node_{}", i));
|
||||
let node = Node::new(node_id.clone(), vec!["Person".to_string()], props);
|
||||
graph.create_node(node).unwrap();
|
||||
node_ids.push(node_id);
|
||||
}
|
||||
|
||||
// Create edges (each node has ~5 outgoing edges)
|
||||
for i in 0..node_ids.len() {
|
||||
for j in 0..5 {
|
||||
let to_idx = (i + j + 1) % node_ids.len();
|
||||
let edge_id = EdgeId(format!("trav_edge_{}_{}", i, j));
|
||||
let edge = Edge::new(
|
||||
edge_id.clone(),
|
||||
node_ids[i].clone(),
|
||||
node_ids[to_idx].clone(),
|
||||
"KNOWS".to_string(),
|
||||
Properties::new(),
|
||||
);
|
||||
graph.create_edge(edge).unwrap();
|
||||
edge_ids.push(edge_id);
|
||||
}
|
||||
}
|
||||
|
||||
group.bench_function("edge_by_id", |b| {
|
||||
let graph = graph.clone();
|
||||
let edge_ids = edge_ids.clone();
|
||||
b.iter(|| {
|
||||
let id = &edge_ids[black_box(10 % edge_ids.len())];
|
||||
black_box(graph.get_edge(id).unwrap());
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
/// Benchmark: Get nodes by label
|
||||
fn bench_query_get_by_label(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("query_get_by_label");
|
||||
|
||||
let graph = Arc::new(create_test_graph());
|
||||
|
||||
// Create diverse nodes with different labels
|
||||
for i in 0..1000 {
|
||||
let mut props = Properties::new();
|
||||
props.insert("id".to_string(), PropertyValue::Integer(i as i64));
|
||||
let node_id = NodeId(format!("label_node_{}", i));
|
||||
|
||||
let label = if i % 3 == 0 {
|
||||
"Person"
|
||||
} else if i % 3 == 1 {
|
||||
"Organization"
|
||||
} else {
|
||||
"Location"
|
||||
};
|
||||
|
||||
let node = Node::new(node_id, vec![label.to_string()], props);
|
||||
graph.create_node(node).unwrap();
|
||||
}
|
||||
|
||||
group.bench_function("get_persons", |b| {
|
||||
let graph = graph.clone();
|
||||
b.iter(|| {
|
||||
let nodes = graph.get_nodes_by_label("Person");
|
||||
black_box(nodes.len());
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
/// Benchmark: Memory usage tracking
|
||||
fn bench_memory_usage(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("memory_usage");
|
||||
group.sample_size(10);
|
||||
|
||||
for num_nodes in [1000, 10000].iter() {
|
||||
group.throughput(Throughput::Elements(*num_nodes as u64));
|
||||
group.bench_with_input(
|
||||
BenchmarkId::from_parameter(num_nodes),
|
||||
num_nodes,
|
||||
|b, &num_nodes| {
|
||||
b.iter_custom(|iters| {
|
||||
let mut total_duration = Duration::ZERO;
|
||||
|
||||
for _ in 0..iters {
|
||||
let graph = create_test_graph();
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
for i in 0..*num_nodes {
|
||||
let mut props = Properties::new();
|
||||
props.insert("id".to_string(), PropertyValue::Integer(i as i64));
|
||||
props.insert(
|
||||
"name".to_string(),
|
||||
PropertyValue::String(format!("node_{}", i)),
|
||||
);
|
||||
|
||||
let node_id = NodeId(format!("mem_node_{}", i));
|
||||
let node = Node::new(node_id, vec!["TestNode".to_string()], props);
|
||||
graph.create_node(node).unwrap();
|
||||
}
|
||||
total_duration += start.elapsed();
|
||||
|
||||
// Force drop to measure cleanup
|
||||
drop(graph);
|
||||
}
|
||||
|
||||
total_duration
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
bench_node_insertion_single,
|
||||
bench_node_insertion_batch,
|
||||
bench_node_insertion_bulk,
|
||||
bench_edge_creation,
|
||||
bench_query_node_lookup,
|
||||
bench_query_edge_lookup,
|
||||
bench_query_get_by_label,
|
||||
bench_memory_usage
|
||||
);
|
||||
|
||||
criterion_main!(benches);
|
||||
11
vendor/ruvector/crates/ruvector-graph/benches/graph_traversal.rs
vendored
Normal file
11
vendor/ruvector/crates/ruvector-graph/benches/graph_traversal.rs
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
// Placeholder benchmark for graph traversal
|
||||
// TODO: Implement comprehensive benchmarks
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
|
||||
fn graph_traversal_benchmark(c: &mut Criterion) {
|
||||
c.bench_function("placeholder", |b| b.iter(|| black_box(42)));
|
||||
}
|
||||
|
||||
criterion_group!(benches, graph_traversal_benchmark);
|
||||
criterion_main!(benches);
|
||||
11
vendor/ruvector/crates/ruvector-graph/benches/hybrid_vector_graph.rs
vendored
Normal file
11
vendor/ruvector/crates/ruvector-graph/benches/hybrid_vector_graph.rs
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
// Placeholder benchmark for hybrid vector graph
|
||||
// TODO: Implement comprehensive benchmarks
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
|
||||
fn hybrid_vector_graph_benchmark(c: &mut Criterion) {
|
||||
c.bench_function("placeholder", |b| b.iter(|| black_box(42)));
|
||||
}
|
||||
|
||||
criterion_group!(benches, hybrid_vector_graph_benchmark);
|
||||
criterion_main!(benches);
|
||||
251
vendor/ruvector/crates/ruvector-graph/benches/new_capabilities_bench.rs
vendored
Normal file
251
vendor/ruvector/crates/ruvector-graph/benches/new_capabilities_bench.rs
vendored
Normal file
@@ -0,0 +1,251 @@
|
||||
//! Benchmarks for new capabilities
|
||||
//!
|
||||
//! Run with: cargo bench --package ruvector-graph --bench new_capabilities_bench
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||
use ruvector_graph::cypher::parser::parse_cypher;
|
||||
use ruvector_graph::hybrid::semantic_search::{SemanticSearch, SemanticSearchConfig};
|
||||
use ruvector_graph::hybrid::vector_index::{EmbeddingConfig, HybridIndex, VectorIndexType};
|
||||
|
||||
// ============================================================================
|
||||
// Parser Benchmarks
|
||||
// ============================================================================
|
||||
|
||||
fn bench_simple_match(c: &mut Criterion) {
|
||||
let query = "MATCH (n:Person) RETURN n";
|
||||
|
||||
c.bench_function("parser/simple_match", |b| {
|
||||
b.iter(|| parse_cypher(black_box(query)))
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_relationship_match(c: &mut Criterion) {
|
||||
let query = "MATCH (a:Person)-[r:KNOWS]->(b:Person) RETURN a, r, b";
|
||||
|
||||
c.bench_function("parser/relationship_match", |b| {
|
||||
b.iter(|| parse_cypher(black_box(query)))
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_chained_relationship(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("parser/chained_relationships");
|
||||
|
||||
// 2-hop chain
|
||||
let query_2hop = "MATCH (a)-[r]->(b)-[s]->(c) RETURN a, c";
|
||||
group.bench_function("2_hop", |b| b.iter(|| parse_cypher(black_box(query_2hop))));
|
||||
|
||||
// 3-hop chain
|
||||
let query_3hop = "MATCH (a)-[r]->(b)-[s]->(c)-[t]->(d) RETURN a, d";
|
||||
group.bench_function("3_hop", |b| b.iter(|| parse_cypher(black_box(query_3hop))));
|
||||
|
||||
// 4-hop chain
|
||||
let query_4hop = "MATCH (a)-[r]->(b)-[s]->(c)-[t]->(d)-[u]->(e) RETURN a, e";
|
||||
group.bench_function("4_hop", |b| b.iter(|| parse_cypher(black_box(query_4hop))));
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_mixed_direction_chain(c: &mut Criterion) {
|
||||
let query = "MATCH (a:Person)-[r:KNOWS]->(b:Person)<-[s:MANAGES]-(c:Manager) RETURN a, b, c";
|
||||
|
||||
c.bench_function("parser/mixed_direction_chain", |b| {
|
||||
b.iter(|| parse_cypher(black_box(query)))
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_map_literal(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("parser/map_literal");
|
||||
|
||||
// Empty map
|
||||
let query_empty = "MATCH (n) RETURN {}";
|
||||
group.bench_function("empty", |b| b.iter(|| parse_cypher(black_box(query_empty))));
|
||||
|
||||
// Small map (2 keys)
|
||||
let query_small = "MATCH (n) RETURN {name: n.name, age: n.age}";
|
||||
group.bench_function("2_keys", |b| {
|
||||
b.iter(|| parse_cypher(black_box(query_small)))
|
||||
});
|
||||
|
||||
// Medium map (5 keys)
|
||||
let query_medium = "MATCH (n) RETURN {a: n.a, b: n.b, c: n.c, d: n.d, e: n.e}";
|
||||
group.bench_function("5_keys", |b| {
|
||||
b.iter(|| parse_cypher(black_box(query_medium)))
|
||||
});
|
||||
|
||||
// Large map (10 keys)
|
||||
let query_large = "MATCH (n) RETURN {a: n.a, b: n.b, c: n.c, d: n.d, e: n.e, f: n.f, g: n.g, h: n.h, i: n.i, j: n.j}";
|
||||
group.bench_function("10_keys", |b| {
|
||||
b.iter(|| parse_cypher(black_box(query_large)))
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_remove_statement(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("parser/remove");
|
||||
|
||||
// Remove property
|
||||
let query_prop = "MATCH (n:Person) REMOVE n.age RETURN n";
|
||||
group.bench_function("property", |b| {
|
||||
b.iter(|| parse_cypher(black_box(query_prop)))
|
||||
});
|
||||
|
||||
// Remove single label
|
||||
let query_label = "MATCH (n:Person:Employee) REMOVE n:Employee RETURN n";
|
||||
group.bench_function("single_label", |b| {
|
||||
b.iter(|| parse_cypher(black_box(query_label)))
|
||||
});
|
||||
|
||||
// Remove multiple labels
|
||||
let query_multi = "MATCH (n:A:B:C:D) REMOVE n:B:C:D RETURN n";
|
||||
group.bench_function("multi_label", |b| {
|
||||
b.iter(|| parse_cypher(black_box(query_multi)))
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_complex_query(c: &mut Criterion) {
|
||||
let query = r#"
|
||||
MATCH (p:Person)-[r:WORKS_AT]->(c:Company)<-[h:HEADQUARTERED]-(l:Location)
|
||||
WHERE p.age > 30 AND c.revenue > 1000000
|
||||
RETURN {
|
||||
person: p.name,
|
||||
company: c.name,
|
||||
location: l.city
|
||||
}
|
||||
ORDER BY p.age DESC
|
||||
LIMIT 10
|
||||
"#;
|
||||
|
||||
c.bench_function("parser/complex_query", |b| {
|
||||
b.iter(|| parse_cypher(black_box(query)))
|
||||
});
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Semantic Search Benchmarks
|
||||
// ============================================================================
|
||||
|
||||
fn setup_semantic_search(num_vectors: usize, dimensions: usize) -> SemanticSearch {
|
||||
let config = EmbeddingConfig {
|
||||
dimensions,
|
||||
..Default::default()
|
||||
};
|
||||
let index = HybridIndex::new(config).unwrap();
|
||||
index.initialize_index(VectorIndexType::Node).unwrap();
|
||||
|
||||
// Add test embeddings
|
||||
for i in 0..num_vectors {
|
||||
let mut embedding = vec![0.0f32; dimensions];
|
||||
// Create varied embeddings
|
||||
embedding[i % dimensions] = 1.0;
|
||||
embedding[(i + 1) % dimensions] = 0.5;
|
||||
|
||||
index
|
||||
.add_node_embedding(format!("node_{}", i), embedding)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
SemanticSearch::new(index, SemanticSearchConfig::default())
|
||||
}
|
||||
|
||||
fn bench_semantic_search_small(c: &mut Criterion) {
|
||||
let search = setup_semantic_search(100, 128);
|
||||
let query: Vec<f32> = (0..128).map(|i| if i == 0 { 1.0 } else { 0.0 }).collect();
|
||||
|
||||
c.bench_function("semantic_search/100_vectors_128d", |b| {
|
||||
b.iter(|| search.find_similar_nodes(black_box(&query), 10))
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_semantic_search_medium(c: &mut Criterion) {
|
||||
let search = setup_semantic_search(1000, 128);
|
||||
let query: Vec<f32> = (0..128).map(|i| if i == 0 { 1.0 } else { 0.0 }).collect();
|
||||
|
||||
c.bench_function("semantic_search/1000_vectors_128d", |b| {
|
||||
b.iter(|| search.find_similar_nodes(black_box(&query), 10))
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_semantic_search_dimensions(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("semantic_search/dimensions");
|
||||
|
||||
for dim in [64, 128, 256, 384, 512].iter() {
|
||||
let search = setup_semantic_search(500, *dim);
|
||||
let query: Vec<f32> = (0..*dim).map(|i| if i == 0 { 1.0 } else { 0.0 }).collect();
|
||||
|
||||
group.bench_with_input(BenchmarkId::from_parameter(dim), dim, |b, _| {
|
||||
b.iter(|| search.find_similar_nodes(black_box(&query), 10))
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_semantic_search_top_k(c: &mut Criterion) {
|
||||
let search = setup_semantic_search(1000, 128);
|
||||
let query: Vec<f32> = (0..128).map(|i| if i == 0 { 1.0 } else { 0.0 }).collect();
|
||||
|
||||
let mut group = c.benchmark_group("semantic_search/top_k");
|
||||
|
||||
for k in [1, 5, 10, 25, 50, 100].iter() {
|
||||
group.bench_with_input(BenchmarkId::from_parameter(k), k, |b, &k| {
|
||||
b.iter(|| search.find_similar_nodes(black_box(&query), k))
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Distance Conversion Benchmark (the fix we made)
|
||||
// ============================================================================
|
||||
|
||||
fn bench_distance_conversion(c: &mut Criterion) {
|
||||
let distances: Vec<f32> = (0..10000).map(|i| (i as f32) / 10000.0).collect();
|
||||
|
||||
c.bench_function("semantic_search/distance_conversion_10k", |b| {
|
||||
b.iter(|| {
|
||||
let _: Vec<f32> = distances.iter().map(|d| 1.0 - d).collect();
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_similarity_filtering(c: &mut Criterion) {
|
||||
let distances: Vec<f32> = (0..10000).map(|i| (i as f32) / 10000.0).collect();
|
||||
let min_similarity = 0.7f32;
|
||||
|
||||
c.bench_function("semantic_search/similarity_filter_10k", |b| {
|
||||
b.iter(|| {
|
||||
let _: Vec<f32> = distances
|
||||
.iter()
|
||||
.map(|d| 1.0 - d)
|
||||
.filter(|s| *s >= min_similarity)
|
||||
.collect();
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
parser_benches,
|
||||
bench_simple_match,
|
||||
bench_relationship_match,
|
||||
bench_chained_relationship,
|
||||
bench_mixed_direction_chain,
|
||||
bench_map_literal,
|
||||
bench_remove_statement,
|
||||
bench_complex_query,
|
||||
);
|
||||
|
||||
criterion_group!(
|
||||
semantic_search_benches,
|
||||
bench_semantic_search_small,
|
||||
bench_semantic_search_medium,
|
||||
bench_semantic_search_dimensions,
|
||||
bench_semantic_search_top_k,
|
||||
bench_distance_conversion,
|
||||
bench_similarity_filtering,
|
||||
);
|
||||
|
||||
criterion_main!(parser_benches, semantic_search_benches);
|
||||
11
vendor/ruvector/crates/ruvector-graph/benches/query_execution.rs
vendored
Normal file
11
vendor/ruvector/crates/ruvector-graph/benches/query_execution.rs
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
// Placeholder benchmark for query execution
|
||||
// TODO: Implement comprehensive benchmarks
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
|
||||
fn query_execution_benchmark(c: &mut Criterion) {
|
||||
c.bench_function("placeholder", |b| b.iter(|| black_box(42)));
|
||||
}
|
||||
|
||||
criterion_group!(benches, query_execution_benchmark);
|
||||
criterion_main!(benches);
|
||||
11
vendor/ruvector/crates/ruvector-graph/benches/simd_operations.rs
vendored
Normal file
11
vendor/ruvector/crates/ruvector-graph/benches/simd_operations.rs
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
// Placeholder benchmark for SIMD operations
|
||||
// TODO: Implement comprehensive benchmarks
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
|
||||
fn simd_operations_benchmark(c: &mut Criterion) {
|
||||
c.bench_function("placeholder", |b| b.iter(|| black_box(42)));
|
||||
}
|
||||
|
||||
criterion_group!(benches, simd_operations_benchmark);
|
||||
criterion_main!(benches);
|
||||
121
vendor/ruvector/crates/ruvector-graph/examples/test_cypher_parser.rs
vendored
Normal file
121
vendor/ruvector/crates/ruvector-graph/examples/test_cypher_parser.rs
vendored
Normal file
@@ -0,0 +1,121 @@
|
||||
//! Standalone example demonstrating the Cypher parser
|
||||
//! Run with: cargo run --example test_cypher_parser
|
||||
|
||||
use ruvector_graph::cypher::{ast::*, parse_cypher};
|
||||
|
||||
fn main() {
|
||||
println!("=== Cypher Parser Test Suite ===\n");
|
||||
|
||||
// Test 1: Simple MATCH
|
||||
println!("Test 1: Simple MATCH query");
|
||||
let query1 = "MATCH (n:Person) RETURN n";
|
||||
match parse_cypher(query1) {
|
||||
Ok(ast) => {
|
||||
println!("✓ Parsed successfully");
|
||||
println!(" Statements: {}", ast.statements.len());
|
||||
println!(" Read-only: {}", ast.is_read_only());
|
||||
}
|
||||
Err(e) => println!("✗ Parse error: {}", e),
|
||||
}
|
||||
println!();
|
||||
|
||||
// Test 2: MATCH with WHERE
|
||||
println!("Test 2: MATCH with WHERE clause");
|
||||
let query2 = "MATCH (n:Person) WHERE n.age > 30 RETURN n.name";
|
||||
match parse_cypher(query2) {
|
||||
Ok(_) => println!("✓ Parsed successfully"),
|
||||
Err(e) => println!("✗ Parse error: {}", e),
|
||||
}
|
||||
println!();
|
||||
|
||||
// Test 3: Relationship pattern
|
||||
println!("Test 3: Relationship pattern");
|
||||
let query3 = "MATCH (a:Person)-[r:KNOWS]->(b:Person) RETURN a, r, b";
|
||||
match parse_cypher(query3) {
|
||||
Ok(_) => println!("✓ Parsed successfully"),
|
||||
Err(e) => println!("✗ Parse error: {}", e),
|
||||
}
|
||||
println!();
|
||||
|
||||
// Test 4: CREATE node
|
||||
println!("Test 4: CREATE node");
|
||||
let query4 = "CREATE (n:Person {name: 'Alice', age: 30})";
|
||||
match parse_cypher(query4) {
|
||||
Ok(ast) => {
|
||||
println!("✓ Parsed successfully");
|
||||
println!(" Read-only: {}", ast.is_read_only());
|
||||
}
|
||||
Err(e) => println!("✗ Parse error: {}", e),
|
||||
}
|
||||
println!();
|
||||
|
||||
// Test 5: Hyperedge (N-ary relationship)
|
||||
println!("Test 5: Hyperedge pattern");
|
||||
let query5 = "MATCH (a)-[r:TRANSACTION]->(b, c, d) RETURN a, r, b, c, d";
|
||||
match parse_cypher(query5) {
|
||||
Ok(ast) => {
|
||||
println!("✓ Parsed successfully");
|
||||
println!(" Has hyperedges: {}", ast.has_hyperedges());
|
||||
}
|
||||
Err(e) => println!("✗ Parse error: {}", e),
|
||||
}
|
||||
println!();
|
||||
|
||||
// Test 6: Aggregation functions
|
||||
println!("Test 6: Aggregation functions");
|
||||
let query6 = "MATCH (n:Person) RETURN COUNT(n), AVG(n.age), MAX(n.salary)";
|
||||
match parse_cypher(query6) {
|
||||
Ok(_) => println!("✓ Parsed successfully"),
|
||||
Err(e) => println!("✗ Parse error: {}", e),
|
||||
}
|
||||
println!();
|
||||
|
||||
// Test 7: Complex query with ORDER BY and LIMIT
|
||||
println!("Test 7: Complex query with ORDER BY and LIMIT");
|
||||
let query7 = r#"
|
||||
MATCH (a:Person)-[r:KNOWS]->(b:Person)
|
||||
WHERE a.age > 30 AND b.name = 'Alice'
|
||||
RETURN a.name, b.name, r.since
|
||||
ORDER BY r.since DESC
|
||||
LIMIT 10
|
||||
"#;
|
||||
match parse_cypher(query7) {
|
||||
Ok(_) => println!("✓ Parsed successfully"),
|
||||
Err(e) => println!("✗ Parse error: {}", e),
|
||||
}
|
||||
println!();
|
||||
|
||||
// Test 8: MERGE with ON CREATE
|
||||
println!("Test 8: MERGE with ON CREATE");
|
||||
let query8 = "MERGE (n:Person {name: 'Bob'}) ON CREATE SET n.created = 2024";
|
||||
match parse_cypher(query8) {
|
||||
Ok(_) => println!("✓ Parsed successfully"),
|
||||
Err(e) => println!("✗ Parse error: {}", e),
|
||||
}
|
||||
println!();
|
||||
|
||||
// Test 9: WITH clause (query chaining)
|
||||
println!("Test 9: WITH clause");
|
||||
let query9 = r#"
|
||||
MATCH (n:Person)
|
||||
WITH n, n.age AS age
|
||||
WHERE age > 30
|
||||
RETURN n.name, age
|
||||
"#;
|
||||
match parse_cypher(query9) {
|
||||
Ok(_) => println!("✓ Parsed successfully"),
|
||||
Err(e) => println!("✗ Parse error: {}", e),
|
||||
}
|
||||
println!();
|
||||
|
||||
// Test 10: Variable-length path
|
||||
println!("Test 10: Variable-length path");
|
||||
let query10 = "MATCH p = (a:Person)-[*1..3]->(b:Person) RETURN p";
|
||||
match parse_cypher(query10) {
|
||||
Ok(_) => println!("✓ Parsed successfully"),
|
||||
Err(e) => println!("✗ Parse error: {}", e),
|
||||
}
|
||||
println!();
|
||||
|
||||
println!("=== All tests completed ===");
|
||||
}
|
||||
431
vendor/ruvector/crates/ruvector-graph/src/cypher/README.md
vendored
Normal file
431
vendor/ruvector/crates/ruvector-graph/src/cypher/README.md
vendored
Normal file
@@ -0,0 +1,431 @@
|
||||
# Cypher Query Language Parser for RuVector
|
||||
|
||||
A complete Cypher-compatible query language parser implementation for the RuVector graph database, built using the nom parser combinator library.
|
||||
|
||||
## Overview
|
||||
|
||||
This module provides a full-featured Cypher query parser that converts Cypher query text into an Abstract Syntax Tree (AST) suitable for execution. It includes:
|
||||
|
||||
- **Lexical Analysis** (`lexer.rs`): Tokenizes Cypher query strings
|
||||
- **Syntax Parsing** (`parser.rs`): Recursive descent parser using nom
|
||||
- **AST Definitions** (`ast.rs`): Complete type system for Cypher queries
|
||||
- **Semantic Analysis** (`semantic.rs`): Type checking and validation
|
||||
- **Query Optimization** (`optimizer.rs`): Query plan optimization
|
||||
|
||||
## Supported Cypher Features
|
||||
|
||||
### Pattern Matching
|
||||
```cypher
|
||||
MATCH (n:Person)
|
||||
MATCH (a:Person)-[r:KNOWS]->(b:Person)
|
||||
OPTIONAL MATCH (n)-[r]->()
|
||||
```
|
||||
|
||||
### Hyperedges (N-ary Relationships)
|
||||
```cypher
|
||||
-- Transaction involving multiple parties
|
||||
MATCH (person)-[r:TRANSACTION]->(acc1:Account, acc2:Account, merchant:Merchant)
|
||||
WHERE r.amount > 1000
|
||||
RETURN person, r, acc1, acc2, merchant
|
||||
```
|
||||
|
||||
### Filtering
|
||||
```cypher
|
||||
WHERE n.age > 30 AND n.name = 'Alice'
|
||||
WHERE n.age >= 18 OR n.verified = true
|
||||
```
|
||||
|
||||
### Projections and Aggregations
|
||||
```cypher
|
||||
RETURN n.name, n.age
|
||||
RETURN COUNT(n), AVG(n.age), MAX(n.salary), COLLECT(n.name)
|
||||
RETURN DISTINCT n.department
|
||||
```
|
||||
|
||||
### Mutations
|
||||
```cypher
|
||||
CREATE (n:Person {name: 'Bob', age: 30})
|
||||
MERGE (n:Person {email: 'alice@example.com'})
|
||||
ON CREATE SET n.created = timestamp()
|
||||
ON MATCH SET n.accessed = timestamp()
|
||||
DELETE n
|
||||
DETACH DELETE n
|
||||
SET n.age = 31, n.updated = timestamp()
|
||||
```
|
||||
|
||||
### Query Chaining
|
||||
```cypher
|
||||
MATCH (n:Person)
|
||||
WITH n, n.age AS age
|
||||
WHERE age > 30
|
||||
RETURN n.name, age
|
||||
ORDER BY age DESC
|
||||
LIMIT 10
|
||||
```
|
||||
|
||||
### Path Patterns
|
||||
```cypher
|
||||
MATCH p = (a:Person)-[*1..5]->(b:Person)
|
||||
RETURN p
|
||||
```
|
||||
|
||||
### Advanced Expressions
|
||||
```cypher
|
||||
CASE
|
||||
WHEN n.age < 18 THEN 'minor'
|
||||
WHEN n.age < 65 THEN 'adult'
|
||||
ELSE 'senior'
|
||||
END
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
### 1. Lexer (`lexer.rs`)
|
||||
|
||||
The lexer converts raw text into a stream of tokens:
|
||||
|
||||
```rust
|
||||
use ruvector_graph::cypher::lexer::tokenize;
|
||||
|
||||
let tokens = tokenize("MATCH (n:Person) RETURN n")?;
|
||||
// Returns: [MATCH, (, Identifier("n"), :, Identifier("Person"), ), RETURN, Identifier("n")]
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Full Cypher keyword support
|
||||
- String literals (single and double quoted)
|
||||
- Numeric literals (integers and floats with scientific notation)
|
||||
- Operators and delimiters
|
||||
- Position tracking for error reporting
|
||||
|
||||
### 2. Parser (`parser.rs`)
|
||||
|
||||
Recursive descent parser using nom combinators:
|
||||
|
||||
```rust
|
||||
use ruvector_graph::cypher::parse_cypher;
|
||||
|
||||
let query = "MATCH (n:Person) WHERE n.age > 30 RETURN n.name";
|
||||
let ast = parse_cypher(query)?;
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Error recovery and detailed error messages
|
||||
- Support for all Cypher clauses
|
||||
- Hyperedge pattern recognition
|
||||
- Operator precedence handling
|
||||
- Property map parsing
|
||||
|
||||
### 3. AST (`ast.rs`)
|
||||
|
||||
Complete Abstract Syntax Tree representation:
|
||||
|
||||
```rust
|
||||
pub struct Query {
|
||||
pub statements: Vec<Statement>,
|
||||
}
|
||||
|
||||
pub enum Statement {
|
||||
Match(MatchClause),
|
||||
Create(CreateClause),
|
||||
Merge(MergeClause),
|
||||
Delete(DeleteClause),
|
||||
Set(SetClause),
|
||||
Return(ReturnClause),
|
||||
With(WithClause),
|
||||
}
|
||||
|
||||
// Hyperedge support for N-ary relationships
|
||||
pub struct HyperedgePattern {
|
||||
pub variable: Option<String>,
|
||||
pub rel_type: String,
|
||||
pub properties: Option<PropertyMap>,
|
||||
pub from: Box<NodePattern>,
|
||||
pub to: Vec<NodePattern>, // Multiple targets
|
||||
pub arity: usize, // N-ary degree
|
||||
}
|
||||
```
|
||||
|
||||
**Key Types:**
|
||||
- `Pattern`: Node, Relationship, Path, and Hyperedge patterns
|
||||
- `Expression`: Full expression tree with operators and functions
|
||||
- `AggregationFunction`: COUNT, SUM, AVG, MIN, MAX, COLLECT
|
||||
- `BinaryOperator`: Arithmetic, comparison, logical, string operations
|
||||
|
||||
### 4. Semantic Analyzer (`semantic.rs`)
|
||||
|
||||
Type checking and validation:
|
||||
|
||||
```rust
|
||||
use ruvector_graph::cypher::semantic::SemanticAnalyzer;
|
||||
|
||||
let mut analyzer = SemanticAnalyzer::new();
|
||||
analyzer.analyze_query(&ast)?;
|
||||
```
|
||||
|
||||
**Checks:**
|
||||
- Variable scope and lifetime
|
||||
- Type compatibility
|
||||
- Aggregation context validation
|
||||
- Hyperedge validity (minimum 2 target nodes)
|
||||
- Pattern correctness
|
||||
|
||||
### 5. Query Optimizer (`optimizer.rs`)
|
||||
|
||||
Query plan optimization:
|
||||
|
||||
```rust
|
||||
use ruvector_graph::cypher::optimizer::QueryOptimizer;
|
||||
|
||||
let optimizer = QueryOptimizer::new();
|
||||
let plan = optimizer.optimize(query);
|
||||
|
||||
println!("Optimizations: {:?}", plan.optimizations_applied);
|
||||
println!("Estimated cost: {}", plan.estimated_cost);
|
||||
```
|
||||
|
||||
**Optimizations:**
|
||||
- **Constant Folding**: Evaluate constant expressions at parse time
|
||||
- **Predicate Pushdown**: Move filters closer to data access
|
||||
- **Join Reordering**: Minimize intermediate result sizes
|
||||
- **Selectivity Estimation**: Optimize pattern matching order
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Query Parsing
|
||||
|
||||
```rust
|
||||
use ruvector_graph::cypher::{parse_cypher, Query};
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let query = r#"
|
||||
MATCH (person:Person)-[knows:KNOWS]->(friend:Person)
|
||||
WHERE person.age > 25 AND friend.city = 'NYC'
|
||||
RETURN person.name, friend.name, knows.since
|
||||
ORDER BY knows.since DESC
|
||||
LIMIT 10
|
||||
"#;
|
||||
|
||||
let ast = parse_cypher(query)?;
|
||||
|
||||
println!("Parsed {} statements", ast.statements.len());
|
||||
println!("Read-only query: {}", ast.is_read_only());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
### Hyperedge Queries
|
||||
|
||||
```rust
|
||||
use ruvector_graph::cypher::parse_cypher;
|
||||
|
||||
// Parse a hyperedge pattern (N-ary relationship)
|
||||
let query = r#"
|
||||
MATCH (buyer:Person)-[txn:PURCHASE]->(
|
||||
product:Product,
|
||||
seller:Person,
|
||||
warehouse:Location
|
||||
)
|
||||
WHERE txn.amount > 100
|
||||
RETURN buyer, product, seller, warehouse, txn.timestamp
|
||||
"#;
|
||||
|
||||
let ast = parse_cypher(query)?;
|
||||
assert!(ast.has_hyperedges());
|
||||
```
|
||||
|
||||
### Semantic Analysis
|
||||
|
||||
```rust
|
||||
use ruvector_graph::cypher::{parse_cypher, semantic::SemanticAnalyzer};
|
||||
|
||||
let query = "MATCH (n:Person) RETURN COUNT(n), AVG(n.age)";
|
||||
let ast = parse_cypher(query)?;
|
||||
|
||||
let mut analyzer = SemanticAnalyzer::new();
|
||||
match analyzer.analyze_query(&ast) {
|
||||
Ok(()) => println!("Query is semantically valid"),
|
||||
Err(e) => eprintln!("Semantic error: {}", e),
|
||||
}
|
||||
```
|
||||
|
||||
### Query Optimization
|
||||
|
||||
```rust
|
||||
use ruvector_graph::cypher::{parse_cypher, optimizer::QueryOptimizer};
|
||||
|
||||
let query = r#"
|
||||
MATCH (a:Person), (b:Person)
|
||||
WHERE a.age > 30 AND b.name = 'Alice' AND 2 + 2 = 4
|
||||
RETURN a, b
|
||||
"#;
|
||||
|
||||
let ast = parse_cypher(query)?;
|
||||
let optimizer = QueryOptimizer::new();
|
||||
let plan = optimizer.optimize(ast);
|
||||
|
||||
println!("Applied optimizations: {:?}", plan.optimizations_applied);
|
||||
println!("Estimated execution cost: {:.2}", plan.estimated_cost);
|
||||
```
|
||||
|
||||
## Hyperedge Support
|
||||
|
||||
Traditional graph databases represent relationships as binary edges (one source, one target). RuVector's Cypher parser supports **hyperedges** - relationships connecting multiple nodes simultaneously.
|
||||
|
||||
### Why Hyperedges?
|
||||
|
||||
- **Multi-party Transactions**: Model transfers involving multiple accounts
|
||||
- **Complex Events**: Represent events with multiple participants
|
||||
- **N-way Relationships**: Natural representation of real-world scenarios
|
||||
|
||||
### Hyperedge Syntax
|
||||
|
||||
```cypher
|
||||
-- Create a 3-way transaction
|
||||
CREATE (alice:Person)-[t:TRANSFER {amount: 100}]->(
|
||||
bob:Person,
|
||||
carol:Person
|
||||
)
|
||||
|
||||
-- Match complex patterns
|
||||
MATCH (author:Person)-[collab:AUTHORED]->(
|
||||
paper:Paper,
|
||||
coauthor1:Person,
|
||||
coauthor2:Person
|
||||
)
|
||||
RETURN author, paper, coauthor1, coauthor2
|
||||
|
||||
-- Hyperedge with properties
|
||||
MATCH (teacher)-[class:TEACHES {semester: 'Fall2024'}]->(
|
||||
student1, student2, student3, course:Course
|
||||
)
|
||||
WHERE course.level = 'Graduate'
|
||||
RETURN teacher, course, student1, student2, student3
|
||||
```
|
||||
|
||||
### Hyperedge AST
|
||||
|
||||
```rust
|
||||
pub struct HyperedgePattern {
|
||||
pub variable: Option<String>, // Optional variable binding
|
||||
pub rel_type: String, // Relationship type (required)
|
||||
pub properties: Option<PropertyMap>, // Optional properties
|
||||
pub from: Box<NodePattern>, // Source node
|
||||
pub to: Vec<NodePattern>, // Multiple target nodes (>= 2)
|
||||
pub arity: usize, // Total nodes (source + targets)
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The parser provides detailed error messages with position information:
|
||||
|
||||
```rust
|
||||
use ruvector_graph::cypher::parse_cypher;
|
||||
|
||||
match parse_cypher("MATCH (n:Person WHERE n.age > 30") {
|
||||
Ok(ast) => { /* ... */ },
|
||||
Err(e) => {
|
||||
eprintln!("Parse error: {}", e);
|
||||
// Output: "Unexpected token: expected ), found WHERE at line 1, column 17"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
- **Lexer**: ~500ns per token on average
|
||||
- **Parser**: ~50-200μs for typical queries
|
||||
- **Optimization**: ~10-50μs for plan generation
|
||||
|
||||
Benchmarks available in `benches/cypher_parser.rs`:
|
||||
|
||||
```bash
|
||||
cargo bench --package ruvector-graph --bench cypher_parser
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
Comprehensive test coverage across all modules:
|
||||
|
||||
```bash
|
||||
# Run all Cypher tests
|
||||
cargo test --package ruvector-graph --lib cypher
|
||||
|
||||
# Run parser integration tests
|
||||
cargo test --package ruvector-graph --test cypher_parser_integration
|
||||
|
||||
# Run specific test
|
||||
cargo test --package ruvector-graph test_hyperedge_pattern
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### Nom Parser Combinators
|
||||
|
||||
The parser uses [nom](https://github.com/Geal/nom), a Rust parser combinator library:
|
||||
|
||||
```rust
|
||||
fn parse_node_pattern(input: &str) -> IResult<&str, NodePattern> {
|
||||
preceded(
|
||||
char('('),
|
||||
terminated(
|
||||
parse_node_content,
|
||||
char(')')
|
||||
)
|
||||
)(input)
|
||||
}
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Zero-copy parsing
|
||||
- Composable parsers
|
||||
- Excellent error handling
|
||||
- Type-safe combinators
|
||||
|
||||
### Type System
|
||||
|
||||
The semantic analyzer implements a simple type system:
|
||||
|
||||
```rust
|
||||
pub enum ValueType {
|
||||
Integer, Float, String, Boolean, Null,
|
||||
Node, Relationship, Path,
|
||||
List(Box<ValueType>),
|
||||
Map,
|
||||
Any,
|
||||
}
|
||||
```
|
||||
|
||||
Type compatibility checks ensure query correctness before execution.
|
||||
|
||||
### Cost-Based Optimization
|
||||
|
||||
The optimizer estimates query cost based on:
|
||||
|
||||
1. **Pattern Selectivity**: More specific patterns are cheaper
|
||||
2. **Index Availability**: Indexed properties reduce scan cost
|
||||
3. **Cardinality Estimates**: Smaller intermediate results are better
|
||||
4. **Operation Cost**: Aggregations, sorts, and joins have inherent costs
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- [ ] Subqueries (CALL {...})
|
||||
- [ ] User-defined functions
|
||||
- [ ] Graph projections
|
||||
- [ ] Pattern comprehensions
|
||||
- [ ] JIT compilation for hot paths
|
||||
- [ ] Parallel query execution
|
||||
- [ ] Advanced cost-based optimization
|
||||
- [ ] Query result caching
|
||||
|
||||
## References
|
||||
|
||||
- [Cypher Query Language Reference](https://neo4j.com/docs/cypher-manual/current/)
|
||||
- [openCypher](http://www.opencypher.org/) - Open specification
|
||||
- [GQL Standard](https://www.gqlstandards.org/) - ISO graph query language
|
||||
|
||||
## License
|
||||
|
||||
MIT License - See LICENSE file for details
|
||||
472
vendor/ruvector/crates/ruvector-graph/src/cypher/ast.rs
vendored
Normal file
472
vendor/ruvector/crates/ruvector-graph/src/cypher/ast.rs
vendored
Normal file
@@ -0,0 +1,472 @@
|
||||
//! Abstract Syntax Tree definitions for Cypher query language
|
||||
//!
|
||||
//! Represents the parsed structure of Cypher queries including:
|
||||
//! - Pattern matching (MATCH, OPTIONAL MATCH)
|
||||
//! - Filtering (WHERE)
|
||||
//! - Projections (RETURN, WITH)
|
||||
//! - Mutations (CREATE, MERGE, DELETE, SET)
|
||||
//! - Aggregations and ordering
|
||||
//! - Hyperedge support for N-ary relationships
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Top-level query representation
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct Query {
|
||||
pub statements: Vec<Statement>,
|
||||
}
|
||||
|
||||
/// Individual query statement
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum Statement {
|
||||
Match(MatchClause),
|
||||
Create(CreateClause),
|
||||
Merge(MergeClause),
|
||||
Delete(DeleteClause),
|
||||
Set(SetClause),
|
||||
Remove(RemoveClause),
|
||||
Return(ReturnClause),
|
||||
With(WithClause),
|
||||
}
|
||||
|
||||
/// MATCH clause for pattern matching
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct MatchClause {
|
||||
pub optional: bool,
|
||||
pub patterns: Vec<Pattern>,
|
||||
pub where_clause: Option<WhereClause>,
|
||||
}
|
||||
|
||||
/// Pattern matching expressions
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum Pattern {
|
||||
/// Simple node pattern: (n:Label {props})
|
||||
Node(NodePattern),
|
||||
/// Relationship pattern: (a)-[r:TYPE]->(b)
|
||||
Relationship(RelationshipPattern),
|
||||
/// Path pattern: p = (a)-[*1..5]->(b)
|
||||
Path(PathPattern),
|
||||
/// Hyperedge pattern for N-ary relationships: (a)-[r:TYPE]->(b,c,d)
|
||||
Hyperedge(HyperedgePattern),
|
||||
}
|
||||
|
||||
/// Node pattern: (variable:Label {property: value})
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct NodePattern {
|
||||
pub variable: Option<String>,
|
||||
pub labels: Vec<String>,
|
||||
pub properties: Option<PropertyMap>,
|
||||
}
|
||||
|
||||
/// Relationship pattern: [variable:Type {properties}]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct RelationshipPattern {
|
||||
pub variable: Option<String>,
|
||||
pub rel_type: Option<String>,
|
||||
pub properties: Option<PropertyMap>,
|
||||
pub direction: Direction,
|
||||
pub range: Option<RelationshipRange>,
|
||||
/// Source node pattern
|
||||
pub from: Box<NodePattern>,
|
||||
/// Target - can be a NodePattern or another Pattern for chained relationships
|
||||
/// For simple relationships like (a)-[r]->(b), this is just the node
|
||||
/// For chained patterns like (a)-[r]->(b)<-[s]-(c), the target is nested
|
||||
pub to: Box<Pattern>,
|
||||
}
|
||||
|
||||
/// Hyperedge pattern for N-ary relationships
|
||||
/// Example: (person)-[r:TRANSACTION]->(account1, account2, merchant)
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct HyperedgePattern {
|
||||
pub variable: Option<String>,
|
||||
pub rel_type: String,
|
||||
pub properties: Option<PropertyMap>,
|
||||
pub from: Box<NodePattern>,
|
||||
pub to: Vec<NodePattern>, // Multiple target nodes for N-ary relationships
|
||||
pub arity: usize, // Number of participating nodes (including source)
|
||||
}
|
||||
|
||||
/// Relationship direction
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum Direction {
|
||||
Outgoing, // ->
|
||||
Incoming, // <-
|
||||
Undirected, // -
|
||||
}
|
||||
|
||||
/// Relationship range for path queries: [*min..max]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct RelationshipRange {
|
||||
pub min: Option<usize>,
|
||||
pub max: Option<usize>,
|
||||
}
|
||||
|
||||
/// Path pattern: p = (a)-[*]->(b)
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct PathPattern {
|
||||
pub variable: String,
|
||||
pub pattern: Box<Pattern>,
|
||||
}
|
||||
|
||||
/// Property map: {key: value, ...}
|
||||
pub type PropertyMap = HashMap<String, Expression>;
|
||||
|
||||
/// WHERE clause for filtering
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct WhereClause {
|
||||
pub condition: Expression,
|
||||
}
|
||||
|
||||
/// CREATE clause for creating nodes and relationships
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CreateClause {
|
||||
pub patterns: Vec<Pattern>,
|
||||
}
|
||||
|
||||
/// MERGE clause for create-or-match
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct MergeClause {
|
||||
pub pattern: Pattern,
|
||||
pub on_create: Option<SetClause>,
|
||||
pub on_match: Option<SetClause>,
|
||||
}
|
||||
|
||||
/// DELETE clause
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct DeleteClause {
|
||||
pub detach: bool,
|
||||
pub expressions: Vec<Expression>,
|
||||
}
|
||||
|
||||
/// SET clause for updating properties
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct SetClause {
|
||||
pub items: Vec<SetItem>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum SetItem {
|
||||
Property {
|
||||
variable: String,
|
||||
property: String,
|
||||
value: Expression,
|
||||
},
|
||||
Variable {
|
||||
variable: String,
|
||||
value: Expression,
|
||||
},
|
||||
Labels {
|
||||
variable: String,
|
||||
labels: Vec<String>,
|
||||
},
|
||||
}
|
||||
|
||||
/// REMOVE clause for removing properties or labels
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct RemoveClause {
|
||||
pub items: Vec<RemoveItem>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum RemoveItem {
|
||||
/// Remove a property: REMOVE n.property
|
||||
Property { variable: String, property: String },
|
||||
/// Remove labels: REMOVE n:Label1:Label2
|
||||
Labels {
|
||||
variable: String,
|
||||
labels: Vec<String>,
|
||||
},
|
||||
}
|
||||
|
||||
/// RETURN clause for projection
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ReturnClause {
|
||||
pub distinct: bool,
|
||||
pub items: Vec<ReturnItem>,
|
||||
pub order_by: Option<OrderBy>,
|
||||
pub skip: Option<Expression>,
|
||||
pub limit: Option<Expression>,
|
||||
}
|
||||
|
||||
/// WITH clause for chaining queries
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct WithClause {
|
||||
pub distinct: bool,
|
||||
pub items: Vec<ReturnItem>,
|
||||
pub where_clause: Option<WhereClause>,
|
||||
pub order_by: Option<OrderBy>,
|
||||
pub skip: Option<Expression>,
|
||||
pub limit: Option<Expression>,
|
||||
}
|
||||
|
||||
/// Return item: expression AS alias
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ReturnItem {
|
||||
pub expression: Expression,
|
||||
pub alias: Option<String>,
|
||||
}
|
||||
|
||||
/// ORDER BY clause
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct OrderBy {
|
||||
pub items: Vec<OrderByItem>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct OrderByItem {
|
||||
pub expression: Expression,
|
||||
pub ascending: bool,
|
||||
}
|
||||
|
||||
/// Expression tree
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum Expression {
|
||||
// Literals
|
||||
Integer(i64),
|
||||
Float(f64),
|
||||
String(String),
|
||||
Boolean(bool),
|
||||
Null,
|
||||
|
||||
// Variables and properties
|
||||
Variable(String),
|
||||
Property {
|
||||
object: Box<Expression>,
|
||||
property: String,
|
||||
},
|
||||
|
||||
// Collections
|
||||
List(Vec<Expression>),
|
||||
Map(HashMap<String, Expression>),
|
||||
|
||||
// Operators
|
||||
BinaryOp {
|
||||
left: Box<Expression>,
|
||||
op: BinaryOperator,
|
||||
right: Box<Expression>,
|
||||
},
|
||||
UnaryOp {
|
||||
op: UnaryOperator,
|
||||
operand: Box<Expression>,
|
||||
},
|
||||
|
||||
// Functions and aggregations
|
||||
FunctionCall {
|
||||
name: String,
|
||||
args: Vec<Expression>,
|
||||
},
|
||||
Aggregation {
|
||||
function: AggregationFunction,
|
||||
expression: Box<Expression>,
|
||||
distinct: bool,
|
||||
},
|
||||
|
||||
// Pattern predicates
|
||||
PatternPredicate(Box<Pattern>),
|
||||
|
||||
// Case expressions
|
||||
Case {
|
||||
expression: Option<Box<Expression>>,
|
||||
alternatives: Vec<(Expression, Expression)>,
|
||||
default: Option<Box<Expression>>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Binary operators
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum BinaryOperator {
|
||||
// Arithmetic
|
||||
Add,
|
||||
Subtract,
|
||||
Multiply,
|
||||
Divide,
|
||||
Modulo,
|
||||
Power,
|
||||
|
||||
// Comparison
|
||||
Equal,
|
||||
NotEqual,
|
||||
LessThan,
|
||||
LessThanOrEqual,
|
||||
GreaterThan,
|
||||
GreaterThanOrEqual,
|
||||
|
||||
// Logical
|
||||
And,
|
||||
Or,
|
||||
Xor,
|
||||
|
||||
// String
|
||||
Contains,
|
||||
StartsWith,
|
||||
EndsWith,
|
||||
Matches, // Regex
|
||||
|
||||
// Collection
|
||||
In,
|
||||
|
||||
// Null checking
|
||||
Is,
|
||||
IsNot,
|
||||
}
|
||||
|
||||
/// Unary operators
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum UnaryOperator {
|
||||
Not,
|
||||
Minus,
|
||||
Plus,
|
||||
IsNull,
|
||||
IsNotNull,
|
||||
}
|
||||
|
||||
/// Aggregation functions
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum AggregationFunction {
|
||||
Count,
|
||||
Sum,
|
||||
Avg,
|
||||
Min,
|
||||
Max,
|
||||
Collect,
|
||||
StdDev,
|
||||
StdDevP,
|
||||
Percentile,
|
||||
}
|
||||
|
||||
impl Query {
|
||||
pub fn new(statements: Vec<Statement>) -> Self {
|
||||
Self { statements }
|
||||
}
|
||||
|
||||
/// Check if query contains only read operations
|
||||
pub fn is_read_only(&self) -> bool {
|
||||
self.statements.iter().all(|stmt| {
|
||||
matches!(
|
||||
stmt,
|
||||
Statement::Match(_) | Statement::Return(_) | Statement::With(_)
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if query contains hyperedges
|
||||
pub fn has_hyperedges(&self) -> bool {
|
||||
self.statements.iter().any(|stmt| match stmt {
|
||||
Statement::Match(m) => m
|
||||
.patterns
|
||||
.iter()
|
||||
.any(|p| matches!(p, Pattern::Hyperedge(_))),
|
||||
Statement::Create(c) => c
|
||||
.patterns
|
||||
.iter()
|
||||
.any(|p| matches!(p, Pattern::Hyperedge(_))),
|
||||
Statement::Merge(m) => matches!(&m.pattern, Pattern::Hyperedge(_)),
|
||||
_ => false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Pattern {
|
||||
/// Get the arity of the pattern (number of nodes involved)
|
||||
pub fn arity(&self) -> usize {
|
||||
match self {
|
||||
Pattern::Node(_) => 1,
|
||||
Pattern::Relationship(_) => 2,
|
||||
Pattern::Path(_) => 2, // Simplified, could be variable
|
||||
Pattern::Hyperedge(h) => h.arity,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Expression {
|
||||
/// Check if expression is constant (no variables)
|
||||
pub fn is_constant(&self) -> bool {
|
||||
match self {
|
||||
Expression::Integer(_)
|
||||
| Expression::Float(_)
|
||||
| Expression::String(_)
|
||||
| Expression::Boolean(_)
|
||||
| Expression::Null => true,
|
||||
Expression::List(items) => items.iter().all(|e| e.is_constant()),
|
||||
Expression::Map(map) => map.values().all(|e| e.is_constant()),
|
||||
Expression::BinaryOp { left, right, .. } => left.is_constant() && right.is_constant(),
|
||||
Expression::UnaryOp { operand, .. } => operand.is_constant(),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if expression contains aggregation
|
||||
pub fn has_aggregation(&self) -> bool {
|
||||
match self {
|
||||
Expression::Aggregation { .. } => true,
|
||||
Expression::BinaryOp { left, right, .. } => {
|
||||
left.has_aggregation() || right.has_aggregation()
|
||||
}
|
||||
Expression::UnaryOp { operand, .. } => operand.has_aggregation(),
|
||||
Expression::FunctionCall { args, .. } => args.iter().any(|e| e.has_aggregation()),
|
||||
Expression::List(items) => items.iter().any(|e| e.has_aggregation()),
|
||||
Expression::Property { object, .. } => object.has_aggregation(),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_query_is_read_only() {
|
||||
let query = Query::new(vec![
|
||||
Statement::Match(MatchClause {
|
||||
optional: false,
|
||||
patterns: vec![],
|
||||
where_clause: None,
|
||||
}),
|
||||
Statement::Return(ReturnClause {
|
||||
distinct: false,
|
||||
items: vec![],
|
||||
order_by: None,
|
||||
skip: None,
|
||||
limit: None,
|
||||
}),
|
||||
]);
|
||||
assert!(query.is_read_only());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expression_is_constant() {
|
||||
assert!(Expression::Integer(42).is_constant());
|
||||
assert!(Expression::String("test".to_string()).is_constant());
|
||||
assert!(!Expression::Variable("x".to_string()).is_constant());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hyperedge_arity() {
|
||||
let hyperedge = Pattern::Hyperedge(HyperedgePattern {
|
||||
variable: Some("r".to_string()),
|
||||
rel_type: "TRANSACTION".to_string(),
|
||||
properties: None,
|
||||
from: Box::new(NodePattern {
|
||||
variable: Some("a".to_string()),
|
||||
labels: vec![],
|
||||
properties: None,
|
||||
}),
|
||||
to: vec![
|
||||
NodePattern {
|
||||
variable: Some("b".to_string()),
|
||||
labels: vec![],
|
||||
properties: None,
|
||||
},
|
||||
NodePattern {
|
||||
variable: Some("c".to_string()),
|
||||
labels: vec![],
|
||||
properties: None,
|
||||
},
|
||||
],
|
||||
arity: 3,
|
||||
});
|
||||
assert_eq!(hyperedge.arity(), 3);
|
||||
}
|
||||
}
|
||||
430
vendor/ruvector/crates/ruvector-graph/src/cypher/lexer.rs
vendored
Normal file
430
vendor/ruvector/crates/ruvector-graph/src/cypher/lexer.rs
vendored
Normal file
@@ -0,0 +1,430 @@
|
||||
//! Lexical analyzer (tokenizer) for Cypher query language
|
||||
//!
|
||||
//! Converts raw Cypher text into a stream of tokens for parsing.
|
||||
|
||||
use nom::{
|
||||
branch::alt,
|
||||
bytes::complete::{tag, tag_no_case, take_while, take_while1},
|
||||
character::complete::{char, multispace0, multispace1, one_of},
|
||||
combinator::{map, opt, recognize},
|
||||
multi::many0,
|
||||
sequence::{delimited, pair, preceded, tuple},
|
||||
IResult,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
/// Token with kind and location information
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct Token {
|
||||
pub kind: TokenKind,
|
||||
pub lexeme: String,
|
||||
pub position: Position,
|
||||
}
|
||||
|
||||
/// Source position for error reporting
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct Position {
|
||||
pub line: usize,
|
||||
pub column: usize,
|
||||
pub offset: usize,
|
||||
}
|
||||
|
||||
/// Token kinds
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum TokenKind {
|
||||
// Keywords
|
||||
Match,
|
||||
OptionalMatch,
|
||||
Where,
|
||||
Return,
|
||||
Create,
|
||||
Merge,
|
||||
Delete,
|
||||
DetachDelete,
|
||||
Set,
|
||||
Remove,
|
||||
With,
|
||||
OrderBy,
|
||||
Limit,
|
||||
Skip,
|
||||
Distinct,
|
||||
As,
|
||||
Asc,
|
||||
Desc,
|
||||
Case,
|
||||
When,
|
||||
Then,
|
||||
Else,
|
||||
End,
|
||||
And,
|
||||
Or,
|
||||
Xor,
|
||||
Not,
|
||||
In,
|
||||
Is,
|
||||
Null,
|
||||
True,
|
||||
False,
|
||||
OnCreate,
|
||||
OnMatch,
|
||||
|
||||
// Identifiers and literals
|
||||
Identifier(String),
|
||||
Integer(i64),
|
||||
Float(f64),
|
||||
String(String),
|
||||
|
||||
// Operators
|
||||
Plus,
|
||||
Minus,
|
||||
Star,
|
||||
Slash,
|
||||
Percent,
|
||||
Caret,
|
||||
Equal,
|
||||
NotEqual,
|
||||
LessThan,
|
||||
LessThanOrEqual,
|
||||
GreaterThan,
|
||||
GreaterThanOrEqual,
|
||||
Arrow, // ->
|
||||
LeftArrow, // <-
|
||||
Dash, // -
|
||||
|
||||
// Delimiters
|
||||
LeftParen,
|
||||
RightParen,
|
||||
LeftBracket,
|
||||
RightBracket,
|
||||
LeftBrace,
|
||||
RightBrace,
|
||||
Comma,
|
||||
Dot,
|
||||
Colon,
|
||||
Semicolon,
|
||||
Pipe,
|
||||
|
||||
// Special
|
||||
DotDot, // ..
|
||||
Eof,
|
||||
}
|
||||
|
||||
impl fmt::Display for TokenKind {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
TokenKind::Identifier(s) => write!(f, "identifier '{}'", s),
|
||||
TokenKind::Integer(n) => write!(f, "integer {}", n),
|
||||
TokenKind::Float(n) => write!(f, "float {}", n),
|
||||
TokenKind::String(s) => write!(f, "string \"{}\"", s),
|
||||
_ => write!(f, "{:?}", self),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tokenize a Cypher query string
|
||||
pub fn tokenize(input: &str) -> Result<Vec<Token>, LexerError> {
|
||||
let mut tokens = Vec::new();
|
||||
let mut remaining = input;
|
||||
let mut position = Position {
|
||||
line: 1,
|
||||
column: 1,
|
||||
offset: 0,
|
||||
};
|
||||
|
||||
while !remaining.is_empty() {
|
||||
// Skip whitespace
|
||||
if let Ok((rest, _)) = multispace1::<_, nom::error::Error<_>>(remaining) {
|
||||
let consumed = remaining.len() - rest.len();
|
||||
update_position(&mut position, &remaining[..consumed]);
|
||||
remaining = rest;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Try to parse a token
|
||||
match parse_token(remaining) {
|
||||
Ok((rest, (kind, lexeme))) => {
|
||||
tokens.push(Token {
|
||||
kind,
|
||||
lexeme: lexeme.to_string(),
|
||||
position,
|
||||
});
|
||||
update_position(&mut position, lexeme);
|
||||
remaining = rest;
|
||||
}
|
||||
Err(_) => {
|
||||
return Err(LexerError::UnexpectedCharacter {
|
||||
character: remaining.chars().next().unwrap(),
|
||||
position,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokens.push(Token {
|
||||
kind: TokenKind::Eof,
|
||||
lexeme: String::new(),
|
||||
position,
|
||||
});
|
||||
|
||||
Ok(tokens)
|
||||
}
|
||||
|
||||
fn update_position(pos: &mut Position, text: &str) {
|
||||
for ch in text.chars() {
|
||||
pos.offset += ch.len_utf8();
|
||||
if ch == '\n' {
|
||||
pos.line += 1;
|
||||
pos.column = 1;
|
||||
} else {
|
||||
pos.column += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_token(input: &str) -> IResult<&str, (TokenKind, &str)> {
|
||||
alt((
|
||||
parse_keyword,
|
||||
parse_number,
|
||||
parse_string,
|
||||
parse_identifier,
|
||||
parse_operator,
|
||||
parse_delimiter,
|
||||
))(input)
|
||||
}
|
||||
|
||||
fn parse_keyword(input: &str) -> IResult<&str, (TokenKind, &str)> {
|
||||
let (input, _) = multispace0(input)?;
|
||||
|
||||
// Split into nested alt() calls since nom's alt() supports max 21 alternatives
|
||||
alt((
|
||||
alt((
|
||||
map(tag_no_case("OPTIONAL MATCH"), |s: &str| {
|
||||
(TokenKind::OptionalMatch, s)
|
||||
}),
|
||||
map(tag_no_case("DETACH DELETE"), |s: &str| {
|
||||
(TokenKind::DetachDelete, s)
|
||||
}),
|
||||
map(tag_no_case("ORDER BY"), |s: &str| (TokenKind::OrderBy, s)),
|
||||
map(tag_no_case("ON CREATE"), |s: &str| (TokenKind::OnCreate, s)),
|
||||
map(tag_no_case("ON MATCH"), |s: &str| (TokenKind::OnMatch, s)),
|
||||
map(tag_no_case("MATCH"), |s: &str| (TokenKind::Match, s)),
|
||||
map(tag_no_case("WHERE"), |s: &str| (TokenKind::Where, s)),
|
||||
map(tag_no_case("RETURN"), |s: &str| (TokenKind::Return, s)),
|
||||
map(tag_no_case("CREATE"), |s: &str| (TokenKind::Create, s)),
|
||||
map(tag_no_case("MERGE"), |s: &str| (TokenKind::Merge, s)),
|
||||
map(tag_no_case("DELETE"), |s: &str| (TokenKind::Delete, s)),
|
||||
map(tag_no_case("SET"), |s: &str| (TokenKind::Set, s)),
|
||||
map(tag_no_case("REMOVE"), |s: &str| (TokenKind::Remove, s)),
|
||||
map(tag_no_case("WITH"), |s: &str| (TokenKind::With, s)),
|
||||
map(tag_no_case("LIMIT"), |s: &str| (TokenKind::Limit, s)),
|
||||
map(tag_no_case("SKIP"), |s: &str| (TokenKind::Skip, s)),
|
||||
map(tag_no_case("DISTINCT"), |s: &str| (TokenKind::Distinct, s)),
|
||||
)),
|
||||
alt((
|
||||
map(tag_no_case("ASC"), |s: &str| (TokenKind::Asc, s)),
|
||||
map(tag_no_case("DESC"), |s: &str| (TokenKind::Desc, s)),
|
||||
map(tag_no_case("CASE"), |s: &str| (TokenKind::Case, s)),
|
||||
map(tag_no_case("WHEN"), |s: &str| (TokenKind::When, s)),
|
||||
map(tag_no_case("THEN"), |s: &str| (TokenKind::Then, s)),
|
||||
map(tag_no_case("ELSE"), |s: &str| (TokenKind::Else, s)),
|
||||
map(tag_no_case("END"), |s: &str| (TokenKind::End, s)),
|
||||
map(tag_no_case("AND"), |s: &str| (TokenKind::And, s)),
|
||||
map(tag_no_case("OR"), |s: &str| (TokenKind::Or, s)),
|
||||
map(tag_no_case("XOR"), |s: &str| (TokenKind::Xor, s)),
|
||||
map(tag_no_case("NOT"), |s: &str| (TokenKind::Not, s)),
|
||||
map(tag_no_case("IN"), |s: &str| (TokenKind::In, s)),
|
||||
map(tag_no_case("IS"), |s: &str| (TokenKind::Is, s)),
|
||||
map(tag_no_case("NULL"), |s: &str| (TokenKind::Null, s)),
|
||||
map(tag_no_case("TRUE"), |s: &str| (TokenKind::True, s)),
|
||||
map(tag_no_case("FALSE"), |s: &str| (TokenKind::False, s)),
|
||||
map(tag_no_case("AS"), |s: &str| (TokenKind::As, s)),
|
||||
)),
|
||||
))(input)
|
||||
}
|
||||
|
||||
fn parse_number(input: &str) -> IResult<&str, (TokenKind, &str)> {
|
||||
let (input, _) = multispace0(input)?;
|
||||
|
||||
// Try to parse float first
|
||||
if let Ok((rest, num_str)) = recognize::<_, _, nom::error::Error<_>, _>(tuple((
|
||||
opt(char('-')),
|
||||
take_while1(|c: char| c.is_ascii_digit()),
|
||||
char('.'),
|
||||
take_while1(|c: char| c.is_ascii_digit()),
|
||||
opt(tuple((
|
||||
one_of("eE"),
|
||||
opt(one_of("+-")),
|
||||
take_while1(|c: char| c.is_ascii_digit()),
|
||||
))),
|
||||
)))(input)
|
||||
{
|
||||
if let Ok(n) = num_str.parse::<f64>() {
|
||||
return Ok((rest, (TokenKind::Float(n), num_str)));
|
||||
}
|
||||
}
|
||||
|
||||
// Parse integer
|
||||
let (rest, num_str) = recognize(tuple((
|
||||
opt(char('-')),
|
||||
take_while1(|c: char| c.is_ascii_digit()),
|
||||
)))(input)?;
|
||||
|
||||
let n = num_str.parse::<i64>().map_err(|_| {
|
||||
nom::Err::Error(nom::error::Error::new(input, nom::error::ErrorKind::Digit))
|
||||
})?;
|
||||
|
||||
Ok((rest, (TokenKind::Integer(n), num_str)))
|
||||
}
|
||||
|
||||
fn parse_string(input: &str) -> IResult<&str, (TokenKind, &str)> {
|
||||
let (input, _) = multispace0(input)?;
|
||||
|
||||
let (rest, s) = alt((
|
||||
delimited(
|
||||
char('\''),
|
||||
recognize(many0(alt((
|
||||
tag("\\'"),
|
||||
tag("\\\\"),
|
||||
take_while1(|c| c != '\'' && c != '\\'),
|
||||
)))),
|
||||
char('\''),
|
||||
),
|
||||
delimited(
|
||||
char('"'),
|
||||
recognize(many0(alt((
|
||||
tag("\\\""),
|
||||
tag("\\\\"),
|
||||
take_while1(|c| c != '"' && c != '\\'),
|
||||
)))),
|
||||
char('"'),
|
||||
),
|
||||
))(input)?;
|
||||
|
||||
// Unescape string
|
||||
let unescaped = s
|
||||
.replace("\\'", "'")
|
||||
.replace("\\\"", "\"")
|
||||
.replace("\\\\", "\\");
|
||||
|
||||
Ok((rest, (TokenKind::String(unescaped), s)))
|
||||
}
|
||||
|
||||
fn parse_identifier(input: &str) -> IResult<&str, (TokenKind, &str)> {
|
||||
let (input, _) = multispace0(input)?;
|
||||
|
||||
// Backtick-quoted identifier
|
||||
let backtick_result: IResult<&str, &str> =
|
||||
delimited(char('`'), take_while1(|c| c != '`'), char('`'))(input);
|
||||
if let Ok((rest, id)) = backtick_result {
|
||||
return Ok((rest, (TokenKind::Identifier(id.to_string()), id)));
|
||||
}
|
||||
|
||||
// Regular identifier
|
||||
let (rest, id) = recognize(pair(
|
||||
alt((
|
||||
take_while1(|c: char| c.is_ascii_alphabetic() || c == '_'),
|
||||
tag("$"),
|
||||
)),
|
||||
take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'),
|
||||
))(input)?;
|
||||
|
||||
Ok((rest, (TokenKind::Identifier(id.to_string()), id)))
|
||||
}
|
||||
|
||||
fn parse_operator(input: &str) -> IResult<&str, (TokenKind, &str)> {
|
||||
let (input, _) = multispace0(input)?;
|
||||
|
||||
alt((
|
||||
map(tag("<="), |s| (TokenKind::LessThanOrEqual, s)),
|
||||
map(tag(">="), |s| (TokenKind::GreaterThanOrEqual, s)),
|
||||
map(tag("<>"), |s| (TokenKind::NotEqual, s)),
|
||||
map(tag("!="), |s| (TokenKind::NotEqual, s)),
|
||||
map(tag("->"), |s| (TokenKind::Arrow, s)),
|
||||
map(tag("<-"), |s| (TokenKind::LeftArrow, s)),
|
||||
map(tag(".."), |s| (TokenKind::DotDot, s)),
|
||||
map(char('+'), |_| (TokenKind::Plus, "+")),
|
||||
map(char('-'), |_| (TokenKind::Dash, "-")),
|
||||
map(char('*'), |_| (TokenKind::Star, "*")),
|
||||
map(char('/'), |_| (TokenKind::Slash, "/")),
|
||||
map(char('%'), |_| (TokenKind::Percent, "%")),
|
||||
map(char('^'), |_| (TokenKind::Caret, "^")),
|
||||
map(char('='), |_| (TokenKind::Equal, "=")),
|
||||
map(char('<'), |_| (TokenKind::LessThan, "<")),
|
||||
map(char('>'), |_| (TokenKind::GreaterThan, ">")),
|
||||
))(input)
|
||||
}
|
||||
|
||||
fn parse_delimiter(input: &str) -> IResult<&str, (TokenKind, &str)> {
|
||||
let (input, _) = multispace0(input)?;
|
||||
|
||||
alt((
|
||||
map(char('('), |_| (TokenKind::LeftParen, "(")),
|
||||
map(char(')'), |_| (TokenKind::RightParen, ")")),
|
||||
map(char('['), |_| (TokenKind::LeftBracket, "[")),
|
||||
map(char(']'), |_| (TokenKind::RightBracket, "]")),
|
||||
map(char('{'), |_| (TokenKind::LeftBrace, "{")),
|
||||
map(char('}'), |_| (TokenKind::RightBrace, "}")),
|
||||
map(char(','), |_| (TokenKind::Comma, ",")),
|
||||
map(char('.'), |_| (TokenKind::Dot, ".")),
|
||||
map(char(':'), |_| (TokenKind::Colon, ":")),
|
||||
map(char(';'), |_| (TokenKind::Semicolon, ";")),
|
||||
map(char('|'), |_| (TokenKind::Pipe, "|")),
|
||||
))(input)
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum LexerError {
|
||||
#[error("Unexpected character '{character}' at line {}, column {}", position.line, position.column)]
|
||||
UnexpectedCharacter { character: char, position: Position },
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_tokenize_simple_match() {
|
||||
let input = "MATCH (n:Person) RETURN n";
|
||||
let tokens = tokenize(input).unwrap();
|
||||
|
||||
assert_eq!(tokens[0].kind, TokenKind::Match);
|
||||
assert_eq!(tokens[1].kind, TokenKind::LeftParen);
|
||||
assert_eq!(tokens[2].kind, TokenKind::Identifier("n".to_string()));
|
||||
assert_eq!(tokens[3].kind, TokenKind::Colon);
|
||||
assert_eq!(tokens[4].kind, TokenKind::Identifier("Person".to_string()));
|
||||
assert_eq!(tokens[5].kind, TokenKind::RightParen);
|
||||
assert_eq!(tokens[6].kind, TokenKind::Return);
|
||||
assert_eq!(tokens[7].kind, TokenKind::Identifier("n".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tokenize_numbers() {
|
||||
let tokens = tokenize("123 45.67 -89 3.14e-2").unwrap();
|
||||
assert_eq!(tokens[0].kind, TokenKind::Integer(123));
|
||||
assert_eq!(tokens[1].kind, TokenKind::Float(45.67));
|
||||
assert_eq!(tokens[2].kind, TokenKind::Integer(-89));
|
||||
assert!(matches!(tokens[3].kind, TokenKind::Float(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tokenize_strings() {
|
||||
let tokens = tokenize(r#"'Alice' "Bob's friend""#).unwrap();
|
||||
assert_eq!(tokens[0].kind, TokenKind::String("Alice".to_string()));
|
||||
assert_eq!(
|
||||
tokens[1].kind,
|
||||
TokenKind::String("Bob's friend".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tokenize_operators() {
|
||||
let tokens = tokenize("-> <- = <> >= <=").unwrap();
|
||||
assert_eq!(tokens[0].kind, TokenKind::Arrow);
|
||||
assert_eq!(tokens[1].kind, TokenKind::LeftArrow);
|
||||
assert_eq!(tokens[2].kind, TokenKind::Equal);
|
||||
assert_eq!(tokens[3].kind, TokenKind::NotEqual);
|
||||
assert_eq!(tokens[4].kind, TokenKind::GreaterThanOrEqual);
|
||||
assert_eq!(tokens[5].kind, TokenKind::LessThanOrEqual);
|
||||
}
|
||||
}
|
||||
20
vendor/ruvector/crates/ruvector-graph/src/cypher/mod.rs
vendored
Normal file
20
vendor/ruvector/crates/ruvector-graph/src/cypher/mod.rs
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
//! Cypher query language parser and execution engine
|
||||
//!
|
||||
//! This module provides a complete Cypher query language implementation including:
|
||||
//! - Lexical analysis (tokenization)
|
||||
//! - Syntax parsing (AST generation)
|
||||
//! - Semantic analysis and type checking
|
||||
//! - Query optimization
|
||||
//! - Support for hyperedges (N-ary relationships)
|
||||
|
||||
pub mod ast;
|
||||
pub mod lexer;
|
||||
pub mod optimizer;
|
||||
pub mod parser;
|
||||
pub mod semantic;
|
||||
|
||||
pub use ast::{Query, Statement};
|
||||
pub use lexer::{Token, TokenKind};
|
||||
pub use optimizer::{OptimizationPlan, QueryOptimizer};
|
||||
pub use parser::{parse_cypher, ParseError};
|
||||
pub use semantic::{SemanticAnalyzer, SemanticError};
|
||||
582
vendor/ruvector/crates/ruvector-graph/src/cypher/optimizer.rs
vendored
Normal file
582
vendor/ruvector/crates/ruvector-graph/src/cypher/optimizer.rs
vendored
Normal file
@@ -0,0 +1,582 @@
|
||||
//! Query optimizer for Cypher queries
|
||||
//!
|
||||
//! Optimizes query execution plans through:
|
||||
//! - Predicate pushdown (filter as early as possible)
|
||||
//! - Join reordering (minimize intermediate results)
|
||||
//! - Index utilization
|
||||
//! - Constant folding
|
||||
//! - Dead code elimination
|
||||
|
||||
use super::ast::*;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// Query optimization plan
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OptimizationPlan {
|
||||
pub optimized_query: Query,
|
||||
pub optimizations_applied: Vec<OptimizationType>,
|
||||
pub estimated_cost: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum OptimizationType {
|
||||
PredicatePushdown,
|
||||
JoinReordering,
|
||||
ConstantFolding,
|
||||
IndexHint,
|
||||
EarlyFiltering,
|
||||
PatternSimplification,
|
||||
DeadCodeElimination,
|
||||
}
|
||||
|
||||
pub struct QueryOptimizer {
|
||||
enable_predicate_pushdown: bool,
|
||||
enable_join_reordering: bool,
|
||||
enable_constant_folding: bool,
|
||||
}
|
||||
|
||||
impl QueryOptimizer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
enable_predicate_pushdown: true,
|
||||
enable_join_reordering: true,
|
||||
enable_constant_folding: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Optimize a query and return an execution plan
|
||||
pub fn optimize(&self, query: Query) -> OptimizationPlan {
|
||||
let mut optimized = query;
|
||||
let mut optimizations = Vec::new();
|
||||
|
||||
// Apply optimizations in order
|
||||
if self.enable_constant_folding {
|
||||
if let Some(q) = self.apply_constant_folding(optimized.clone()) {
|
||||
optimized = q;
|
||||
optimizations.push(OptimizationType::ConstantFolding);
|
||||
}
|
||||
}
|
||||
|
||||
if self.enable_predicate_pushdown {
|
||||
if let Some(q) = self.apply_predicate_pushdown(optimized.clone()) {
|
||||
optimized = q;
|
||||
optimizations.push(OptimizationType::PredicatePushdown);
|
||||
}
|
||||
}
|
||||
|
||||
if self.enable_join_reordering {
|
||||
if let Some(q) = self.apply_join_reordering(optimized.clone()) {
|
||||
optimized = q;
|
||||
optimizations.push(OptimizationType::JoinReordering);
|
||||
}
|
||||
}
|
||||
|
||||
let cost = self.estimate_cost(&optimized);
|
||||
|
||||
OptimizationPlan {
|
||||
optimized_query: optimized,
|
||||
optimizations_applied: optimizations,
|
||||
estimated_cost: cost,
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply constant folding to simplify expressions
|
||||
fn apply_constant_folding(&self, mut query: Query) -> Option<Query> {
|
||||
let mut changed = false;
|
||||
|
||||
for statement in &mut query.statements {
|
||||
if self.fold_statement(statement) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if changed {
|
||||
Some(query)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_statement(&self, statement: &mut Statement) -> bool {
|
||||
match statement {
|
||||
Statement::Match(clause) => {
|
||||
let mut changed = false;
|
||||
if let Some(where_clause) = &mut clause.where_clause {
|
||||
if let Some(folded) = self.fold_expression(&where_clause.condition) {
|
||||
where_clause.condition = folded;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
changed
|
||||
}
|
||||
Statement::Return(clause) => {
|
||||
let mut changed = false;
|
||||
for item in &mut clause.items {
|
||||
if let Some(folded) = self.fold_expression(&item.expression) {
|
||||
item.expression = folded;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
changed
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_expression(&self, expr: &Expression) -> Option<Expression> {
|
||||
match expr {
|
||||
Expression::BinaryOp { left, op, right } => {
|
||||
// Fold operands first
|
||||
let left = self
|
||||
.fold_expression(left)
|
||||
.unwrap_or_else(|| (**left).clone());
|
||||
let right = self
|
||||
.fold_expression(right)
|
||||
.unwrap_or_else(|| (**right).clone());
|
||||
|
||||
// Try to evaluate constant expressions
|
||||
if left.is_constant() && right.is_constant() {
|
||||
return self.evaluate_constant_binary_op(&left, *op, &right);
|
||||
}
|
||||
|
||||
// Return simplified expression
|
||||
Some(Expression::BinaryOp {
|
||||
left: Box::new(left),
|
||||
op: *op,
|
||||
right: Box::new(right),
|
||||
})
|
||||
}
|
||||
Expression::UnaryOp { op, operand } => {
|
||||
let operand = self
|
||||
.fold_expression(operand)
|
||||
.unwrap_or_else(|| (**operand).clone());
|
||||
|
||||
if operand.is_constant() {
|
||||
return self.evaluate_constant_unary_op(*op, &operand);
|
||||
}
|
||||
|
||||
Some(Expression::UnaryOp {
|
||||
op: *op,
|
||||
operand: Box::new(operand),
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_constant_binary_op(
|
||||
&self,
|
||||
left: &Expression,
|
||||
op: BinaryOperator,
|
||||
right: &Expression,
|
||||
) -> Option<Expression> {
|
||||
match (left, op, right) {
|
||||
// Arithmetic operations
|
||||
(Expression::Integer(a), BinaryOperator::Add, Expression::Integer(b)) => {
|
||||
Some(Expression::Integer(a + b))
|
||||
}
|
||||
(Expression::Integer(a), BinaryOperator::Subtract, Expression::Integer(b)) => {
|
||||
Some(Expression::Integer(a - b))
|
||||
}
|
||||
(Expression::Integer(a), BinaryOperator::Multiply, Expression::Integer(b)) => {
|
||||
Some(Expression::Integer(a * b))
|
||||
}
|
||||
(Expression::Integer(a), BinaryOperator::Divide, Expression::Integer(b)) if *b != 0 => {
|
||||
Some(Expression::Integer(a / b))
|
||||
}
|
||||
(Expression::Integer(a), BinaryOperator::Modulo, Expression::Integer(b)) if *b != 0 => {
|
||||
Some(Expression::Integer(a % b))
|
||||
}
|
||||
(Expression::Float(a), BinaryOperator::Add, Expression::Float(b)) => {
|
||||
Some(Expression::Float(a + b))
|
||||
}
|
||||
(Expression::Float(a), BinaryOperator::Subtract, Expression::Float(b)) => {
|
||||
Some(Expression::Float(a - b))
|
||||
}
|
||||
(Expression::Float(a), BinaryOperator::Multiply, Expression::Float(b)) => {
|
||||
Some(Expression::Float(a * b))
|
||||
}
|
||||
(Expression::Float(a), BinaryOperator::Divide, Expression::Float(b)) if *b != 0.0 => {
|
||||
Some(Expression::Float(a / b))
|
||||
}
|
||||
// Comparison operations for integers
|
||||
(Expression::Integer(a), BinaryOperator::Equal, Expression::Integer(b)) => {
|
||||
Some(Expression::Boolean(a == b))
|
||||
}
|
||||
(Expression::Integer(a), BinaryOperator::NotEqual, Expression::Integer(b)) => {
|
||||
Some(Expression::Boolean(a != b))
|
||||
}
|
||||
(Expression::Integer(a), BinaryOperator::LessThan, Expression::Integer(b)) => {
|
||||
Some(Expression::Boolean(a < b))
|
||||
}
|
||||
(Expression::Integer(a), BinaryOperator::LessThanOrEqual, Expression::Integer(b)) => {
|
||||
Some(Expression::Boolean(a <= b))
|
||||
}
|
||||
(Expression::Integer(a), BinaryOperator::GreaterThan, Expression::Integer(b)) => {
|
||||
Some(Expression::Boolean(a > b))
|
||||
}
|
||||
(
|
||||
Expression::Integer(a),
|
||||
BinaryOperator::GreaterThanOrEqual,
|
||||
Expression::Integer(b),
|
||||
) => Some(Expression::Boolean(a >= b)),
|
||||
// Comparison operations for floats
|
||||
(Expression::Float(a), BinaryOperator::Equal, Expression::Float(b)) => {
|
||||
Some(Expression::Boolean((a - b).abs() < f64::EPSILON))
|
||||
}
|
||||
(Expression::Float(a), BinaryOperator::NotEqual, Expression::Float(b)) => {
|
||||
Some(Expression::Boolean((a - b).abs() >= f64::EPSILON))
|
||||
}
|
||||
(Expression::Float(a), BinaryOperator::LessThan, Expression::Float(b)) => {
|
||||
Some(Expression::Boolean(a < b))
|
||||
}
|
||||
(Expression::Float(a), BinaryOperator::LessThanOrEqual, Expression::Float(b)) => {
|
||||
Some(Expression::Boolean(a <= b))
|
||||
}
|
||||
(Expression::Float(a), BinaryOperator::GreaterThan, Expression::Float(b)) => {
|
||||
Some(Expression::Boolean(a > b))
|
||||
}
|
||||
(Expression::Float(a), BinaryOperator::GreaterThanOrEqual, Expression::Float(b)) => {
|
||||
Some(Expression::Boolean(a >= b))
|
||||
}
|
||||
// String comparison
|
||||
(Expression::String(a), BinaryOperator::Equal, Expression::String(b)) => {
|
||||
Some(Expression::Boolean(a == b))
|
||||
}
|
||||
(Expression::String(a), BinaryOperator::NotEqual, Expression::String(b)) => {
|
||||
Some(Expression::Boolean(a != b))
|
||||
}
|
||||
// Boolean operations
|
||||
(Expression::Boolean(a), BinaryOperator::And, Expression::Boolean(b)) => {
|
||||
Some(Expression::Boolean(*a && *b))
|
||||
}
|
||||
(Expression::Boolean(a), BinaryOperator::Or, Expression::Boolean(b)) => {
|
||||
Some(Expression::Boolean(*a || *b))
|
||||
}
|
||||
(Expression::Boolean(a), BinaryOperator::Equal, Expression::Boolean(b)) => {
|
||||
Some(Expression::Boolean(a == b))
|
||||
}
|
||||
(Expression::Boolean(a), BinaryOperator::NotEqual, Expression::Boolean(b)) => {
|
||||
Some(Expression::Boolean(a != b))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_constant_unary_op(
|
||||
&self,
|
||||
op: UnaryOperator,
|
||||
operand: &Expression,
|
||||
) -> Option<Expression> {
|
||||
match (op, operand) {
|
||||
(UnaryOperator::Not, Expression::Boolean(b)) => Some(Expression::Boolean(!b)),
|
||||
(UnaryOperator::Minus, Expression::Integer(n)) => Some(Expression::Integer(-n)),
|
||||
(UnaryOperator::Minus, Expression::Float(n)) => Some(Expression::Float(-n)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply predicate pushdown optimization
|
||||
/// Move WHERE clauses as close to data access as possible
|
||||
fn apply_predicate_pushdown(&self, query: Query) -> Option<Query> {
|
||||
// In a real implementation, this would analyze the query graph
|
||||
// and push predicates down to the earliest possible point
|
||||
// For now, we'll do a simple transformation
|
||||
|
||||
// This is a placeholder - real implementation would be more complex
|
||||
None
|
||||
}
|
||||
|
||||
/// Reorder joins to minimize intermediate result sizes
|
||||
fn apply_join_reordering(&self, query: Query) -> Option<Query> {
|
||||
// Analyze pattern complexity and reorder based on selectivity
|
||||
// Patterns with more constraints should be evaluated first
|
||||
|
||||
let mut optimized = query.clone();
|
||||
let mut changed = false;
|
||||
|
||||
for statement in &mut optimized.statements {
|
||||
if let Statement::Match(clause) = statement {
|
||||
let mut patterns = clause.patterns.clone();
|
||||
|
||||
// Sort patterns by estimated selectivity (more selective first)
|
||||
patterns.sort_by_key(|p| {
|
||||
let selectivity = self.estimate_pattern_selectivity(p);
|
||||
// Use negative to sort in descending order (most selective first)
|
||||
-(selectivity * 1000.0) as i64
|
||||
});
|
||||
|
||||
if patterns != clause.patterns {
|
||||
clause.patterns = patterns;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if changed {
|
||||
Some(optimized)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate the selectivity of a pattern (0.0 = least selective, 1.0 = most selective)
|
||||
fn estimate_pattern_selectivity(&self, pattern: &Pattern) -> f64 {
|
||||
match pattern {
|
||||
Pattern::Node(node) => {
|
||||
let mut selectivity = 0.3; // Base selectivity for node
|
||||
|
||||
// More labels = more selective
|
||||
selectivity += node.labels.len() as f64 * 0.1;
|
||||
|
||||
// Properties = more selective
|
||||
if let Some(props) = &node.properties {
|
||||
selectivity += props.len() as f64 * 0.15;
|
||||
}
|
||||
|
||||
selectivity.min(1.0)
|
||||
}
|
||||
Pattern::Relationship(rel) => {
|
||||
let mut selectivity = 0.2; // Base selectivity for relationship
|
||||
|
||||
// Specific type = more selective
|
||||
if rel.rel_type.is_some() {
|
||||
selectivity += 0.2;
|
||||
}
|
||||
|
||||
// Properties = more selective
|
||||
if let Some(props) = &rel.properties {
|
||||
selectivity += props.len() as f64 * 0.15;
|
||||
}
|
||||
|
||||
// Add selectivity from connected nodes
|
||||
selectivity +=
|
||||
self.estimate_pattern_selectivity(&Pattern::Node(*rel.from.clone())) * 0.3;
|
||||
// rel.to is now a Pattern (can be NodePattern or chained RelationshipPattern)
|
||||
selectivity += self.estimate_pattern_selectivity(&*rel.to) * 0.3;
|
||||
|
||||
selectivity.min(1.0)
|
||||
}
|
||||
Pattern::Hyperedge(hyperedge) => {
|
||||
let mut selectivity = 0.5; // Hyperedges are typically more selective
|
||||
|
||||
// More nodes involved = more selective
|
||||
selectivity += hyperedge.arity as f64 * 0.1;
|
||||
|
||||
if let Some(props) = &hyperedge.properties {
|
||||
selectivity += props.len() as f64 * 0.15;
|
||||
}
|
||||
|
||||
selectivity.min(1.0)
|
||||
}
|
||||
Pattern::Path(_) => 0.1, // Paths are typically less selective
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate the cost of executing a query
|
||||
fn estimate_cost(&self, query: &Query) -> f64 {
|
||||
let mut cost = 0.0;
|
||||
|
||||
for statement in &query.statements {
|
||||
cost += self.estimate_statement_cost(statement);
|
||||
}
|
||||
|
||||
cost
|
||||
}
|
||||
|
||||
fn estimate_statement_cost(&self, statement: &Statement) -> f64 {
|
||||
match statement {
|
||||
Statement::Match(clause) => {
|
||||
let mut cost = 0.0;
|
||||
|
||||
for pattern in &clause.patterns {
|
||||
cost += self.estimate_pattern_cost(pattern);
|
||||
}
|
||||
|
||||
// WHERE clause adds filtering cost
|
||||
if clause.where_clause.is_some() {
|
||||
cost *= 1.2;
|
||||
}
|
||||
|
||||
cost
|
||||
}
|
||||
Statement::Create(clause) => {
|
||||
// Create operations are expensive
|
||||
clause.patterns.len() as f64 * 50.0
|
||||
}
|
||||
Statement::Merge(clause) => {
|
||||
// Merge is more expensive than match or create alone
|
||||
self.estimate_pattern_cost(&clause.pattern) * 2.0
|
||||
}
|
||||
Statement::Delete(_) => 30.0,
|
||||
Statement::Set(_) => 20.0,
|
||||
Statement::Remove(clause) => clause.items.len() as f64 * 15.0,
|
||||
Statement::Return(clause) => {
|
||||
let mut cost = 10.0;
|
||||
|
||||
// Aggregations are expensive
|
||||
for item in &clause.items {
|
||||
if item.expression.has_aggregation() {
|
||||
cost += 50.0;
|
||||
}
|
||||
}
|
||||
|
||||
// Sorting adds cost
|
||||
if clause.order_by.is_some() {
|
||||
cost += 100.0;
|
||||
}
|
||||
|
||||
cost
|
||||
}
|
||||
Statement::With(_) => 15.0,
|
||||
}
|
||||
}
|
||||
|
||||
fn estimate_pattern_cost(&self, pattern: &Pattern) -> f64 {
|
||||
match pattern {
|
||||
Pattern::Node(node) => {
|
||||
let mut cost = 100.0;
|
||||
|
||||
// Labels reduce cost (more selective)
|
||||
cost /= (1.0 + node.labels.len() as f64 * 0.5);
|
||||
|
||||
// Properties reduce cost
|
||||
if let Some(props) = &node.properties {
|
||||
cost /= (1.0 + props.len() as f64 * 0.3);
|
||||
}
|
||||
|
||||
cost
|
||||
}
|
||||
Pattern::Relationship(rel) => {
|
||||
let mut cost = 200.0; // Relationships are more expensive
|
||||
|
||||
// Specific type reduces cost
|
||||
if rel.rel_type.is_some() {
|
||||
cost *= 0.7;
|
||||
}
|
||||
|
||||
// Variable length paths are very expensive
|
||||
if let Some(range) = &rel.range {
|
||||
let max = range.max.unwrap_or(10);
|
||||
cost *= max as f64;
|
||||
}
|
||||
|
||||
cost
|
||||
}
|
||||
Pattern::Hyperedge(hyperedge) => {
|
||||
// Hyperedges are more expensive due to N-ary nature
|
||||
150.0 * hyperedge.arity as f64
|
||||
}
|
||||
Pattern::Path(_) => 300.0, // Paths can be expensive
|
||||
}
|
||||
}
|
||||
|
||||
/// Get variables used in an expression
|
||||
fn get_variables_in_expression(&self, expr: &Expression) -> HashSet<String> {
|
||||
let mut vars = HashSet::new();
|
||||
self.collect_variables(expr, &mut vars);
|
||||
vars
|
||||
}
|
||||
|
||||
fn collect_variables(&self, expr: &Expression, vars: &mut HashSet<String>) {
|
||||
match expr {
|
||||
Expression::Variable(name) => {
|
||||
vars.insert(name.clone());
|
||||
}
|
||||
Expression::Property { object, .. } => {
|
||||
self.collect_variables(object, vars);
|
||||
}
|
||||
Expression::BinaryOp { left, right, .. } => {
|
||||
self.collect_variables(left, vars);
|
||||
self.collect_variables(right, vars);
|
||||
}
|
||||
Expression::UnaryOp { operand, .. } => {
|
||||
self.collect_variables(operand, vars);
|
||||
}
|
||||
Expression::FunctionCall { args, .. } => {
|
||||
for arg in args {
|
||||
self.collect_variables(arg, vars);
|
||||
}
|
||||
}
|
||||
Expression::Aggregation { expression, .. } => {
|
||||
self.collect_variables(expression, vars);
|
||||
}
|
||||
Expression::List(items) => {
|
||||
for item in items {
|
||||
self.collect_variables(item, vars);
|
||||
}
|
||||
}
|
||||
Expression::Case {
|
||||
expression,
|
||||
alternatives,
|
||||
default,
|
||||
} => {
|
||||
if let Some(expr) = expression {
|
||||
self.collect_variables(expr, vars);
|
||||
}
|
||||
for (cond, result) in alternatives {
|
||||
self.collect_variables(cond, vars);
|
||||
self.collect_variables(result, vars);
|
||||
}
|
||||
if let Some(default_expr) = default {
|
||||
self.collect_variables(default_expr, vars);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for QueryOptimizer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::cypher::parser::parse_cypher;
|
||||
|
||||
#[test]
|
||||
fn test_constant_folding() {
|
||||
let query = parse_cypher("MATCH (n) WHERE 2 + 3 = 5 RETURN n").unwrap();
|
||||
let optimizer = QueryOptimizer::new();
|
||||
let plan = optimizer.optimize(query);
|
||||
|
||||
assert!(plan
|
||||
.optimizations_applied
|
||||
.contains(&OptimizationType::ConstantFolding));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cost_estimation() {
|
||||
let query = parse_cypher("MATCH (n:Person {age: 30}) RETURN n").unwrap();
|
||||
let optimizer = QueryOptimizer::new();
|
||||
let cost = optimizer.estimate_cost(&query);
|
||||
|
||||
assert!(cost > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pattern_selectivity() {
|
||||
let optimizer = QueryOptimizer::new();
|
||||
|
||||
let node_with_label = Pattern::Node(NodePattern {
|
||||
variable: Some("n".to_string()),
|
||||
labels: vec!["Person".to_string()],
|
||||
properties: None,
|
||||
});
|
||||
|
||||
let node_without_label = Pattern::Node(NodePattern {
|
||||
variable: Some("n".to_string()),
|
||||
labels: vec![],
|
||||
properties: None,
|
||||
});
|
||||
|
||||
let sel_with = optimizer.estimate_pattern_selectivity(&node_with_label);
|
||||
let sel_without = optimizer.estimate_pattern_selectivity(&node_without_label);
|
||||
|
||||
assert!(sel_with > sel_without);
|
||||
}
|
||||
}
|
||||
1295
vendor/ruvector/crates/ruvector-graph/src/cypher/parser.rs
vendored
Normal file
1295
vendor/ruvector/crates/ruvector-graph/src/cypher/parser.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
616
vendor/ruvector/crates/ruvector-graph/src/cypher/semantic.rs
vendored
Normal file
616
vendor/ruvector/crates/ruvector-graph/src/cypher/semantic.rs
vendored
Normal file
@@ -0,0 +1,616 @@
|
||||
//! Semantic analysis and type checking for Cypher queries
|
||||
//!
|
||||
//! Validates the semantic correctness of parsed Cypher queries including:
|
||||
//! - Variable scope checking
|
||||
//! - Type compatibility validation
|
||||
//! - Aggregation context verification
|
||||
//! - Pattern validity
|
||||
|
||||
use super::ast::*;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum SemanticError {
|
||||
#[error("Undefined variable: {0}")]
|
||||
UndefinedVariable(String),
|
||||
|
||||
#[error("Variable already defined: {0}")]
|
||||
VariableAlreadyDefined(String),
|
||||
|
||||
#[error("Type mismatch: expected {expected}, found {found}")]
|
||||
TypeMismatch { expected: String, found: String },
|
||||
|
||||
#[error("Aggregation not allowed in {0}")]
|
||||
InvalidAggregation(String),
|
||||
|
||||
#[error("Cannot mix aggregated and non-aggregated expressions")]
|
||||
MixedAggregation,
|
||||
|
||||
#[error("Invalid pattern: {0}")]
|
||||
InvalidPattern(String),
|
||||
|
||||
#[error("Invalid hyperedge: {0}")]
|
||||
InvalidHyperedge(String),
|
||||
|
||||
#[error("Property access on non-object type")]
|
||||
InvalidPropertyAccess,
|
||||
|
||||
#[error(
|
||||
"Invalid number of arguments for function {function}: expected {expected}, found {found}"
|
||||
)]
|
||||
InvalidArgumentCount {
|
||||
function: String,
|
||||
expected: usize,
|
||||
found: usize,
|
||||
},
|
||||
}
|
||||
|
||||
type SemanticResult<T> = Result<T, SemanticError>;
|
||||
|
||||
/// Semantic analyzer for Cypher queries
|
||||
pub struct SemanticAnalyzer {
|
||||
scope_stack: Vec<Scope>,
|
||||
in_aggregation: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Scope {
|
||||
variables: HashMap<String, ValueType>,
|
||||
}
|
||||
|
||||
/// Type system for Cypher values
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ValueType {
|
||||
Integer,
|
||||
Float,
|
||||
String,
|
||||
Boolean,
|
||||
Null,
|
||||
Node,
|
||||
Relationship,
|
||||
Path,
|
||||
List(Box<ValueType>),
|
||||
Map,
|
||||
Any,
|
||||
}
|
||||
|
||||
impl ValueType {
|
||||
/// Check if this type is compatible with another type
|
||||
pub fn is_compatible_with(&self, other: &ValueType) -> bool {
|
||||
match (self, other) {
|
||||
(ValueType::Any, _) | (_, ValueType::Any) => true,
|
||||
(ValueType::Null, _) | (_, ValueType::Null) => true,
|
||||
(ValueType::Integer, ValueType::Float) | (ValueType::Float, ValueType::Integer) => true,
|
||||
(ValueType::List(a), ValueType::List(b)) => a.is_compatible_with(b),
|
||||
_ => self == other,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this is a numeric type
|
||||
pub fn is_numeric(&self) -> bool {
|
||||
matches!(self, ValueType::Integer | ValueType::Float | ValueType::Any)
|
||||
}
|
||||
|
||||
/// Check if this is a graph element (node, relationship, path)
|
||||
pub fn is_graph_element(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
ValueType::Node | ValueType::Relationship | ValueType::Path | ValueType::Any
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Scope {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
variables: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn define(&mut self, name: String, value_type: ValueType) -> SemanticResult<()> {
|
||||
if self.variables.contains_key(&name) {
|
||||
Err(SemanticError::VariableAlreadyDefined(name))
|
||||
} else {
|
||||
self.variables.insert(name, value_type);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn get(&self, name: &str) -> Option<&ValueType> {
|
||||
self.variables.get(name)
|
||||
}
|
||||
}
|
||||
|
||||
impl SemanticAnalyzer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
scope_stack: vec![Scope::new()],
|
||||
in_aggregation: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn current_scope(&self) -> &Scope {
|
||||
self.scope_stack.last().unwrap()
|
||||
}
|
||||
|
||||
fn current_scope_mut(&mut self) -> &mut Scope {
|
||||
self.scope_stack.last_mut().unwrap()
|
||||
}
|
||||
|
||||
fn push_scope(&mut self) {
|
||||
self.scope_stack.push(Scope::new());
|
||||
}
|
||||
|
||||
fn pop_scope(&mut self) {
|
||||
self.scope_stack.pop();
|
||||
}
|
||||
|
||||
fn lookup_variable(&self, name: &str) -> SemanticResult<&ValueType> {
|
||||
for scope in self.scope_stack.iter().rev() {
|
||||
if let Some(value_type) = scope.get(name) {
|
||||
return Ok(value_type);
|
||||
}
|
||||
}
|
||||
Err(SemanticError::UndefinedVariable(name.to_string()))
|
||||
}
|
||||
|
||||
fn define_variable(&mut self, name: String, value_type: ValueType) -> SemanticResult<()> {
|
||||
self.current_scope_mut().define(name, value_type)
|
||||
}
|
||||
|
||||
/// Analyze a complete query
|
||||
pub fn analyze_query(&mut self, query: &Query) -> SemanticResult<()> {
|
||||
for statement in &query.statements {
|
||||
self.analyze_statement(statement)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn analyze_statement(&mut self, statement: &Statement) -> SemanticResult<()> {
|
||||
match statement {
|
||||
Statement::Match(clause) => self.analyze_match(clause),
|
||||
Statement::Create(clause) => self.analyze_create(clause),
|
||||
Statement::Merge(clause) => self.analyze_merge(clause),
|
||||
Statement::Delete(clause) => self.analyze_delete(clause),
|
||||
Statement::Set(clause) => self.analyze_set(clause),
|
||||
Statement::Remove(clause) => self.analyze_remove(clause),
|
||||
Statement::Return(clause) => self.analyze_return(clause),
|
||||
Statement::With(clause) => self.analyze_with(clause),
|
||||
}
|
||||
}
|
||||
|
||||
fn analyze_remove(&mut self, clause: &RemoveClause) -> SemanticResult<()> {
|
||||
for item in &clause.items {
|
||||
match item {
|
||||
RemoveItem::Property { variable, .. } => {
|
||||
// Verify variable is defined
|
||||
self.lookup_variable(variable)?;
|
||||
}
|
||||
RemoveItem::Labels { variable, .. } => {
|
||||
// Verify variable is defined
|
||||
self.lookup_variable(variable)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn analyze_match(&mut self, clause: &MatchClause) -> SemanticResult<()> {
|
||||
// Analyze patterns and define variables
|
||||
for pattern in &clause.patterns {
|
||||
self.analyze_pattern(pattern)?;
|
||||
}
|
||||
|
||||
// Analyze WHERE clause
|
||||
if let Some(where_clause) = &clause.where_clause {
|
||||
let expr_type = self.analyze_expression(&where_clause.condition)?;
|
||||
if !expr_type.is_compatible_with(&ValueType::Boolean) {
|
||||
return Err(SemanticError::TypeMismatch {
|
||||
expected: "Boolean".to_string(),
|
||||
found: format!("{:?}", expr_type),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn analyze_pattern(&mut self, pattern: &Pattern) -> SemanticResult<()> {
|
||||
match pattern {
|
||||
Pattern::Node(node) => self.analyze_node_pattern(node),
|
||||
Pattern::Relationship(rel) => self.analyze_relationship_pattern(rel),
|
||||
Pattern::Path(path) => self.analyze_path_pattern(path),
|
||||
Pattern::Hyperedge(hyperedge) => self.analyze_hyperedge_pattern(hyperedge),
|
||||
}
|
||||
}
|
||||
|
||||
fn analyze_node_pattern(&mut self, node: &NodePattern) -> SemanticResult<()> {
|
||||
if let Some(variable) = &node.variable {
|
||||
self.define_variable(variable.clone(), ValueType::Node)?;
|
||||
}
|
||||
|
||||
if let Some(properties) = &node.properties {
|
||||
for expr in properties.values() {
|
||||
self.analyze_expression(expr)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn analyze_relationship_pattern(&mut self, rel: &RelationshipPattern) -> SemanticResult<()> {
|
||||
self.analyze_node_pattern(&rel.from)?;
|
||||
// rel.to is now a Pattern (can be NodePattern or chained RelationshipPattern)
|
||||
self.analyze_pattern(&*rel.to)?;
|
||||
|
||||
if let Some(variable) = &rel.variable {
|
||||
self.define_variable(variable.clone(), ValueType::Relationship)?;
|
||||
}
|
||||
|
||||
if let Some(properties) = &rel.properties {
|
||||
for expr in properties.values() {
|
||||
self.analyze_expression(expr)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Validate range if present
|
||||
if let Some(range) = &rel.range {
|
||||
if let (Some(min), Some(max)) = (range.min, range.max) {
|
||||
if min > max {
|
||||
return Err(SemanticError::InvalidPattern(
|
||||
"Minimum range cannot be greater than maximum".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn analyze_path_pattern(&mut self, path: &PathPattern) -> SemanticResult<()> {
|
||||
self.define_variable(path.variable.clone(), ValueType::Path)?;
|
||||
self.analyze_pattern(&path.pattern)
|
||||
}
|
||||
|
||||
fn analyze_hyperedge_pattern(&mut self, hyperedge: &HyperedgePattern) -> SemanticResult<()> {
|
||||
// Validate hyperedge has at least 2 target nodes
|
||||
if hyperedge.to.len() < 2 {
|
||||
return Err(SemanticError::InvalidHyperedge(
|
||||
"Hyperedge must have at least 2 target nodes".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Validate arity matches
|
||||
if hyperedge.arity != hyperedge.to.len() + 1 {
|
||||
return Err(SemanticError::InvalidHyperedge(
|
||||
"Hyperedge arity doesn't match number of participating nodes".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
self.analyze_node_pattern(&hyperedge.from)?;
|
||||
|
||||
for target in &hyperedge.to {
|
||||
self.analyze_node_pattern(target)?;
|
||||
}
|
||||
|
||||
if let Some(variable) = &hyperedge.variable {
|
||||
self.define_variable(variable.clone(), ValueType::Relationship)?;
|
||||
}
|
||||
|
||||
if let Some(properties) = &hyperedge.properties {
|
||||
for expr in properties.values() {
|
||||
self.analyze_expression(expr)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn analyze_create(&mut self, clause: &CreateClause) -> SemanticResult<()> {
|
||||
for pattern in &clause.patterns {
|
||||
self.analyze_pattern(pattern)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn analyze_merge(&mut self, clause: &MergeClause) -> SemanticResult<()> {
|
||||
self.analyze_pattern(&clause.pattern)?;
|
||||
|
||||
if let Some(on_create) = &clause.on_create {
|
||||
self.analyze_set(on_create)?;
|
||||
}
|
||||
|
||||
if let Some(on_match) = &clause.on_match {
|
||||
self.analyze_set(on_match)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn analyze_delete(&mut self, clause: &DeleteClause) -> SemanticResult<()> {
|
||||
for expr in &clause.expressions {
|
||||
let expr_type = self.analyze_expression(expr)?;
|
||||
if !expr_type.is_graph_element() {
|
||||
return Err(SemanticError::TypeMismatch {
|
||||
expected: "graph element (node, relationship, path)".to_string(),
|
||||
found: format!("{:?}", expr_type),
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn analyze_set(&mut self, clause: &SetClause) -> SemanticResult<()> {
|
||||
for item in &clause.items {
|
||||
match item {
|
||||
SetItem::Property {
|
||||
variable, value, ..
|
||||
} => {
|
||||
self.lookup_variable(variable)?;
|
||||
self.analyze_expression(value)?;
|
||||
}
|
||||
SetItem::Variable { variable, value } => {
|
||||
self.lookup_variable(variable)?;
|
||||
self.analyze_expression(value)?;
|
||||
}
|
||||
SetItem::Labels { variable, .. } => {
|
||||
self.lookup_variable(variable)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn analyze_return(&mut self, clause: &ReturnClause) -> SemanticResult<()> {
|
||||
self.analyze_return_items(&clause.items)?;
|
||||
|
||||
if let Some(order_by) = &clause.order_by {
|
||||
for item in &order_by.items {
|
||||
self.analyze_expression(&item.expression)?;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(skip) = &clause.skip {
|
||||
let skip_type = self.analyze_expression(skip)?;
|
||||
if !skip_type.is_compatible_with(&ValueType::Integer) {
|
||||
return Err(SemanticError::TypeMismatch {
|
||||
expected: "Integer".to_string(),
|
||||
found: format!("{:?}", skip_type),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(limit) = &clause.limit {
|
||||
let limit_type = self.analyze_expression(limit)?;
|
||||
if !limit_type.is_compatible_with(&ValueType::Integer) {
|
||||
return Err(SemanticError::TypeMismatch {
|
||||
expected: "Integer".to_string(),
|
||||
found: format!("{:?}", limit_type),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn analyze_with(&mut self, clause: &WithClause) -> SemanticResult<()> {
|
||||
self.analyze_return_items(&clause.items)?;
|
||||
|
||||
if let Some(where_clause) = &clause.where_clause {
|
||||
let expr_type = self.analyze_expression(&where_clause.condition)?;
|
||||
if !expr_type.is_compatible_with(&ValueType::Boolean) {
|
||||
return Err(SemanticError::TypeMismatch {
|
||||
expected: "Boolean".to_string(),
|
||||
found: format!("{:?}", expr_type),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn analyze_return_items(&mut self, items: &[ReturnItem]) -> SemanticResult<()> {
|
||||
let mut has_aggregation = false;
|
||||
let mut has_non_aggregation = false;
|
||||
|
||||
for item in items {
|
||||
let item_has_agg = item.expression.has_aggregation();
|
||||
has_aggregation |= item_has_agg;
|
||||
has_non_aggregation |= !item_has_agg && !item.expression.is_constant();
|
||||
}
|
||||
|
||||
if has_aggregation && has_non_aggregation {
|
||||
return Err(SemanticError::MixedAggregation);
|
||||
}
|
||||
|
||||
for item in items {
|
||||
self.analyze_expression(&item.expression)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn analyze_expression(&mut self, expr: &Expression) -> SemanticResult<ValueType> {
|
||||
match expr {
|
||||
Expression::Integer(_) => Ok(ValueType::Integer),
|
||||
Expression::Float(_) => Ok(ValueType::Float),
|
||||
Expression::String(_) => Ok(ValueType::String),
|
||||
Expression::Boolean(_) => Ok(ValueType::Boolean),
|
||||
Expression::Null => Ok(ValueType::Null),
|
||||
|
||||
Expression::Variable(name) => {
|
||||
self.lookup_variable(name)?;
|
||||
Ok(ValueType::Any)
|
||||
}
|
||||
|
||||
Expression::Property { object, .. } => {
|
||||
let obj_type = self.analyze_expression(object)?;
|
||||
if !obj_type.is_graph_element()
|
||||
&& obj_type != ValueType::Map
|
||||
&& obj_type != ValueType::Any
|
||||
{
|
||||
return Err(SemanticError::InvalidPropertyAccess);
|
||||
}
|
||||
Ok(ValueType::Any)
|
||||
}
|
||||
|
||||
Expression::List(items) => {
|
||||
if items.is_empty() {
|
||||
Ok(ValueType::List(Box::new(ValueType::Any)))
|
||||
} else {
|
||||
let first_type = self.analyze_expression(&items[0])?;
|
||||
for item in items.iter().skip(1) {
|
||||
let item_type = self.analyze_expression(item)?;
|
||||
if !item_type.is_compatible_with(&first_type) {
|
||||
return Ok(ValueType::List(Box::new(ValueType::Any)));
|
||||
}
|
||||
}
|
||||
Ok(ValueType::List(Box::new(first_type)))
|
||||
}
|
||||
}
|
||||
|
||||
Expression::Map(map) => {
|
||||
for expr in map.values() {
|
||||
self.analyze_expression(expr)?;
|
||||
}
|
||||
Ok(ValueType::Map)
|
||||
}
|
||||
|
||||
Expression::BinaryOp { left, op, right } => {
|
||||
let left_type = self.analyze_expression(left)?;
|
||||
let right_type = self.analyze_expression(right)?;
|
||||
|
||||
match op {
|
||||
BinaryOperator::Add
|
||||
| BinaryOperator::Subtract
|
||||
| BinaryOperator::Multiply
|
||||
| BinaryOperator::Divide
|
||||
| BinaryOperator::Modulo
|
||||
| BinaryOperator::Power => {
|
||||
if !left_type.is_numeric() || !right_type.is_numeric() {
|
||||
return Err(SemanticError::TypeMismatch {
|
||||
expected: "numeric".to_string(),
|
||||
found: format!("{:?} and {:?}", left_type, right_type),
|
||||
});
|
||||
}
|
||||
if left_type == ValueType::Float || right_type == ValueType::Float {
|
||||
Ok(ValueType::Float)
|
||||
} else {
|
||||
Ok(ValueType::Integer)
|
||||
}
|
||||
}
|
||||
BinaryOperator::Equal
|
||||
| BinaryOperator::NotEqual
|
||||
| BinaryOperator::LessThan
|
||||
| BinaryOperator::LessThanOrEqual
|
||||
| BinaryOperator::GreaterThan
|
||||
| BinaryOperator::GreaterThanOrEqual => Ok(ValueType::Boolean),
|
||||
BinaryOperator::And | BinaryOperator::Or | BinaryOperator::Xor => {
|
||||
Ok(ValueType::Boolean)
|
||||
}
|
||||
_ => Ok(ValueType::Any),
|
||||
}
|
||||
}
|
||||
|
||||
Expression::UnaryOp { op, operand } => {
|
||||
let operand_type = self.analyze_expression(operand)?;
|
||||
match op {
|
||||
UnaryOperator::Not | UnaryOperator::IsNull | UnaryOperator::IsNotNull => {
|
||||
Ok(ValueType::Boolean)
|
||||
}
|
||||
UnaryOperator::Minus | UnaryOperator::Plus => Ok(operand_type),
|
||||
}
|
||||
}
|
||||
|
||||
Expression::FunctionCall { args, .. } => {
|
||||
for arg in args {
|
||||
self.analyze_expression(arg)?;
|
||||
}
|
||||
Ok(ValueType::Any)
|
||||
}
|
||||
|
||||
Expression::Aggregation { expression, .. } => {
|
||||
let old_in_agg = self.in_aggregation;
|
||||
self.in_aggregation = true;
|
||||
let result = self.analyze_expression(expression);
|
||||
self.in_aggregation = old_in_agg;
|
||||
result?;
|
||||
Ok(ValueType::Any)
|
||||
}
|
||||
|
||||
Expression::PatternPredicate(pattern) => {
|
||||
self.analyze_pattern(pattern)?;
|
||||
Ok(ValueType::Boolean)
|
||||
}
|
||||
|
||||
Expression::Case {
|
||||
expression,
|
||||
alternatives,
|
||||
default,
|
||||
} => {
|
||||
if let Some(expr) = expression {
|
||||
self.analyze_expression(expr)?;
|
||||
}
|
||||
|
||||
for (condition, result) in alternatives {
|
||||
self.analyze_expression(condition)?;
|
||||
self.analyze_expression(result)?;
|
||||
}
|
||||
|
||||
if let Some(default_expr) = default {
|
||||
self.analyze_expression(default_expr)?;
|
||||
}
|
||||
|
||||
Ok(ValueType::Any)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SemanticAnalyzer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::cypher::parser::parse_cypher;
|
||||
|
||||
#[test]
|
||||
fn test_analyze_simple_match() {
|
||||
let query = parse_cypher("MATCH (n:Person) RETURN n").unwrap();
|
||||
let mut analyzer = SemanticAnalyzer::new();
|
||||
assert!(analyzer.analyze_query(&query).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_undefined_variable() {
|
||||
let query = parse_cypher("MATCH (n:Person) RETURN m").unwrap();
|
||||
let mut analyzer = SemanticAnalyzer::new();
|
||||
assert!(matches!(
|
||||
analyzer.analyze_query(&query),
|
||||
Err(SemanticError::UndefinedVariable(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mixed_aggregation() {
|
||||
let query = parse_cypher("MATCH (n:Person) RETURN n.name, COUNT(n)").unwrap();
|
||||
let mut analyzer = SemanticAnalyzer::new();
|
||||
assert!(matches!(
|
||||
analyzer.analyze_query(&query),
|
||||
Err(SemanticError::MixedAggregation)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "Hyperedge syntax not yet implemented in parser"]
|
||||
fn test_hyperedge_validation() {
|
||||
let query = parse_cypher("MATCH (a)-[r:REL]->(b, c) RETURN a, r, b, c").unwrap();
|
||||
let mut analyzer = SemanticAnalyzer::new();
|
||||
assert!(analyzer.analyze_query(&query).is_ok());
|
||||
}
|
||||
}
|
||||
535
vendor/ruvector/crates/ruvector-graph/src/distributed/coordinator.rs
vendored
Normal file
535
vendor/ruvector/crates/ruvector-graph/src/distributed/coordinator.rs
vendored
Normal file
@@ -0,0 +1,535 @@
|
||||
//! Query coordinator for distributed graph execution
|
||||
//!
|
||||
//! Coordinates distributed query execution across multiple shards:
|
||||
//! - Query planning and optimization
|
||||
//! - Query routing to relevant shards
|
||||
//! - Result aggregation and merging
|
||||
//! - Transaction coordination across shards
|
||||
//! - Query caching and optimization
|
||||
|
||||
use crate::distributed::shard::{EdgeData, GraphShard, NodeData, NodeId, ShardId};
|
||||
use crate::{GraphError, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use dashmap::DashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Query execution plan
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct QueryPlan {
|
||||
/// Unique query ID
|
||||
pub query_id: String,
|
||||
/// Original query (Cypher-like syntax)
|
||||
pub query: String,
|
||||
/// Shards involved in this query
|
||||
pub target_shards: Vec<ShardId>,
|
||||
/// Execution steps
|
||||
pub steps: Vec<QueryStep>,
|
||||
/// Estimated cost
|
||||
pub estimated_cost: f64,
|
||||
/// Whether this is a distributed query
|
||||
pub is_distributed: bool,
|
||||
/// Creation timestamp
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Individual step in query execution
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum QueryStep {
|
||||
/// Scan nodes with optional filter
|
||||
NodeScan {
|
||||
shard_id: ShardId,
|
||||
label: Option<String>,
|
||||
filter: Option<String>,
|
||||
},
|
||||
/// Scan edges
|
||||
EdgeScan {
|
||||
shard_id: ShardId,
|
||||
edge_type: Option<String>,
|
||||
},
|
||||
/// Join results from multiple shards
|
||||
Join {
|
||||
left_shard: ShardId,
|
||||
right_shard: ShardId,
|
||||
join_key: String,
|
||||
},
|
||||
/// Aggregate results
|
||||
Aggregate {
|
||||
operation: AggregateOp,
|
||||
group_by: Option<String>,
|
||||
},
|
||||
/// Filter results
|
||||
Filter { predicate: String },
|
||||
/// Sort results
|
||||
Sort { key: String, ascending: bool },
|
||||
/// Limit results
|
||||
Limit { count: usize },
|
||||
}
|
||||
|
||||
/// Aggregate operations
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum AggregateOp {
|
||||
Count,
|
||||
Sum(String),
|
||||
Avg(String),
|
||||
Min(String),
|
||||
Max(String),
|
||||
}
|
||||
|
||||
/// Query result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct QueryResult {
|
||||
/// Query ID
|
||||
pub query_id: String,
|
||||
/// Result nodes
|
||||
pub nodes: Vec<NodeData>,
|
||||
/// Result edges
|
||||
pub edges: Vec<EdgeData>,
|
||||
/// Aggregate results
|
||||
pub aggregates: HashMap<String, serde_json::Value>,
|
||||
/// Execution statistics
|
||||
pub stats: QueryStats,
|
||||
}
|
||||
|
||||
/// Query execution statistics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct QueryStats {
|
||||
/// Execution time in milliseconds
|
||||
pub execution_time_ms: u64,
|
||||
/// Number of shards queried
|
||||
pub shards_queried: usize,
|
||||
/// Total nodes scanned
|
||||
pub nodes_scanned: usize,
|
||||
/// Total edges scanned
|
||||
pub edges_scanned: usize,
|
||||
/// Whether query was cached
|
||||
pub cached: bool,
|
||||
}
|
||||
|
||||
/// Shard coordinator for managing distributed queries
|
||||
pub struct ShardCoordinator {
|
||||
/// Map of shard_id to GraphShard
|
||||
shards: Arc<DashMap<ShardId, Arc<GraphShard>>>,
|
||||
/// Query cache
|
||||
query_cache: Arc<DashMap<String, QueryResult>>,
|
||||
/// Active transactions
|
||||
transactions: Arc<DashMap<String, Transaction>>,
|
||||
}
|
||||
|
||||
impl ShardCoordinator {
|
||||
/// Create a new shard coordinator
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
shards: Arc::new(DashMap::new()),
|
||||
query_cache: Arc::new(DashMap::new()),
|
||||
transactions: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a shard with the coordinator
|
||||
pub fn register_shard(&self, shard_id: ShardId, shard: Arc<GraphShard>) {
|
||||
info!("Registering shard {} with coordinator", shard_id);
|
||||
self.shards.insert(shard_id, shard);
|
||||
}
|
||||
|
||||
/// Unregister a shard
|
||||
pub fn unregister_shard(&self, shard_id: ShardId) -> Result<()> {
|
||||
info!("Unregistering shard {}", shard_id);
|
||||
self.shards
|
||||
.remove(&shard_id)
|
||||
.ok_or_else(|| GraphError::ShardError(format!("Shard {} not found", shard_id)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get a shard by ID
|
||||
pub fn get_shard(&self, shard_id: ShardId) -> Option<Arc<GraphShard>> {
|
||||
self.shards.get(&shard_id).map(|s| Arc::clone(s.value()))
|
||||
}
|
||||
|
||||
/// List all registered shards
|
||||
pub fn list_shards(&self) -> Vec<ShardId> {
|
||||
self.shards.iter().map(|e| *e.key()).collect()
|
||||
}
|
||||
|
||||
/// Create a query plan from a Cypher-like query
|
||||
pub fn plan_query(&self, query: &str) -> Result<QueryPlan> {
|
||||
let query_id = Uuid::new_v4().to_string();
|
||||
|
||||
// Parse query and determine target shards
|
||||
// For now, simple heuristic: query all shards for distributed queries
|
||||
let target_shards: Vec<ShardId> = self.list_shards();
|
||||
|
||||
let steps = self.parse_query_steps(query)?;
|
||||
|
||||
let estimated_cost = self.estimate_cost(&steps, &target_shards);
|
||||
|
||||
Ok(QueryPlan {
|
||||
query_id,
|
||||
query: query.to_string(),
|
||||
target_shards,
|
||||
steps,
|
||||
estimated_cost,
|
||||
is_distributed: true,
|
||||
created_at: Utc::now(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse query into execution steps
|
||||
fn parse_query_steps(&self, query: &str) -> Result<Vec<QueryStep>> {
|
||||
// Simplified query parsing
|
||||
// In production, use a proper Cypher parser
|
||||
let mut steps = Vec::new();
|
||||
|
||||
// Example: "MATCH (n:Person) RETURN n"
|
||||
if query.to_lowercase().contains("match") {
|
||||
// Add node scan for each shard
|
||||
for shard_id in self.list_shards() {
|
||||
steps.push(QueryStep::NodeScan {
|
||||
shard_id,
|
||||
label: None,
|
||||
filter: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Add aggregation if needed
|
||||
if query.to_lowercase().contains("count") {
|
||||
steps.push(QueryStep::Aggregate {
|
||||
operation: AggregateOp::Count,
|
||||
group_by: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Add limit if specified
|
||||
if let Some(limit_pos) = query.to_lowercase().find("limit") {
|
||||
if let Some(count_str) = query[limit_pos..].split_whitespace().nth(1) {
|
||||
if let Ok(count) = count_str.parse::<usize>() {
|
||||
steps.push(QueryStep::Limit { count });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(steps)
|
||||
}
|
||||
|
||||
/// Estimate query execution cost
|
||||
fn estimate_cost(&self, steps: &[QueryStep], target_shards: &[ShardId]) -> f64 {
|
||||
let mut cost = 0.0;
|
||||
|
||||
for step in steps {
|
||||
match step {
|
||||
QueryStep::NodeScan { .. } => cost += 10.0,
|
||||
QueryStep::EdgeScan { .. } => cost += 15.0,
|
||||
QueryStep::Join { .. } => cost += 50.0,
|
||||
QueryStep::Aggregate { .. } => cost += 20.0,
|
||||
QueryStep::Filter { .. } => cost += 5.0,
|
||||
QueryStep::Sort { .. } => cost += 30.0,
|
||||
QueryStep::Limit { .. } => cost += 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
// Multiply by number of shards for distributed queries
|
||||
cost * target_shards.len() as f64
|
||||
}
|
||||
|
||||
/// Execute a query plan
|
||||
pub async fn execute_query(&self, plan: QueryPlan) -> Result<QueryResult> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
info!(
|
||||
"Executing query {} across {} shards",
|
||||
plan.query_id,
|
||||
plan.target_shards.len()
|
||||
);
|
||||
|
||||
// Check cache first
|
||||
if let Some(cached) = self.query_cache.get(&plan.query) {
|
||||
debug!("Query cache hit for: {}", plan.query);
|
||||
return Ok(cached.value().clone());
|
||||
}
|
||||
|
||||
let mut nodes = Vec::new();
|
||||
let mut edges = Vec::new();
|
||||
let mut aggregates = HashMap::new();
|
||||
let mut nodes_scanned = 0;
|
||||
let mut edges_scanned = 0;
|
||||
|
||||
// Execute steps
|
||||
for step in &plan.steps {
|
||||
match step {
|
||||
QueryStep::NodeScan {
|
||||
shard_id,
|
||||
label,
|
||||
filter,
|
||||
} => {
|
||||
if let Some(shard) = self.get_shard(*shard_id) {
|
||||
let shard_nodes = shard.list_nodes();
|
||||
nodes_scanned += shard_nodes.len();
|
||||
|
||||
// Apply label filter
|
||||
let filtered: Vec<_> = if let Some(label_filter) = label {
|
||||
shard_nodes
|
||||
.into_iter()
|
||||
.filter(|n| n.labels.contains(label_filter))
|
||||
.collect()
|
||||
} else {
|
||||
shard_nodes
|
||||
};
|
||||
|
||||
nodes.extend(filtered);
|
||||
}
|
||||
}
|
||||
QueryStep::EdgeScan {
|
||||
shard_id,
|
||||
edge_type,
|
||||
} => {
|
||||
if let Some(shard) = self.get_shard(*shard_id) {
|
||||
let shard_edges = shard.list_edges();
|
||||
edges_scanned += shard_edges.len();
|
||||
|
||||
// Apply edge type filter
|
||||
let filtered: Vec<_> = if let Some(type_filter) = edge_type {
|
||||
shard_edges
|
||||
.into_iter()
|
||||
.filter(|e| &e.edge_type == type_filter)
|
||||
.collect()
|
||||
} else {
|
||||
shard_edges
|
||||
};
|
||||
|
||||
edges.extend(filtered);
|
||||
}
|
||||
}
|
||||
QueryStep::Aggregate {
|
||||
operation,
|
||||
group_by,
|
||||
} => {
|
||||
match operation {
|
||||
AggregateOp::Count => {
|
||||
aggregates.insert(
|
||||
"count".to_string(),
|
||||
serde_json::Value::Number(nodes.len().into()),
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
// Implement other aggregations
|
||||
}
|
||||
}
|
||||
}
|
||||
QueryStep::Limit { count } => {
|
||||
nodes.truncate(*count);
|
||||
}
|
||||
_ => {
|
||||
// Implement other steps
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let execution_time_ms = start.elapsed().as_millis() as u64;
|
||||
|
||||
let result = QueryResult {
|
||||
query_id: plan.query_id.clone(),
|
||||
nodes,
|
||||
edges,
|
||||
aggregates,
|
||||
stats: QueryStats {
|
||||
execution_time_ms,
|
||||
shards_queried: plan.target_shards.len(),
|
||||
nodes_scanned,
|
||||
edges_scanned,
|
||||
cached: false,
|
||||
},
|
||||
};
|
||||
|
||||
// Cache the result
|
||||
self.query_cache.insert(plan.query.clone(), result.clone());
|
||||
|
||||
info!(
|
||||
"Query {} completed in {}ms",
|
||||
plan.query_id, execution_time_ms
|
||||
);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Begin a distributed transaction
|
||||
pub fn begin_transaction(&self) -> String {
|
||||
let tx_id = Uuid::new_v4().to_string();
|
||||
let transaction = Transaction::new(tx_id.clone());
|
||||
self.transactions.insert(tx_id.clone(), transaction);
|
||||
info!("Started transaction: {}", tx_id);
|
||||
tx_id
|
||||
}
|
||||
|
||||
/// Commit a transaction
|
||||
pub async fn commit_transaction(&self, tx_id: &str) -> Result<()> {
|
||||
if let Some((_, tx)) = self.transactions.remove(tx_id) {
|
||||
// In production, implement 2PC (Two-Phase Commit)
|
||||
info!("Committing transaction: {}", tx_id);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(GraphError::CoordinatorError(format!(
|
||||
"Transaction not found: {}",
|
||||
tx_id
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback a transaction
|
||||
pub async fn rollback_transaction(&self, tx_id: &str) -> Result<()> {
|
||||
if let Some((_, tx)) = self.transactions.remove(tx_id) {
|
||||
warn!("Rolling back transaction: {}", tx_id);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(GraphError::CoordinatorError(format!(
|
||||
"Transaction not found: {}",
|
||||
tx_id
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear query cache
|
||||
pub fn clear_cache(&self) {
|
||||
self.query_cache.clear();
|
||||
info!("Query cache cleared");
|
||||
}
|
||||
}
|
||||
|
||||
/// Distributed transaction
|
||||
#[derive(Debug, Clone)]
|
||||
struct Transaction {
|
||||
/// Transaction ID
|
||||
id: String,
|
||||
/// Participating shards
|
||||
shards: HashSet<ShardId>,
|
||||
/// Transaction state
|
||||
state: TransactionState,
|
||||
/// Created timestamp
|
||||
created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl Transaction {
|
||||
fn new(id: String) -> Self {
|
||||
Self {
|
||||
id,
|
||||
shards: HashSet::new(),
|
||||
state: TransactionState::Active,
|
||||
created_at: Utc::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Transaction state
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum TransactionState {
|
||||
Active,
|
||||
Preparing,
|
||||
Committed,
|
||||
Aborted,
|
||||
}
|
||||
|
||||
/// Main coordinator for the entire distributed graph system
|
||||
pub struct Coordinator {
|
||||
/// Shard coordinator
|
||||
shard_coordinator: Arc<ShardCoordinator>,
|
||||
/// Coordinator configuration
|
||||
config: CoordinatorConfig,
|
||||
}
|
||||
|
||||
impl Coordinator {
|
||||
/// Create a new coordinator
|
||||
pub fn new(config: CoordinatorConfig) -> Self {
|
||||
Self {
|
||||
shard_coordinator: Arc::new(ShardCoordinator::new()),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the shard coordinator
|
||||
pub fn shard_coordinator(&self) -> Arc<ShardCoordinator> {
|
||||
Arc::clone(&self.shard_coordinator)
|
||||
}
|
||||
|
||||
/// Execute a query
|
||||
pub async fn execute(&self, query: &str) -> Result<QueryResult> {
|
||||
let plan = self.shard_coordinator.plan_query(query)?;
|
||||
self.shard_coordinator.execute_query(plan).await
|
||||
}
|
||||
|
||||
/// Get configuration
|
||||
pub fn config(&self) -> &CoordinatorConfig {
|
||||
&self.config
|
||||
}
|
||||
}
|
||||
|
||||
/// Coordinator configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CoordinatorConfig {
|
||||
/// Enable query caching
|
||||
pub enable_cache: bool,
|
||||
/// Cache TTL in seconds
|
||||
pub cache_ttl_seconds: u64,
|
||||
/// Maximum query execution time
|
||||
pub max_query_time_seconds: u64,
|
||||
/// Enable query optimization
|
||||
pub enable_optimization: bool,
|
||||
}
|
||||
|
||||
impl Default for CoordinatorConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enable_cache: true,
|
||||
cache_ttl_seconds: 300,
|
||||
max_query_time_seconds: 60,
|
||||
enable_optimization: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::distributed::shard::ShardMetadata;
|
||||
use crate::distributed::shard::ShardStrategy;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_shard_coordinator() {
|
||||
let coordinator = ShardCoordinator::new();
|
||||
|
||||
let metadata = ShardMetadata::new(0, "node-1".to_string(), ShardStrategy::Hash);
|
||||
let shard = Arc::new(GraphShard::new(metadata));
|
||||
|
||||
coordinator.register_shard(0, shard);
|
||||
|
||||
assert_eq!(coordinator.list_shards().len(), 1);
|
||||
assert!(coordinator.get_shard(0).is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_query_planning() {
|
||||
let coordinator = ShardCoordinator::new();
|
||||
|
||||
let metadata = ShardMetadata::new(0, "node-1".to_string(), ShardStrategy::Hash);
|
||||
let shard = Arc::new(GraphShard::new(metadata));
|
||||
coordinator.register_shard(0, shard);
|
||||
|
||||
let plan = coordinator.plan_query("MATCH (n:Person) RETURN n").unwrap();
|
||||
|
||||
assert!(!plan.query_id.is_empty());
|
||||
assert!(!plan.steps.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transaction() {
|
||||
let coordinator = ShardCoordinator::new();
|
||||
|
||||
let tx_id = coordinator.begin_transaction();
|
||||
assert!(!tx_id.is_empty());
|
||||
|
||||
coordinator.commit_transaction(&tx_id).await.unwrap();
|
||||
}
|
||||
}
|
||||
582
vendor/ruvector/crates/ruvector-graph/src/distributed/federation.rs
vendored
Normal file
582
vendor/ruvector/crates/ruvector-graph/src/distributed/federation.rs
vendored
Normal file
@@ -0,0 +1,582 @@
|
||||
//! Cross-cluster federation for distributed graph queries
|
||||
//!
|
||||
//! Enables querying across independent RuVector graph clusters:
|
||||
//! - Cluster discovery and registration
|
||||
//! - Remote query execution
|
||||
//! - Result merging from multiple clusters
|
||||
//! - Cross-cluster authentication and authorization
|
||||
|
||||
use crate::distributed::coordinator::{QueryPlan, QueryResult};
|
||||
use crate::distributed::shard::ShardId;
|
||||
use crate::{GraphError, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use dashmap::DashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Unique identifier for a cluster
|
||||
pub type ClusterId = String;
|
||||
|
||||
/// Remote cluster information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RemoteCluster {
|
||||
/// Unique cluster ID
|
||||
pub cluster_id: ClusterId,
|
||||
/// Cluster name
|
||||
pub name: String,
|
||||
/// Cluster endpoint URL
|
||||
pub endpoint: String,
|
||||
/// Cluster status
|
||||
pub status: ClusterStatus,
|
||||
/// Authentication token
|
||||
pub auth_token: Option<String>,
|
||||
/// Last health check timestamp
|
||||
pub last_health_check: DateTime<Utc>,
|
||||
/// Cluster metadata
|
||||
pub metadata: HashMap<String, String>,
|
||||
/// Number of shards in this cluster
|
||||
pub shard_count: u32,
|
||||
/// Cluster region/datacenter
|
||||
pub region: Option<String>,
|
||||
}
|
||||
|
||||
impl RemoteCluster {
|
||||
/// Create a new remote cluster
|
||||
pub fn new(cluster_id: ClusterId, name: String, endpoint: String) -> Self {
|
||||
Self {
|
||||
cluster_id,
|
||||
name,
|
||||
endpoint,
|
||||
status: ClusterStatus::Unknown,
|
||||
auth_token: None,
|
||||
last_health_check: Utc::now(),
|
||||
metadata: HashMap::new(),
|
||||
shard_count: 0,
|
||||
region: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if cluster is healthy
|
||||
pub fn is_healthy(&self) -> bool {
|
||||
matches!(self.status, ClusterStatus::Healthy)
|
||||
}
|
||||
}
|
||||
|
||||
/// Cluster status
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum ClusterStatus {
|
||||
/// Cluster is healthy and available
|
||||
Healthy,
|
||||
/// Cluster is degraded but operational
|
||||
Degraded,
|
||||
/// Cluster is unreachable
|
||||
Unreachable,
|
||||
/// Cluster status unknown
|
||||
Unknown,
|
||||
}
|
||||
|
||||
/// Cluster registry for managing federated clusters
|
||||
pub struct ClusterRegistry {
|
||||
/// Registered clusters
|
||||
clusters: Arc<DashMap<ClusterId, RemoteCluster>>,
|
||||
/// Cluster discovery configuration
|
||||
discovery_config: DiscoveryConfig,
|
||||
}
|
||||
|
||||
impl ClusterRegistry {
|
||||
/// Create a new cluster registry
|
||||
pub fn new(discovery_config: DiscoveryConfig) -> Self {
|
||||
Self {
|
||||
clusters: Arc::new(DashMap::new()),
|
||||
discovery_config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a remote cluster
|
||||
pub fn register_cluster(&self, cluster: RemoteCluster) -> Result<()> {
|
||||
info!(
|
||||
"Registering cluster: {} ({})",
|
||||
cluster.name, cluster.cluster_id
|
||||
);
|
||||
self.clusters.insert(cluster.cluster_id.clone(), cluster);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Unregister a cluster
|
||||
pub fn unregister_cluster(&self, cluster_id: &ClusterId) -> Result<()> {
|
||||
info!("Unregistering cluster: {}", cluster_id);
|
||||
self.clusters.remove(cluster_id).ok_or_else(|| {
|
||||
GraphError::FederationError(format!("Cluster not found: {}", cluster_id))
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get a cluster by ID
|
||||
pub fn get_cluster(&self, cluster_id: &ClusterId) -> Option<RemoteCluster> {
|
||||
self.clusters.get(cluster_id).map(|c| c.value().clone())
|
||||
}
|
||||
|
||||
/// List all registered clusters
|
||||
pub fn list_clusters(&self) -> Vec<RemoteCluster> {
|
||||
self.clusters.iter().map(|e| e.value().clone()).collect()
|
||||
}
|
||||
|
||||
/// List healthy clusters only
|
||||
pub fn healthy_clusters(&self) -> Vec<RemoteCluster> {
|
||||
self.clusters
|
||||
.iter()
|
||||
.filter(|e| e.value().is_healthy())
|
||||
.map(|e| e.value().clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Perform health check on a cluster
|
||||
pub async fn health_check(&self, cluster_id: &ClusterId) -> Result<ClusterStatus> {
|
||||
let cluster = self.get_cluster(cluster_id).ok_or_else(|| {
|
||||
GraphError::FederationError(format!("Cluster not found: {}", cluster_id))
|
||||
})?;
|
||||
|
||||
// In production, make actual HTTP/gRPC health check request
|
||||
// For now, simulate health check
|
||||
let status = ClusterStatus::Healthy;
|
||||
|
||||
// Update cluster status
|
||||
if let Some(mut entry) = self.clusters.get_mut(cluster_id) {
|
||||
entry.status = status;
|
||||
entry.last_health_check = Utc::now();
|
||||
}
|
||||
|
||||
debug!("Health check for cluster {}: {:?}", cluster_id, status);
|
||||
Ok(status)
|
||||
}
|
||||
|
||||
/// Perform health checks on all clusters
|
||||
pub async fn health_check_all(&self) -> HashMap<ClusterId, ClusterStatus> {
|
||||
let mut results = HashMap::new();
|
||||
|
||||
for cluster in self.list_clusters() {
|
||||
match self.health_check(&cluster.cluster_id).await {
|
||||
Ok(status) => {
|
||||
results.insert(cluster.cluster_id, status);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Health check failed for cluster {}: {}",
|
||||
cluster.cluster_id, e
|
||||
);
|
||||
results.insert(cluster.cluster_id, ClusterStatus::Unreachable);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Discover clusters automatically (if enabled)
|
||||
pub async fn discover_clusters(&self) -> Result<Vec<RemoteCluster>> {
|
||||
if !self.discovery_config.auto_discovery {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
info!("Discovering clusters...");
|
||||
|
||||
// In production, implement actual cluster discovery:
|
||||
// - mDNS/DNS-SD for local network
|
||||
// - Consul/etcd for service discovery
|
||||
// - Static configuration file
|
||||
// - Cloud provider APIs (AWS, GCP, Azure)
|
||||
|
||||
// For now, return empty list
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
/// Cluster discovery configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DiscoveryConfig {
|
||||
/// Enable automatic cluster discovery
|
||||
pub auto_discovery: bool,
|
||||
/// Discovery method
|
||||
pub discovery_method: DiscoveryMethod,
|
||||
/// Discovery interval in seconds
|
||||
pub discovery_interval_seconds: u64,
|
||||
/// Health check interval in seconds
|
||||
pub health_check_interval_seconds: u64,
|
||||
}
|
||||
|
||||
impl Default for DiscoveryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
auto_discovery: false,
|
||||
discovery_method: DiscoveryMethod::Static,
|
||||
discovery_interval_seconds: 60,
|
||||
health_check_interval_seconds: 30,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cluster discovery method
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum DiscoveryMethod {
|
||||
/// Static configuration
|
||||
Static,
|
||||
/// DNS-based discovery
|
||||
Dns,
|
||||
/// Consul service discovery
|
||||
Consul,
|
||||
/// etcd service discovery
|
||||
Etcd,
|
||||
/// Kubernetes service discovery
|
||||
Kubernetes,
|
||||
}
|
||||
|
||||
/// Federated query spanning multiple clusters
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FederatedQuery {
|
||||
/// Query ID
|
||||
pub query_id: String,
|
||||
/// Original query
|
||||
pub query: String,
|
||||
/// Target clusters
|
||||
pub target_clusters: Vec<ClusterId>,
|
||||
/// Query execution strategy
|
||||
pub strategy: FederationStrategy,
|
||||
/// Created timestamp
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Federation strategy
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum FederationStrategy {
|
||||
/// Execute on all clusters in parallel
|
||||
Parallel,
|
||||
/// Execute on clusters sequentially
|
||||
Sequential,
|
||||
/// Execute on primary cluster, fallback to others
|
||||
PrimaryWithFallback,
|
||||
/// Execute on nearest/fastest cluster only
|
||||
Nearest,
|
||||
}
|
||||
|
||||
/// Federation engine for cross-cluster queries
|
||||
pub struct Federation {
|
||||
/// Cluster registry
|
||||
registry: Arc<ClusterRegistry>,
|
||||
/// Federation configuration
|
||||
config: FederationConfig,
|
||||
/// Active federated queries
|
||||
active_queries: Arc<DashMap<String, FederatedQuery>>,
|
||||
}
|
||||
|
||||
impl Federation {
|
||||
/// Create a new federation engine
|
||||
pub fn new(config: FederationConfig) -> Self {
|
||||
let discovery_config = DiscoveryConfig::default();
|
||||
Self {
|
||||
registry: Arc::new(ClusterRegistry::new(discovery_config)),
|
||||
config,
|
||||
active_queries: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the cluster registry
|
||||
pub fn registry(&self) -> Arc<ClusterRegistry> {
|
||||
Arc::clone(&self.registry)
|
||||
}
|
||||
|
||||
/// Execute a federated query across multiple clusters
|
||||
pub async fn execute_federated(
|
||||
&self,
|
||||
query: &str,
|
||||
target_clusters: Option<Vec<ClusterId>>,
|
||||
) -> Result<FederatedQueryResult> {
|
||||
let query_id = Uuid::new_v4().to_string();
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
// Determine target clusters
|
||||
let clusters = if let Some(targets) = target_clusters {
|
||||
targets
|
||||
.into_iter()
|
||||
.filter_map(|id| self.registry.get_cluster(&id))
|
||||
.collect()
|
||||
} else {
|
||||
self.registry.healthy_clusters()
|
||||
};
|
||||
|
||||
if clusters.is_empty() {
|
||||
return Err(GraphError::FederationError(
|
||||
"No healthy clusters available".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
info!(
|
||||
"Executing federated query {} across {} clusters",
|
||||
query_id,
|
||||
clusters.len()
|
||||
);
|
||||
|
||||
let federated_query = FederatedQuery {
|
||||
query_id: query_id.clone(),
|
||||
query: query.to_string(),
|
||||
target_clusters: clusters.iter().map(|c| c.cluster_id.clone()).collect(),
|
||||
strategy: self.config.default_strategy,
|
||||
created_at: Utc::now(),
|
||||
};
|
||||
|
||||
self.active_queries
|
||||
.insert(query_id.clone(), federated_query.clone());
|
||||
|
||||
// Execute query on each cluster based on strategy
|
||||
let mut cluster_results = HashMap::new();
|
||||
|
||||
match self.config.default_strategy {
|
||||
FederationStrategy::Parallel => {
|
||||
// Execute on all clusters in parallel
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for cluster in &clusters {
|
||||
let cluster_id = cluster.cluster_id.clone();
|
||||
let query_str = query.to_string();
|
||||
let cluster_clone = cluster.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
Self::execute_on_cluster(&cluster_clone, &query_str).await
|
||||
});
|
||||
|
||||
handles.push((cluster_id, handle));
|
||||
}
|
||||
|
||||
// Collect results
|
||||
for (cluster_id, handle) in handles {
|
||||
match handle.await {
|
||||
Ok(Ok(result)) => {
|
||||
cluster_results.insert(cluster_id, result);
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
warn!("Query failed on cluster {}: {}", cluster_id, e);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Task failed for cluster {}: {}", cluster_id, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
FederationStrategy::Sequential => {
|
||||
// Execute on clusters sequentially
|
||||
for cluster in &clusters {
|
||||
match Self::execute_on_cluster(cluster, query).await {
|
||||
Ok(result) => {
|
||||
cluster_results.insert(cluster.cluster_id.clone(), result);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Query failed on cluster {}: {}", cluster.cluster_id, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
FederationStrategy::Nearest | FederationStrategy::PrimaryWithFallback => {
|
||||
// Execute on first healthy cluster
|
||||
if let Some(cluster) = clusters.first() {
|
||||
match Self::execute_on_cluster(cluster, query).await {
|
||||
Ok(result) => {
|
||||
cluster_results.insert(cluster.cluster_id.clone(), result);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Query failed on cluster {}: {}", cluster.cluster_id, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Merge results from all clusters
|
||||
let merged_result = self.merge_results(cluster_results)?;
|
||||
|
||||
let execution_time_ms = start.elapsed().as_millis() as u64;
|
||||
|
||||
// Remove from active queries
|
||||
self.active_queries.remove(&query_id);
|
||||
|
||||
Ok(FederatedQueryResult {
|
||||
query_id,
|
||||
merged_result,
|
||||
clusters_queried: clusters.len(),
|
||||
execution_time_ms,
|
||||
})
|
||||
}
|
||||
|
||||
/// Execute query on a single remote cluster
|
||||
async fn execute_on_cluster(cluster: &RemoteCluster, query: &str) -> Result<QueryResult> {
|
||||
debug!("Executing query on cluster: {}", cluster.cluster_id);
|
||||
|
||||
// In production, make actual HTTP/gRPC call to remote cluster
|
||||
// For now, return empty result
|
||||
Ok(QueryResult {
|
||||
query_id: Uuid::new_v4().to_string(),
|
||||
nodes: Vec::new(),
|
||||
edges: Vec::new(),
|
||||
aggregates: HashMap::new(),
|
||||
stats: crate::distributed::coordinator::QueryStats {
|
||||
execution_time_ms: 0,
|
||||
shards_queried: 0,
|
||||
nodes_scanned: 0,
|
||||
edges_scanned: 0,
|
||||
cached: false,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/// Merge results from multiple clusters
|
||||
fn merge_results(&self, results: HashMap<ClusterId, QueryResult>) -> Result<QueryResult> {
|
||||
if results.is_empty() {
|
||||
return Err(GraphError::FederationError(
|
||||
"No results to merge".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut merged = QueryResult {
|
||||
query_id: Uuid::new_v4().to_string(),
|
||||
nodes: Vec::new(),
|
||||
edges: Vec::new(),
|
||||
aggregates: HashMap::new(),
|
||||
stats: crate::distributed::coordinator::QueryStats {
|
||||
execution_time_ms: 0,
|
||||
shards_queried: 0,
|
||||
nodes_scanned: 0,
|
||||
edges_scanned: 0,
|
||||
cached: false,
|
||||
},
|
||||
};
|
||||
|
||||
for (cluster_id, result) in results {
|
||||
debug!("Merging results from cluster: {}", cluster_id);
|
||||
|
||||
// Merge nodes (deduplicating by ID)
|
||||
for node in result.nodes {
|
||||
if !merged.nodes.iter().any(|n| n.id == node.id) {
|
||||
merged.nodes.push(node);
|
||||
}
|
||||
}
|
||||
|
||||
// Merge edges (deduplicating by ID)
|
||||
for edge in result.edges {
|
||||
if !merged.edges.iter().any(|e| e.id == edge.id) {
|
||||
merged.edges.push(edge);
|
||||
}
|
||||
}
|
||||
|
||||
// Merge aggregates
|
||||
for (key, value) in result.aggregates {
|
||||
merged
|
||||
.aggregates
|
||||
.insert(format!("{}_{}", cluster_id, key), value);
|
||||
}
|
||||
|
||||
// Aggregate stats
|
||||
merged.stats.execution_time_ms = merged
|
||||
.stats
|
||||
.execution_time_ms
|
||||
.max(result.stats.execution_time_ms);
|
||||
merged.stats.shards_queried += result.stats.shards_queried;
|
||||
merged.stats.nodes_scanned += result.stats.nodes_scanned;
|
||||
merged.stats.edges_scanned += result.stats.edges_scanned;
|
||||
}
|
||||
|
||||
Ok(merged)
|
||||
}
|
||||
|
||||
/// Get configuration
|
||||
pub fn config(&self) -> &FederationConfig {
|
||||
&self.config
|
||||
}
|
||||
}
|
||||
|
||||
/// Federation configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FederationConfig {
|
||||
/// Default federation strategy
|
||||
pub default_strategy: FederationStrategy,
|
||||
/// Maximum number of clusters to query
|
||||
pub max_clusters: usize,
|
||||
/// Query timeout in seconds
|
||||
pub query_timeout_seconds: u64,
|
||||
/// Enable result caching
|
||||
pub enable_caching: bool,
|
||||
}
|
||||
|
||||
impl Default for FederationConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
default_strategy: FederationStrategy::Parallel,
|
||||
max_clusters: 10,
|
||||
query_timeout_seconds: 30,
|
||||
enable_caching: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Federated query result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FederatedQueryResult {
|
||||
/// Query ID
|
||||
pub query_id: String,
|
||||
/// Merged result from all clusters
|
||||
pub merged_result: QueryResult,
|
||||
/// Number of clusters queried
|
||||
pub clusters_queried: usize,
|
||||
/// Total execution time
|
||||
pub execution_time_ms: u64,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cluster_registry() {
|
||||
let config = DiscoveryConfig::default();
|
||||
let registry = ClusterRegistry::new(config);
|
||||
|
||||
let cluster = RemoteCluster::new(
|
||||
"cluster-1".to_string(),
|
||||
"Test Cluster".to_string(),
|
||||
"http://localhost:8080".to_string(),
|
||||
);
|
||||
|
||||
registry.register_cluster(cluster.clone()).unwrap();
|
||||
|
||||
assert_eq!(registry.list_clusters().len(), 1);
|
||||
assert!(registry.get_cluster(&"cluster-1".to_string()).is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_federation() {
|
||||
let config = FederationConfig::default();
|
||||
let federation = Federation::new(config);
|
||||
|
||||
let cluster = RemoteCluster::new(
|
||||
"cluster-1".to_string(),
|
||||
"Test Cluster".to_string(),
|
||||
"http://localhost:8080".to_string(),
|
||||
);
|
||||
|
||||
federation.registry().register_cluster(cluster).unwrap();
|
||||
|
||||
// Test would execute federated query in production
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remote_cluster() {
|
||||
let cluster = RemoteCluster::new(
|
||||
"test".to_string(),
|
||||
"Test".to_string(),
|
||||
"http://localhost".to_string(),
|
||||
);
|
||||
|
||||
assert!(!cluster.is_healthy());
|
||||
}
|
||||
}
|
||||
623
vendor/ruvector/crates/ruvector-graph/src/distributed/gossip.rs
vendored
Normal file
623
vendor/ruvector/crates/ruvector-graph/src/distributed/gossip.rs
vendored
Normal file
@@ -0,0 +1,623 @@
|
||||
//! Gossip protocol for cluster membership and health monitoring
|
||||
//!
|
||||
//! Implements SWIM (Scalable Weakly-consistent Infection-style Membership) protocol:
|
||||
//! - Fast failure detection
|
||||
//! - Efficient membership propagation
|
||||
//! - Low network overhead
|
||||
//! - Automatic node discovery
|
||||
|
||||
use crate::{GraphError, Result};
|
||||
use chrono::{DateTime, Duration as ChronoDuration, Utc};
|
||||
use dashmap::DashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Node identifier in the cluster
|
||||
pub type NodeId = String;
|
||||
|
||||
/// Gossip message types
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum GossipMessage {
|
||||
/// Ping message for health check
|
||||
Ping {
|
||||
from: NodeId,
|
||||
sequence: u64,
|
||||
timestamp: DateTime<Utc>,
|
||||
},
|
||||
/// Ack response to ping
|
||||
Ack {
|
||||
from: NodeId,
|
||||
to: NodeId,
|
||||
sequence: u64,
|
||||
timestamp: DateTime<Utc>,
|
||||
},
|
||||
/// Indirect ping through intermediary
|
||||
IndirectPing {
|
||||
from: NodeId,
|
||||
target: NodeId,
|
||||
intermediary: NodeId,
|
||||
sequence: u64,
|
||||
},
|
||||
/// Membership update
|
||||
MembershipUpdate {
|
||||
from: NodeId,
|
||||
updates: Vec<MembershipEvent>,
|
||||
version: u64,
|
||||
},
|
||||
/// Join request
|
||||
Join {
|
||||
node_id: NodeId,
|
||||
address: SocketAddr,
|
||||
metadata: HashMap<String, String>,
|
||||
},
|
||||
/// Leave notification
|
||||
Leave { node_id: NodeId },
|
||||
}
|
||||
|
||||
/// Membership event types
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum MembershipEvent {
|
||||
/// Node joined the cluster
|
||||
Join {
|
||||
node_id: NodeId,
|
||||
address: SocketAddr,
|
||||
timestamp: DateTime<Utc>,
|
||||
},
|
||||
/// Node left the cluster
|
||||
Leave {
|
||||
node_id: NodeId,
|
||||
timestamp: DateTime<Utc>,
|
||||
},
|
||||
/// Node suspected to be failed
|
||||
Suspect {
|
||||
node_id: NodeId,
|
||||
timestamp: DateTime<Utc>,
|
||||
},
|
||||
/// Node confirmed alive
|
||||
Alive {
|
||||
node_id: NodeId,
|
||||
timestamp: DateTime<Utc>,
|
||||
},
|
||||
/// Node confirmed dead
|
||||
Dead {
|
||||
node_id: NodeId,
|
||||
timestamp: DateTime<Utc>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Node health status
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum NodeHealth {
|
||||
/// Node is healthy and responsive
|
||||
Alive,
|
||||
/// Node is suspected to be failed
|
||||
Suspect,
|
||||
/// Node is confirmed dead
|
||||
Dead,
|
||||
/// Node explicitly left
|
||||
Left,
|
||||
}
|
||||
|
||||
/// Member information in the gossip protocol
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Member {
|
||||
/// Node identifier
|
||||
pub node_id: NodeId,
|
||||
/// Network address
|
||||
pub address: SocketAddr,
|
||||
/// Current health status
|
||||
pub health: NodeHealth,
|
||||
/// Last time we heard from this node
|
||||
pub last_seen: DateTime<Utc>,
|
||||
/// Incarnation number (for conflict resolution)
|
||||
pub incarnation: u64,
|
||||
/// Node metadata
|
||||
pub metadata: HashMap<String, String>,
|
||||
/// Number of consecutive ping failures
|
||||
pub failure_count: u32,
|
||||
}
|
||||
|
||||
impl Member {
|
||||
/// Create a new member
|
||||
pub fn new(node_id: NodeId, address: SocketAddr) -> Self {
|
||||
Self {
|
||||
node_id,
|
||||
address,
|
||||
health: NodeHealth::Alive,
|
||||
last_seen: Utc::now(),
|
||||
incarnation: 0,
|
||||
metadata: HashMap::new(),
|
||||
failure_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if member is healthy
|
||||
pub fn is_healthy(&self) -> bool {
|
||||
matches!(self.health, NodeHealth::Alive)
|
||||
}
|
||||
|
||||
/// Mark as seen
|
||||
pub fn mark_seen(&mut self) {
|
||||
self.last_seen = Utc::now();
|
||||
self.failure_count = 0;
|
||||
if self.health != NodeHealth::Left {
|
||||
self.health = NodeHealth::Alive;
|
||||
}
|
||||
}
|
||||
|
||||
/// Increment failure count
|
||||
pub fn increment_failures(&mut self) {
|
||||
self.failure_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Gossip configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GossipConfig {
|
||||
/// Gossip interval in milliseconds
|
||||
pub gossip_interval_ms: u64,
|
||||
/// Number of nodes to gossip with per interval
|
||||
pub gossip_fanout: usize,
|
||||
/// Ping timeout in milliseconds
|
||||
pub ping_timeout_ms: u64,
|
||||
/// Number of ping failures before suspecting node
|
||||
pub suspect_threshold: u32,
|
||||
/// Number of indirect ping nodes
|
||||
pub indirect_ping_nodes: usize,
|
||||
/// Suspicion timeout in seconds
|
||||
pub suspicion_timeout_seconds: u64,
|
||||
}
|
||||
|
||||
impl Default for GossipConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
gossip_interval_ms: 1000,
|
||||
gossip_fanout: 3,
|
||||
ping_timeout_ms: 500,
|
||||
suspect_threshold: 3,
|
||||
indirect_ping_nodes: 3,
|
||||
suspicion_timeout_seconds: 30,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Gossip-based membership protocol
|
||||
pub struct GossipMembership {
|
||||
/// Local node ID
|
||||
local_node_id: NodeId,
|
||||
/// Local node address
|
||||
local_address: SocketAddr,
|
||||
/// Configuration
|
||||
config: GossipConfig,
|
||||
/// Cluster members
|
||||
members: Arc<DashMap<NodeId, Member>>,
|
||||
/// Membership version (incremented on changes)
|
||||
version: Arc<RwLock<u64>>,
|
||||
/// Pending acks
|
||||
pending_acks: Arc<DashMap<u64, PendingAck>>,
|
||||
/// Sequence number for messages
|
||||
sequence: Arc<RwLock<u64>>,
|
||||
/// Event listeners
|
||||
event_listeners: Arc<RwLock<Vec<Box<dyn Fn(MembershipEvent) + Send + Sync>>>>,
|
||||
}
|
||||
|
||||
/// Pending acknowledgment
|
||||
struct PendingAck {
|
||||
target: NodeId,
|
||||
sent_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl GossipMembership {
|
||||
/// Create a new gossip membership
|
||||
pub fn new(node_id: NodeId, address: SocketAddr, config: GossipConfig) -> Self {
|
||||
let members = Arc::new(DashMap::new());
|
||||
|
||||
// Add self to members
|
||||
let local_member = Member::new(node_id.clone(), address);
|
||||
members.insert(node_id.clone(), local_member);
|
||||
|
||||
Self {
|
||||
local_node_id: node_id,
|
||||
local_address: address,
|
||||
config,
|
||||
members,
|
||||
version: Arc::new(RwLock::new(0)),
|
||||
pending_acks: Arc::new(DashMap::new()),
|
||||
sequence: Arc::new(RwLock::new(0)),
|
||||
event_listeners: Arc::new(RwLock::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the gossip protocol
|
||||
pub async fn start(&self) -> Result<()> {
|
||||
info!("Starting gossip protocol for node: {}", self.local_node_id);
|
||||
|
||||
// Start periodic gossip
|
||||
let gossip_self = self.clone();
|
||||
tokio::spawn(async move {
|
||||
gossip_self.run_gossip_loop().await;
|
||||
});
|
||||
|
||||
// Start failure detection
|
||||
let detection_self = self.clone();
|
||||
tokio::spawn(async move {
|
||||
detection_self.run_failure_detection().await;
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add a seed node to join cluster
|
||||
pub async fn join(&self, seed_address: SocketAddr) -> Result<()> {
|
||||
info!("Joining cluster via seed: {}", seed_address);
|
||||
|
||||
// Send join message
|
||||
let join_msg = GossipMessage::Join {
|
||||
node_id: self.local_node_id.clone(),
|
||||
address: self.local_address,
|
||||
metadata: HashMap::new(),
|
||||
};
|
||||
|
||||
// In production, send actual network message
|
||||
// For now, just log
|
||||
debug!("Would send join message to {}", seed_address);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Leave the cluster gracefully
|
||||
pub async fn leave(&self) -> Result<()> {
|
||||
info!("Leaving cluster: {}", self.local_node_id);
|
||||
|
||||
// Update own status
|
||||
if let Some(mut member) = self.members.get_mut(&self.local_node_id) {
|
||||
member.health = NodeHealth::Left;
|
||||
}
|
||||
|
||||
// Broadcast leave message
|
||||
let leave_msg = GossipMessage::Leave {
|
||||
node_id: self.local_node_id.clone(),
|
||||
};
|
||||
|
||||
self.broadcast_event(MembershipEvent::Leave {
|
||||
node_id: self.local_node_id.clone(),
|
||||
timestamp: Utc::now(),
|
||||
})
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get all cluster members
|
||||
pub fn get_members(&self) -> Vec<Member> {
|
||||
self.members.iter().map(|e| e.value().clone()).collect()
|
||||
}
|
||||
|
||||
/// Get healthy members only
|
||||
pub fn get_healthy_members(&self) -> Vec<Member> {
|
||||
self.members
|
||||
.iter()
|
||||
.filter(|e| e.value().is_healthy())
|
||||
.map(|e| e.value().clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get a specific member
|
||||
pub fn get_member(&self, node_id: &NodeId) -> Option<Member> {
|
||||
self.members.get(node_id).map(|m| m.value().clone())
|
||||
}
|
||||
|
||||
/// Handle incoming gossip message
|
||||
pub async fn handle_message(&self, message: GossipMessage) -> Result<()> {
|
||||
match message {
|
||||
GossipMessage::Ping { from, sequence, .. } => self.handle_ping(from, sequence).await,
|
||||
GossipMessage::Ack { from, sequence, .. } => self.handle_ack(from, sequence).await,
|
||||
GossipMessage::MembershipUpdate { updates, .. } => {
|
||||
self.handle_membership_update(updates).await
|
||||
}
|
||||
GossipMessage::Join {
|
||||
node_id,
|
||||
address,
|
||||
metadata,
|
||||
} => self.handle_join(node_id, address, metadata).await,
|
||||
GossipMessage::Leave { node_id } => self.handle_leave(node_id).await,
|
||||
_ => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the gossip loop
|
||||
async fn run_gossip_loop(&self) {
|
||||
let interval = std::time::Duration::from_millis(self.config.gossip_interval_ms);
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
|
||||
// Select random members to gossip with
|
||||
let members = self.get_healthy_members();
|
||||
let targets: Vec<_> = members
|
||||
.into_iter()
|
||||
.filter(|m| m.node_id != self.local_node_id)
|
||||
.take(self.config.gossip_fanout)
|
||||
.collect();
|
||||
|
||||
for target in targets {
|
||||
self.send_ping(target.node_id).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Run failure detection
|
||||
async fn run_failure_detection(&self) {
|
||||
let interval = std::time::Duration::from_secs(5);
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
|
||||
let now = Utc::now();
|
||||
let timeout = ChronoDuration::seconds(self.config.suspicion_timeout_seconds as i64);
|
||||
|
||||
for mut entry in self.members.iter_mut() {
|
||||
let member = entry.value_mut();
|
||||
|
||||
if member.node_id == self.local_node_id {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if node has timed out
|
||||
if member.health == NodeHealth::Suspect {
|
||||
let elapsed = now.signed_duration_since(member.last_seen);
|
||||
if elapsed > timeout {
|
||||
debug!("Marking node as dead: {}", member.node_id);
|
||||
member.health = NodeHealth::Dead;
|
||||
|
||||
let event = MembershipEvent::Dead {
|
||||
node_id: member.node_id.clone(),
|
||||
timestamp: now,
|
||||
};
|
||||
|
||||
self.emit_event(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Send ping to a node
|
||||
async fn send_ping(&self, target: NodeId) {
|
||||
let mut seq = self.sequence.write().await;
|
||||
*seq += 1;
|
||||
let sequence = *seq;
|
||||
drop(seq);
|
||||
|
||||
let ping = GossipMessage::Ping {
|
||||
from: self.local_node_id.clone(),
|
||||
sequence,
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
|
||||
// Track pending ack
|
||||
self.pending_acks.insert(
|
||||
sequence,
|
||||
PendingAck {
|
||||
target: target.clone(),
|
||||
sent_at: Utc::now(),
|
||||
},
|
||||
);
|
||||
|
||||
debug!("Sending ping to {}", target);
|
||||
// In production, send actual network message
|
||||
}
|
||||
|
||||
/// Handle ping message
|
||||
async fn handle_ping(&self, from: NodeId, sequence: u64) -> Result<()> {
|
||||
debug!("Received ping from {}", from);
|
||||
|
||||
// Update member status
|
||||
if let Some(mut member) = self.members.get_mut(&from) {
|
||||
member.mark_seen();
|
||||
}
|
||||
|
||||
// Send ack
|
||||
let ack = GossipMessage::Ack {
|
||||
from: self.local_node_id.clone(),
|
||||
to: from,
|
||||
sequence,
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
|
||||
// In production, send actual network message
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle ack message
|
||||
async fn handle_ack(&self, from: NodeId, sequence: u64) -> Result<()> {
|
||||
debug!("Received ack from {}", from);
|
||||
|
||||
// Remove from pending
|
||||
self.pending_acks.remove(&sequence);
|
||||
|
||||
// Update member status
|
||||
if let Some(mut member) = self.members.get_mut(&from) {
|
||||
member.mark_seen();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle membership update
|
||||
async fn handle_membership_update(&self, updates: Vec<MembershipEvent>) -> Result<()> {
|
||||
for event in updates {
|
||||
match &event {
|
||||
MembershipEvent::Join {
|
||||
node_id, address, ..
|
||||
} => {
|
||||
if !self.members.contains_key(node_id) {
|
||||
let member = Member::new(node_id.clone(), *address);
|
||||
self.members.insert(node_id.clone(), member);
|
||||
}
|
||||
}
|
||||
MembershipEvent::Suspect { node_id, .. } => {
|
||||
if let Some(mut member) = self.members.get_mut(node_id) {
|
||||
member.health = NodeHealth::Suspect;
|
||||
}
|
||||
}
|
||||
MembershipEvent::Dead { node_id, .. } => {
|
||||
if let Some(mut member) = self.members.get_mut(node_id) {
|
||||
member.health = NodeHealth::Dead;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
self.emit_event(event);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle join request
|
||||
async fn handle_join(
|
||||
&self,
|
||||
node_id: NodeId,
|
||||
address: SocketAddr,
|
||||
metadata: HashMap<String, String>,
|
||||
) -> Result<()> {
|
||||
info!("Node joining: {}", node_id);
|
||||
|
||||
let mut member = Member::new(node_id.clone(), address);
|
||||
member.metadata = metadata;
|
||||
|
||||
self.members.insert(node_id.clone(), member);
|
||||
|
||||
let event = MembershipEvent::Join {
|
||||
node_id,
|
||||
address,
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
|
||||
self.broadcast_event(event).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle leave notification
|
||||
async fn handle_leave(&self, node_id: NodeId) -> Result<()> {
|
||||
info!("Node leaving: {}", node_id);
|
||||
|
||||
if let Some(mut member) = self.members.get_mut(&node_id) {
|
||||
member.health = NodeHealth::Left;
|
||||
}
|
||||
|
||||
let event = MembershipEvent::Leave {
|
||||
node_id,
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
|
||||
self.emit_event(event);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Broadcast event to all members
|
||||
async fn broadcast_event(&self, event: MembershipEvent) {
|
||||
let mut version = self.version.write().await;
|
||||
*version += 1;
|
||||
drop(version);
|
||||
|
||||
self.emit_event(event);
|
||||
}
|
||||
|
||||
/// Emit event to listeners
|
||||
fn emit_event(&self, event: MembershipEvent) {
|
||||
// In production, call event listeners
|
||||
debug!("Membership event: {:?}", event);
|
||||
}
|
||||
|
||||
/// Add event listener
|
||||
pub async fn add_listener<F>(&self, listener: F)
|
||||
where
|
||||
F: Fn(MembershipEvent) + Send + Sync + 'static,
|
||||
{
|
||||
let mut listeners = self.event_listeners.write().await;
|
||||
listeners.push(Box::new(listener));
|
||||
}
|
||||
|
||||
/// Get membership version
|
||||
pub async fn get_version(&self) -> u64 {
|
||||
*self.version.read().await
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for GossipMembership {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
local_node_id: self.local_node_id.clone(),
|
||||
local_address: self.local_address,
|
||||
config: self.config.clone(),
|
||||
members: Arc::clone(&self.members),
|
||||
version: Arc::clone(&self.version),
|
||||
pending_acks: Arc::clone(&self.pending_acks),
|
||||
sequence: Arc::clone(&self.sequence),
|
||||
event_listeners: Arc::clone(&self.event_listeners),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
|
||||
fn create_test_address(port: u16) -> SocketAddr {
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_gossip_membership() {
|
||||
let config = GossipConfig::default();
|
||||
let address = create_test_address(8000);
|
||||
let gossip = GossipMembership::new("node-1".to_string(), address, config);
|
||||
|
||||
assert_eq!(gossip.get_members().len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_join_leave() {
|
||||
let config = GossipConfig::default();
|
||||
let address1 = create_test_address(8000);
|
||||
let address2 = create_test_address(8001);
|
||||
|
||||
let gossip = GossipMembership::new("node-1".to_string(), address1, config);
|
||||
|
||||
gossip
|
||||
.handle_join("node-2".to_string(), address2, HashMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(gossip.get_members().len(), 2);
|
||||
|
||||
gossip.handle_leave("node-2".to_string()).await.unwrap();
|
||||
|
||||
let member = gossip.get_member(&"node-2".to_string()).unwrap();
|
||||
assert_eq!(member.health, NodeHealth::Left);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_member() {
|
||||
let address = create_test_address(8000);
|
||||
let mut member = Member::new("test".to_string(), address);
|
||||
|
||||
assert!(member.is_healthy());
|
||||
|
||||
member.health = NodeHealth::Suspect;
|
||||
assert!(!member.is_healthy());
|
||||
|
||||
member.mark_seen();
|
||||
assert!(member.is_healthy());
|
||||
}
|
||||
}
|
||||
25
vendor/ruvector/crates/ruvector-graph/src/distributed/mod.rs
vendored
Normal file
25
vendor/ruvector/crates/ruvector-graph/src/distributed/mod.rs
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
//! Distributed graph query capabilities
|
||||
//!
|
||||
//! This module provides comprehensive distributed and federated graph operations:
|
||||
//! - Graph sharding with multiple partitioning strategies
|
||||
//! - Distributed query coordination and execution
|
||||
//! - Cross-cluster federation for multi-cluster queries
|
||||
//! - Graph-aware replication extending ruvector-replication
|
||||
//! - Gossip-based cluster membership and health monitoring
|
||||
//! - High-performance gRPC communication layer
|
||||
|
||||
pub mod coordinator;
|
||||
pub mod federation;
|
||||
pub mod gossip;
|
||||
pub mod replication;
|
||||
pub mod rpc;
|
||||
pub mod shard;
|
||||
|
||||
pub use coordinator::{Coordinator, QueryPlan, ShardCoordinator};
|
||||
pub use federation::{ClusterRegistry, FederatedQuery, Federation, RemoteCluster};
|
||||
pub use gossip::{GossipConfig, GossipMembership, MembershipEvent, NodeHealth};
|
||||
pub use replication::{GraphReplication, GraphReplicationConfig, ReplicationStrategy};
|
||||
pub use rpc::{GraphRpcService, RpcClient, RpcServer};
|
||||
pub use shard::{
|
||||
EdgeCutMinimizer, GraphShard, HashPartitioner, RangePartitioner, ShardMetadata, ShardStrategy,
|
||||
};
|
||||
407
vendor/ruvector/crates/ruvector-graph/src/distributed/replication.rs
vendored
Normal file
407
vendor/ruvector/crates/ruvector-graph/src/distributed/replication.rs
vendored
Normal file
@@ -0,0 +1,407 @@
|
||||
//! Graph-aware data replication extending ruvector-replication
|
||||
//!
|
||||
//! Provides graph-specific replication strategies:
|
||||
//! - Vertex-cut replication for high-degree nodes
|
||||
//! - Edge replication with consistency guarantees
|
||||
//! - Subgraph replication for locality
|
||||
//! - Conflict-free replicated graphs (CRG)
|
||||
|
||||
use crate::distributed::shard::{EdgeData, GraphShard, NodeData, NodeId, ShardId};
|
||||
use crate::{GraphError, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use dashmap::DashMap;
|
||||
use ruvector_replication::{
|
||||
Replica, ReplicaRole, ReplicaSet, ReplicationLog, SyncManager, SyncMode,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Graph replication strategy
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum ReplicationStrategy {
|
||||
/// Replicate entire shards
|
||||
FullShard,
|
||||
/// Replicate high-degree nodes (vertex-cut)
|
||||
VertexCut,
|
||||
/// Replicate based on subgraph locality
|
||||
Subgraph,
|
||||
/// Hybrid approach
|
||||
Hybrid,
|
||||
}
|
||||
|
||||
/// Graph replication configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GraphReplicationConfig {
|
||||
/// Replication factor (number of copies)
|
||||
pub replication_factor: usize,
|
||||
/// Replication strategy
|
||||
pub strategy: ReplicationStrategy,
|
||||
/// High-degree threshold for vertex-cut
|
||||
pub high_degree_threshold: usize,
|
||||
/// Synchronization mode
|
||||
pub sync_mode: SyncMode,
|
||||
/// Enable conflict resolution
|
||||
pub enable_conflict_resolution: bool,
|
||||
/// Replication timeout in seconds
|
||||
pub timeout_seconds: u64,
|
||||
}
|
||||
|
||||
impl Default for GraphReplicationConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
replication_factor: 3,
|
||||
strategy: ReplicationStrategy::FullShard,
|
||||
high_degree_threshold: 100,
|
||||
sync_mode: SyncMode::Async,
|
||||
enable_conflict_resolution: true,
|
||||
timeout_seconds: 30,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Graph replication manager
|
||||
pub struct GraphReplication {
|
||||
/// Configuration
|
||||
config: GraphReplicationConfig,
|
||||
/// Replica sets per shard
|
||||
replica_sets: Arc<DashMap<ShardId, Arc<ReplicaSet>>>,
|
||||
/// Sync managers per shard
|
||||
sync_managers: Arc<DashMap<ShardId, Arc<SyncManager>>>,
|
||||
/// High-degree nodes (for vertex-cut replication)
|
||||
high_degree_nodes: Arc<DashMap<NodeId, usize>>,
|
||||
/// Node replication metadata
|
||||
node_replicas: Arc<DashMap<NodeId, Vec<String>>>,
|
||||
}
|
||||
|
||||
impl GraphReplication {
|
||||
/// Create a new graph replication manager
|
||||
pub fn new(config: GraphReplicationConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
replica_sets: Arc::new(DashMap::new()),
|
||||
sync_managers: Arc::new(DashMap::new()),
|
||||
high_degree_nodes: Arc::new(DashMap::new()),
|
||||
node_replicas: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize replication for a shard
|
||||
pub fn initialize_shard_replication(
|
||||
&self,
|
||||
shard_id: ShardId,
|
||||
primary_node: String,
|
||||
replica_nodes: Vec<String>,
|
||||
) -> Result<()> {
|
||||
info!(
|
||||
"Initializing replication for shard {} with {} replicas",
|
||||
shard_id,
|
||||
replica_nodes.len()
|
||||
);
|
||||
|
||||
// Create replica set
|
||||
let mut replica_set = ReplicaSet::new(format!("shard-{}", shard_id));
|
||||
|
||||
// Add primary replica
|
||||
replica_set
|
||||
.add_replica(
|
||||
&primary_node,
|
||||
&format!("{}:9001", primary_node),
|
||||
ReplicaRole::Primary,
|
||||
)
|
||||
.map_err(|e| GraphError::ReplicationError(e))?;
|
||||
|
||||
// Add secondary replicas
|
||||
for (idx, node) in replica_nodes.iter().enumerate() {
|
||||
replica_set
|
||||
.add_replica(
|
||||
&format!("{}-replica-{}", node, idx),
|
||||
&format!("{}:9001", node),
|
||||
ReplicaRole::Secondary,
|
||||
)
|
||||
.map_err(|e| GraphError::ReplicationError(e))?;
|
||||
}
|
||||
|
||||
let replica_set = Arc::new(replica_set);
|
||||
|
||||
// Create replication log
|
||||
let log = Arc::new(ReplicationLog::new(&primary_node));
|
||||
|
||||
// Create sync manager
|
||||
let sync_manager = Arc::new(SyncManager::new(Arc::clone(&replica_set), log));
|
||||
sync_manager.set_sync_mode(self.config.sync_mode.clone());
|
||||
|
||||
self.replica_sets.insert(shard_id, replica_set);
|
||||
self.sync_managers.insert(shard_id, sync_manager);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Replicate a node addition
|
||||
pub async fn replicate_node_add(&self, shard_id: ShardId, node: NodeData) -> Result<()> {
|
||||
debug!(
|
||||
"Replicating node addition: {} to shard {}",
|
||||
node.id, shard_id
|
||||
);
|
||||
|
||||
// Determine replication strategy
|
||||
match self.config.strategy {
|
||||
ReplicationStrategy::FullShard => {
|
||||
self.replicate_to_shard(shard_id, ReplicationOp::AddNode(node))
|
||||
.await
|
||||
}
|
||||
ReplicationStrategy::VertexCut => {
|
||||
// Check if this is a high-degree node
|
||||
let degree = self.get_node_degree(&node.id);
|
||||
if degree >= self.config.high_degree_threshold {
|
||||
// Replicate to multiple shards
|
||||
self.replicate_high_degree_node(node).await
|
||||
} else {
|
||||
self.replicate_to_shard(shard_id, ReplicationOp::AddNode(node))
|
||||
.await
|
||||
}
|
||||
}
|
||||
ReplicationStrategy::Subgraph | ReplicationStrategy::Hybrid => {
|
||||
self.replicate_to_shard(shard_id, ReplicationOp::AddNode(node))
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Replicate an edge addition
|
||||
pub async fn replicate_edge_add(&self, shard_id: ShardId, edge: EdgeData) -> Result<()> {
|
||||
debug!(
|
||||
"Replicating edge addition: {} to shard {}",
|
||||
edge.id, shard_id
|
||||
);
|
||||
|
||||
// Update degree information
|
||||
self.increment_node_degree(&edge.from);
|
||||
self.increment_node_degree(&edge.to);
|
||||
|
||||
self.replicate_to_shard(shard_id, ReplicationOp::AddEdge(edge))
|
||||
.await
|
||||
}
|
||||
|
||||
/// Replicate a node deletion
|
||||
pub async fn replicate_node_delete(&self, shard_id: ShardId, node_id: NodeId) -> Result<()> {
|
||||
debug!(
|
||||
"Replicating node deletion: {} from shard {}",
|
||||
node_id, shard_id
|
||||
);
|
||||
|
||||
self.replicate_to_shard(shard_id, ReplicationOp::DeleteNode(node_id))
|
||||
.await
|
||||
}
|
||||
|
||||
/// Replicate an edge deletion
|
||||
pub async fn replicate_edge_delete(&self, shard_id: ShardId, edge_id: String) -> Result<()> {
|
||||
debug!(
|
||||
"Replicating edge deletion: {} from shard {}",
|
||||
edge_id, shard_id
|
||||
);
|
||||
|
||||
self.replicate_to_shard(shard_id, ReplicationOp::DeleteEdge(edge_id))
|
||||
.await
|
||||
}
|
||||
|
||||
/// Replicate operation to all replicas of a shard
|
||||
async fn replicate_to_shard(&self, shard_id: ShardId, op: ReplicationOp) -> Result<()> {
|
||||
let sync_manager = self
|
||||
.sync_managers
|
||||
.get(&shard_id)
|
||||
.ok_or_else(|| GraphError::ShardError(format!("Shard {} not initialized", shard_id)))?;
|
||||
|
||||
// Serialize operation
|
||||
let data = bincode::encode_to_vec(&op, bincode::config::standard())
|
||||
.map_err(|e| GraphError::SerializationError(e.to_string()))?;
|
||||
|
||||
// Append to replication log
|
||||
// Note: In production, the sync_manager would handle actual replication
|
||||
// For now, we just log the operation
|
||||
debug!("Replicating operation for shard {}", shard_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Replicate high-degree node to multiple shards
|
||||
async fn replicate_high_degree_node(&self, node: NodeData) -> Result<()> {
|
||||
info!(
|
||||
"Replicating high-degree node {} to multiple shards",
|
||||
node.id
|
||||
);
|
||||
|
||||
// Replicate to additional shards based on degree
|
||||
let degree = self.get_node_degree(&node.id);
|
||||
let replica_count =
|
||||
(degree / self.config.high_degree_threshold).min(self.config.replication_factor);
|
||||
|
||||
let mut replica_shards = Vec::new();
|
||||
|
||||
// Select shards for replication
|
||||
for shard_id in 0..replica_count {
|
||||
replica_shards.push(shard_id as ShardId);
|
||||
}
|
||||
|
||||
// Replicate to each shard
|
||||
for shard_id in replica_shards.clone() {
|
||||
self.replicate_to_shard(shard_id, ReplicationOp::AddNode(node.clone()))
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Store replica locations
|
||||
self.node_replicas.insert(
|
||||
node.id.clone(),
|
||||
replica_shards.iter().map(|s| s.to_string()).collect(),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get node degree
|
||||
fn get_node_degree(&self, node_id: &NodeId) -> usize {
|
||||
self.high_degree_nodes
|
||||
.get(node_id)
|
||||
.map(|d| *d.value())
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Increment node degree
|
||||
fn increment_node_degree(&self, node_id: &NodeId) {
|
||||
self.high_degree_nodes
|
||||
.entry(node_id.clone())
|
||||
.and_modify(|d| *d += 1)
|
||||
.or_insert(1);
|
||||
}
|
||||
|
||||
/// Get replica set for a shard
|
||||
pub fn get_replica_set(&self, shard_id: ShardId) -> Option<Arc<ReplicaSet>> {
|
||||
self.replica_sets
|
||||
.get(&shard_id)
|
||||
.map(|r| Arc::clone(r.value()))
|
||||
}
|
||||
|
||||
/// Get sync manager for a shard
|
||||
pub fn get_sync_manager(&self, shard_id: ShardId) -> Option<Arc<SyncManager>> {
|
||||
self.sync_managers
|
||||
.get(&shard_id)
|
||||
.map(|s| Arc::clone(s.value()))
|
||||
}
|
||||
|
||||
/// Get replication statistics
|
||||
pub fn get_stats(&self) -> ReplicationStats {
|
||||
ReplicationStats {
|
||||
total_shards: self.replica_sets.len(),
|
||||
high_degree_nodes: self.high_degree_nodes.len(),
|
||||
replicated_nodes: self.node_replicas.len(),
|
||||
strategy: self.config.strategy,
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform health check on all replicas
|
||||
pub async fn health_check(&self) -> HashMap<ShardId, ReplicaHealth> {
|
||||
let mut health = HashMap::new();
|
||||
|
||||
for entry in self.replica_sets.iter() {
|
||||
let shard_id = *entry.key();
|
||||
let replica_set = entry.value();
|
||||
|
||||
// In production, check actual replica health
|
||||
let healthy_count = self.config.replication_factor;
|
||||
|
||||
health.insert(
|
||||
shard_id,
|
||||
ReplicaHealth {
|
||||
total_replicas: self.config.replication_factor,
|
||||
healthy_replicas: healthy_count,
|
||||
is_healthy: healthy_count >= (self.config.replication_factor / 2 + 1),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
health
|
||||
}
|
||||
|
||||
/// Get configuration
|
||||
pub fn config(&self) -> &GraphReplicationConfig {
|
||||
&self.config
|
||||
}
|
||||
}
|
||||
|
||||
/// Replication operation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
enum ReplicationOp {
|
||||
AddNode(NodeData),
|
||||
AddEdge(EdgeData),
|
||||
DeleteNode(NodeId),
|
||||
DeleteEdge(String),
|
||||
UpdateNode(NodeData),
|
||||
UpdateEdge(EdgeData),
|
||||
}
|
||||
|
||||
/// Replication statistics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReplicationStats {
|
||||
pub total_shards: usize,
|
||||
pub high_degree_nodes: usize,
|
||||
pub replicated_nodes: usize,
|
||||
pub strategy: ReplicationStrategy,
|
||||
}
|
||||
|
||||
/// Replica health information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReplicaHealth {
|
||||
pub total_replicas: usize,
|
||||
pub healthy_replicas: usize,
|
||||
pub is_healthy: bool,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_graph_replication() {
|
||||
let config = GraphReplicationConfig::default();
|
||||
let replication = GraphReplication::new(config);
|
||||
|
||||
replication
|
||||
.initialize_shard_replication(0, "node-1".to_string(), vec!["node-2".to_string()])
|
||||
.unwrap();
|
||||
|
||||
assert!(replication.get_replica_set(0).is_some());
|
||||
assert!(replication.get_sync_manager(0).is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_node_replication() {
|
||||
let config = GraphReplicationConfig::default();
|
||||
let replication = GraphReplication::new(config);
|
||||
|
||||
replication
|
||||
.initialize_shard_replication(0, "node-1".to_string(), vec!["node-2".to_string()])
|
||||
.unwrap();
|
||||
|
||||
let node = NodeData {
|
||||
id: "test-node".to_string(),
|
||||
properties: HashMap::new(),
|
||||
labels: vec!["Test".to_string()],
|
||||
};
|
||||
|
||||
let result = replication.replicate_node_add(0, node).await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replication_stats() {
|
||||
let config = GraphReplicationConfig::default();
|
||||
let replication = GraphReplication::new(config);
|
||||
|
||||
let stats = replication.get_stats();
|
||||
assert_eq!(stats.total_shards, 0);
|
||||
assert_eq!(stats.strategy, ReplicationStrategy::FullShard);
|
||||
}
|
||||
}
|
||||
515
vendor/ruvector/crates/ruvector-graph/src/distributed/rpc.rs
vendored
Normal file
515
vendor/ruvector/crates/ruvector-graph/src/distributed/rpc.rs
vendored
Normal file
@@ -0,0 +1,515 @@
|
||||
//! gRPC-based inter-node communication for distributed graph queries
|
||||
//!
|
||||
//! Provides high-performance RPC communication layer:
|
||||
//! - Query execution RPC
|
||||
//! - Data replication RPC
|
||||
//! - Cluster coordination RPC
|
||||
//! - Streaming results for large queries
|
||||
|
||||
use crate::distributed::coordinator::{QueryPlan, QueryResult};
|
||||
use crate::distributed::shard::{EdgeData, NodeData, NodeId, ShardId};
|
||||
use crate::{GraphError, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
#[cfg(feature = "federation")]
|
||||
use tonic::{Request, Response, Status};
|
||||
|
||||
#[cfg(not(feature = "federation"))]
|
||||
pub struct Status;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// RPC request for executing a query
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExecuteQueryRequest {
|
||||
/// Query to execute (Cypher syntax)
|
||||
pub query: String,
|
||||
/// Optional parameters
|
||||
pub parameters: std::collections::HashMap<String, serde_json::Value>,
|
||||
/// Transaction ID (if part of a transaction)
|
||||
pub transaction_id: Option<String>,
|
||||
}
|
||||
|
||||
/// RPC response for query execution
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExecuteQueryResponse {
|
||||
/// Query result
|
||||
pub result: QueryResult,
|
||||
/// Success indicator
|
||||
pub success: bool,
|
||||
/// Error message if failed
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// RPC request for replicating data
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReplicateDataRequest {
|
||||
/// Shard ID to replicate to
|
||||
pub shard_id: ShardId,
|
||||
/// Operation type
|
||||
pub operation: ReplicationOperation,
|
||||
}
|
||||
|
||||
/// Replication operation types
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ReplicationOperation {
|
||||
AddNode(NodeData),
|
||||
AddEdge(EdgeData),
|
||||
DeleteNode(NodeId),
|
||||
DeleteEdge(String),
|
||||
UpdateNode(NodeData),
|
||||
UpdateEdge(EdgeData),
|
||||
}
|
||||
|
||||
/// RPC response for replication
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReplicateDataResponse {
|
||||
/// Success indicator
|
||||
pub success: bool,
|
||||
/// Error message if failed
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// RPC request for health check
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HealthCheckRequest {
|
||||
/// Node ID performing the check
|
||||
pub node_id: String,
|
||||
}
|
||||
|
||||
/// RPC response for health check
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HealthCheckResponse {
|
||||
/// Node is healthy
|
||||
pub healthy: bool,
|
||||
/// Current load (0.0 - 1.0)
|
||||
pub load: f64,
|
||||
/// Number of active queries
|
||||
pub active_queries: usize,
|
||||
/// Uptime in seconds
|
||||
pub uptime_seconds: u64,
|
||||
}
|
||||
|
||||
/// RPC request for shard info
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GetShardInfoRequest {
|
||||
/// Shard ID
|
||||
pub shard_id: ShardId,
|
||||
}
|
||||
|
||||
/// RPC response for shard info
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GetShardInfoResponse {
|
||||
/// Shard ID
|
||||
pub shard_id: ShardId,
|
||||
/// Number of nodes
|
||||
pub node_count: usize,
|
||||
/// Number of edges
|
||||
pub edge_count: usize,
|
||||
/// Shard size in bytes
|
||||
pub size_bytes: u64,
|
||||
}
|
||||
|
||||
/// Graph RPC service trait (would be implemented via tonic in production)
|
||||
#[cfg(feature = "federation")]
|
||||
#[tonic::async_trait]
|
||||
pub trait GraphRpcService: Send + Sync {
|
||||
/// Execute a query on this node
|
||||
async fn execute_query(
|
||||
&self,
|
||||
request: ExecuteQueryRequest,
|
||||
) -> std::result::Result<ExecuteQueryResponse, Status>;
|
||||
|
||||
/// Replicate data to this node
|
||||
async fn replicate_data(
|
||||
&self,
|
||||
request: ReplicateDataRequest,
|
||||
) -> std::result::Result<ReplicateDataResponse, Status>;
|
||||
|
||||
/// Health check
|
||||
async fn health_check(
|
||||
&self,
|
||||
request: HealthCheckRequest,
|
||||
) -> std::result::Result<HealthCheckResponse, Status>;
|
||||
|
||||
/// Get shard information
|
||||
async fn get_shard_info(
|
||||
&self,
|
||||
request: GetShardInfoRequest,
|
||||
) -> std::result::Result<GetShardInfoResponse, Status>;
|
||||
}
|
||||
|
||||
/// RPC client for communicating with remote nodes
|
||||
pub struct RpcClient {
|
||||
/// Target node address
|
||||
target_address: String,
|
||||
/// Connection timeout in seconds
|
||||
timeout_seconds: u64,
|
||||
}
|
||||
|
||||
impl RpcClient {
|
||||
/// Create a new RPC client
|
||||
pub fn new(target_address: String) -> Self {
|
||||
Self {
|
||||
target_address,
|
||||
timeout_seconds: 30,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set connection timeout
|
||||
pub fn with_timeout(mut self, timeout_seconds: u64) -> Self {
|
||||
self.timeout_seconds = timeout_seconds;
|
||||
self
|
||||
}
|
||||
|
||||
/// Execute a query on the remote node
|
||||
pub async fn execute_query(
|
||||
&self,
|
||||
request: ExecuteQueryRequest,
|
||||
) -> Result<ExecuteQueryResponse> {
|
||||
debug!(
|
||||
"Executing remote query on {}: {}",
|
||||
self.target_address, request.query
|
||||
);
|
||||
|
||||
// In production, make actual gRPC call using tonic
|
||||
// For now, simulate response
|
||||
Ok(ExecuteQueryResponse {
|
||||
result: QueryResult {
|
||||
query_id: uuid::Uuid::new_v4().to_string(),
|
||||
nodes: Vec::new(),
|
||||
edges: Vec::new(),
|
||||
aggregates: std::collections::HashMap::new(),
|
||||
stats: crate::distributed::coordinator::QueryStats {
|
||||
execution_time_ms: 0,
|
||||
shards_queried: 0,
|
||||
nodes_scanned: 0,
|
||||
edges_scanned: 0,
|
||||
cached: false,
|
||||
},
|
||||
},
|
||||
success: true,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Replicate data to the remote node
|
||||
pub async fn replicate_data(
|
||||
&self,
|
||||
request: ReplicateDataRequest,
|
||||
) -> Result<ReplicateDataResponse> {
|
||||
debug!(
|
||||
"Replicating data to {} for shard {}",
|
||||
self.target_address, request.shard_id
|
||||
);
|
||||
|
||||
// In production, make actual gRPC call
|
||||
Ok(ReplicateDataResponse {
|
||||
success: true,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Perform health check on remote node
|
||||
pub async fn health_check(&self, node_id: String) -> Result<HealthCheckResponse> {
|
||||
debug!("Health check on {}", self.target_address);
|
||||
|
||||
// In production, make actual gRPC call
|
||||
Ok(HealthCheckResponse {
|
||||
healthy: true,
|
||||
load: 0.5,
|
||||
active_queries: 0,
|
||||
uptime_seconds: 3600,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get shard information from remote node
|
||||
pub async fn get_shard_info(&self, shard_id: ShardId) -> Result<GetShardInfoResponse> {
|
||||
debug!(
|
||||
"Getting shard info for {} from {}",
|
||||
shard_id, self.target_address
|
||||
);
|
||||
|
||||
// In production, make actual gRPC call
|
||||
Ok(GetShardInfoResponse {
|
||||
shard_id,
|
||||
node_count: 0,
|
||||
edge_count: 0,
|
||||
size_bytes: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// RPC server for handling incoming requests
|
||||
#[cfg(feature = "federation")]
|
||||
pub struct RpcServer {
|
||||
/// Server address to bind to
|
||||
bind_address: String,
|
||||
/// Service implementation
|
||||
service: Arc<dyn GraphRpcService>,
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "federation"))]
|
||||
pub struct RpcServer {
|
||||
/// Server address to bind to
|
||||
bind_address: String,
|
||||
}
|
||||
|
||||
#[cfg(feature = "federation")]
|
||||
impl RpcServer {
|
||||
/// Create a new RPC server
|
||||
pub fn new(bind_address: String, service: Arc<dyn GraphRpcService>) -> Self {
|
||||
Self {
|
||||
bind_address,
|
||||
service,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the RPC server
|
||||
pub async fn start(&self) -> Result<()> {
|
||||
info!("Starting RPC server on {}", self.bind_address);
|
||||
|
||||
// In production, start actual gRPC server using tonic
|
||||
// For now, just log
|
||||
debug!("RPC server would start on {}", self.bind_address);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop the RPC server
|
||||
pub async fn stop(&self) -> Result<()> {
|
||||
info!("Stopping RPC server");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "federation"))]
|
||||
impl RpcServer {
|
||||
/// Create a new RPC server
|
||||
pub fn new(bind_address: String) -> Self {
|
||||
Self { bind_address }
|
||||
}
|
||||
|
||||
/// Start the RPC server
|
||||
pub async fn start(&self) -> Result<()> {
|
||||
info!("Starting RPC server on {}", self.bind_address);
|
||||
|
||||
// In production, start actual gRPC server using tonic
|
||||
// For now, just log
|
||||
debug!("RPC server would start on {}", self.bind_address);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop the RPC server
|
||||
pub async fn stop(&self) -> Result<()> {
|
||||
info!("Stopping RPC server");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Default implementation of GraphRpcService
|
||||
#[cfg(feature = "federation")]
|
||||
pub struct DefaultGraphRpcService {
|
||||
/// Node ID
|
||||
node_id: String,
|
||||
/// Start time for uptime calculation
|
||||
start_time: std::time::Instant,
|
||||
/// Active queries counter
|
||||
active_queries: Arc<RwLock<usize>>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "federation")]
|
||||
impl DefaultGraphRpcService {
|
||||
/// Create a new default service
|
||||
pub fn new(node_id: String) -> Self {
|
||||
Self {
|
||||
node_id,
|
||||
start_time: std::time::Instant::now(),
|
||||
active_queries: Arc::new(RwLock::new(0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "federation")]
|
||||
#[tonic::async_trait]
|
||||
impl GraphRpcService for DefaultGraphRpcService {
|
||||
async fn execute_query(
|
||||
&self,
|
||||
request: ExecuteQueryRequest,
|
||||
) -> std::result::Result<ExecuteQueryResponse, Status> {
|
||||
// Increment active queries
|
||||
{
|
||||
let mut count = self.active_queries.write().await;
|
||||
*count += 1;
|
||||
}
|
||||
|
||||
debug!("Executing query: {}", request.query);
|
||||
|
||||
// In production, execute actual query
|
||||
let result = QueryResult {
|
||||
query_id: uuid::Uuid::new_v4().to_string(),
|
||||
nodes: Vec::new(),
|
||||
edges: Vec::new(),
|
||||
aggregates: std::collections::HashMap::new(),
|
||||
stats: crate::distributed::coordinator::QueryStats {
|
||||
execution_time_ms: 0,
|
||||
shards_queried: 0,
|
||||
nodes_scanned: 0,
|
||||
edges_scanned: 0,
|
||||
cached: false,
|
||||
},
|
||||
};
|
||||
|
||||
// Decrement active queries
|
||||
{
|
||||
let mut count = self.active_queries.write().await;
|
||||
*count -= 1;
|
||||
}
|
||||
|
||||
Ok(ExecuteQueryResponse {
|
||||
result,
|
||||
success: true,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn replicate_data(
|
||||
&self,
|
||||
request: ReplicateDataRequest,
|
||||
) -> std::result::Result<ReplicateDataResponse, Status> {
|
||||
debug!("Replicating data for shard {}", request.shard_id);
|
||||
|
||||
// In production, perform actual replication
|
||||
Ok(ReplicateDataResponse {
|
||||
success: true,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn health_check(
|
||||
&self,
|
||||
_request: HealthCheckRequest,
|
||||
) -> std::result::Result<HealthCheckResponse, Status> {
|
||||
let uptime = self.start_time.elapsed().as_secs();
|
||||
let active = *self.active_queries.read().await;
|
||||
|
||||
Ok(HealthCheckResponse {
|
||||
healthy: true,
|
||||
load: 0.5, // Would calculate actual load
|
||||
active_queries: active,
|
||||
uptime_seconds: uptime,
|
||||
})
|
||||
}
|
||||
|
||||
async fn get_shard_info(
|
||||
&self,
|
||||
request: GetShardInfoRequest,
|
||||
) -> std::result::Result<GetShardInfoResponse, Status> {
|
||||
// In production, get actual shard info
|
||||
Ok(GetShardInfoResponse {
|
||||
shard_id: request.shard_id,
|
||||
node_count: 0,
|
||||
edge_count: 0,
|
||||
size_bytes: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// RPC connection pool for managing connections to multiple nodes
|
||||
pub struct RpcConnectionPool {
|
||||
/// Map of node_id to RPC client
|
||||
clients: Arc<dashmap::DashMap<String, Arc<RpcClient>>>,
|
||||
}
|
||||
|
||||
impl RpcConnectionPool {
|
||||
/// Create a new connection pool
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
clients: Arc::new(dashmap::DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get or create a client for a node
|
||||
pub fn get_client(&self, node_id: &str, address: &str) -> Arc<RpcClient> {
|
||||
self.clients
|
||||
.entry(node_id.to_string())
|
||||
.or_insert_with(|| Arc::new(RpcClient::new(address.to_string())))
|
||||
.clone()
|
||||
}
|
||||
|
||||
/// Remove a client from the pool
|
||||
pub fn remove_client(&self, node_id: &str) {
|
||||
self.clients.remove(node_id);
|
||||
}
|
||||
|
||||
/// Get number of active connections
|
||||
pub fn connection_count(&self) -> usize {
|
||||
self.clients.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RpcConnectionPool {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rpc_client() {
|
||||
let client = RpcClient::new("localhost:9000".to_string());
|
||||
|
||||
let request = ExecuteQueryRequest {
|
||||
query: "MATCH (n) RETURN n".to_string(),
|
||||
parameters: std::collections::HashMap::new(),
|
||||
transaction_id: None,
|
||||
};
|
||||
|
||||
let response = client.execute_query(request).await.unwrap();
|
||||
assert!(response.success);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_default_service() {
|
||||
let service = DefaultGraphRpcService::new("test-node".to_string());
|
||||
|
||||
let request = ExecuteQueryRequest {
|
||||
query: "MATCH (n) RETURN n".to_string(),
|
||||
parameters: std::collections::HashMap::new(),
|
||||
transaction_id: None,
|
||||
};
|
||||
|
||||
let response = service.execute_query(request).await.unwrap();
|
||||
assert!(response.success);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_connection_pool() {
|
||||
let pool = RpcConnectionPool::new();
|
||||
|
||||
let client1 = pool.get_client("node-1", "localhost:9000");
|
||||
let client2 = pool.get_client("node-2", "localhost:9001");
|
||||
|
||||
assert_eq!(pool.connection_count(), 2);
|
||||
|
||||
pool.remove_client("node-1");
|
||||
assert_eq!(pool.connection_count(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health_check() {
|
||||
let service = DefaultGraphRpcService::new("test-node".to_string());
|
||||
|
||||
let request = HealthCheckRequest {
|
||||
node_id: "test".to_string(),
|
||||
};
|
||||
|
||||
let response = service.health_check(request).await.unwrap();
|
||||
assert!(response.healthy);
|
||||
assert_eq!(response.active_queries, 0);
|
||||
}
|
||||
}
|
||||
595
vendor/ruvector/crates/ruvector-graph/src/distributed/shard.rs
vendored
Normal file
595
vendor/ruvector/crates/ruvector-graph/src/distributed/shard.rs
vendored
Normal file
@@ -0,0 +1,595 @@
|
||||
//! Graph sharding strategies for distributed hypergraphs
|
||||
//!
|
||||
//! Provides multiple partitioning strategies optimized for graph workloads:
|
||||
//! - Hash-based node partitioning for uniform distribution
|
||||
//! - Range-based partitioning for locality-aware queries
|
||||
//! - Edge-cut minimization for reducing cross-shard communication
|
||||
|
||||
use crate::{GraphError, Result};
|
||||
use blake3::Hasher;
|
||||
use chrono::{DateTime, Utc};
|
||||
use dashmap::DashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
use uuid::Uuid;
|
||||
use xxhash_rust::xxh3::xxh3_64;
|
||||
|
||||
/// Unique identifier for a graph node
|
||||
pub type NodeId = String;
|
||||
|
||||
/// Unique identifier for a graph edge
|
||||
pub type EdgeId = String;
|
||||
|
||||
/// Shard identifier
|
||||
pub type ShardId = u32;
|
||||
|
||||
/// Graph sharding strategy
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum ShardStrategy {
|
||||
/// Hash-based partitioning using consistent hashing
|
||||
Hash,
|
||||
/// Range-based partitioning for ordered node IDs
|
||||
Range,
|
||||
/// Edge-cut minimization for graph partitioning
|
||||
EdgeCut,
|
||||
/// Custom partitioning strategy
|
||||
Custom,
|
||||
}
|
||||
|
||||
/// Metadata about a graph shard
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ShardMetadata {
|
||||
/// Shard identifier
|
||||
pub shard_id: ShardId,
|
||||
/// Number of nodes in this shard
|
||||
pub node_count: usize,
|
||||
/// Number of edges in this shard
|
||||
pub edge_count: usize,
|
||||
/// Number of edges crossing to other shards
|
||||
pub cross_shard_edges: usize,
|
||||
/// Primary node responsible for this shard
|
||||
pub primary_node: String,
|
||||
/// Replica nodes
|
||||
pub replicas: Vec<String>,
|
||||
/// Creation timestamp
|
||||
pub created_at: DateTime<Utc>,
|
||||
/// Last modification timestamp
|
||||
pub modified_at: DateTime<Utc>,
|
||||
/// Partitioning strategy used
|
||||
pub strategy: ShardStrategy,
|
||||
}
|
||||
|
||||
impl ShardMetadata {
|
||||
/// Create new shard metadata
|
||||
pub fn new(shard_id: ShardId, primary_node: String, strategy: ShardStrategy) -> Self {
|
||||
Self {
|
||||
shard_id,
|
||||
node_count: 0,
|
||||
edge_count: 0,
|
||||
cross_shard_edges: 0,
|
||||
primary_node,
|
||||
replicas: Vec::new(),
|
||||
created_at: Utc::now(),
|
||||
modified_at: Utc::now(),
|
||||
strategy,
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate edge cut ratio (cross-shard edges / total edges)
|
||||
pub fn edge_cut_ratio(&self) -> f64 {
|
||||
if self.edge_count == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.cross_shard_edges as f64 / self.edge_count as f64
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Hash-based node partitioner
|
||||
pub struct HashPartitioner {
|
||||
/// Total number of shards
|
||||
shard_count: u32,
|
||||
/// Virtual nodes per physical shard for better distribution
|
||||
virtual_nodes: u32,
|
||||
}
|
||||
|
||||
impl HashPartitioner {
|
||||
/// Create a new hash partitioner
|
||||
pub fn new(shard_count: u32) -> Self {
|
||||
assert!(shard_count > 0, "shard_count must be greater than zero");
|
||||
Self {
|
||||
shard_count,
|
||||
virtual_nodes: 150, // Similar to consistent hashing best practices
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the shard ID for a given node ID using xxHash
|
||||
pub fn get_shard(&self, node_id: &NodeId) -> ShardId {
|
||||
let hash = xxh3_64(node_id.as_bytes());
|
||||
(hash % self.shard_count as u64) as ShardId
|
||||
}
|
||||
|
||||
/// Get the shard ID using BLAKE3 for cryptographic strength (alternative)
|
||||
pub fn get_shard_secure(&self, node_id: &NodeId) -> ShardId {
|
||||
let mut hasher = Hasher::new();
|
||||
hasher.update(node_id.as_bytes());
|
||||
let hash = hasher.finalize();
|
||||
let hash_bytes = hash.as_bytes();
|
||||
let hash_u64 = u64::from_le_bytes([
|
||||
hash_bytes[0],
|
||||
hash_bytes[1],
|
||||
hash_bytes[2],
|
||||
hash_bytes[3],
|
||||
hash_bytes[4],
|
||||
hash_bytes[5],
|
||||
hash_bytes[6],
|
||||
hash_bytes[7],
|
||||
]);
|
||||
(hash_u64 % self.shard_count as u64) as ShardId
|
||||
}
|
||||
|
||||
/// Get multiple candidate shards for replication
|
||||
pub fn get_replica_shards(&self, node_id: &NodeId, replica_count: usize) -> Vec<ShardId> {
|
||||
let mut shards = Vec::with_capacity(replica_count);
|
||||
let primary = self.get_shard(node_id);
|
||||
shards.push(primary);
|
||||
|
||||
// Generate additional shards using salted hashing
|
||||
for i in 1..replica_count {
|
||||
let salted_id = format!("{}-replica-{}", node_id, i);
|
||||
let shard = self.get_shard(&salted_id);
|
||||
if !shards.contains(&shard) {
|
||||
shards.push(shard);
|
||||
}
|
||||
}
|
||||
|
||||
shards
|
||||
}
|
||||
}
|
||||
|
||||
/// Range-based node partitioner for ordered node IDs
|
||||
pub struct RangePartitioner {
|
||||
/// Total number of shards
|
||||
shard_count: u32,
|
||||
/// Range boundaries (shard_id -> max_value in range)
|
||||
ranges: Vec<String>,
|
||||
}
|
||||
|
||||
impl RangePartitioner {
|
||||
/// Create a new range partitioner with automatic range distribution
|
||||
pub fn new(shard_count: u32) -> Self {
|
||||
Self {
|
||||
shard_count,
|
||||
ranges: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a range partitioner with explicit boundaries
|
||||
pub fn with_boundaries(boundaries: Vec<String>) -> Self {
|
||||
Self {
|
||||
shard_count: boundaries.len() as u32,
|
||||
ranges: boundaries,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the shard ID for a node based on range boundaries
|
||||
pub fn get_shard(&self, node_id: &NodeId) -> ShardId {
|
||||
if self.ranges.is_empty() {
|
||||
// Fallback to simple modulo if no ranges defined
|
||||
let hash = xxh3_64(node_id.as_bytes());
|
||||
return (hash % self.shard_count as u64) as ShardId;
|
||||
}
|
||||
|
||||
// Binary search through sorted ranges
|
||||
for (idx, boundary) in self.ranges.iter().enumerate() {
|
||||
if node_id <= boundary {
|
||||
return idx as ShardId;
|
||||
}
|
||||
}
|
||||
|
||||
// Last shard for values beyond all boundaries
|
||||
(self.shard_count - 1) as ShardId
|
||||
}
|
||||
|
||||
/// Update range boundaries based on data distribution
|
||||
pub fn update_boundaries(&mut self, new_boundaries: Vec<String>) {
|
||||
info!(
|
||||
"Updating range boundaries: old={}, new={}",
|
||||
self.ranges.len(),
|
||||
new_boundaries.len()
|
||||
);
|
||||
self.ranges = new_boundaries;
|
||||
self.shard_count = self.ranges.len() as u32;
|
||||
}
|
||||
}
|
||||
|
||||
/// Edge-cut minimization using METIS-like graph partitioning
|
||||
pub struct EdgeCutMinimizer {
|
||||
/// Total number of shards
|
||||
shard_count: u32,
|
||||
/// Node to shard assignments
|
||||
node_assignments: Arc<DashMap<NodeId, ShardId>>,
|
||||
/// Edge information for partitioning decisions
|
||||
edge_weights: Arc<DashMap<(NodeId, NodeId), f64>>,
|
||||
/// Adjacency list representation
|
||||
adjacency: Arc<DashMap<NodeId, HashSet<NodeId>>>,
|
||||
}
|
||||
|
||||
impl EdgeCutMinimizer {
|
||||
/// Create a new edge-cut minimizer
|
||||
pub fn new(shard_count: u32) -> Self {
|
||||
Self {
|
||||
shard_count,
|
||||
node_assignments: Arc::new(DashMap::new()),
|
||||
edge_weights: Arc::new(DashMap::new()),
|
||||
adjacency: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an edge to the graph for partitioning consideration
|
||||
pub fn add_edge(&self, from: NodeId, to: NodeId, weight: f64) {
|
||||
self.edge_weights.insert((from.clone(), to.clone()), weight);
|
||||
|
||||
// Update adjacency list
|
||||
self.adjacency
|
||||
.entry(from.clone())
|
||||
.or_insert_with(HashSet::new)
|
||||
.insert(to.clone());
|
||||
|
||||
self.adjacency
|
||||
.entry(to)
|
||||
.or_insert_with(HashSet::new)
|
||||
.insert(from);
|
||||
}
|
||||
|
||||
/// Get the shard assignment for a node
|
||||
pub fn get_shard(&self, node_id: &NodeId) -> Option<ShardId> {
|
||||
self.node_assignments.get(node_id).map(|r| *r.value())
|
||||
}
|
||||
|
||||
/// Compute initial partitioning using multilevel k-way partitioning
|
||||
pub fn compute_partitioning(&self) -> Result<HashMap<NodeId, ShardId>> {
|
||||
info!("Computing edge-cut minimized partitioning");
|
||||
|
||||
let nodes: Vec<_> = self.adjacency.iter().map(|e| e.key().clone()).collect();
|
||||
|
||||
if nodes.is_empty() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
|
||||
// Phase 1: Coarsening - merge highly connected nodes
|
||||
let coarse_graph = self.coarsen_graph(&nodes);
|
||||
|
||||
// Phase 2: Initial partitioning using greedy approach
|
||||
let mut assignments = self.initial_partition(&coarse_graph);
|
||||
|
||||
// Phase 3: Refinement using Kernighan-Lin algorithm
|
||||
self.refine_partition(&mut assignments);
|
||||
|
||||
// Store assignments
|
||||
for (node, shard) in &assignments {
|
||||
self.node_assignments.insert(node.clone(), *shard);
|
||||
}
|
||||
|
||||
info!(
|
||||
"Partitioning complete: {} nodes across {} shards",
|
||||
assignments.len(),
|
||||
self.shard_count
|
||||
);
|
||||
|
||||
Ok(assignments)
|
||||
}
|
||||
|
||||
/// Coarsen the graph by merging highly connected nodes
|
||||
fn coarsen_graph(&self, nodes: &[NodeId]) -> HashMap<NodeId, Vec<NodeId>> {
|
||||
let mut coarse: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
|
||||
let mut visited = HashSet::new();
|
||||
|
||||
for node in nodes {
|
||||
if visited.contains(node) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut group = vec![node.clone()];
|
||||
visited.insert(node.clone());
|
||||
|
||||
// Find best matching neighbor based on edge weight
|
||||
if let Some(neighbors) = self.adjacency.get(node) {
|
||||
let mut best_neighbor: Option<(NodeId, f64)> = None;
|
||||
|
||||
for neighbor in neighbors.iter() {
|
||||
if visited.contains(neighbor) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let weight = self
|
||||
.edge_weights
|
||||
.get(&(node.clone(), neighbor.clone()))
|
||||
.map(|w| *w.value())
|
||||
.unwrap_or(1.0);
|
||||
|
||||
if let Some((_, best_weight)) = best_neighbor {
|
||||
if weight > best_weight {
|
||||
best_neighbor = Some((neighbor.clone(), weight));
|
||||
}
|
||||
} else {
|
||||
best_neighbor = Some((neighbor.clone(), weight));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some((neighbor, _)) = best_neighbor {
|
||||
group.push(neighbor.clone());
|
||||
visited.insert(neighbor);
|
||||
}
|
||||
}
|
||||
|
||||
let representative = node.clone();
|
||||
coarse.insert(representative, group);
|
||||
}
|
||||
|
||||
coarse
|
||||
}
|
||||
|
||||
/// Initial partition using greedy approach
|
||||
fn initial_partition(
|
||||
&self,
|
||||
coarse_graph: &HashMap<NodeId, Vec<NodeId>>,
|
||||
) -> HashMap<NodeId, ShardId> {
|
||||
let mut assignments = HashMap::new();
|
||||
let mut shard_sizes: Vec<usize> = vec![0; self.shard_count as usize];
|
||||
|
||||
for (representative, group) in coarse_graph {
|
||||
// Assign to least-loaded shard
|
||||
let shard = shard_sizes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.min_by_key(|(_, size)| *size)
|
||||
.map(|(idx, _)| idx as ShardId)
|
||||
.unwrap_or(0);
|
||||
|
||||
for node in group {
|
||||
assignments.insert(node.clone(), shard);
|
||||
shard_sizes[shard as usize] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
assignments
|
||||
}
|
||||
|
||||
/// Refine partition using simplified Kernighan-Lin algorithm
|
||||
fn refine_partition(&self, assignments: &mut HashMap<NodeId, ShardId>) {
|
||||
const MAX_ITERATIONS: usize = 10;
|
||||
let mut improved = true;
|
||||
let mut iteration = 0;
|
||||
|
||||
while improved && iteration < MAX_ITERATIONS {
|
||||
improved = false;
|
||||
iteration += 1;
|
||||
|
||||
for (node, current_shard) in assignments.clone().iter() {
|
||||
let current_cost = self.compute_node_cost(node, *current_shard, assignments);
|
||||
|
||||
// Try moving to each other shard
|
||||
for target_shard in 0..self.shard_count {
|
||||
if target_shard == *current_shard {
|
||||
continue;
|
||||
}
|
||||
|
||||
let new_cost = self.compute_node_cost(node, target_shard, assignments);
|
||||
|
||||
if new_cost < current_cost {
|
||||
assignments.insert(node.clone(), target_shard);
|
||||
improved = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!("Refinement iteration {}: improved={}", iteration, improved);
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the cost (number of cross-shard edges) for a node in a given shard
|
||||
fn compute_node_cost(
|
||||
&self,
|
||||
node: &NodeId,
|
||||
shard: ShardId,
|
||||
assignments: &HashMap<NodeId, ShardId>,
|
||||
) -> usize {
|
||||
let mut cross_shard_edges = 0;
|
||||
|
||||
if let Some(neighbors) = self.adjacency.get(node) {
|
||||
for neighbor in neighbors.iter() {
|
||||
if let Some(neighbor_shard) = assignments.get(neighbor) {
|
||||
if *neighbor_shard != shard {
|
||||
cross_shard_edges += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cross_shard_edges
|
||||
}
|
||||
|
||||
/// Calculate total edge cut across all shards
|
||||
pub fn calculate_edge_cut(&self, assignments: &HashMap<NodeId, ShardId>) -> usize {
|
||||
let mut cut = 0;
|
||||
|
||||
for entry in self.edge_weights.iter() {
|
||||
let ((from, to), _) = entry.pair();
|
||||
let from_shard = assignments.get(from);
|
||||
let to_shard = assignments.get(to);
|
||||
|
||||
if from_shard.is_some() && to_shard.is_some() && from_shard != to_shard {
|
||||
cut += 1;
|
||||
}
|
||||
}
|
||||
|
||||
cut
|
||||
}
|
||||
}
|
||||
|
||||
/// Graph shard containing partitioned data
|
||||
pub struct GraphShard {
|
||||
/// Shard metadata
|
||||
metadata: ShardMetadata,
|
||||
/// Nodes in this shard
|
||||
nodes: Arc<DashMap<NodeId, NodeData>>,
|
||||
/// Edges in this shard (including cross-shard edges)
|
||||
edges: Arc<DashMap<EdgeId, EdgeData>>,
|
||||
/// Partitioning strategy
|
||||
strategy: ShardStrategy,
|
||||
}
|
||||
|
||||
/// Node data in the graph
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NodeData {
|
||||
pub id: NodeId,
|
||||
pub properties: HashMap<String, serde_json::Value>,
|
||||
pub labels: Vec<String>,
|
||||
}
|
||||
|
||||
/// Edge data in the graph
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EdgeData {
|
||||
pub id: EdgeId,
|
||||
pub from: NodeId,
|
||||
pub to: NodeId,
|
||||
pub edge_type: String,
|
||||
pub properties: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
impl GraphShard {
|
||||
/// Create a new graph shard
|
||||
pub fn new(metadata: ShardMetadata) -> Self {
|
||||
let strategy = metadata.strategy;
|
||||
Self {
|
||||
metadata,
|
||||
nodes: Arc::new(DashMap::new()),
|
||||
edges: Arc::new(DashMap::new()),
|
||||
strategy,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a node to this shard
|
||||
pub fn add_node(&self, node: NodeData) -> Result<()> {
|
||||
self.nodes.insert(node.id.clone(), node);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add an edge to this shard
|
||||
pub fn add_edge(&self, edge: EdgeData) -> Result<()> {
|
||||
self.edges.insert(edge.id.clone(), edge);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get a node by ID
|
||||
pub fn get_node(&self, node_id: &NodeId) -> Option<NodeData> {
|
||||
self.nodes.get(node_id).map(|n| n.value().clone())
|
||||
}
|
||||
|
||||
/// Get an edge by ID
|
||||
pub fn get_edge(&self, edge_id: &EdgeId) -> Option<EdgeData> {
|
||||
self.edges.get(edge_id).map(|e| e.value().clone())
|
||||
}
|
||||
|
||||
/// Get shard metadata
|
||||
pub fn metadata(&self) -> &ShardMetadata {
|
||||
&self.metadata
|
||||
}
|
||||
|
||||
/// Get node count
|
||||
pub fn node_count(&self) -> usize {
|
||||
self.nodes.len()
|
||||
}
|
||||
|
||||
/// Get edge count
|
||||
pub fn edge_count(&self) -> usize {
|
||||
self.edges.len()
|
||||
}
|
||||
|
||||
/// List all nodes in this shard
|
||||
pub fn list_nodes(&self) -> Vec<NodeData> {
|
||||
self.nodes.iter().map(|e| e.value().clone()).collect()
|
||||
}
|
||||
|
||||
/// List all edges in this shard
|
||||
pub fn list_edges(&self) -> Vec<EdgeData> {
|
||||
self.edges.iter().map(|e| e.value().clone()).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_hash_partitioner() {
|
||||
let partitioner = HashPartitioner::new(16);
|
||||
|
||||
let node1 = "node-1".to_string();
|
||||
let node2 = "node-2".to_string();
|
||||
|
||||
let shard1 = partitioner.get_shard(&node1);
|
||||
let shard2 = partitioner.get_shard(&node2);
|
||||
|
||||
assert!(shard1 < 16);
|
||||
assert!(shard2 < 16);
|
||||
|
||||
// Same node should always map to same shard
|
||||
assert_eq!(shard1, partitioner.get_shard(&node1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_range_partitioner() {
|
||||
let boundaries = vec!["m".to_string(), "z".to_string()];
|
||||
let partitioner = RangePartitioner::with_boundaries(boundaries);
|
||||
|
||||
assert_eq!(partitioner.get_shard(&"apple".to_string()), 0);
|
||||
assert_eq!(partitioner.get_shard(&"orange".to_string()), 1);
|
||||
assert_eq!(partitioner.get_shard(&"zebra".to_string()), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_cut_minimizer() {
|
||||
let minimizer = EdgeCutMinimizer::new(2);
|
||||
|
||||
// Create a simple graph: A-B-C-D
|
||||
minimizer.add_edge("A".to_string(), "B".to_string(), 1.0);
|
||||
minimizer.add_edge("B".to_string(), "C".to_string(), 1.0);
|
||||
minimizer.add_edge("C".to_string(), "D".to_string(), 1.0);
|
||||
|
||||
let assignments = minimizer.compute_partitioning().unwrap();
|
||||
let cut = minimizer.calculate_edge_cut(&assignments);
|
||||
|
||||
// Optimal partitioning should minimize edge cuts
|
||||
assert!(cut <= 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shard_metadata() {
|
||||
let metadata = ShardMetadata::new(0, "node-1".to_string(), ShardStrategy::Hash);
|
||||
|
||||
assert_eq!(metadata.shard_id, 0);
|
||||
assert_eq!(metadata.edge_cut_ratio(), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_graph_shard() {
|
||||
let metadata = ShardMetadata::new(0, "node-1".to_string(), ShardStrategy::Hash);
|
||||
let shard = GraphShard::new(metadata);
|
||||
|
||||
let node = NodeData {
|
||||
id: "test-node".to_string(),
|
||||
properties: HashMap::new(),
|
||||
labels: vec!["TestLabel".to_string()],
|
||||
};
|
||||
|
||||
shard.add_node(node.clone()).unwrap();
|
||||
|
||||
assert_eq!(shard.node_count(), 1);
|
||||
assert!(shard.get_node(&"test-node".to_string()).is_some());
|
||||
}
|
||||
}
|
||||
150
vendor/ruvector/crates/ruvector-graph/src/edge.rs
vendored
Normal file
150
vendor/ruvector/crates/ruvector-graph/src/edge.rs
vendored
Normal file
@@ -0,0 +1,150 @@
|
||||
//! Edge (relationship) implementation
|
||||
|
||||
use crate::types::{EdgeId, NodeId, Properties, PropertyValue};
|
||||
use bincode::{Decode, Encode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
|
||||
pub struct Edge {
|
||||
pub id: EdgeId,
|
||||
pub from: NodeId,
|
||||
pub to: NodeId,
|
||||
pub edge_type: String,
|
||||
pub properties: Properties,
|
||||
}
|
||||
|
||||
impl Edge {
|
||||
/// Create a new edge with all fields
|
||||
pub fn new(
|
||||
id: EdgeId,
|
||||
from: NodeId,
|
||||
to: NodeId,
|
||||
edge_type: String,
|
||||
properties: Properties,
|
||||
) -> Self {
|
||||
Self {
|
||||
id,
|
||||
from,
|
||||
to,
|
||||
edge_type,
|
||||
properties,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new edge with auto-generated ID and empty properties
|
||||
pub fn create(from: NodeId, to: NodeId, edge_type: impl Into<String>) -> Self {
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
from,
|
||||
to,
|
||||
edge_type: edge_type.into(),
|
||||
properties: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a property value by key
|
||||
pub fn get_property(&self, key: &str) -> Option<&PropertyValue> {
|
||||
self.properties.get(key)
|
||||
}
|
||||
|
||||
/// Set a property value
|
||||
pub fn set_property(&mut self, key: impl Into<String>, value: PropertyValue) {
|
||||
self.properties.insert(key.into(), value);
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for constructing Edge instances
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EdgeBuilder {
|
||||
id: Option<EdgeId>,
|
||||
from: NodeId,
|
||||
to: NodeId,
|
||||
edge_type: String,
|
||||
properties: Properties,
|
||||
}
|
||||
|
||||
impl EdgeBuilder {
|
||||
/// Create a new edge builder with required fields
|
||||
pub fn new(from: NodeId, to: NodeId, edge_type: impl Into<String>) -> Self {
|
||||
Self {
|
||||
id: None,
|
||||
from,
|
||||
to,
|
||||
edge_type: edge_type.into(),
|
||||
properties: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a custom edge ID
|
||||
pub fn id(mut self, id: impl Into<String>) -> Self {
|
||||
self.id = Some(id.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a property to the edge
|
||||
pub fn property<V: Into<PropertyValue>>(mut self, key: impl Into<String>, value: V) -> Self {
|
||||
self.properties.insert(key.into(), value.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Add multiple properties to the edge
|
||||
pub fn properties(mut self, props: Properties) -> Self {
|
||||
self.properties.extend(props);
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the edge
|
||||
pub fn build(self) -> Edge {
|
||||
Edge {
|
||||
id: self.id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||
from: self.from,
|
||||
to: self.to,
|
||||
edge_type: self.edge_type,
|
||||
properties: self.properties,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_edge_builder() {
|
||||
let edge = EdgeBuilder::new("node1".to_string(), "node2".to_string(), "KNOWS")
|
||||
.property("since", 2020i64)
|
||||
.build();
|
||||
|
||||
assert_eq!(edge.from, "node1");
|
||||
assert_eq!(edge.to, "node2");
|
||||
assert_eq!(edge.edge_type, "KNOWS");
|
||||
assert_eq!(
|
||||
edge.get_property("since"),
|
||||
Some(&PropertyValue::Integer(2020))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_create() {
|
||||
let edge = Edge::create("a".to_string(), "b".to_string(), "FOLLOWS");
|
||||
assert_eq!(edge.from, "a");
|
||||
assert_eq!(edge.to, "b");
|
||||
assert_eq!(edge.edge_type, "FOLLOWS");
|
||||
assert!(edge.properties.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_new() {
|
||||
let edge = Edge::new(
|
||||
"e1".to_string(),
|
||||
"n1".to_string(),
|
||||
"n2".to_string(),
|
||||
"LIKES".to_string(),
|
||||
HashMap::new(),
|
||||
);
|
||||
assert_eq!(edge.id, "e1");
|
||||
assert_eq!(edge.edge_type, "LIKES");
|
||||
}
|
||||
}
|
||||
101
vendor/ruvector/crates/ruvector-graph/src/error.rs
vendored
Normal file
101
vendor/ruvector/crates/ruvector-graph/src/error.rs
vendored
Normal file
@@ -0,0 +1,101 @@
|
||||
//! Error types for graph database operations
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum GraphError {
|
||||
#[error("Node not found: {0}")]
|
||||
NodeNotFound(String),
|
||||
|
||||
#[error("Edge not found: {0}")]
|
||||
EdgeNotFound(String),
|
||||
|
||||
#[error("Hyperedge not found: {0}")]
|
||||
HyperedgeNotFound(String),
|
||||
|
||||
#[error("Invalid query: {0}")]
|
||||
InvalidQuery(String),
|
||||
|
||||
#[error("Transaction error: {0}")]
|
||||
TransactionError(String),
|
||||
|
||||
#[error("Constraint violation: {0}")]
|
||||
ConstraintViolation(String),
|
||||
|
||||
#[error("Cypher parse error: {0}")]
|
||||
CypherParseError(String),
|
||||
|
||||
#[error("Cypher execution error: {0}")]
|
||||
CypherExecutionError(String),
|
||||
|
||||
#[error("Distributed operation failed: {0}")]
|
||||
DistributedError(String),
|
||||
|
||||
#[error("Invalid input: {0}")]
|
||||
InvalidInput(String),
|
||||
|
||||
#[error("Shard error: {0}")]
|
||||
ShardError(String),
|
||||
|
||||
#[error("Coordinator error: {0}")]
|
||||
CoordinatorError(String),
|
||||
|
||||
#[error("Federation error: {0}")]
|
||||
FederationError(String),
|
||||
|
||||
#[error("RPC error: {0}")]
|
||||
RpcError(String),
|
||||
|
||||
#[error("Query error: {0}")]
|
||||
QueryError(String),
|
||||
|
||||
#[error("Network error: {0}")]
|
||||
NetworkError(String),
|
||||
|
||||
#[error("Serialization error: {0}")]
|
||||
SerializationError(String),
|
||||
|
||||
#[error("Replication error: {0}")]
|
||||
ReplicationError(String),
|
||||
|
||||
#[error("Cluster error: {0}")]
|
||||
ClusterError(String),
|
||||
|
||||
#[error("Index error: {0}")]
|
||||
IndexError(String),
|
||||
|
||||
#[error("Invalid embedding: {0}")]
|
||||
InvalidEmbedding(String),
|
||||
|
||||
#[error("Storage error: {0}")]
|
||||
StorageError(String),
|
||||
|
||||
#[error("Execution error: {0}")]
|
||||
ExecutionError(String),
|
||||
|
||||
#[error("Configuration error: {0}")]
|
||||
ConfigError(String),
|
||||
|
||||
#[error("IO error: {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
impl From<anyhow::Error> for GraphError {
|
||||
fn from(err: anyhow::Error) -> Self {
|
||||
GraphError::StorageError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<bincode::error::EncodeError> for GraphError {
|
||||
fn from(err: bincode::error::EncodeError) -> Self {
|
||||
GraphError::SerializationError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<bincode::error::DecodeError> for GraphError {
|
||||
fn from(err: bincode::error::DecodeError) -> Self {
|
||||
GraphError::SerializationError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, GraphError>;
|
||||
365
vendor/ruvector/crates/ruvector-graph/src/executor/cache.rs
vendored
Normal file
365
vendor/ruvector/crates/ruvector-graph/src/executor/cache.rs
vendored
Normal file
@@ -0,0 +1,365 @@
|
||||
//! Query result caching for performance optimization
|
||||
//!
|
||||
//! Implements LRU cache with TTL support
|
||||
|
||||
use crate::executor::pipeline::RowBatch;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
/// Cache configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CacheConfig {
|
||||
/// Maximum number of cached entries
|
||||
pub max_entries: usize,
|
||||
/// Maximum memory usage in bytes
|
||||
pub max_memory_bytes: usize,
|
||||
/// Time-to-live for cache entries in seconds
|
||||
pub ttl_seconds: u64,
|
||||
}
|
||||
|
||||
impl CacheConfig {
|
||||
/// Create new cache config
|
||||
pub fn new(max_entries: usize, max_memory_bytes: usize, ttl_seconds: u64) -> Self {
|
||||
Self {
|
||||
max_entries,
|
||||
max_memory_bytes,
|
||||
ttl_seconds,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CacheConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_entries: 1000,
|
||||
max_memory_bytes: 100 * 1024 * 1024, // 100MB
|
||||
ttl_seconds: 300, // 5 minutes
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cache entry with metadata
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CacheEntry {
|
||||
/// Cached query results
|
||||
pub results: Vec<RowBatch>,
|
||||
/// Entry creation time
|
||||
pub created_at: Instant,
|
||||
/// Last access time
|
||||
pub last_accessed: Instant,
|
||||
/// Estimated memory size in bytes
|
||||
pub size_bytes: usize,
|
||||
/// Access count
|
||||
pub access_count: u64,
|
||||
}
|
||||
|
||||
impl CacheEntry {
|
||||
/// Create new cache entry
|
||||
pub fn new(results: Vec<RowBatch>) -> Self {
|
||||
let size_bytes = Self::estimate_size(&results);
|
||||
let now = Instant::now();
|
||||
|
||||
Self {
|
||||
results,
|
||||
created_at: now,
|
||||
last_accessed: now,
|
||||
size_bytes,
|
||||
access_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate memory size of results
|
||||
fn estimate_size(results: &[RowBatch]) -> usize {
|
||||
results
|
||||
.iter()
|
||||
.map(|batch| {
|
||||
// Rough estimate: 8 bytes per value + overhead
|
||||
batch.len() * batch.schema.columns.len() * 8 + 1024
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Check if entry is expired
|
||||
pub fn is_expired(&self, ttl: Duration) -> bool {
|
||||
self.created_at.elapsed() > ttl
|
||||
}
|
||||
|
||||
/// Update access metadata
|
||||
pub fn mark_accessed(&mut self) {
|
||||
self.last_accessed = Instant::now();
|
||||
self.access_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// LRU cache for query results
|
||||
pub struct QueryCache {
|
||||
/// Cache storage
|
||||
entries: Arc<RwLock<HashMap<String, CacheEntry>>>,
|
||||
/// LRU tracking
|
||||
lru_order: Arc<RwLock<Vec<String>>>,
|
||||
/// Configuration
|
||||
config: CacheConfig,
|
||||
/// Current memory usage
|
||||
memory_used: Arc<RwLock<usize>>,
|
||||
/// Cache statistics
|
||||
stats: Arc<RwLock<CacheStats>>,
|
||||
}
|
||||
|
||||
impl QueryCache {
|
||||
/// Create a new query cache
|
||||
pub fn new(config: CacheConfig) -> Self {
|
||||
Self {
|
||||
entries: Arc::new(RwLock::new(HashMap::new())),
|
||||
lru_order: Arc::new(RwLock::new(Vec::new())),
|
||||
config,
|
||||
memory_used: Arc::new(RwLock::new(0)),
|
||||
stats: Arc::new(RwLock::new(CacheStats::default())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get cached results
|
||||
pub fn get(&self, key: &str) -> Option<CacheEntry> {
|
||||
let mut entries = self.entries.write().ok()?;
|
||||
let mut lru = self.lru_order.write().ok()?;
|
||||
let mut stats = self.stats.write().ok()?;
|
||||
|
||||
if let Some(entry) = entries.get_mut(key) {
|
||||
// Check if expired
|
||||
if entry.is_expired(Duration::from_secs(self.config.ttl_seconds)) {
|
||||
stats.misses += 1;
|
||||
return None;
|
||||
}
|
||||
|
||||
// Update LRU order
|
||||
if let Some(pos) = lru.iter().position(|k| k == key) {
|
||||
lru.remove(pos);
|
||||
}
|
||||
lru.push(key.to_string());
|
||||
|
||||
// Update access metadata
|
||||
entry.mark_accessed();
|
||||
stats.hits += 1;
|
||||
|
||||
Some(entry.clone())
|
||||
} else {
|
||||
stats.misses += 1;
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert results into cache
|
||||
pub fn insert(&self, key: String, results: Vec<RowBatch>) {
|
||||
let entry = CacheEntry::new(results);
|
||||
let entry_size = entry.size_bytes;
|
||||
|
||||
let mut entries = self.entries.write().unwrap();
|
||||
let mut lru = self.lru_order.write().unwrap();
|
||||
let mut memory = self.memory_used.write().unwrap();
|
||||
let mut stats = self.stats.write().unwrap();
|
||||
|
||||
// Evict if necessary
|
||||
while (entries.len() >= self.config.max_entries
|
||||
|| *memory + entry_size > self.config.max_memory_bytes)
|
||||
&& !lru.is_empty()
|
||||
{
|
||||
if let Some(old_key) = lru.first().cloned() {
|
||||
if let Some(old_entry) = entries.remove(&old_key) {
|
||||
*memory = memory.saturating_sub(old_entry.size_bytes);
|
||||
stats.evictions += 1;
|
||||
}
|
||||
lru.remove(0);
|
||||
}
|
||||
}
|
||||
|
||||
// Insert new entry
|
||||
entries.insert(key.clone(), entry);
|
||||
lru.push(key);
|
||||
*memory += entry_size;
|
||||
stats.inserts += 1;
|
||||
}
|
||||
|
||||
/// Remove entry from cache
|
||||
pub fn remove(&self, key: &str) -> bool {
|
||||
let mut entries = self.entries.write().unwrap();
|
||||
let mut lru = self.lru_order.write().unwrap();
|
||||
let mut memory = self.memory_used.write().unwrap();
|
||||
|
||||
if let Some(entry) = entries.remove(key) {
|
||||
*memory = memory.saturating_sub(entry.size_bytes);
|
||||
if let Some(pos) = lru.iter().position(|k| k == key) {
|
||||
lru.remove(pos);
|
||||
}
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear all cache entries
|
||||
pub fn clear(&self) {
|
||||
let mut entries = self.entries.write().unwrap();
|
||||
let mut lru = self.lru_order.write().unwrap();
|
||||
let mut memory = self.memory_used.write().unwrap();
|
||||
|
||||
entries.clear();
|
||||
lru.clear();
|
||||
*memory = 0;
|
||||
}
|
||||
|
||||
/// Get cache statistics
|
||||
pub fn stats(&self) -> CacheStats {
|
||||
self.stats.read().unwrap().clone()
|
||||
}
|
||||
|
||||
/// Get current memory usage
|
||||
pub fn memory_used(&self) -> usize {
|
||||
*self.memory_used.read().unwrap()
|
||||
}
|
||||
|
||||
/// Get number of cached entries
|
||||
pub fn len(&self) -> usize {
|
||||
self.entries.read().unwrap().len()
|
||||
}
|
||||
|
||||
/// Check if cache is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.entries.read().unwrap().is_empty()
|
||||
}
|
||||
|
||||
/// Clean expired entries
|
||||
pub fn clean_expired(&self) {
|
||||
let ttl = Duration::from_secs(self.config.ttl_seconds);
|
||||
let mut entries = self.entries.write().unwrap();
|
||||
let mut lru = self.lru_order.write().unwrap();
|
||||
let mut memory = self.memory_used.write().unwrap();
|
||||
let mut stats = self.stats.write().unwrap();
|
||||
|
||||
let expired_keys: Vec<_> = entries
|
||||
.iter()
|
||||
.filter(|(_, entry)| entry.is_expired(ttl))
|
||||
.map(|(key, _)| key.clone())
|
||||
.collect();
|
||||
|
||||
for key in expired_keys {
|
||||
if let Some(entry) = entries.remove(&key) {
|
||||
*memory = memory.saturating_sub(entry.size_bytes);
|
||||
if let Some(pos) = lru.iter().position(|k| k == &key) {
|
||||
lru.remove(pos);
|
||||
}
|
||||
stats.evictions += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cache statistics
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct CacheStats {
|
||||
/// Number of cache hits
|
||||
pub hits: u64,
|
||||
/// Number of cache misses
|
||||
pub misses: u64,
|
||||
/// Number of insertions
|
||||
pub inserts: u64,
|
||||
/// Number of evictions
|
||||
pub evictions: u64,
|
||||
}
|
||||
|
||||
impl CacheStats {
|
||||
/// Calculate hit rate
|
||||
pub fn hit_rate(&self) -> f64 {
|
||||
let total = self.hits + self.misses;
|
||||
if total == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.hits as f64 / total as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset statistics
|
||||
pub fn reset(&mut self) {
|
||||
self.hits = 0;
|
||||
self.misses = 0;
|
||||
self.inserts = 0;
|
||||
self.evictions = 0;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::executor::plan::{ColumnDef, DataType, QuerySchema};
|
||||
|
||||
fn create_test_batch() -> RowBatch {
|
||||
let schema = QuerySchema::new(vec![ColumnDef {
|
||||
name: "id".to_string(),
|
||||
data_type: DataType::Int64,
|
||||
nullable: false,
|
||||
}]);
|
||||
RowBatch::new(schema)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_insert_and_get() {
|
||||
let cache = QueryCache::new(CacheConfig::default());
|
||||
let batch = create_test_batch();
|
||||
|
||||
cache.insert("test_key".to_string(), vec![batch.clone()]);
|
||||
assert_eq!(cache.len(), 1);
|
||||
|
||||
let cached = cache.get("test_key");
|
||||
assert!(cached.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_miss() {
|
||||
let cache = QueryCache::new(CacheConfig::default());
|
||||
let result = cache.get("nonexistent");
|
||||
assert!(result.is_none());
|
||||
|
||||
let stats = cache.stats();
|
||||
assert_eq!(stats.misses, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_eviction() {
|
||||
let config = CacheConfig {
|
||||
max_entries: 2,
|
||||
max_memory_bytes: 1024 * 1024,
|
||||
ttl_seconds: 300,
|
||||
};
|
||||
let cache = QueryCache::new(config);
|
||||
let batch = create_test_batch();
|
||||
|
||||
cache.insert("key1".to_string(), vec![batch.clone()]);
|
||||
cache.insert("key2".to_string(), vec![batch.clone()]);
|
||||
cache.insert("key3".to_string(), vec![batch.clone()]);
|
||||
|
||||
// Should have evicted oldest entry
|
||||
assert_eq!(cache.len(), 2);
|
||||
assert!(cache.get("key1").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_clear() {
|
||||
let cache = QueryCache::new(CacheConfig::default());
|
||||
let batch = create_test_batch();
|
||||
|
||||
cache.insert("key1".to_string(), vec![batch.clone()]);
|
||||
cache.insert("key2".to_string(), vec![batch.clone()]);
|
||||
|
||||
cache.clear();
|
||||
assert_eq!(cache.len(), 0);
|
||||
assert_eq!(cache.memory_used(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hit_rate() {
|
||||
let mut stats = CacheStats::default();
|
||||
stats.hits = 7;
|
||||
stats.misses = 3;
|
||||
|
||||
assert!((stats.hit_rate() - 0.7).abs() < 0.001);
|
||||
}
|
||||
}
|
||||
183
vendor/ruvector/crates/ruvector-graph/src/executor/mod.rs
vendored
Normal file
183
vendor/ruvector/crates/ruvector-graph/src/executor/mod.rs
vendored
Normal file
@@ -0,0 +1,183 @@
|
||||
//! High-performance query execution engine for RuVector graph database
|
||||
//!
|
||||
//! This module provides a complete query execution system with:
|
||||
//! - Logical and physical query plans
|
||||
//! - Vectorized operators (scan, filter, join, aggregate)
|
||||
//! - Pipeline execution with iterator model
|
||||
//! - Parallel execution using rayon
|
||||
//! - Query result caching
|
||||
//! - Cost-based optimization statistics
|
||||
//!
|
||||
//! Performance targets:
|
||||
//! - 100K+ traversals/second per core
|
||||
//! - Sub-millisecond simple lookups
|
||||
//! - SIMD-optimized predicate evaluation
|
||||
|
||||
pub mod cache;
|
||||
pub mod operators;
|
||||
pub mod parallel;
|
||||
pub mod pipeline;
|
||||
pub mod plan;
|
||||
pub mod stats;
|
||||
|
||||
pub use cache::{CacheConfig, CacheEntry, QueryCache};
|
||||
pub use operators::{
|
||||
Aggregate, AggregateFunction, EdgeScan, Filter, HyperedgeScan, Join, JoinType, Limit, NodeScan,
|
||||
Operator, Project, ScanMode, Sort,
|
||||
};
|
||||
pub use parallel::{ParallelConfig, ParallelExecutor};
|
||||
pub use pipeline::{ExecutionContext, Pipeline, RowBatch};
|
||||
pub use plan::{LogicalPlan, PhysicalPlan, PlanNode};
|
||||
pub use stats::{ColumnStats, Histogram, Statistics, TableStats};
|
||||
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Query execution error types
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ExecutionError {
|
||||
/// Invalid query plan
|
||||
InvalidPlan(String),
|
||||
/// Operator execution failed
|
||||
OperatorError(String),
|
||||
/// Type mismatch in expression evaluation
|
||||
TypeMismatch(String),
|
||||
/// Resource exhausted (memory, disk, etc.)
|
||||
ResourceExhausted(String),
|
||||
/// Internal error
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for ExecutionError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
ExecutionError::InvalidPlan(msg) => write!(f, "Invalid plan: {}", msg),
|
||||
ExecutionError::OperatorError(msg) => write!(f, "Operator error: {}", msg),
|
||||
ExecutionError::TypeMismatch(msg) => write!(f, "Type mismatch: {}", msg),
|
||||
ExecutionError::ResourceExhausted(msg) => write!(f, "Resource exhausted: {}", msg),
|
||||
ExecutionError::Internal(msg) => write!(f, "Internal error: {}", msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for ExecutionError {}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ExecutionError>;
|
||||
|
||||
/// Query execution engine
|
||||
pub struct QueryExecutor {
|
||||
/// Query result cache
|
||||
cache: Arc<QueryCache>,
|
||||
/// Execution statistics
|
||||
stats: Arc<Statistics>,
|
||||
/// Parallel execution configuration
|
||||
parallel_config: ParallelConfig,
|
||||
}
|
||||
|
||||
impl QueryExecutor {
|
||||
/// Create a new query executor
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
cache: Arc::new(QueryCache::new(CacheConfig::default())),
|
||||
stats: Arc::new(Statistics::new()),
|
||||
parallel_config: ParallelConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create executor with custom configuration
|
||||
pub fn with_config(cache_config: CacheConfig, parallel_config: ParallelConfig) -> Self {
|
||||
Self {
|
||||
cache: Arc::new(QueryCache::new(cache_config)),
|
||||
stats: Arc::new(Statistics::new()),
|
||||
parallel_config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a logical plan
|
||||
pub fn execute(&self, plan: &LogicalPlan) -> Result<Vec<RowBatch>> {
|
||||
// Check cache first
|
||||
let cache_key = plan.cache_key();
|
||||
if let Some(cached) = self.cache.get(&cache_key) {
|
||||
return Ok(cached.results.clone());
|
||||
}
|
||||
|
||||
// Optimize logical plan to physical plan
|
||||
let physical_plan = self.optimize(plan)?;
|
||||
|
||||
// Execute physical plan
|
||||
let results = if self.parallel_config.enabled && plan.is_parallelizable() {
|
||||
self.execute_parallel(&physical_plan)?
|
||||
} else {
|
||||
self.execute_sequential(&physical_plan)?
|
||||
};
|
||||
|
||||
// Cache results
|
||||
self.cache.insert(cache_key, results.clone());
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Optimize logical plan to physical plan
|
||||
fn optimize(&self, plan: &LogicalPlan) -> Result<PhysicalPlan> {
|
||||
// Cost-based optimization using statistics
|
||||
let physical = PhysicalPlan::from_logical(plan, &self.stats)?;
|
||||
Ok(physical)
|
||||
}
|
||||
|
||||
/// Execute plan sequentially
|
||||
fn execute_sequential(&self, _plan: &PhysicalPlan) -> Result<Vec<RowBatch>> {
|
||||
// Note: In a real implementation, we would need to reconstruct operators
|
||||
// For now, return empty results as placeholder
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
/// Execute plan in parallel
|
||||
fn execute_parallel(&self, plan: &PhysicalPlan) -> Result<Vec<RowBatch>> {
|
||||
let executor = ParallelExecutor::new(self.parallel_config.clone());
|
||||
executor.execute(plan)
|
||||
}
|
||||
|
||||
/// Get execution statistics
|
||||
pub fn stats(&self) -> Arc<Statistics> {
|
||||
Arc::clone(&self.stats)
|
||||
}
|
||||
|
||||
/// Clear query cache
|
||||
pub fn clear_cache(&self) {
|
||||
self.cache.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for QueryExecutor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_executor_creation() {
|
||||
let executor = QueryExecutor::new();
|
||||
assert!(executor.stats().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_executor_with_config() {
|
||||
let cache_config = CacheConfig {
|
||||
max_entries: 100,
|
||||
max_memory_bytes: 1024 * 1024,
|
||||
ttl_seconds: 300,
|
||||
};
|
||||
let parallel_config = ParallelConfig {
|
||||
enabled: true,
|
||||
num_threads: 4,
|
||||
batch_size: 1000,
|
||||
};
|
||||
let executor = QueryExecutor::with_config(cache_config, parallel_config);
|
||||
assert!(executor.stats().is_empty());
|
||||
}
|
||||
}
|
||||
521
vendor/ruvector/crates/ruvector-graph/src/executor/operators.rs
vendored
Normal file
521
vendor/ruvector/crates/ruvector-graph/src/executor/operators.rs
vendored
Normal file
@@ -0,0 +1,521 @@
|
||||
//! Query operators for graph traversal and data processing
|
||||
//!
|
||||
//! High-performance implementations with SIMD optimization
|
||||
|
||||
use crate::executor::pipeline::RowBatch;
|
||||
use crate::executor::plan::{Predicate, Value};
|
||||
use crate::executor::{ExecutionError, Result};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
use std::arch::x86_64::*;
|
||||
|
||||
/// Base trait for all query operators
|
||||
pub trait Operator: Send + Sync {
|
||||
/// Execute operator and produce output batch
|
||||
fn execute(&mut self, input: Option<RowBatch>) -> Result<Option<RowBatch>>;
|
||||
|
||||
/// Get operator name for debugging
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Check if operator is pipeline breaker
|
||||
fn is_pipeline_breaker(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Scan mode for data access
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ScanMode {
|
||||
/// Sequential scan
|
||||
Sequential,
|
||||
/// Index-based scan
|
||||
Index { index_name: String },
|
||||
/// Range scan with bounds
|
||||
Range { start: Value, end: Value },
|
||||
}
|
||||
|
||||
/// Node scan operator
|
||||
pub struct NodeScan {
|
||||
mode: ScanMode,
|
||||
filter: Option<Predicate>,
|
||||
position: usize,
|
||||
}
|
||||
|
||||
impl NodeScan {
|
||||
pub fn new(mode: ScanMode, filter: Option<Predicate>) -> Self {
|
||||
Self {
|
||||
mode,
|
||||
filter,
|
||||
position: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Operator for NodeScan {
|
||||
fn execute(&mut self, _input: Option<RowBatch>) -> Result<Option<RowBatch>> {
|
||||
// Placeholder implementation
|
||||
// In real implementation, scan graph storage
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"NodeScan"
|
||||
}
|
||||
}
|
||||
|
||||
/// Edge scan operator
|
||||
pub struct EdgeScan {
|
||||
mode: ScanMode,
|
||||
filter: Option<Predicate>,
|
||||
position: usize,
|
||||
}
|
||||
|
||||
impl EdgeScan {
|
||||
pub fn new(mode: ScanMode, filter: Option<Predicate>) -> Self {
|
||||
Self {
|
||||
mode,
|
||||
filter,
|
||||
position: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Operator for EdgeScan {
|
||||
fn execute(&mut self, _input: Option<RowBatch>) -> Result<Option<RowBatch>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"EdgeScan"
|
||||
}
|
||||
}
|
||||
|
||||
/// Hyperedge scan operator
|
||||
pub struct HyperedgeScan {
|
||||
mode: ScanMode,
|
||||
filter: Option<Predicate>,
|
||||
}
|
||||
|
||||
impl HyperedgeScan {
|
||||
pub fn new(mode: ScanMode, filter: Option<Predicate>) -> Self {
|
||||
Self { mode, filter }
|
||||
}
|
||||
}
|
||||
|
||||
impl Operator for HyperedgeScan {
|
||||
fn execute(&mut self, _input: Option<RowBatch>) -> Result<Option<RowBatch>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"HyperedgeScan"
|
||||
}
|
||||
}
|
||||
|
||||
/// Filter operator with SIMD-optimized predicate evaluation
|
||||
pub struct Filter {
|
||||
predicate: Predicate,
|
||||
}
|
||||
|
||||
impl Filter {
|
||||
pub fn new(predicate: Predicate) -> Self {
|
||||
Self { predicate }
|
||||
}
|
||||
|
||||
/// Evaluate predicate on a row
|
||||
fn evaluate(&self, row: &HashMap<String, Value>) -> bool {
|
||||
self.evaluate_predicate(&self.predicate, row)
|
||||
}
|
||||
|
||||
fn evaluate_predicate(&self, pred: &Predicate, row: &HashMap<String, Value>) -> bool {
|
||||
match pred {
|
||||
Predicate::Equals(col, val) => row.get(col).map(|v| v == val).unwrap_or(false),
|
||||
Predicate::NotEquals(col, val) => row.get(col).map(|v| v != val).unwrap_or(false),
|
||||
Predicate::GreaterThan(col, val) => row
|
||||
.get(col)
|
||||
.and_then(|v| v.compare(val))
|
||||
.map(|ord| ord == std::cmp::Ordering::Greater)
|
||||
.unwrap_or(false),
|
||||
Predicate::GreaterThanOrEqual(col, val) => row
|
||||
.get(col)
|
||||
.and_then(|v| v.compare(val))
|
||||
.map(|ord| ord != std::cmp::Ordering::Less)
|
||||
.unwrap_or(false),
|
||||
Predicate::LessThan(col, val) => row
|
||||
.get(col)
|
||||
.and_then(|v| v.compare(val))
|
||||
.map(|ord| ord == std::cmp::Ordering::Less)
|
||||
.unwrap_or(false),
|
||||
Predicate::LessThanOrEqual(col, val) => row
|
||||
.get(col)
|
||||
.and_then(|v| v.compare(val))
|
||||
.map(|ord| ord != std::cmp::Ordering::Greater)
|
||||
.unwrap_or(false),
|
||||
Predicate::In(col, values) => row.get(col).map(|v| values.contains(v)).unwrap_or(false),
|
||||
Predicate::Like(col, pattern) => {
|
||||
if let Some(Value::String(s)) = row.get(col) {
|
||||
self.pattern_match(s, pattern)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
Predicate::And(preds) => preds.iter().all(|p| self.evaluate_predicate(p, row)),
|
||||
Predicate::Or(preds) => preds.iter().any(|p| self.evaluate_predicate(p, row)),
|
||||
Predicate::Not(pred) => !self.evaluate_predicate(pred, row),
|
||||
}
|
||||
}
|
||||
|
||||
fn pattern_match(&self, s: &str, pattern: &str) -> bool {
|
||||
// Simple LIKE pattern matching (% = wildcard)
|
||||
if pattern.starts_with('%') && pattern.ends_with('%') {
|
||||
let p = &pattern[1..pattern.len() - 1];
|
||||
s.contains(p)
|
||||
} else if pattern.starts_with('%') {
|
||||
let p = &pattern[1..];
|
||||
s.ends_with(p)
|
||||
} else if pattern.ends_with('%') {
|
||||
let p = &pattern[..pattern.len() - 1];
|
||||
s.starts_with(p)
|
||||
} else {
|
||||
s == pattern
|
||||
}
|
||||
}
|
||||
|
||||
/// SIMD-optimized batch filtering for numeric predicates
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
fn filter_batch_simd(&self, values: &[f32], threshold: f32) -> Vec<bool> {
|
||||
if is_x86_feature_detected!("avx2") {
|
||||
unsafe { self.filter_batch_avx2(values, threshold) }
|
||||
} else {
|
||||
self.filter_batch_scalar(values, threshold)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
unsafe fn filter_batch_avx2(&self, values: &[f32], threshold: f32) -> Vec<bool> {
|
||||
let mut result = vec![false; values.len()];
|
||||
let threshold_vec = _mm256_set1_ps(threshold);
|
||||
|
||||
let chunks = values.len() / 8;
|
||||
for i in 0..chunks {
|
||||
let idx = i * 8;
|
||||
let vals = _mm256_loadu_ps(values.as_ptr().add(idx));
|
||||
let cmp = _mm256_cmp_ps(vals, threshold_vec, _CMP_GT_OQ);
|
||||
|
||||
let mask: [f32; 8] = std::mem::transmute(cmp);
|
||||
for j in 0..8 {
|
||||
result[idx + j] = mask[j] != 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle remaining elements
|
||||
for i in (chunks * 8)..values.len() {
|
||||
result[i] = values[i] > threshold;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "x86_64"))]
|
||||
fn filter_batch_simd(&self, values: &[f32], threshold: f32) -> Vec<bool> {
|
||||
self.filter_batch_scalar(values, threshold)
|
||||
}
|
||||
|
||||
fn filter_batch_scalar(&self, values: &[f32], threshold: f32) -> Vec<bool> {
|
||||
values.iter().map(|&v| v > threshold).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Operator for Filter {
|
||||
fn execute(&mut self, input: Option<RowBatch>) -> Result<Option<RowBatch>> {
|
||||
if let Some(batch) = input {
|
||||
let filtered_rows: Vec<_> = batch
|
||||
.rows
|
||||
.into_iter()
|
||||
.filter(|row| self.evaluate(row))
|
||||
.collect();
|
||||
|
||||
Ok(Some(RowBatch {
|
||||
rows: filtered_rows,
|
||||
schema: batch.schema,
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"Filter"
|
||||
}
|
||||
}
|
||||
|
||||
/// Join type
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum JoinType {
|
||||
Inner,
|
||||
LeftOuter,
|
||||
RightOuter,
|
||||
FullOuter,
|
||||
}
|
||||
|
||||
/// Join operator with hash join implementation
|
||||
pub struct Join {
|
||||
join_type: JoinType,
|
||||
on: Vec<(String, String)>,
|
||||
hash_table: HashMap<Vec<Value>, Vec<HashMap<String, Value>>>,
|
||||
built: bool,
|
||||
}
|
||||
|
||||
impl Join {
|
||||
pub fn new(join_type: JoinType, on: Vec<(String, String)>) -> Self {
|
||||
Self {
|
||||
join_type,
|
||||
on,
|
||||
hash_table: HashMap::new(),
|
||||
built: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn build_hash_table(&mut self, build_side: RowBatch) {
|
||||
for row in build_side.rows {
|
||||
let key: Vec<Value> = self
|
||||
.on
|
||||
.iter()
|
||||
.filter_map(|(_, right_col)| row.get(right_col).cloned())
|
||||
.collect();
|
||||
|
||||
self.hash_table
|
||||
.entry(key)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(row);
|
||||
}
|
||||
self.built = true;
|
||||
}
|
||||
|
||||
fn probe(&self, probe_row: &HashMap<String, Value>) -> Vec<HashMap<String, Value>> {
|
||||
let key: Vec<Value> = self
|
||||
.on
|
||||
.iter()
|
||||
.filter_map(|(left_col, _)| probe_row.get(left_col).cloned())
|
||||
.collect();
|
||||
|
||||
if let Some(matches) = self.hash_table.get(&key) {
|
||||
matches
|
||||
.iter()
|
||||
.map(|right_row| {
|
||||
let mut joined = probe_row.clone();
|
||||
joined.extend(right_row.clone());
|
||||
joined
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Operator for Join {
|
||||
fn execute(&mut self, input: Option<RowBatch>) -> Result<Option<RowBatch>> {
|
||||
// Simplified: assumes build side comes first, then probe side
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"Join"
|
||||
}
|
||||
|
||||
fn is_pipeline_breaker(&self) -> bool {
|
||||
true // Hash join needs to build hash table first
|
||||
}
|
||||
}
|
||||
|
||||
/// Aggregate function
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum AggregateFunction {
|
||||
Count,
|
||||
Sum,
|
||||
Avg,
|
||||
Min,
|
||||
Max,
|
||||
}
|
||||
|
||||
/// Aggregate operator
|
||||
pub struct Aggregate {
|
||||
group_by: Vec<String>,
|
||||
aggregates: Vec<(AggregateFunction, String)>,
|
||||
state: HashMap<Vec<Value>, Vec<f64>>,
|
||||
}
|
||||
|
||||
impl Aggregate {
|
||||
pub fn new(group_by: Vec<String>, aggregates: Vec<(AggregateFunction, String)>) -> Self {
|
||||
Self {
|
||||
group_by,
|
||||
aggregates,
|
||||
state: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Operator for Aggregate {
|
||||
fn execute(&mut self, input: Option<RowBatch>) -> Result<Option<RowBatch>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"Aggregate"
|
||||
}
|
||||
|
||||
fn is_pipeline_breaker(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Project operator (column selection)
|
||||
pub struct Project {
|
||||
columns: Vec<String>,
|
||||
}
|
||||
|
||||
impl Project {
|
||||
pub fn new(columns: Vec<String>) -> Self {
|
||||
Self { columns }
|
||||
}
|
||||
}
|
||||
|
||||
impl Operator for Project {
|
||||
fn execute(&mut self, input: Option<RowBatch>) -> Result<Option<RowBatch>> {
|
||||
if let Some(batch) = input {
|
||||
let projected: Vec<_> = batch
|
||||
.rows
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
self.columns
|
||||
.iter()
|
||||
.filter_map(|col| row.get(col).map(|v| (col.clone(), v.clone())))
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Some(RowBatch {
|
||||
rows: projected,
|
||||
schema: batch.schema,
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"Project"
|
||||
}
|
||||
}
|
||||
|
||||
/// Sort operator with external sort for large datasets
|
||||
pub struct Sort {
|
||||
order_by: Vec<(String, crate::executor::plan::SortOrder)>,
|
||||
buffer: Vec<HashMap<String, Value>>,
|
||||
}
|
||||
|
||||
impl Sort {
|
||||
pub fn new(order_by: Vec<(String, crate::executor::plan::SortOrder)>) -> Self {
|
||||
Self {
|
||||
order_by,
|
||||
buffer: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Operator for Sort {
|
||||
fn execute(&mut self, input: Option<RowBatch>) -> Result<Option<RowBatch>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"Sort"
|
||||
}
|
||||
|
||||
fn is_pipeline_breaker(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Limit operator
|
||||
pub struct Limit {
|
||||
limit: usize,
|
||||
offset: usize,
|
||||
current: usize,
|
||||
}
|
||||
|
||||
impl Limit {
|
||||
pub fn new(limit: usize, offset: usize) -> Self {
|
||||
Self {
|
||||
limit,
|
||||
offset,
|
||||
current: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Operator for Limit {
|
||||
fn execute(&mut self, input: Option<RowBatch>) -> Result<Option<RowBatch>> {
|
||||
if let Some(batch) = input {
|
||||
let start = self.offset.saturating_sub(self.current);
|
||||
let end = start + self.limit;
|
||||
|
||||
let limited: Vec<_> = batch
|
||||
.rows
|
||||
.into_iter()
|
||||
.skip(start)
|
||||
.take(end - start)
|
||||
.collect();
|
||||
|
||||
self.current += limited.len();
|
||||
|
||||
Ok(Some(RowBatch {
|
||||
rows: limited,
|
||||
schema: batch.schema,
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"Limit"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_filter_operator() {
|
||||
let mut filter = Filter::new(Predicate::Equals("id".to_string(), Value::Int64(42)));
|
||||
|
||||
let mut row = HashMap::new();
|
||||
row.insert("id".to_string(), Value::Int64(42));
|
||||
assert!(filter.evaluate(&row));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pattern_matching() {
|
||||
let filter = Filter::new(Predicate::Like("name".to_string(), "%test%".to_string()));
|
||||
assert!(filter.pattern_match("this is a test", "%test%"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simd_filtering() {
|
||||
let filter = Filter::new(Predicate::GreaterThan(
|
||||
"value".to_string(),
|
||||
Value::Float64(5.0),
|
||||
));
|
||||
let values = vec![1.0, 6.0, 3.0, 8.0, 4.0, 9.0, 2.0, 7.0];
|
||||
let result = filter.filter_batch_simd(&values, 5.0);
|
||||
assert_eq!(
|
||||
result,
|
||||
vec![false, true, false, true, false, true, false, true]
|
||||
);
|
||||
}
|
||||
}
|
||||
361
vendor/ruvector/crates/ruvector-graph/src/executor/parallel.rs
vendored
Normal file
361
vendor/ruvector/crates/ruvector-graph/src/executor/parallel.rs
vendored
Normal file
@@ -0,0 +1,361 @@
|
||||
//! Parallel query execution using rayon
|
||||
//!
|
||||
//! Implements data parallelism for graph queries
|
||||
|
||||
use crate::executor::operators::Operator;
|
||||
use crate::executor::pipeline::{ExecutionContext, RowBatch};
|
||||
use crate::executor::plan::PhysicalPlan;
|
||||
use crate::executor::{ExecutionError, Result};
|
||||
use rayon::prelude::*;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
/// Parallel execution configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ParallelConfig {
|
||||
/// Enable parallel execution
|
||||
pub enabled: bool,
|
||||
/// Number of worker threads (0 = auto-detect)
|
||||
pub num_threads: usize,
|
||||
/// Batch size for parallel processing
|
||||
pub batch_size: usize,
|
||||
}
|
||||
|
||||
impl ParallelConfig {
|
||||
/// Create new config with defaults
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
num_threads: 0, // Auto-detect
|
||||
batch_size: 1024,
|
||||
}
|
||||
}
|
||||
|
||||
/// Disable parallel execution
|
||||
pub fn sequential() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
num_threads: 1,
|
||||
batch_size: 1024,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with specific thread count
|
||||
pub fn with_threads(num_threads: usize) -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
num_threads,
|
||||
batch_size: 1024,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ParallelConfig {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Parallel query executor
|
||||
pub struct ParallelExecutor {
|
||||
config: ParallelConfig,
|
||||
thread_pool: rayon::ThreadPool,
|
||||
}
|
||||
|
||||
impl ParallelExecutor {
|
||||
/// Create a new parallel executor
|
||||
pub fn new(config: ParallelConfig) -> Self {
|
||||
let num_threads = if config.num_threads == 0 {
|
||||
num_cpus::get()
|
||||
} else {
|
||||
config.num_threads
|
||||
};
|
||||
|
||||
let thread_pool = rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(num_threads)
|
||||
.build()
|
||||
.expect("Failed to create thread pool");
|
||||
|
||||
Self {
|
||||
config,
|
||||
thread_pool,
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a physical plan in parallel
|
||||
pub fn execute(&self, plan: &PhysicalPlan) -> Result<Vec<RowBatch>> {
|
||||
if !self.config.enabled {
|
||||
return self.execute_sequential(plan);
|
||||
}
|
||||
|
||||
// Determine parallelization strategy based on plan structure
|
||||
if plan.pipeline_breakers.is_empty() {
|
||||
// No pipeline breakers - can parallelize entire pipeline
|
||||
self.execute_parallel_scan(plan)
|
||||
} else {
|
||||
// Has pipeline breakers - need to materialize intermediate results
|
||||
self.execute_parallel_staged(plan)
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute plan sequentially (fallback)
|
||||
fn execute_sequential(&self, plan: &PhysicalPlan) -> Result<Vec<RowBatch>> {
|
||||
let mut results = Vec::new();
|
||||
// Simplified sequential execution
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Parallel scan execution (for scan-heavy queries)
|
||||
fn execute_parallel_scan(&self, plan: &PhysicalPlan) -> Result<Vec<RowBatch>> {
|
||||
let results = Arc::new(Mutex::new(Vec::new()));
|
||||
let num_partitions = self.config.num_threads.max(1);
|
||||
|
||||
// Partition the scan and execute in parallel
|
||||
self.thread_pool.scope(|s| {
|
||||
for partition_id in 0..num_partitions {
|
||||
let results = Arc::clone(&results);
|
||||
s.spawn(move |_| {
|
||||
// Execute partition
|
||||
let batch = self.execute_partition(plan, partition_id, num_partitions);
|
||||
if let Ok(Some(b)) = batch {
|
||||
results.lock().unwrap().push(b);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
let final_results = Arc::try_unwrap(results)
|
||||
.map_err(|_| ExecutionError::Internal("Failed to unwrap results".to_string()))?
|
||||
.into_inner()
|
||||
.map_err(|_| ExecutionError::Internal("Failed to acquire lock".to_string()))?;
|
||||
|
||||
Ok(final_results)
|
||||
}
|
||||
|
||||
/// Execute a partition of the data
|
||||
fn execute_partition(
|
||||
&self,
|
||||
plan: &PhysicalPlan,
|
||||
partition_id: usize,
|
||||
num_partitions: usize,
|
||||
) -> Result<Option<RowBatch>> {
|
||||
// Simplified partition execution
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Staged parallel execution (for complex queries with pipeline breakers)
|
||||
fn execute_parallel_staged(&self, plan: &PhysicalPlan) -> Result<Vec<RowBatch>> {
|
||||
let mut intermediate_results = Vec::new();
|
||||
|
||||
// Execute each stage between pipeline breakers
|
||||
let mut start = 0;
|
||||
for &breaker in &plan.pipeline_breakers {
|
||||
let stage_results = self.execute_stage(plan, start, breaker)?;
|
||||
intermediate_results = stage_results;
|
||||
start = breaker + 1;
|
||||
}
|
||||
|
||||
// Execute final stage
|
||||
let final_results = self.execute_stage(plan, start, plan.operators.len())?;
|
||||
Ok(final_results)
|
||||
}
|
||||
|
||||
/// Execute a stage of operators
|
||||
fn execute_stage(
|
||||
&self,
|
||||
plan: &PhysicalPlan,
|
||||
start: usize,
|
||||
end: usize,
|
||||
) -> Result<Vec<RowBatch>> {
|
||||
// Simplified stage execution
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
/// Parallel batch processing
|
||||
pub fn process_batches_parallel<F>(
|
||||
&self,
|
||||
batches: Vec<RowBatch>,
|
||||
processor: F,
|
||||
) -> Result<Vec<RowBatch>>
|
||||
where
|
||||
F: Fn(RowBatch) -> Result<RowBatch> + Send + Sync,
|
||||
{
|
||||
let results: Vec<_> = self.thread_pool.install(|| {
|
||||
batches
|
||||
.into_par_iter()
|
||||
.map(|batch| processor(batch))
|
||||
.collect()
|
||||
});
|
||||
|
||||
// Collect results and check for errors
|
||||
results.into_iter().collect()
|
||||
}
|
||||
|
||||
/// Parallel aggregation
|
||||
pub fn aggregate_parallel<K, V, F, G>(
|
||||
&self,
|
||||
batches: Vec<RowBatch>,
|
||||
key_fn: F,
|
||||
agg_fn: G,
|
||||
) -> Result<Vec<(K, V)>>
|
||||
where
|
||||
K: Send + Sync + Eq + std::hash::Hash,
|
||||
V: Send + Sync,
|
||||
F: Fn(&RowBatch) -> K + Send + Sync,
|
||||
G: Fn(Vec<RowBatch>) -> V + Send + Sync,
|
||||
{
|
||||
use std::collections::HashMap;
|
||||
|
||||
// Group batches by key
|
||||
let mut groups: HashMap<K, Vec<RowBatch>> = HashMap::new();
|
||||
for batch in batches {
|
||||
let key = key_fn(&batch);
|
||||
groups.entry(key).or_insert_with(Vec::new).push(batch);
|
||||
}
|
||||
|
||||
// Aggregate each group in parallel
|
||||
let results: Vec<_> = self.thread_pool.install(|| {
|
||||
groups
|
||||
.into_par_iter()
|
||||
.map(|(key, batches)| (key, agg_fn(batches)))
|
||||
.collect()
|
||||
});
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Get number of worker threads
|
||||
pub fn num_threads(&self) -> usize {
|
||||
self.thread_pool.current_num_threads()
|
||||
}
|
||||
}
|
||||
|
||||
/// Parallel scan partitioner
|
||||
pub struct ScanPartitioner {
|
||||
total_rows: usize,
|
||||
num_partitions: usize,
|
||||
}
|
||||
|
||||
impl ScanPartitioner {
|
||||
/// Create a new partitioner
|
||||
pub fn new(total_rows: usize, num_partitions: usize) -> Self {
|
||||
Self {
|
||||
total_rows,
|
||||
num_partitions,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get partition range for a given partition ID
|
||||
pub fn partition_range(&self, partition_id: usize) -> (usize, usize) {
|
||||
let rows_per_partition = (self.total_rows + self.num_partitions - 1) / self.num_partitions;
|
||||
let start = partition_id * rows_per_partition;
|
||||
let end = (start + rows_per_partition).min(self.total_rows);
|
||||
(start, end)
|
||||
}
|
||||
|
||||
/// Check if partition is valid
|
||||
pub fn is_valid_partition(&self, partition_id: usize) -> bool {
|
||||
partition_id < self.num_partitions
|
||||
}
|
||||
}
|
||||
|
||||
/// Parallel join strategies
|
||||
pub enum ParallelJoinStrategy {
|
||||
/// Broadcast small table to all workers
|
||||
Broadcast,
|
||||
/// Partition both tables by join key
|
||||
PartitionedHash,
|
||||
/// Sort-merge join with parallel sort
|
||||
SortMerge,
|
||||
}
|
||||
|
||||
/// Parallel join executor
|
||||
pub struct ParallelJoin {
|
||||
strategy: ParallelJoinStrategy,
|
||||
executor: Arc<ParallelExecutor>,
|
||||
}
|
||||
|
||||
impl ParallelJoin {
|
||||
/// Create new parallel join
|
||||
pub fn new(strategy: ParallelJoinStrategy, executor: Arc<ParallelExecutor>) -> Self {
|
||||
Self { strategy, executor }
|
||||
}
|
||||
|
||||
/// Execute parallel join
|
||||
pub fn execute(&self, left: Vec<RowBatch>, right: Vec<RowBatch>) -> Result<Vec<RowBatch>> {
|
||||
match self.strategy {
|
||||
ParallelJoinStrategy::Broadcast => self.broadcast_join(left, right),
|
||||
ParallelJoinStrategy::PartitionedHash => self.partitioned_hash_join(left, right),
|
||||
ParallelJoinStrategy::SortMerge => self.sort_merge_join(left, right),
|
||||
}
|
||||
}
|
||||
|
||||
fn broadcast_join(&self, left: Vec<RowBatch>, right: Vec<RowBatch>) -> Result<Vec<RowBatch>> {
|
||||
// Broadcast smaller side to all workers
|
||||
let (build_side, probe_side) = if left.len() < right.len() {
|
||||
(left, right)
|
||||
} else {
|
||||
(right, left)
|
||||
};
|
||||
|
||||
// Simplified implementation
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
fn partitioned_hash_join(
|
||||
&self,
|
||||
left: Vec<RowBatch>,
|
||||
right: Vec<RowBatch>,
|
||||
) -> Result<Vec<RowBatch>> {
|
||||
// Partition both sides by join key
|
||||
// Each partition is processed independently
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
fn sort_merge_join(&self, left: Vec<RowBatch>, right: Vec<RowBatch>) -> Result<Vec<RowBatch>> {
|
||||
// Sort both sides in parallel, then merge
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parallel_config() {
|
||||
let config = ParallelConfig::new();
|
||||
assert!(config.enabled);
|
||||
assert_eq!(config.num_threads, 0);
|
||||
|
||||
let seq_config = ParallelConfig::sequential();
|
||||
assert!(!seq_config.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parallel_executor_creation() {
|
||||
let config = ParallelConfig::with_threads(4);
|
||||
let executor = ParallelExecutor::new(config);
|
||||
assert_eq!(executor.num_threads(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scan_partitioner() {
|
||||
let partitioner = ScanPartitioner::new(100, 4);
|
||||
|
||||
let (start, end) = partitioner.partition_range(0);
|
||||
assert_eq!(start, 0);
|
||||
assert_eq!(end, 25);
|
||||
|
||||
let (start, end) = partitioner.partition_range(3);
|
||||
assert_eq!(start, 75);
|
||||
assert_eq!(end, 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partition_validity() {
|
||||
let partitioner = ScanPartitioner::new(100, 4);
|
||||
assert!(partitioner.is_valid_partition(0));
|
||||
assert!(partitioner.is_valid_partition(3));
|
||||
assert!(!partitioner.is_valid_partition(4));
|
||||
}
|
||||
}
|
||||
336
vendor/ruvector/crates/ruvector-graph/src/executor/pipeline.rs
vendored
Normal file
336
vendor/ruvector/crates/ruvector-graph/src/executor/pipeline.rs
vendored
Normal file
@@ -0,0 +1,336 @@
|
||||
//! Pipeline execution model with Volcano-style iterators
|
||||
//!
|
||||
//! Implements pull-based query execution with row batching
|
||||
|
||||
use crate::executor::operators::Operator;
|
||||
use crate::executor::plan::Value;
|
||||
use crate::executor::plan::{PhysicalPlan, QuerySchema};
|
||||
use crate::executor::{ExecutionError, Result};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Batch size for vectorized execution
|
||||
const DEFAULT_BATCH_SIZE: usize = 1024;
|
||||
|
||||
/// Row batch for vectorized processing
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RowBatch {
|
||||
pub rows: Vec<HashMap<String, Value>>,
|
||||
pub schema: QuerySchema,
|
||||
}
|
||||
|
||||
impl RowBatch {
|
||||
/// Create a new row batch
|
||||
pub fn new(schema: QuerySchema) -> Self {
|
||||
Self {
|
||||
rows: Vec::with_capacity(DEFAULT_BATCH_SIZE),
|
||||
schema,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create batch with rows
|
||||
pub fn with_rows(rows: Vec<HashMap<String, Value>>, schema: QuerySchema) -> Self {
|
||||
Self { rows, schema }
|
||||
}
|
||||
|
||||
/// Add a row to the batch
|
||||
pub fn add_row(&mut self, row: HashMap<String, Value>) {
|
||||
self.rows.push(row);
|
||||
}
|
||||
|
||||
/// Check if batch is full
|
||||
pub fn is_full(&self) -> bool {
|
||||
self.rows.len() >= DEFAULT_BATCH_SIZE
|
||||
}
|
||||
|
||||
/// Get number of rows
|
||||
pub fn len(&self) -> usize {
|
||||
self.rows.len()
|
||||
}
|
||||
|
||||
/// Check if batch is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.rows.is_empty()
|
||||
}
|
||||
|
||||
/// Clear the batch
|
||||
pub fn clear(&mut self) {
|
||||
self.rows.clear();
|
||||
}
|
||||
|
||||
/// Merge another batch into this one
|
||||
pub fn merge(&mut self, other: RowBatch) {
|
||||
self.rows.extend(other.rows);
|
||||
}
|
||||
}
|
||||
|
||||
/// Execution context for query pipeline
|
||||
pub struct ExecutionContext {
|
||||
/// Memory limit for execution
|
||||
pub memory_limit: usize,
|
||||
/// Current memory usage
|
||||
pub memory_used: usize,
|
||||
/// Batch size
|
||||
pub batch_size: usize,
|
||||
/// Enable query profiling
|
||||
pub enable_profiling: bool,
|
||||
}
|
||||
|
||||
impl ExecutionContext {
|
||||
/// Create new execution context
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
memory_limit: 1024 * 1024 * 1024, // 1GB default
|
||||
memory_used: 0,
|
||||
batch_size: DEFAULT_BATCH_SIZE,
|
||||
enable_profiling: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with custom memory limit
|
||||
pub fn with_memory_limit(memory_limit: usize) -> Self {
|
||||
Self {
|
||||
memory_limit,
|
||||
memory_used: 0,
|
||||
batch_size: DEFAULT_BATCH_SIZE,
|
||||
enable_profiling: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if memory limit exceeded
|
||||
pub fn check_memory(&self) -> Result<()> {
|
||||
if self.memory_used > self.memory_limit {
|
||||
Err(ExecutionError::ResourceExhausted(format!(
|
||||
"Memory limit exceeded: {} > {}",
|
||||
self.memory_used, self.memory_limit
|
||||
)))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate memory
|
||||
pub fn allocate(&mut self, bytes: usize) -> Result<()> {
|
||||
self.memory_used += bytes;
|
||||
self.check_memory()
|
||||
}
|
||||
|
||||
/// Free memory
|
||||
pub fn free(&mut self, bytes: usize) {
|
||||
self.memory_used = self.memory_used.saturating_sub(bytes);
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ExecutionContext {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Pipeline executor using Volcano iterator model
|
||||
pub struct Pipeline {
|
||||
plan: PhysicalPlan,
|
||||
operators: Vec<Box<dyn Operator>>,
|
||||
current_operator: usize,
|
||||
context: ExecutionContext,
|
||||
finished: bool,
|
||||
}
|
||||
|
||||
impl Pipeline {
|
||||
/// Create a new pipeline from physical plan (takes ownership of operators)
|
||||
pub fn new(mut plan: PhysicalPlan) -> Self {
|
||||
let operators = std::mem::take(&mut plan.operators);
|
||||
Self {
|
||||
operators,
|
||||
plan,
|
||||
current_operator: 0,
|
||||
context: ExecutionContext::new(),
|
||||
finished: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create pipeline with custom context (takes ownership of operators)
|
||||
pub fn with_context(mut plan: PhysicalPlan, context: ExecutionContext) -> Self {
|
||||
let operators = std::mem::take(&mut plan.operators);
|
||||
Self {
|
||||
operators,
|
||||
plan,
|
||||
current_operator: 0,
|
||||
context,
|
||||
finished: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get next batch from pipeline
|
||||
pub fn next(&mut self) -> Result<Option<RowBatch>> {
|
||||
if self.finished {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Execute pipeline in pull-based fashion
|
||||
let result = self.execute_pipeline()?;
|
||||
|
||||
if result.is_none() {
|
||||
self.finished = true;
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Execute the full pipeline
|
||||
fn execute_pipeline(&mut self) -> Result<Option<RowBatch>> {
|
||||
if self.operators.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Start with the first operator (scan)
|
||||
let mut current_batch = self.operators[0].execute(None)?;
|
||||
|
||||
// Pipeline the batch through remaining operators
|
||||
for operator in &mut self.operators[1..] {
|
||||
if let Some(batch) = current_batch {
|
||||
current_batch = operator.execute(Some(batch))?;
|
||||
} else {
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(current_batch)
|
||||
}
|
||||
|
||||
/// Reset pipeline for re-execution
|
||||
pub fn reset(&mut self) {
|
||||
self.current_operator = 0;
|
||||
self.finished = false;
|
||||
self.context = ExecutionContext::new();
|
||||
}
|
||||
|
||||
/// Get execution context
|
||||
pub fn context(&self) -> &ExecutionContext {
|
||||
&self.context
|
||||
}
|
||||
|
||||
/// Get mutable execution context
|
||||
pub fn context_mut(&mut self) -> &mut ExecutionContext {
|
||||
&mut self.context
|
||||
}
|
||||
}
|
||||
|
||||
/// Pipeline builder for constructing execution pipelines
|
||||
pub struct PipelineBuilder {
|
||||
operators: Vec<Box<dyn Operator>>,
|
||||
context: ExecutionContext,
|
||||
}
|
||||
|
||||
impl PipelineBuilder {
|
||||
/// Create a new pipeline builder
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
operators: Vec::new(),
|
||||
context: ExecutionContext::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an operator to the pipeline
|
||||
pub fn add_operator(mut self, operator: Box<dyn Operator>) -> Self {
|
||||
self.operators.push(operator);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set execution context
|
||||
pub fn with_context(mut self, context: ExecutionContext) -> Self {
|
||||
self.context = context;
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the pipeline
|
||||
pub fn build(self) -> Pipeline {
|
||||
let plan = PhysicalPlan {
|
||||
operators: self.operators,
|
||||
pipeline_breakers: Vec::new(),
|
||||
parallelism: 1,
|
||||
};
|
||||
|
||||
Pipeline::with_context(plan, self.context)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PipelineBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Iterator adapter for pipeline
|
||||
pub struct PipelineIterator {
|
||||
pipeline: Pipeline,
|
||||
}
|
||||
|
||||
impl PipelineIterator {
|
||||
pub fn new(pipeline: Pipeline) -> Self {
|
||||
Self { pipeline }
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for PipelineIterator {
|
||||
type Item = Result<RowBatch>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
match self.pipeline.next() {
|
||||
Ok(Some(batch)) => Some(Ok(batch)),
|
||||
Ok(None) => None,
|
||||
Err(e) => Some(Err(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::executor::plan::ColumnDef;
|
||||
use crate::executor::plan::DataType;
|
||||
|
||||
#[test]
|
||||
fn test_row_batch() {
|
||||
let schema = QuerySchema::new(vec![ColumnDef {
|
||||
name: "id".to_string(),
|
||||
data_type: DataType::Int64,
|
||||
nullable: false,
|
||||
}]);
|
||||
|
||||
let mut batch = RowBatch::new(schema);
|
||||
assert!(batch.is_empty());
|
||||
|
||||
let mut row = HashMap::new();
|
||||
row.insert("id".to_string(), Value::Int64(1));
|
||||
batch.add_row(row);
|
||||
|
||||
assert_eq!(batch.len(), 1);
|
||||
assert!(!batch.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execution_context() {
|
||||
let mut ctx = ExecutionContext::new();
|
||||
assert_eq!(ctx.memory_used, 0);
|
||||
|
||||
ctx.allocate(1024).unwrap();
|
||||
assert_eq!(ctx.memory_used, 1024);
|
||||
|
||||
ctx.free(512);
|
||||
assert_eq!(ctx.memory_used, 512);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_limit() {
|
||||
let mut ctx = ExecutionContext::with_memory_limit(1000);
|
||||
assert!(ctx.allocate(500).is_ok());
|
||||
assert!(ctx.allocate(600).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pipeline_builder() {
|
||||
let builder = PipelineBuilder::new();
|
||||
let pipeline = builder.build();
|
||||
assert_eq!(pipeline.operators.len(), 0);
|
||||
}
|
||||
}
|
||||
391
vendor/ruvector/crates/ruvector-graph/src/executor/plan.rs
vendored
Normal file
391
vendor/ruvector/crates/ruvector-graph/src/executor/plan.rs
vendored
Normal file
@@ -0,0 +1,391 @@
|
||||
//! Query execution plan representation
|
||||
//!
|
||||
//! Provides logical and physical query plan structures for graph queries
|
||||
|
||||
use crate::executor::operators::{AggregateFunction, JoinType, Operator, ScanMode};
|
||||
use crate::executor::stats::Statistics;
|
||||
use crate::executor::{ExecutionError, Result};
|
||||
use ordered_float::OrderedFloat;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
/// Logical query plan (high-level, optimizer input)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LogicalPlan {
|
||||
pub root: PlanNode,
|
||||
pub schema: QuerySchema,
|
||||
}
|
||||
|
||||
impl LogicalPlan {
|
||||
/// Create a new logical plan
|
||||
pub fn new(root: PlanNode, schema: QuerySchema) -> Self {
|
||||
Self { root, schema }
|
||||
}
|
||||
|
||||
/// Generate cache key for this plan
|
||||
pub fn cache_key(&self) -> String {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
format!("{:?}", self).hash(&mut hasher);
|
||||
format!("plan_{:x}", hasher.finish())
|
||||
}
|
||||
|
||||
/// Check if plan can be parallelized
|
||||
pub fn is_parallelizable(&self) -> bool {
|
||||
self.root.is_parallelizable()
|
||||
}
|
||||
|
||||
/// Estimate output cardinality
|
||||
pub fn estimate_cardinality(&self) -> usize {
|
||||
self.root.estimate_cardinality()
|
||||
}
|
||||
}
|
||||
|
||||
/// Physical query plan (low-level, executor input)
|
||||
pub struct PhysicalPlan {
|
||||
pub operators: Vec<Box<dyn Operator>>,
|
||||
pub pipeline_breakers: Vec<usize>,
|
||||
pub parallelism: usize,
|
||||
}
|
||||
|
||||
impl fmt::Debug for PhysicalPlan {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("PhysicalPlan")
|
||||
.field("operator_count", &self.operators.len())
|
||||
.field("pipeline_breakers", &self.pipeline_breakers)
|
||||
.field("parallelism", &self.parallelism)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl PhysicalPlan {
|
||||
/// Create physical plan from logical plan
|
||||
pub fn from_logical(logical: &LogicalPlan, stats: &Statistics) -> Result<Self> {
|
||||
let mut operators = Vec::new();
|
||||
let mut pipeline_breakers = Vec::new();
|
||||
|
||||
Self::compile_node(&logical.root, stats, &mut operators, &mut pipeline_breakers)?;
|
||||
|
||||
let parallelism = if logical.is_parallelizable() {
|
||||
num_cpus::get()
|
||||
} else {
|
||||
1
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
operators,
|
||||
pipeline_breakers,
|
||||
parallelism,
|
||||
})
|
||||
}
|
||||
|
||||
fn compile_node(
|
||||
node: &PlanNode,
|
||||
stats: &Statistics,
|
||||
operators: &mut Vec<Box<dyn Operator>>,
|
||||
pipeline_breakers: &mut Vec<usize>,
|
||||
) -> Result<()> {
|
||||
match node {
|
||||
PlanNode::NodeScan { mode, filter } => {
|
||||
// Add scan operator
|
||||
operators.push(Box::new(crate::executor::operators::NodeScan::new(
|
||||
mode.clone(),
|
||||
filter.clone(),
|
||||
)));
|
||||
}
|
||||
PlanNode::EdgeScan { mode, filter } => {
|
||||
operators.push(Box::new(crate::executor::operators::EdgeScan::new(
|
||||
mode.clone(),
|
||||
filter.clone(),
|
||||
)));
|
||||
}
|
||||
PlanNode::Filter { input, predicate } => {
|
||||
Self::compile_node(input, stats, operators, pipeline_breakers)?;
|
||||
operators.push(Box::new(crate::executor::operators::Filter::new(
|
||||
predicate.clone(),
|
||||
)));
|
||||
}
|
||||
PlanNode::Join {
|
||||
left,
|
||||
right,
|
||||
join_type,
|
||||
on,
|
||||
} => {
|
||||
Self::compile_node(left, stats, operators, pipeline_breakers)?;
|
||||
pipeline_breakers.push(operators.len());
|
||||
Self::compile_node(right, stats, operators, pipeline_breakers)?;
|
||||
operators.push(Box::new(crate::executor::operators::Join::new(
|
||||
*join_type,
|
||||
on.clone(),
|
||||
)));
|
||||
}
|
||||
PlanNode::Aggregate {
|
||||
input,
|
||||
group_by,
|
||||
aggregates,
|
||||
} => {
|
||||
Self::compile_node(input, stats, operators, pipeline_breakers)?;
|
||||
pipeline_breakers.push(operators.len());
|
||||
operators.push(Box::new(crate::executor::operators::Aggregate::new(
|
||||
group_by.clone(),
|
||||
aggregates.clone(),
|
||||
)));
|
||||
}
|
||||
PlanNode::Sort { input, order_by } => {
|
||||
Self::compile_node(input, stats, operators, pipeline_breakers)?;
|
||||
pipeline_breakers.push(operators.len());
|
||||
operators.push(Box::new(crate::executor::operators::Sort::new(
|
||||
order_by.clone(),
|
||||
)));
|
||||
}
|
||||
PlanNode::Limit {
|
||||
input,
|
||||
limit,
|
||||
offset,
|
||||
} => {
|
||||
Self::compile_node(input, stats, operators, pipeline_breakers)?;
|
||||
operators.push(Box::new(crate::executor::operators::Limit::new(
|
||||
*limit, *offset,
|
||||
)));
|
||||
}
|
||||
PlanNode::Project { input, columns } => {
|
||||
Self::compile_node(input, stats, operators, pipeline_breakers)?;
|
||||
operators.push(Box::new(crate::executor::operators::Project::new(
|
||||
columns.clone(),
|
||||
)));
|
||||
}
|
||||
PlanNode::HyperedgeScan { mode, filter } => {
|
||||
operators.push(Box::new(crate::executor::operators::HyperedgeScan::new(
|
||||
mode.clone(),
|
||||
filter.clone(),
|
||||
)));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Plan node types
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PlanNode {
|
||||
/// Sequential or index-based node scan
|
||||
NodeScan {
|
||||
mode: ScanMode,
|
||||
filter: Option<Predicate>,
|
||||
},
|
||||
/// Edge scan
|
||||
EdgeScan {
|
||||
mode: ScanMode,
|
||||
filter: Option<Predicate>,
|
||||
},
|
||||
/// Hyperedge scan
|
||||
HyperedgeScan {
|
||||
mode: ScanMode,
|
||||
filter: Option<Predicate>,
|
||||
},
|
||||
/// Filter rows by predicate
|
||||
Filter {
|
||||
input: Box<PlanNode>,
|
||||
predicate: Predicate,
|
||||
},
|
||||
/// Join two inputs
|
||||
Join {
|
||||
left: Box<PlanNode>,
|
||||
right: Box<PlanNode>,
|
||||
join_type: JoinType,
|
||||
on: Vec<(String, String)>,
|
||||
},
|
||||
/// Aggregate with grouping
|
||||
Aggregate {
|
||||
input: Box<PlanNode>,
|
||||
group_by: Vec<String>,
|
||||
aggregates: Vec<(AggregateFunction, String)>,
|
||||
},
|
||||
/// Sort results
|
||||
Sort {
|
||||
input: Box<PlanNode>,
|
||||
order_by: Vec<(String, SortOrder)>,
|
||||
},
|
||||
/// Limit and offset
|
||||
Limit {
|
||||
input: Box<PlanNode>,
|
||||
limit: usize,
|
||||
offset: usize,
|
||||
},
|
||||
/// Project columns
|
||||
Project {
|
||||
input: Box<PlanNode>,
|
||||
columns: Vec<String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl PlanNode {
|
||||
/// Check if node can be parallelized
|
||||
pub fn is_parallelizable(&self) -> bool {
|
||||
match self {
|
||||
PlanNode::NodeScan { .. } => true,
|
||||
PlanNode::EdgeScan { .. } => true,
|
||||
PlanNode::HyperedgeScan { .. } => true,
|
||||
PlanNode::Filter { input, .. } => input.is_parallelizable(),
|
||||
PlanNode::Join { .. } => true,
|
||||
PlanNode::Aggregate { .. } => true,
|
||||
PlanNode::Sort { .. } => true,
|
||||
PlanNode::Limit { .. } => false,
|
||||
PlanNode::Project { input, .. } => input.is_parallelizable(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate output cardinality
|
||||
pub fn estimate_cardinality(&self) -> usize {
|
||||
match self {
|
||||
PlanNode::NodeScan { .. } => 1000, // Placeholder
|
||||
PlanNode::EdgeScan { .. } => 5000,
|
||||
PlanNode::HyperedgeScan { .. } => 500,
|
||||
PlanNode::Filter { input, .. } => input.estimate_cardinality() / 10,
|
||||
PlanNode::Join { left, right, .. } => {
|
||||
left.estimate_cardinality() * right.estimate_cardinality() / 100
|
||||
}
|
||||
PlanNode::Aggregate { input, .. } => input.estimate_cardinality() / 20,
|
||||
PlanNode::Sort { input, .. } => input.estimate_cardinality(),
|
||||
PlanNode::Limit { limit, .. } => *limit,
|
||||
PlanNode::Project { input, .. } => input.estimate_cardinality(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Query schema definition
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QuerySchema {
|
||||
pub columns: Vec<ColumnDef>,
|
||||
}
|
||||
|
||||
impl QuerySchema {
|
||||
pub fn new(columns: Vec<ColumnDef>) -> Self {
|
||||
Self { columns }
|
||||
}
|
||||
}
|
||||
|
||||
/// Column definition
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ColumnDef {
|
||||
pub name: String,
|
||||
pub data_type: DataType,
|
||||
pub nullable: bool,
|
||||
}
|
||||
|
||||
/// Data types supported in query execution
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum DataType {
|
||||
Int64,
|
||||
Float64,
|
||||
String,
|
||||
Boolean,
|
||||
Bytes,
|
||||
List(Box<DataType>),
|
||||
}
|
||||
|
||||
/// Sort order
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum SortOrder {
|
||||
Ascending,
|
||||
Descending,
|
||||
}
|
||||
|
||||
/// Query predicate for filtering
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Predicate {
|
||||
/// column = value
|
||||
Equals(String, Value),
|
||||
/// column != value
|
||||
NotEquals(String, Value),
|
||||
/// column > value
|
||||
GreaterThan(String, Value),
|
||||
/// column >= value
|
||||
GreaterThanOrEqual(String, Value),
|
||||
/// column < value
|
||||
LessThan(String, Value),
|
||||
/// column <= value
|
||||
LessThanOrEqual(String, Value),
|
||||
/// column IN (values)
|
||||
In(String, Vec<Value>),
|
||||
/// column LIKE pattern
|
||||
Like(String, String),
|
||||
/// AND predicates
|
||||
And(Vec<Predicate>),
|
||||
/// OR predicates
|
||||
Or(Vec<Predicate>),
|
||||
/// NOT predicate
|
||||
Not(Box<Predicate>),
|
||||
}
|
||||
|
||||
/// Runtime value
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum Value {
|
||||
Int64(i64),
|
||||
Float64(f64),
|
||||
String(String),
|
||||
Boolean(bool),
|
||||
Bytes(Vec<u8>),
|
||||
Null,
|
||||
}
|
||||
|
||||
impl Eq for Value {}
|
||||
|
||||
impl Hash for Value {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
std::mem::discriminant(self).hash(state);
|
||||
match self {
|
||||
Value::Int64(v) => v.hash(state),
|
||||
Value::Float64(v) => OrderedFloat(*v).hash(state),
|
||||
Value::String(v) => v.hash(state),
|
||||
Value::Boolean(v) => v.hash(state),
|
||||
Value::Bytes(v) => v.hash(state),
|
||||
Value::Null => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Value {
|
||||
/// Compare values for predicate evaluation
|
||||
pub fn compare(&self, other: &Value) -> Option<std::cmp::Ordering> {
|
||||
match (self, other) {
|
||||
(Value::Int64(a), Value::Int64(b)) => Some(a.cmp(b)),
|
||||
(Value::Float64(a), Value::Float64(b)) => a.partial_cmp(b),
|
||||
(Value::String(a), Value::String(b)) => Some(a.cmp(b)),
|
||||
(Value::Boolean(a), Value::Boolean(b)) => Some(a.cmp(b)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_logical_plan_creation() {
|
||||
let schema = QuerySchema::new(vec![ColumnDef {
|
||||
name: "id".to_string(),
|
||||
data_type: DataType::Int64,
|
||||
nullable: false,
|
||||
}]);
|
||||
|
||||
let plan = LogicalPlan::new(
|
||||
PlanNode::NodeScan {
|
||||
mode: ScanMode::Sequential,
|
||||
filter: None,
|
||||
},
|
||||
schema,
|
||||
);
|
||||
|
||||
assert!(plan.is_parallelizable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_value_comparison() {
|
||||
let v1 = Value::Int64(42);
|
||||
let v2 = Value::Int64(100);
|
||||
assert_eq!(v1.compare(&v2), Some(std::cmp::Ordering::Less));
|
||||
}
|
||||
}
|
||||
400
vendor/ruvector/crates/ruvector-graph/src/executor/stats.rs
vendored
Normal file
400
vendor/ruvector/crates/ruvector-graph/src/executor/stats.rs
vendored
Normal file
@@ -0,0 +1,400 @@
|
||||
//! Statistics collection for cost-based query optimization
|
||||
//!
|
||||
//! Maintains table and column statistics for query planning
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::RwLock;
|
||||
|
||||
/// Statistics manager for query optimization
|
||||
pub struct Statistics {
|
||||
/// Table-level statistics
|
||||
tables: RwLock<HashMap<String, TableStats>>,
|
||||
/// Column-level statistics
|
||||
columns: RwLock<HashMap<String, ColumnStats>>,
|
||||
}
|
||||
|
||||
impl Statistics {
|
||||
/// Create a new statistics manager
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tables: RwLock::new(HashMap::new()),
|
||||
columns: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update table statistics
|
||||
pub fn update_table_stats(&self, table_name: String, stats: TableStats) {
|
||||
self.tables.write().unwrap().insert(table_name, stats);
|
||||
}
|
||||
|
||||
/// Get table statistics
|
||||
pub fn get_table_stats(&self, table_name: &str) -> Option<TableStats> {
|
||||
self.tables.read().unwrap().get(table_name).cloned()
|
||||
}
|
||||
|
||||
/// Update column statistics
|
||||
pub fn update_column_stats(&self, column_key: String, stats: ColumnStats) {
|
||||
self.columns.write().unwrap().insert(column_key, stats);
|
||||
}
|
||||
|
||||
/// Get column statistics
|
||||
pub fn get_column_stats(&self, column_key: &str) -> Option<ColumnStats> {
|
||||
self.columns.read().unwrap().get(column_key).cloned()
|
||||
}
|
||||
|
||||
/// Check if statistics are empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.tables.read().unwrap().is_empty() && self.columns.read().unwrap().is_empty()
|
||||
}
|
||||
|
||||
/// Clear all statistics
|
||||
pub fn clear(&self) {
|
||||
self.tables.write().unwrap().clear();
|
||||
self.columns.write().unwrap().clear();
|
||||
}
|
||||
|
||||
/// Estimate join selectivity
|
||||
pub fn estimate_join_selectivity(
|
||||
&self,
|
||||
left_table: &str,
|
||||
right_table: &str,
|
||||
join_column: &str,
|
||||
) -> f64 {
|
||||
let left_stats = self.get_table_stats(left_table);
|
||||
let right_stats = self.get_table_stats(right_table);
|
||||
|
||||
if let (Some(left), Some(right)) = (left_stats, right_stats) {
|
||||
// Simple selectivity estimate based on cardinalities
|
||||
let left_ndv = left.row_count as f64;
|
||||
let right_ndv = right.row_count as f64;
|
||||
|
||||
if left_ndv > 0.0 && right_ndv > 0.0 {
|
||||
1.0 / left_ndv.max(right_ndv)
|
||||
} else {
|
||||
0.1 // Default selectivity
|
||||
}
|
||||
} else {
|
||||
0.1 // Default selectivity when stats not available
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate filter selectivity
|
||||
pub fn estimate_filter_selectivity(&self, column_key: &str, operator: &str) -> f64 {
|
||||
if let Some(stats) = self.get_column_stats(column_key) {
|
||||
match operator {
|
||||
"=" => 1.0 / stats.ndv.max(1) as f64,
|
||||
">" | "<" => 0.33,
|
||||
">=" | "<=" => 0.33,
|
||||
"!=" => 1.0 - (1.0 / stats.ndv.max(1) as f64),
|
||||
"LIKE" => 0.1,
|
||||
"IN" => 0.2,
|
||||
_ => 0.1,
|
||||
}
|
||||
} else {
|
||||
0.1 // Default selectivity
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Statistics {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Table-level statistics
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TableStats {
|
||||
/// Total number of rows
|
||||
pub row_count: usize,
|
||||
/// Average row size in bytes
|
||||
pub avg_row_size: usize,
|
||||
/// Total table size in bytes
|
||||
pub total_size: usize,
|
||||
/// Number of distinct values (for single-column tables)
|
||||
pub ndv: usize,
|
||||
/// Last update timestamp
|
||||
pub last_updated: std::time::SystemTime,
|
||||
}
|
||||
|
||||
impl TableStats {
|
||||
/// Create new table statistics
|
||||
pub fn new(row_count: usize, avg_row_size: usize) -> Self {
|
||||
Self {
|
||||
row_count,
|
||||
avg_row_size,
|
||||
total_size: row_count * avg_row_size,
|
||||
ndv: row_count, // Conservative estimate
|
||||
last_updated: std::time::SystemTime::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update row count
|
||||
pub fn update_row_count(&mut self, row_count: usize) {
|
||||
self.row_count = row_count;
|
||||
self.total_size = row_count * self.avg_row_size;
|
||||
self.last_updated = std::time::SystemTime::now();
|
||||
}
|
||||
|
||||
/// Estimate scan cost (relative units)
|
||||
pub fn estimate_scan_cost(&self) -> f64 {
|
||||
self.row_count as f64 * 0.001 // Simplified cost model
|
||||
}
|
||||
}
|
||||
|
||||
/// Column-level statistics
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ColumnStats {
|
||||
/// Number of distinct values
|
||||
pub ndv: usize,
|
||||
/// Number of null values
|
||||
pub null_count: usize,
|
||||
/// Minimum value (for ordered types)
|
||||
pub min_value: Option<ColumnValue>,
|
||||
/// Maximum value (for ordered types)
|
||||
pub max_value: Option<ColumnValue>,
|
||||
/// Histogram for distribution
|
||||
pub histogram: Option<Histogram>,
|
||||
/// Most common values and their frequencies
|
||||
pub mcv: Vec<(ColumnValue, usize)>,
|
||||
}
|
||||
|
||||
impl ColumnStats {
|
||||
/// Create new column statistics
|
||||
pub fn new(ndv: usize, null_count: usize) -> Self {
|
||||
Self {
|
||||
ndv,
|
||||
null_count,
|
||||
min_value: None,
|
||||
max_value: None,
|
||||
histogram: None,
|
||||
mcv: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set min/max values
|
||||
pub fn with_range(mut self, min: ColumnValue, max: ColumnValue) -> Self {
|
||||
self.min_value = Some(min);
|
||||
self.max_value = Some(max);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set histogram
|
||||
pub fn with_histogram(mut self, histogram: Histogram) -> Self {
|
||||
self.histogram = Some(histogram);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set most common values
|
||||
pub fn with_mcv(mut self, mcv: Vec<(ColumnValue, usize)>) -> Self {
|
||||
self.mcv = mcv;
|
||||
self
|
||||
}
|
||||
|
||||
/// Estimate selectivity for equality predicate
|
||||
pub fn estimate_equality_selectivity(&self, value: &ColumnValue) -> f64 {
|
||||
// Check if value is in MCV
|
||||
for (mcv_val, freq) in &self.mcv {
|
||||
if mcv_val == value {
|
||||
return *freq as f64 / self.ndv as f64;
|
||||
}
|
||||
}
|
||||
|
||||
// Default: uniform distribution assumption
|
||||
if self.ndv > 0 {
|
||||
1.0 / self.ndv as f64
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate selectivity for range predicate
|
||||
pub fn estimate_range_selectivity(&self, start: &ColumnValue, end: &ColumnValue) -> f64 {
|
||||
if let Some(histogram) = &self.histogram {
|
||||
histogram.estimate_range_selectivity(start, end)
|
||||
} else {
|
||||
0.33 // Default for range queries
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Column value for statistics
|
||||
#[derive(Debug, Clone, PartialEq, PartialOrd)]
|
||||
pub enum ColumnValue {
|
||||
Int64(i64),
|
||||
Float64(f64),
|
||||
String(String),
|
||||
Boolean(bool),
|
||||
}
|
||||
|
||||
/// Histogram for data distribution
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Histogram {
|
||||
/// Histogram buckets
|
||||
pub buckets: Vec<HistogramBucket>,
|
||||
/// Total number of values
|
||||
pub total_count: usize,
|
||||
}
|
||||
|
||||
impl Histogram {
|
||||
/// Create new histogram
|
||||
pub fn new(buckets: Vec<HistogramBucket>, total_count: usize) -> Self {
|
||||
Self {
|
||||
buckets,
|
||||
total_count,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create equi-width histogram
|
||||
pub fn equi_width(min: f64, max: f64, num_buckets: usize, values: &[f64]) -> Self {
|
||||
let width = (max - min) / num_buckets as f64;
|
||||
let mut buckets = Vec::with_capacity(num_buckets);
|
||||
|
||||
for i in 0..num_buckets {
|
||||
let lower = min + i as f64 * width;
|
||||
let upper = if i == num_buckets - 1 {
|
||||
max
|
||||
} else {
|
||||
min + (i + 1) as f64 * width
|
||||
};
|
||||
|
||||
let count = values.iter().filter(|&&v| v >= lower && v < upper).count();
|
||||
// Estimate NDV by counting unique values (using BTreeSet to avoid Hash requirement)
|
||||
let ndv = values
|
||||
.iter()
|
||||
.filter(|&&v| v >= lower && v < upper)
|
||||
.map(|&v| ordered_float::OrderedFloat(v))
|
||||
.collect::<std::collections::BTreeSet<_>>()
|
||||
.len();
|
||||
|
||||
buckets.push(HistogramBucket {
|
||||
lower_bound: ColumnValue::Float64(lower),
|
||||
upper_bound: ColumnValue::Float64(upper),
|
||||
count,
|
||||
ndv,
|
||||
});
|
||||
}
|
||||
|
||||
Self {
|
||||
buckets,
|
||||
total_count: values.len(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate selectivity for range query
|
||||
pub fn estimate_range_selectivity(&self, start: &ColumnValue, end: &ColumnValue) -> f64 {
|
||||
if self.total_count == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mut matching_count = 0;
|
||||
for bucket in &self.buckets {
|
||||
if bucket.overlaps(start, end) {
|
||||
matching_count += bucket.count;
|
||||
}
|
||||
}
|
||||
|
||||
matching_count as f64 / self.total_count as f64
|
||||
}
|
||||
|
||||
/// Get number of buckets
|
||||
pub fn num_buckets(&self) -> usize {
|
||||
self.buckets.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Histogram bucket
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HistogramBucket {
|
||||
/// Lower bound (inclusive)
|
||||
pub lower_bound: ColumnValue,
|
||||
/// Upper bound (exclusive, except for last bucket)
|
||||
pub upper_bound: ColumnValue,
|
||||
/// Number of values in bucket
|
||||
pub count: usize,
|
||||
/// Number of distinct values in bucket
|
||||
pub ndv: usize,
|
||||
}
|
||||
|
||||
impl HistogramBucket {
|
||||
/// Check if bucket overlaps with range
|
||||
pub fn overlaps(&self, start: &ColumnValue, end: &ColumnValue) -> bool {
|
||||
// Simplified overlap check
|
||||
self.lower_bound <= *end && self.upper_bound >= *start
|
||||
}
|
||||
|
||||
/// Get bucket width (for numeric types)
|
||||
pub fn width(&self) -> Option<f64> {
|
||||
match (&self.lower_bound, &self.upper_bound) {
|
||||
(ColumnValue::Float64(lower), ColumnValue::Float64(upper)) => Some(upper - lower),
|
||||
(ColumnValue::Int64(lower), ColumnValue::Int64(upper)) => Some((upper - lower) as f64),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_statistics_creation() {
|
||||
let stats = Statistics::new();
|
||||
assert!(stats.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_table_stats() {
|
||||
let stats = Statistics::new();
|
||||
let table_stats = TableStats::new(1000, 128);
|
||||
|
||||
stats.update_table_stats("nodes".to_string(), table_stats.clone());
|
||||
|
||||
let retrieved = stats.get_table_stats("nodes");
|
||||
assert!(retrieved.is_some());
|
||||
assert_eq!(retrieved.unwrap().row_count, 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_column_stats() {
|
||||
let stats = Statistics::new();
|
||||
let col_stats = ColumnStats::new(500, 10);
|
||||
|
||||
stats.update_column_stats("nodes.id".to_string(), col_stats);
|
||||
|
||||
let retrieved = stats.get_column_stats("nodes.id");
|
||||
assert!(retrieved.is_some());
|
||||
assert_eq!(retrieved.unwrap().ndv, 500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_histogram_creation() {
|
||||
let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||
let histogram = Histogram::equi_width(1.0, 10.0, 5, &values);
|
||||
|
||||
assert_eq!(histogram.num_buckets(), 5);
|
||||
assert_eq!(histogram.total_count, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_selectivity_estimation() {
|
||||
let stats = Statistics::new();
|
||||
let table_stats = TableStats::new(1000, 128);
|
||||
|
||||
stats.update_table_stats("nodes".to_string(), table_stats);
|
||||
|
||||
let selectivity = stats.estimate_join_selectivity("nodes", "edges", "id");
|
||||
assert!(selectivity > 0.0 && selectivity <= 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_filter_selectivity() {
|
||||
let stats = Statistics::new();
|
||||
let col_stats = ColumnStats::new(100, 5);
|
||||
|
||||
stats.update_column_stats("nodes.age".to_string(), col_stats);
|
||||
|
||||
let selectivity = stats.estimate_filter_selectivity("nodes.age", "=");
|
||||
assert_eq!(selectivity, 0.01); // 1/100
|
||||
}
|
||||
}
|
||||
414
vendor/ruvector/crates/ruvector-graph/src/graph.rs
vendored
Normal file
414
vendor/ruvector/crates/ruvector-graph/src/graph.rs
vendored
Normal file
@@ -0,0 +1,414 @@
|
||||
//! Graph database implementation with concurrent access and indexing
|
||||
|
||||
use crate::edge::Edge;
|
||||
use crate::error::Result;
|
||||
use crate::hyperedge::{Hyperedge, HyperedgeId};
|
||||
use crate::index::{AdjacencyIndex, EdgeTypeIndex, HyperedgeNodeIndex, LabelIndex, PropertyIndex};
|
||||
use crate::node::Node;
|
||||
#[cfg(feature = "storage")]
|
||||
use crate::storage::GraphStorage;
|
||||
use crate::types::{EdgeId, NodeId, PropertyValue};
|
||||
use dashmap::DashMap;
|
||||
#[cfg(feature = "storage")]
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// High-performance graph database with concurrent access
|
||||
pub struct GraphDB {
|
||||
/// In-memory node storage (DashMap for lock-free concurrent reads)
|
||||
nodes: Arc<DashMap<NodeId, Node>>,
|
||||
/// In-memory edge storage
|
||||
edges: Arc<DashMap<EdgeId, Edge>>,
|
||||
/// In-memory hyperedge storage
|
||||
hyperedges: Arc<DashMap<HyperedgeId, Hyperedge>>,
|
||||
/// Label index for fast label-based lookups
|
||||
label_index: LabelIndex,
|
||||
/// Property index for fast property-based lookups
|
||||
property_index: PropertyIndex,
|
||||
/// Edge type index
|
||||
edge_type_index: EdgeTypeIndex,
|
||||
/// Adjacency index for neighbor lookups
|
||||
adjacency_index: AdjacencyIndex,
|
||||
/// Hyperedge node index
|
||||
hyperedge_node_index: HyperedgeNodeIndex,
|
||||
/// Optional persistent storage
|
||||
#[cfg(feature = "storage")]
|
||||
storage: Option<GraphStorage>,
|
||||
}
|
||||
|
||||
impl GraphDB {
|
||||
/// Create a new in-memory graph database
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
nodes: Arc::new(DashMap::new()),
|
||||
edges: Arc::new(DashMap::new()),
|
||||
hyperedges: Arc::new(DashMap::new()),
|
||||
label_index: LabelIndex::new(),
|
||||
property_index: PropertyIndex::new(),
|
||||
edge_type_index: EdgeTypeIndex::new(),
|
||||
adjacency_index: AdjacencyIndex::new(),
|
||||
hyperedge_node_index: HyperedgeNodeIndex::new(),
|
||||
#[cfg(feature = "storage")]
|
||||
storage: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new graph database with persistent storage
|
||||
#[cfg(feature = "storage")]
|
||||
pub fn with_storage<P: AsRef<Path>>(path: P) -> anyhow::Result<Self> {
|
||||
let storage = GraphStorage::new(path)?;
|
||||
|
||||
let mut db = Self::new();
|
||||
db.storage = Some(storage);
|
||||
|
||||
// Load existing data from storage
|
||||
db.load_from_storage()?;
|
||||
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
/// Load all data from storage into memory
|
||||
#[cfg(feature = "storage")]
|
||||
fn load_from_storage(&mut self) -> anyhow::Result<()> {
|
||||
if let Some(storage) = &self.storage {
|
||||
// Load nodes
|
||||
for node_id in storage.all_node_ids()? {
|
||||
if let Some(node) = storage.get_node(&node_id)? {
|
||||
self.nodes.insert(node_id.clone(), node.clone());
|
||||
self.label_index.add_node(&node);
|
||||
self.property_index.add_node(&node);
|
||||
}
|
||||
}
|
||||
|
||||
// Load edges
|
||||
for edge_id in storage.all_edge_ids()? {
|
||||
if let Some(edge) = storage.get_edge(&edge_id)? {
|
||||
self.edges.insert(edge_id.clone(), edge.clone());
|
||||
self.edge_type_index.add_edge(&edge);
|
||||
self.adjacency_index.add_edge(&edge);
|
||||
}
|
||||
}
|
||||
|
||||
// Load hyperedges
|
||||
for hyperedge_id in storage.all_hyperedge_ids()? {
|
||||
if let Some(hyperedge) = storage.get_hyperedge(&hyperedge_id)? {
|
||||
self.hyperedges
|
||||
.insert(hyperedge_id.clone(), hyperedge.clone());
|
||||
self.hyperedge_node_index.add_hyperedge(&hyperedge);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Node operations
|
||||
|
||||
/// Create a node
|
||||
pub fn create_node(&self, node: Node) -> Result<NodeId> {
|
||||
let id = node.id.clone();
|
||||
|
||||
// Update indexes
|
||||
self.label_index.add_node(&node);
|
||||
self.property_index.add_node(&node);
|
||||
|
||||
// Insert into memory
|
||||
self.nodes.insert(id.clone(), node.clone());
|
||||
|
||||
// Persist to storage if available
|
||||
#[cfg(feature = "storage")]
|
||||
if let Some(storage) = &self.storage {
|
||||
storage.insert_node(&node)?;
|
||||
}
|
||||
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Get a node by ID
|
||||
pub fn get_node(&self, id: impl AsRef<str>) -> Option<Node> {
|
||||
self.nodes.get(id.as_ref()).map(|entry| entry.clone())
|
||||
}
|
||||
|
||||
/// Delete a node
|
||||
pub fn delete_node(&self, id: impl AsRef<str>) -> Result<bool> {
|
||||
if let Some((_, node)) = self.nodes.remove(id.as_ref()) {
|
||||
// Update indexes
|
||||
self.label_index.remove_node(&node);
|
||||
self.property_index.remove_node(&node);
|
||||
|
||||
// Delete from storage if available
|
||||
#[cfg(feature = "storage")]
|
||||
if let Some(storage) = &self.storage {
|
||||
storage.delete_node(id.as_ref())?;
|
||||
}
|
||||
|
||||
Ok(true)
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get nodes by label
|
||||
pub fn get_nodes_by_label(&self, label: &str) -> Vec<Node> {
|
||||
self.label_index
|
||||
.get_nodes_by_label(label)
|
||||
.into_iter()
|
||||
.filter_map(|id| self.get_node(&id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get nodes by property
|
||||
pub fn get_nodes_by_property(&self, key: &str, value: &PropertyValue) -> Vec<Node> {
|
||||
self.property_index
|
||||
.get_nodes_by_property(key, value)
|
||||
.into_iter()
|
||||
.filter_map(|id| self.get_node(&id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
// Edge operations
|
||||
|
||||
/// Create an edge
|
||||
pub fn create_edge(&self, edge: Edge) -> Result<EdgeId> {
|
||||
let id = edge.id.clone();
|
||||
|
||||
// Verify nodes exist
|
||||
if !self.nodes.contains_key(&edge.from) || !self.nodes.contains_key(&edge.to) {
|
||||
return Err(crate::error::GraphError::NodeNotFound(
|
||||
"Source or target node not found".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Update indexes
|
||||
self.edge_type_index.add_edge(&edge);
|
||||
self.adjacency_index.add_edge(&edge);
|
||||
|
||||
// Insert into memory
|
||||
self.edges.insert(id.clone(), edge.clone());
|
||||
|
||||
// Persist to storage if available
|
||||
#[cfg(feature = "storage")]
|
||||
if let Some(storage) = &self.storage {
|
||||
storage.insert_edge(&edge)?;
|
||||
}
|
||||
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Get an edge by ID
|
||||
pub fn get_edge(&self, id: impl AsRef<str>) -> Option<Edge> {
|
||||
self.edges.get(id.as_ref()).map(|entry| entry.clone())
|
||||
}
|
||||
|
||||
/// Delete an edge
|
||||
pub fn delete_edge(&self, id: impl AsRef<str>) -> Result<bool> {
|
||||
if let Some((_, edge)) = self.edges.remove(id.as_ref()) {
|
||||
// Update indexes
|
||||
self.edge_type_index.remove_edge(&edge);
|
||||
self.adjacency_index.remove_edge(&edge);
|
||||
|
||||
// Delete from storage if available
|
||||
#[cfg(feature = "storage")]
|
||||
if let Some(storage) = &self.storage {
|
||||
storage.delete_edge(id.as_ref())?;
|
||||
}
|
||||
|
||||
Ok(true)
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get edges by type
|
||||
pub fn get_edges_by_type(&self, edge_type: &str) -> Vec<Edge> {
|
||||
self.edge_type_index
|
||||
.get_edges_by_type(edge_type)
|
||||
.into_iter()
|
||||
.filter_map(|id| self.get_edge(&id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get outgoing edges from a node
|
||||
pub fn get_outgoing_edges(&self, node_id: &NodeId) -> Vec<Edge> {
|
||||
self.adjacency_index
|
||||
.get_outgoing_edges(node_id)
|
||||
.into_iter()
|
||||
.filter_map(|id| self.get_edge(&id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get incoming edges to a node
|
||||
pub fn get_incoming_edges(&self, node_id: &NodeId) -> Vec<Edge> {
|
||||
self.adjacency_index
|
||||
.get_incoming_edges(node_id)
|
||||
.into_iter()
|
||||
.filter_map(|id| self.get_edge(&id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
// Hyperedge operations
|
||||
|
||||
/// Create a hyperedge
|
||||
pub fn create_hyperedge(&self, hyperedge: Hyperedge) -> Result<HyperedgeId> {
|
||||
let id = hyperedge.id.clone();
|
||||
|
||||
// Verify all nodes exist
|
||||
for node_id in &hyperedge.nodes {
|
||||
if !self.nodes.contains_key(node_id) {
|
||||
return Err(crate::error::GraphError::NodeNotFound(format!(
|
||||
"Node {} not found",
|
||||
node_id
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Update index
|
||||
self.hyperedge_node_index.add_hyperedge(&hyperedge);
|
||||
|
||||
// Insert into memory
|
||||
self.hyperedges.insert(id.clone(), hyperedge.clone());
|
||||
|
||||
// Persist to storage if available
|
||||
#[cfg(feature = "storage")]
|
||||
if let Some(storage) = &self.storage {
|
||||
storage.insert_hyperedge(&hyperedge)?;
|
||||
}
|
||||
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Get a hyperedge by ID
|
||||
pub fn get_hyperedge(&self, id: &HyperedgeId) -> Option<Hyperedge> {
|
||||
self.hyperedges.get(id).map(|entry| entry.clone())
|
||||
}
|
||||
|
||||
/// Get hyperedges containing a node
|
||||
pub fn get_hyperedges_by_node(&self, node_id: &NodeId) -> Vec<Hyperedge> {
|
||||
self.hyperedge_node_index
|
||||
.get_hyperedges_by_node(node_id)
|
||||
.into_iter()
|
||||
.filter_map(|id| self.get_hyperedge(&id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
// Statistics
|
||||
|
||||
/// Get the number of nodes
|
||||
pub fn node_count(&self) -> usize {
|
||||
self.nodes.len()
|
||||
}
|
||||
|
||||
/// Get the number of edges
|
||||
pub fn edge_count(&self) -> usize {
|
||||
self.edges.len()
|
||||
}
|
||||
|
||||
/// Get the number of hyperedges
|
||||
pub fn hyperedge_count(&self) -> usize {
|
||||
self.hyperedges.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GraphDB {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::edge::EdgeBuilder;
|
||||
use crate::hyperedge::HyperedgeBuilder;
|
||||
use crate::node::NodeBuilder;
|
||||
|
||||
#[test]
|
||||
fn test_graph_creation() {
|
||||
let db = GraphDB::new();
|
||||
assert_eq!(db.node_count(), 0);
|
||||
assert_eq!(db.edge_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_operations() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let node = NodeBuilder::new()
|
||||
.label("Person")
|
||||
.property("name", "Alice")
|
||||
.build();
|
||||
|
||||
let id = db.create_node(node.clone()).unwrap();
|
||||
assert_eq!(db.node_count(), 1);
|
||||
|
||||
let retrieved = db.get_node(&id);
|
||||
assert!(retrieved.is_some());
|
||||
|
||||
let deleted = db.delete_node(&id).unwrap();
|
||||
assert!(deleted);
|
||||
assert_eq!(db.node_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_operations() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let node1 = NodeBuilder::new().build();
|
||||
let node2 = NodeBuilder::new().build();
|
||||
|
||||
let id1 = db.create_node(node1.clone()).unwrap();
|
||||
let id2 = db.create_node(node2.clone()).unwrap();
|
||||
|
||||
let edge = EdgeBuilder::new(id1.clone(), id2.clone(), "KNOWS")
|
||||
.property("since", 2020i64)
|
||||
.build();
|
||||
|
||||
let edge_id = db.create_edge(edge).unwrap();
|
||||
assert_eq!(db.edge_count(), 1);
|
||||
|
||||
let retrieved = db.get_edge(&edge_id);
|
||||
assert!(retrieved.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_label_index() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let node1 = NodeBuilder::new().label("Person").build();
|
||||
let node2 = NodeBuilder::new().label("Person").build();
|
||||
let node3 = NodeBuilder::new().label("Organization").build();
|
||||
|
||||
db.create_node(node1).unwrap();
|
||||
db.create_node(node2).unwrap();
|
||||
db.create_node(node3).unwrap();
|
||||
|
||||
let people = db.get_nodes_by_label("Person");
|
||||
assert_eq!(people.len(), 2);
|
||||
|
||||
let orgs = db.get_nodes_by_label("Organization");
|
||||
assert_eq!(orgs.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hyperedge_operations() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let node1 = NodeBuilder::new().build();
|
||||
let node2 = NodeBuilder::new().build();
|
||||
let node3 = NodeBuilder::new().build();
|
||||
|
||||
let id1 = db.create_node(node1).unwrap();
|
||||
let id2 = db.create_node(node2).unwrap();
|
||||
let id3 = db.create_node(node3).unwrap();
|
||||
|
||||
let hyperedge =
|
||||
HyperedgeBuilder::new(vec![id1.clone(), id2.clone(), id3.clone()], "MEETING")
|
||||
.description("Team meeting")
|
||||
.build();
|
||||
|
||||
let hedge_id = db.create_hyperedge(hyperedge).unwrap();
|
||||
assert_eq!(db.hyperedge_count(), 1);
|
||||
|
||||
let hedges = db.get_hyperedges_by_node(&id1);
|
||||
assert_eq!(hedges.len(), 1);
|
||||
}
|
||||
}
|
||||
324
vendor/ruvector/crates/ruvector-graph/src/hybrid/cypher_extensions.rs
vendored
Normal file
324
vendor/ruvector/crates/ruvector-graph/src/hybrid/cypher_extensions.rs
vendored
Normal file
@@ -0,0 +1,324 @@
|
||||
//! Cypher query extensions for vector similarity
|
||||
//!
|
||||
//! Extends Cypher syntax to support vector operations like SIMILAR TO.
|
||||
|
||||
use crate::error::{GraphError, Result};
|
||||
use crate::types::NodeId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Extended Cypher parser with vector support
|
||||
pub struct VectorCypherParser {
|
||||
/// Parse options
|
||||
options: ParserOptions,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ParserOptions {
|
||||
/// Enable vector similarity syntax
|
||||
pub enable_vector_similarity: bool,
|
||||
/// Enable semantic path queries
|
||||
pub enable_semantic_paths: bool,
|
||||
}
|
||||
|
||||
impl Default for ParserOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enable_vector_similarity: true,
|
||||
enable_semantic_paths: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VectorCypherParser {
|
||||
/// Create a new vector-aware Cypher parser
|
||||
pub fn new(options: ParserOptions) -> Self {
|
||||
Self { options }
|
||||
}
|
||||
|
||||
/// Parse a Cypher query with vector extensions
|
||||
pub fn parse(&self, query: &str) -> Result<VectorCypherQuery> {
|
||||
// This is a simplified parser for demonstration
|
||||
// Real implementation would use proper parser combinators or generated parser
|
||||
|
||||
if query.contains("SIMILAR TO") {
|
||||
self.parse_similarity_query(query)
|
||||
} else if query.contains("SEMANTIC PATH") {
|
||||
self.parse_semantic_path_query(query)
|
||||
} else {
|
||||
Ok(VectorCypherQuery {
|
||||
match_clause: query.to_string(),
|
||||
similarity_predicate: None,
|
||||
return_clause: "RETURN *".to_string(),
|
||||
limit: None,
|
||||
order_by: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse similarity query
|
||||
fn parse_similarity_query(&self, query: &str) -> Result<VectorCypherQuery> {
|
||||
// Example: MATCH (n:Document) WHERE n.embedding SIMILAR TO $query_vector LIMIT 10 RETURN n
|
||||
|
||||
// Extract components (simplified parsing)
|
||||
let match_clause = query
|
||||
.split("WHERE")
|
||||
.next()
|
||||
.ok_or_else(|| GraphError::QueryError("Invalid MATCH clause".to_string()))?
|
||||
.to_string();
|
||||
|
||||
let similarity_predicate = Some(SimilarityPredicate {
|
||||
property: "embedding".to_string(),
|
||||
query_vector: Vec::new(), // Would be populated from parameters
|
||||
top_k: 10,
|
||||
min_score: 0.0,
|
||||
});
|
||||
|
||||
Ok(VectorCypherQuery {
|
||||
match_clause,
|
||||
similarity_predicate,
|
||||
return_clause: "RETURN n".to_string(),
|
||||
limit: Some(10),
|
||||
order_by: Some("semanticScore DESC".to_string()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse semantic path query
|
||||
fn parse_semantic_path_query(&self, query: &str) -> Result<VectorCypherQuery> {
|
||||
// Example: MATCH path = (start)-[*1..3]-(end)
|
||||
// WHERE start.embedding SIMILAR TO $query
|
||||
// RETURN path ORDER BY semanticScore(path) DESC
|
||||
|
||||
Ok(VectorCypherQuery {
|
||||
match_clause: query.to_string(),
|
||||
similarity_predicate: None,
|
||||
return_clause: "RETURN path".to_string(),
|
||||
limit: None,
|
||||
order_by: Some("semanticScore(path) DESC".to_string()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Parsed vector-aware Cypher query
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VectorCypherQuery {
|
||||
pub match_clause: String,
|
||||
pub similarity_predicate: Option<SimilarityPredicate>,
|
||||
pub return_clause: String,
|
||||
pub limit: Option<usize>,
|
||||
pub order_by: Option<String>,
|
||||
}
|
||||
|
||||
/// Similarity predicate in WHERE clause
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SimilarityPredicate {
|
||||
/// Property containing embedding
|
||||
pub property: String,
|
||||
/// Query vector for comparison
|
||||
pub query_vector: Vec<f32>,
|
||||
/// Number of results
|
||||
pub top_k: usize,
|
||||
/// Minimum similarity score
|
||||
pub min_score: f32,
|
||||
}
|
||||
|
||||
/// Executor for vector-aware Cypher queries
|
||||
pub struct VectorCypherExecutor {
|
||||
// In real implementation, this would have access to:
|
||||
// - Graph storage
|
||||
// - Vector index
|
||||
// - Query planner
|
||||
}
|
||||
|
||||
impl VectorCypherExecutor {
|
||||
/// Create a new executor
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
/// Execute a vector-aware Cypher query
|
||||
pub fn execute(&self, _query: &VectorCypherQuery) -> Result<QueryResult> {
|
||||
// This is a placeholder for actual execution
|
||||
// Real implementation would:
|
||||
// 1. Plan query execution (optimize with vector indices)
|
||||
// 2. Execute vector similarity search
|
||||
// 3. Apply graph pattern matching
|
||||
// 4. Combine results
|
||||
// 5. Apply ordering and limits
|
||||
|
||||
Ok(QueryResult {
|
||||
rows: Vec::new(),
|
||||
execution_time_ms: 0,
|
||||
stats: ExecutionStats {
|
||||
nodes_scanned: 0,
|
||||
vectors_compared: 0,
|
||||
index_hits: 0,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/// Execute similarity search
|
||||
pub fn execute_similarity_search(
|
||||
&self,
|
||||
_predicate: &SimilarityPredicate,
|
||||
) -> Result<Vec<NodeId>> {
|
||||
// Placeholder for vector similarity search
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
/// Compute semantic score for a path
|
||||
pub fn semantic_score(&self, _path: &[NodeId]) -> f32 {
|
||||
// Placeholder for path scoring
|
||||
// Real implementation would:
|
||||
// 1. Retrieve embeddings for all nodes in path
|
||||
// 2. Compute pairwise similarities
|
||||
// 3. Aggregate scores (e.g., average, min, product)
|
||||
|
||||
0.85 // Dummy score
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for VectorCypherExecutor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Query execution result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct QueryResult {
|
||||
pub rows: Vec<HashMap<String, serde_json::Value>>,
|
||||
pub execution_time_ms: u64,
|
||||
pub stats: ExecutionStats,
|
||||
}
|
||||
|
||||
/// Execution statistics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExecutionStats {
|
||||
pub nodes_scanned: usize,
|
||||
pub vectors_compared: usize,
|
||||
pub index_hits: usize,
|
||||
}
|
||||
|
||||
/// Extended Cypher functions for vectors
|
||||
pub mod functions {
|
||||
use super::*;
|
||||
|
||||
/// Compute cosine similarity between two embeddings
|
||||
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32> {
|
||||
use ruvector_core::distance::cosine_distance;
|
||||
|
||||
if a.len() != b.len() {
|
||||
return Err(GraphError::InvalidEmbedding(
|
||||
"Embedding dimensions must match".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Convert distance to similarity
|
||||
let distance = cosine_distance(a, b);
|
||||
Ok(1.0 - distance)
|
||||
}
|
||||
|
||||
/// Compute semantic score for a path
|
||||
pub fn semantic_score(embeddings: &[Vec<f32>]) -> Result<f32> {
|
||||
if embeddings.is_empty() {
|
||||
return Ok(0.0);
|
||||
}
|
||||
|
||||
if embeddings.len() == 1 {
|
||||
return Ok(1.0);
|
||||
}
|
||||
|
||||
// Compute average pairwise similarity
|
||||
let mut total_score = 0.0;
|
||||
let mut count = 0;
|
||||
|
||||
for i in 0..embeddings.len() - 1 {
|
||||
let sim = cosine_similarity(&embeddings[i], &embeddings[i + 1])?;
|
||||
total_score += sim;
|
||||
count += 1;
|
||||
}
|
||||
|
||||
Ok(total_score / count as f32)
|
||||
}
|
||||
|
||||
/// Vector aggregation (average of embeddings)
|
||||
pub fn avg_embedding(embeddings: &[Vec<f32>]) -> Result<Vec<f32>> {
|
||||
if embeddings.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let dim = embeddings[0].len();
|
||||
let mut result = vec![0.0; dim];
|
||||
|
||||
for emb in embeddings {
|
||||
if emb.len() != dim {
|
||||
return Err(GraphError::InvalidEmbedding(
|
||||
"All embeddings must have same dimensions".to_string(),
|
||||
));
|
||||
}
|
||||
for (i, &val) in emb.iter().enumerate() {
|
||||
result[i] += val;
|
||||
}
|
||||
}
|
||||
|
||||
let n = embeddings.len() as f32;
|
||||
for val in &mut result {
|
||||
*val /= n;
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parser_creation() {
|
||||
let parser = VectorCypherParser::new(ParserOptions::default());
|
||||
assert!(parser.options.enable_vector_similarity);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_similarity_query_parsing() -> Result<()> {
|
||||
let parser = VectorCypherParser::new(ParserOptions::default());
|
||||
let query =
|
||||
"MATCH (n:Document) WHERE n.embedding SIMILAR TO $query_vector LIMIT 10 RETURN n";
|
||||
|
||||
let parsed = parser.parse(query)?;
|
||||
assert!(parsed.similarity_predicate.is_some());
|
||||
assert_eq!(parsed.limit, Some(10));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity() -> Result<()> {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
|
||||
let sim = functions::cosine_similarity(&a, &b)?;
|
||||
assert!(sim > 0.99); // Should be very close to 1.0
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_embedding() -> Result<()> {
|
||||
let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
|
||||
|
||||
let avg = functions::avg_embedding(&embeddings)?;
|
||||
assert_eq!(avg, vec![0.5, 0.5]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_executor_creation() {
|
||||
let executor = VectorCypherExecutor::new();
|
||||
let score = executor.semantic_score(&vec!["n1".to_string()]);
|
||||
assert!(score > 0.0);
|
||||
}
|
||||
}
|
||||
319
vendor/ruvector/crates/ruvector-graph/src/hybrid/graph_neural.rs
vendored
Normal file
319
vendor/ruvector/crates/ruvector-graph/src/hybrid/graph_neural.rs
vendored
Normal file
@@ -0,0 +1,319 @@
|
||||
//! Graph Neural Network inference capabilities
|
||||
//!
|
||||
//! Provides GNN-based predictions: node classification, link prediction, graph embeddings.
|
||||
|
||||
use crate::error::{GraphError, Result};
|
||||
use crate::types::{EdgeId, NodeId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Configuration for GNN engine
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GnnConfig {
|
||||
/// Number of GNN layers
|
||||
pub num_layers: usize,
|
||||
/// Hidden dimension size
|
||||
pub hidden_dim: usize,
|
||||
/// Aggregation method
|
||||
pub aggregation: AggregationType,
|
||||
/// Activation function
|
||||
pub activation: ActivationType,
|
||||
/// Dropout rate
|
||||
pub dropout: f32,
|
||||
}
|
||||
|
||||
impl Default for GnnConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_layers: 2,
|
||||
hidden_dim: 128,
|
||||
aggregation: AggregationType::Mean,
|
||||
activation: ActivationType::ReLU,
|
||||
dropout: 0.1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Aggregation type for message passing
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum AggregationType {
|
||||
Mean,
|
||||
Sum,
|
||||
Max,
|
||||
Attention,
|
||||
}
|
||||
|
||||
/// Activation function type
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum ActivationType {
|
||||
ReLU,
|
||||
Sigmoid,
|
||||
Tanh,
|
||||
GELU,
|
||||
}
|
||||
|
||||
/// Graph Neural Network engine
|
||||
pub struct GraphNeuralEngine {
|
||||
config: GnnConfig,
|
||||
// In real implementation, would store model weights
|
||||
node_embeddings: HashMap<NodeId, Vec<f32>>,
|
||||
}
|
||||
|
||||
impl GraphNeuralEngine {
|
||||
/// Create a new GNN engine
|
||||
pub fn new(config: GnnConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
node_embeddings: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load pre-trained model weights
|
||||
pub fn load_model(&mut self, _model_path: &str) -> Result<()> {
|
||||
// Placeholder for model loading
|
||||
// Real implementation would:
|
||||
// 1. Load weights from file
|
||||
// 2. Initialize neural network layers
|
||||
// 3. Set up computation graph
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Classify a node based on its features and neighbors
|
||||
pub fn classify_node(&self, node_id: &NodeId, _features: &[f32]) -> Result<NodeClassification> {
|
||||
// Placeholder for GNN inference
|
||||
// Real implementation would:
|
||||
// 1. Gather neighbor features
|
||||
// 2. Apply message passing layers
|
||||
// 3. Aggregate neighbor information
|
||||
// 4. Compute final classification
|
||||
|
||||
let class_probabilities = vec![0.7, 0.2, 0.1]; // Dummy probabilities
|
||||
let predicted_class = 0;
|
||||
|
||||
Ok(NodeClassification {
|
||||
node_id: node_id.clone(),
|
||||
predicted_class,
|
||||
class_probabilities,
|
||||
confidence: 0.7,
|
||||
})
|
||||
}
|
||||
|
||||
/// Predict likelihood of a link between two nodes
|
||||
pub fn predict_link(&self, node1: &NodeId, node2: &NodeId) -> Result<LinkPrediction> {
|
||||
// Placeholder for link prediction
|
||||
// Real implementation would:
|
||||
// 1. Get embeddings for both nodes
|
||||
// 2. Compute compatibility score (dot product, concat+MLP, etc.)
|
||||
// 3. Apply sigmoid for probability
|
||||
|
||||
let score = 0.85; // Dummy score
|
||||
let exists = score > 0.5;
|
||||
|
||||
Ok(LinkPrediction {
|
||||
node1: node1.clone(),
|
||||
node2: node2.clone(),
|
||||
score,
|
||||
exists,
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate embedding for entire graph or subgraph
|
||||
pub fn embed_graph(&self, node_ids: &[NodeId]) -> Result<GraphEmbedding> {
|
||||
// Placeholder for graph-level embedding
|
||||
// Real implementation would use graph pooling:
|
||||
// 1. Get node embeddings
|
||||
// 2. Apply pooling (mean, max, attention-based)
|
||||
// 3. Optionally apply final MLP
|
||||
|
||||
let embedding = vec![0.0; self.config.hidden_dim];
|
||||
|
||||
Ok(GraphEmbedding {
|
||||
embedding,
|
||||
node_count: node_ids.len(),
|
||||
method: "mean_pooling".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Update node embeddings using message passing
|
||||
pub fn update_embeddings(&mut self, graph_structure: &GraphStructure) -> Result<()> {
|
||||
// Placeholder for embedding update
|
||||
// Real implementation would:
|
||||
// 1. For each layer:
|
||||
// - Aggregate neighbor features
|
||||
// - Apply linear transformation
|
||||
// - Apply activation
|
||||
// 2. Store final embeddings
|
||||
|
||||
for node_id in &graph_structure.nodes {
|
||||
let embedding = vec![0.0; self.config.hidden_dim];
|
||||
self.node_embeddings.insert(node_id.clone(), embedding);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get embedding for a specific node
|
||||
pub fn get_node_embedding(&self, node_id: &NodeId) -> Option<&Vec<f32>> {
|
||||
self.node_embeddings.get(node_id)
|
||||
}
|
||||
|
||||
/// Batch node classification
|
||||
pub fn classify_nodes_batch(
|
||||
&self,
|
||||
nodes: &[(NodeId, Vec<f32>)],
|
||||
) -> Result<Vec<NodeClassification>> {
|
||||
nodes
|
||||
.iter()
|
||||
.map(|(id, features)| self.classify_node(id, features))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Batch link prediction
|
||||
pub fn predict_links_batch(&self, pairs: &[(NodeId, NodeId)]) -> Result<Vec<LinkPrediction>> {
|
||||
pairs
|
||||
.iter()
|
||||
.map(|(n1, n2)| self.predict_link(n1, n2))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Apply attention mechanism for neighbor aggregation
|
||||
fn aggregate_with_attention(
|
||||
&self,
|
||||
_node_embedding: &[f32],
|
||||
_neighbor_embeddings: &[Vec<f32>],
|
||||
) -> Vec<f32> {
|
||||
// Placeholder for attention-based aggregation
|
||||
// Real implementation would compute attention weights
|
||||
vec![0.0; self.config.hidden_dim]
|
||||
}
|
||||
|
||||
/// Apply activation function with numerical stability
|
||||
fn activate(&self, x: f32) -> f32 {
|
||||
match self.config.activation {
|
||||
ActivationType::ReLU => x.max(0.0),
|
||||
ActivationType::Sigmoid => {
|
||||
if x > 0.0 {
|
||||
1.0 / (1.0 + (-x).exp())
|
||||
} else {
|
||||
let ex = x.exp();
|
||||
ex / (1.0 + ex)
|
||||
}
|
||||
}
|
||||
ActivationType::Tanh => x.tanh(),
|
||||
ActivationType::GELU => {
|
||||
// Approximate GELU
|
||||
0.5 * x * (1.0 + (0.7978845608 * (x + 0.044715 * x.powi(3))).tanh())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of node classification
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NodeClassification {
|
||||
pub node_id: NodeId,
|
||||
pub predicted_class: usize,
|
||||
pub class_probabilities: Vec<f32>,
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
/// Result of link prediction
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LinkPrediction {
|
||||
pub node1: NodeId,
|
||||
pub node2: NodeId,
|
||||
pub score: f32,
|
||||
pub exists: bool,
|
||||
}
|
||||
|
||||
/// Graph-level embedding
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GraphEmbedding {
|
||||
pub embedding: Vec<f32>,
|
||||
pub node_count: usize,
|
||||
pub method: String,
|
||||
}
|
||||
|
||||
/// Graph structure for GNN processing
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GraphStructure {
|
||||
pub nodes: Vec<NodeId>,
|
||||
pub edges: Vec<(NodeId, NodeId)>,
|
||||
pub node_features: HashMap<NodeId, Vec<f32>>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_gnn_engine_creation() {
|
||||
let config = GnnConfig::default();
|
||||
let _engine = GraphNeuralEngine::new(config);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_classification() -> Result<()> {
|
||||
let engine = GraphNeuralEngine::new(GnnConfig::default());
|
||||
let features = vec![1.0, 0.5, 0.3];
|
||||
|
||||
let result = engine.classify_node(&"node1".to_string(), &features)?;
|
||||
|
||||
assert_eq!(result.node_id, "node1");
|
||||
assert!(result.confidence > 0.0);
|
||||
assert!(!result.class_probabilities.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_link_prediction() -> Result<()> {
|
||||
let engine = GraphNeuralEngine::new(GnnConfig::default());
|
||||
|
||||
let result = engine.predict_link(&"node1".to_string(), &"node2".to_string())?;
|
||||
|
||||
assert_eq!(result.node1, "node1");
|
||||
assert_eq!(result.node2, "node2");
|
||||
assert!(result.score >= 0.0 && result.score <= 1.0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_graph_embedding() -> Result<()> {
|
||||
let engine = GraphNeuralEngine::new(GnnConfig::default());
|
||||
let nodes = vec!["n1".to_string(), "n2".to_string(), "n3".to_string()];
|
||||
|
||||
let embedding = engine.embed_graph(&nodes)?;
|
||||
|
||||
assert_eq!(embedding.node_count, 3);
|
||||
assert_eq!(embedding.embedding.len(), 128);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_classification() -> Result<()> {
|
||||
let engine = GraphNeuralEngine::new(GnnConfig::default());
|
||||
let nodes = vec![
|
||||
("n1".to_string(), vec![1.0, 0.0]),
|
||||
("n2".to_string(), vec![0.0, 1.0]),
|
||||
];
|
||||
|
||||
let results = engine.classify_nodes_batch(&nodes)?;
|
||||
assert_eq!(results.len(), 2);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_activation_functions() {
|
||||
let engine = GraphNeuralEngine::new(GnnConfig {
|
||||
activation: ActivationType::ReLU,
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
assert_eq!(engine.activate(-1.0), 0.0);
|
||||
assert_eq!(engine.activate(1.0), 1.0);
|
||||
}
|
||||
}
|
||||
73
vendor/ruvector/crates/ruvector-graph/src/hybrid/mod.rs
vendored
Normal file
73
vendor/ruvector/crates/ruvector-graph/src/hybrid/mod.rs
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
//! Vector-Graph Hybrid Query System
|
||||
//!
|
||||
//! Combines vector similarity search with graph traversal for AI workloads.
|
||||
//! Supports semantic search, RAG (Retrieval Augmented Generation), and GNN inference.
|
||||
|
||||
pub mod cypher_extensions;
|
||||
pub mod graph_neural;
|
||||
pub mod rag_integration;
|
||||
pub mod semantic_search;
|
||||
pub mod vector_index;
|
||||
|
||||
// Re-export main types
|
||||
pub use cypher_extensions::{SimilarityPredicate, VectorCypherExecutor, VectorCypherParser};
|
||||
pub use graph_neural::{
|
||||
GnnConfig, GraphEmbedding, GraphNeuralEngine, LinkPrediction, NodeClassification,
|
||||
};
|
||||
pub use rag_integration::{Context, Evidence, RagConfig, RagEngine, ReasoningPath};
|
||||
pub use semantic_search::{ClusterResult, SemanticPath, SemanticSearch, SemanticSearchConfig};
|
||||
pub use vector_index::{EmbeddingConfig, HybridIndex, VectorIndexType};
|
||||
|
||||
use crate::error::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Hybrid query combining graph patterns and vector similarity
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HybridQuery {
|
||||
/// Cypher pattern to match graph structure
|
||||
pub graph_pattern: String,
|
||||
/// Vector similarity constraint
|
||||
pub vector_constraint: Option<VectorConstraint>,
|
||||
/// Maximum results to return
|
||||
pub limit: usize,
|
||||
/// Minimum similarity score threshold
|
||||
pub min_score: f32,
|
||||
}
|
||||
|
||||
/// Vector similarity constraint for hybrid queries
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VectorConstraint {
|
||||
/// Query embedding vector
|
||||
pub query_vector: Vec<f32>,
|
||||
/// Property name containing the embedding
|
||||
pub embedding_property: String,
|
||||
/// Top-k similar items
|
||||
pub top_k: usize,
|
||||
}
|
||||
|
||||
/// Result from a hybrid query
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HybridResult {
|
||||
/// Matched graph elements
|
||||
pub graph_match: serde_json::Value,
|
||||
/// Similarity score
|
||||
pub score: f32,
|
||||
/// Explanation of match
|
||||
pub explanation: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_hybrid_query_creation() {
|
||||
let query = HybridQuery {
|
||||
graph_pattern: "MATCH (n:Document) RETURN n".to_string(),
|
||||
vector_constraint: None,
|
||||
limit: 10,
|
||||
min_score: 0.8,
|
||||
};
|
||||
assert_eq!(query.limit, 10);
|
||||
}
|
||||
}
|
||||
324
vendor/ruvector/crates/ruvector-graph/src/hybrid/rag_integration.rs
vendored
Normal file
324
vendor/ruvector/crates/ruvector-graph/src/hybrid/rag_integration.rs
vendored
Normal file
@@ -0,0 +1,324 @@
|
||||
//! RAG (Retrieval Augmented Generation) integration
|
||||
//!
|
||||
//! Provides graph-based context retrieval and multi-hop reasoning for LLMs.
|
||||
|
||||
use crate::error::{GraphError, Result};
|
||||
use crate::hybrid::semantic_search::{SemanticPath, SemanticSearch};
|
||||
use crate::types::{EdgeId, NodeId, Properties};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Configuration for RAG engine
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RagConfig {
|
||||
/// Maximum context size (in tokens)
|
||||
pub max_context_tokens: usize,
|
||||
/// Number of top documents to retrieve
|
||||
pub top_k_docs: usize,
|
||||
/// Maximum reasoning depth (hops in graph)
|
||||
pub max_reasoning_depth: usize,
|
||||
/// Minimum relevance score
|
||||
pub min_relevance: f32,
|
||||
/// Enable multi-hop reasoning
|
||||
pub multi_hop_reasoning: bool,
|
||||
}
|
||||
|
||||
impl Default for RagConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_context_tokens: 4096,
|
||||
top_k_docs: 5,
|
||||
max_reasoning_depth: 3,
|
||||
min_relevance: 0.7,
|
||||
multi_hop_reasoning: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// RAG engine for graph-based retrieval
|
||||
pub struct RagEngine {
|
||||
/// Semantic search engine
|
||||
semantic_search: SemanticSearch,
|
||||
/// Configuration
|
||||
config: RagConfig,
|
||||
}
|
||||
|
||||
impl RagEngine {
|
||||
/// Create a new RAG engine
|
||||
pub fn new(semantic_search: SemanticSearch, config: RagConfig) -> Self {
|
||||
Self {
|
||||
semantic_search,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieve relevant context for a query
|
||||
pub fn retrieve_context(&self, query: &[f32]) -> Result<Context> {
|
||||
// Find top-k most relevant documents
|
||||
let matches = self
|
||||
.semantic_search
|
||||
.find_similar_nodes(query, self.config.top_k_docs)?;
|
||||
|
||||
let mut documents = Vec::new();
|
||||
for match_result in matches {
|
||||
if match_result.score >= self.config.min_relevance {
|
||||
documents.push(Document {
|
||||
node_id: match_result.node_id.clone(),
|
||||
content: format!("Document {}", match_result.node_id),
|
||||
metadata: HashMap::new(),
|
||||
relevance_score: match_result.score,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let total_tokens = self.estimate_tokens(&documents);
|
||||
|
||||
Ok(Context {
|
||||
documents,
|
||||
total_tokens,
|
||||
query_embedding: query.to_vec(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Build multi-hop reasoning paths
|
||||
pub fn build_reasoning_paths(
|
||||
&self,
|
||||
start_node: &NodeId,
|
||||
query: &[f32],
|
||||
) -> Result<Vec<ReasoningPath>> {
|
||||
if !self.config.multi_hop_reasoning {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Find semantic paths through the graph
|
||||
let semantic_paths =
|
||||
self.semantic_search
|
||||
.find_semantic_paths(start_node, query, self.config.top_k_docs)?;
|
||||
|
||||
// Convert semantic paths to reasoning paths
|
||||
let reasoning_paths = semantic_paths
|
||||
.into_iter()
|
||||
.map(|path| self.convert_to_reasoning_path(path))
|
||||
.collect();
|
||||
|
||||
Ok(reasoning_paths)
|
||||
}
|
||||
|
||||
/// Aggregate evidence from multiple sources
|
||||
pub fn aggregate_evidence(&self, paths: &[ReasoningPath]) -> Result<Vec<Evidence>> {
|
||||
let mut evidence_map: HashMap<NodeId, Evidence> = HashMap::new();
|
||||
|
||||
for path in paths {
|
||||
for step in &path.steps {
|
||||
evidence_map
|
||||
.entry(step.node_id.clone())
|
||||
.and_modify(|e| {
|
||||
e.support_count += 1;
|
||||
e.confidence = e.confidence.max(step.confidence);
|
||||
})
|
||||
.or_insert_with(|| Evidence {
|
||||
node_id: step.node_id.clone(),
|
||||
content: step.content.clone(),
|
||||
support_count: 1,
|
||||
confidence: step.confidence,
|
||||
sources: vec![step.node_id.clone()],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let mut evidence: Vec<_> = evidence_map.into_values().collect();
|
||||
evidence.sort_by(|a, b| {
|
||||
b.confidence
|
||||
.partial_cmp(&a.confidence)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
Ok(evidence)
|
||||
}
|
||||
|
||||
/// Generate context-aware prompt
|
||||
pub fn generate_prompt(&self, query: &str, context: &Context) -> String {
|
||||
let mut prompt = String::new();
|
||||
|
||||
prompt.push_str("Based on the following context, answer the question.\n\n");
|
||||
prompt.push_str("Context:\n");
|
||||
|
||||
for (i, doc) in context.documents.iter().enumerate() {
|
||||
prompt.push_str(&format!(
|
||||
"{}. {} (relevance: {:.2})\n",
|
||||
i + 1,
|
||||
doc.content,
|
||||
doc.relevance_score
|
||||
));
|
||||
}
|
||||
|
||||
prompt.push_str("\nQuestion: ");
|
||||
prompt.push_str(query);
|
||||
prompt.push_str("\n\nAnswer:");
|
||||
|
||||
prompt
|
||||
}
|
||||
|
||||
/// Rerank results based on graph structure
|
||||
pub fn rerank_results(
|
||||
&self,
|
||||
initial_results: Vec<Document>,
|
||||
_query: &[f32],
|
||||
) -> Result<Vec<Document>> {
|
||||
// Simple reranking based on score
|
||||
// Real implementation would consider:
|
||||
// - Graph centrality
|
||||
// - Cross-document connections
|
||||
// - Temporal relevance
|
||||
// - User preferences
|
||||
|
||||
let mut results = initial_results;
|
||||
results.sort_by(|a, b| {
|
||||
b.relevance_score
|
||||
.partial_cmp(&a.relevance_score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Convert semantic path to reasoning path
|
||||
fn convert_to_reasoning_path(&self, semantic_path: SemanticPath) -> ReasoningPath {
|
||||
let steps = semantic_path
|
||||
.nodes
|
||||
.iter()
|
||||
.map(|node_id| ReasoningStep {
|
||||
node_id: node_id.clone(),
|
||||
content: format!("Step at node {}", node_id),
|
||||
relationship: "RELATED_TO".to_string(),
|
||||
confidence: semantic_path.semantic_score,
|
||||
})
|
||||
.collect();
|
||||
|
||||
ReasoningPath {
|
||||
steps,
|
||||
total_confidence: semantic_path.combined_score,
|
||||
explanation: format!("Reasoning path with {} steps", semantic_path.nodes.len()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate token count for documents
|
||||
fn estimate_tokens(&self, documents: &[Document]) -> usize {
|
||||
// Rough estimation: ~4 characters per token
|
||||
documents.iter().map(|doc| doc.content.len() / 4).sum()
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieved context for generation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Context {
|
||||
/// Retrieved documents
|
||||
pub documents: Vec<Document>,
|
||||
/// Total estimated tokens
|
||||
pub total_tokens: usize,
|
||||
/// Original query embedding
|
||||
pub query_embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
/// A retrieved document
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Document {
|
||||
pub node_id: NodeId,
|
||||
pub content: String,
|
||||
pub metadata: HashMap<String, String>,
|
||||
pub relevance_score: f32,
|
||||
}
|
||||
|
||||
/// A multi-hop reasoning path
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReasoningPath {
|
||||
/// Steps in the reasoning chain
|
||||
pub steps: Vec<ReasoningStep>,
|
||||
/// Overall confidence in this path
|
||||
pub total_confidence: f32,
|
||||
/// Human-readable explanation
|
||||
pub explanation: String,
|
||||
}
|
||||
|
||||
/// A single step in reasoning
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReasoningStep {
|
||||
pub node_id: NodeId,
|
||||
pub content: String,
|
||||
pub relationship: String,
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
/// Aggregated evidence from multiple paths
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Evidence {
|
||||
pub node_id: NodeId,
|
||||
pub content: String,
|
||||
pub support_count: usize,
|
||||
pub confidence: f32,
|
||||
pub sources: Vec<NodeId>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::hybrid::semantic_search::SemanticSearchConfig;
|
||||
use crate::hybrid::vector_index::{EmbeddingConfig, HybridIndex};
|
||||
|
||||
#[test]
|
||||
fn test_rag_engine_creation() {
|
||||
let index = HybridIndex::new(EmbeddingConfig::default()).unwrap();
|
||||
let semantic_search = SemanticSearch::new(index, SemanticSearchConfig::default());
|
||||
let _rag = RagEngine::new(semantic_search, RagConfig::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_retrieval() -> Result<()> {
|
||||
use crate::hybrid::vector_index::VectorIndexType;
|
||||
|
||||
let config = EmbeddingConfig {
|
||||
dimensions: 4,
|
||||
..Default::default()
|
||||
};
|
||||
let index = HybridIndex::new(config)?;
|
||||
// Initialize the node index
|
||||
index.initialize_index(VectorIndexType::Node)?;
|
||||
|
||||
// Add test embeddings so search returns results
|
||||
index.add_node_embedding("doc1".to_string(), vec![1.0, 0.0, 0.0, 0.0])?;
|
||||
index.add_node_embedding("doc2".to_string(), vec![0.9, 0.1, 0.0, 0.0])?;
|
||||
|
||||
let semantic_search = SemanticSearch::new(index, SemanticSearchConfig::default());
|
||||
let rag = RagEngine::new(semantic_search, RagConfig::default());
|
||||
|
||||
let query = vec![1.0, 0.0, 0.0, 0.0];
|
||||
let context = rag.retrieve_context(&query)?;
|
||||
|
||||
assert_eq!(context.query_embedding, query);
|
||||
// Should find at least one document
|
||||
assert!(!context.documents.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prompt_generation() {
|
||||
let index = HybridIndex::new(EmbeddingConfig::default()).unwrap();
|
||||
let semantic_search = SemanticSearch::new(index, SemanticSearchConfig::default());
|
||||
let rag = RagEngine::new(semantic_search, RagConfig::default());
|
||||
|
||||
let context = Context {
|
||||
documents: vec![Document {
|
||||
node_id: "doc1".to_string(),
|
||||
content: "Test content".to_string(),
|
||||
metadata: HashMap::new(),
|
||||
relevance_score: 0.9,
|
||||
}],
|
||||
total_tokens: 100,
|
||||
query_embedding: vec![1.0; 4],
|
||||
};
|
||||
|
||||
let prompt = rag.generate_prompt("What is the answer?", &context);
|
||||
assert!(prompt.contains("Test content"));
|
||||
assert!(prompt.contains("What is the answer?"));
|
||||
}
|
||||
}
|
||||
333
vendor/ruvector/crates/ruvector-graph/src/hybrid/semantic_search.rs
vendored
Normal file
333
vendor/ruvector/crates/ruvector-graph/src/hybrid/semantic_search.rs
vendored
Normal file
@@ -0,0 +1,333 @@
|
||||
//! Semantic search capabilities for graph queries
|
||||
//!
|
||||
//! Combines vector similarity with graph traversal for semantic queries.
|
||||
|
||||
use crate::error::{GraphError, Result};
|
||||
use crate::hybrid::vector_index::HybridIndex;
|
||||
use crate::types::{EdgeId, NodeId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
/// Configuration for semantic search
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SemanticSearchConfig {
|
||||
/// Maximum path length for traversal
|
||||
pub max_path_length: usize,
|
||||
/// Minimum similarity threshold
|
||||
pub min_similarity: f32,
|
||||
/// Top-k results per hop
|
||||
pub top_k: usize,
|
||||
/// Weight for semantic similarity vs. graph distance
|
||||
pub semantic_weight: f32,
|
||||
}
|
||||
|
||||
impl Default for SemanticSearchConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_path_length: 3,
|
||||
min_similarity: 0.7,
|
||||
top_k: 10,
|
||||
semantic_weight: 0.6,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Semantic search engine for graph queries
|
||||
pub struct SemanticSearch {
|
||||
/// Vector index for similarity search
|
||||
index: HybridIndex,
|
||||
/// Configuration
|
||||
config: SemanticSearchConfig,
|
||||
}
|
||||
|
||||
impl SemanticSearch {
|
||||
/// Create a new semantic search engine
|
||||
pub fn new(index: HybridIndex, config: SemanticSearchConfig) -> Self {
|
||||
Self { index, config }
|
||||
}
|
||||
|
||||
/// Find nodes semantically similar to query embedding
|
||||
pub fn find_similar_nodes(&self, query: &[f32], k: usize) -> Result<Vec<SemanticMatch>> {
|
||||
let results = self.index.search_similar_nodes(query, k)?;
|
||||
|
||||
// Pre-compute max distance threshold for faster comparison
|
||||
// If min_similarity = 0.7, then max_distance = 0.3
|
||||
let max_distance = 1.0 - self.config.min_similarity;
|
||||
|
||||
// HNSW returns distance (0 = identical, 1 = orthogonal for cosine)
|
||||
// Convert to similarity (1 = identical, 0 = orthogonal)
|
||||
// Use filter_map for single-pass optimization and pre-allocate result
|
||||
let mut matches = Vec::with_capacity(results.len());
|
||||
for (node_id, distance) in results {
|
||||
// Filter by distance threshold (faster than converting and comparing)
|
||||
if distance <= max_distance {
|
||||
matches.push(SemanticMatch {
|
||||
node_id,
|
||||
score: 1.0 - distance,
|
||||
path_length: 0,
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(matches)
|
||||
}
|
||||
|
||||
/// Find semantic paths through the graph
|
||||
pub fn find_semantic_paths(
|
||||
&self,
|
||||
start_node: &NodeId,
|
||||
query: &[f32],
|
||||
max_paths: usize,
|
||||
) -> Result<Vec<SemanticPath>> {
|
||||
// This is a placeholder for the actual graph traversal logic
|
||||
// In a real implementation, this would:
|
||||
// 1. Start from the given node
|
||||
// 2. At each hop, find semantically similar neighbors
|
||||
// 3. Continue traversal while similarity > threshold
|
||||
// 4. Track paths and score them
|
||||
|
||||
let mut paths = Vec::new();
|
||||
|
||||
// Find similar nodes as potential path endpoints
|
||||
let similar = self.find_similar_nodes(query, max_paths)?;
|
||||
|
||||
for match_result in similar {
|
||||
paths.push(SemanticPath {
|
||||
nodes: vec![start_node.clone(), match_result.node_id],
|
||||
edges: vec![],
|
||||
semantic_score: match_result.score,
|
||||
graph_distance: 1,
|
||||
combined_score: self.compute_path_score(match_result.score, 1),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(paths)
|
||||
}
|
||||
|
||||
/// Detect clusters using embeddings
|
||||
pub fn detect_clusters(
|
||||
&self,
|
||||
nodes: &[NodeId],
|
||||
min_cluster_size: usize,
|
||||
) -> Result<Vec<ClusterResult>> {
|
||||
// This is a placeholder for clustering logic
|
||||
// Real implementation would use algorithms like:
|
||||
// - DBSCAN on embedding space
|
||||
// - Community detection on similarity graph
|
||||
// - Hierarchical clustering
|
||||
|
||||
let mut clusters = Vec::new();
|
||||
|
||||
// Simple example: group all nodes as one cluster
|
||||
if nodes.len() >= min_cluster_size {
|
||||
clusters.push(ClusterResult {
|
||||
cluster_id: 0,
|
||||
nodes: nodes.to_vec(),
|
||||
centroid: None,
|
||||
coherence_score: 0.85,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(clusters)
|
||||
}
|
||||
|
||||
/// Find semantically related edges
|
||||
pub fn find_related_edges(&self, query: &[f32], k: usize) -> Result<Vec<EdgeMatch>> {
|
||||
let results = self.index.search_similar_edges(query, k)?;
|
||||
|
||||
// Pre-compute max distance threshold for faster comparison
|
||||
let max_distance = 1.0 - self.config.min_similarity;
|
||||
|
||||
// Convert distance to similarity with single-pass optimization
|
||||
let mut matches = Vec::with_capacity(results.len());
|
||||
for (edge_id, distance) in results {
|
||||
if distance <= max_distance {
|
||||
matches.push(EdgeMatch {
|
||||
edge_id,
|
||||
score: 1.0 - distance,
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(matches)
|
||||
}
|
||||
|
||||
/// Compute combined score for a path
|
||||
fn compute_path_score(&self, semantic_score: f32, graph_distance: usize) -> f32 {
|
||||
let w = self.config.semantic_weight;
|
||||
let distance_penalty = 1.0 / (graph_distance as f32 + 1.0);
|
||||
|
||||
w * semantic_score + (1.0 - w) * distance_penalty
|
||||
}
|
||||
|
||||
/// Expand query using similar terms
|
||||
pub fn expand_query(&self, query: &[f32], expansion_factor: usize) -> Result<Vec<Vec<f32>>> {
|
||||
// Find similar embeddings to expand the query
|
||||
let similar = self.index.search_similar_nodes(query, expansion_factor)?;
|
||||
|
||||
// In a real implementation, we would retrieve the actual embeddings
|
||||
// For now, return the original query
|
||||
Ok(vec![query.to_vec()])
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of a semantic match
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SemanticMatch {
|
||||
pub node_id: NodeId,
|
||||
pub score: f32,
|
||||
pub path_length: usize,
|
||||
}
|
||||
|
||||
/// A semantic path through the graph
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SemanticPath {
|
||||
/// Nodes in the path
|
||||
pub nodes: Vec<NodeId>,
|
||||
/// Edges connecting the nodes
|
||||
pub edges: Vec<EdgeId>,
|
||||
/// Semantic similarity score
|
||||
pub semantic_score: f32,
|
||||
/// Graph distance (number of hops)
|
||||
pub graph_distance: usize,
|
||||
/// Combined score (semantic + distance)
|
||||
pub combined_score: f32,
|
||||
}
|
||||
|
||||
/// Result of clustering analysis
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClusterResult {
|
||||
pub cluster_id: usize,
|
||||
pub nodes: Vec<NodeId>,
|
||||
pub centroid: Option<Vec<f32>>,
|
||||
pub coherence_score: f32,
|
||||
}
|
||||
|
||||
/// Match result for edges
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EdgeMatch {
|
||||
pub edge_id: EdgeId,
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::hybrid::vector_index::{EmbeddingConfig, VectorIndexType};
|
||||
|
||||
#[test]
|
||||
fn test_semantic_search_creation() {
|
||||
let config = EmbeddingConfig::default();
|
||||
let index = HybridIndex::new(config).unwrap();
|
||||
let search_config = SemanticSearchConfig::default();
|
||||
|
||||
let _search = SemanticSearch::new(index, search_config);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_similar_nodes() -> Result<()> {
|
||||
let config = EmbeddingConfig {
|
||||
dimensions: 4,
|
||||
..Default::default()
|
||||
};
|
||||
let index = HybridIndex::new(config)?;
|
||||
index.initialize_index(VectorIndexType::Node)?;
|
||||
|
||||
// Add test embeddings
|
||||
index.add_node_embedding("doc1".to_string(), vec![1.0, 0.0, 0.0, 0.0])?;
|
||||
index.add_node_embedding("doc2".to_string(), vec![0.9, 0.1, 0.0, 0.0])?;
|
||||
|
||||
let search = SemanticSearch::new(index, SemanticSearchConfig::default());
|
||||
let results = search.find_similar_nodes(&[1.0, 0.0, 0.0, 0.0], 5)?;
|
||||
|
||||
assert!(!results.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cluster_detection() -> Result<()> {
|
||||
let config = EmbeddingConfig::default();
|
||||
let index = HybridIndex::new(config)?;
|
||||
let search = SemanticSearch::new(index, SemanticSearchConfig::default());
|
||||
|
||||
let nodes = vec!["n1".to_string(), "n2".to_string(), "n3".to_string()];
|
||||
let clusters = search.detect_clusters(&nodes, 2)?;
|
||||
|
||||
assert_eq!(clusters.len(), 1);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_similarity_score_range() -> Result<()> {
|
||||
// Verify similarity scores are in [0, 1] range after conversion
|
||||
let config = EmbeddingConfig {
|
||||
dimensions: 4,
|
||||
..Default::default()
|
||||
};
|
||||
let index = HybridIndex::new(config)?;
|
||||
index.initialize_index(VectorIndexType::Node)?;
|
||||
|
||||
// Add embeddings with varying similarity
|
||||
index.add_node_embedding("identical".to_string(), vec![1.0, 0.0, 0.0, 0.0])?;
|
||||
index.add_node_embedding("similar".to_string(), vec![0.9, 0.1, 0.0, 0.0])?;
|
||||
index.add_node_embedding("different".to_string(), vec![0.0, 1.0, 0.0, 0.0])?;
|
||||
|
||||
let search_config = SemanticSearchConfig {
|
||||
min_similarity: 0.0, // Accept all results for this test
|
||||
..Default::default()
|
||||
};
|
||||
let search = SemanticSearch::new(index, search_config);
|
||||
let results = search.find_similar_nodes(&[1.0, 0.0, 0.0, 0.0], 10)?;
|
||||
|
||||
// All scores should be in [0, 1]
|
||||
for result in &results {
|
||||
assert!(
|
||||
result.score >= 0.0 && result.score <= 1.0,
|
||||
"Score {} out of valid range [0, 1]",
|
||||
result.score
|
||||
);
|
||||
}
|
||||
|
||||
// Identical vector should have highest similarity (close to 1.0)
|
||||
if !results.is_empty() {
|
||||
let top_result = &results[0];
|
||||
assert!(
|
||||
top_result.score > 0.9,
|
||||
"Identical vector should have score > 0.9"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_min_similarity_filtering() -> Result<()> {
|
||||
let config = EmbeddingConfig {
|
||||
dimensions: 4,
|
||||
..Default::default()
|
||||
};
|
||||
let index = HybridIndex::new(config)?;
|
||||
index.initialize_index(VectorIndexType::Node)?;
|
||||
|
||||
// Add embeddings
|
||||
index.add_node_embedding("high_sim".to_string(), vec![1.0, 0.0, 0.0, 0.0])?;
|
||||
index.add_node_embedding("low_sim".to_string(), vec![0.0, 0.0, 0.0, 1.0])?;
|
||||
|
||||
// Set high minimum similarity threshold
|
||||
let search_config = SemanticSearchConfig {
|
||||
min_similarity: 0.9,
|
||||
..Default::default()
|
||||
};
|
||||
let search = SemanticSearch::new(index, search_config);
|
||||
let results = search.find_similar_nodes(&[1.0, 0.0, 0.0, 0.0], 10)?;
|
||||
|
||||
// Low similarity result should be filtered out
|
||||
for result in &results {
|
||||
assert!(
|
||||
result.score >= 0.9,
|
||||
"Result with score {} should be filtered out (min: 0.9)",
|
||||
result.score
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
382
vendor/ruvector/crates/ruvector-graph/src/hybrid/vector_index.rs
vendored
Normal file
382
vendor/ruvector/crates/ruvector-graph/src/hybrid/vector_index.rs
vendored
Normal file
@@ -0,0 +1,382 @@
|
||||
//! Vector indexing for graph elements
|
||||
//!
|
||||
//! Integrates RuVector's index (HNSW or Flat) with graph nodes, edges, and hyperedges.
|
||||
|
||||
use crate::error::{GraphError, Result};
|
||||
use crate::types::{EdgeId, NodeId, Properties, PropertyValue};
|
||||
use dashmap::DashMap;
|
||||
use parking_lot::RwLock;
|
||||
use ruvector_core::index::flat::FlatIndex;
|
||||
#[cfg(feature = "hnsw_rs")]
|
||||
use ruvector_core::index::hnsw::HnswIndex;
|
||||
use ruvector_core::index::VectorIndex;
|
||||
#[cfg(feature = "hnsw_rs")]
|
||||
use ruvector_core::types::HnswConfig;
|
||||
use ruvector_core::types::{DistanceMetric, SearchResult};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Type of graph element that can be indexed
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum VectorIndexType {
|
||||
/// Node embeddings
|
||||
Node,
|
||||
/// Edge embeddings
|
||||
Edge,
|
||||
/// Hyperedge embeddings
|
||||
Hyperedge,
|
||||
}
|
||||
|
||||
/// Configuration for embedding storage
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbeddingConfig {
|
||||
/// Dimension of embeddings
|
||||
pub dimensions: usize,
|
||||
/// Distance metric for similarity
|
||||
pub metric: DistanceMetric,
|
||||
/// HNSW index configuration (only used when hnsw_rs feature is enabled)
|
||||
#[cfg(feature = "hnsw_rs")]
|
||||
pub hnsw_config: HnswConfig,
|
||||
/// Property name where embeddings are stored
|
||||
pub embedding_property: String,
|
||||
}
|
||||
|
||||
impl Default for EmbeddingConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
dimensions: 384, // Common for small models like MiniLM
|
||||
metric: DistanceMetric::Cosine,
|
||||
#[cfg(feature = "hnsw_rs")]
|
||||
hnsw_config: HnswConfig::default(),
|
||||
embedding_property: "embedding".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Index type alias based on feature flags
|
||||
#[cfg(feature = "hnsw_rs")]
|
||||
type IndexImpl = HnswIndex;
|
||||
#[cfg(not(feature = "hnsw_rs"))]
|
||||
type IndexImpl = FlatIndex;
|
||||
|
||||
/// Hybrid index combining graph structure with vector search
|
||||
pub struct HybridIndex {
|
||||
/// Node embeddings index
|
||||
node_index: Arc<RwLock<Option<IndexImpl>>>,
|
||||
/// Edge embeddings index
|
||||
edge_index: Arc<RwLock<Option<IndexImpl>>>,
|
||||
/// Hyperedge embeddings index
|
||||
hyperedge_index: Arc<RwLock<Option<IndexImpl>>>,
|
||||
|
||||
/// Mapping from node IDs to internal vector IDs
|
||||
node_id_map: Arc<DashMap<NodeId, String>>,
|
||||
/// Mapping from edge IDs to internal vector IDs
|
||||
edge_id_map: Arc<DashMap<EdgeId, String>>,
|
||||
/// Mapping from hyperedge IDs to internal vector IDs
|
||||
hyperedge_id_map: Arc<DashMap<String, String>>,
|
||||
|
||||
/// Configuration
|
||||
config: EmbeddingConfig,
|
||||
}
|
||||
|
||||
impl HybridIndex {
|
||||
/// Create a new hybrid index
|
||||
pub fn new(config: EmbeddingConfig) -> Result<Self> {
|
||||
Ok(Self {
|
||||
node_index: Arc::new(RwLock::new(None)),
|
||||
edge_index: Arc::new(RwLock::new(None)),
|
||||
hyperedge_index: Arc::new(RwLock::new(None)),
|
||||
node_id_map: Arc::new(DashMap::new()),
|
||||
edge_id_map: Arc::new(DashMap::new()),
|
||||
hyperedge_id_map: Arc::new(DashMap::new()),
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
/// Initialize index for a specific element type
|
||||
#[cfg(feature = "hnsw_rs")]
|
||||
pub fn initialize_index(&self, index_type: VectorIndexType) -> Result<()> {
|
||||
let index = HnswIndex::new(
|
||||
self.config.dimensions,
|
||||
self.config.metric,
|
||||
self.config.hnsw_config.clone(),
|
||||
)
|
||||
.map_err(|e| GraphError::IndexError(format!("Failed to create HNSW index: {}", e)))?;
|
||||
|
||||
match index_type {
|
||||
VectorIndexType::Node => {
|
||||
*self.node_index.write() = Some(index);
|
||||
}
|
||||
VectorIndexType::Edge => {
|
||||
*self.edge_index.write() = Some(index);
|
||||
}
|
||||
VectorIndexType::Hyperedge => {
|
||||
*self.hyperedge_index.write() = Some(index);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Initialize index for a specific element type (Flat index for WASM)
|
||||
#[cfg(not(feature = "hnsw_rs"))]
|
||||
pub fn initialize_index(&self, index_type: VectorIndexType) -> Result<()> {
|
||||
let index = FlatIndex::new(self.config.dimensions, self.config.metric);
|
||||
|
||||
match index_type {
|
||||
VectorIndexType::Node => {
|
||||
*self.node_index.write() = Some(index);
|
||||
}
|
||||
VectorIndexType::Edge => {
|
||||
*self.edge_index.write() = Some(index);
|
||||
}
|
||||
VectorIndexType::Hyperedge => {
|
||||
*self.hyperedge_index.write() = Some(index);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add node embedding to index
|
||||
pub fn add_node_embedding(&self, node_id: NodeId, embedding: Vec<f32>) -> Result<()> {
|
||||
if embedding.len() != self.config.dimensions {
|
||||
return Err(GraphError::InvalidEmbedding(format!(
|
||||
"Expected {} dimensions, got {}",
|
||||
self.config.dimensions,
|
||||
embedding.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut index_guard = self.node_index.write();
|
||||
let index = index_guard
|
||||
.as_mut()
|
||||
.ok_or_else(|| GraphError::IndexError("Node index not initialized".to_string()))?;
|
||||
|
||||
let vector_id = format!("node_{}", node_id);
|
||||
index
|
||||
.add(vector_id.clone(), embedding)
|
||||
.map_err(|e| GraphError::IndexError(format!("Failed to add node embedding: {}", e)))?;
|
||||
|
||||
self.node_id_map.insert(node_id, vector_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add edge embedding to index
|
||||
pub fn add_edge_embedding(&self, edge_id: EdgeId, embedding: Vec<f32>) -> Result<()> {
|
||||
if embedding.len() != self.config.dimensions {
|
||||
return Err(GraphError::InvalidEmbedding(format!(
|
||||
"Expected {} dimensions, got {}",
|
||||
self.config.dimensions,
|
||||
embedding.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut index_guard = self.edge_index.write();
|
||||
let index = index_guard
|
||||
.as_mut()
|
||||
.ok_or_else(|| GraphError::IndexError("Edge index not initialized".to_string()))?;
|
||||
|
||||
let vector_id = format!("edge_{}", edge_id);
|
||||
index
|
||||
.add(vector_id.clone(), embedding)
|
||||
.map_err(|e| GraphError::IndexError(format!("Failed to add edge embedding: {}", e)))?;
|
||||
|
||||
self.edge_id_map.insert(edge_id, vector_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add hyperedge embedding to index
|
||||
pub fn add_hyperedge_embedding(&self, hyperedge_id: String, embedding: Vec<f32>) -> Result<()> {
|
||||
if embedding.len() != self.config.dimensions {
|
||||
return Err(GraphError::InvalidEmbedding(format!(
|
||||
"Expected {} dimensions, got {}",
|
||||
self.config.dimensions,
|
||||
embedding.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut index_guard = self.hyperedge_index.write();
|
||||
let index = index_guard
|
||||
.as_mut()
|
||||
.ok_or_else(|| GraphError::IndexError("Hyperedge index not initialized".to_string()))?;
|
||||
|
||||
let vector_id = format!("hyperedge_{}", hyperedge_id);
|
||||
index.add(vector_id.clone(), embedding).map_err(|e| {
|
||||
GraphError::IndexError(format!("Failed to add hyperedge embedding: {}", e))
|
||||
})?;
|
||||
|
||||
self.hyperedge_id_map.insert(hyperedge_id, vector_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Search for similar nodes
|
||||
pub fn search_similar_nodes(&self, query: &[f32], k: usize) -> Result<Vec<(NodeId, f32)>> {
|
||||
let index_guard = self.node_index.read();
|
||||
let index = index_guard
|
||||
.as_ref()
|
||||
.ok_or_else(|| GraphError::IndexError("Node index not initialized".to_string()))?;
|
||||
|
||||
let results = index
|
||||
.search(query, k)
|
||||
.map_err(|e| GraphError::IndexError(format!("Search failed: {}", e)))?;
|
||||
|
||||
Ok(results
|
||||
.into_iter()
|
||||
.filter_map(|result| {
|
||||
// Remove "node_" prefix to get original ID
|
||||
let node_id = result.id.strip_prefix("node_")?.to_string();
|
||||
Some((node_id, result.score))
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Search for similar edges
|
||||
pub fn search_similar_edges(&self, query: &[f32], k: usize) -> Result<Vec<(EdgeId, f32)>> {
|
||||
let index_guard = self.edge_index.read();
|
||||
let index = index_guard
|
||||
.as_ref()
|
||||
.ok_or_else(|| GraphError::IndexError("Edge index not initialized".to_string()))?;
|
||||
|
||||
let results = index
|
||||
.search(query, k)
|
||||
.map_err(|e| GraphError::IndexError(format!("Search failed: {}", e)))?;
|
||||
|
||||
Ok(results
|
||||
.into_iter()
|
||||
.filter_map(|result| {
|
||||
let edge_id = result.id.strip_prefix("edge_")?.to_string();
|
||||
Some((edge_id, result.score))
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Search for similar hyperedges
|
||||
pub fn search_similar_hyperedges(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
|
||||
let index_guard = self.hyperedge_index.read();
|
||||
let index = index_guard
|
||||
.as_ref()
|
||||
.ok_or_else(|| GraphError::IndexError("Hyperedge index not initialized".to_string()))?;
|
||||
|
||||
let results = index
|
||||
.search(query, k)
|
||||
.map_err(|e| GraphError::IndexError(format!("Search failed: {}", e)))?;
|
||||
|
||||
Ok(results
|
||||
.into_iter()
|
||||
.filter_map(|result| {
|
||||
let hyperedge_id = result.id.strip_prefix("hyperedge_")?.to_string();
|
||||
Some((hyperedge_id, result.score))
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Extract embedding from properties
|
||||
pub fn extract_embedding(&self, properties: &Properties) -> Result<Option<Vec<f32>>> {
|
||||
let prop_value = match properties.get(&self.config.embedding_property) {
|
||||
Some(v) => v,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
match prop_value {
|
||||
PropertyValue::Array(arr) => {
|
||||
let embedding: Result<Vec<f32>> = arr
|
||||
.iter()
|
||||
.map(|v| match v {
|
||||
PropertyValue::Float(f) => Ok(*f as f32),
|
||||
PropertyValue::Integer(i) => Ok(*i as f32),
|
||||
_ => Err(GraphError::InvalidEmbedding(
|
||||
"Embedding array must contain numbers".to_string(),
|
||||
)),
|
||||
})
|
||||
.collect();
|
||||
embedding.map(Some)
|
||||
}
|
||||
_ => Err(GraphError::InvalidEmbedding(
|
||||
"Embedding property must be an array".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get index statistics
|
||||
pub fn stats(&self) -> HybridIndexStats {
|
||||
let node_count = self.node_id_map.len();
|
||||
let edge_count = self.edge_id_map.len();
|
||||
let hyperedge_count = self.hyperedge_id_map.len();
|
||||
|
||||
HybridIndexStats {
|
||||
node_count,
|
||||
edge_count,
|
||||
hyperedge_count,
|
||||
total_embeddings: node_count + edge_count + hyperedge_count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Statistics about the hybrid index
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HybridIndexStats {
|
||||
pub node_count: usize,
|
||||
pub edge_count: usize,
|
||||
pub hyperedge_count: usize,
|
||||
pub total_embeddings: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_hybrid_index_creation() -> Result<()> {
|
||||
let config = EmbeddingConfig::default();
|
||||
let index = HybridIndex::new(config)?;
|
||||
|
||||
let stats = index.stats();
|
||||
assert_eq!(stats.total_embeddings, 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_embedding_indexing() -> Result<()> {
|
||||
let config = EmbeddingConfig {
|
||||
dimensions: 4,
|
||||
..Default::default()
|
||||
};
|
||||
let index = HybridIndex::new(config)?;
|
||||
index.initialize_index(VectorIndexType::Node)?;
|
||||
|
||||
let embedding = vec![1.0, 2.0, 3.0, 4.0];
|
||||
index.add_node_embedding("node1".to_string(), embedding)?;
|
||||
|
||||
let stats = index.stats();
|
||||
assert_eq!(stats.node_count, 1);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_similarity_search() -> Result<()> {
|
||||
let config = EmbeddingConfig {
|
||||
dimensions: 4,
|
||||
..Default::default()
|
||||
};
|
||||
let index = HybridIndex::new(config)?;
|
||||
index.initialize_index(VectorIndexType::Node)?;
|
||||
|
||||
// Add some embeddings
|
||||
index.add_node_embedding("node1".to_string(), vec![1.0, 0.0, 0.0, 0.0])?;
|
||||
index.add_node_embedding("node2".to_string(), vec![0.9, 0.1, 0.0, 0.0])?;
|
||||
index.add_node_embedding("node3".to_string(), vec![0.0, 1.0, 0.0, 0.0])?;
|
||||
|
||||
// Search for similar to node1
|
||||
let results = index.search_similar_nodes(&[1.0, 0.0, 0.0, 0.0], 2)?;
|
||||
|
||||
assert!(results.len() <= 2);
|
||||
if !results.is_empty() {
|
||||
assert_eq!(results[0].0, "node1");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
314
vendor/ruvector/crates/ruvector-graph/src/hyperedge.rs
vendored
Normal file
314
vendor/ruvector/crates/ruvector-graph/src/hyperedge.rs
vendored
Normal file
@@ -0,0 +1,314 @@
|
||||
//! N-ary relationship support (hyperedges)
|
||||
//!
|
||||
//! Extends the basic edge model to support relationships connecting multiple nodes
|
||||
|
||||
use crate::types::{NodeId, Properties, PropertyValue};
|
||||
use bincode::{Decode, Encode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Unique identifier for a hyperedge
|
||||
pub type HyperedgeId = String;
|
||||
|
||||
/// Hyperedge connecting multiple nodes (N-ary relationship)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
|
||||
pub struct Hyperedge {
|
||||
/// Unique identifier
|
||||
pub id: HyperedgeId,
|
||||
/// Node IDs connected by this hyperedge
|
||||
pub nodes: Vec<NodeId>,
|
||||
/// Hyperedge type/label (e.g., "MEETING", "COLLABORATION")
|
||||
pub edge_type: String,
|
||||
/// Natural language description of the relationship
|
||||
pub description: Option<String>,
|
||||
/// Property key-value pairs
|
||||
pub properties: Properties,
|
||||
/// Confidence/weight (0.0-1.0)
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
impl Hyperedge {
|
||||
/// Create a new hyperedge with generated UUID
|
||||
pub fn new<S: Into<String>>(nodes: Vec<NodeId>, edge_type: S) -> Self {
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
nodes,
|
||||
edge_type: edge_type.into(),
|
||||
description: None,
|
||||
properties: Properties::new(),
|
||||
confidence: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new hyperedge with specific ID
|
||||
pub fn with_id<S: Into<String>>(id: HyperedgeId, nodes: Vec<NodeId>, edge_type: S) -> Self {
|
||||
Self {
|
||||
id,
|
||||
nodes,
|
||||
edge_type: edge_type.into(),
|
||||
description: None,
|
||||
properties: Properties::new(),
|
||||
confidence: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the order of the hyperedge (number of nodes)
|
||||
pub fn order(&self) -> usize {
|
||||
self.nodes.len()
|
||||
}
|
||||
|
||||
/// Check if hyperedge contains a specific node
|
||||
pub fn contains_node(&self, node_id: &NodeId) -> bool {
|
||||
self.nodes.contains(node_id)
|
||||
}
|
||||
|
||||
/// Check if hyperedge contains all specified nodes
|
||||
pub fn contains_all_nodes(&self, node_ids: &[NodeId]) -> bool {
|
||||
node_ids.iter().all(|id| self.contains_node(id))
|
||||
}
|
||||
|
||||
/// Check if hyperedge contains any of the specified nodes
|
||||
pub fn contains_any_node(&self, node_ids: &[NodeId]) -> bool {
|
||||
node_ids.iter().any(|id| self.contains_node(id))
|
||||
}
|
||||
|
||||
/// Get unique nodes (removes duplicates)
|
||||
pub fn unique_nodes(&self) -> HashSet<&NodeId> {
|
||||
self.nodes.iter().collect()
|
||||
}
|
||||
|
||||
/// Set the description
|
||||
pub fn set_description<S: Into<String>>(&mut self, description: S) -> &mut Self {
|
||||
self.description = Some(description.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the confidence
|
||||
pub fn set_confidence(&mut self, confidence: f32) -> &mut Self {
|
||||
self.confidence = confidence.clamp(0.0, 1.0);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a property
|
||||
pub fn set_property<K, V>(&mut self, key: K, value: V) -> &mut Self
|
||||
where
|
||||
K: Into<String>,
|
||||
V: Into<PropertyValue>,
|
||||
{
|
||||
self.properties.insert(key.into(), value.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Get a property
|
||||
pub fn get_property(&self, key: &str) -> Option<&PropertyValue> {
|
||||
self.properties.get(key)
|
||||
}
|
||||
|
||||
/// Remove a property
|
||||
pub fn remove_property(&mut self, key: &str) -> Option<PropertyValue> {
|
||||
self.properties.remove(key)
|
||||
}
|
||||
|
||||
/// Check if hyperedge has a property
|
||||
pub fn has_property(&self, key: &str) -> bool {
|
||||
self.properties.contains_key(key)
|
||||
}
|
||||
|
||||
/// Get all property keys
|
||||
pub fn property_keys(&self) -> Vec<&String> {
|
||||
self.properties.keys().collect()
|
||||
}
|
||||
|
||||
/// Clear all properties
|
||||
pub fn clear_properties(&mut self) {
|
||||
self.properties.clear();
|
||||
}
|
||||
|
||||
/// Get the number of properties
|
||||
pub fn property_count(&self) -> usize {
|
||||
self.properties.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for creating hyperedges with fluent API
|
||||
pub struct HyperedgeBuilder {
|
||||
hyperedge: Hyperedge,
|
||||
}
|
||||
|
||||
impl HyperedgeBuilder {
|
||||
/// Create a new builder
|
||||
pub fn new<S: Into<String>>(nodes: Vec<NodeId>, edge_type: S) -> Self {
|
||||
Self {
|
||||
hyperedge: Hyperedge::new(nodes, edge_type),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create builder with specific ID
|
||||
pub fn with_id<S: Into<String>>(id: HyperedgeId, nodes: Vec<NodeId>, edge_type: S) -> Self {
|
||||
Self {
|
||||
hyperedge: Hyperedge::with_id(id, nodes, edge_type),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set description
|
||||
pub fn description<S: Into<String>>(mut self, description: S) -> Self {
|
||||
self.hyperedge.set_description(description);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set confidence
|
||||
pub fn confidence(mut self, confidence: f32) -> Self {
|
||||
self.hyperedge.set_confidence(confidence);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a property
|
||||
pub fn property<K, V>(mut self, key: K, value: V) -> Self
|
||||
where
|
||||
K: Into<String>,
|
||||
V: Into<PropertyValue>,
|
||||
{
|
||||
self.hyperedge.set_property(key, value);
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the hyperedge
|
||||
pub fn build(self) -> Hyperedge {
|
||||
self.hyperedge
|
||||
}
|
||||
}
|
||||
|
||||
/// Hyperedge role assignment for directed N-ary relationships
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
|
||||
pub struct HyperedgeWithRoles {
|
||||
/// Base hyperedge
|
||||
pub hyperedge: Hyperedge,
|
||||
/// Role assignments: node_id -> role
|
||||
pub roles: std::collections::HashMap<NodeId, String>,
|
||||
}
|
||||
|
||||
impl HyperedgeWithRoles {
|
||||
/// Create a new hyperedge with roles
|
||||
pub fn new(hyperedge: Hyperedge) -> Self {
|
||||
Self {
|
||||
hyperedge,
|
||||
roles: std::collections::HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Assign a role to a node
|
||||
pub fn assign_role<S: Into<String>>(&mut self, node_id: NodeId, role: S) -> &mut Self {
|
||||
self.roles.insert(node_id, role.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the role of a node
|
||||
pub fn get_role(&self, node_id: &NodeId) -> Option<&String> {
|
||||
self.roles.get(node_id)
|
||||
}
|
||||
|
||||
/// Get all nodes with a specific role
|
||||
pub fn nodes_with_role(&self, role: &str) -> Vec<&NodeId> {
|
||||
self.roles
|
||||
.iter()
|
||||
.filter(|(_, r)| r.as_str() == role)
|
||||
.map(|(id, _)| id)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_hyperedge_creation() {
|
||||
let nodes = vec![
|
||||
"node1".to_string(),
|
||||
"node2".to_string(),
|
||||
"node3".to_string(),
|
||||
];
|
||||
let hedge = Hyperedge::new(nodes, "MEETING");
|
||||
|
||||
assert!(!hedge.id.is_empty());
|
||||
assert_eq!(hedge.order(), 3);
|
||||
assert_eq!(hedge.edge_type, "MEETING");
|
||||
assert_eq!(hedge.confidence, 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hyperedge_contains() {
|
||||
let nodes = vec![
|
||||
"node1".to_string(),
|
||||
"node2".to_string(),
|
||||
"node3".to_string(),
|
||||
];
|
||||
let hedge = Hyperedge::new(nodes, "MEETING");
|
||||
|
||||
assert!(hedge.contains_node(&"node1".to_string()));
|
||||
assert!(hedge.contains_node(&"node2".to_string()));
|
||||
assert!(!hedge.contains_node(&"node4".to_string()));
|
||||
|
||||
assert!(hedge.contains_all_nodes(&["node1".to_string(), "node2".to_string()]));
|
||||
assert!(!hedge.contains_all_nodes(&["node1".to_string(), "node4".to_string()]));
|
||||
|
||||
assert!(hedge.contains_any_node(&["node1".to_string(), "node4".to_string()]));
|
||||
assert!(!hedge.contains_any_node(&["node4".to_string(), "node5".to_string()]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hyperedge_builder() {
|
||||
let nodes = vec!["node1".to_string(), "node2".to_string()];
|
||||
let hedge = HyperedgeBuilder::new(nodes, "COLLABORATION")
|
||||
.description("Team collaboration on project X")
|
||||
.confidence(0.95)
|
||||
.property("project", "X")
|
||||
.property("duration", 30i64)
|
||||
.build();
|
||||
|
||||
assert_eq!(hedge.edge_type, "COLLABORATION");
|
||||
assert_eq!(hedge.confidence, 0.95);
|
||||
assert!(hedge.description.is_some());
|
||||
assert_eq!(
|
||||
hedge.get_property("project"),
|
||||
Some(&PropertyValue::String("X".to_string()))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hyperedge_with_roles() {
|
||||
let nodes = vec![
|
||||
"alice".to_string(),
|
||||
"bob".to_string(),
|
||||
"charlie".to_string(),
|
||||
];
|
||||
let hedge = Hyperedge::new(nodes, "MEETING");
|
||||
|
||||
let mut hedge_with_roles = HyperedgeWithRoles::new(hedge);
|
||||
hedge_with_roles.assign_role("alice".to_string(), "organizer");
|
||||
hedge_with_roles.assign_role("bob".to_string(), "participant");
|
||||
hedge_with_roles.assign_role("charlie".to_string(), "participant");
|
||||
|
||||
assert_eq!(
|
||||
hedge_with_roles.get_role(&"alice".to_string()),
|
||||
Some(&"organizer".to_string())
|
||||
);
|
||||
|
||||
let participants = hedge_with_roles.nodes_with_role("participant");
|
||||
assert_eq!(participants.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unique_nodes() {
|
||||
let nodes = vec![
|
||||
"node1".to_string(),
|
||||
"node2".to_string(),
|
||||
"node1".to_string(), // duplicate
|
||||
];
|
||||
let hedge = Hyperedge::new(nodes, "TEST");
|
||||
|
||||
let unique = hedge.unique_nodes();
|
||||
assert_eq!(unique.len(), 2);
|
||||
}
|
||||
}
|
||||
472
vendor/ruvector/crates/ruvector-graph/src/index.rs
vendored
Normal file
472
vendor/ruvector/crates/ruvector-graph/src/index.rs
vendored
Normal file
@@ -0,0 +1,472 @@
|
||||
//! Index structures for fast node and edge lookups
|
||||
//!
|
||||
//! Provides label indexes, property indexes, and edge type indexes for efficient querying
|
||||
|
||||
use crate::edge::Edge;
|
||||
use crate::hyperedge::{Hyperedge, HyperedgeId};
|
||||
use crate::node::Node;
|
||||
use crate::types::{EdgeId, NodeId, PropertyValue};
|
||||
use dashmap::DashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Label index for nodes (maps labels to node IDs)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LabelIndex {
|
||||
/// Label -> Set of node IDs
|
||||
index: Arc<DashMap<String, HashSet<NodeId>>>,
|
||||
}
|
||||
|
||||
impl LabelIndex {
|
||||
/// Create a new label index
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
index: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a node to the index
|
||||
pub fn add_node(&self, node: &Node) {
|
||||
for label in &node.labels {
|
||||
self.index
|
||||
.entry(label.name.clone())
|
||||
.or_insert_with(HashSet::new)
|
||||
.insert(node.id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a node from the index
|
||||
pub fn remove_node(&self, node: &Node) {
|
||||
for label in &node.labels {
|
||||
if let Some(mut set) = self.index.get_mut(&label.name) {
|
||||
set.remove(&node.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all nodes with a specific label
|
||||
pub fn get_nodes_by_label(&self, label: &str) -> Vec<NodeId> {
|
||||
self.index
|
||||
.get(label)
|
||||
.map(|set| set.iter().cloned().collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get all labels in the index
|
||||
pub fn all_labels(&self) -> Vec<String> {
|
||||
self.index.iter().map(|entry| entry.key().clone()).collect()
|
||||
}
|
||||
|
||||
/// Count nodes with a specific label
|
||||
pub fn count_by_label(&self, label: &str) -> usize {
|
||||
self.index.get(label).map(|set| set.len()).unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Clear the index
|
||||
pub fn clear(&self) {
|
||||
self.index.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LabelIndex {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Property index for nodes (maps property keys to values to node IDs)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PropertyIndex {
|
||||
/// Property key -> Property value -> Set of node IDs
|
||||
index: Arc<DashMap<String, DashMap<String, HashSet<NodeId>>>>,
|
||||
}
|
||||
|
||||
impl PropertyIndex {
|
||||
/// Create a new property index
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
index: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a node to the index
|
||||
pub fn add_node(&self, node: &Node) {
|
||||
for (key, value) in &node.properties {
|
||||
let value_str = self.property_value_to_string(value);
|
||||
self.index
|
||||
.entry(key.clone())
|
||||
.or_insert_with(DashMap::new)
|
||||
.entry(value_str)
|
||||
.or_insert_with(HashSet::new)
|
||||
.insert(node.id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a node from the index
|
||||
pub fn remove_node(&self, node: &Node) {
|
||||
for (key, value) in &node.properties {
|
||||
let value_str = self.property_value_to_string(value);
|
||||
if let Some(value_map) = self.index.get(key) {
|
||||
if let Some(mut set) = value_map.get_mut(&value_str) {
|
||||
set.remove(&node.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get nodes by property key-value pair
|
||||
pub fn get_nodes_by_property(&self, key: &str, value: &PropertyValue) -> Vec<NodeId> {
|
||||
let value_str = self.property_value_to_string(value);
|
||||
self.index
|
||||
.get(key)
|
||||
.and_then(|value_map| {
|
||||
value_map
|
||||
.get(&value_str)
|
||||
.map(|set| set.iter().cloned().collect())
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get all nodes that have a specific property key (regardless of value)
|
||||
pub fn get_nodes_with_property(&self, key: &str) -> Vec<NodeId> {
|
||||
self.index
|
||||
.get(key)
|
||||
.map(|value_map| {
|
||||
let mut result = HashSet::new();
|
||||
for entry in value_map.iter() {
|
||||
for id in entry.value().iter() {
|
||||
result.insert(id.clone());
|
||||
}
|
||||
}
|
||||
result.into_iter().collect()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get all property keys in the index
|
||||
pub fn all_property_keys(&self) -> Vec<String> {
|
||||
self.index.iter().map(|entry| entry.key().clone()).collect()
|
||||
}
|
||||
|
||||
/// Clear the index
|
||||
pub fn clear(&self) {
|
||||
self.index.clear();
|
||||
}
|
||||
|
||||
/// Convert property value to string for indexing
|
||||
fn property_value_to_string(&self, value: &PropertyValue) -> String {
|
||||
match value {
|
||||
PropertyValue::Null => "null".to_string(),
|
||||
PropertyValue::Boolean(b) => b.to_string(),
|
||||
PropertyValue::Integer(i) => i.to_string(),
|
||||
PropertyValue::Float(f) => f.to_string(),
|
||||
PropertyValue::String(s) => s.clone(),
|
||||
PropertyValue::Array(_) | PropertyValue::List(_) => format!("{:?}", value),
|
||||
PropertyValue::Map(_) => format!("{:?}", value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PropertyIndex {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Edge type index (maps edge types to edge IDs)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EdgeTypeIndex {
|
||||
/// Edge type -> Set of edge IDs
|
||||
index: Arc<DashMap<String, HashSet<EdgeId>>>,
|
||||
}
|
||||
|
||||
impl EdgeTypeIndex {
|
||||
/// Create a new edge type index
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
index: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an edge to the index
|
||||
pub fn add_edge(&self, edge: &Edge) {
|
||||
self.index
|
||||
.entry(edge.edge_type.clone())
|
||||
.or_insert_with(HashSet::new)
|
||||
.insert(edge.id.clone());
|
||||
}
|
||||
|
||||
/// Remove an edge from the index
|
||||
pub fn remove_edge(&self, edge: &Edge) {
|
||||
if let Some(mut set) = self.index.get_mut(&edge.edge_type) {
|
||||
set.remove(&edge.id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all edges of a specific type
|
||||
pub fn get_edges_by_type(&self, edge_type: &str) -> Vec<EdgeId> {
|
||||
self.index
|
||||
.get(edge_type)
|
||||
.map(|set| set.iter().cloned().collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get all edge types
|
||||
pub fn all_edge_types(&self) -> Vec<String> {
|
||||
self.index.iter().map(|entry| entry.key().clone()).collect()
|
||||
}
|
||||
|
||||
/// Count edges of a specific type
|
||||
pub fn count_by_type(&self, edge_type: &str) -> usize {
|
||||
self.index.get(edge_type).map(|set| set.len()).unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Clear the index
|
||||
pub fn clear(&self) {
|
||||
self.index.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EdgeTypeIndex {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Adjacency index for fast neighbor lookups
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdjacencyIndex {
|
||||
/// Node ID -> Set of outgoing edge IDs
|
||||
outgoing: Arc<DashMap<NodeId, HashSet<EdgeId>>>,
|
||||
/// Node ID -> Set of incoming edge IDs
|
||||
incoming: Arc<DashMap<NodeId, HashSet<EdgeId>>>,
|
||||
}
|
||||
|
||||
impl AdjacencyIndex {
|
||||
/// Create a new adjacency index
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
outgoing: Arc::new(DashMap::new()),
|
||||
incoming: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an edge to the index
|
||||
pub fn add_edge(&self, edge: &Edge) {
|
||||
self.outgoing
|
||||
.entry(edge.from.clone())
|
||||
.or_insert_with(HashSet::new)
|
||||
.insert(edge.id.clone());
|
||||
|
||||
self.incoming
|
||||
.entry(edge.to.clone())
|
||||
.or_insert_with(HashSet::new)
|
||||
.insert(edge.id.clone());
|
||||
}
|
||||
|
||||
/// Remove an edge from the index
|
||||
pub fn remove_edge(&self, edge: &Edge) {
|
||||
if let Some(mut set) = self.outgoing.get_mut(&edge.from) {
|
||||
set.remove(&edge.id);
|
||||
}
|
||||
if let Some(mut set) = self.incoming.get_mut(&edge.to) {
|
||||
set.remove(&edge.id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all outgoing edges from a node
|
||||
pub fn get_outgoing_edges(&self, node_id: &NodeId) -> Vec<EdgeId> {
|
||||
self.outgoing
|
||||
.get(node_id)
|
||||
.map(|set| set.iter().cloned().collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get all incoming edges to a node
|
||||
pub fn get_incoming_edges(&self, node_id: &NodeId) -> Vec<EdgeId> {
|
||||
self.incoming
|
||||
.get(node_id)
|
||||
.map(|set| set.iter().cloned().collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get all edges connected to a node (both incoming and outgoing)
|
||||
pub fn get_all_edges(&self, node_id: &NodeId) -> Vec<EdgeId> {
|
||||
let mut edges = self.get_outgoing_edges(node_id);
|
||||
edges.extend(self.get_incoming_edges(node_id));
|
||||
edges
|
||||
}
|
||||
|
||||
/// Get degree (number of outgoing edges)
|
||||
pub fn out_degree(&self, node_id: &NodeId) -> usize {
|
||||
self.outgoing.get(node_id).map(|set| set.len()).unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Get in-degree (number of incoming edges)
|
||||
pub fn in_degree(&self, node_id: &NodeId) -> usize {
|
||||
self.incoming.get(node_id).map(|set| set.len()).unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Clear the index
|
||||
pub fn clear(&self) {
|
||||
self.outgoing.clear();
|
||||
self.incoming.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AdjacencyIndex {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Hyperedge node index (maps nodes to hyperedges they participate in)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HyperedgeNodeIndex {
|
||||
/// Node ID -> Set of hyperedge IDs
|
||||
index: Arc<DashMap<NodeId, HashSet<HyperedgeId>>>,
|
||||
}
|
||||
|
||||
impl HyperedgeNodeIndex {
|
||||
/// Create a new hyperedge node index
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
index: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a hyperedge to the index
|
||||
pub fn add_hyperedge(&self, hyperedge: &Hyperedge) {
|
||||
for node_id in &hyperedge.nodes {
|
||||
self.index
|
||||
.entry(node_id.clone())
|
||||
.or_insert_with(HashSet::new)
|
||||
.insert(hyperedge.id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a hyperedge from the index
|
||||
pub fn remove_hyperedge(&self, hyperedge: &Hyperedge) {
|
||||
for node_id in &hyperedge.nodes {
|
||||
if let Some(mut set) = self.index.get_mut(node_id) {
|
||||
set.remove(&hyperedge.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all hyperedges containing a node
|
||||
pub fn get_hyperedges_by_node(&self, node_id: &NodeId) -> Vec<HyperedgeId> {
|
||||
self.index
|
||||
.get(node_id)
|
||||
.map(|set| set.iter().cloned().collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Clear the index
|
||||
pub fn clear(&self) {
|
||||
self.index.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HyperedgeNodeIndex {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::node::NodeBuilder;
|
||||
|
||||
#[test]
|
||||
fn test_label_index() {
|
||||
let index = LabelIndex::new();
|
||||
|
||||
let node1 = NodeBuilder::new().label("Person").label("User").build();
|
||||
|
||||
let node2 = NodeBuilder::new().label("Person").build();
|
||||
|
||||
index.add_node(&node1);
|
||||
index.add_node(&node2);
|
||||
|
||||
let people = index.get_nodes_by_label("Person");
|
||||
assert_eq!(people.len(), 2);
|
||||
|
||||
let users = index.get_nodes_by_label("User");
|
||||
assert_eq!(users.len(), 1);
|
||||
|
||||
assert_eq!(index.count_by_label("Person"), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_property_index() {
|
||||
let index = PropertyIndex::new();
|
||||
|
||||
let node1 = NodeBuilder::new()
|
||||
.property("name", "Alice")
|
||||
.property("age", 30i64)
|
||||
.build();
|
||||
|
||||
let node2 = NodeBuilder::new()
|
||||
.property("name", "Bob")
|
||||
.property("age", 30i64)
|
||||
.build();
|
||||
|
||||
index.add_node(&node1);
|
||||
index.add_node(&node2);
|
||||
|
||||
let alice =
|
||||
index.get_nodes_by_property("name", &PropertyValue::String("Alice".to_string()));
|
||||
assert_eq!(alice.len(), 1);
|
||||
|
||||
let age_30 = index.get_nodes_by_property("age", &PropertyValue::Integer(30));
|
||||
assert_eq!(age_30.len(), 2);
|
||||
|
||||
let with_age = index.get_nodes_with_property("age");
|
||||
assert_eq!(with_age.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_type_index() {
|
||||
let index = EdgeTypeIndex::new();
|
||||
|
||||
let edge1 = Edge::create("n1".to_string(), "n2".to_string(), "KNOWS");
|
||||
let edge2 = Edge::create("n2".to_string(), "n3".to_string(), "KNOWS");
|
||||
let edge3 = Edge::create("n1".to_string(), "n3".to_string(), "WORKS_WITH");
|
||||
|
||||
index.add_edge(&edge1);
|
||||
index.add_edge(&edge2);
|
||||
index.add_edge(&edge3);
|
||||
|
||||
let knows_edges = index.get_edges_by_type("KNOWS");
|
||||
assert_eq!(knows_edges.len(), 2);
|
||||
|
||||
let works_with_edges = index.get_edges_by_type("WORKS_WITH");
|
||||
assert_eq!(works_with_edges.len(), 1);
|
||||
|
||||
assert_eq!(index.all_edge_types().len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adjacency_index() {
|
||||
let index = AdjacencyIndex::new();
|
||||
|
||||
let edge1 = Edge::create("n1".to_string(), "n2".to_string(), "KNOWS");
|
||||
let edge2 = Edge::create("n1".to_string(), "n3".to_string(), "KNOWS");
|
||||
let edge3 = Edge::create("n2".to_string(), "n1".to_string(), "KNOWS");
|
||||
|
||||
index.add_edge(&edge1);
|
||||
index.add_edge(&edge2);
|
||||
index.add_edge(&edge3);
|
||||
|
||||
assert_eq!(index.out_degree(&"n1".to_string()), 2);
|
||||
assert_eq!(index.in_degree(&"n1".to_string()), 1);
|
||||
|
||||
let outgoing = index.get_outgoing_edges(&"n1".to_string());
|
||||
assert_eq!(outgoing.len(), 2);
|
||||
|
||||
let incoming = index.get_incoming_edges(&"n1".to_string());
|
||||
assert_eq!(incoming.len(), 1);
|
||||
}
|
||||
}
|
||||
61
vendor/ruvector/crates/ruvector-graph/src/lib.rs
vendored
Normal file
61
vendor/ruvector/crates/ruvector-graph/src/lib.rs
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
//! # RuVector Graph Database
|
||||
//!
|
||||
//! A high-performance graph database layer built on RuVector with Neo4j compatibility.
|
||||
//! Supports property graphs, hypergraphs, Cypher queries, ACID transactions, and distributed queries.
|
||||
|
||||
pub mod cypher;
|
||||
pub mod edge;
|
||||
pub mod error;
|
||||
pub mod executor;
|
||||
pub mod graph;
|
||||
pub mod hyperedge;
|
||||
pub mod index;
|
||||
pub mod node;
|
||||
pub mod property;
|
||||
pub mod storage;
|
||||
pub mod transaction;
|
||||
pub mod types;
|
||||
|
||||
// Performance optimization modules
|
||||
pub mod optimization;
|
||||
|
||||
// Vector-graph hybrid query capabilities
|
||||
pub mod hybrid;
|
||||
|
||||
// Distributed graph capabilities
|
||||
#[cfg(feature = "distributed")]
|
||||
pub mod distributed;
|
||||
|
||||
// Core type re-exports
|
||||
pub use edge::{Edge, EdgeBuilder};
|
||||
pub use error::{GraphError, Result};
|
||||
pub use graph::GraphDB;
|
||||
pub use hyperedge::{Hyperedge, HyperedgeBuilder, HyperedgeId};
|
||||
pub use node::{Node, NodeBuilder};
|
||||
#[cfg(feature = "storage")]
|
||||
pub use storage::GraphStorage;
|
||||
pub use transaction::{IsolationLevel, Transaction, TransactionManager};
|
||||
pub use types::{EdgeId, Label, NodeId, Properties, PropertyValue, RelationType};
|
||||
|
||||
// Re-export hybrid query types when available
|
||||
#[cfg(not(feature = "minimal"))]
|
||||
pub use hybrid::{
|
||||
EmbeddingConfig, GnnConfig, GraphNeuralEngine, HybridIndex, RagConfig, RagEngine,
|
||||
SemanticSearch, VectorCypherParser,
|
||||
};
|
||||
|
||||
// Re-export distributed types when feature is enabled
|
||||
#[cfg(feature = "distributed")]
|
||||
pub use distributed::{
|
||||
Coordinator, Federation, GossipMembership, GraphReplication, GraphShard, RpcClient, RpcServer,
|
||||
ShardCoordinator, ShardStrategy,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[test]
|
||||
fn test_placeholder() {
|
||||
// Placeholder test to allow compilation
|
||||
assert!(true);
|
||||
}
|
||||
}
|
||||
149
vendor/ruvector/crates/ruvector-graph/src/node.rs
vendored
Normal file
149
vendor/ruvector/crates/ruvector-graph/src/node.rs
vendored
Normal file
@@ -0,0 +1,149 @@
|
||||
//! Node implementation
|
||||
|
||||
use crate::types::{Label, NodeId, Properties, PropertyValue};
|
||||
use bincode::{Decode, Encode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
|
||||
pub struct Node {
|
||||
pub id: NodeId,
|
||||
pub labels: Vec<Label>,
|
||||
pub properties: Properties,
|
||||
}
|
||||
|
||||
impl Node {
|
||||
pub fn new(id: NodeId, labels: Vec<Label>, properties: Properties) -> Self {
|
||||
Self {
|
||||
id,
|
||||
labels,
|
||||
properties,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if node has a specific label
|
||||
pub fn has_label(&self, label_name: &str) -> bool {
|
||||
self.labels.iter().any(|l| l.name == label_name)
|
||||
}
|
||||
|
||||
/// Get a property value by key
|
||||
pub fn get_property(&self, key: &str) -> Option<&PropertyValue> {
|
||||
self.properties.get(key)
|
||||
}
|
||||
|
||||
/// Set a property value
|
||||
pub fn set_property(&mut self, key: impl Into<String>, value: PropertyValue) {
|
||||
self.properties.insert(key.into(), value);
|
||||
}
|
||||
|
||||
/// Add a label to the node
|
||||
pub fn add_label(&mut self, label: impl Into<String>) {
|
||||
self.labels.push(Label::new(label));
|
||||
}
|
||||
|
||||
/// Remove a label from the node
|
||||
pub fn remove_label(&mut self, label_name: &str) -> bool {
|
||||
let len_before = self.labels.len();
|
||||
self.labels.retain(|l| l.name != label_name);
|
||||
self.labels.len() < len_before
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for constructing Node instances
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct NodeBuilder {
|
||||
id: Option<NodeId>,
|
||||
labels: Vec<Label>,
|
||||
properties: Properties,
|
||||
}
|
||||
|
||||
impl NodeBuilder {
|
||||
/// Create a new node builder
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Set the node ID
|
||||
pub fn id(mut self, id: impl Into<String>) -> Self {
|
||||
self.id = Some(id.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a label to the node
|
||||
pub fn label(mut self, label: impl Into<String>) -> Self {
|
||||
self.labels.push(Label::new(label));
|
||||
self
|
||||
}
|
||||
|
||||
/// Add multiple labels to the node
|
||||
pub fn labels(mut self, labels: impl IntoIterator<Item = impl Into<String>>) -> Self {
|
||||
for label in labels {
|
||||
self.labels.push(Label::new(label));
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a property to the node
|
||||
pub fn property<V: Into<PropertyValue>>(mut self, key: impl Into<String>, value: V) -> Self {
|
||||
self.properties.insert(key.into(), value.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Add multiple properties to the node
|
||||
pub fn properties(mut self, props: Properties) -> Self {
|
||||
self.properties.extend(props);
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the node
|
||||
pub fn build(self) -> Node {
|
||||
Node {
|
||||
id: self.id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||
labels: self.labels,
|
||||
properties: self.properties,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_node_builder() {
|
||||
let node = NodeBuilder::new()
|
||||
.label("Person")
|
||||
.property("name", "Alice")
|
||||
.property("age", 30i64)
|
||||
.build();
|
||||
|
||||
assert!(node.has_label("Person"));
|
||||
assert!(!node.has_label("Organization"));
|
||||
assert_eq!(
|
||||
node.get_property("name"),
|
||||
Some(&PropertyValue::String("Alice".to_string()))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_has_label() {
|
||||
let node = NodeBuilder::new().label("Person").label("Employee").build();
|
||||
|
||||
assert!(node.has_label("Person"));
|
||||
assert!(node.has_label("Employee"));
|
||||
assert!(!node.has_label("Company"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_modify_labels() {
|
||||
let mut node = NodeBuilder::new().label("Person").build();
|
||||
|
||||
node.add_label("Employee");
|
||||
assert!(node.has_label("Employee"));
|
||||
|
||||
let removed = node.remove_label("Person");
|
||||
assert!(removed);
|
||||
assert!(!node.has_label("Person"));
|
||||
}
|
||||
}
|
||||
498
vendor/ruvector/crates/ruvector-graph/src/optimization/adaptive_radix.rs
vendored
Normal file
498
vendor/ruvector/crates/ruvector-graph/src/optimization/adaptive_radix.rs
vendored
Normal file
@@ -0,0 +1,498 @@
|
||||
//! Adaptive Radix Tree (ART) for property indexes
|
||||
//!
|
||||
//! ART provides space-efficient indexing with excellent cache performance
|
||||
//! through adaptive node sizes and path compression.
|
||||
|
||||
use std::cmp::Ordering;
|
||||
use std::mem;
|
||||
|
||||
/// Adaptive Radix Tree for property indexing
|
||||
pub struct AdaptiveRadixTree<V: Clone> {
|
||||
root: Option<Box<ArtNode<V>>>,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
impl<V: Clone> AdaptiveRadixTree<V> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
root: None,
|
||||
size: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert a key-value pair
|
||||
pub fn insert(&mut self, key: &[u8], value: V) {
|
||||
if self.root.is_none() {
|
||||
self.root = Some(Box::new(ArtNode::Leaf {
|
||||
key: key.to_vec(),
|
||||
value,
|
||||
}));
|
||||
self.size += 1;
|
||||
return;
|
||||
}
|
||||
|
||||
let root = self.root.take().unwrap();
|
||||
self.root = Some(Self::insert_recursive(root, key, 0, value));
|
||||
self.size += 1;
|
||||
}
|
||||
|
||||
fn insert_recursive(
|
||||
mut node: Box<ArtNode<V>>,
|
||||
key: &[u8],
|
||||
depth: usize,
|
||||
value: V,
|
||||
) -> Box<ArtNode<V>> {
|
||||
match node.as_mut() {
|
||||
ArtNode::Leaf {
|
||||
key: leaf_key,
|
||||
value: leaf_value,
|
||||
} => {
|
||||
// Check if keys are identical
|
||||
if *leaf_key == key {
|
||||
// Replace value
|
||||
*leaf_value = value;
|
||||
return node;
|
||||
}
|
||||
|
||||
// Find common prefix length starting from depth
|
||||
let common_prefix_len = Self::common_prefix_len(leaf_key, key, depth);
|
||||
let prefix = if depth + common_prefix_len <= leaf_key.len()
|
||||
&& depth + common_prefix_len <= key.len()
|
||||
{
|
||||
key[depth..depth + common_prefix_len].to_vec()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
// Create a new Node4 to hold both leaves
|
||||
let mut children: [Option<Box<ArtNode<V>>>; 4] = [None, None, None, None];
|
||||
let mut keys_arr = [0u8; 4];
|
||||
let mut num_children = 0u8;
|
||||
|
||||
let next_depth = depth + common_prefix_len;
|
||||
|
||||
// Get the distinguishing bytes for old and new keys
|
||||
let old_byte = if next_depth < leaf_key.len() {
|
||||
Some(leaf_key[next_depth])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let new_byte = if next_depth < key.len() {
|
||||
Some(key[next_depth])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Take ownership of old leaf's data
|
||||
let old_key = std::mem::take(leaf_key);
|
||||
let old_value = unsafe { std::ptr::read(leaf_value) };
|
||||
|
||||
// Add old leaf
|
||||
if let Some(byte) = old_byte {
|
||||
keys_arr[num_children as usize] = byte;
|
||||
children[num_children as usize] = Some(Box::new(ArtNode::Leaf {
|
||||
key: old_key,
|
||||
value: old_value,
|
||||
}));
|
||||
num_children += 1;
|
||||
}
|
||||
|
||||
// Add new leaf
|
||||
if let Some(byte) = new_byte {
|
||||
// Find insertion position (keep sorted for efficiency)
|
||||
let mut insert_idx = num_children as usize;
|
||||
for i in 0..num_children as usize {
|
||||
if byte < keys_arr[i] {
|
||||
insert_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Shift existing entries if needed
|
||||
for i in (insert_idx..num_children as usize).rev() {
|
||||
keys_arr[i + 1] = keys_arr[i];
|
||||
children[i + 1] = children[i].take();
|
||||
}
|
||||
|
||||
keys_arr[insert_idx] = byte;
|
||||
children[insert_idx] = Some(Box::new(ArtNode::Leaf {
|
||||
key: key.to_vec(),
|
||||
value,
|
||||
}));
|
||||
num_children += 1;
|
||||
}
|
||||
|
||||
Box::new(ArtNode::Node4 {
|
||||
prefix,
|
||||
children,
|
||||
keys: keys_arr,
|
||||
num_children,
|
||||
})
|
||||
}
|
||||
ArtNode::Node4 {
|
||||
prefix,
|
||||
children,
|
||||
keys,
|
||||
num_children,
|
||||
} => {
|
||||
// Check prefix match
|
||||
let prefix_match = Self::check_prefix(prefix, key, depth);
|
||||
|
||||
if prefix_match < prefix.len() {
|
||||
// Prefix mismatch - need to split the node
|
||||
let common = prefix[..prefix_match].to_vec();
|
||||
let remaining = prefix[prefix_match..].to_vec();
|
||||
let old_byte = remaining[0];
|
||||
|
||||
// Create new inner node with remaining prefix
|
||||
let old_children = std::mem::replace(children, [None, None, None, None]);
|
||||
let old_keys = *keys;
|
||||
let old_num = *num_children;
|
||||
|
||||
let inner_node = Box::new(ArtNode::Node4 {
|
||||
prefix: remaining[1..].to_vec(),
|
||||
children: old_children,
|
||||
keys: old_keys,
|
||||
num_children: old_num,
|
||||
});
|
||||
|
||||
// Create new leaf for the inserted key
|
||||
let next_depth = depth + prefix_match;
|
||||
let new_byte = if next_depth < key.len() {
|
||||
key[next_depth]
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let new_leaf = Box::new(ArtNode::Leaf {
|
||||
key: key.to_vec(),
|
||||
value,
|
||||
});
|
||||
|
||||
// Set up new node
|
||||
let mut new_children: [Option<Box<ArtNode<V>>>; 4] = [None, None, None, None];
|
||||
let mut new_keys = [0u8; 4];
|
||||
|
||||
if old_byte < new_byte {
|
||||
new_keys[0] = old_byte;
|
||||
new_children[0] = Some(inner_node);
|
||||
new_keys[1] = new_byte;
|
||||
new_children[1] = Some(new_leaf);
|
||||
} else {
|
||||
new_keys[0] = new_byte;
|
||||
new_children[0] = Some(new_leaf);
|
||||
new_keys[1] = old_byte;
|
||||
new_children[1] = Some(inner_node);
|
||||
}
|
||||
|
||||
return Box::new(ArtNode::Node4 {
|
||||
prefix: common,
|
||||
children: new_children,
|
||||
keys: new_keys,
|
||||
num_children: 2,
|
||||
});
|
||||
}
|
||||
|
||||
// Full prefix match - traverse to child
|
||||
let next_depth = depth + prefix.len();
|
||||
if next_depth < key.len() {
|
||||
let key_byte = key[next_depth];
|
||||
|
||||
// Find existing child
|
||||
for i in 0..(*num_children as usize) {
|
||||
if keys[i] == key_byte {
|
||||
let child = children[i].take().unwrap();
|
||||
children[i] =
|
||||
Some(Self::insert_recursive(child, key, next_depth + 1, value));
|
||||
return node;
|
||||
}
|
||||
}
|
||||
|
||||
// No matching child - add new one
|
||||
if (*num_children as usize) < 4 {
|
||||
let idx = *num_children as usize;
|
||||
keys[idx] = key_byte;
|
||||
children[idx] = Some(Box::new(ArtNode::Leaf {
|
||||
key: key.to_vec(),
|
||||
value,
|
||||
}));
|
||||
*num_children += 1;
|
||||
}
|
||||
// TODO: Handle node growth to Node16 when full
|
||||
}
|
||||
|
||||
node
|
||||
}
|
||||
_ => {
|
||||
// Handle other node types (Node16, Node48, Node256)
|
||||
node
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Search for a value by key
|
||||
pub fn get(&self, key: &[u8]) -> Option<&V> {
|
||||
let mut current = self.root.as_ref()?;
|
||||
let mut depth = 0;
|
||||
|
||||
loop {
|
||||
match current.as_ref() {
|
||||
ArtNode::Leaf {
|
||||
key: leaf_key,
|
||||
value,
|
||||
} => {
|
||||
if leaf_key == key {
|
||||
return Some(value);
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
ArtNode::Node4 {
|
||||
prefix,
|
||||
children,
|
||||
keys,
|
||||
num_children,
|
||||
} => {
|
||||
if !Self::match_prefix(prefix, key, depth) {
|
||||
return None;
|
||||
}
|
||||
|
||||
depth += prefix.len();
|
||||
if depth >= key.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let key_byte = key[depth];
|
||||
let mut found = false;
|
||||
|
||||
for i in 0..*num_children as usize {
|
||||
if keys[i] == key_byte {
|
||||
current = children[i].as_ref()?;
|
||||
depth += 1;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
_ => return None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if tree contains key
|
||||
pub fn contains_key(&self, key: &[u8]) -> bool {
|
||||
self.get(key).is_some()
|
||||
}
|
||||
|
||||
/// Get number of entries
|
||||
pub fn len(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
|
||||
/// Check if tree is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.size == 0
|
||||
}
|
||||
|
||||
/// Find common prefix length
|
||||
fn common_prefix_len(a: &[u8], b: &[u8], start: usize) -> usize {
|
||||
let mut len = 0;
|
||||
let max = a.len().min(b.len()) - start;
|
||||
|
||||
for i in 0..max {
|
||||
if a[start + i] == b[start + i] {
|
||||
len += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
len
|
||||
}
|
||||
|
||||
/// Check prefix match
|
||||
fn check_prefix(prefix: &[u8], key: &[u8], depth: usize) -> usize {
|
||||
let max = prefix.len().min(key.len() - depth);
|
||||
let mut matched = 0;
|
||||
|
||||
for i in 0..max {
|
||||
if prefix[i] == key[depth + i] {
|
||||
matched += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
matched
|
||||
}
|
||||
|
||||
/// Check if prefix matches
|
||||
fn match_prefix(prefix: &[u8], key: &[u8], depth: usize) -> bool {
|
||||
if depth + prefix.len() > key.len() {
|
||||
return false;
|
||||
}
|
||||
|
||||
for i in 0..prefix.len() {
|
||||
if prefix[i] != key[depth + i] {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl<V: Clone> Default for AdaptiveRadixTree<V> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// ART node types with adaptive sizing
|
||||
pub enum ArtNode<V> {
|
||||
/// Leaf node containing value
|
||||
Leaf { key: Vec<u8>, value: V },
|
||||
|
||||
/// Node with 4 children (smallest)
|
||||
Node4 {
|
||||
prefix: Vec<u8>,
|
||||
children: [Option<Box<ArtNode<V>>>; 4],
|
||||
keys: [u8; 4],
|
||||
num_children: u8,
|
||||
},
|
||||
|
||||
/// Node with 16 children
|
||||
Node16 {
|
||||
prefix: Vec<u8>,
|
||||
children: [Option<Box<ArtNode<V>>>; 16],
|
||||
keys: [u8; 16],
|
||||
num_children: u8,
|
||||
},
|
||||
|
||||
/// Node with 48 children (using index array)
|
||||
Node48 {
|
||||
prefix: Vec<u8>,
|
||||
children: [Option<Box<ArtNode<V>>>; 48],
|
||||
index: [u8; 256], // Maps key byte to child index
|
||||
num_children: u8,
|
||||
},
|
||||
|
||||
/// Node with 256 children (full array)
|
||||
Node256 {
|
||||
prefix: Vec<u8>,
|
||||
children: [Option<Box<ArtNode<V>>>; 256],
|
||||
num_children: u16,
|
||||
},
|
||||
}
|
||||
|
||||
impl<V> ArtNode<V> {
|
||||
/// Check if node is a leaf
|
||||
pub fn is_leaf(&self) -> bool {
|
||||
matches!(self, ArtNode::Leaf { .. })
|
||||
}
|
||||
|
||||
/// Get node type name
|
||||
pub fn node_type(&self) -> &str {
|
||||
match self {
|
||||
ArtNode::Leaf { .. } => "Leaf",
|
||||
ArtNode::Node4 { .. } => "Node4",
|
||||
ArtNode::Node16 { .. } => "Node16",
|
||||
ArtNode::Node48 { .. } => "Node48",
|
||||
ArtNode::Node256 { .. } => "Node256",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Iterator over ART entries
|
||||
pub struct ArtIter<'a, V> {
|
||||
stack: Vec<&'a ArtNode<V>>,
|
||||
}
|
||||
|
||||
impl<'a, V> Iterator for ArtIter<'a, V> {
|
||||
type Item = (&'a [u8], &'a V);
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
while let Some(node) = self.stack.pop() {
|
||||
match node {
|
||||
ArtNode::Leaf { key, value } => {
|
||||
return Some((key.as_slice(), value));
|
||||
}
|
||||
ArtNode::Node4 {
|
||||
children,
|
||||
num_children,
|
||||
..
|
||||
} => {
|
||||
for i in (0..*num_children as usize).rev() {
|
||||
if let Some(child) = &children[i] {
|
||||
self.stack.push(child);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Handle other node types
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_art_basic() {
|
||||
let mut tree = AdaptiveRadixTree::new();
|
||||
|
||||
tree.insert(b"hello", 1);
|
||||
tree.insert(b"world", 2);
|
||||
tree.insert(b"help", 3);
|
||||
|
||||
assert_eq!(tree.get(b"hello"), Some(&1));
|
||||
assert_eq!(tree.get(b"world"), Some(&2));
|
||||
assert_eq!(tree.get(b"help"), Some(&3));
|
||||
assert_eq!(tree.get(b"nonexistent"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_art_contains() {
|
||||
let mut tree = AdaptiveRadixTree::new();
|
||||
|
||||
tree.insert(b"test", 42);
|
||||
|
||||
assert!(tree.contains_key(b"test"));
|
||||
assert!(!tree.contains_key(b"other"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_art_len() {
|
||||
let mut tree = AdaptiveRadixTree::new();
|
||||
|
||||
assert_eq!(tree.len(), 0);
|
||||
assert!(tree.is_empty());
|
||||
|
||||
tree.insert(b"a", 1);
|
||||
tree.insert(b"b", 2);
|
||||
|
||||
assert_eq!(tree.len(), 2);
|
||||
assert!(!tree.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_art_common_prefix() {
|
||||
let mut tree = AdaptiveRadixTree::new();
|
||||
|
||||
tree.insert(b"prefix_one", 1);
|
||||
tree.insert(b"prefix_two", 2);
|
||||
tree.insert(b"other", 3);
|
||||
|
||||
assert_eq!(tree.get(b"prefix_one"), Some(&1));
|
||||
assert_eq!(tree.get(b"prefix_two"), Some(&2));
|
||||
assert_eq!(tree.get(b"other"), Some(&3));
|
||||
}
|
||||
}
|
||||
336
vendor/ruvector/crates/ruvector-graph/src/optimization/bloom_filter.rs
vendored
Normal file
336
vendor/ruvector/crates/ruvector-graph/src/optimization/bloom_filter.rs
vendored
Normal file
@@ -0,0 +1,336 @@
|
||||
//! Bloom filters for fast negative lookups
|
||||
//!
|
||||
//! Bloom filters provide O(1) membership tests with false positives
|
||||
//! but no false negatives, perfect for quickly eliminating non-existent keys.
|
||||
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
/// Standard bloom filter with configurable size and hash functions
|
||||
pub struct BloomFilter {
|
||||
/// Bit array
|
||||
bits: Vec<u64>,
|
||||
/// Number of hash functions
|
||||
num_hashes: usize,
|
||||
/// Number of bits
|
||||
num_bits: usize,
|
||||
}
|
||||
|
||||
impl BloomFilter {
|
||||
/// Create a new bloom filter
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `expected_items` - Expected number of items to be inserted
|
||||
/// * `false_positive_rate` - Desired false positive rate (e.g., 0.01 for 1%)
|
||||
pub fn new(expected_items: usize, false_positive_rate: f64) -> Self {
|
||||
let num_bits = Self::optimal_num_bits(expected_items, false_positive_rate);
|
||||
let num_hashes = Self::optimal_num_hashes(expected_items, num_bits);
|
||||
|
||||
let num_u64s = (num_bits + 63) / 64;
|
||||
|
||||
Self {
|
||||
bits: vec![0; num_u64s],
|
||||
num_hashes,
|
||||
num_bits,
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate optimal number of bits
|
||||
fn optimal_num_bits(n: usize, p: f64) -> usize {
|
||||
let ln2 = std::f64::consts::LN_2;
|
||||
(-(n as f64) * p.ln() / (ln2 * ln2)).ceil() as usize
|
||||
}
|
||||
|
||||
/// Calculate optimal number of hash functions
|
||||
fn optimal_num_hashes(n: usize, m: usize) -> usize {
|
||||
let ln2 = std::f64::consts::LN_2;
|
||||
((m as f64 / n as f64) * ln2).ceil() as usize
|
||||
}
|
||||
|
||||
/// Insert an item into the bloom filter
|
||||
pub fn insert<T: Hash>(&mut self, item: &T) {
|
||||
for i in 0..self.num_hashes {
|
||||
let hash = self.hash(item, i);
|
||||
let bit_index = hash % self.num_bits;
|
||||
let array_index = bit_index / 64;
|
||||
let bit_offset = bit_index % 64;
|
||||
|
||||
self.bits[array_index] |= 1u64 << bit_offset;
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if an item might be in the set
|
||||
///
|
||||
/// Returns true if the item might be present (with possible false positive)
|
||||
/// Returns false if the item is definitely not present
|
||||
pub fn contains<T: Hash>(&self, item: &T) -> bool {
|
||||
for i in 0..self.num_hashes {
|
||||
let hash = self.hash(item, i);
|
||||
let bit_index = hash % self.num_bits;
|
||||
let array_index = bit_index / 64;
|
||||
let bit_offset = bit_index % 64;
|
||||
|
||||
if (self.bits[array_index] & (1u64 << bit_offset)) == 0 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Hash function for bloom filter
|
||||
fn hash<T: Hash>(&self, item: &T, i: usize) -> usize {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
item.hash(&mut hasher);
|
||||
i.hash(&mut hasher);
|
||||
hasher.finish() as usize
|
||||
}
|
||||
|
||||
/// Clear the bloom filter
|
||||
pub fn clear(&mut self) {
|
||||
self.bits.fill(0);
|
||||
}
|
||||
|
||||
/// Get approximate number of items (based on bit saturation)
|
||||
pub fn approximate_count(&self) -> usize {
|
||||
let set_bits: u32 = self.bits.iter().map(|&word| word.count_ones()).sum();
|
||||
|
||||
let m = self.num_bits as f64;
|
||||
let k = self.num_hashes as f64;
|
||||
let x = set_bits as f64;
|
||||
|
||||
// Formula: n ≈ -(m/k) * ln(1 - x/m)
|
||||
let n = -(m / k) * (1.0 - x / m).ln();
|
||||
n as usize
|
||||
}
|
||||
|
||||
/// Get current false positive rate estimate
|
||||
pub fn current_false_positive_rate(&self) -> f64 {
|
||||
let set_bits: u32 = self.bits.iter().map(|&word| word.count_ones()).sum();
|
||||
|
||||
let p = set_bits as f64 / self.num_bits as f64;
|
||||
p.powi(self.num_hashes as i32)
|
||||
}
|
||||
}
|
||||
|
||||
/// Scalable bloom filter that grows as needed
|
||||
pub struct ScalableBloomFilter {
|
||||
/// Current active filter
|
||||
filters: Vec<BloomFilter>,
|
||||
/// Items per filter
|
||||
items_per_filter: usize,
|
||||
/// Target false positive rate
|
||||
false_positive_rate: f64,
|
||||
/// Growth factor
|
||||
growth_factor: f64,
|
||||
/// Current item count
|
||||
item_count: usize,
|
||||
}
|
||||
|
||||
impl ScalableBloomFilter {
|
||||
/// Create a new scalable bloom filter
|
||||
pub fn new(initial_capacity: usize, false_positive_rate: f64) -> Self {
|
||||
let initial_filter = BloomFilter::new(initial_capacity, false_positive_rate);
|
||||
|
||||
Self {
|
||||
filters: vec![initial_filter],
|
||||
items_per_filter: initial_capacity,
|
||||
false_positive_rate,
|
||||
growth_factor: 2.0,
|
||||
item_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert an item
|
||||
pub fn insert<T: Hash>(&mut self, item: &T) {
|
||||
// Check if we need to add a new filter
|
||||
if self.item_count >= self.items_per_filter * self.filters.len() {
|
||||
let new_capacity = (self.items_per_filter as f64 * self.growth_factor) as usize;
|
||||
let new_filter = BloomFilter::new(new_capacity, self.false_positive_rate);
|
||||
self.filters.push(new_filter);
|
||||
}
|
||||
|
||||
// Insert into the most recent filter
|
||||
if let Some(filter) = self.filters.last_mut() {
|
||||
filter.insert(item);
|
||||
}
|
||||
|
||||
self.item_count += 1;
|
||||
}
|
||||
|
||||
/// Check if item might be present
|
||||
pub fn contains<T: Hash>(&self, item: &T) -> bool {
|
||||
// Check all filters (item could be in any of them)
|
||||
self.filters.iter().any(|filter| filter.contains(item))
|
||||
}
|
||||
|
||||
/// Clear all filters
|
||||
pub fn clear(&mut self) {
|
||||
for filter in &mut self.filters {
|
||||
filter.clear();
|
||||
}
|
||||
self.item_count = 0;
|
||||
}
|
||||
|
||||
/// Get number of filters
|
||||
pub fn num_filters(&self) -> usize {
|
||||
self.filters.len()
|
||||
}
|
||||
|
||||
/// Get total memory usage in bytes
|
||||
pub fn memory_usage(&self) -> usize {
|
||||
self.filters.iter().map(|f| f.bits.len() * 8).sum()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ScalableBloomFilter {
|
||||
fn default() -> Self {
|
||||
Self::new(1000, 0.01)
|
||||
}
|
||||
}
|
||||
|
||||
/// Counting bloom filter (supports deletion)
|
||||
pub struct CountingBloomFilter {
|
||||
/// Counter array (4-bit counters)
|
||||
counters: Vec<u8>,
|
||||
/// Number of hash functions
|
||||
num_hashes: usize,
|
||||
/// Number of counters
|
||||
num_counters: usize,
|
||||
}
|
||||
|
||||
impl CountingBloomFilter {
|
||||
pub fn new(expected_items: usize, false_positive_rate: f64) -> Self {
|
||||
let num_counters = BloomFilter::optimal_num_bits(expected_items, false_positive_rate);
|
||||
let num_hashes = BloomFilter::optimal_num_hashes(expected_items, num_counters);
|
||||
|
||||
Self {
|
||||
counters: vec![0; num_counters],
|
||||
num_hashes,
|
||||
num_counters,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn insert<T: Hash>(&mut self, item: &T) {
|
||||
for i in 0..self.num_hashes {
|
||||
let hash = self.hash(item, i);
|
||||
let index = hash % self.num_counters;
|
||||
|
||||
// Increment counter (saturate at 15)
|
||||
if self.counters[index] < 15 {
|
||||
self.counters[index] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remove<T: Hash>(&mut self, item: &T) {
|
||||
for i in 0..self.num_hashes {
|
||||
let hash = self.hash(item, i);
|
||||
let index = hash % self.num_counters;
|
||||
|
||||
// Decrement counter
|
||||
if self.counters[index] > 0 {
|
||||
self.counters[index] -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn contains<T: Hash>(&self, item: &T) -> bool {
|
||||
for i in 0..self.num_hashes {
|
||||
let hash = self.hash(item, i);
|
||||
let index = hash % self.num_counters;
|
||||
|
||||
if self.counters[index] == 0 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn hash<T: Hash>(&self, item: &T, i: usize) -> usize {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
item.hash(&mut hasher);
|
||||
i.hash(&mut hasher);
|
||||
hasher.finish() as usize
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_bloom_filter() {
|
||||
let mut filter = BloomFilter::new(1000, 0.01);
|
||||
|
||||
filter.insert(&"hello");
|
||||
filter.insert(&"world");
|
||||
filter.insert(&12345);
|
||||
|
||||
assert!(filter.contains(&"hello"));
|
||||
assert!(filter.contains(&"world"));
|
||||
assert!(filter.contains(&12345));
|
||||
assert!(!filter.contains(&"nonexistent"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bloom_filter_false_positive_rate() {
|
||||
let mut filter = BloomFilter::new(100, 0.01);
|
||||
|
||||
// Insert 100 items
|
||||
for i in 0..100 {
|
||||
filter.insert(&i);
|
||||
}
|
||||
|
||||
// Check false positive rate
|
||||
let mut false_positives = 0;
|
||||
let test_items = 1000;
|
||||
|
||||
for i in 100..(100 + test_items) {
|
||||
if filter.contains(&i) {
|
||||
false_positives += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let rate = false_positives as f64 / test_items as f64;
|
||||
assert!(rate < 0.05, "False positive rate too high: {}", rate);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalable_bloom_filter() {
|
||||
let mut filter = ScalableBloomFilter::new(10, 0.01);
|
||||
|
||||
// Insert many items (more than initial capacity)
|
||||
for i in 0..100 {
|
||||
filter.insert(&i);
|
||||
}
|
||||
|
||||
assert!(filter.num_filters() > 1);
|
||||
|
||||
// All items should be found
|
||||
for i in 0..100 {
|
||||
assert!(filter.contains(&i));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_counting_bloom_filter() {
|
||||
let mut filter = CountingBloomFilter::new(100, 0.01);
|
||||
|
||||
filter.insert(&"test");
|
||||
assert!(filter.contains(&"test"));
|
||||
|
||||
filter.remove(&"test");
|
||||
assert!(!filter.contains(&"test"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bloom_clear() {
|
||||
let mut filter = BloomFilter::new(100, 0.01);
|
||||
|
||||
filter.insert(&"test");
|
||||
assert!(filter.contains(&"test"));
|
||||
|
||||
filter.clear();
|
||||
assert!(!filter.contains(&"test"));
|
||||
}
|
||||
}
|
||||
412
vendor/ruvector/crates/ruvector-graph/src/optimization/cache_hierarchy.rs
vendored
Normal file
412
vendor/ruvector/crates/ruvector-graph/src/optimization/cache_hierarchy.rs
vendored
Normal file
@@ -0,0 +1,412 @@
|
||||
//! Cache-optimized data layouts with hot/cold data separation
|
||||
//!
|
||||
//! This module implements cache-friendly storage patterns to minimize
|
||||
//! cache misses and maximize memory bandwidth utilization.
|
||||
|
||||
use parking_lot::RwLock;
|
||||
use std::alloc::{alloc, dealloc, Layout};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Cache line size (64 bytes on x86-64)
|
||||
const CACHE_LINE_SIZE: usize = 64;
|
||||
|
||||
/// L1 cache size estimate (32KB typical)
|
||||
const L1_CACHE_SIZE: usize = 32 * 1024;
|
||||
|
||||
/// L2 cache size estimate (256KB typical)
|
||||
const L2_CACHE_SIZE: usize = 256 * 1024;
|
||||
|
||||
/// L3 cache size estimate (8MB typical)
|
||||
const L3_CACHE_SIZE: usize = 8 * 1024 * 1024;
|
||||
|
||||
/// Cache hierarchy manager for graph data
|
||||
pub struct CacheHierarchy {
|
||||
/// Hot data stored in L1-friendly layout
|
||||
hot_storage: Arc<RwLock<HotStorage>>,
|
||||
/// Cold data stored in compressed format
|
||||
cold_storage: Arc<RwLock<ColdStorage>>,
|
||||
/// Access frequency tracker
|
||||
access_tracker: Arc<RwLock<AccessTracker>>,
|
||||
}
|
||||
|
||||
impl CacheHierarchy {
|
||||
/// Create a new cache hierarchy
|
||||
pub fn new(hot_capacity: usize, cold_capacity: usize) -> Self {
|
||||
Self {
|
||||
hot_storage: Arc::new(RwLock::new(HotStorage::new(hot_capacity))),
|
||||
cold_storage: Arc::new(RwLock::new(ColdStorage::new(cold_capacity))),
|
||||
access_tracker: Arc::new(RwLock::new(AccessTracker::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Access node data with automatic hot/cold promotion
|
||||
pub fn get_node(&self, node_id: u64) -> Option<NodeData> {
|
||||
// Record access
|
||||
self.access_tracker.write().record_access(node_id);
|
||||
|
||||
// Try hot storage first
|
||||
if let Some(data) = self.hot_storage.read().get(node_id) {
|
||||
return Some(data);
|
||||
}
|
||||
|
||||
// Fall back to cold storage
|
||||
if let Some(data) = self.cold_storage.read().get(node_id) {
|
||||
// Promote to hot if frequently accessed
|
||||
if self.access_tracker.read().should_promote(node_id) {
|
||||
self.promote_to_hot(node_id, data.clone());
|
||||
}
|
||||
return Some(data);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Insert node data with automatic placement
|
||||
pub fn insert_node(&self, node_id: u64, data: NodeData) {
|
||||
// Record initial access for the new node
|
||||
self.access_tracker.write().record_access(node_id);
|
||||
|
||||
// Check if we need to evict before inserting (to avoid double eviction with HotStorage)
|
||||
if self.hot_storage.read().is_at_capacity() {
|
||||
self.evict_one_to_cold(node_id); // Don't evict the one we're about to insert
|
||||
}
|
||||
|
||||
// New data goes to hot storage
|
||||
self.hot_storage.write().insert(node_id, data.clone());
|
||||
}
|
||||
|
||||
/// Promote node from cold to hot storage
|
||||
fn promote_to_hot(&self, node_id: u64, data: NodeData) {
|
||||
// First evict if needed to make room
|
||||
if self.hot_storage.read().is_full() {
|
||||
self.evict_one_to_cold(node_id); // Pass node_id to avoid evicting the one we're promoting
|
||||
}
|
||||
|
||||
self.hot_storage.write().insert(node_id, data);
|
||||
self.cold_storage.write().remove(node_id);
|
||||
}
|
||||
|
||||
/// Evict least recently used hot data to cold storage
|
||||
fn evict_cold(&self) {
|
||||
let tracker = self.access_tracker.read();
|
||||
let lru_nodes = tracker.get_lru_nodes_by_frequency(10);
|
||||
drop(tracker);
|
||||
|
||||
let mut hot = self.hot_storage.write();
|
||||
let mut cold = self.cold_storage.write();
|
||||
|
||||
for node_id in lru_nodes {
|
||||
if let Some(data) = hot.remove(node_id) {
|
||||
cold.insert(node_id, data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Evict one node to cold storage, avoiding the protected node_id
|
||||
fn evict_one_to_cold(&self, protected_id: u64) {
|
||||
let tracker = self.access_tracker.read();
|
||||
// Get nodes sorted by frequency (least frequently accessed first)
|
||||
let candidates = tracker.get_lru_nodes_by_frequency(5);
|
||||
drop(tracker);
|
||||
|
||||
let mut hot = self.hot_storage.write();
|
||||
let mut cold = self.cold_storage.write();
|
||||
|
||||
for node_id in candidates {
|
||||
if node_id != protected_id {
|
||||
if let Some(data) = hot.remove(node_id) {
|
||||
cold.insert(node_id, data);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Prefetch nodes that are likely to be accessed soon
|
||||
pub fn prefetch_neighbors(&self, node_ids: &[u64]) {
|
||||
// Use software prefetching hints
|
||||
for &node_id in node_ids {
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
unsafe {
|
||||
// Prefetch to L1 cache
|
||||
std::arch::x86_64::_mm_prefetch(
|
||||
&node_id as *const u64 as *const i8,
|
||||
std::arch::x86_64::_MM_HINT_T0,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Hot storage with cache-line aligned entries
|
||||
#[repr(align(64))]
|
||||
struct HotStorage {
|
||||
/// Cache-line aligned storage
|
||||
entries: Vec<CacheLineEntry>,
|
||||
/// Capacity in number of entries
|
||||
capacity: usize,
|
||||
/// Current size
|
||||
size: usize,
|
||||
}
|
||||
|
||||
impl HotStorage {
|
||||
fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
entries: Vec::with_capacity(capacity),
|
||||
capacity,
|
||||
size: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn get(&self, node_id: u64) -> Option<NodeData> {
|
||||
self.entries
|
||||
.iter()
|
||||
.find(|e| e.node_id == node_id)
|
||||
.map(|e| e.data.clone())
|
||||
}
|
||||
|
||||
fn insert(&mut self, node_id: u64, data: NodeData) {
|
||||
// Remove old entry if exists
|
||||
self.entries.retain(|e| e.node_id != node_id);
|
||||
|
||||
if self.entries.len() >= self.capacity {
|
||||
self.entries.remove(0); // Simple FIFO eviction
|
||||
}
|
||||
|
||||
self.entries.push(CacheLineEntry { node_id, data });
|
||||
self.size = self.entries.len();
|
||||
}
|
||||
|
||||
fn remove(&mut self, node_id: u64) -> Option<NodeData> {
|
||||
if let Some(pos) = self.entries.iter().position(|e| e.node_id == node_id) {
|
||||
let entry = self.entries.remove(pos);
|
||||
self.size = self.entries.len();
|
||||
Some(entry.data)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn is_full(&self) -> bool {
|
||||
self.size >= self.capacity
|
||||
}
|
||||
|
||||
fn is_at_capacity(&self) -> bool {
|
||||
self.size >= self.capacity
|
||||
}
|
||||
}
|
||||
|
||||
/// Cache-line aligned entry (64 bytes)
|
||||
#[repr(align(64))]
|
||||
#[derive(Clone)]
|
||||
struct CacheLineEntry {
|
||||
node_id: u64,
|
||||
data: NodeData,
|
||||
}
|
||||
|
||||
/// Cold storage with compression
|
||||
struct ColdStorage {
|
||||
/// Compressed data storage
|
||||
entries: dashmap::DashMap<u64, Vec<u8>>,
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
impl ColdStorage {
|
||||
fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
entries: dashmap::DashMap::new(),
|
||||
capacity,
|
||||
}
|
||||
}
|
||||
|
||||
fn get(&self, node_id: u64) -> Option<NodeData> {
|
||||
self.entries.get(&node_id).and_then(|compressed| {
|
||||
// Decompress data using bincode 2.0 API
|
||||
bincode::decode_from_slice(&compressed, bincode::config::standard())
|
||||
.ok()
|
||||
.map(|(data, _)| data)
|
||||
})
|
||||
}
|
||||
|
||||
fn insert(&mut self, node_id: u64, data: NodeData) {
|
||||
// Compress data using bincode 2.0 API
|
||||
if let Ok(compressed) = bincode::encode_to_vec(&data, bincode::config::standard()) {
|
||||
self.entries.insert(node_id, compressed);
|
||||
}
|
||||
}
|
||||
|
||||
fn remove(&mut self, node_id: u64) -> Option<NodeData> {
|
||||
self.entries.remove(&node_id).and_then(|(_, compressed)| {
|
||||
bincode::decode_from_slice(&compressed, bincode::config::standard())
|
||||
.ok()
|
||||
.map(|(data, _)| data)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Access frequency tracker for hot/cold promotion
|
||||
struct AccessTracker {
|
||||
/// Access counts per node
|
||||
access_counts: dashmap::DashMap<u64, u32>,
|
||||
/// Last access timestamp
|
||||
last_access: dashmap::DashMap<u64, u64>,
|
||||
/// Global timestamp
|
||||
timestamp: u64,
|
||||
}
|
||||
|
||||
impl AccessTracker {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
access_counts: dashmap::DashMap::new(),
|
||||
last_access: dashmap::DashMap::new(),
|
||||
timestamp: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn record_access(&mut self, node_id: u64) {
|
||||
self.timestamp += 1;
|
||||
|
||||
self.access_counts
|
||||
.entry(node_id)
|
||||
.and_modify(|count| *count += 1)
|
||||
.or_insert(1);
|
||||
|
||||
self.last_access.insert(node_id, self.timestamp);
|
||||
}
|
||||
|
||||
fn should_promote(&self, node_id: u64) -> bool {
|
||||
// Promote if accessed more than 5 times
|
||||
self.access_counts
|
||||
.get(&node_id)
|
||||
.map(|count| *count > 5)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn get_lru_nodes(&self, count: usize) -> Vec<u64> {
|
||||
let mut nodes: Vec<_> = self
|
||||
.last_access
|
||||
.iter()
|
||||
.map(|entry| (*entry.key(), *entry.value()))
|
||||
.collect();
|
||||
|
||||
nodes.sort_by_key(|(_, timestamp)| *timestamp);
|
||||
nodes
|
||||
.into_iter()
|
||||
.take(count)
|
||||
.map(|(node_id, _)| node_id)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get least frequently accessed nodes (for smart eviction)
|
||||
fn get_lru_nodes_by_frequency(&self, count: usize) -> Vec<u64> {
|
||||
let mut nodes: Vec<_> = self
|
||||
.access_counts
|
||||
.iter()
|
||||
.map(|entry| (*entry.key(), *entry.value()))
|
||||
.collect();
|
||||
|
||||
// Sort by access count (ascending - least frequently accessed first)
|
||||
nodes.sort_by_key(|(_, access_count)| *access_count);
|
||||
nodes
|
||||
.into_iter()
|
||||
.take(count)
|
||||
.map(|(node_id, _)| node_id)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Node data structure
|
||||
#[derive(Clone, serde::Serialize, serde::Deserialize, bincode::Encode, bincode::Decode)]
|
||||
pub struct NodeData {
|
||||
pub id: u64,
|
||||
pub labels: Vec<String>,
|
||||
pub properties: Vec<(String, CachePropertyValue)>,
|
||||
}
|
||||
|
||||
/// Property value types for cache storage
|
||||
#[derive(Clone, serde::Serialize, serde::Deserialize, bincode::Encode, bincode::Decode)]
|
||||
pub enum CachePropertyValue {
|
||||
String(String),
|
||||
Integer(i64),
|
||||
Float(f64),
|
||||
Boolean(bool),
|
||||
}
|
||||
|
||||
/// Hot/cold storage facade
|
||||
pub struct HotColdStorage {
|
||||
cache_hierarchy: CacheHierarchy,
|
||||
}
|
||||
|
||||
impl HotColdStorage {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
cache_hierarchy: CacheHierarchy::new(1000, 10000),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get(&self, node_id: u64) -> Option<NodeData> {
|
||||
self.cache_hierarchy.get_node(node_id)
|
||||
}
|
||||
|
||||
pub fn insert(&self, node_id: u64, data: NodeData) {
|
||||
self.cache_hierarchy.insert_node(node_id, data);
|
||||
}
|
||||
|
||||
pub fn prefetch(&self, node_ids: &[u64]) {
|
||||
self.cache_hierarchy.prefetch_neighbors(node_ids);
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HotColdStorage {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
#[test]
|
||||
fn test_cache_hierarchy() {
|
||||
let cache = CacheHierarchy::new(10, 100);
|
||||
|
||||
let data = NodeData {
|
||||
id: 1,
|
||||
labels: vec!["Person".to_string()],
|
||||
properties: vec![(
|
||||
"name".to_string(),
|
||||
CachePropertyValue::String("Alice".to_string()),
|
||||
)],
|
||||
};
|
||||
|
||||
cache.insert_node(1, data.clone());
|
||||
|
||||
let retrieved = cache.get_node(1);
|
||||
assert!(retrieved.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hot_cold_promotion() {
|
||||
let cache = CacheHierarchy::new(2, 10);
|
||||
|
||||
// Insert 3 nodes (exceeds hot capacity)
|
||||
for i in 1..=3 {
|
||||
cache.insert_node(
|
||||
i,
|
||||
NodeData {
|
||||
id: i,
|
||||
labels: vec![],
|
||||
properties: vec![],
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Access node 1 multiple times to trigger promotion
|
||||
for _ in 0..10 {
|
||||
cache.get_node(1);
|
||||
}
|
||||
|
||||
// Node 1 should still be accessible
|
||||
assert!(cache.get_node(1).is_some());
|
||||
}
|
||||
}
|
||||
429
vendor/ruvector/crates/ruvector-graph/src/optimization/index_compression.rs
vendored
Normal file
429
vendor/ruvector/crates/ruvector-graph/src/optimization/index_compression.rs
vendored
Normal file
@@ -0,0 +1,429 @@
|
||||
//! Compressed index structures for massive space savings
|
||||
//!
|
||||
//! This module provides:
|
||||
//! - Roaring bitmaps for label indexes
|
||||
//! - Delta encoding for sorted ID lists
|
||||
//! - Dictionary encoding for string properties
|
||||
|
||||
use parking_lot::RwLock;
|
||||
use roaring::RoaringBitmap;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Compressed index using multiple encoding strategies
|
||||
pub struct CompressedIndex {
|
||||
/// Bitmap indexes for labels
|
||||
label_indexes: Arc<RwLock<HashMap<String, RoaringBitmap>>>,
|
||||
/// Delta-encoded sorted ID lists
|
||||
sorted_indexes: Arc<RwLock<HashMap<String, DeltaEncodedList>>>,
|
||||
/// Dictionary encoding for string properties
|
||||
string_dict: Arc<RwLock<StringDictionary>>,
|
||||
}
|
||||
|
||||
impl CompressedIndex {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
label_indexes: Arc::new(RwLock::new(HashMap::new())),
|
||||
sorted_indexes: Arc::new(RwLock::new(HashMap::new())),
|
||||
string_dict: Arc::new(RwLock::new(StringDictionary::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add node to label index
|
||||
pub fn add_to_label_index(&self, label: &str, node_id: u64) {
|
||||
let mut indexes = self.label_indexes.write();
|
||||
indexes
|
||||
.entry(label.to_string())
|
||||
.or_insert_with(RoaringBitmap::new)
|
||||
.insert(node_id as u32);
|
||||
}
|
||||
|
||||
/// Get all nodes with a specific label
|
||||
pub fn get_nodes_by_label(&self, label: &str) -> Vec<u64> {
|
||||
self.label_indexes
|
||||
.read()
|
||||
.get(label)
|
||||
.map(|bitmap| bitmap.iter().map(|id| id as u64).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Check if node has label (fast bitmap lookup)
|
||||
pub fn has_label(&self, label: &str, node_id: u64) -> bool {
|
||||
self.label_indexes
|
||||
.read()
|
||||
.get(label)
|
||||
.map(|bitmap| bitmap.contains(node_id as u32))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Count nodes with label
|
||||
pub fn count_label(&self, label: &str) -> u64 {
|
||||
self.label_indexes
|
||||
.read()
|
||||
.get(label)
|
||||
.map(|bitmap| bitmap.len())
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Intersect multiple labels (efficient bitmap AND)
|
||||
pub fn intersect_labels(&self, labels: &[&str]) -> Vec<u64> {
|
||||
let indexes = self.label_indexes.read();
|
||||
|
||||
if labels.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut result = indexes
|
||||
.get(labels[0])
|
||||
.cloned()
|
||||
.unwrap_or_else(RoaringBitmap::new);
|
||||
|
||||
for &label in &labels[1..] {
|
||||
if let Some(bitmap) = indexes.get(label) {
|
||||
result &= bitmap;
|
||||
} else {
|
||||
return Vec::new();
|
||||
}
|
||||
}
|
||||
|
||||
result.iter().map(|id| id as u64).collect()
|
||||
}
|
||||
|
||||
/// Union multiple labels (efficient bitmap OR)
|
||||
pub fn union_labels(&self, labels: &[&str]) -> Vec<u64> {
|
||||
let indexes = self.label_indexes.read();
|
||||
let mut result = RoaringBitmap::new();
|
||||
|
||||
for &label in labels {
|
||||
if let Some(bitmap) = indexes.get(label) {
|
||||
result |= bitmap;
|
||||
}
|
||||
}
|
||||
|
||||
result.iter().map(|id| id as u64).collect()
|
||||
}
|
||||
|
||||
/// Encode string using dictionary
|
||||
pub fn encode_string(&self, s: &str) -> u32 {
|
||||
self.string_dict.write().encode(s)
|
||||
}
|
||||
|
||||
/// Decode string from dictionary
|
||||
pub fn decode_string(&self, id: u32) -> Option<String> {
|
||||
self.string_dict.read().decode(id)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CompressedIndex {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Roaring bitmap index for efficient set operations
|
||||
pub struct RoaringBitmapIndex {
|
||||
bitmap: RoaringBitmap,
|
||||
}
|
||||
|
||||
impl RoaringBitmapIndex {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
bitmap: RoaringBitmap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, id: u64) {
|
||||
self.bitmap.insert(id as u32);
|
||||
}
|
||||
|
||||
pub fn contains(&self, id: u64) -> bool {
|
||||
self.bitmap.contains(id as u32)
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, id: u64) {
|
||||
self.bitmap.remove(id as u32);
|
||||
}
|
||||
|
||||
pub fn len(&self) -> u64 {
|
||||
self.bitmap.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.bitmap.is_empty()
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = u64> + '_ {
|
||||
self.bitmap.iter().map(|id| id as u64)
|
||||
}
|
||||
|
||||
/// Intersect with another bitmap
|
||||
pub fn intersect(&self, other: &Self) -> Self {
|
||||
Self {
|
||||
bitmap: &self.bitmap & &other.bitmap,
|
||||
}
|
||||
}
|
||||
|
||||
/// Union with another bitmap
|
||||
pub fn union(&self, other: &Self) -> Self {
|
||||
Self {
|
||||
bitmap: &self.bitmap | &other.bitmap,
|
||||
}
|
||||
}
|
||||
|
||||
/// Serialize to bytes
|
||||
pub fn serialize(&self) -> Vec<u8> {
|
||||
let mut bytes = Vec::new();
|
||||
self.bitmap
|
||||
.serialize_into(&mut bytes)
|
||||
.expect("Failed to serialize bitmap");
|
||||
bytes
|
||||
}
|
||||
|
||||
/// Deserialize from bytes
|
||||
pub fn deserialize(bytes: &[u8]) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let bitmap = RoaringBitmap::deserialize_from(bytes)?;
|
||||
Ok(Self { bitmap })
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RoaringBitmapIndex {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Delta encoding for sorted ID lists
|
||||
/// Stores differences between consecutive IDs for better compression
|
||||
pub struct DeltaEncodedList {
|
||||
/// Base value (first ID)
|
||||
base: u64,
|
||||
/// Delta values
|
||||
deltas: Vec<u32>,
|
||||
}
|
||||
|
||||
impl DeltaEncodedList {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
base: 0,
|
||||
deltas: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode a sorted list of IDs
|
||||
pub fn encode(ids: &[u64]) -> Self {
|
||||
if ids.is_empty() {
|
||||
return Self::new();
|
||||
}
|
||||
|
||||
let base = ids[0];
|
||||
let deltas = ids
|
||||
.windows(2)
|
||||
.map(|pair| (pair[1] - pair[0]) as u32)
|
||||
.collect();
|
||||
|
||||
Self { base, deltas }
|
||||
}
|
||||
|
||||
/// Decode to original ID list
|
||||
pub fn decode(&self) -> Vec<u64> {
|
||||
if self.deltas.is_empty() {
|
||||
if self.base == 0 {
|
||||
return Vec::new();
|
||||
}
|
||||
return vec![self.base];
|
||||
}
|
||||
|
||||
let mut result = Vec::with_capacity(self.deltas.len() + 1);
|
||||
result.push(self.base);
|
||||
|
||||
let mut current = self.base;
|
||||
for &delta in &self.deltas {
|
||||
current += delta as u64;
|
||||
result.push(current);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Get compression ratio
|
||||
pub fn compression_ratio(&self) -> f64 {
|
||||
let original_size = (self.deltas.len() + 1) * 8; // u64s
|
||||
let compressed_size = 8 + self.deltas.len() * 4; // base + u32 deltas
|
||||
original_size as f64 / compressed_size as f64
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DeltaEncodedList {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Delta encoder utility
|
||||
pub struct DeltaEncoder;
|
||||
|
||||
impl DeltaEncoder {
|
||||
/// Encode sorted u64 slice to delta-encoded format
|
||||
pub fn encode(values: &[u64]) -> Vec<u8> {
|
||||
if values.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut result = Vec::new();
|
||||
|
||||
// Write base value
|
||||
result.extend_from_slice(&values[0].to_le_bytes());
|
||||
|
||||
// Write deltas
|
||||
for window in values.windows(2) {
|
||||
let delta = (window[1] - window[0]) as u32;
|
||||
result.extend_from_slice(&delta.to_le_bytes());
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Decode delta-encoded format back to u64 values
|
||||
pub fn decode(bytes: &[u8]) -> Vec<u64> {
|
||||
if bytes.len() < 8 {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut result = Vec::new();
|
||||
|
||||
// Read base value
|
||||
let base = u64::from_le_bytes(bytes[0..8].try_into().unwrap());
|
||||
result.push(base);
|
||||
|
||||
// Read deltas
|
||||
let mut current = base;
|
||||
for chunk in bytes[8..].chunks(4) {
|
||||
if chunk.len() == 4 {
|
||||
let delta = u32::from_le_bytes(chunk.try_into().unwrap());
|
||||
current += delta as u64;
|
||||
result.push(current);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/// String dictionary for deduplication and compression
|
||||
struct StringDictionary {
|
||||
/// String to ID mapping
|
||||
string_to_id: HashMap<String, u32>,
|
||||
/// ID to string mapping
|
||||
id_to_string: HashMap<u32, String>,
|
||||
/// Next available ID
|
||||
next_id: u32,
|
||||
}
|
||||
|
||||
impl StringDictionary {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
string_to_id: HashMap::new(),
|
||||
id_to_string: HashMap::new(),
|
||||
next_id: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn encode(&mut self, s: &str) -> u32 {
|
||||
if let Some(&id) = self.string_to_id.get(s) {
|
||||
return id;
|
||||
}
|
||||
|
||||
let id = self.next_id;
|
||||
self.next_id += 1;
|
||||
|
||||
self.string_to_id.insert(s.to_string(), id);
|
||||
self.id_to_string.insert(id, s.to_string());
|
||||
|
||||
id
|
||||
}
|
||||
|
||||
fn decode(&self, id: u32) -> Option<String> {
|
||||
self.id_to_string.get(&id).cloned()
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.string_to_id.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_compressed_index() {
|
||||
let index = CompressedIndex::new();
|
||||
|
||||
index.add_to_label_index("Person", 1);
|
||||
index.add_to_label_index("Person", 2);
|
||||
index.add_to_label_index("Person", 3);
|
||||
index.add_to_label_index("Employee", 2);
|
||||
index.add_to_label_index("Employee", 3);
|
||||
|
||||
let persons = index.get_nodes_by_label("Person");
|
||||
assert_eq!(persons.len(), 3);
|
||||
|
||||
let intersection = index.intersect_labels(&["Person", "Employee"]);
|
||||
assert_eq!(intersection.len(), 2);
|
||||
|
||||
let union = index.union_labels(&["Person", "Employee"]);
|
||||
assert_eq!(union.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_roaring_bitmap() {
|
||||
let mut bitmap = RoaringBitmapIndex::new();
|
||||
|
||||
bitmap.insert(1);
|
||||
bitmap.insert(100);
|
||||
bitmap.insert(1000);
|
||||
|
||||
assert!(bitmap.contains(1));
|
||||
assert!(bitmap.contains(100));
|
||||
assert!(!bitmap.contains(50));
|
||||
|
||||
assert_eq!(bitmap.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delta_encoding() {
|
||||
let ids = vec![100, 102, 105, 110, 120];
|
||||
let encoded = DeltaEncodedList::encode(&ids);
|
||||
let decoded = encoded.decode();
|
||||
|
||||
assert_eq!(ids, decoded);
|
||||
assert!(encoded.compression_ratio() > 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delta_encoder() {
|
||||
let values = vec![1000, 1005, 1010, 1020, 1030];
|
||||
let encoded = DeltaEncoder::encode(&values);
|
||||
let decoded = DeltaEncoder::decode(&encoded);
|
||||
|
||||
assert_eq!(values, decoded);
|
||||
|
||||
// Encoded size should be smaller
|
||||
assert!(encoded.len() < values.len() * 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_string_dictionary() {
|
||||
let index = CompressedIndex::new();
|
||||
|
||||
let id1 = index.encode_string("hello");
|
||||
let id2 = index.encode_string("world");
|
||||
let id3 = index.encode_string("hello"); // Duplicate
|
||||
|
||||
assert_eq!(id1, id3); // Same string gets same ID
|
||||
assert_ne!(id1, id2);
|
||||
|
||||
assert_eq!(index.decode_string(id1), Some("hello".to_string()));
|
||||
assert_eq!(index.decode_string(id2), Some("world".to_string()));
|
||||
}
|
||||
}
|
||||
432
vendor/ruvector/crates/ruvector-graph/src/optimization/memory_pool.rs
vendored
Normal file
432
vendor/ruvector/crates/ruvector-graph/src/optimization/memory_pool.rs
vendored
Normal file
@@ -0,0 +1,432 @@
|
||||
//! Custom memory allocators for graph query execution
|
||||
//!
|
||||
//! This module provides specialized allocators:
|
||||
//! - Arena allocation for query-scoped memory
|
||||
//! - Object pooling for frequent allocations
|
||||
//! - NUMA-aware allocation for distributed systems
|
||||
|
||||
use parking_lot::Mutex;
|
||||
use std::alloc::{alloc, dealloc, Layout};
|
||||
use std::cell::Cell;
|
||||
use std::ptr::{self, NonNull};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Arena allocator for query execution
|
||||
/// All allocations are freed together when the arena is dropped
|
||||
pub struct ArenaAllocator {
|
||||
/// Current chunk
|
||||
current: Cell<Option<NonNull<Chunk>>>,
|
||||
/// All chunks (for cleanup)
|
||||
chunks: Mutex<Vec<NonNull<Chunk>>>,
|
||||
/// Default chunk size
|
||||
chunk_size: usize,
|
||||
}
|
||||
|
||||
struct Chunk {
|
||||
/// Data buffer
|
||||
data: NonNull<u8>,
|
||||
/// Current offset in buffer
|
||||
offset: Cell<usize>,
|
||||
/// Total capacity
|
||||
capacity: usize,
|
||||
/// Next chunk in linked list
|
||||
next: Cell<Option<NonNull<Chunk>>>,
|
||||
}
|
||||
|
||||
impl ArenaAllocator {
|
||||
/// Create a new arena with default chunk size (1MB)
|
||||
pub fn new() -> Self {
|
||||
Self::with_chunk_size(1024 * 1024)
|
||||
}
|
||||
|
||||
/// Create arena with specific chunk size
|
||||
pub fn with_chunk_size(chunk_size: usize) -> Self {
|
||||
Self {
|
||||
current: Cell::new(None),
|
||||
chunks: Mutex::new(Vec::new()),
|
||||
chunk_size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate memory from the arena
|
||||
pub fn alloc<T>(&self) -> NonNull<T> {
|
||||
let layout = Layout::new::<T>();
|
||||
let ptr = self.alloc_layout(layout);
|
||||
ptr.cast()
|
||||
}
|
||||
|
||||
/// Allocate with specific layout
|
||||
pub fn alloc_layout(&self, layout: Layout) -> NonNull<u8> {
|
||||
let size = layout.size();
|
||||
let align = layout.align();
|
||||
|
||||
// SECURITY: Validate layout parameters
|
||||
assert!(size > 0, "Cannot allocate zero bytes");
|
||||
assert!(
|
||||
align > 0 && align.is_power_of_two(),
|
||||
"Alignment must be a power of 2"
|
||||
);
|
||||
assert!(size <= isize::MAX as usize, "Allocation size too large");
|
||||
|
||||
// Get current chunk or allocate new one
|
||||
let chunk = match self.current.get() {
|
||||
Some(chunk) => chunk,
|
||||
None => {
|
||||
let chunk = self.allocate_chunk();
|
||||
self.current.set(Some(chunk));
|
||||
chunk
|
||||
}
|
||||
};
|
||||
|
||||
unsafe {
|
||||
let chunk_ref = chunk.as_ref();
|
||||
let offset = chunk_ref.offset.get();
|
||||
|
||||
// Align offset
|
||||
let aligned_offset = (offset + align - 1) & !(align - 1);
|
||||
|
||||
// SECURITY: Check for overflow in alignment calculation
|
||||
if aligned_offset < offset {
|
||||
panic!("Alignment calculation overflow");
|
||||
}
|
||||
|
||||
let new_offset = aligned_offset
|
||||
.checked_add(size)
|
||||
.expect("Arena allocation overflow");
|
||||
|
||||
if new_offset > chunk_ref.capacity {
|
||||
// Need a new chunk
|
||||
let new_chunk = self.allocate_chunk();
|
||||
chunk_ref.next.set(Some(new_chunk));
|
||||
self.current.set(Some(new_chunk));
|
||||
|
||||
// Retry allocation with new chunk
|
||||
return self.alloc_layout(layout);
|
||||
}
|
||||
|
||||
chunk_ref.offset.set(new_offset);
|
||||
|
||||
// SECURITY: Verify pointer arithmetic is safe
|
||||
let result_ptr = chunk_ref.data.as_ptr().add(aligned_offset);
|
||||
debug_assert!(
|
||||
result_ptr as usize >= chunk_ref.data.as_ptr() as usize,
|
||||
"Pointer arithmetic underflow"
|
||||
);
|
||||
debug_assert!(
|
||||
result_ptr as usize <= chunk_ref.data.as_ptr().add(chunk_ref.capacity) as usize,
|
||||
"Pointer arithmetic overflow"
|
||||
);
|
||||
|
||||
NonNull::new_unchecked(result_ptr)
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate a new chunk
|
||||
fn allocate_chunk(&self) -> NonNull<Chunk> {
|
||||
unsafe {
|
||||
let layout = Layout::from_size_align_unchecked(self.chunk_size, 64);
|
||||
let data = NonNull::new_unchecked(alloc(layout));
|
||||
|
||||
let chunk_layout = Layout::new::<Chunk>();
|
||||
let chunk_ptr = alloc(chunk_layout) as *mut Chunk;
|
||||
|
||||
ptr::write(
|
||||
chunk_ptr,
|
||||
Chunk {
|
||||
data,
|
||||
offset: Cell::new(0),
|
||||
capacity: self.chunk_size,
|
||||
next: Cell::new(None),
|
||||
},
|
||||
);
|
||||
|
||||
let chunk = NonNull::new_unchecked(chunk_ptr);
|
||||
self.chunks.lock().push(chunk);
|
||||
chunk
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset arena (reuse existing chunks)
|
||||
pub fn reset(&self) {
|
||||
let chunks = self.chunks.lock();
|
||||
for &chunk in chunks.iter() {
|
||||
unsafe {
|
||||
chunk.as_ref().offset.set(0);
|
||||
chunk.as_ref().next.set(None);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(first_chunk) = chunks.first() {
|
||||
self.current.set(Some(*first_chunk));
|
||||
}
|
||||
}
|
||||
|
||||
/// Get total allocated bytes across all chunks
|
||||
pub fn total_allocated(&self) -> usize {
|
||||
self.chunks.lock().len() * self.chunk_size
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ArenaAllocator {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ArenaAllocator {
|
||||
fn drop(&mut self) {
|
||||
let chunks = self.chunks.lock();
|
||||
for &chunk in chunks.iter() {
|
||||
unsafe {
|
||||
let chunk_ref = chunk.as_ref();
|
||||
|
||||
// Deallocate data buffer
|
||||
let data_layout = Layout::from_size_align_unchecked(chunk_ref.capacity, 64);
|
||||
dealloc(chunk_ref.data.as_ptr(), data_layout);
|
||||
|
||||
// Deallocate chunk itself
|
||||
let chunk_layout = Layout::new::<Chunk>();
|
||||
dealloc(chunk.as_ptr() as *mut u8, chunk_layout);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Send for ArenaAllocator {}
|
||||
unsafe impl Sync for ArenaAllocator {}
|
||||
|
||||
/// Query-scoped arena that resets after each query
|
||||
pub struct QueryArena {
|
||||
arena: Arc<ArenaAllocator>,
|
||||
}
|
||||
|
||||
impl QueryArena {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
arena: Arc::new(ArenaAllocator::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute_query<F, R>(&self, f: F) -> R
|
||||
where
|
||||
F: FnOnce(&ArenaAllocator) -> R,
|
||||
{
|
||||
let result = f(&self.arena);
|
||||
self.arena.reset();
|
||||
result
|
||||
}
|
||||
|
||||
pub fn arena(&self) -> &ArenaAllocator {
|
||||
&self.arena
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for QueryArena {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// NUMA-aware allocator for multi-socket systems
|
||||
pub struct NumaAllocator {
|
||||
/// Allocators per NUMA node
|
||||
node_allocators: Vec<Arc<ArenaAllocator>>,
|
||||
/// Current thread's preferred NUMA node
|
||||
preferred_node: Cell<usize>,
|
||||
}
|
||||
|
||||
impl NumaAllocator {
|
||||
/// Create NUMA-aware allocator
|
||||
pub fn new() -> Self {
|
||||
let num_nodes = Self::detect_numa_nodes();
|
||||
let node_allocators = (0..num_nodes)
|
||||
.map(|_| Arc::new(ArenaAllocator::new()))
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
node_allocators,
|
||||
preferred_node: Cell::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect number of NUMA nodes (simplified)
|
||||
fn detect_numa_nodes() -> usize {
|
||||
// In a real implementation, this would use platform-specific APIs
|
||||
// For now, assume 1 node per 8 CPUs
|
||||
let cpus = num_cpus::get();
|
||||
((cpus + 7) / 8).max(1)
|
||||
}
|
||||
|
||||
/// Allocate from preferred NUMA node
|
||||
pub fn alloc<T>(&self) -> NonNull<T> {
|
||||
let node = self.preferred_node.get();
|
||||
self.node_allocators[node].alloc()
|
||||
}
|
||||
|
||||
/// Set preferred NUMA node for current thread
|
||||
pub fn set_preferred_node(&self, node: usize) {
|
||||
if node < self.node_allocators.len() {
|
||||
self.preferred_node.set(node);
|
||||
}
|
||||
}
|
||||
|
||||
/// Bind current thread to NUMA node
|
||||
pub fn bind_to_node(&self, node: usize) {
|
||||
self.set_preferred_node(node);
|
||||
|
||||
// In a real implementation, this would use platform-specific APIs
|
||||
// to bind the thread to CPUs on the specified NUMA node
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
// Would use libnuma or similar
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NumaAllocator {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Object pool for reducing allocation overhead
|
||||
pub struct ObjectPool<T> {
|
||||
/// Pool of available objects
|
||||
available: Arc<crossbeam::queue::SegQueue<T>>,
|
||||
/// Factory function
|
||||
factory: Arc<dyn Fn() -> T + Send + Sync>,
|
||||
/// Maximum pool size
|
||||
max_size: usize,
|
||||
}
|
||||
|
||||
impl<T> ObjectPool<T> {
|
||||
pub fn new<F>(max_size: usize, factory: F) -> Self
|
||||
where
|
||||
F: Fn() -> T + Send + Sync + 'static,
|
||||
{
|
||||
Self {
|
||||
available: Arc::new(crossbeam::queue::SegQueue::new()),
|
||||
factory: Arc::new(factory),
|
||||
max_size,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn acquire(&self) -> PooledObject<T> {
|
||||
let object = self.available.pop().unwrap_or_else(|| (self.factory)());
|
||||
|
||||
PooledObject {
|
||||
object: Some(object),
|
||||
pool: Arc::clone(&self.available),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.available.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.available.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// RAII wrapper for pooled objects
|
||||
pub struct PooledObject<T> {
|
||||
object: Option<T>,
|
||||
pool: Arc<crossbeam::queue::SegQueue<T>>,
|
||||
}
|
||||
|
||||
impl<T> PooledObject<T> {
|
||||
pub fn get(&self) -> &T {
|
||||
self.object.as_ref().unwrap()
|
||||
}
|
||||
|
||||
pub fn get_mut(&mut self) -> &mut T {
|
||||
self.object.as_mut().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for PooledObject<T> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(object) = self.object.take() {
|
||||
let _ = self.pool.push(object);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::Deref for PooledObject<T> {
|
||||
type Target = T;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.object.as_ref().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::DerefMut for PooledObject<T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.object.as_mut().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_arena_allocator() {
|
||||
let arena = ArenaAllocator::new();
|
||||
|
||||
let ptr1 = arena.alloc::<u64>();
|
||||
let ptr2 = arena.alloc::<u64>();
|
||||
|
||||
unsafe {
|
||||
ptr1.as_ptr().write(42);
|
||||
ptr2.as_ptr().write(84);
|
||||
|
||||
assert_eq!(ptr1.as_ptr().read(), 42);
|
||||
assert_eq!(ptr2.as_ptr().read(), 84);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_reset() {
|
||||
let arena = ArenaAllocator::new();
|
||||
|
||||
for _ in 0..100 {
|
||||
arena.alloc::<u64>();
|
||||
}
|
||||
|
||||
let allocated_before = arena.total_allocated();
|
||||
arena.reset();
|
||||
let allocated_after = arena.total_allocated();
|
||||
|
||||
assert_eq!(allocated_before, allocated_after);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_arena() {
|
||||
let query_arena = QueryArena::new();
|
||||
|
||||
let result = query_arena.execute_query(|arena| {
|
||||
let ptr = arena.alloc::<u64>();
|
||||
unsafe {
|
||||
ptr.as_ptr().write(123);
|
||||
ptr.as_ptr().read()
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(result, 123);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_object_pool() {
|
||||
let pool = ObjectPool::new(10, || Vec::<u8>::with_capacity(1024));
|
||||
|
||||
let mut obj = pool.acquire();
|
||||
obj.push(42);
|
||||
assert_eq!(obj[0], 42);
|
||||
|
||||
drop(obj);
|
||||
|
||||
let obj2 = pool.acquire();
|
||||
assert!(obj2.capacity() >= 1024);
|
||||
}
|
||||
}
|
||||
39
vendor/ruvector/crates/ruvector-graph/src/optimization/mod.rs
vendored
Normal file
39
vendor/ruvector/crates/ruvector-graph/src/optimization/mod.rs
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
//! Performance optimization modules for orders of magnitude speedup
|
||||
//!
|
||||
//! This module provides cutting-edge optimizations targeting 100x performance
|
||||
//! improvement over Neo4j through:
|
||||
//! - SIMD-vectorized graph traversal
|
||||
//! - Cache-optimized data layouts
|
||||
//! - Custom memory allocators
|
||||
//! - Compressed indexes
|
||||
//! - JIT-compiled query operators
|
||||
//! - Bloom filters for negative lookups
|
||||
//! - Adaptive radix trees for property indexes
|
||||
|
||||
pub mod adaptive_radix;
|
||||
pub mod bloom_filter;
|
||||
pub mod cache_hierarchy;
|
||||
pub mod index_compression;
|
||||
pub mod memory_pool;
|
||||
pub mod query_jit;
|
||||
pub mod simd_traversal;
|
||||
|
||||
// Re-exports for convenience
|
||||
pub use adaptive_radix::{AdaptiveRadixTree, ArtNode};
|
||||
pub use bloom_filter::{BloomFilter, ScalableBloomFilter};
|
||||
pub use cache_hierarchy::{CacheHierarchy, HotColdStorage};
|
||||
pub use index_compression::{CompressedIndex, DeltaEncoder, RoaringBitmapIndex};
|
||||
pub use memory_pool::{ArenaAllocator, NumaAllocator, QueryArena};
|
||||
pub use query_jit::{JitCompiler, JitQuery, QueryOperator};
|
||||
pub use simd_traversal::{SimdBfsIterator, SimdDfsIterator, SimdTraversal};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_optimization_modules_compile() {
|
||||
// Smoke test to ensure all modules compile
|
||||
assert!(true);
|
||||
}
|
||||
}
|
||||
337
vendor/ruvector/crates/ruvector-graph/src/optimization/query_jit.rs
vendored
Normal file
337
vendor/ruvector/crates/ruvector-graph/src/optimization/query_jit.rs
vendored
Normal file
@@ -0,0 +1,337 @@
|
||||
//! JIT compilation for hot query paths
|
||||
//!
|
||||
//! This module provides specialized query operators that are
|
||||
//! compiled/optimized for common query patterns.
|
||||
|
||||
use parking_lot::RwLock;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// JIT compiler for graph queries
|
||||
pub struct JitCompiler {
|
||||
/// Compiled query cache
|
||||
compiled_cache: Arc<RwLock<HashMap<String, Arc<JitQuery>>>>,
|
||||
/// Query execution statistics
|
||||
stats: Arc<RwLock<QueryStats>>,
|
||||
}
|
||||
|
||||
impl JitCompiler {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
compiled_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
stats: Arc::new(RwLock::new(QueryStats::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compile a query pattern into optimized operators
|
||||
pub fn compile(&self, pattern: &str) -> Arc<JitQuery> {
|
||||
// Check cache first
|
||||
{
|
||||
let cache = self.compiled_cache.read();
|
||||
if let Some(compiled) = cache.get(pattern) {
|
||||
return Arc::clone(compiled);
|
||||
}
|
||||
}
|
||||
|
||||
// Compile new query
|
||||
let query = Arc::new(self.compile_pattern(pattern));
|
||||
|
||||
// Cache it
|
||||
self.compiled_cache
|
||||
.write()
|
||||
.insert(pattern.to_string(), Arc::clone(&query));
|
||||
|
||||
query
|
||||
}
|
||||
|
||||
/// Compile pattern into specialized operators
|
||||
fn compile_pattern(&self, pattern: &str) -> JitQuery {
|
||||
// Parse pattern and generate optimized operator chain
|
||||
let operators = self.parse_and_optimize(pattern);
|
||||
|
||||
JitQuery {
|
||||
pattern: pattern.to_string(),
|
||||
operators,
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse query and generate optimized operator chain
|
||||
fn parse_and_optimize(&self, pattern: &str) -> Vec<QueryOperator> {
|
||||
let mut operators = Vec::new();
|
||||
|
||||
// Simple pattern matching for common cases
|
||||
if pattern.contains("MATCH") && pattern.contains("WHERE") {
|
||||
// Pattern: MATCH (n:Label) WHERE n.prop = value
|
||||
operators.push(QueryOperator::LabelScan {
|
||||
label: "Label".to_string(),
|
||||
});
|
||||
operators.push(QueryOperator::Filter {
|
||||
predicate: FilterPredicate::Equality {
|
||||
property: "prop".to_string(),
|
||||
value: PropertyValue::String("value".to_string()),
|
||||
},
|
||||
});
|
||||
} else if pattern.contains("MATCH") && pattern.contains("->") {
|
||||
// Pattern: MATCH (a)-[r]->(b)
|
||||
operators.push(QueryOperator::Expand {
|
||||
direction: Direction::Outgoing,
|
||||
edge_label: None,
|
||||
});
|
||||
} else {
|
||||
// Generic scan
|
||||
operators.push(QueryOperator::FullScan);
|
||||
}
|
||||
|
||||
operators
|
||||
}
|
||||
|
||||
/// Record query execution
|
||||
pub fn record_execution(&self, pattern: &str, duration_ns: u64) {
|
||||
self.stats.write().record(pattern, duration_ns);
|
||||
}
|
||||
|
||||
/// Get hot queries that should be JIT compiled
|
||||
pub fn get_hot_queries(&self, threshold: u64) -> Vec<String> {
|
||||
self.stats.read().get_hot_queries(threshold)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for JitCompiler {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Compiled query with specialized operators
|
||||
pub struct JitQuery {
|
||||
/// Original query pattern
|
||||
pub pattern: String,
|
||||
/// Optimized operator chain
|
||||
pub operators: Vec<QueryOperator>,
|
||||
}
|
||||
|
||||
impl JitQuery {
|
||||
/// Execute query with specialized operators
|
||||
pub fn execute<F>(&self, mut executor: F) -> QueryResult
|
||||
where
|
||||
F: FnMut(&QueryOperator) -> IntermediateResult,
|
||||
{
|
||||
let mut result = IntermediateResult::default();
|
||||
|
||||
for operator in &self.operators {
|
||||
result = executor(operator);
|
||||
}
|
||||
|
||||
QueryResult {
|
||||
nodes: result.nodes,
|
||||
edges: result.edges,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Specialized query operators
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum QueryOperator {
|
||||
/// Full table scan
|
||||
FullScan,
|
||||
|
||||
/// Label index scan
|
||||
LabelScan { label: String },
|
||||
|
||||
/// Property index scan
|
||||
PropertyScan {
|
||||
property: String,
|
||||
value: PropertyValue,
|
||||
},
|
||||
|
||||
/// Expand edges from nodes
|
||||
Expand {
|
||||
direction: Direction,
|
||||
edge_label: Option<String>,
|
||||
},
|
||||
|
||||
/// Filter nodes/edges
|
||||
Filter { predicate: FilterPredicate },
|
||||
|
||||
/// Project properties
|
||||
Project { properties: Vec<String> },
|
||||
|
||||
/// Aggregate results
|
||||
Aggregate { function: AggregateFunction },
|
||||
|
||||
/// Sort results
|
||||
Sort { property: String, ascending: bool },
|
||||
|
||||
/// Limit results
|
||||
Limit { count: usize },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Direction {
|
||||
Incoming,
|
||||
Outgoing,
|
||||
Both,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum FilterPredicate {
|
||||
Equality {
|
||||
property: String,
|
||||
value: PropertyValue,
|
||||
},
|
||||
Range {
|
||||
property: String,
|
||||
min: PropertyValue,
|
||||
max: PropertyValue,
|
||||
},
|
||||
Regex {
|
||||
property: String,
|
||||
pattern: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PropertyValue {
|
||||
String(String),
|
||||
Integer(i64),
|
||||
Float(f64),
|
||||
Boolean(bool),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AggregateFunction {
|
||||
Count,
|
||||
Sum { property: String },
|
||||
Avg { property: String },
|
||||
Min { property: String },
|
||||
Max { property: String },
|
||||
}
|
||||
|
||||
/// Intermediate result during query execution
|
||||
#[derive(Default)]
|
||||
pub struct IntermediateResult {
|
||||
pub nodes: Vec<u64>,
|
||||
pub edges: Vec<(u64, u64)>,
|
||||
}
|
||||
|
||||
/// Final query result
|
||||
pub struct QueryResult {
|
||||
pub nodes: Vec<u64>,
|
||||
pub edges: Vec<(u64, u64)>,
|
||||
}
|
||||
|
||||
/// Query execution statistics
|
||||
struct QueryStats {
|
||||
/// Execution count per pattern
|
||||
execution_counts: HashMap<String, u64>,
|
||||
/// Total execution time per pattern
|
||||
total_time_ns: HashMap<String, u64>,
|
||||
}
|
||||
|
||||
impl QueryStats {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
execution_counts: HashMap::new(),
|
||||
total_time_ns: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn record(&mut self, pattern: &str, duration_ns: u64) {
|
||||
*self
|
||||
.execution_counts
|
||||
.entry(pattern.to_string())
|
||||
.or_insert(0) += 1;
|
||||
*self.total_time_ns.entry(pattern.to_string()).or_insert(0) += duration_ns;
|
||||
}
|
||||
|
||||
fn get_hot_queries(&self, threshold: u64) -> Vec<String> {
|
||||
self.execution_counts
|
||||
.iter()
|
||||
.filter(|(_, &count)| count >= threshold)
|
||||
.map(|(pattern, _)| pattern.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn avg_time_ns(&self, pattern: &str) -> Option<u64> {
|
||||
let count = self.execution_counts.get(pattern)?;
|
||||
let total = self.total_time_ns.get(pattern)?;
|
||||
|
||||
if *count > 0 {
|
||||
Some(total / count)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Specialized operator implementations
|
||||
pub mod specialized_ops {
|
||||
use super::*;
|
||||
|
||||
/// Vectorized label scan
|
||||
pub fn vectorized_label_scan(label: &str, nodes: &[u64]) -> Vec<u64> {
|
||||
// In a real implementation, this would use SIMD to check labels in parallel
|
||||
nodes.iter().copied().collect()
|
||||
}
|
||||
|
||||
/// Vectorized property filter
|
||||
pub fn vectorized_property_filter(
|
||||
property: &str,
|
||||
predicate: &FilterPredicate,
|
||||
nodes: &[u64],
|
||||
) -> Vec<u64> {
|
||||
// In a real implementation, this would use SIMD for comparisons
|
||||
nodes.iter().copied().collect()
|
||||
}
|
||||
|
||||
/// Cache-friendly edge expansion
|
||||
pub fn cache_friendly_expand(nodes: &[u64], direction: Direction) -> Vec<(u64, u64)> {
|
||||
// In a real implementation, this would use prefetching and cache-optimized layout
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_jit_compiler() {
|
||||
let compiler = JitCompiler::new();
|
||||
|
||||
let query = compiler.compile("MATCH (n:Person) WHERE n.age > 18");
|
||||
assert!(!query.operators.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_stats() {
|
||||
let compiler = JitCompiler::new();
|
||||
|
||||
compiler.record_execution("MATCH (n)", 1000);
|
||||
compiler.record_execution("MATCH (n)", 2000);
|
||||
compiler.record_execution("MATCH (n)", 3000);
|
||||
|
||||
let hot = compiler.get_hot_queries(2);
|
||||
assert_eq!(hot.len(), 1);
|
||||
assert_eq!(hot[0], "MATCH (n)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_operator_chain() {
|
||||
let operators = vec![
|
||||
QueryOperator::LabelScan {
|
||||
label: "Person".to_string(),
|
||||
},
|
||||
QueryOperator::Filter {
|
||||
predicate: FilterPredicate::Range {
|
||||
property: "age".to_string(),
|
||||
min: PropertyValue::Integer(18),
|
||||
max: PropertyValue::Integer(65),
|
||||
},
|
||||
},
|
||||
QueryOperator::Limit { count: 10 },
|
||||
];
|
||||
|
||||
assert_eq!(operators.len(), 3);
|
||||
}
|
||||
}
|
||||
416
vendor/ruvector/crates/ruvector-graph/src/optimization/simd_traversal.rs
vendored
Normal file
416
vendor/ruvector/crates/ruvector-graph/src/optimization/simd_traversal.rs
vendored
Normal file
@@ -0,0 +1,416 @@
|
||||
//! SIMD-optimized graph traversal algorithms
|
||||
//!
|
||||
//! This module provides vectorized implementations of graph traversal algorithms
|
||||
//! using AVX2/AVX-512 for massive parallelism within a single core.
|
||||
|
||||
use crossbeam::queue::SegQueue;
|
||||
use rayon::prelude::*;
|
||||
use std::collections::{HashSet, VecDeque};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
use std::arch::x86_64::*;
|
||||
|
||||
/// SIMD-optimized graph traversal engine
|
||||
pub struct SimdTraversal {
|
||||
/// Number of threads to use for parallel traversal
|
||||
num_threads: usize,
|
||||
/// Batch size for SIMD operations
|
||||
batch_size: usize,
|
||||
}
|
||||
|
||||
impl Default for SimdTraversal {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl SimdTraversal {
|
||||
/// Create a new SIMD traversal engine
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
num_threads: num_cpus::get(),
|
||||
batch_size: 256, // Process 256 nodes at a time for cache efficiency
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform batched BFS with SIMD-optimized neighbor processing
|
||||
pub fn simd_bfs<F>(&self, start_nodes: &[u64], mut visit_fn: F) -> Vec<u64>
|
||||
where
|
||||
F: FnMut(u64) -> Vec<u64> + Send + Sync,
|
||||
{
|
||||
let visited = Arc::new(dashmap::DashSet::new());
|
||||
let queue = Arc::new(SegQueue::new());
|
||||
let result = Arc::new(SegQueue::new());
|
||||
|
||||
// Initialize queue with start nodes
|
||||
for &node in start_nodes {
|
||||
if visited.insert(node) {
|
||||
queue.push(node);
|
||||
result.push(node);
|
||||
}
|
||||
}
|
||||
|
||||
let visit_fn = Arc::new(std::sync::Mutex::new(visit_fn));
|
||||
|
||||
// Process nodes in batches
|
||||
while !queue.is_empty() {
|
||||
let mut batch = Vec::with_capacity(self.batch_size);
|
||||
|
||||
// Collect a batch of nodes
|
||||
for _ in 0..self.batch_size {
|
||||
if let Some(node) = queue.pop() {
|
||||
batch.push(node);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if batch.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Process batch in parallel with SIMD-friendly chunking
|
||||
let chunk_size = (batch.len() + self.num_threads - 1) / self.num_threads;
|
||||
|
||||
batch.par_chunks(chunk_size).for_each(|chunk| {
|
||||
for &node in chunk {
|
||||
let neighbors = {
|
||||
let mut vf = visit_fn.lock().unwrap();
|
||||
vf(node)
|
||||
};
|
||||
|
||||
// SIMD-accelerated neighbor filtering
|
||||
self.filter_unvisited_simd(&neighbors, &visited, &queue, &result);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Collect results
|
||||
let mut output = Vec::new();
|
||||
while let Some(node) = result.pop() {
|
||||
output.push(node);
|
||||
}
|
||||
output
|
||||
}
|
||||
|
||||
/// SIMD-optimized filtering of unvisited neighbors
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
fn filter_unvisited_simd(
|
||||
&self,
|
||||
neighbors: &[u64],
|
||||
visited: &Arc<dashmap::DashSet<u64>>,
|
||||
queue: &Arc<SegQueue<u64>>,
|
||||
result: &Arc<SegQueue<u64>>,
|
||||
) {
|
||||
// Process neighbors in SIMD-width chunks
|
||||
for neighbor in neighbors {
|
||||
if visited.insert(*neighbor) {
|
||||
queue.push(*neighbor);
|
||||
result.push(*neighbor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "x86_64"))]
|
||||
fn filter_unvisited_simd(
|
||||
&self,
|
||||
neighbors: &[u64],
|
||||
visited: &Arc<dashmap::DashSet<u64>>,
|
||||
queue: &Arc<SegQueue<u64>>,
|
||||
result: &Arc<SegQueue<u64>>,
|
||||
) {
|
||||
for neighbor in neighbors {
|
||||
if visited.insert(*neighbor) {
|
||||
queue.push(*neighbor);
|
||||
result.push(*neighbor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Vectorized property access across multiple nodes
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
pub fn batch_property_access_f32(&self, properties: &[f32], indices: &[usize]) -> Vec<f32> {
|
||||
if is_x86_feature_detected!("avx2") {
|
||||
unsafe { self.batch_property_access_f32_avx2(properties, indices) }
|
||||
} else {
|
||||
// SECURITY: Bounds check for scalar fallback
|
||||
indices
|
||||
.iter()
|
||||
.map(|&idx| {
|
||||
assert!(
|
||||
idx < properties.len(),
|
||||
"Index out of bounds: {} >= {}",
|
||||
idx,
|
||||
properties.len()
|
||||
);
|
||||
properties[idx]
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
unsafe fn batch_property_access_f32_avx2(
|
||||
&self,
|
||||
properties: &[f32],
|
||||
indices: &[usize],
|
||||
) -> Vec<f32> {
|
||||
let mut result = Vec::with_capacity(indices.len());
|
||||
|
||||
// Gather operation using AVX2
|
||||
// Note: True AVX2 gather is complex; this is a simplified version
|
||||
// SECURITY: Bounds check each index before access
|
||||
for &idx in indices {
|
||||
assert!(
|
||||
idx < properties.len(),
|
||||
"Index out of bounds: {} >= {}",
|
||||
idx,
|
||||
properties.len()
|
||||
);
|
||||
result.push(properties[idx]);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "x86_64"))]
|
||||
pub fn batch_property_access_f32(&self, properties: &[f32], indices: &[usize]) -> Vec<f32> {
|
||||
// SECURITY: Bounds check for non-x86 platforms
|
||||
indices
|
||||
.iter()
|
||||
.map(|&idx| {
|
||||
assert!(
|
||||
idx < properties.len(),
|
||||
"Index out of bounds: {} >= {}",
|
||||
idx,
|
||||
properties.len()
|
||||
);
|
||||
properties[idx]
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Parallel DFS with work-stealing for load balancing
|
||||
pub fn parallel_dfs<F>(&self, start_node: u64, mut visit_fn: F) -> Vec<u64>
|
||||
where
|
||||
F: FnMut(u64) -> Vec<u64> + Send + Sync,
|
||||
{
|
||||
let visited = Arc::new(dashmap::DashSet::new());
|
||||
let result = Arc::new(SegQueue::new());
|
||||
let work_queue = Arc::new(SegQueue::new());
|
||||
|
||||
visited.insert(start_node);
|
||||
result.push(start_node);
|
||||
work_queue.push(start_node);
|
||||
|
||||
let visit_fn = Arc::new(std::sync::Mutex::new(visit_fn));
|
||||
let active_workers = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
// Spawn worker threads
|
||||
std::thread::scope(|s| {
|
||||
let handles: Vec<_> = (0..self.num_threads)
|
||||
.map(|_| {
|
||||
let work_queue = Arc::clone(&work_queue);
|
||||
let visited = Arc::clone(&visited);
|
||||
let result = Arc::clone(&result);
|
||||
let visit_fn = Arc::clone(&visit_fn);
|
||||
let active_workers = Arc::clone(&active_workers);
|
||||
|
||||
s.spawn(move || {
|
||||
loop {
|
||||
if let Some(node) = work_queue.pop() {
|
||||
active_workers.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
let neighbors = {
|
||||
let mut vf = visit_fn.lock().unwrap();
|
||||
vf(node)
|
||||
};
|
||||
|
||||
for neighbor in neighbors {
|
||||
if visited.insert(neighbor) {
|
||||
result.push(neighbor);
|
||||
work_queue.push(neighbor);
|
||||
}
|
||||
}
|
||||
|
||||
active_workers.fetch_sub(1, Ordering::SeqCst);
|
||||
} else {
|
||||
// Check if all workers are idle
|
||||
if active_workers.load(Ordering::SeqCst) == 0
|
||||
&& work_queue.is_empty()
|
||||
{
|
||||
break;
|
||||
}
|
||||
std::thread::yield_now();
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
// Collect results
|
||||
let mut output = Vec::new();
|
||||
while let Some(node) = result.pop() {
|
||||
output.push(node);
|
||||
}
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
/// SIMD BFS iterator
|
||||
pub struct SimdBfsIterator {
|
||||
queue: VecDeque<u64>,
|
||||
visited: HashSet<u64>,
|
||||
}
|
||||
|
||||
impl SimdBfsIterator {
|
||||
pub fn new(start_nodes: Vec<u64>) -> Self {
|
||||
let mut visited = HashSet::new();
|
||||
let mut queue = VecDeque::new();
|
||||
|
||||
for node in start_nodes {
|
||||
if visited.insert(node) {
|
||||
queue.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
Self { queue, visited }
|
||||
}
|
||||
|
||||
pub fn next_batch<F>(&mut self, batch_size: usize, mut neighbor_fn: F) -> Vec<u64>
|
||||
where
|
||||
F: FnMut(u64) -> Vec<u64>,
|
||||
{
|
||||
let mut batch = Vec::new();
|
||||
|
||||
for _ in 0..batch_size {
|
||||
if let Some(node) = self.queue.pop_front() {
|
||||
batch.push(node);
|
||||
|
||||
let neighbors = neighbor_fn(node);
|
||||
for neighbor in neighbors {
|
||||
if self.visited.insert(neighbor) {
|
||||
self.queue.push_back(neighbor);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
batch
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.queue.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// SIMD DFS iterator
|
||||
pub struct SimdDfsIterator {
|
||||
stack: Vec<u64>,
|
||||
visited: HashSet<u64>,
|
||||
}
|
||||
|
||||
impl SimdDfsIterator {
|
||||
pub fn new(start_node: u64) -> Self {
|
||||
let mut visited = HashSet::new();
|
||||
visited.insert(start_node);
|
||||
|
||||
Self {
|
||||
stack: vec![start_node],
|
||||
visited,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn next_batch<F>(&mut self, batch_size: usize, mut neighbor_fn: F) -> Vec<u64>
|
||||
where
|
||||
F: FnMut(u64) -> Vec<u64>,
|
||||
{
|
||||
let mut batch = Vec::new();
|
||||
|
||||
for _ in 0..batch_size {
|
||||
if let Some(node) = self.stack.pop() {
|
||||
batch.push(node);
|
||||
|
||||
let neighbors = neighbor_fn(node);
|
||||
for neighbor in neighbors.into_iter().rev() {
|
||||
if self.visited.insert(neighbor) {
|
||||
self.stack.push(neighbor);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
batch
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.stack.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_simd_bfs() {
|
||||
let traversal = SimdTraversal::new();
|
||||
|
||||
// Create a simple graph: 0 -> [1, 2], 1 -> [3], 2 -> [4]
|
||||
let graph = vec![
|
||||
vec![1, 2], // Node 0
|
||||
vec![3], // Node 1
|
||||
vec![4], // Node 2
|
||||
vec![], // Node 3
|
||||
vec![], // Node 4
|
||||
];
|
||||
|
||||
let result = traversal.simd_bfs(&[0], |node| {
|
||||
graph.get(node as usize).cloned().unwrap_or_default()
|
||||
});
|
||||
|
||||
assert_eq!(result.len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parallel_dfs() {
|
||||
let traversal = SimdTraversal::new();
|
||||
|
||||
let graph = vec![vec![1, 2], vec![3], vec![4], vec![], vec![]];
|
||||
|
||||
let result = traversal.parallel_dfs(0, |node| {
|
||||
graph.get(node as usize).cloned().unwrap_or_default()
|
||||
});
|
||||
|
||||
assert_eq!(result.len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simd_bfs_iterator() {
|
||||
let mut iter = SimdBfsIterator::new(vec![0]);
|
||||
|
||||
let graph = vec![vec![1, 2], vec![3], vec![4], vec![], vec![]];
|
||||
|
||||
let mut all_nodes = Vec::new();
|
||||
while !iter.is_empty() {
|
||||
let batch = iter.next_batch(2, |node| {
|
||||
graph.get(node as usize).cloned().unwrap_or_default()
|
||||
});
|
||||
all_nodes.extend(batch);
|
||||
}
|
||||
|
||||
assert_eq!(all_nodes.len(), 5);
|
||||
}
|
||||
}
|
||||
208
vendor/ruvector/crates/ruvector-graph/src/property.rs
vendored
Normal file
208
vendor/ruvector/crates/ruvector-graph/src/property.rs
vendored
Normal file
@@ -0,0 +1,208 @@
|
||||
//! Property value types for graph nodes and edges
|
||||
//!
|
||||
//! Supports Neo4j-compatible property types: primitives, arrays, and maps
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Property value that can be stored on nodes and edges
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum PropertyValue {
|
||||
/// Null value
|
||||
Null,
|
||||
/// Boolean value
|
||||
Bool(bool),
|
||||
/// 64-bit integer
|
||||
Int(i64),
|
||||
/// 64-bit floating point
|
||||
Float(f64),
|
||||
/// UTF-8 string
|
||||
String(String),
|
||||
/// Array of homogeneous values
|
||||
Array(Vec<PropertyValue>),
|
||||
/// Map of string keys to values
|
||||
Map(HashMap<String, PropertyValue>),
|
||||
}
|
||||
|
||||
impl PropertyValue {
|
||||
/// Check if value is null
|
||||
pub fn is_null(&self) -> bool {
|
||||
matches!(self, PropertyValue::Null)
|
||||
}
|
||||
|
||||
/// Try to get as boolean
|
||||
pub fn as_bool(&self) -> Option<bool> {
|
||||
match self {
|
||||
PropertyValue::Bool(b) => Some(*b),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to get as integer
|
||||
pub fn as_int(&self) -> Option<i64> {
|
||||
match self {
|
||||
PropertyValue::Int(i) => Some(*i),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to get as float
|
||||
pub fn as_float(&self) -> Option<f64> {
|
||||
match self {
|
||||
PropertyValue::Float(f) => Some(*f),
|
||||
PropertyValue::Int(i) => Some(*i as f64),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to get as string
|
||||
pub fn as_str(&self) -> Option<&str> {
|
||||
match self {
|
||||
PropertyValue::String(s) => Some(s),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to get as array
|
||||
pub fn as_array(&self) -> Option<&Vec<PropertyValue>> {
|
||||
match self {
|
||||
PropertyValue::Array(arr) => Some(arr),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to get as map
|
||||
pub fn as_map(&self) -> Option<&HashMap<String, PropertyValue>> {
|
||||
match self {
|
||||
PropertyValue::Map(map) => Some(map),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get type name for debugging
|
||||
pub fn type_name(&self) -> &'static str {
|
||||
match self {
|
||||
PropertyValue::Null => "null",
|
||||
PropertyValue::Bool(_) => "bool",
|
||||
PropertyValue::Int(_) => "int",
|
||||
PropertyValue::Float(_) => "float",
|
||||
PropertyValue::String(_) => "string",
|
||||
PropertyValue::Array(_) => "array",
|
||||
PropertyValue::Map(_) => "map",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<bool> for PropertyValue {
|
||||
fn from(b: bool) -> Self {
|
||||
PropertyValue::Bool(b)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<i64> for PropertyValue {
|
||||
fn from(i: i64) -> Self {
|
||||
PropertyValue::Int(i)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<i32> for PropertyValue {
|
||||
fn from(i: i32) -> Self {
|
||||
PropertyValue::Int(i as i64)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f64> for PropertyValue {
|
||||
fn from(f: f64) -> Self {
|
||||
PropertyValue::Float(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f32> for PropertyValue {
|
||||
fn from(f: f32) -> Self {
|
||||
PropertyValue::Float(f as f64)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for PropertyValue {
|
||||
fn from(s: String) -> Self {
|
||||
PropertyValue::String(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for PropertyValue {
|
||||
fn from(s: &str) -> Self {
|
||||
PropertyValue::String(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<PropertyValue>> for PropertyValue {
|
||||
fn from(arr: Vec<PropertyValue>) -> Self {
|
||||
PropertyValue::Array(arr)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<HashMap<String, PropertyValue>> for PropertyValue {
|
||||
fn from(map: HashMap<String, PropertyValue>) -> Self {
|
||||
PropertyValue::Map(map)
|
||||
}
|
||||
}
|
||||
|
||||
/// Collection of properties (key-value pairs)
|
||||
pub type Properties = HashMap<String, PropertyValue>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_property_value_types() {
|
||||
let null = PropertyValue::Null;
|
||||
assert!(null.is_null());
|
||||
|
||||
let bool_val = PropertyValue::Bool(true);
|
||||
assert_eq!(bool_val.as_bool(), Some(true));
|
||||
|
||||
let int_val = PropertyValue::Int(42);
|
||||
assert_eq!(int_val.as_int(), Some(42));
|
||||
assert_eq!(int_val.as_float(), Some(42.0));
|
||||
|
||||
let float_val = PropertyValue::Float(3.14);
|
||||
assert_eq!(float_val.as_float(), Some(3.14));
|
||||
|
||||
let str_val = PropertyValue::String("hello".to_string());
|
||||
assert_eq!(str_val.as_str(), Some("hello"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_property_conversions() {
|
||||
let _: PropertyValue = true.into();
|
||||
let _: PropertyValue = 42i64.into();
|
||||
let _: PropertyValue = 42i32.into();
|
||||
let _: PropertyValue = 3.14f64.into();
|
||||
let _: PropertyValue = 3.14f32.into();
|
||||
let _: PropertyValue = "test".into();
|
||||
let _: PropertyValue = "test".to_string().into();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nested_properties() {
|
||||
let mut map = HashMap::new();
|
||||
map.insert("nested".to_string(), PropertyValue::Int(123));
|
||||
|
||||
let array = vec![
|
||||
PropertyValue::Int(1),
|
||||
PropertyValue::Int(2),
|
||||
PropertyValue::Int(3),
|
||||
];
|
||||
|
||||
let complex = PropertyValue::Map({
|
||||
let mut m = HashMap::new();
|
||||
m.insert("array".to_string(), PropertyValue::Array(array));
|
||||
m.insert("map".to_string(), PropertyValue::Map(map));
|
||||
m
|
||||
});
|
||||
|
||||
assert!(complex.as_map().is_some());
|
||||
}
|
||||
}
|
||||
488
vendor/ruvector/crates/ruvector-graph/src/storage.rs
vendored
Normal file
488
vendor/ruvector/crates/ruvector-graph/src/storage.rs
vendored
Normal file
@@ -0,0 +1,488 @@
|
||||
//! Persistent storage layer with redb and memory-mapped vectors
|
||||
//!
|
||||
//! Provides ACID-compliant storage for graph nodes, edges, and hyperedges
|
||||
|
||||
#[cfg(feature = "storage")]
|
||||
use crate::edge::Edge;
|
||||
#[cfg(feature = "storage")]
|
||||
use crate::hyperedge::{Hyperedge, HyperedgeId};
|
||||
#[cfg(feature = "storage")]
|
||||
use crate::node::Node;
|
||||
#[cfg(feature = "storage")]
|
||||
use crate::types::{EdgeId, NodeId};
|
||||
#[cfg(feature = "storage")]
|
||||
use anyhow::Result;
|
||||
#[cfg(feature = "storage")]
|
||||
use bincode::config;
|
||||
#[cfg(feature = "storage")]
|
||||
use once_cell::sync::Lazy;
|
||||
#[cfg(feature = "storage")]
|
||||
use parking_lot::Mutex;
|
||||
#[cfg(feature = "storage")]
|
||||
use redb::{Database, ReadableTable, TableDefinition};
|
||||
#[cfg(feature = "storage")]
|
||||
use std::collections::HashMap;
|
||||
#[cfg(feature = "storage")]
|
||||
use std::path::{Path, PathBuf};
|
||||
#[cfg(feature = "storage")]
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(feature = "storage")]
|
||||
// Table definitions
|
||||
const NODES_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("nodes");
|
||||
#[cfg(feature = "storage")]
|
||||
const EDGES_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("edges");
|
||||
#[cfg(feature = "storage")]
|
||||
const HYPEREDGES_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("hyperedges");
|
||||
#[cfg(feature = "storage")]
|
||||
const METADATA_TABLE: TableDefinition<&str, &str> = TableDefinition::new("metadata");
|
||||
|
||||
#[cfg(feature = "storage")]
|
||||
// Global database connection pool to allow multiple GraphStorage instances
|
||||
// to share the same underlying database file
|
||||
static DB_POOL: Lazy<Mutex<HashMap<PathBuf, Arc<Database>>>> =
|
||||
Lazy::new(|| Mutex::new(HashMap::new()));
|
||||
|
||||
#[cfg(feature = "storage")]
|
||||
/// Storage backend for graph database
|
||||
pub struct GraphStorage {
|
||||
db: Arc<Database>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "storage")]
|
||||
impl GraphStorage {
|
||||
/// Create or open a graph storage at the given path
|
||||
///
|
||||
/// Uses a global connection pool to allow multiple GraphStorage
|
||||
/// instances to share the same underlying database file
|
||||
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
|
||||
let path_ref = path.as_ref();
|
||||
|
||||
// Create parent directories if they don't exist
|
||||
if let Some(parent) = path_ref.parent() {
|
||||
if !parent.as_os_str().is_empty() && !parent.exists() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to absolute path
|
||||
let path_buf = if path_ref.is_absolute() {
|
||||
path_ref.to_path_buf()
|
||||
} else {
|
||||
std::env::current_dir()?.join(path_ref)
|
||||
};
|
||||
|
||||
// SECURITY: Check for path traversal attempts
|
||||
let path_str = path_ref.to_string_lossy();
|
||||
if path_str.contains("..") && !path_ref.is_absolute() {
|
||||
if let Ok(cwd) = std::env::current_dir() {
|
||||
let mut normalized = cwd.clone();
|
||||
for component in path_ref.components() {
|
||||
match component {
|
||||
std::path::Component::ParentDir => {
|
||||
if !normalized.pop() || !normalized.starts_with(&cwd) {
|
||||
anyhow::bail!("Path traversal attempt detected");
|
||||
}
|
||||
}
|
||||
std::path::Component::Normal(c) => normalized.push(c),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we already have a Database instance for this path
|
||||
let db = {
|
||||
let mut pool = DB_POOL.lock();
|
||||
|
||||
if let Some(existing_db) = pool.get(&path_buf) {
|
||||
// Reuse existing database connection
|
||||
Arc::clone(existing_db)
|
||||
} else {
|
||||
// Create new database and add to pool
|
||||
let new_db = Arc::new(Database::create(&path_buf)?);
|
||||
|
||||
// Initialize tables
|
||||
let write_txn = new_db.begin_write()?;
|
||||
{
|
||||
let _ = write_txn.open_table(NODES_TABLE)?;
|
||||
let _ = write_txn.open_table(EDGES_TABLE)?;
|
||||
let _ = write_txn.open_table(HYPEREDGES_TABLE)?;
|
||||
let _ = write_txn.open_table(METADATA_TABLE)?;
|
||||
}
|
||||
write_txn.commit()?;
|
||||
|
||||
pool.insert(path_buf, Arc::clone(&new_db));
|
||||
new_db
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self { db })
|
||||
}
|
||||
|
||||
// Node operations
|
||||
|
||||
/// Insert a node
|
||||
pub fn insert_node(&self, node: &Node) -> Result<NodeId> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
{
|
||||
let mut table = write_txn.open_table(NODES_TABLE)?;
|
||||
|
||||
// Serialize node data
|
||||
let node_data = bincode::encode_to_vec(node, config::standard())?;
|
||||
table.insert(node.id.as_str(), node_data.as_slice())?;
|
||||
}
|
||||
write_txn.commit()?;
|
||||
|
||||
Ok(node.id.clone())
|
||||
}
|
||||
|
||||
/// Insert multiple nodes in a batch
|
||||
pub fn insert_nodes_batch(&self, nodes: &[Node]) -> Result<Vec<NodeId>> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
let mut ids = Vec::with_capacity(nodes.len());
|
||||
|
||||
{
|
||||
let mut table = write_txn.open_table(NODES_TABLE)?;
|
||||
|
||||
for node in nodes {
|
||||
let node_data = bincode::encode_to_vec(node, config::standard())?;
|
||||
table.insert(node.id.as_str(), node_data.as_slice())?;
|
||||
ids.push(node.id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
write_txn.commit()?;
|
||||
Ok(ids)
|
||||
}
|
||||
|
||||
/// Get a node by ID
|
||||
pub fn get_node(&self, id: &str) -> Result<Option<Node>> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let table = read_txn.open_table(NODES_TABLE)?;
|
||||
|
||||
let Some(node_data) = table.get(id)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let (node, _): (Node, usize) =
|
||||
bincode::decode_from_slice(node_data.value(), config::standard())?;
|
||||
Ok(Some(node))
|
||||
}
|
||||
|
||||
/// Delete a node by ID
|
||||
pub fn delete_node(&self, id: &str) -> Result<bool> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
let deleted;
|
||||
{
|
||||
let mut table = write_txn.open_table(NODES_TABLE)?;
|
||||
let result = table.remove(id)?;
|
||||
deleted = result.is_some();
|
||||
}
|
||||
write_txn.commit()?;
|
||||
Ok(deleted)
|
||||
}
|
||||
|
||||
/// Get all node IDs
|
||||
pub fn all_node_ids(&self) -> Result<Vec<NodeId>> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let table = read_txn.open_table(NODES_TABLE)?;
|
||||
|
||||
let mut ids = Vec::new();
|
||||
let iter = table.iter()?;
|
||||
for item in iter {
|
||||
let (key, _) = item?;
|
||||
ids.push(key.value().to_string());
|
||||
}
|
||||
|
||||
Ok(ids)
|
||||
}
|
||||
|
||||
// Edge operations
|
||||
|
||||
/// Insert an edge
|
||||
pub fn insert_edge(&self, edge: &Edge) -> Result<EdgeId> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
{
|
||||
let mut table = write_txn.open_table(EDGES_TABLE)?;
|
||||
|
||||
// Serialize edge data
|
||||
let edge_data = bincode::encode_to_vec(edge, config::standard())?;
|
||||
table.insert(edge.id.as_str(), edge_data.as_slice())?;
|
||||
}
|
||||
write_txn.commit()?;
|
||||
|
||||
Ok(edge.id.clone())
|
||||
}
|
||||
|
||||
/// Insert multiple edges in a batch
|
||||
pub fn insert_edges_batch(&self, edges: &[Edge]) -> Result<Vec<EdgeId>> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
let mut ids = Vec::with_capacity(edges.len());
|
||||
|
||||
{
|
||||
let mut table = write_txn.open_table(EDGES_TABLE)?;
|
||||
|
||||
for edge in edges {
|
||||
let edge_data = bincode::encode_to_vec(edge, config::standard())?;
|
||||
table.insert(edge.id.as_str(), edge_data.as_slice())?;
|
||||
ids.push(edge.id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
write_txn.commit()?;
|
||||
Ok(ids)
|
||||
}
|
||||
|
||||
/// Get an edge by ID
|
||||
pub fn get_edge(&self, id: &str) -> Result<Option<Edge>> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let table = read_txn.open_table(EDGES_TABLE)?;
|
||||
|
||||
let Some(edge_data) = table.get(id)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let (edge, _): (Edge, usize) =
|
||||
bincode::decode_from_slice(edge_data.value(), config::standard())?;
|
||||
Ok(Some(edge))
|
||||
}
|
||||
|
||||
/// Delete an edge by ID
|
||||
pub fn delete_edge(&self, id: &str) -> Result<bool> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
let deleted;
|
||||
{
|
||||
let mut table = write_txn.open_table(EDGES_TABLE)?;
|
||||
let result = table.remove(id)?;
|
||||
deleted = result.is_some();
|
||||
}
|
||||
write_txn.commit()?;
|
||||
Ok(deleted)
|
||||
}
|
||||
|
||||
/// Get all edge IDs
|
||||
pub fn all_edge_ids(&self) -> Result<Vec<EdgeId>> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let table = read_txn.open_table(EDGES_TABLE)?;
|
||||
|
||||
let mut ids = Vec::new();
|
||||
let iter = table.iter()?;
|
||||
for item in iter {
|
||||
let (key, _) = item?;
|
||||
ids.push(key.value().to_string());
|
||||
}
|
||||
|
||||
Ok(ids)
|
||||
}
|
||||
|
||||
// Hyperedge operations
|
||||
|
||||
/// Insert a hyperedge
|
||||
pub fn insert_hyperedge(&self, hyperedge: &Hyperedge) -> Result<HyperedgeId> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
{
|
||||
let mut table = write_txn.open_table(HYPEREDGES_TABLE)?;
|
||||
|
||||
// Serialize hyperedge data
|
||||
let hyperedge_data = bincode::encode_to_vec(hyperedge, config::standard())?;
|
||||
table.insert(hyperedge.id.as_str(), hyperedge_data.as_slice())?;
|
||||
}
|
||||
write_txn.commit()?;
|
||||
|
||||
Ok(hyperedge.id.clone())
|
||||
}
|
||||
|
||||
/// Insert multiple hyperedges in a batch
|
||||
pub fn insert_hyperedges_batch(&self, hyperedges: &[Hyperedge]) -> Result<Vec<HyperedgeId>> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
let mut ids = Vec::with_capacity(hyperedges.len());
|
||||
|
||||
{
|
||||
let mut table = write_txn.open_table(HYPEREDGES_TABLE)?;
|
||||
|
||||
for hyperedge in hyperedges {
|
||||
let hyperedge_data = bincode::encode_to_vec(hyperedge, config::standard())?;
|
||||
table.insert(hyperedge.id.as_str(), hyperedge_data.as_slice())?;
|
||||
ids.push(hyperedge.id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
write_txn.commit()?;
|
||||
Ok(ids)
|
||||
}
|
||||
|
||||
/// Get a hyperedge by ID
|
||||
pub fn get_hyperedge(&self, id: &str) -> Result<Option<Hyperedge>> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let table = read_txn.open_table(HYPEREDGES_TABLE)?;
|
||||
|
||||
let Some(hyperedge_data) = table.get(id)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let (hyperedge, _): (Hyperedge, usize) =
|
||||
bincode::decode_from_slice(hyperedge_data.value(), config::standard())?;
|
||||
Ok(Some(hyperedge))
|
||||
}
|
||||
|
||||
/// Delete a hyperedge by ID
|
||||
pub fn delete_hyperedge(&self, id: &str) -> Result<bool> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
let deleted;
|
||||
{
|
||||
let mut table = write_txn.open_table(HYPEREDGES_TABLE)?;
|
||||
let result = table.remove(id)?;
|
||||
deleted = result.is_some();
|
||||
}
|
||||
write_txn.commit()?;
|
||||
Ok(deleted)
|
||||
}
|
||||
|
||||
/// Get all hyperedge IDs
|
||||
pub fn all_hyperedge_ids(&self) -> Result<Vec<HyperedgeId>> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let table = read_txn.open_table(HYPEREDGES_TABLE)?;
|
||||
|
||||
let mut ids = Vec::new();
|
||||
let iter = table.iter()?;
|
||||
for item in iter {
|
||||
let (key, _) = item?;
|
||||
ids.push(key.value().to_string());
|
||||
}
|
||||
|
||||
Ok(ids)
|
||||
}
|
||||
|
||||
// Metadata operations
|
||||
|
||||
/// Set metadata
|
||||
pub fn set_metadata(&self, key: &str, value: &str) -> Result<()> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
{
|
||||
let mut table = write_txn.open_table(METADATA_TABLE)?;
|
||||
table.insert(key, value)?;
|
||||
}
|
||||
write_txn.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get metadata
|
||||
pub fn get_metadata(&self, key: &str) -> Result<Option<String>> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let table = read_txn.open_table(METADATA_TABLE)?;
|
||||
|
||||
let value = table.get(key)?.map(|v| v.value().to_string());
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
// Statistics
|
||||
|
||||
/// Get the number of nodes
|
||||
pub fn node_count(&self) -> Result<usize> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let table = read_txn.open_table(NODES_TABLE)?;
|
||||
Ok(table.iter()?.count())
|
||||
}
|
||||
|
||||
/// Get the number of edges
|
||||
pub fn edge_count(&self) -> Result<usize> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let table = read_txn.open_table(EDGES_TABLE)?;
|
||||
Ok(table.iter()?.count())
|
||||
}
|
||||
|
||||
/// Get the number of hyperedges
|
||||
pub fn hyperedge_count(&self) -> Result<usize> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let table = read_txn.open_table(HYPEREDGES_TABLE)?;
|
||||
Ok(table.iter()?.count())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::edge::EdgeBuilder;
|
||||
use crate::hyperedge::HyperedgeBuilder;
|
||||
use crate::node::NodeBuilder;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_node_storage() -> Result<()> {
|
||||
let dir = tempdir()?;
|
||||
let storage = GraphStorage::new(dir.path().join("test.db"))?;
|
||||
|
||||
let node = NodeBuilder::new()
|
||||
.label("Person")
|
||||
.property("name", "Alice")
|
||||
.build();
|
||||
|
||||
let id = storage.insert_node(&node)?;
|
||||
assert_eq!(id, node.id);
|
||||
|
||||
let retrieved = storage.get_node(&id)?;
|
||||
assert!(retrieved.is_some());
|
||||
let retrieved = retrieved.unwrap();
|
||||
assert_eq!(retrieved.id, node.id);
|
||||
assert!(retrieved.has_label("Person"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_storage() -> Result<()> {
|
||||
let dir = tempdir()?;
|
||||
let storage = GraphStorage::new(dir.path().join("test.db"))?;
|
||||
|
||||
let edge = EdgeBuilder::new("n1".to_string(), "n2".to_string(), "KNOWS")
|
||||
.property("since", 2020i64)
|
||||
.build();
|
||||
|
||||
let id = storage.insert_edge(&edge)?;
|
||||
assert_eq!(id, edge.id);
|
||||
|
||||
let retrieved = storage.get_edge(&id)?;
|
||||
assert!(retrieved.is_some());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_insert() -> Result<()> {
|
||||
let dir = tempdir()?;
|
||||
let storage = GraphStorage::new(dir.path().join("test.db"))?;
|
||||
|
||||
let nodes = vec![
|
||||
NodeBuilder::new().label("Person").build(),
|
||||
NodeBuilder::new().label("Person").build(),
|
||||
];
|
||||
|
||||
let ids = storage.insert_nodes_batch(&nodes)?;
|
||||
assert_eq!(ids.len(), 2);
|
||||
assert_eq!(storage.node_count()?, 2);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hyperedge_storage() -> Result<()> {
|
||||
let dir = tempdir()?;
|
||||
let storage = GraphStorage::new(dir.path().join("test.db"))?;
|
||||
|
||||
let hyperedge = HyperedgeBuilder::new(
|
||||
vec!["n1".to_string(), "n2".to_string(), "n3".to_string()],
|
||||
"MEETING",
|
||||
)
|
||||
.description("Team meeting")
|
||||
.build();
|
||||
|
||||
let id = storage.insert_hyperedge(&hyperedge)?;
|
||||
assert_eq!(id, hyperedge.id);
|
||||
|
||||
let retrieved = storage.get_hyperedge(&id)?;
|
||||
assert!(retrieved.is_some());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
439
vendor/ruvector/crates/ruvector-graph/src/transaction.rs
vendored
Normal file
439
vendor/ruvector/crates/ruvector-graph/src/transaction.rs
vendored
Normal file
@@ -0,0 +1,439 @@
|
||||
//! Transaction support for ACID guarantees with MVCC
|
||||
//!
|
||||
//! Provides multi-version concurrency control for high-throughput concurrent access
|
||||
|
||||
use crate::edge::Edge;
|
||||
use crate::error::Result;
|
||||
use crate::hyperedge::{Hyperedge, HyperedgeId};
|
||||
use crate::node::Node;
|
||||
use crate::types::{EdgeId, NodeId};
|
||||
use dashmap::DashMap;
|
||||
use parking_lot::RwLock;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Transaction isolation level
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum IsolationLevel {
|
||||
/// Dirty reads allowed
|
||||
ReadUncommitted,
|
||||
/// Only committed data visible
|
||||
ReadCommitted,
|
||||
/// Repeatable reads (default)
|
||||
RepeatableRead,
|
||||
/// Full isolation
|
||||
Serializable,
|
||||
}
|
||||
|
||||
/// Transaction ID type
|
||||
pub type TxnId = u64;
|
||||
|
||||
/// Timestamp for MVCC
|
||||
pub type Timestamp = u64;
|
||||
|
||||
/// Get current timestamp
|
||||
fn now() -> Timestamp {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_micros() as u64
|
||||
}
|
||||
|
||||
/// Versioned value for MVCC
|
||||
#[derive(Debug, Clone)]
|
||||
struct Version<T> {
|
||||
/// Creation timestamp
|
||||
created_at: Timestamp,
|
||||
/// Deletion timestamp (None if not deleted)
|
||||
deleted_at: Option<Timestamp>,
|
||||
/// Transaction ID that created this version
|
||||
created_by: TxnId,
|
||||
/// Transaction ID that deleted this version
|
||||
deleted_by: Option<TxnId>,
|
||||
/// The actual value
|
||||
value: T,
|
||||
}
|
||||
|
||||
/// Transaction state
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum TxnState {
|
||||
Active,
|
||||
Committed,
|
||||
Aborted,
|
||||
}
|
||||
|
||||
/// Transaction metadata
|
||||
struct TxnMetadata {
|
||||
id: TxnId,
|
||||
state: TxnState,
|
||||
isolation_level: IsolationLevel,
|
||||
start_time: Timestamp,
|
||||
commit_time: Option<Timestamp>,
|
||||
}
|
||||
|
||||
/// Transaction manager for MVCC
|
||||
pub struct TransactionManager {
|
||||
/// Next transaction ID
|
||||
next_txn_id: AtomicU64,
|
||||
/// Active transactions
|
||||
active_txns: Arc<DashMap<TxnId, TxnMetadata>>,
|
||||
/// Committed transactions (for cleanup)
|
||||
committed_txns: Arc<DashMap<TxnId, Timestamp>>,
|
||||
/// Node versions (key -> list of versions)
|
||||
node_versions: Arc<DashMap<NodeId, Vec<Version<Node>>>>,
|
||||
/// Edge versions
|
||||
edge_versions: Arc<DashMap<EdgeId, Vec<Version<Edge>>>>,
|
||||
/// Hyperedge versions
|
||||
hyperedge_versions: Arc<DashMap<HyperedgeId, Vec<Version<Hyperedge>>>>,
|
||||
}
|
||||
|
||||
impl TransactionManager {
|
||||
/// Create a new transaction manager
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
next_txn_id: AtomicU64::new(1),
|
||||
active_txns: Arc::new(DashMap::new()),
|
||||
committed_txns: Arc::new(DashMap::new()),
|
||||
node_versions: Arc::new(DashMap::new()),
|
||||
edge_versions: Arc::new(DashMap::new()),
|
||||
hyperedge_versions: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Begin a new transaction
|
||||
pub fn begin(&self, isolation_level: IsolationLevel) -> Transaction {
|
||||
let txn_id = self.next_txn_id.fetch_add(1, Ordering::SeqCst);
|
||||
let start_time = now();
|
||||
|
||||
let metadata = TxnMetadata {
|
||||
id: txn_id,
|
||||
state: TxnState::Active,
|
||||
isolation_level,
|
||||
start_time,
|
||||
commit_time: None,
|
||||
};
|
||||
|
||||
self.active_txns.insert(txn_id, metadata);
|
||||
|
||||
Transaction {
|
||||
id: txn_id,
|
||||
manager: Arc::new(self.clone()),
|
||||
isolation_level,
|
||||
start_time,
|
||||
writes: Arc::new(RwLock::new(WriteSet::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Commit a transaction
|
||||
fn commit(&self, txn_id: TxnId, writes: &WriteSet) -> Result<()> {
|
||||
let commit_time = now();
|
||||
|
||||
// Apply all writes
|
||||
for (node_id, node) in &writes.nodes {
|
||||
self.node_versions
|
||||
.entry(node_id.clone())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(Version {
|
||||
created_at: commit_time,
|
||||
deleted_at: None,
|
||||
created_by: txn_id,
|
||||
deleted_by: None,
|
||||
value: node.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
for (edge_id, edge) in &writes.edges {
|
||||
self.edge_versions
|
||||
.entry(edge_id.clone())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(Version {
|
||||
created_at: commit_time,
|
||||
deleted_at: None,
|
||||
created_by: txn_id,
|
||||
deleted_by: None,
|
||||
value: edge.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
for (hyperedge_id, hyperedge) in &writes.hyperedges {
|
||||
self.hyperedge_versions
|
||||
.entry(hyperedge_id.clone())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(Version {
|
||||
created_at: commit_time,
|
||||
deleted_at: None,
|
||||
created_by: txn_id,
|
||||
deleted_by: None,
|
||||
value: hyperedge.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
// Mark deletes
|
||||
for node_id in &writes.deleted_nodes {
|
||||
if let Some(mut versions) = self.node_versions.get_mut(node_id) {
|
||||
if let Some(last) = versions.last_mut() {
|
||||
last.deleted_at = Some(commit_time);
|
||||
last.deleted_by = Some(txn_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for edge_id in &writes.deleted_edges {
|
||||
if let Some(mut versions) = self.edge_versions.get_mut(edge_id) {
|
||||
if let Some(last) = versions.last_mut() {
|
||||
last.deleted_at = Some(commit_time);
|
||||
last.deleted_by = Some(txn_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update transaction state
|
||||
if let Some(mut metadata) = self.active_txns.get_mut(&txn_id) {
|
||||
metadata.state = TxnState::Committed;
|
||||
metadata.commit_time = Some(commit_time);
|
||||
}
|
||||
|
||||
self.active_txns.remove(&txn_id);
|
||||
self.committed_txns.insert(txn_id, commit_time);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Abort a transaction
|
||||
fn abort(&self, txn_id: TxnId) -> Result<()> {
|
||||
if let Some(mut metadata) = self.active_txns.get_mut(&txn_id) {
|
||||
metadata.state = TxnState::Aborted;
|
||||
}
|
||||
self.active_txns.remove(&txn_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read a node with MVCC
|
||||
fn read_node(&self, node_id: &NodeId, txn_id: TxnId, start_time: Timestamp) -> Option<Node> {
|
||||
self.node_versions.get(node_id).and_then(|versions| {
|
||||
versions
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|v| {
|
||||
v.created_at <= start_time
|
||||
&& v.deleted_at.map_or(true, |d| d > start_time)
|
||||
&& v.created_by != txn_id
|
||||
})
|
||||
.map(|v| v.value.clone())
|
||||
})
|
||||
}
|
||||
|
||||
/// Read an edge with MVCC
|
||||
fn read_edge(&self, edge_id: &EdgeId, txn_id: TxnId, start_time: Timestamp) -> Option<Edge> {
|
||||
self.edge_versions.get(edge_id).and_then(|versions| {
|
||||
versions
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|v| {
|
||||
v.created_at <= start_time
|
||||
&& v.deleted_at.map_or(true, |d| d > start_time)
|
||||
&& v.created_by != txn_id
|
||||
})
|
||||
.map(|v| v.value.clone())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for TransactionManager {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
next_txn_id: AtomicU64::new(self.next_txn_id.load(Ordering::SeqCst)),
|
||||
active_txns: Arc::clone(&self.active_txns),
|
||||
committed_txns: Arc::clone(&self.committed_txns),
|
||||
node_versions: Arc::clone(&self.node_versions),
|
||||
edge_versions: Arc::clone(&self.edge_versions),
|
||||
hyperedge_versions: Arc::clone(&self.hyperedge_versions),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TransactionManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Write set for a transaction
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct WriteSet {
|
||||
nodes: HashMap<NodeId, Node>,
|
||||
edges: HashMap<EdgeId, Edge>,
|
||||
hyperedges: HashMap<HyperedgeId, Hyperedge>,
|
||||
deleted_nodes: HashSet<NodeId>,
|
||||
deleted_edges: HashSet<EdgeId>,
|
||||
deleted_hyperedges: HashSet<HyperedgeId>,
|
||||
}
|
||||
|
||||
impl WriteSet {
|
||||
fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Transaction handle
|
||||
pub struct Transaction {
|
||||
id: TxnId,
|
||||
manager: Arc<TransactionManager>,
|
||||
/// The isolation level for this transaction
|
||||
pub isolation_level: IsolationLevel,
|
||||
start_time: Timestamp,
|
||||
writes: Arc<RwLock<WriteSet>>,
|
||||
}
|
||||
|
||||
impl Transaction {
|
||||
/// Begin a new standalone transaction
|
||||
///
|
||||
/// This creates an internal TransactionManager for simple use cases.
|
||||
/// For production use, prefer using a shared TransactionManager.
|
||||
pub fn begin(isolation_level: IsolationLevel) -> Result<Self> {
|
||||
let manager = TransactionManager::new();
|
||||
Ok(manager.begin(isolation_level))
|
||||
}
|
||||
|
||||
/// Get transaction ID
|
||||
pub fn id(&self) -> TxnId {
|
||||
self.id
|
||||
}
|
||||
|
||||
/// Write a node (buffered until commit)
|
||||
pub fn write_node(&self, node: Node) {
|
||||
let mut writes = self.writes.write();
|
||||
writes.nodes.insert(node.id.clone(), node);
|
||||
}
|
||||
|
||||
/// Write an edge (buffered until commit)
|
||||
pub fn write_edge(&self, edge: Edge) {
|
||||
let mut writes = self.writes.write();
|
||||
writes.edges.insert(edge.id.clone(), edge);
|
||||
}
|
||||
|
||||
/// Write a hyperedge (buffered until commit)
|
||||
pub fn write_hyperedge(&self, hyperedge: Hyperedge) {
|
||||
let mut writes = self.writes.write();
|
||||
writes.hyperedges.insert(hyperedge.id.clone(), hyperedge);
|
||||
}
|
||||
|
||||
/// Delete a node (buffered until commit)
|
||||
pub fn delete_node(&self, node_id: NodeId) {
|
||||
let mut writes = self.writes.write();
|
||||
writes.deleted_nodes.insert(node_id);
|
||||
}
|
||||
|
||||
/// Delete an edge (buffered until commit)
|
||||
pub fn delete_edge(&self, edge_id: EdgeId) {
|
||||
let mut writes = self.writes.write();
|
||||
writes.deleted_edges.insert(edge_id);
|
||||
}
|
||||
|
||||
/// Read a node (with MVCC visibility)
|
||||
pub fn read_node(&self, node_id: &NodeId) -> Option<Node> {
|
||||
// Check write set first
|
||||
{
|
||||
let writes = self.writes.read();
|
||||
if writes.deleted_nodes.contains(node_id) {
|
||||
return None;
|
||||
}
|
||||
if let Some(node) = writes.nodes.get(node_id) {
|
||||
return Some(node.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Read from MVCC store
|
||||
self.manager.read_node(node_id, self.id, self.start_time)
|
||||
}
|
||||
|
||||
/// Read an edge (with MVCC visibility)
|
||||
pub fn read_edge(&self, edge_id: &EdgeId) -> Option<Edge> {
|
||||
// Check write set first
|
||||
{
|
||||
let writes = self.writes.read();
|
||||
if writes.deleted_edges.contains(edge_id) {
|
||||
return None;
|
||||
}
|
||||
if let Some(edge) = writes.edges.get(edge_id) {
|
||||
return Some(edge.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Read from MVCC store
|
||||
self.manager.read_edge(edge_id, self.id, self.start_time)
|
||||
}
|
||||
|
||||
/// Commit the transaction
|
||||
pub fn commit(self) -> Result<()> {
|
||||
let writes = self.writes.read();
|
||||
self.manager.commit(self.id, &writes)
|
||||
}
|
||||
|
||||
/// Rollback the transaction
|
||||
pub fn rollback(self) -> Result<()> {
|
||||
self.manager.abort(self.id)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::node::NodeBuilder;
|
||||
|
||||
#[test]
|
||||
fn test_transaction_basic() {
|
||||
let manager = TransactionManager::new();
|
||||
let txn = manager.begin(IsolationLevel::ReadCommitted);
|
||||
|
||||
assert_eq!(txn.isolation_level, IsolationLevel::ReadCommitted);
|
||||
assert!(txn.id() > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mvcc_read_write() {
|
||||
let manager = TransactionManager::new();
|
||||
|
||||
// Transaction 1: Write a node
|
||||
let txn1 = manager.begin(IsolationLevel::ReadCommitted);
|
||||
let node = NodeBuilder::new()
|
||||
.label("Person")
|
||||
.property("name", "Alice")
|
||||
.build();
|
||||
let node_id = node.id.clone();
|
||||
txn1.write_node(node.clone());
|
||||
txn1.commit().unwrap();
|
||||
|
||||
// Transaction 2: Read the node
|
||||
let txn2 = manager.begin(IsolationLevel::ReadCommitted);
|
||||
let read_node = txn2.read_node(&node_id);
|
||||
assert!(read_node.is_some());
|
||||
assert_eq!(read_node.unwrap().id, node_id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transaction_isolation() {
|
||||
let manager = TransactionManager::new();
|
||||
|
||||
let node = NodeBuilder::new().build();
|
||||
let node_id = node.id.clone();
|
||||
|
||||
// Txn1: Write but don't commit
|
||||
let txn1 = manager.begin(IsolationLevel::ReadCommitted);
|
||||
txn1.write_node(node.clone());
|
||||
|
||||
// Txn2: Should not see uncommitted write
|
||||
let txn2 = manager.begin(IsolationLevel::ReadCommitted);
|
||||
assert!(txn2.read_node(&node_id).is_none());
|
||||
|
||||
// Commit txn1
|
||||
txn1.commit().unwrap();
|
||||
|
||||
// Txn3: Should see committed write
|
||||
let txn3 = manager.begin(IsolationLevel::ReadCommitted);
|
||||
assert!(txn3.read_node(&node_id).is_some());
|
||||
}
|
||||
}
|
||||
136
vendor/ruvector/crates/ruvector-graph/src/types.rs
vendored
Normal file
136
vendor/ruvector/crates/ruvector-graph/src/types.rs
vendored
Normal file
@@ -0,0 +1,136 @@
|
||||
//! Core types for graph database
|
||||
|
||||
use bincode::{Decode, Encode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub type NodeId = String;
|
||||
pub type EdgeId = String;
|
||||
|
||||
/// Property value types for graph nodes and edges
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Encode, Decode)]
|
||||
pub enum PropertyValue {
|
||||
/// Null value
|
||||
Null,
|
||||
/// Boolean value
|
||||
Boolean(bool),
|
||||
/// 64-bit integer
|
||||
Integer(i64),
|
||||
/// 64-bit floating point
|
||||
Float(f64),
|
||||
/// UTF-8 string
|
||||
String(String),
|
||||
/// Array of values
|
||||
Array(Vec<PropertyValue>),
|
||||
/// List of values (alias for Array)
|
||||
List(Vec<PropertyValue>),
|
||||
/// Map of string keys to values
|
||||
Map(HashMap<String, PropertyValue>),
|
||||
}
|
||||
|
||||
// Convenience constructors for PropertyValue
|
||||
impl PropertyValue {
|
||||
/// Create a boolean value
|
||||
pub fn boolean(b: bool) -> Self {
|
||||
PropertyValue::Boolean(b)
|
||||
}
|
||||
/// Create an integer value
|
||||
pub fn integer(i: i64) -> Self {
|
||||
PropertyValue::Integer(i)
|
||||
}
|
||||
/// Create a float value
|
||||
pub fn float(f: f64) -> Self {
|
||||
PropertyValue::Float(f)
|
||||
}
|
||||
/// Create a string value
|
||||
pub fn string(s: impl Into<String>) -> Self {
|
||||
PropertyValue::String(s.into())
|
||||
}
|
||||
/// Create an array value
|
||||
pub fn array(arr: Vec<PropertyValue>) -> Self {
|
||||
PropertyValue::Array(arr)
|
||||
}
|
||||
/// Create a map value
|
||||
pub fn map(m: HashMap<String, PropertyValue>) -> Self {
|
||||
PropertyValue::Map(m)
|
||||
}
|
||||
}
|
||||
|
||||
// From implementations for convenient property value creation
|
||||
impl From<bool> for PropertyValue {
|
||||
fn from(b: bool) -> Self {
|
||||
PropertyValue::Boolean(b)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<i64> for PropertyValue {
|
||||
fn from(i: i64) -> Self {
|
||||
PropertyValue::Integer(i)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<i32> for PropertyValue {
|
||||
fn from(i: i32) -> Self {
|
||||
PropertyValue::Integer(i as i64)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f64> for PropertyValue {
|
||||
fn from(f: f64) -> Self {
|
||||
PropertyValue::Float(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f32> for PropertyValue {
|
||||
fn from(f: f32) -> Self {
|
||||
PropertyValue::Float(f as f64)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for PropertyValue {
|
||||
fn from(s: String) -> Self {
|
||||
PropertyValue::String(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for PropertyValue {
|
||||
fn from(s: &str) -> Self {
|
||||
PropertyValue::String(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Into<PropertyValue>> From<Vec<T>> for PropertyValue {
|
||||
fn from(v: Vec<T>) -> Self {
|
||||
PropertyValue::Array(v.into_iter().map(Into::into).collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<HashMap<String, PropertyValue>> for PropertyValue {
|
||||
fn from(m: HashMap<String, PropertyValue>) -> Self {
|
||||
PropertyValue::Map(m)
|
||||
}
|
||||
}
|
||||
|
||||
pub type Properties = HashMap<String, PropertyValue>;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Encode, Decode)]
|
||||
pub struct Label {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
impl Label {
|
||||
pub fn new(name: impl Into<String>) -> Self {
|
||||
Self { name: name.into() }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
|
||||
pub struct RelationType {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
impl RelationType {
|
||||
pub fn new(name: impl Into<String>) -> Self {
|
||||
Self { name: name.into() }
|
||||
}
|
||||
}
|
||||
363
vendor/ruvector/crates/ruvector-graph/tests/compatibility_tests.rs
vendored
Normal file
363
vendor/ruvector/crates/ruvector-graph/tests/compatibility_tests.rs
vendored
Normal file
@@ -0,0 +1,363 @@
|
||||
//! Neo4j compatibility tests
|
||||
//!
|
||||
//! Tests to verify that RuVector graph database is compatible with Neo4j
|
||||
//! in terms of query syntax and result format.
|
||||
|
||||
use ruvector_graph::{Edge, GraphDB, Label, Node, Properties, PropertyValue};
|
||||
|
||||
fn setup_movie_graph() -> GraphDB {
|
||||
let db = GraphDB::new();
|
||||
|
||||
// Actors
|
||||
let mut keanu_props = Properties::new();
|
||||
keanu_props.insert(
|
||||
"name".to_string(),
|
||||
PropertyValue::String("Keanu Reeves".to_string()),
|
||||
);
|
||||
keanu_props.insert("born".to_string(), PropertyValue::Integer(1964));
|
||||
|
||||
let mut carrie_props = Properties::new();
|
||||
carrie_props.insert(
|
||||
"name".to_string(),
|
||||
PropertyValue::String("Carrie-Anne Moss".to_string()),
|
||||
);
|
||||
carrie_props.insert("born".to_string(), PropertyValue::Integer(1967));
|
||||
|
||||
let mut laurence_props = Properties::new();
|
||||
laurence_props.insert(
|
||||
"name".to_string(),
|
||||
PropertyValue::String("Laurence Fishburne".to_string()),
|
||||
);
|
||||
laurence_props.insert("born".to_string(), PropertyValue::Integer(1961));
|
||||
|
||||
// Movies
|
||||
let mut matrix_props = Properties::new();
|
||||
matrix_props.insert(
|
||||
"title".to_string(),
|
||||
PropertyValue::String("The Matrix".to_string()),
|
||||
);
|
||||
matrix_props.insert("released".to_string(), PropertyValue::Integer(1999));
|
||||
matrix_props.insert(
|
||||
"tagline".to_string(),
|
||||
PropertyValue::String("Welcome to the Real World".to_string()),
|
||||
);
|
||||
|
||||
db.create_node(Node::new(
|
||||
"keanu".to_string(),
|
||||
vec![Label {
|
||||
name: "Person".to_string(),
|
||||
}],
|
||||
keanu_props,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
db.create_node(Node::new(
|
||||
"carrie".to_string(),
|
||||
vec![Label {
|
||||
name: "Person".to_string(),
|
||||
}],
|
||||
carrie_props,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
db.create_node(Node::new(
|
||||
"laurence".to_string(),
|
||||
vec![Label {
|
||||
name: "Person".to_string(),
|
||||
}],
|
||||
laurence_props,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
db.create_node(Node::new(
|
||||
"matrix".to_string(),
|
||||
vec![Label {
|
||||
name: "Movie".to_string(),
|
||||
}],
|
||||
matrix_props,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
// Relationships
|
||||
let mut keanu_role = Properties::new();
|
||||
keanu_role.insert(
|
||||
"roles".to_string(),
|
||||
PropertyValue::List(vec![PropertyValue::String("Neo".to_string())]),
|
||||
);
|
||||
|
||||
let mut carrie_role = Properties::new();
|
||||
carrie_role.insert(
|
||||
"roles".to_string(),
|
||||
PropertyValue::List(vec![PropertyValue::String("Trinity".to_string())]),
|
||||
);
|
||||
|
||||
let mut laurence_role = Properties::new();
|
||||
laurence_role.insert(
|
||||
"roles".to_string(),
|
||||
PropertyValue::List(vec![PropertyValue::String("Morpheus".to_string())]),
|
||||
);
|
||||
|
||||
db.create_edge(Edge::new(
|
||||
"e1".to_string(),
|
||||
"keanu".to_string(),
|
||||
"matrix".to_string(),
|
||||
"ACTED_IN".to_string(),
|
||||
keanu_role,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
db.create_edge(Edge::new(
|
||||
"e2".to_string(),
|
||||
"carrie".to_string(),
|
||||
"matrix".to_string(),
|
||||
"ACTED_IN".to_string(),
|
||||
carrie_role,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
db.create_edge(Edge::new(
|
||||
"e3".to_string(),
|
||||
"laurence".to_string(),
|
||||
"matrix".to_string(),
|
||||
"ACTED_IN".to_string(),
|
||||
laurence_role,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
db
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Neo4j Query Compatibility Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_neo4j_match_all_nodes() {
|
||||
let db = setup_movie_graph();
|
||||
|
||||
// Neo4j query: MATCH (n) RETURN n
|
||||
// TODO: Implement query execution
|
||||
// let results = db.execute("MATCH (n) RETURN n").unwrap();
|
||||
// assert_eq!(results.len(), 4); // 3 people + 1 movie
|
||||
|
||||
// For now, verify graph setup
|
||||
assert!(db.get_node("keanu").is_some());
|
||||
assert!(db.get_node("matrix").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_neo4j_match_with_label() {
|
||||
let db = setup_movie_graph();
|
||||
|
||||
// Neo4j query: MATCH (p:Person) RETURN p
|
||||
// TODO: Implement
|
||||
// let results = db.execute("MATCH (p:Person) RETURN p").unwrap();
|
||||
// assert_eq!(results.len(), 3);
|
||||
|
||||
// Verify label filtering would work
|
||||
let keanu = db.get_node("keanu").unwrap();
|
||||
assert_eq!(keanu.labels[0].name, "Person");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_neo4j_match_with_properties() {
|
||||
let db = setup_movie_graph();
|
||||
|
||||
// Neo4j query: MATCH (m:Movie {title: 'The Matrix'}) RETURN m
|
||||
// TODO: Implement
|
||||
// let results = db.execute("MATCH (m:Movie {title: 'The Matrix'}) RETURN m").unwrap();
|
||||
// assert_eq!(results.len(), 1);
|
||||
|
||||
let matrix = db.get_node("matrix").unwrap();
|
||||
assert_eq!(
|
||||
matrix.properties.get("title"),
|
||||
Some(&PropertyValue::String("The Matrix".to_string()))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_neo4j_match_relationship() {
|
||||
let db = setup_movie_graph();
|
||||
|
||||
// Neo4j query: MATCH (a:Person)-[r:ACTED_IN]->(m:Movie) RETURN a, r, m
|
||||
// TODO: Implement
|
||||
// let results = db.execute("MATCH (a:Person)-[r:ACTED_IN]->(m:Movie) RETURN a, r, m").unwrap();
|
||||
// assert_eq!(results.len(), 3);
|
||||
|
||||
let edge = db.get_edge("e1").unwrap();
|
||||
assert_eq!(edge.edge_type, "ACTED_IN");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_neo4j_where_clause() {
|
||||
let db = setup_movie_graph();
|
||||
|
||||
// Neo4j query: MATCH (p:Person) WHERE p.born > 1965 RETURN p
|
||||
// TODO: Implement
|
||||
// let results = db.execute("MATCH (p:Person) WHERE p.born > 1965 RETURN p").unwrap();
|
||||
// assert_eq!(results.len(), 1); // Only Carrie-Anne Moss
|
||||
|
||||
let carrie = db.get_node("carrie").unwrap();
|
||||
if let Some(PropertyValue::Integer(born)) = carrie.properties.get("born") {
|
||||
assert!(*born > 1965);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_neo4j_count_aggregation() {
|
||||
let db = setup_movie_graph();
|
||||
|
||||
// Neo4j query: MATCH (p:Person) RETURN COUNT(p)
|
||||
// TODO: Implement
|
||||
// let results = db.execute("MATCH (p:Person) RETURN COUNT(p)").unwrap();
|
||||
// assert_eq!(results[0]["count"], 3);
|
||||
|
||||
// Manually verify
|
||||
assert!(db.get_node("keanu").is_some());
|
||||
assert!(db.get_node("carrie").is_some());
|
||||
assert!(db.get_node("laurence").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_neo4j_collect_aggregation() {
|
||||
let db = setup_movie_graph();
|
||||
|
||||
// Neo4j query: MATCH (p:Person)-[:ACTED_IN]->(m:Movie)
|
||||
// RETURN m.title, COLLECT(p.name) AS actors
|
||||
// TODO: Implement
|
||||
// let results = db.execute("...").unwrap();
|
||||
|
||||
// Verify relationships exist
|
||||
assert!(db.get_edge("e1").is_some());
|
||||
assert!(db.get_edge("e2").is_some());
|
||||
assert!(db.get_edge("e3").is_some());
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Neo4j Data Type Compatibility
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_neo4j_string_property() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let mut props = Properties::new();
|
||||
props.insert(
|
||||
"name".to_string(),
|
||||
PropertyValue::String("Test".to_string()),
|
||||
);
|
||||
|
||||
db.create_node(Node::new("n1".to_string(), vec![], props))
|
||||
.unwrap();
|
||||
|
||||
let node = db.get_node("n1").unwrap();
|
||||
assert!(matches!(
|
||||
node.properties.get("name"),
|
||||
Some(PropertyValue::String(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_neo4j_integer_property() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let mut props = Properties::new();
|
||||
props.insert("count".to_string(), PropertyValue::Integer(42));
|
||||
|
||||
db.create_node(Node::new("n1".to_string(), vec![], props))
|
||||
.unwrap();
|
||||
|
||||
let node = db.get_node("n1").unwrap();
|
||||
assert_eq!(
|
||||
node.properties.get("count"),
|
||||
Some(&PropertyValue::Integer(42))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_neo4j_float_property() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let mut props = Properties::new();
|
||||
props.insert("score".to_string(), PropertyValue::Float(3.14));
|
||||
|
||||
db.create_node(Node::new("n1".to_string(), vec![], props))
|
||||
.unwrap();
|
||||
|
||||
let node = db.get_node("n1").unwrap();
|
||||
assert_eq!(
|
||||
node.properties.get("score"),
|
||||
Some(&PropertyValue::Float(3.14))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_neo4j_boolean_property() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let mut props = Properties::new();
|
||||
props.insert("active".to_string(), PropertyValue::Boolean(true));
|
||||
|
||||
db.create_node(Node::new("n1".to_string(), vec![], props))
|
||||
.unwrap();
|
||||
|
||||
let node = db.get_node("n1").unwrap();
|
||||
assert_eq!(
|
||||
node.properties.get("active"),
|
||||
Some(&PropertyValue::Boolean(true))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_neo4j_list_property() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let mut props = Properties::new();
|
||||
props.insert(
|
||||
"tags".to_string(),
|
||||
PropertyValue::List(vec![
|
||||
PropertyValue::String("tag1".to_string()),
|
||||
PropertyValue::String("tag2".to_string()),
|
||||
]),
|
||||
);
|
||||
|
||||
db.create_node(Node::new("n1".to_string(), vec![], props))
|
||||
.unwrap();
|
||||
|
||||
let node = db.get_node("n1").unwrap();
|
||||
assert!(matches!(
|
||||
node.properties.get("tags"),
|
||||
Some(PropertyValue::List(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_neo4j_null_property() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let mut props = Properties::new();
|
||||
props.insert("optional".to_string(), PropertyValue::Null);
|
||||
|
||||
db.create_node(Node::new("n1".to_string(), vec![], props))
|
||||
.unwrap();
|
||||
|
||||
let node = db.get_node("n1").unwrap();
|
||||
assert_eq!(node.properties.get("optional"), Some(&PropertyValue::Null));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Known Differences from Neo4j
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_documented_differences() {
|
||||
// Document any intentional differences from Neo4j behavior
|
||||
// For example:
|
||||
// - Different default values
|
||||
// - Different error messages
|
||||
// - Different performance characteristics
|
||||
// - Missing features
|
||||
|
||||
// This test serves as documentation
|
||||
assert!(true);
|
||||
}
|
||||
396
vendor/ruvector/crates/ruvector-graph/tests/concurrent_tests.rs
vendored
Normal file
396
vendor/ruvector/crates/ruvector-graph/tests/concurrent_tests.rs
vendored
Normal file
@@ -0,0 +1,396 @@
|
||||
//! Concurrent access pattern tests
|
||||
//!
|
||||
//! Tests for multi-threaded access, lock-free operations, and concurrent modifications.
|
||||
|
||||
use ruvector_graph::{Edge, GraphDB, Label, Node, Properties, PropertyValue};
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_node_creation() {
|
||||
let db = Arc::new(GraphDB::new());
|
||||
let num_threads = 10;
|
||||
let nodes_per_thread = 100;
|
||||
|
||||
let handles: Vec<_> = (0..num_threads)
|
||||
.map(|thread_id| {
|
||||
let db_clone = Arc::clone(&db);
|
||||
thread::spawn(move || {
|
||||
for i in 0..nodes_per_thread {
|
||||
let mut props = Properties::new();
|
||||
props.insert("thread".to_string(), PropertyValue::Integer(thread_id));
|
||||
props.insert("index".to_string(), PropertyValue::Integer(i));
|
||||
|
||||
let node = Node::new(
|
||||
format!("node_{}_{}", thread_id, i),
|
||||
vec![Label {
|
||||
name: "Concurrent".to_string(),
|
||||
}],
|
||||
props,
|
||||
);
|
||||
|
||||
db_clone.create_node(node).unwrap();
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// Verify all nodes were created
|
||||
// Note: Would need node_count() method
|
||||
// assert_eq!(db.node_count(), num_threads * nodes_per_thread);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_reads() {
|
||||
let db = Arc::new(GraphDB::new());
|
||||
|
||||
// Create initial nodes
|
||||
for i in 0..100 {
|
||||
let node = Node::new(format!("node_{}", i), vec![], Properties::new());
|
||||
db.create_node(node).unwrap();
|
||||
}
|
||||
|
||||
let num_readers = 20;
|
||||
let reads_per_thread = 1000;
|
||||
|
||||
let handles: Vec<_> = (0..num_readers)
|
||||
.map(|thread_id| {
|
||||
let db_clone = Arc::clone(&db);
|
||||
thread::spawn(move || {
|
||||
for i in 0..reads_per_thread {
|
||||
let node_id = format!("node_{}", (thread_id * 10 + i) % 100);
|
||||
let result = db_clone.get_node(&node_id);
|
||||
assert!(result.is_some());
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_writes_no_collision() {
|
||||
let db = Arc::new(GraphDB::new());
|
||||
let num_threads = 10;
|
||||
|
||||
let handles: Vec<_> = (0..num_threads)
|
||||
.map(|thread_id| {
|
||||
let db_clone = Arc::clone(&db);
|
||||
thread::spawn(move || {
|
||||
for i in 0..50 {
|
||||
let node_id = format!("t{}_n{}", thread_id, i);
|
||||
let node = Node::new(node_id, vec![], Properties::new());
|
||||
db_clone.create_node(node).unwrap();
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// All 500 nodes should be created
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_edge_creation() {
|
||||
let db = Arc::new(GraphDB::new());
|
||||
|
||||
// Create nodes first
|
||||
for i in 0..100 {
|
||||
db.create_node(Node::new(format!("n{}", i), vec![], Properties::new()))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let num_threads = 10;
|
||||
|
||||
let handles: Vec<_> = (0..num_threads)
|
||||
.map(|thread_id| {
|
||||
let db_clone = Arc::clone(&db);
|
||||
thread::spawn(move || {
|
||||
for i in 0..50 {
|
||||
let from = format!("n{}", (thread_id * 10 + i) % 100);
|
||||
let to = format!("n{}", (thread_id * 10 + i + 1) % 100);
|
||||
|
||||
let edge = Edge::new(
|
||||
format!("e_{}_{}", thread_id, i),
|
||||
from,
|
||||
to,
|
||||
"LINK".to_string(),
|
||||
Properties::new(),
|
||||
);
|
||||
|
||||
db_clone.create_edge(edge).unwrap();
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_read_while_writing() {
|
||||
let db = Arc::new(GraphDB::new());
|
||||
|
||||
// Initial nodes
|
||||
for i in 0..50 {
|
||||
db.create_node(Node::new(
|
||||
format!("initial_{}", i),
|
||||
vec![],
|
||||
Properties::new(),
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let num_readers = 5;
|
||||
let num_writers = 3;
|
||||
|
||||
let mut handles = vec![];
|
||||
|
||||
// Spawn readers
|
||||
for reader_id in 0..num_readers {
|
||||
let db_clone = Arc::clone(&db);
|
||||
let handle = thread::spawn(move || {
|
||||
for i in 0..100 {
|
||||
let node_id = format!("initial_{}", (reader_id * 10 + i) % 50);
|
||||
let _ = db_clone.get_node(&node_id);
|
||||
}
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Spawn writers
|
||||
for writer_id in 0..num_writers {
|
||||
let db_clone = Arc::clone(&db);
|
||||
let handle = thread::spawn(move || {
|
||||
for i in 0..100 {
|
||||
let node = Node::new(
|
||||
format!("new_{}_{}", writer_id, i),
|
||||
vec![],
|
||||
Properties::new(),
|
||||
);
|
||||
db_clone.create_node(node).unwrap();
|
||||
}
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_property_updates() {
|
||||
let db = Arc::new(GraphDB::new());
|
||||
|
||||
// Create shared counter node
|
||||
let mut props = Properties::new();
|
||||
props.insert("counter".to_string(), PropertyValue::Integer(0));
|
||||
db.create_node(Node::new("counter".to_string(), vec![], props))
|
||||
.unwrap();
|
||||
|
||||
// TODO: Implement atomic property updates
|
||||
// For now, just test concurrent reads
|
||||
let num_threads = 10;
|
||||
|
||||
let handles: Vec<_> = (0..num_threads)
|
||||
.map(|_| {
|
||||
let db_clone = Arc::clone(&db);
|
||||
thread::spawn(move || {
|
||||
for _ in 0..100 {
|
||||
let _node = db_clone.get_node("counter");
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lock_free_reads() {
|
||||
let db = Arc::new(GraphDB::new());
|
||||
|
||||
// Populate database
|
||||
for i in 0..1000 {
|
||||
db.create_node(Node::new(
|
||||
format!("node_{}", i),
|
||||
vec![Label {
|
||||
name: "Test".to_string(),
|
||||
}],
|
||||
Properties::new(),
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Many concurrent readers should not block each other
|
||||
let num_readers = 50;
|
||||
|
||||
let handles: Vec<_> = (0..num_readers)
|
||||
.map(|reader_id| {
|
||||
let db_clone = Arc::clone(&db);
|
||||
thread::spawn(move || {
|
||||
for i in 0..100 {
|
||||
let node_id = format!("node_{}", (reader_id * 20 + i) % 1000);
|
||||
let result = db_clone.get_node(&node_id);
|
||||
assert!(result.is_some());
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
// With lock-free reads, this should complete quickly
|
||||
// Even with 50 threads doing 100 reads each (5000 reads total)
|
||||
println!("Concurrent reads took: {:?}", duration);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writer_starvation_prevention() {
|
||||
// Ensure that heavy read load doesn't prevent writes
|
||||
|
||||
let db = Arc::new(GraphDB::new());
|
||||
|
||||
// Initial data
|
||||
for i in 0..100 {
|
||||
db.create_node(Node::new(
|
||||
format!("initial_{}", i),
|
||||
vec![],
|
||||
Properties::new(),
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let readers_done = Arc::new(std::sync::atomic::AtomicBool::new(false));
|
||||
let writers_done = Arc::new(std::sync::atomic::AtomicBool::new(false));
|
||||
|
||||
let mut handles = vec![];
|
||||
|
||||
// Heavy read load
|
||||
for reader_id in 0..20i64 {
|
||||
let db_clone = Arc::clone(&db);
|
||||
let done = Arc::clone(&readers_done);
|
||||
let handle = thread::spawn(move || {
|
||||
for i in 0..1000i64 {
|
||||
let node_id = format!("initial_{}", (reader_id + i) % 100);
|
||||
let _ = db_clone.get_node(&node_id);
|
||||
}
|
||||
done.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Writers should still make progress
|
||||
for writer_id in 0..5 {
|
||||
let db_clone = Arc::clone(&db);
|
||||
let done = Arc::clone(&writers_done);
|
||||
let handle = thread::spawn(move || {
|
||||
for i in 0..50 {
|
||||
let node = Node::new(
|
||||
format!("writer_{}_{}", writer_id, i),
|
||||
vec![],
|
||||
Properties::new(),
|
||||
);
|
||||
db_clone.create_node(node).unwrap();
|
||||
}
|
||||
done.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// Verify both readers and writers completed
|
||||
assert!(readers_done.load(std::sync::atomic::Ordering::Relaxed));
|
||||
assert!(writers_done.load(std::sync::atomic::Ordering::Relaxed));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Stress Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_high_concurrency_stress() {
|
||||
let db = Arc::new(GraphDB::new());
|
||||
let num_threads = 50;
|
||||
|
||||
let handles: Vec<_> = (0..num_threads)
|
||||
.map(|thread_id| {
|
||||
let db_clone = Arc::clone(&db);
|
||||
thread::spawn(move || {
|
||||
// Mix of operations
|
||||
for i in 0i32..100 {
|
||||
if i % 3 == 0 {
|
||||
// Create node
|
||||
let node = Node::new(
|
||||
format!("stress_{}_{}", thread_id, i),
|
||||
vec![],
|
||||
Properties::new(),
|
||||
);
|
||||
db_clone.create_node(node).unwrap();
|
||||
} else {
|
||||
// Read node (might not exist)
|
||||
let node_id = format!("stress_{}_{}", thread_id, i.saturating_sub(1));
|
||||
let _ = db_clone.get_node(&node_id);
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_batch_operations() {
|
||||
let db = Arc::new(GraphDB::new());
|
||||
let num_threads = 10;
|
||||
let batch_size = 100;
|
||||
|
||||
let handles: Vec<_> = (0..num_threads)
|
||||
.map(|thread_id| {
|
||||
let db_clone = Arc::clone(&db);
|
||||
thread::spawn(move || {
|
||||
// TODO: Implement batch insert
|
||||
// For now, insert individually
|
||||
for i in 0..batch_size {
|
||||
let node = Node::new(
|
||||
format!("batch_{}_{}", thread_id, i),
|
||||
vec![],
|
||||
Properties::new(),
|
||||
);
|
||||
db_clone.create_node(node).unwrap();
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
}
|
||||
405
vendor/ruvector/crates/ruvector-graph/tests/cypher_execution_tests.rs
vendored
Normal file
405
vendor/ruvector/crates/ruvector-graph/tests/cypher_execution_tests.rs
vendored
Normal file
@@ -0,0 +1,405 @@
|
||||
//! Cypher query execution correctness tests
|
||||
//!
|
||||
//! Tests to verify that Cypher queries execute correctly and return expected results.
|
||||
|
||||
use ruvector_graph::{Edge, GraphDB, Label, Node, Properties, PropertyValue};
|
||||
|
||||
fn setup_test_graph() -> GraphDB {
|
||||
let db = GraphDB::new();
|
||||
|
||||
// Create people
|
||||
let mut alice_props = Properties::new();
|
||||
alice_props.insert(
|
||||
"name".to_string(),
|
||||
PropertyValue::String("Alice".to_string()),
|
||||
);
|
||||
alice_props.insert("age".to_string(), PropertyValue::Integer(30));
|
||||
|
||||
let mut bob_props = Properties::new();
|
||||
bob_props.insert("name".to_string(), PropertyValue::String("Bob".to_string()));
|
||||
bob_props.insert("age".to_string(), PropertyValue::Integer(35));
|
||||
|
||||
let mut charlie_props = Properties::new();
|
||||
charlie_props.insert(
|
||||
"name".to_string(),
|
||||
PropertyValue::String("Charlie".to_string()),
|
||||
);
|
||||
charlie_props.insert("age".to_string(), PropertyValue::Integer(28));
|
||||
|
||||
db.create_node(Node::new(
|
||||
"alice".to_string(),
|
||||
vec![Label {
|
||||
name: "Person".to_string(),
|
||||
}],
|
||||
alice_props,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
db.create_node(Node::new(
|
||||
"bob".to_string(),
|
||||
vec![Label {
|
||||
name: "Person".to_string(),
|
||||
}],
|
||||
bob_props,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
db.create_node(Node::new(
|
||||
"charlie".to_string(),
|
||||
vec![Label {
|
||||
name: "Person".to_string(),
|
||||
}],
|
||||
charlie_props,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
// Create relationships
|
||||
db.create_edge(Edge::new(
|
||||
"e1".to_string(),
|
||||
"alice".to_string(),
|
||||
"bob".to_string(),
|
||||
"KNOWS".to_string(),
|
||||
Properties::new(),
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
db.create_edge(Edge::new(
|
||||
"e2".to_string(),
|
||||
"bob".to_string(),
|
||||
"charlie".to_string(),
|
||||
"KNOWS".to_string(),
|
||||
Properties::new(),
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
db
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execute_simple_match_all_nodes() {
|
||||
let db = setup_test_graph();
|
||||
|
||||
// TODO: Implement query execution
|
||||
// let results = db.execute("MATCH (n) RETURN n").unwrap();
|
||||
// assert_eq!(results.len(), 3);
|
||||
|
||||
// For now, just verify the graph was set up correctly
|
||||
assert!(db.get_node("alice").is_some());
|
||||
assert!(db.get_node("bob").is_some());
|
||||
assert!(db.get_node("charlie").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execute_match_with_label_filter() {
|
||||
let db = setup_test_graph();
|
||||
|
||||
// TODO: Implement
|
||||
// let results = db.execute("MATCH (n:Person) RETURN n").unwrap();
|
||||
// assert_eq!(results.len(), 3);
|
||||
|
||||
assert!(db.get_node("alice").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execute_match_with_property_filter() {
|
||||
let db = setup_test_graph();
|
||||
|
||||
// TODO: Implement
|
||||
// let results = db.execute("MATCH (n:Person {name: 'Alice'}) RETURN n").unwrap();
|
||||
// assert_eq!(results.len(), 1);
|
||||
|
||||
let alice = db.get_node("alice").unwrap();
|
||||
assert_eq!(
|
||||
alice.properties.get("name"),
|
||||
Some(&PropertyValue::String("Alice".to_string()))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execute_match_with_where_clause() {
|
||||
let db = setup_test_graph();
|
||||
|
||||
// TODO: Implement
|
||||
// let results = db.execute("MATCH (n:Person) WHERE n.age > 30 RETURN n").unwrap();
|
||||
// Should return Bob (35)
|
||||
// assert_eq!(results.len(), 1);
|
||||
|
||||
let bob = db.get_node("bob").unwrap();
|
||||
if let Some(PropertyValue::Integer(age)) = bob.properties.get("age") {
|
||||
assert!(*age > 30);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execute_match_relationship() {
|
||||
let db = setup_test_graph();
|
||||
|
||||
// TODO: Implement
|
||||
// let results = db.execute("MATCH (a)-[r:KNOWS]->(b) RETURN a, r, b").unwrap();
|
||||
// Should return 2 relationships
|
||||
|
||||
assert!(db.get_edge("e1").is_some());
|
||||
assert!(db.get_edge("e2").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execute_create_node() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
// TODO: Implement
|
||||
// db.execute("CREATE (n:Person {name: 'David', age: 40})").unwrap();
|
||||
|
||||
// For now, create manually
|
||||
let mut props = Properties::new();
|
||||
props.insert(
|
||||
"name".to_string(),
|
||||
PropertyValue::String("David".to_string()),
|
||||
);
|
||||
props.insert("age".to_string(), PropertyValue::Integer(40));
|
||||
|
||||
db.create_node(Node::new(
|
||||
"david".to_string(),
|
||||
vec![Label {
|
||||
name: "Person".to_string(),
|
||||
}],
|
||||
props,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
let david = db.get_node("david").unwrap();
|
||||
assert_eq!(
|
||||
david.properties.get("name"),
|
||||
Some(&PropertyValue::String("David".to_string()))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execute_count_aggregation() {
|
||||
let db = setup_test_graph();
|
||||
|
||||
// TODO: Implement
|
||||
// let results = db.execute("MATCH (n:Person) RETURN COUNT(n) AS count").unwrap();
|
||||
// assert_eq!(results[0]["count"], 3);
|
||||
|
||||
// Manual verification
|
||||
assert!(db.get_node("alice").is_some());
|
||||
assert!(db.get_node("bob").is_some());
|
||||
assert!(db.get_node("charlie").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execute_sum_aggregation() {
|
||||
let db = setup_test_graph();
|
||||
|
||||
// TODO: Implement
|
||||
// let results = db.execute("MATCH (n:Person) RETURN SUM(n.age) AS total_age").unwrap();
|
||||
// assert_eq!(results[0]["total_age"], 93); // 30 + 35 + 28
|
||||
|
||||
// Manual verification
|
||||
let ages: Vec<i64> = ["alice", "bob", "charlie"]
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
db.get_node(*id).and_then(|n| {
|
||||
if let Some(PropertyValue::Integer(age)) = n.properties.get("age") {
|
||||
Some(*age)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
assert_eq!(ages.iter().sum::<i64>(), 93);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execute_avg_aggregation() {
|
||||
let db = setup_test_graph();
|
||||
|
||||
// TODO: Implement
|
||||
// let results = db.execute("MATCH (n:Person) RETURN AVG(n.age) AS avg_age").unwrap();
|
||||
// assert_eq!(results[0]["avg_age"], 31.0); // (30 + 35 + 28) / 3
|
||||
|
||||
let ages: Vec<i64> = ["alice", "bob", "charlie"]
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
db.get_node(*id).and_then(|n| {
|
||||
if let Some(PropertyValue::Integer(age)) = n.properties.get("age") {
|
||||
Some(*age)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let avg = ages.iter().sum::<i64>() as f64 / ages.len() as f64;
|
||||
assert!((avg - 31.0).abs() < 0.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execute_order_by() {
|
||||
let db = setup_test_graph();
|
||||
|
||||
// TODO: Implement
|
||||
// let results = db.execute("MATCH (n:Person) RETURN n ORDER BY n.age ASC").unwrap();
|
||||
// First should be Charlie (28), last should be Bob (35)
|
||||
|
||||
let mut ages: Vec<i64> = ["alice", "bob", "charlie"]
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
db.get_node(*id).and_then(|n| {
|
||||
if let Some(PropertyValue::Integer(age)) = n.properties.get("age") {
|
||||
Some(*age)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
ages.sort();
|
||||
assert_eq!(ages, vec![28, 30, 35]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execute_limit() {
|
||||
let db = setup_test_graph();
|
||||
|
||||
// TODO: Implement
|
||||
// let results = db.execute("MATCH (n:Person) RETURN n LIMIT 2").unwrap();
|
||||
// assert_eq!(results.len(), 2);
|
||||
|
||||
assert!(db.get_node("alice").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execute_path_query() {
|
||||
let db = setup_test_graph();
|
||||
|
||||
// TODO: Implement
|
||||
// let results = db.execute("MATCH p = (a:Person)-[:KNOWS*1..2]->(b:Person) RETURN p").unwrap();
|
||||
// Should find paths: Alice->Bob, Bob->Charlie, Alice->Bob->Charlie
|
||||
|
||||
let e1 = db.get_edge("e1").unwrap();
|
||||
let e2 = db.get_edge("e2").unwrap();
|
||||
|
||||
assert_eq!(e1.from, "alice");
|
||||
assert_eq!(e1.to, "bob");
|
||||
assert_eq!(e2.from, "bob");
|
||||
assert_eq!(e2.to, "charlie");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Complex Query Execution Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_execute_multi_hop_traversal() {
|
||||
let db = setup_test_graph();
|
||||
|
||||
// TODO: Implement
|
||||
// Find all people connected to Alice within 2 hops
|
||||
// let results = db.execute("
|
||||
// MATCH (alice:Person {name: 'Alice'})-[:KNOWS*1..2]->(connected)
|
||||
// RETURN DISTINCT connected.name
|
||||
// ").unwrap();
|
||||
|
||||
// Should find Bob (1 hop) and Charlie (2 hops)
|
||||
|
||||
assert!(db.get_node("bob").is_some());
|
||||
assert!(db.get_node("charlie").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execute_pattern_matching() {
|
||||
let db = setup_test_graph();
|
||||
|
||||
// TODO: Implement
|
||||
// let results = db.execute("
|
||||
// MATCH (a:Person)-[:KNOWS]->(b:Person)-[:KNOWS]->(c:Person)
|
||||
// RETURN a.name, c.name
|
||||
// ").unwrap();
|
||||
|
||||
// Should find Alice knows Charlie through Bob
|
||||
|
||||
assert!(db.get_edge("e1").is_some());
|
||||
assert!(db.get_edge("e2").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execute_collect_aggregation() {
|
||||
let db = setup_test_graph();
|
||||
|
||||
// TODO: Implement
|
||||
// let results = db.execute("
|
||||
// MATCH (p:Person)-[:KNOWS]->(friend)
|
||||
// RETURN p.name, COLLECT(friend.name) AS friends
|
||||
// ").unwrap();
|
||||
|
||||
// Alice: [Bob], Bob: [Charlie], Charlie: []
|
||||
|
||||
assert!(db.get_edge("e1").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execute_optional_match() {
|
||||
let db = setup_test_graph();
|
||||
|
||||
// TODO: Implement
|
||||
// let results = db.execute("
|
||||
// MATCH (p:Person)
|
||||
// OPTIONAL MATCH (p)-[:KNOWS]->(friend)
|
||||
// RETURN p.name, friend.name
|
||||
// ").unwrap();
|
||||
|
||||
// Should return all people, some with null friends
|
||||
|
||||
assert!(db.get_node("charlie").is_some());
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Result Verification Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_query_result_schema() {
|
||||
// TODO: Implement
|
||||
// Verify that query results have correct schema
|
||||
// let db = setup_test_graph();
|
||||
// let results = db.execute("MATCH (n:Person) RETURN n.name AS name, n.age AS age").unwrap();
|
||||
// assert!(results.has_column("name"));
|
||||
// assert!(results.has_column("age"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_result_ordering() {
|
||||
// TODO: Implement
|
||||
// Verify that ORDER BY is correctly applied
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_result_pagination() {
|
||||
// TODO: Implement
|
||||
// Verify SKIP and LIMIT work correctly together
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Error Handling Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_execute_invalid_property_access() {
|
||||
// TODO: Implement
|
||||
// let db = setup_test_graph();
|
||||
// let result = db.execute("MATCH (n:Person) WHERE n.nonexistent > 5 RETURN n");
|
||||
// Should handle gracefully (return no results or error depending on semantics)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execute_type_mismatch() {
|
||||
// TODO: Implement
|
||||
// let db = setup_test_graph();
|
||||
// let result = db.execute("MATCH (n:Person) WHERE n.name > 5 RETURN n");
|
||||
// Should handle type mismatch error
|
||||
}
|
||||
166
vendor/ruvector/crates/ruvector-graph/tests/cypher_parser_integration.rs
vendored
Normal file
166
vendor/ruvector/crates/ruvector-graph/tests/cypher_parser_integration.rs
vendored
Normal file
@@ -0,0 +1,166 @@
|
||||
//! Integration tests for Cypher parser
|
||||
|
||||
use ruvector_graph::cypher::{ast::*, parse_cypher};
|
||||
|
||||
#[test]
|
||||
fn test_simple_match_query() {
|
||||
let query = "MATCH (n:Person) RETURN n";
|
||||
let result = parse_cypher(query);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Failed to parse simple MATCH query: {:?}",
|
||||
result.err()
|
||||
);
|
||||
|
||||
let ast = result.unwrap();
|
||||
assert_eq!(ast.statements.len(), 2); // MATCH and RETURN
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_match_with_where() {
|
||||
let query = "MATCH (n:Person) WHERE n.age > 30 RETURN n.name";
|
||||
let result = parse_cypher(query);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Failed to parse MATCH with WHERE: {:?}",
|
||||
result.err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_relationship_pattern() {
|
||||
let query = "MATCH (a:Person)-[r:KNOWS]->(b:Person) RETURN a, r, b";
|
||||
let result = parse_cypher(query);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Failed to parse relationship pattern: {:?}",
|
||||
result.err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_node() {
|
||||
let query = "CREATE (n:Person {name: 'Alice', age: 30})";
|
||||
let result = parse_cypher(query);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Failed to parse CREATE query: {:?}",
|
||||
result.err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "Hyperedge syntax not yet implemented in parser"]
|
||||
fn test_hyperedge_pattern() {
|
||||
let query = "MATCH (a)-[r:TRANSACTION]->(b, c, d) RETURN a, r, b, c, d";
|
||||
let result = parse_cypher(query);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Failed to parse hyperedge: {:?}",
|
||||
result.err()
|
||||
);
|
||||
|
||||
let ast = result.unwrap();
|
||||
assert!(ast.has_hyperedges(), "Query should contain hyperedges");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregation_functions() {
|
||||
let query = "MATCH (n:Person) RETURN COUNT(n), AVG(n.age), MAX(n.salary)";
|
||||
let result = parse_cypher(query);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Failed to parse aggregation query: {:?}",
|
||||
result.err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_order_by_limit() {
|
||||
let query = "MATCH (n:Person) RETURN n.name ORDER BY n.age DESC LIMIT 10";
|
||||
let result = parse_cypher(query);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Failed to parse ORDER BY LIMIT: {:?}",
|
||||
result.err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_complex_query() {
|
||||
let query = r#"
|
||||
MATCH (a:Person)-[r:KNOWS]->(b:Person)
|
||||
WHERE a.age > 30 AND b.name = 'Alice'
|
||||
RETURN a.name, b.name, r.since
|
||||
ORDER BY r.since DESC
|
||||
LIMIT 10
|
||||
"#;
|
||||
let result = parse_cypher(query);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Failed to parse complex query: {:?}",
|
||||
result.err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "CREATE relationship with properties not yet fully implemented"]
|
||||
fn test_create_relationship() {
|
||||
let query = r#"
|
||||
MATCH (a:Person), (b:Person)
|
||||
WHERE a.name = 'Alice' AND b.name = 'Bob'
|
||||
CREATE (a)-[:KNOWS {since: 2024}]->(b)
|
||||
"#;
|
||||
let result = parse_cypher(query);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Failed to parse CREATE relationship: {:?}",
|
||||
result.err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "MERGE ON CREATE SET not yet implemented"]
|
||||
fn test_merge_pattern() {
|
||||
let query = "MERGE (n:Person {name: 'Alice'}) ON CREATE SET n.created = 2024";
|
||||
let result = parse_cypher(query);
|
||||
assert!(result.is_ok(), "Failed to parse MERGE: {:?}", result.err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_clause() {
|
||||
let query = r#"
|
||||
MATCH (n:Person)
|
||||
WITH n, n.age AS age
|
||||
WHERE age > 30
|
||||
RETURN n.name, age
|
||||
"#;
|
||||
let result = parse_cypher(query);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Failed to parse WITH clause: {:?}",
|
||||
result.err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_path_pattern() {
|
||||
let query = "MATCH p = (a:Person)-[*1..3]->(b:Person) RETURN p";
|
||||
let result = parse_cypher(query);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Failed to parse path pattern: {:?}",
|
||||
result.err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_read_only() {
|
||||
let query1 = "MATCH (n:Person) RETURN n";
|
||||
let ast1 = parse_cypher(query1).unwrap();
|
||||
assert!(ast1.is_read_only());
|
||||
|
||||
let query2 = "CREATE (n:Person {name: 'Alice'})";
|
||||
let ast2 = parse_cypher(query2).unwrap();
|
||||
assert!(!ast2.is_read_only());
|
||||
}
|
||||
223
vendor/ruvector/crates/ruvector-graph/tests/cypher_parser_tests.rs
vendored
Normal file
223
vendor/ruvector/crates/ruvector-graph/tests/cypher_parser_tests.rs
vendored
Normal file
@@ -0,0 +1,223 @@
|
||||
//! Cypher query parser tests
|
||||
//!
|
||||
//! Tests for parsing valid and invalid Cypher queries to ensure syntax correctness.
|
||||
|
||||
use ruvector_graph::cypher::parse_cypher;
|
||||
|
||||
// ============================================================================
|
||||
// Valid Cypher Queries
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_parse_simple_match() {
|
||||
let result = parse_cypher("MATCH (n) RETURN n");
|
||||
assert!(result.is_ok(), "Parse failed: {:?}", result.err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_match_with_label() {
|
||||
let result = parse_cypher("MATCH (n:Person) RETURN n");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_match_with_properties() {
|
||||
let result = parse_cypher("MATCH (n:Person {name: 'Alice'}) RETURN n");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_match_relationship() {
|
||||
let result = parse_cypher("MATCH (a)-[r:KNOWS]->(b) RETURN a, r, b");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_match_undirected_relationship() {
|
||||
let result = parse_cypher("MATCH (a)-[r:FRIEND]-(b) RETURN a, b");
|
||||
assert!(result.is_ok(), "Parse failed: {:?}", result.err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_match_path() {
|
||||
let result = parse_cypher("MATCH p = (a)-[:KNOWS*1..3]->(b) RETURN p");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_create_node() {
|
||||
let result = parse_cypher("CREATE (n:Person {name: 'Bob', age: 30})");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_create_relationship() {
|
||||
let result =
|
||||
parse_cypher("CREATE (a:Person {name: 'Alice'})-[r:KNOWS]->(b:Person {name: 'Bob'})");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_merge() {
|
||||
let result = parse_cypher("MERGE (n:Person {name: 'Charlie'})");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_delete() {
|
||||
let result = parse_cypher("MATCH (n:Person {name: 'Alice'}) DELETE n");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_set_property() {
|
||||
let result = parse_cypher("MATCH (n:Person {name: 'Alice'}) SET n.age = 31");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_remove_property() {
|
||||
let result = parse_cypher("MATCH (n:Person {name: 'Alice'}) REMOVE n.age");
|
||||
assert!(result.is_ok(), "Parse failed: {:?}", result.err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_where_clause() {
|
||||
let result = parse_cypher("MATCH (n:Person) WHERE n.age > 25 RETURN n");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_order_by() {
|
||||
let result = parse_cypher("MATCH (n:Person) RETURN n ORDER BY n.age DESC");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_limit() {
|
||||
let result = parse_cypher("MATCH (n:Person) RETURN n LIMIT 10");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_skip() {
|
||||
let result = parse_cypher("MATCH (n:Person) RETURN n SKIP 5 LIMIT 10");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_aggregate_count() {
|
||||
let result = parse_cypher("MATCH (n:Person) RETURN COUNT(n)");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_aggregate_sum() {
|
||||
let result = parse_cypher("MATCH (n:Person) RETURN SUM(n.age)");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_aggregate_avg() {
|
||||
let result = parse_cypher("MATCH (n:Person) RETURN AVG(n.age)");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_with_clause() {
|
||||
let result = parse_cypher("MATCH (n:Person) WITH n.age AS age WHERE age > 25 RETURN age");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_optional_match() {
|
||||
let result = parse_cypher("OPTIONAL MATCH (n:Person)-[r:KNOWS]->(m) RETURN n, m");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Complex Query Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
#[ignore = "Complex multi-direction patterns with <- not yet fully implemented"]
|
||||
fn test_parse_complex_graph_pattern() {
|
||||
let result = parse_cypher(
|
||||
"
|
||||
MATCH (user:User)-[:PURCHASED]->(product:Product)<-[:PURCHASED]-(other:User)
|
||||
WHERE other.id <> 123
|
||||
WITH other, COUNT(*) AS commonProducts
|
||||
WHERE commonProducts > 3
|
||||
RETURN other.name
|
||||
ORDER BY commonProducts DESC
|
||||
LIMIT 10
|
||||
",
|
||||
);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_variable_length_path() {
|
||||
let result =
|
||||
parse_cypher("MATCH (a:Person)-[:KNOWS*1..5]->(b:Person) WHERE a.name = 'Alice' RETURN b");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_multiple_patterns() {
|
||||
let result = parse_cypher(
|
||||
"
|
||||
MATCH (a:Person)-[:KNOWS]->(b:Person)
|
||||
MATCH (b)-[:WORKS_AT]->(c:Company)
|
||||
RETURN a.name, b.name, c.name
|
||||
",
|
||||
);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_collect_aggregation() {
|
||||
let result = parse_cypher(
|
||||
"MATCH (p:Person)-[:KNOWS]->(f:Person) RETURN p.name, COLLECT(f.name) AS friends",
|
||||
);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Edge Cases
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
#[ignore = "Empty query validation not yet implemented"]
|
||||
fn test_parse_empty_query() {
|
||||
let result = parse_cypher("");
|
||||
// Empty query should fail
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "Whitespace-only query validation not yet implemented"]
|
||||
fn test_parse_whitespace_only() {
|
||||
let result = parse_cypher(" \n\t ");
|
||||
// Whitespace only should fail
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_parameters() {
|
||||
let result = parse_cypher("MATCH (n:Person {name: $name, age: $age}) RETURN n");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_list_literal() {
|
||||
let result = parse_cypher("RETURN [1, 2, 3, 4, 5] AS numbers");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "Map literal in RETURN not yet implemented"]
|
||||
fn test_parse_map_literal() {
|
||||
let result = parse_cypher("RETURN {name: 'Alice', age: 30} AS person");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
295
vendor/ruvector/crates/ruvector-graph/tests/distributed_tests.rs
vendored
Normal file
295
vendor/ruvector/crates/ruvector-graph/tests/distributed_tests.rs
vendored
Normal file
@@ -0,0 +1,295 @@
|
||||
//! Distributed graph database tests
|
||||
//!
|
||||
//! Tests for clustering, replication, sharding, and federation.
|
||||
|
||||
#[test]
|
||||
fn test_placeholder_distributed() {
|
||||
// TODO: Implement distributed tests when distributed features are available
|
||||
assert!(true);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Cluster Setup Tests
|
||||
// ============================================================================
|
||||
|
||||
// #[test]
|
||||
// fn test_three_node_cluster() {
|
||||
// // TODO: Set up a 3-node cluster
|
||||
// // Verify all nodes can communicate
|
||||
// // Verify leader election works
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_cluster_discovery() {
|
||||
// // TODO: Test node discovery mechanism
|
||||
// // New node should discover existing cluster
|
||||
// }
|
||||
|
||||
// ============================================================================
|
||||
// Data Sharding Tests
|
||||
// ============================================================================
|
||||
|
||||
// #[test]
|
||||
// fn test_hash_based_sharding() {
|
||||
// // TODO: Test that data is distributed across shards based on hash
|
||||
// // Create nodes on different shards
|
||||
// // Verify they end up on correct nodes
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_range_based_sharding() {
|
||||
// // TODO: Test range-based sharding for ordered data
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_shard_rebalancing() {
|
||||
// // TODO: Test automatic rebalancing when adding/removing nodes
|
||||
// }
|
||||
|
||||
// ============================================================================
|
||||
// Replication Tests
|
||||
// ============================================================================
|
||||
|
||||
// #[test]
|
||||
// fn test_synchronous_replication() {
|
||||
// // TODO: Write to leader, verify data appears on all replicas
|
||||
// // before write is acknowledged
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_asynchronous_replication() {
|
||||
// // TODO: Write to leader, verify data eventually appears on replicas
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_replica_consistency() {
|
||||
// // TODO: Verify all replicas have same data
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_read_from_replica() {
|
||||
// // TODO: Verify reads can be served from replicas
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_replica_lag_monitoring() {
|
||||
// // TODO: Monitor replication lag
|
||||
// }
|
||||
|
||||
// ============================================================================
|
||||
// Leader Election Tests
|
||||
// ============================================================================
|
||||
|
||||
// #[test]
|
||||
// fn test_leader_election_on_startup() {
|
||||
// // TODO: Start cluster, verify leader is elected
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_leader_failover() {
|
||||
// // TODO: Kill leader, verify new leader is elected
|
||||
// // Verify cluster remains available
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_split_brain_prevention() {
|
||||
// // TODO: Simulate network partition
|
||||
// // Verify that split brain doesn't occur
|
||||
// }
|
||||
|
||||
// ============================================================================
|
||||
// Distributed Queries
|
||||
// ============================================================================
|
||||
|
||||
// #[test]
|
||||
// fn test_cross_shard_query() {
|
||||
// // TODO: Query that requires data from multiple shards
|
||||
// // Verify correct results are returned
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_distributed_aggregation() {
|
||||
// // TODO: Aggregation query across shards
|
||||
// // COUNT, SUM, etc. should work correctly
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_distributed_traversal() {
|
||||
// // TODO: Graph traversal that crosses shard boundaries
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_distributed_shortest_path() {
|
||||
// // TODO: Shortest path query where path crosses shards
|
||||
// }
|
||||
|
||||
// ============================================================================
|
||||
// Distributed Transactions
|
||||
// ============================================================================
|
||||
|
||||
// #[test]
|
||||
// fn test_two_phase_commit() {
|
||||
// // TODO: Transaction spanning multiple shards
|
||||
// // Verify 2PC ensures atomicity
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_distributed_deadlock_detection() {
|
||||
// // TODO: Create scenario that could cause distributed deadlock
|
||||
// // Verify detection and resolution
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_distributed_rollback() {
|
||||
// // TODO: Transaction that fails on one shard
|
||||
// // Verify all shards roll back
|
||||
// }
|
||||
|
||||
// ============================================================================
|
||||
// Fault Tolerance Tests
|
||||
// ============================================================================
|
||||
|
||||
// #[test]
|
||||
// fn test_node_failure_recovery() {
|
||||
// // TODO: Kill a node, verify cluster recovers
|
||||
// // Data should still be accessible via replicas
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_network_partition_handling() {
|
||||
// // TODO: Simulate network partition
|
||||
// // Verify cluster handles it gracefully
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_data_recovery_after_crash() {
|
||||
// // TODO: Node crashes, then restarts
|
||||
// // Verify it can rejoin cluster and catch up
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_quorum_based_operations() {
|
||||
// // TODO: Verify operations require quorum
|
||||
// // If quorum lost, writes should fail
|
||||
// }
|
||||
|
||||
// ============================================================================
|
||||
// Federation Tests
|
||||
// ============================================================================
|
||||
|
||||
// #[test]
|
||||
// fn test_cross_cluster_query() {
|
||||
// // TODO: Query that spans multiple independent clusters
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_federated_search() {
|
||||
// // TODO: Search across federated clusters
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_cluster_to_cluster_replication() {
|
||||
// // TODO: Data replication between clusters
|
||||
// }
|
||||
|
||||
// ============================================================================
|
||||
// Consistency Tests
|
||||
// ============================================================================
|
||||
|
||||
// #[test]
|
||||
// fn test_strong_consistency() {
|
||||
// // TODO: With strong consistency level, verify linearizability
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_eventual_consistency() {
|
||||
// // TODO: With eventual consistency, verify data converges
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_causal_consistency() {
|
||||
// // TODO: Verify causal relationships are preserved
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_read_your_writes() {
|
||||
// // TODO: Client should always see its own writes
|
||||
// }
|
||||
|
||||
// ============================================================================
|
||||
// Performance Tests
|
||||
// ============================================================================
|
||||
|
||||
// #[test]
|
||||
// fn test_horizontal_scalability() {
|
||||
// // TODO: Measure throughput with 1, 2, 4, 8 nodes
|
||||
// // Verify near-linear scaling
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_load_balancing() {
|
||||
// // TODO: Verify load is balanced across nodes
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_hotspot_handling() {
|
||||
// // TODO: Create hotspot (frequently accessed data)
|
||||
// // Verify system handles it gracefully
|
||||
// }
|
||||
|
||||
// ============================================================================
|
||||
// Configuration Tests
|
||||
// ============================================================================
|
||||
|
||||
// #[test]
|
||||
// fn test_replication_factor_configuration() {
|
||||
// // TODO: Test different replication factors (1, 2, 3)
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_consistency_level_configuration() {
|
||||
// // TODO: Test different consistency levels
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_partition_strategy_configuration() {
|
||||
// // TODO: Test different partitioning strategies
|
||||
// }
|
||||
|
||||
// ============================================================================
|
||||
// Monitoring and Observability
|
||||
// ============================================================================
|
||||
|
||||
// #[test]
|
||||
// fn test_cluster_health_monitoring() {
|
||||
// // TODO: Verify cluster health metrics are available
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_shard_distribution_metrics() {
|
||||
// // TODO: Verify we can monitor shard distribution
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_replication_lag_metrics() {
|
||||
// // TODO: Verify replication lag is monitored
|
||||
// }
|
||||
|
||||
// ============================================================================
|
||||
// Backup and Restore
|
||||
// ============================================================================
|
||||
|
||||
// #[test]
|
||||
// fn test_distributed_backup() {
|
||||
// // TODO: Create backup of distributed database
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_distributed_restore() {
|
||||
// // TODO: Restore from backup to new cluster
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_point_in_time_recovery() {
|
||||
// // TODO: Restore to specific point in time
|
||||
// }
|
||||
371
vendor/ruvector/crates/ruvector-graph/tests/edge_tests.rs
vendored
Normal file
371
vendor/ruvector/crates/ruvector-graph/tests/edge_tests.rs
vendored
Normal file
@@ -0,0 +1,371 @@
|
||||
//! Edge (relationship) operation tests
|
||||
//!
|
||||
//! Tests for creating edges, querying relationships, and graph traversals.
|
||||
|
||||
use ruvector_graph::{Edge, EdgeBuilder, GraphDB, Label, Node, Properties, PropertyValue};
|
||||
|
||||
#[test]
|
||||
fn test_create_edge_basic() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
// Create nodes first
|
||||
let node1 = Node::new(
|
||||
"person1".to_string(),
|
||||
vec![Label {
|
||||
name: "Person".to_string(),
|
||||
}],
|
||||
Properties::new(),
|
||||
);
|
||||
let node2 = Node::new(
|
||||
"person2".to_string(),
|
||||
vec![Label {
|
||||
name: "Person".to_string(),
|
||||
}],
|
||||
Properties::new(),
|
||||
);
|
||||
|
||||
db.create_node(node1).unwrap();
|
||||
db.create_node(node2).unwrap();
|
||||
|
||||
// Create edge
|
||||
let edge = Edge::new(
|
||||
"edge1".to_string(),
|
||||
"person1".to_string(),
|
||||
"person2".to_string(),
|
||||
"KNOWS".to_string(),
|
||||
Properties::new(),
|
||||
);
|
||||
|
||||
let edge_id = db.create_edge(edge).unwrap();
|
||||
assert_eq!(edge_id, "edge1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_edge_existing() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
// Setup nodes
|
||||
let node1 = Node::new("n1".to_string(), vec![], Properties::new());
|
||||
let node2 = Node::new("n2".to_string(), vec![], Properties::new());
|
||||
db.create_node(node1).unwrap();
|
||||
db.create_node(node2).unwrap();
|
||||
|
||||
// Create edge with properties
|
||||
let mut properties = Properties::new();
|
||||
properties.insert("since".to_string(), PropertyValue::Integer(2020));
|
||||
|
||||
let edge = Edge::new(
|
||||
"e1".to_string(),
|
||||
"n1".to_string(),
|
||||
"n2".to_string(),
|
||||
"FRIEND_OF".to_string(),
|
||||
properties,
|
||||
);
|
||||
|
||||
db.create_edge(edge).unwrap();
|
||||
|
||||
let retrieved = db.get_edge("e1").unwrap();
|
||||
assert_eq!(retrieved.id, "e1");
|
||||
assert_eq!(retrieved.from, "n1");
|
||||
assert_eq!(retrieved.to, "n2");
|
||||
assert_eq!(retrieved.edge_type, "FRIEND_OF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_with_properties() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
// Setup
|
||||
db.create_node(Node::new("a".to_string(), vec![], Properties::new()))
|
||||
.unwrap();
|
||||
db.create_node(Node::new("b".to_string(), vec![], Properties::new()))
|
||||
.unwrap();
|
||||
|
||||
let mut properties = Properties::new();
|
||||
properties.insert("weight".to_string(), PropertyValue::Float(0.85));
|
||||
properties.insert(
|
||||
"type".to_string(),
|
||||
PropertyValue::String("strong".to_string()),
|
||||
);
|
||||
properties.insert("verified".to_string(), PropertyValue::Boolean(true));
|
||||
|
||||
let edge = Edge::new(
|
||||
"weighted_edge".to_string(),
|
||||
"a".to_string(),
|
||||
"b".to_string(),
|
||||
"CONNECTED_TO".to_string(),
|
||||
properties,
|
||||
);
|
||||
|
||||
db.create_edge(edge).unwrap();
|
||||
|
||||
let retrieved = db.get_edge("weighted_edge").unwrap();
|
||||
assert_eq!(
|
||||
retrieved.properties.get("weight"),
|
||||
Some(&PropertyValue::Float(0.85))
|
||||
);
|
||||
assert_eq!(
|
||||
retrieved.properties.get("verified"),
|
||||
Some(&PropertyValue::Boolean(true))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bidirectional_edges() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
db.create_node(Node::new("alice".to_string(), vec![], Properties::new()))
|
||||
.unwrap();
|
||||
db.create_node(Node::new("bob".to_string(), vec![], Properties::new()))
|
||||
.unwrap();
|
||||
|
||||
// Alice -> Bob
|
||||
let edge1 = Edge::new(
|
||||
"e1".to_string(),
|
||||
"alice".to_string(),
|
||||
"bob".to_string(),
|
||||
"FOLLOWS".to_string(),
|
||||
Properties::new(),
|
||||
);
|
||||
|
||||
// Bob -> Alice
|
||||
let edge2 = Edge::new(
|
||||
"e2".to_string(),
|
||||
"bob".to_string(),
|
||||
"alice".to_string(),
|
||||
"FOLLOWS".to_string(),
|
||||
Properties::new(),
|
||||
);
|
||||
|
||||
db.create_edge(edge1).unwrap();
|
||||
db.create_edge(edge2).unwrap();
|
||||
|
||||
let e1 = db.get_edge("e1").unwrap();
|
||||
let e2 = db.get_edge("e2").unwrap();
|
||||
|
||||
assert_eq!(e1.from, "alice");
|
||||
assert_eq!(e1.to, "bob");
|
||||
assert_eq!(e2.from, "bob");
|
||||
assert_eq!(e2.to, "alice");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_self_loop_edge() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
db.create_node(Node::new("node".to_string(), vec![], Properties::new()))
|
||||
.unwrap();
|
||||
|
||||
let edge = Edge::new(
|
||||
"self_loop".to_string(),
|
||||
"node".to_string(),
|
||||
"node".to_string(),
|
||||
"REFERENCES".to_string(),
|
||||
Properties::new(),
|
||||
);
|
||||
|
||||
db.create_edge(edge).unwrap();
|
||||
|
||||
let retrieved = db.get_edge("self_loop").unwrap();
|
||||
assert_eq!(retrieved.from, retrieved.to);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_edges_same_nodes() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
db.create_node(Node::new("x".to_string(), vec![], Properties::new()))
|
||||
.unwrap();
|
||||
db.create_node(Node::new("y".to_string(), vec![], Properties::new()))
|
||||
.unwrap();
|
||||
|
||||
// Multiple relationship types between same nodes
|
||||
let edge1 = Edge::new(
|
||||
"e1".to_string(),
|
||||
"x".to_string(),
|
||||
"y".to_string(),
|
||||
"WORKS_WITH".to_string(),
|
||||
Properties::new(),
|
||||
);
|
||||
|
||||
let edge2 = Edge::new(
|
||||
"e2".to_string(),
|
||||
"x".to_string(),
|
||||
"y".to_string(),
|
||||
"FRIENDS_WITH".to_string(),
|
||||
Properties::new(),
|
||||
);
|
||||
|
||||
db.create_edge(edge1).unwrap();
|
||||
db.create_edge(edge2).unwrap();
|
||||
|
||||
assert!(db.get_edge("e1").is_some());
|
||||
assert!(db.get_edge("e2").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_timestamp_property() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
db.create_node(Node::new("user1".to_string(), vec![], Properties::new()))
|
||||
.unwrap();
|
||||
db.create_node(Node::new("post1".to_string(), vec![], Properties::new()))
|
||||
.unwrap();
|
||||
|
||||
let mut properties = Properties::new();
|
||||
properties.insert("timestamp".to_string(), PropertyValue::Integer(1699564800));
|
||||
properties.insert(
|
||||
"action".to_string(),
|
||||
PropertyValue::String("liked".to_string()),
|
||||
);
|
||||
|
||||
let edge = Edge::new(
|
||||
"interaction".to_string(),
|
||||
"user1".to_string(),
|
||||
"post1".to_string(),
|
||||
"INTERACTED".to_string(),
|
||||
properties,
|
||||
);
|
||||
|
||||
db.create_edge(edge).unwrap();
|
||||
|
||||
let retrieved = db.get_edge("interaction").unwrap();
|
||||
assert!(retrieved.properties.contains_key("timestamp"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_nonexistent_edge() {
|
||||
let db = GraphDB::new();
|
||||
let result = db.get_edge("does_not_exist");
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_many_edges() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
// Create hub node
|
||||
db.create_node(Node::new("hub".to_string(), vec![], Properties::new()))
|
||||
.unwrap();
|
||||
|
||||
// Create 100 spoke nodes
|
||||
for i in 0..100 {
|
||||
let node_id = format!("spoke_{}", i);
|
||||
db.create_node(Node::new(node_id.clone(), vec![], Properties::new()))
|
||||
.unwrap();
|
||||
|
||||
let edge = Edge::new(
|
||||
format!("edge_{}", i),
|
||||
"hub".to_string(),
|
||||
node_id,
|
||||
"CONNECTS".to_string(),
|
||||
Properties::new(),
|
||||
);
|
||||
|
||||
db.create_edge(edge).unwrap();
|
||||
}
|
||||
|
||||
// Verify all edges exist
|
||||
for i in 0..100 {
|
||||
assert!(db.get_edge(&format!("edge_{}", i)).is_some());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_builder() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
db.create_node(Node::new("a".to_string(), vec![], Properties::new()))
|
||||
.unwrap();
|
||||
db.create_node(Node::new("b".to_string(), vec![], Properties::new()))
|
||||
.unwrap();
|
||||
|
||||
let edge = EdgeBuilder::new("a".to_string(), "b".to_string(), "KNOWS")
|
||||
.id("e1")
|
||||
.property("since", 2020i64)
|
||||
.property("weight", 0.95f64)
|
||||
.build();
|
||||
|
||||
db.create_edge(edge).unwrap();
|
||||
|
||||
let retrieved = db.get_edge("e1").unwrap();
|
||||
assert_eq!(retrieved.from, "a");
|
||||
assert_eq!(retrieved.to, "b");
|
||||
assert_eq!(retrieved.edge_type, "KNOWS");
|
||||
assert_eq!(
|
||||
retrieved.get_property("since"),
|
||||
Some(&PropertyValue::Integer(2020))
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Property-based tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod property_tests {
|
||||
use super::*;
|
||||
use proptest::prelude::*;
|
||||
|
||||
fn edge_id_strategy() -> impl Strategy<Value = String> {
|
||||
"[a-z][a-z0-9_]{0,20}".prop_map(|s| s.to_string())
|
||||
}
|
||||
|
||||
fn edge_type_strategy() -> impl Strategy<Value = String> {
|
||||
"[A-Z_]{2,15}".prop_map(|s| s.to_string())
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn test_edge_roundtrip(
|
||||
edge_id in edge_id_strategy(),
|
||||
edge_type in edge_type_strategy()
|
||||
) {
|
||||
let db = GraphDB::new();
|
||||
|
||||
// Setup nodes
|
||||
db.create_node(Node::new("from".to_string(), vec![], Properties::new())).unwrap();
|
||||
db.create_node(Node::new("to".to_string(), vec![], Properties::new())).unwrap();
|
||||
|
||||
let edge = Edge::new(
|
||||
edge_id.clone(),
|
||||
"from".to_string(),
|
||||
"to".to_string(),
|
||||
edge_type.clone(),
|
||||
Properties::new(),
|
||||
);
|
||||
|
||||
db.create_edge(edge).unwrap();
|
||||
|
||||
let retrieved = db.get_edge(&edge_id).unwrap();
|
||||
assert_eq!(retrieved.id, edge_id);
|
||||
assert_eq!(retrieved.edge_type, edge_type);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_many_edges_unique(
|
||||
edge_ids in prop::collection::hash_set(edge_id_strategy(), 10..50)
|
||||
) {
|
||||
let db = GraphDB::new();
|
||||
|
||||
// Create source and target nodes
|
||||
db.create_node(Node::new("source".to_string(), vec![], Properties::new())).unwrap();
|
||||
db.create_node(Node::new("target".to_string(), vec![], Properties::new())).unwrap();
|
||||
|
||||
for edge_id in &edge_ids {
|
||||
let edge = Edge::new(
|
||||
edge_id.clone(),
|
||||
"source".to_string(),
|
||||
"target".to_string(),
|
||||
"TEST".to_string(),
|
||||
Properties::new(),
|
||||
);
|
||||
db.create_edge(edge).unwrap();
|
||||
}
|
||||
|
||||
for edge_id in &edge_ids {
|
||||
assert!(db.get_edge(edge_id).is_some());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
51
vendor/ruvector/crates/ruvector-graph/tests/fixtures/README.md
vendored
Normal file
51
vendor/ruvector/crates/ruvector-graph/tests/fixtures/README.md
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
# Test Fixtures
|
||||
|
||||
This directory contains sample datasets and expected results for testing the RuVector graph database.
|
||||
|
||||
## Datasets
|
||||
|
||||
### movie_database.json
|
||||
A small movie database inspired by Neo4j's example dataset:
|
||||
- 3 actors (Keanu Reeves, Carrie-Anne Moss, Laurence Fishburne)
|
||||
- 1 movie (The Matrix)
|
||||
- 3 ACTED_IN relationships with role properties
|
||||
|
||||
### social_network.json
|
||||
A simple social network for testing graph algorithms:
|
||||
- 5 people
|
||||
- 6 KNOWS relationships forming a small network
|
||||
|
||||
## Expected Results
|
||||
|
||||
### expected_results.json
|
||||
Contains test cases with:
|
||||
- Query text (Cypher)
|
||||
- Which dataset to use
|
||||
- Expected query results
|
||||
|
||||
Use these to validate that query execution returns correct results.
|
||||
|
||||
## Usage in Tests
|
||||
|
||||
```rust
|
||||
use std::fs;
|
||||
use serde_json::Value;
|
||||
|
||||
#[test]
|
||||
fn test_with_fixture() {
|
||||
let fixture = fs::read_to_string("tests/fixtures/movie_database.json").unwrap();
|
||||
let data: Value = serde_json::from_str(&fixture).unwrap();
|
||||
|
||||
// Load data into graph
|
||||
// Execute queries
|
||||
// Validate against expected results
|
||||
}
|
||||
```
|
||||
|
||||
## Adding New Fixtures
|
||||
|
||||
When adding new fixtures:
|
||||
1. Follow the JSON schema used in existing files
|
||||
2. Add corresponding expected results
|
||||
3. Document the dataset purpose
|
||||
4. Keep datasets small and focused on specific test scenarios
|
||||
48
vendor/ruvector/crates/ruvector-graph/tests/fixtures/expected_results.json
vendored
Normal file
48
vendor/ruvector/crates/ruvector-graph/tests/fixtures/expected_results.json
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
{
|
||||
"description": "Expected query results for validation",
|
||||
"test_cases": [
|
||||
{
|
||||
"name": "count_all_nodes",
|
||||
"query": "MATCH (n) RETURN COUNT(n)",
|
||||
"dataset": "movie_database",
|
||||
"expected": [{"COUNT(n)": 4}]
|
||||
},
|
||||
{
|
||||
"name": "match_person_nodes",
|
||||
"query": "MATCH (p:Person) RETURN p.name ORDER BY p.name",
|
||||
"dataset": "movie_database",
|
||||
"expected": [
|
||||
{"p.name": "Carrie-Anne Moss"},
|
||||
{"p.name": "Keanu Reeves"},
|
||||
{"p.name": "Laurence Fishburne"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "count_relationships",
|
||||
"query": "MATCH ()-[r:ACTED_IN]->() RETURN COUNT(r)",
|
||||
"dataset": "movie_database",
|
||||
"expected": [{"COUNT(r)": 3}]
|
||||
},
|
||||
{
|
||||
"name": "social_network_friends",
|
||||
"query": "MATCH (p:Person {name: 'Alice'})-[:KNOWS]->(friend) RETURN friend.name ORDER BY friend.name",
|
||||
"dataset": "social_network",
|
||||
"expected": [
|
||||
{"friend.name": "Bob"},
|
||||
{"friend.name": "Charlie"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "average_age",
|
||||
"query": "MATCH (p:Person) RETURN AVG(p.age) AS avg_age",
|
||||
"dataset": "social_network",
|
||||
"expected": [{"avg_age": 30.4}]
|
||||
},
|
||||
{
|
||||
"name": "people_born_after_1965",
|
||||
"query": "MATCH (p:Person) WHERE p.born > 1965 RETURN p.name",
|
||||
"dataset": "movie_database",
|
||||
"expected": [{"p.name": "Carrie-Anne Moss"}]
|
||||
}
|
||||
]
|
||||
}
|
||||
67
vendor/ruvector/crates/ruvector-graph/tests/fixtures/movie_database.json
vendored
Normal file
67
vendor/ruvector/crates/ruvector-graph/tests/fixtures/movie_database.json
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
{
|
||||
"description": "Sample movie database for testing",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "keanu",
|
||||
"labels": ["Person"],
|
||||
"properties": {
|
||||
"name": "Keanu Reeves",
|
||||
"born": 1964
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "carrie",
|
||||
"labels": ["Person"],
|
||||
"properties": {
|
||||
"name": "Carrie-Anne Moss",
|
||||
"born": 1967
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "laurence",
|
||||
"labels": ["Person"],
|
||||
"properties": {
|
||||
"name": "Laurence Fishburne",
|
||||
"born": 1961
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "matrix",
|
||||
"labels": ["Movie"],
|
||||
"properties": {
|
||||
"title": "The Matrix",
|
||||
"released": 1999,
|
||||
"tagline": "Welcome to the Real World"
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "e1",
|
||||
"from": "keanu",
|
||||
"to": "matrix",
|
||||
"type": "ACTED_IN",
|
||||
"properties": {
|
||||
"roles": ["Neo"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "e2",
|
||||
"from": "carrie",
|
||||
"to": "matrix",
|
||||
"type": "ACTED_IN",
|
||||
"properties": {
|
||||
"roles": ["Trinity"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "e3",
|
||||
"from": "laurence",
|
||||
"to": "matrix",
|
||||
"type": "ACTED_IN",
|
||||
"properties": {
|
||||
"roles": ["Morpheus"]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
18
vendor/ruvector/crates/ruvector-graph/tests/fixtures/social_network.json
vendored
Normal file
18
vendor/ruvector/crates/ruvector-graph/tests/fixtures/social_network.json
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"description": "Sample social network for testing",
|
||||
"nodes": [
|
||||
{"id": "alice", "labels": ["Person"], "properties": {"name": "Alice", "age": 30}},
|
||||
{"id": "bob", "labels": ["Person"], "properties": {"name": "Bob", "age": 35}},
|
||||
{"id": "charlie", "labels": ["Person"], "properties": {"name": "Charlie", "age": 28}},
|
||||
{"id": "diana", "labels": ["Person"], "properties": {"name": "Diana", "age": 32}},
|
||||
{"id": "eve", "labels": ["Person"], "properties": {"name": "Eve", "age": 27}}
|
||||
],
|
||||
"edges": [
|
||||
{"id": "e1", "from": "alice", "to": "bob", "type": "KNOWS", "properties": {"since": 2015}},
|
||||
{"id": "e2", "from": "alice", "to": "charlie", "type": "KNOWS", "properties": {"since": 2018}},
|
||||
{"id": "e3", "from": "bob", "to": "charlie", "type": "KNOWS", "properties": {"since": 2016}},
|
||||
{"id": "e4", "from": "bob", "to": "diana", "type": "KNOWS", "properties": {"since": 2019}},
|
||||
{"id": "e5", "from": "charlie", "to": "eve", "type": "KNOWS", "properties": {"since": 2020}},
|
||||
{"id": "e6", "from": "diana", "to": "eve", "type": "KNOWS", "properties": {"since": 2017}}
|
||||
]
|
||||
}
|
||||
461
vendor/ruvector/crates/ruvector-graph/tests/hyperedge_tests.rs
vendored
Normal file
461
vendor/ruvector/crates/ruvector-graph/tests/hyperedge_tests.rs
vendored
Normal file
@@ -0,0 +1,461 @@
|
||||
//! Hyperedge (N-ary relationship) tests
|
||||
//!
|
||||
//! Tests for hypergraph features supporting relationships between multiple nodes.
|
||||
//! Based on the existing hypergraph implementation in ruvector-core.
|
||||
|
||||
use ruvector_core::advanced::hypergraph::{
|
||||
Hyperedge, HypergraphIndex, TemporalGranularity, TemporalHyperedge,
|
||||
};
|
||||
use ruvector_core::types::DistanceMetric;
|
||||
|
||||
#[test]
|
||||
fn test_create_binary_hyperedge() {
|
||||
let edge = Hyperedge::new(
|
||||
vec!["1".to_string(), "2".to_string()],
|
||||
"Alice knows Bob".to_string(),
|
||||
vec![0.1, 0.2, 0.3],
|
||||
0.95,
|
||||
);
|
||||
|
||||
assert_eq!(edge.order(), 2);
|
||||
assert!(edge.contains_node(&"1".to_string()));
|
||||
assert!(edge.contains_node(&"2".to_string()));
|
||||
assert!(!edge.contains_node(&"3".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ternary_hyperedge() {
|
||||
let edge = Hyperedge::new(
|
||||
vec!["1".to_string(), "2".to_string(), "3".to_string()],
|
||||
"Meeting between Alice, Bob, and Charlie".to_string(),
|
||||
vec![0.5; 128],
|
||||
0.90,
|
||||
);
|
||||
|
||||
assert_eq!(edge.order(), 3);
|
||||
assert!(edge.contains_node(&"1".to_string()));
|
||||
assert!(edge.contains_node(&"2".to_string()));
|
||||
assert!(edge.contains_node(&"3".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_large_hyperedge() {
|
||||
let nodes: Vec<String> = (0..100).map(|i| i.to_string()).collect();
|
||||
let edge = Hyperedge::new(
|
||||
nodes.clone(),
|
||||
"Large group collaboration".to_string(),
|
||||
vec![0.1; 64],
|
||||
0.75,
|
||||
);
|
||||
|
||||
assert_eq!(edge.order(), 100);
|
||||
for node in nodes {
|
||||
assert!(edge.contains_node(&node));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hyperedge_confidence_clamping() {
|
||||
let edge1 = Hyperedge::new(
|
||||
vec!["1".to_string(), "2".to_string()],
|
||||
"Test".to_string(),
|
||||
vec![0.1],
|
||||
1.5,
|
||||
);
|
||||
assert_eq!(edge1.confidence, 1.0);
|
||||
|
||||
let edge2 = Hyperedge::new(
|
||||
vec!["1".to_string(), "2".to_string()],
|
||||
"Test".to_string(),
|
||||
vec![0.1],
|
||||
-0.5,
|
||||
);
|
||||
assert_eq!(edge2.confidence, 0.0);
|
||||
|
||||
let edge3 = Hyperedge::new(
|
||||
vec!["1".to_string(), "2".to_string()],
|
||||
"Test".to_string(),
|
||||
vec![0.1],
|
||||
0.75,
|
||||
);
|
||||
assert_eq!(edge3.confidence, 0.75);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temporal_hyperedge_creation() {
|
||||
let edge = Hyperedge::new(
|
||||
vec!["1".to_string(), "2".to_string(), "3".to_string()],
|
||||
"Temporal relationship".to_string(),
|
||||
vec![0.5; 32],
|
||||
0.9,
|
||||
);
|
||||
|
||||
let temporal = TemporalHyperedge::new(edge, TemporalGranularity::Hourly);
|
||||
|
||||
assert!(!temporal.is_expired());
|
||||
assert!(temporal.timestamp > 0);
|
||||
assert!(temporal.time_bucket() > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temporal_granularity_bucketing() {
|
||||
let edge = Hyperedge::new(
|
||||
vec!["1".to_string(), "2".to_string()],
|
||||
"Test".to_string(),
|
||||
vec![0.1],
|
||||
1.0,
|
||||
);
|
||||
|
||||
let hourly = TemporalHyperedge::new(edge.clone(), TemporalGranularity::Hourly);
|
||||
let daily = TemporalHyperedge::new(edge.clone(), TemporalGranularity::Daily);
|
||||
let monthly = TemporalHyperedge::new(edge.clone(), TemporalGranularity::Monthly);
|
||||
|
||||
// Different granularities should produce different buckets
|
||||
assert!(hourly.time_bucket() >= daily.time_bucket());
|
||||
assert!(daily.time_bucket() >= monthly.time_bucket());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hypergraph_index_basic() {
|
||||
let mut index = HypergraphIndex::new(DistanceMetric::Cosine);
|
||||
|
||||
// Add entities
|
||||
index.add_entity("1".to_string(), vec![1.0, 0.0, 0.0]);
|
||||
index.add_entity("2".to_string(), vec![0.0, 1.0, 0.0]);
|
||||
index.add_entity("3".to_string(), vec![0.0, 0.0, 1.0]);
|
||||
|
||||
// Add hyperedge
|
||||
let edge = Hyperedge::new(
|
||||
vec!["1".to_string(), "2".to_string(), "3".to_string()],
|
||||
"Triangle relationship".to_string(),
|
||||
vec![0.33, 0.33, 0.34],
|
||||
0.95,
|
||||
);
|
||||
|
||||
index.add_hyperedge(edge).unwrap();
|
||||
|
||||
let stats = index.stats();
|
||||
assert_eq!(stats.total_entities, 3);
|
||||
assert_eq!(stats.total_hyperedges, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hypergraph_multiple_hyperedges() {
|
||||
let mut index = HypergraphIndex::new(DistanceMetric::Euclidean);
|
||||
|
||||
// Add entities
|
||||
for i in 1..=5 {
|
||||
index.add_entity(i.to_string(), vec![i as f32; 64]);
|
||||
}
|
||||
|
||||
// Add multiple hyperedges with different orders
|
||||
let edge1 = Hyperedge::new(
|
||||
vec!["1".to_string(), "2".to_string()],
|
||||
"Binary".to_string(),
|
||||
vec![0.5; 64],
|
||||
1.0,
|
||||
);
|
||||
let edge2 = Hyperedge::new(
|
||||
vec!["1".to_string(), "2".to_string(), "3".to_string()],
|
||||
"Ternary".to_string(),
|
||||
vec![0.5; 64],
|
||||
1.0,
|
||||
);
|
||||
let edge3 = Hyperedge::new(
|
||||
vec![
|
||||
"1".to_string(),
|
||||
"2".to_string(),
|
||||
"3".to_string(),
|
||||
"4".to_string(),
|
||||
],
|
||||
"Quaternary".to_string(),
|
||||
vec![0.5; 64],
|
||||
1.0,
|
||||
);
|
||||
let edge4 = Hyperedge::new(
|
||||
vec![
|
||||
"1".to_string(),
|
||||
"2".to_string(),
|
||||
"3".to_string(),
|
||||
"4".to_string(),
|
||||
"5".to_string(),
|
||||
],
|
||||
"Quinary".to_string(),
|
||||
vec![0.5; 64],
|
||||
1.0,
|
||||
);
|
||||
|
||||
index.add_hyperedge(edge1).unwrap();
|
||||
index.add_hyperedge(edge2).unwrap();
|
||||
index.add_hyperedge(edge3).unwrap();
|
||||
index.add_hyperedge(edge4).unwrap();
|
||||
|
||||
let stats = index.stats();
|
||||
assert_eq!(stats.total_hyperedges, 4);
|
||||
assert!(stats.avg_entity_degree > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hypergraph_search() {
|
||||
let mut index = HypergraphIndex::new(DistanceMetric::Cosine);
|
||||
|
||||
// Add entities
|
||||
for i in 1..=10 {
|
||||
index.add_entity(i.to_string(), vec![i as f32 * 0.1; 32]);
|
||||
}
|
||||
|
||||
// Add hyperedges
|
||||
for i in 1..=5 {
|
||||
let edge = Hyperedge::new(
|
||||
vec![i.to_string(), (i + 1).to_string()],
|
||||
format!("Edge {}", i),
|
||||
vec![i as f32 * 0.1; 32],
|
||||
0.9,
|
||||
);
|
||||
index.add_hyperedge(edge).unwrap();
|
||||
}
|
||||
|
||||
// Search for similar hyperedges
|
||||
let query = vec![0.3; 32];
|
||||
let results = index.search_hyperedges(&query, 3);
|
||||
|
||||
assert_eq!(results.len(), 3);
|
||||
// Results should be sorted by distance
|
||||
for i in 0..results.len() - 1 {
|
||||
assert!(results[i].1 <= results[i + 1].1);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_k_hop_neighbors_simple() {
|
||||
let mut index = HypergraphIndex::new(DistanceMetric::Cosine);
|
||||
|
||||
// Create chain: 1-2-3-4
|
||||
for i in 1..=4 {
|
||||
index.add_entity(i.to_string(), vec![i as f32]);
|
||||
}
|
||||
|
||||
let e1 = Hyperedge::new(
|
||||
vec!["1".to_string(), "2".to_string()],
|
||||
"e1".to_string(),
|
||||
vec![1.0],
|
||||
1.0,
|
||||
);
|
||||
let e2 = Hyperedge::new(
|
||||
vec!["2".to_string(), "3".to_string()],
|
||||
"e2".to_string(),
|
||||
vec![1.0],
|
||||
1.0,
|
||||
);
|
||||
let e3 = Hyperedge::new(
|
||||
vec!["3".to_string(), "4".to_string()],
|
||||
"e3".to_string(),
|
||||
vec![1.0],
|
||||
1.0,
|
||||
);
|
||||
|
||||
index.add_hyperedge(e1).unwrap();
|
||||
index.add_hyperedge(e2).unwrap();
|
||||
index.add_hyperedge(e3).unwrap();
|
||||
|
||||
// 1-hop from node 1 should include 1 and 2
|
||||
let neighbors_1hop = index.k_hop_neighbors("1".to_string(), 1);
|
||||
assert!(neighbors_1hop.contains(&"1".to_string()));
|
||||
assert!(neighbors_1hop.contains(&"2".to_string()));
|
||||
|
||||
// 2-hop from node 1 should include 1, 2, and 3
|
||||
let neighbors_2hop = index.k_hop_neighbors("1".to_string(), 2);
|
||||
assert!(neighbors_2hop.contains(&"1".to_string()));
|
||||
assert!(neighbors_2hop.contains(&"2".to_string()));
|
||||
assert!(neighbors_2hop.contains(&"3".to_string()));
|
||||
|
||||
// 3-hop from node 1 should include all nodes
|
||||
let neighbors_3hop = index.k_hop_neighbors("1".to_string(), 3);
|
||||
assert_eq!(neighbors_3hop.len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_k_hop_neighbors_complex() {
|
||||
let mut index = HypergraphIndex::new(DistanceMetric::Cosine);
|
||||
|
||||
// Create star topology: center node connected to 5 peripheral nodes
|
||||
for i in 0..=5 {
|
||||
index.add_entity(i.to_string(), vec![i as f32]);
|
||||
}
|
||||
|
||||
// Center (0) connected to all others via hyperedges
|
||||
for i in 1..=5 {
|
||||
let edge = Hyperedge::new(
|
||||
vec!["0".to_string(), i.to_string()],
|
||||
format!("e{}", i),
|
||||
vec![1.0],
|
||||
1.0,
|
||||
);
|
||||
index.add_hyperedge(edge).unwrap();
|
||||
}
|
||||
|
||||
// 1-hop from center should reach all nodes
|
||||
let neighbors = index.k_hop_neighbors("0".to_string(), 1);
|
||||
assert_eq!(neighbors.len(), 6); // All nodes
|
||||
|
||||
// 1-hop from peripheral node should reach center and itself
|
||||
let neighbors = index.k_hop_neighbors("1".to_string(), 1);
|
||||
assert!(neighbors.contains(&"0".to_string()));
|
||||
assert!(neighbors.contains(&"1".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temporal_range_query() {
|
||||
let mut index = HypergraphIndex::new(DistanceMetric::Cosine);
|
||||
|
||||
// Add entities
|
||||
for i in 1..=3 {
|
||||
index.add_entity(i.to_string(), vec![i as f32]);
|
||||
}
|
||||
|
||||
// Add temporal hyperedges (they'll all be in current time bucket)
|
||||
let edge1 = Hyperedge::new(
|
||||
vec!["1".to_string(), "2".to_string()],
|
||||
"t1".to_string(),
|
||||
vec![1.0],
|
||||
1.0,
|
||||
);
|
||||
let edge2 = Hyperedge::new(
|
||||
vec!["2".to_string(), "3".to_string()],
|
||||
"t2".to_string(),
|
||||
vec![1.0],
|
||||
1.0,
|
||||
);
|
||||
|
||||
let temp1 = TemporalHyperedge::new(edge1, TemporalGranularity::Hourly);
|
||||
let temp2 = TemporalHyperedge::new(edge2, TemporalGranularity::Hourly);
|
||||
|
||||
let bucket = temp1.time_bucket();
|
||||
|
||||
index.add_temporal_hyperedge(temp1).unwrap();
|
||||
index.add_temporal_hyperedge(temp2).unwrap();
|
||||
|
||||
// Query current time bucket
|
||||
let results = index.query_temporal_range(bucket, bucket);
|
||||
assert_eq!(results.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hyperedge_with_duplicate_nodes() {
|
||||
// Test that hyperedge handles duplicate nodes appropriately
|
||||
let edge = Hyperedge::new(
|
||||
vec![
|
||||
"1".to_string(),
|
||||
"2".to_string(),
|
||||
"2".to_string(),
|
||||
"3".to_string(),
|
||||
], // Duplicate node 2
|
||||
"Duplicate test".to_string(),
|
||||
vec![0.5; 16],
|
||||
0.8,
|
||||
);
|
||||
|
||||
assert_eq!(edge.order(), 4); // Includes duplicates
|
||||
assert!(edge.contains_node(&"2".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hypergraph_error_on_missing_entity() {
|
||||
let mut index = HypergraphIndex::new(DistanceMetric::Cosine);
|
||||
|
||||
// Only add entity 1, not 2
|
||||
index.add_entity("1".to_string(), vec![1.0]);
|
||||
|
||||
// Try to create hyperedge with missing entity
|
||||
let edge = Hyperedge::new(
|
||||
vec!["1".to_string(), "2".to_string()],
|
||||
"Test".to_string(),
|
||||
vec![0.5],
|
||||
1.0,
|
||||
);
|
||||
|
||||
let result = index.add_hyperedge(edge);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Property-based tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod property_tests {
|
||||
use super::*;
|
||||
use proptest::prelude::*;
|
||||
|
||||
fn node_vec_strategy() -> impl Strategy<Value = Vec<String>> {
|
||||
prop::collection::vec("[a-z]{1,5}".prop_map(|s| s), 2..20)
|
||||
}
|
||||
|
||||
fn embedding_strategy(dim: usize) -> impl Strategy<Value = Vec<f32>> {
|
||||
prop::collection::vec(-1.0f32..1.0f32, dim)
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn test_hyperedge_order_property(
|
||||
nodes in node_vec_strategy()
|
||||
) {
|
||||
let edge = Hyperedge::new(
|
||||
nodes.clone(),
|
||||
"Test".to_string(),
|
||||
vec![0.5; 32],
|
||||
0.9
|
||||
);
|
||||
|
||||
assert_eq!(edge.order(), nodes.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hyperedge_contains_all_nodes(
|
||||
nodes in node_vec_strategy()
|
||||
) {
|
||||
let edge = Hyperedge::new(
|
||||
nodes.clone(),
|
||||
"Test".to_string(),
|
||||
vec![0.5; 32],
|
||||
0.9
|
||||
);
|
||||
|
||||
for node in &nodes {
|
||||
assert!(edge.contains_node(node));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hypergraph_search_consistency(
|
||||
query in embedding_strategy(32),
|
||||
k in 1usize..10
|
||||
) {
|
||||
let mut index = HypergraphIndex::new(DistanceMetric::Cosine);
|
||||
|
||||
// Add entities
|
||||
for i in 1..=10 {
|
||||
index.add_entity(i.to_string(), vec![i as f32 * 0.1; 32]);
|
||||
}
|
||||
|
||||
// Add hyperedges
|
||||
for i in 1..=10 {
|
||||
let edge = Hyperedge::new(
|
||||
vec![i.to_string()],
|
||||
format!("Edge {}", i),
|
||||
vec![i as f32 * 0.1; 32],
|
||||
0.9
|
||||
);
|
||||
index.add_hyperedge(edge).unwrap();
|
||||
}
|
||||
|
||||
let results = index.search_hyperedges(&query, k.min(10));
|
||||
assert!(results.len() <= k.min(10));
|
||||
|
||||
// Verify results are sorted
|
||||
for i in 0..results.len().saturating_sub(1) {
|
||||
assert!(results[i].1 <= results[i + 1].1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
386
vendor/ruvector/crates/ruvector-graph/tests/node_tests.rs
vendored
Normal file
386
vendor/ruvector/crates/ruvector-graph/tests/node_tests.rs
vendored
Normal file
@@ -0,0 +1,386 @@
|
||||
//! Node CRUD operation tests
|
||||
//!
|
||||
//! Tests for creating, reading, updating, and deleting nodes in the graph database.
|
||||
|
||||
use ruvector_graph::{GraphDB, Label, Node, Properties, PropertyValue};
|
||||
|
||||
#[test]
|
||||
fn test_create_node_basic() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let mut properties = Properties::new();
|
||||
properties.insert(
|
||||
"name".to_string(),
|
||||
PropertyValue::String("Alice".to_string()),
|
||||
);
|
||||
properties.insert("age".to_string(), PropertyValue::Integer(30));
|
||||
|
||||
let node = Node::new(
|
||||
"node1".to_string(),
|
||||
vec![Label {
|
||||
name: "Person".to_string(),
|
||||
}],
|
||||
properties,
|
||||
);
|
||||
|
||||
let node_id = db.create_node(node).unwrap();
|
||||
assert_eq!(node_id, "node1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_node_existing() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let mut properties = Properties::new();
|
||||
properties.insert("name".to_string(), PropertyValue::String("Bob".to_string()));
|
||||
|
||||
let node = Node::new(
|
||||
"node2".to_string(),
|
||||
vec![Label {
|
||||
name: "Person".to_string(),
|
||||
}],
|
||||
properties.clone(),
|
||||
);
|
||||
|
||||
db.create_node(node).unwrap();
|
||||
|
||||
let retrieved = db.get_node("node2").unwrap();
|
||||
assert_eq!(retrieved.id, "node2");
|
||||
assert_eq!(
|
||||
retrieved.properties.get("name"),
|
||||
Some(&PropertyValue::String("Bob".to_string()))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_node_nonexistent() {
|
||||
let db = GraphDB::new();
|
||||
let result = db.get_node("nonexistent");
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_with_multiple_labels() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let labels = vec![
|
||||
Label {
|
||||
name: "Person".to_string(),
|
||||
},
|
||||
Label {
|
||||
name: "Employee".to_string(),
|
||||
},
|
||||
Label {
|
||||
name: "Manager".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
let mut properties = Properties::new();
|
||||
properties.insert(
|
||||
"name".to_string(),
|
||||
PropertyValue::String("Charlie".to_string()),
|
||||
);
|
||||
|
||||
let node = Node::new("node3".to_string(), labels, properties);
|
||||
db.create_node(node).unwrap();
|
||||
|
||||
let retrieved = db.get_node("node3").unwrap();
|
||||
assert_eq!(retrieved.labels.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_with_complex_properties() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let mut properties = Properties::new();
|
||||
properties.insert(
|
||||
"name".to_string(),
|
||||
PropertyValue::String("David".to_string()),
|
||||
);
|
||||
properties.insert("age".to_string(), PropertyValue::Integer(35));
|
||||
properties.insert("height".to_string(), PropertyValue::Float(1.82));
|
||||
properties.insert("active".to_string(), PropertyValue::Boolean(true));
|
||||
properties.insert(
|
||||
"tags".to_string(),
|
||||
PropertyValue::List(vec![
|
||||
PropertyValue::String("developer".to_string()),
|
||||
PropertyValue::String("team-lead".to_string()),
|
||||
]),
|
||||
);
|
||||
|
||||
let node = Node::new(
|
||||
"node4".to_string(),
|
||||
vec![Label {
|
||||
name: "Person".to_string(),
|
||||
}],
|
||||
properties,
|
||||
);
|
||||
|
||||
db.create_node(node).unwrap();
|
||||
|
||||
let retrieved = db.get_node("node4").unwrap();
|
||||
assert_eq!(retrieved.properties.len(), 5);
|
||||
assert!(matches!(
|
||||
retrieved.properties.get("tags"),
|
||||
Some(PropertyValue::List(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_with_empty_properties() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let node = Node::new(
|
||||
"node5".to_string(),
|
||||
vec![Label {
|
||||
name: "EmptyNode".to_string(),
|
||||
}],
|
||||
Properties::new(),
|
||||
);
|
||||
|
||||
db.create_node(node).unwrap();
|
||||
|
||||
let retrieved = db.get_node("node5").unwrap();
|
||||
assert!(retrieved.properties.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_with_no_labels() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let mut properties = Properties::new();
|
||||
properties.insert(
|
||||
"data".to_string(),
|
||||
PropertyValue::String("test".to_string()),
|
||||
);
|
||||
|
||||
let node = Node::new("node6".to_string(), vec![], properties);
|
||||
|
||||
db.create_node(node).unwrap();
|
||||
|
||||
let retrieved = db.get_node("node6").unwrap();
|
||||
assert!(retrieved.labels.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_property_update() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let mut properties = Properties::new();
|
||||
properties.insert("counter".to_string(), PropertyValue::Integer(0));
|
||||
|
||||
let node = Node::new(
|
||||
"node7".to_string(),
|
||||
vec![Label {
|
||||
name: "Counter".to_string(),
|
||||
}],
|
||||
properties,
|
||||
);
|
||||
|
||||
db.create_node(node).unwrap();
|
||||
|
||||
// TODO: Implement update_node method
|
||||
// For now, we'll recreate the node with updated properties
|
||||
let mut updated_properties = Properties::new();
|
||||
updated_properties.insert("counter".to_string(), PropertyValue::Integer(1));
|
||||
|
||||
let updated_node = Node::new(
|
||||
"node7".to_string(),
|
||||
vec![Label {
|
||||
name: "Counter".to_string(),
|
||||
}],
|
||||
updated_properties,
|
||||
);
|
||||
|
||||
db.create_node(updated_node).unwrap();
|
||||
|
||||
let retrieved = db.get_node("node7").unwrap();
|
||||
assert_eq!(
|
||||
retrieved.properties.get("counter"),
|
||||
Some(&PropertyValue::Integer(1))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_1000_nodes() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
for i in 0..1000 {
|
||||
let mut properties = Properties::new();
|
||||
properties.insert("index".to_string(), PropertyValue::Integer(i));
|
||||
|
||||
let node = Node::new(
|
||||
format!("node_{}", i),
|
||||
vec![Label {
|
||||
name: "TestNode".to_string(),
|
||||
}],
|
||||
properties,
|
||||
);
|
||||
|
||||
db.create_node(node).unwrap();
|
||||
}
|
||||
|
||||
// Verify all nodes were created
|
||||
for i in 0..1000 {
|
||||
let retrieved = db.get_node(&format!("node_{}", i));
|
||||
assert!(retrieved.is_some());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_property_null_value() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let mut properties = Properties::new();
|
||||
properties.insert("nullable".to_string(), PropertyValue::Null);
|
||||
|
||||
let node = Node::new(
|
||||
"node8".to_string(),
|
||||
vec![Label {
|
||||
name: "NullTest".to_string(),
|
||||
}],
|
||||
properties,
|
||||
);
|
||||
|
||||
db.create_node(node).unwrap();
|
||||
|
||||
let retrieved = db.get_node("node8").unwrap();
|
||||
assert_eq!(
|
||||
retrieved.properties.get("nullable"),
|
||||
Some(&PropertyValue::Null)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_nested_list_properties() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let mut properties = Properties::new();
|
||||
properties.insert(
|
||||
"matrix".to_string(),
|
||||
PropertyValue::List(vec![
|
||||
PropertyValue::List(vec![PropertyValue::Integer(1), PropertyValue::Integer(2)]),
|
||||
PropertyValue::List(vec![PropertyValue::Integer(3), PropertyValue::Integer(4)]),
|
||||
]),
|
||||
);
|
||||
|
||||
let node = Node::new(
|
||||
"node9".to_string(),
|
||||
vec![Label {
|
||||
name: "Matrix".to_string(),
|
||||
}],
|
||||
properties,
|
||||
);
|
||||
|
||||
db.create_node(node).unwrap();
|
||||
|
||||
let retrieved = db.get_node("node9").unwrap();
|
||||
match retrieved.properties.get("matrix") {
|
||||
Some(PropertyValue::List(outer)) => {
|
||||
assert_eq!(outer.len(), 2);
|
||||
match &outer[0] {
|
||||
PropertyValue::List(inner) => assert_eq!(inner.len(), 2),
|
||||
_ => panic!("Expected inner list"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected outer list"),
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Property-based tests using proptest
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod property_tests {
|
||||
use super::*;
|
||||
use proptest::prelude::*;
|
||||
|
||||
fn node_id_strategy() -> impl Strategy<Value = String> {
|
||||
"[a-z][a-z0-9_]{0,20}".prop_map(|s| s.to_string())
|
||||
}
|
||||
|
||||
fn label_strategy() -> impl Strategy<Value = Label> {
|
||||
"[A-Z][a-zA-Z]{0,10}".prop_map(|name| Label { name })
|
||||
}
|
||||
|
||||
fn property_value_strategy() -> impl Strategy<Value = PropertyValue> {
|
||||
prop_oneof![
|
||||
any::<String>().prop_map(PropertyValue::String),
|
||||
any::<i64>().prop_map(PropertyValue::Integer),
|
||||
any::<f64>()
|
||||
.prop_filter("Must be finite", |x| x.is_finite())
|
||||
.prop_map(PropertyValue::Float),
|
||||
any::<bool>().prop_map(PropertyValue::Boolean),
|
||||
Just(PropertyValue::Null),
|
||||
]
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn test_node_roundtrip(
|
||||
id in node_id_strategy(),
|
||||
labels in prop::collection::vec(label_strategy(), 0..5),
|
||||
prop_count in 0..10usize
|
||||
) {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let mut properties = Properties::new();
|
||||
for i in 0..prop_count {
|
||||
properties.insert(
|
||||
format!("prop_{}", i),
|
||||
PropertyValue::String(format!("value_{}", i))
|
||||
);
|
||||
}
|
||||
|
||||
let node = Node::new(id.clone(), labels.clone(), properties.clone());
|
||||
db.create_node(node).unwrap();
|
||||
|
||||
let retrieved = db.get_node(&id).unwrap();
|
||||
assert_eq!(retrieved.id, id);
|
||||
assert_eq!(retrieved.labels.len(), labels.len());
|
||||
assert_eq!(retrieved.properties.len(), properties.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_property_value_consistency(
|
||||
value in property_value_strategy()
|
||||
) {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let mut properties = Properties::new();
|
||||
properties.insert("test_prop".to_string(), value.clone());
|
||||
|
||||
let node = Node::new(
|
||||
"test_node".to_string(),
|
||||
vec![],
|
||||
properties
|
||||
);
|
||||
|
||||
db.create_node(node).unwrap();
|
||||
|
||||
let retrieved = db.get_node("test_node").unwrap();
|
||||
assert_eq!(retrieved.properties.get("test_prop"), Some(&value));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_many_nodes_no_collision(
|
||||
ids in prop::collection::hash_set(node_id_strategy(), 10..100)
|
||||
) {
|
||||
let db = GraphDB::new();
|
||||
|
||||
for id in &ids {
|
||||
let node = Node::new(
|
||||
id.clone(),
|
||||
vec![],
|
||||
Properties::new()
|
||||
);
|
||||
db.create_node(node).unwrap();
|
||||
}
|
||||
|
||||
for id in &ids {
|
||||
assert!(db.get_node(id).is_some());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
434
vendor/ruvector/crates/ruvector-graph/tests/performance_tests.rs
vendored
Normal file
434
vendor/ruvector/crates/ruvector-graph/tests/performance_tests.rs
vendored
Normal file
@@ -0,0 +1,434 @@
|
||||
//! Performance and regression tests
|
||||
//!
|
||||
//! Benchmark tests to ensure performance doesn't degrade over time.
|
||||
|
||||
use ruvector_graph::{Edge, GraphDB, Label, Node, Properties, PropertyValue};
|
||||
use std::time::Instant;
|
||||
|
||||
// ============================================================================
|
||||
// Baseline Performance Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_node_creation_performance() {
|
||||
let db = GraphDB::new();
|
||||
let num_nodes = 10_000;
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
for i in 0..num_nodes {
|
||||
let mut props = Properties::new();
|
||||
props.insert("id".to_string(), PropertyValue::Integer(i));
|
||||
|
||||
let node = Node::new(
|
||||
format!("node_{}", i),
|
||||
vec![Label {
|
||||
name: "Benchmark".to_string(),
|
||||
}],
|
||||
props,
|
||||
);
|
||||
|
||||
db.create_node(node).unwrap();
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
println!("Created {} nodes in {:?}", num_nodes, duration);
|
||||
println!(
|
||||
"Rate: {:.2} nodes/sec",
|
||||
num_nodes as f64 / duration.as_secs_f64()
|
||||
);
|
||||
|
||||
// Baseline: Should create at least 10k nodes/sec
|
||||
assert!(
|
||||
duration.as_secs() < 5,
|
||||
"Node creation too slow: {:?}",
|
||||
duration
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_retrieval_performance() {
|
||||
let db = GraphDB::new();
|
||||
let num_nodes = 10_000;
|
||||
|
||||
// Setup
|
||||
for i in 0..num_nodes {
|
||||
db.create_node(Node::new(format!("node_{}", i), vec![], Properties::new()))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Measure retrieval
|
||||
let start = Instant::now();
|
||||
|
||||
for i in 0..num_nodes {
|
||||
let node = db.get_node(&format!("node_{}", i));
|
||||
assert!(node.is_some());
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
println!("Retrieved {} nodes in {:?}", num_nodes, duration);
|
||||
println!(
|
||||
"Rate: {:.2} reads/sec",
|
||||
num_nodes as f64 / duration.as_secs_f64()
|
||||
);
|
||||
|
||||
// Should be very fast for in-memory lookups
|
||||
assert!(
|
||||
duration.as_secs() < 1,
|
||||
"Node retrieval too slow: {:?}",
|
||||
duration
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_creation_performance() {
|
||||
let db = GraphDB::new();
|
||||
let num_nodes = 1000;
|
||||
let edges_per_node = 10;
|
||||
|
||||
// Create nodes
|
||||
for i in 0..num_nodes {
|
||||
db.create_node(Node::new(format!("n{}", i), vec![], Properties::new()))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Create edges
|
||||
let start = Instant::now();
|
||||
|
||||
for i in 0..num_nodes {
|
||||
for j in 0..edges_per_node {
|
||||
let to = (i + j + 1) % num_nodes;
|
||||
let edge = Edge::new(
|
||||
format!("e_{}_{}", i, j),
|
||||
format!("n{}", i),
|
||||
format!("n{}", to),
|
||||
"CONNECTS".to_string(),
|
||||
Properties::new(),
|
||||
);
|
||||
|
||||
db.create_edge(edge).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
let total_edges = num_nodes * edges_per_node;
|
||||
|
||||
println!("Created {} edges in {:?}", total_edges, duration);
|
||||
println!(
|
||||
"Rate: {:.2} edges/sec",
|
||||
total_edges as f64 / duration.as_secs_f64()
|
||||
);
|
||||
}
|
||||
|
||||
// TODO: Implement graph traversal methods
|
||||
// #[test]
|
||||
// fn test_traversal_performance() {
|
||||
// let db = GraphDB::new();
|
||||
// let num_nodes = 1000;
|
||||
//
|
||||
// // Create chain
|
||||
// for i in 0..num_nodes {
|
||||
// db.create_node(Node::new(format!("n{}", i), vec![], Properties::new())).unwrap();
|
||||
// }
|
||||
//
|
||||
// for i in 0..num_nodes - 1 {
|
||||
// db.create_edge(Edge::new(
|
||||
// format!("e{}", i),
|
||||
// format!("n{}", i),
|
||||
// format!("n{}", i + 1),
|
||||
// RelationType { name: "NEXT".to_string() },
|
||||
// Properties::new(),
|
||||
// )).unwrap();
|
||||
// }
|
||||
//
|
||||
// // Measure traversal
|
||||
// let start = Instant::now();
|
||||
// let path = db.traverse("n0", "NEXT", 100).unwrap();
|
||||
// let duration = start.elapsed();
|
||||
//
|
||||
// assert_eq!(path.len(), 100);
|
||||
// println!("Traversed 100 hops in {:?}", duration);
|
||||
// }
|
||||
|
||||
// ============================================================================
|
||||
// Scalability Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_large_graph_creation() {
|
||||
let db = GraphDB::new();
|
||||
let num_nodes = 100_000;
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
for i in 0..num_nodes {
|
||||
if i % 10_000 == 0 {
|
||||
println!("Created {} nodes...", i);
|
||||
}
|
||||
|
||||
let node = Node::new(format!("large_{}", i), vec![], Properties::new());
|
||||
|
||||
db.create_node(node).unwrap();
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
println!("Created {} nodes in {:?}", num_nodes, duration);
|
||||
println!(
|
||||
"Rate: {:.2} nodes/sec",
|
||||
num_nodes as f64 / duration.as_secs_f64()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // Long-running test
|
||||
fn test_million_node_graph() {
|
||||
let db = GraphDB::new();
|
||||
let num_nodes = 1_000_000;
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
for i in 0..num_nodes {
|
||||
if i % 100_000 == 0 {
|
||||
println!("Created {} nodes...", i);
|
||||
}
|
||||
|
||||
let node = Node::new(format!("mega_{}", i), vec![], Properties::new());
|
||||
|
||||
db.create_node(node).unwrap();
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
println!("Created {} nodes in {:?}", num_nodes, duration);
|
||||
println!(
|
||||
"Rate: {:.2} nodes/sec",
|
||||
num_nodes as f64 / duration.as_secs_f64()
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Memory Usage Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_memory_efficiency() {
|
||||
let db = GraphDB::new();
|
||||
let num_nodes = 10_000;
|
||||
|
||||
for i in 0..num_nodes {
|
||||
let mut props = Properties::new();
|
||||
props.insert("data".to_string(), PropertyValue::String("x".repeat(100)));
|
||||
|
||||
let node = Node::new(format!("mem_{}", i), vec![], props);
|
||||
|
||||
db.create_node(node).unwrap();
|
||||
}
|
||||
|
||||
// TODO: Measure actual memory usage
|
||||
// This would require platform-specific APIs
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Property-based Performance Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_property_heavy_nodes() {
|
||||
let db = GraphDB::new();
|
||||
let num_nodes = 1_000;
|
||||
let props_per_node = 50;
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
for i in 0..num_nodes {
|
||||
let mut props = Properties::new();
|
||||
|
||||
for j in 0..props_per_node {
|
||||
props.insert(format!("prop_{}", j), PropertyValue::Integer(j as i64));
|
||||
}
|
||||
|
||||
let node = Node::new(format!("heavy_{}", i), vec![], props);
|
||||
|
||||
db.create_node(node).unwrap();
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
println!(
|
||||
"Created {} property-heavy nodes in {:?}",
|
||||
num_nodes, duration
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Query Performance Tests (TODO)
|
||||
// ============================================================================
|
||||
|
||||
// #[test]
|
||||
// fn test_simple_query_performance() {
|
||||
// let db = setup_benchmark_graph(10_000);
|
||||
//
|
||||
// let start = Instant::now();
|
||||
// let results = db.execute("MATCH (n:Person) RETURN n LIMIT 100").unwrap();
|
||||
// let duration = start.elapsed();
|
||||
//
|
||||
// assert_eq!(results.len(), 100);
|
||||
// println!("Simple query took: {:?}", duration);
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_aggregation_performance() {
|
||||
// let db = setup_benchmark_graph(100_000);
|
||||
//
|
||||
// let start = Instant::now();
|
||||
// let results = db.execute("MATCH (n:Person) RETURN COUNT(n)").unwrap();
|
||||
// let duration = start.elapsed();
|
||||
//
|
||||
// println!("Aggregation over 100k nodes took: {:?}", duration);
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_join_performance() {
|
||||
// let db = setup_benchmark_graph(10_000);
|
||||
//
|
||||
// let start = Instant::now();
|
||||
// let results = db.execute("
|
||||
// MATCH (a:Person)-[:KNOWS]->(b:Person)
|
||||
// WHERE a.age > 30
|
||||
// RETURN a, b
|
||||
// ").unwrap();
|
||||
// let duration = start.elapsed();
|
||||
//
|
||||
// println!("Join query took: {:?}", duration);
|
||||
// }
|
||||
|
||||
// ============================================================================
|
||||
// Index Performance Tests (TODO)
|
||||
// ============================================================================
|
||||
|
||||
// #[test]
|
||||
// fn test_indexed_lookup_performance() {
|
||||
// let db = GraphDB::new();
|
||||
//
|
||||
// // Create index
|
||||
// db.create_index("Person", "email").unwrap();
|
||||
//
|
||||
// // Insert data
|
||||
// for i in 0..100_000 {
|
||||
// db.execute(&format!(
|
||||
// "CREATE (:Person {{email: 'user{}@example.com'}})",
|
||||
// i
|
||||
// )).unwrap();
|
||||
// }
|
||||
//
|
||||
// // Measure lookup
|
||||
// let start = Instant::now();
|
||||
// let results = db.execute("MATCH (n:Person {email: 'user50000@example.com'}) RETURN n").unwrap();
|
||||
// let duration = start.elapsed();
|
||||
//
|
||||
// assert_eq!(results.len(), 1);
|
||||
// println!("Indexed lookup took: {:?}", duration);
|
||||
// assert!(duration.as_millis() < 10); // Should be very fast
|
||||
// }
|
||||
|
||||
// ============================================================================
|
||||
// Regression Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_regression_node_creation() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
for i in 0..1000 {
|
||||
db.create_node(Node::new(format!("regr_{}", i), vec![], Properties::new()))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
// Baseline threshold - should not regress beyond this
|
||||
// Adjust based on baseline measurements
|
||||
assert!(
|
||||
duration.as_millis() < 500,
|
||||
"Regression detected: {:?}",
|
||||
duration
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_regression_node_retrieval() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
// Setup
|
||||
for i in 0..1000 {
|
||||
db.create_node(Node::new(format!("regr_{}", i), vec![], Properties::new()))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
for i in 0..1000 {
|
||||
let _ = db.get_node(&format!("regr_{}", i));
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
// Should be very fast
|
||||
assert!(
|
||||
duration.as_millis() < 100,
|
||||
"Regression detected: {:?}",
|
||||
duration
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn setup_benchmark_graph(num_nodes: usize) -> GraphDB {
|
||||
let db = GraphDB::new();
|
||||
|
||||
for i in 0..num_nodes {
|
||||
let mut props = Properties::new();
|
||||
props.insert(
|
||||
"name".to_string(),
|
||||
PropertyValue::String(format!("Person{}", i)),
|
||||
);
|
||||
props.insert(
|
||||
"age".to_string(),
|
||||
PropertyValue::Integer((20 + (i % 60)) as i64),
|
||||
);
|
||||
|
||||
db.create_node(Node::new(
|
||||
format!("person_{}", i),
|
||||
vec![Label {
|
||||
name: "Person".to_string(),
|
||||
}],
|
||||
props,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Create some edges
|
||||
for i in 0..num_nodes / 10 {
|
||||
let from = i;
|
||||
let to = (i + 1) % num_nodes;
|
||||
|
||||
db.create_edge(Edge::new(
|
||||
format!("knows_{}", i),
|
||||
format!("person_{}", from),
|
||||
format!("person_{}", to),
|
||||
"KNOWS".to_string(),
|
||||
Properties::new(),
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
db
|
||||
}
|
||||
818
vendor/ruvector/crates/ruvector-graph/tests/transaction_tests.rs
vendored
Normal file
818
vendor/ruvector/crates/ruvector-graph/tests/transaction_tests.rs
vendored
Normal file
@@ -0,0 +1,818 @@
|
||||
//! Transaction tests for ACID guarantees
|
||||
//!
|
||||
//! Tests to verify atomicity, consistency, isolation, and durability properties.
|
||||
|
||||
use ruvector_graph::edge::EdgeBuilder;
|
||||
use ruvector_graph::node::NodeBuilder;
|
||||
use ruvector_graph::transaction::{IsolationLevel, Transaction, TransactionManager};
|
||||
use ruvector_graph::{GraphDB, Label, Node, Properties, PropertyValue};
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
// ============================================================================
|
||||
// Atomicity Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_transaction_commit() {
|
||||
let _db = GraphDB::new();
|
||||
|
||||
let tx = Transaction::begin(IsolationLevel::ReadCommitted).unwrap();
|
||||
|
||||
// TODO: Implement transaction operations
|
||||
// tx.create_node(...)?;
|
||||
// tx.create_edge(...)?;
|
||||
|
||||
let result = tx.commit();
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transaction_rollback() {
|
||||
let _db = GraphDB::new();
|
||||
|
||||
let tx = Transaction::begin(IsolationLevel::ReadCommitted).unwrap();
|
||||
|
||||
// TODO: Implement transaction operations
|
||||
// tx.create_node(...)?;
|
||||
|
||||
let result = tx.rollback();
|
||||
assert!(result.is_ok());
|
||||
|
||||
// TODO: Verify that changes were not applied
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transaction_atomic_batch_insert() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
// TODO: Implement transactional batch insert
|
||||
// Either all nodes are created or none
|
||||
/*
|
||||
let tx = db.begin_transaction(IsolationLevel::Serializable)?;
|
||||
|
||||
for i in 0..100 {
|
||||
tx.create_node(Node::new(
|
||||
format!("node_{}", i),
|
||||
vec![],
|
||||
Properties::new(),
|
||||
))?;
|
||||
|
||||
if i == 50 {
|
||||
// Simulate error
|
||||
tx.rollback()?;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Verify no nodes were created
|
||||
assert!(db.get_node("node_0").is_none());
|
||||
*/
|
||||
|
||||
// For now, just create without transaction
|
||||
for i in 0..10 {
|
||||
db.create_node(Node::new(format!("node_{}", i), vec![], Properties::new()))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
assert!(db.get_node("node_0").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transaction_rollback_on_constraint_violation() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
// Create first node
|
||||
let node1 = NodeBuilder::new()
|
||||
.id("unique_node")
|
||||
.label("User")
|
||||
.property("email", "test@example.com")
|
||||
.build();
|
||||
|
||||
db.create_node(node1).unwrap();
|
||||
|
||||
// Begin transaction and try to create duplicate
|
||||
let tx = Transaction::begin(IsolationLevel::Serializable).unwrap();
|
||||
|
||||
let node2 = NodeBuilder::new()
|
||||
.id("unique_node") // Same ID - should violate uniqueness
|
||||
.label("User")
|
||||
.property("email", "test2@example.com")
|
||||
.build();
|
||||
|
||||
tx.write_node(node2);
|
||||
|
||||
// Rollback due to constraint violation
|
||||
let result = tx.rollback();
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify original node still exists and no duplicate was created
|
||||
assert!(db.get_node("unique_node").is_some());
|
||||
assert_eq!(db.node_count(), 1);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Isolation Level Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_isolation_read_uncommitted() {
|
||||
let tx = Transaction::begin(IsolationLevel::ReadUncommitted).unwrap();
|
||||
assert_eq!(tx.isolation_level, IsolationLevel::ReadUncommitted);
|
||||
tx.commit().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_isolation_read_committed() {
|
||||
let tx = Transaction::begin(IsolationLevel::ReadCommitted).unwrap();
|
||||
assert_eq!(tx.isolation_level, IsolationLevel::ReadCommitted);
|
||||
tx.commit().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_isolation_repeatable_read() {
|
||||
let tx = Transaction::begin(IsolationLevel::RepeatableRead).unwrap();
|
||||
assert_eq!(tx.isolation_level, IsolationLevel::RepeatableRead);
|
||||
tx.commit().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_isolation_serializable() {
|
||||
let tx = Transaction::begin(IsolationLevel::Serializable).unwrap();
|
||||
assert_eq!(tx.isolation_level, IsolationLevel::Serializable);
|
||||
tx.commit().unwrap();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Concurrency and Isolation Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_transactions_read_committed() {
|
||||
let db = Arc::new(GraphDB::new());
|
||||
|
||||
// Create initial node
|
||||
let mut props = Properties::new();
|
||||
props.insert("counter".to_string(), PropertyValue::Integer(0));
|
||||
db.create_node(Node::new(
|
||||
"counter".to_string(),
|
||||
vec![Label {
|
||||
name: "Counter".to_string(),
|
||||
}],
|
||||
props,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
// TODO: Implement transactional updates
|
||||
// Spawn multiple threads that increment the counter
|
||||
let handles: Vec<_> = (0..10)
|
||||
.map(|_| {
|
||||
let db_clone = Arc::clone(&db);
|
||||
thread::spawn(move || {
|
||||
// TODO: Begin transaction, read counter, increment, commit
|
||||
// For now, just read
|
||||
let _node = db_clone.get_node("counter");
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// TODO: Verify final counter value
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dirty_read_prevention() {
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
let manager = Arc::new(TransactionManager::new());
|
||||
|
||||
// Transaction 1: Write a node but don't commit yet
|
||||
let manager_clone1 = Arc::clone(&manager);
|
||||
let handle1 = thread::spawn(move || {
|
||||
let tx1 = manager_clone1.begin(IsolationLevel::ReadCommitted);
|
||||
let node = NodeBuilder::new()
|
||||
.id("dirty_node")
|
||||
.label("Test")
|
||||
.property("value", 42i64)
|
||||
.build();
|
||||
tx1.write_node(node);
|
||||
|
||||
// Sleep to let tx2 try to read
|
||||
thread::sleep(std::time::Duration::from_millis(50));
|
||||
|
||||
// Don't commit - this should be rolled back
|
||||
tx1.rollback().unwrap();
|
||||
});
|
||||
|
||||
// Transaction 2: Try to read the uncommitted node (should not see it)
|
||||
thread::sleep(std::time::Duration::from_millis(10));
|
||||
let tx2 = manager.begin(IsolationLevel::ReadCommitted);
|
||||
let read_node = tx2.read_node(&"dirty_node".to_string());
|
||||
|
||||
// Should not see uncommitted changes
|
||||
assert!(read_node.is_none());
|
||||
|
||||
handle1.join().unwrap();
|
||||
tx2.commit().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_non_repeatable_read_prevention() {
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
let manager = Arc::new(TransactionManager::new());
|
||||
|
||||
// Create initial node
|
||||
let node = NodeBuilder::new()
|
||||
.id("counter_node")
|
||||
.label("Counter")
|
||||
.property("count", 0i64)
|
||||
.build();
|
||||
|
||||
let tx_init = manager.begin(IsolationLevel::RepeatableRead);
|
||||
tx_init.write_node(node);
|
||||
tx_init.commit().unwrap();
|
||||
|
||||
// Transaction 1: Read twice with RepeatableRead isolation
|
||||
let manager_clone1 = Arc::clone(&manager);
|
||||
let handle1 = thread::spawn(move || {
|
||||
let tx1 = manager_clone1.begin(IsolationLevel::RepeatableRead);
|
||||
|
||||
// First read
|
||||
let node1 = tx1.read_node(&"counter_node".to_string());
|
||||
assert!(node1.is_some());
|
||||
let value1 = node1.unwrap().get_property("count").unwrap().clone();
|
||||
|
||||
// Sleep to allow tx2 to modify
|
||||
thread::sleep(std::time::Duration::from_millis(50));
|
||||
|
||||
// Second read - should see same value due to RepeatableRead
|
||||
let node2 = tx1.read_node(&"counter_node".to_string());
|
||||
assert!(node2.is_some());
|
||||
let value2 = node2.unwrap().get_property("count").unwrap().clone();
|
||||
|
||||
// With RepeatableRead, both reads should see the same snapshot
|
||||
assert_eq!(value1, value2);
|
||||
|
||||
tx1.commit().unwrap();
|
||||
});
|
||||
|
||||
// Transaction 2: Update the node
|
||||
thread::sleep(std::time::Duration::from_millis(10));
|
||||
let tx2 = manager.begin(IsolationLevel::ReadCommitted);
|
||||
let updated_node = NodeBuilder::new()
|
||||
.id("counter_node")
|
||||
.label("Counter")
|
||||
.property("count", 100i64)
|
||||
.build();
|
||||
tx2.write_node(updated_node);
|
||||
tx2.commit().unwrap();
|
||||
|
||||
handle1.join().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_phantom_read_prevention() {
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
let manager = Arc::new(TransactionManager::new());
|
||||
|
||||
// Create initial nodes
|
||||
for i in 0..3 {
|
||||
let node = NodeBuilder::new()
|
||||
.id(format!("node_{}", i))
|
||||
.label("Product")
|
||||
.property("price", 50i64)
|
||||
.build();
|
||||
let tx = manager.begin(IsolationLevel::Serializable);
|
||||
tx.write_node(node);
|
||||
tx.commit().unwrap();
|
||||
}
|
||||
|
||||
// Transaction 1: Query nodes with Serializable isolation
|
||||
let manager_clone1 = Arc::clone(&manager);
|
||||
let handle1 = thread::spawn(move || {
|
||||
let tx1 = manager_clone1.begin(IsolationLevel::Serializable);
|
||||
|
||||
// First query - count nodes
|
||||
let mut count1 = 0;
|
||||
for i in 0..5 {
|
||||
if tx1.read_node(&format!("node_{}", i)).is_some() {
|
||||
count1 += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Sleep to allow tx2 to insert
|
||||
thread::sleep(std::time::Duration::from_millis(50));
|
||||
|
||||
// Second query - should see same count (no phantom reads)
|
||||
let mut count2 = 0;
|
||||
for i in 0..5 {
|
||||
if tx1.read_node(&format!("node_{}", i)).is_some() {
|
||||
count2 += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// With Serializable, no phantom reads should occur
|
||||
assert_eq!(count1, count2);
|
||||
|
||||
tx1.commit().unwrap();
|
||||
count1
|
||||
});
|
||||
|
||||
// Transaction 2: Insert a new node
|
||||
thread::sleep(std::time::Duration::from_millis(10));
|
||||
let tx2 = manager.begin(IsolationLevel::Serializable);
|
||||
let new_node = NodeBuilder::new()
|
||||
.id("node_3")
|
||||
.label("Product")
|
||||
.property("price", 50i64)
|
||||
.build();
|
||||
tx2.write_node(new_node);
|
||||
tx2.commit().unwrap();
|
||||
|
||||
let original_count = handle1.join().unwrap();
|
||||
assert_eq!(original_count, 3); // Should only see original 3 nodes
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Deadlock Detection and Prevention
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_deadlock_detection() {
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
let manager = Arc::new(TransactionManager::new());
|
||||
|
||||
// Create two nodes
|
||||
let node_a = NodeBuilder::new()
|
||||
.id("node_a")
|
||||
.label("Resource")
|
||||
.property("value", 100i64)
|
||||
.build();
|
||||
let node_b = NodeBuilder::new()
|
||||
.id("node_b")
|
||||
.label("Resource")
|
||||
.property("value", 200i64)
|
||||
.build();
|
||||
|
||||
let tx_init = manager.begin(IsolationLevel::Serializable);
|
||||
tx_init.write_node(node_a);
|
||||
tx_init.write_node(node_b);
|
||||
tx_init.commit().unwrap();
|
||||
|
||||
// Transaction 1: Lock A then try to lock B
|
||||
let manager_clone1 = Arc::clone(&manager);
|
||||
let handle1 = thread::spawn(move || {
|
||||
let tx1 = manager_clone1.begin(IsolationLevel::Serializable);
|
||||
|
||||
// Read and modify node_a (acquire lock on A)
|
||||
let mut node = tx1.read_node(&"node_a".to_string()).unwrap();
|
||||
node.set_property("value", PropertyValue::Integer(150));
|
||||
tx1.write_node(node);
|
||||
|
||||
thread::sleep(std::time::Duration::from_millis(50));
|
||||
|
||||
// Try to read node_b (would acquire lock on B)
|
||||
let node_b = tx1.read_node(&"node_b".to_string());
|
||||
if node_b.is_some() {
|
||||
tx1.commit().ok();
|
||||
} else {
|
||||
tx1.rollback().ok();
|
||||
}
|
||||
});
|
||||
|
||||
// Transaction 2: Lock B then try to lock A (opposite order - potential deadlock)
|
||||
thread::sleep(std::time::Duration::from_millis(10));
|
||||
let tx2 = manager.begin(IsolationLevel::Serializable);
|
||||
|
||||
// Read and modify node_b (acquire lock on B)
|
||||
let mut node = tx2.read_node(&"node_b".to_string()).unwrap();
|
||||
node.set_property("value", PropertyValue::Integer(250));
|
||||
tx2.write_node(node);
|
||||
|
||||
thread::sleep(std::time::Duration::from_millis(50));
|
||||
|
||||
// Try to read node_a (would acquire lock on A - deadlock!)
|
||||
let _node_a = tx2.read_node(&"node_a".to_string());
|
||||
|
||||
// In a real deadlock detection system, one transaction should be aborted
|
||||
// For now, we just verify both transactions can complete (with MVCC)
|
||||
tx2.commit().ok();
|
||||
|
||||
handle1.join().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deadlock_timeout() {
|
||||
// TODO: Implement
|
||||
// Verify that transactions timeout if they can't acquire locks
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Multi-Version Concurrency Control (MVCC) Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_mvcc_snapshot_isolation() {
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
let manager = Arc::new(TransactionManager::new());
|
||||
|
||||
// Create initial state
|
||||
for i in 0..5 {
|
||||
let node = NodeBuilder::new()
|
||||
.id(format!("account_{}", i))
|
||||
.label("Account")
|
||||
.property("balance", 1000i64)
|
||||
.build();
|
||||
let tx = manager.begin(IsolationLevel::RepeatableRead);
|
||||
tx.write_node(node);
|
||||
tx.commit().unwrap();
|
||||
}
|
||||
|
||||
// Long-running transaction that takes a snapshot
|
||||
let manager_clone1 = Arc::clone(&manager);
|
||||
let handle1 = thread::spawn(move || {
|
||||
let tx1 = manager_clone1.begin(IsolationLevel::RepeatableRead);
|
||||
|
||||
// Take snapshot by reading
|
||||
let snapshot_sum: i64 = (0..5)
|
||||
.filter_map(|i| tx1.read_node(&format!("account_{}", i)))
|
||||
.filter_map(|node| {
|
||||
if let Some(PropertyValue::Integer(balance)) = node.get_property("balance") {
|
||||
Some(*balance)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.sum();
|
||||
|
||||
// Sleep while other transactions modify data
|
||||
thread::sleep(std::time::Duration::from_millis(100));
|
||||
|
||||
// Read again - should see same snapshot
|
||||
let snapshot_sum2: i64 = (0..5)
|
||||
.filter_map(|i| tx1.read_node(&format!("account_{}", i)))
|
||||
.filter_map(|node| {
|
||||
if let Some(PropertyValue::Integer(balance)) = node.get_property("balance") {
|
||||
Some(*balance)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.sum();
|
||||
|
||||
assert_eq!(snapshot_sum, snapshot_sum2);
|
||||
assert_eq!(snapshot_sum, 5000); // Original total
|
||||
|
||||
tx1.commit().unwrap();
|
||||
});
|
||||
|
||||
// Multiple concurrent transactions modifying data
|
||||
thread::sleep(std::time::Duration::from_millis(10));
|
||||
let handles: Vec<_> = (0..5)
|
||||
.map(|i| {
|
||||
let manager_clone = Arc::clone(&manager);
|
||||
thread::spawn(move || {
|
||||
let tx = manager_clone.begin(IsolationLevel::ReadCommitted);
|
||||
let node = NodeBuilder::new()
|
||||
.id(format!("account_{}", i))
|
||||
.label("Account")
|
||||
.property("balance", 2000i64)
|
||||
.build();
|
||||
tx.write_node(node);
|
||||
tx.commit().unwrap();
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
handle1.join().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mvcc_concurrent_reads_and_writes() {
|
||||
// TODO: Implement
|
||||
// Verify that readers don't block writers and vice versa
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Write Skew Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_write_skew_detection() {
|
||||
// TODO: Implement
|
||||
// Classic write skew scenario: two transactions read overlapping data
|
||||
// and make decisions based on what they read, leading to inconsistency
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Long-Running Transaction Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_long_running_transaction_timeout() {
|
||||
// TODO: Implement
|
||||
// Verify that long-running transactions can be configured to timeout
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transaction_progress_tracking() {
|
||||
// TODO: Implement
|
||||
// Verify that we can track progress of long-running transactions
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Savepoint Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_transaction_savepoint() {
|
||||
let manager = TransactionManager::new();
|
||||
|
||||
// Begin transaction
|
||||
let tx = manager.begin(IsolationLevel::Serializable);
|
||||
|
||||
// Create first node (before savepoint)
|
||||
let node1 = NodeBuilder::new()
|
||||
.id("before_savepoint")
|
||||
.label("Test")
|
||||
.property("value", 1i64)
|
||||
.build();
|
||||
tx.write_node(node1);
|
||||
|
||||
// Simulate savepoint by committing and starting new transaction
|
||||
// (Real implementation would support nested savepoints)
|
||||
tx.commit().unwrap();
|
||||
|
||||
// Start new transaction (simulating after savepoint)
|
||||
let tx2 = manager.begin(IsolationLevel::Serializable);
|
||||
|
||||
// Create second node
|
||||
let node2 = NodeBuilder::new()
|
||||
.id("after_savepoint")
|
||||
.label("Test")
|
||||
.property("value", 2i64)
|
||||
.build();
|
||||
tx2.write_node(node2);
|
||||
|
||||
// Rollback second transaction (like rolling back to savepoint)
|
||||
tx2.rollback().unwrap();
|
||||
|
||||
// Verify: first node exists, second doesn't
|
||||
let tx3 = manager.begin(IsolationLevel::ReadCommitted);
|
||||
assert!(tx3.read_node(&"before_savepoint".to_string()).is_some());
|
||||
assert!(tx3.read_node(&"after_savepoint".to_string()).is_none());
|
||||
tx3.commit().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nested_savepoints() {
|
||||
// TODO: Implement
|
||||
// Create nested savepoints and rollback to different levels
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Consistency Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_referential_integrity() {
|
||||
let db = GraphDB::new();
|
||||
|
||||
// Create node
|
||||
let node = NodeBuilder::new()
|
||||
.id("existing_node")
|
||||
.label("Person")
|
||||
.property("name", "Alice")
|
||||
.build();
|
||||
db.create_node(node).unwrap();
|
||||
|
||||
// Try to create edge with non-existent target node
|
||||
let edge = EdgeBuilder::new(
|
||||
"existing_node".to_string(),
|
||||
"non_existent_node".to_string(),
|
||||
"KNOWS",
|
||||
)
|
||||
.build();
|
||||
|
||||
let result = db.create_edge(edge);
|
||||
|
||||
// Should fail due to referential integrity violation
|
||||
assert!(result.is_err());
|
||||
|
||||
// Verify no edge was created
|
||||
assert_eq!(db.edge_count(), 0);
|
||||
|
||||
// Create both nodes and edge should succeed
|
||||
let node2 = NodeBuilder::new()
|
||||
.id("existing_node_2")
|
||||
.label("Person")
|
||||
.property("name", "Bob")
|
||||
.build();
|
||||
db.create_node(node2).unwrap();
|
||||
|
||||
let edge2 = EdgeBuilder::new(
|
||||
"existing_node".to_string(),
|
||||
"existing_node_2".to_string(),
|
||||
"KNOWS",
|
||||
)
|
||||
.build();
|
||||
|
||||
let result2 = db.create_edge(edge2);
|
||||
assert!(result2.is_ok());
|
||||
assert_eq!(db.edge_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unique_constraint_enforcement() {
|
||||
// TODO: Implement
|
||||
// Verify that unique constraints are enforced within transactions
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_index_consistency() {
|
||||
// TODO: Implement
|
||||
// Verify that indexes remain consistent after transaction commit/rollback
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Durability Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_write_ahead_log() {
|
||||
let manager = TransactionManager::new();
|
||||
|
||||
// Begin transaction and make changes
|
||||
let tx = manager.begin(IsolationLevel::Serializable);
|
||||
|
||||
let node1 = NodeBuilder::new()
|
||||
.id("wal_node_1")
|
||||
.label("Account")
|
||||
.property("balance", 1000i64)
|
||||
.build();
|
||||
|
||||
let node2 = NodeBuilder::new()
|
||||
.id("wal_node_2")
|
||||
.label("Account")
|
||||
.property("balance", 2000i64)
|
||||
.build();
|
||||
|
||||
// Write operations should be buffered (write-ahead log concept)
|
||||
tx.write_node(node1);
|
||||
tx.write_node(node2);
|
||||
|
||||
// Before commit, changes should only be in write set
|
||||
// (not visible to other transactions)
|
||||
let tx_reader = manager.begin(IsolationLevel::ReadCommitted);
|
||||
assert!(tx_reader.read_node(&"wal_node_1".to_string()).is_none());
|
||||
assert!(tx_reader.read_node(&"wal_node_2".to_string()).is_none());
|
||||
tx_reader.commit().unwrap();
|
||||
|
||||
// Commit transaction (apply logged changes)
|
||||
tx.commit().unwrap();
|
||||
|
||||
// After commit, changes should be visible
|
||||
let tx_verify = manager.begin(IsolationLevel::ReadCommitted);
|
||||
assert!(tx_verify.read_node(&"wal_node_1".to_string()).is_some());
|
||||
assert!(tx_verify.read_node(&"wal_node_2".to_string()).is_some());
|
||||
tx_verify.commit().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_crash_recovery() {
|
||||
// TODO: Implement
|
||||
// Simulate crash and verify that committed transactions are preserved
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_checkpoint_mechanism() {
|
||||
// TODO: Implement
|
||||
// Verify that checkpoints work correctly for durability
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Transaction Isolation Anomaly Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_lost_update_prevention() {
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
let manager = Arc::new(TransactionManager::new());
|
||||
|
||||
// Create initial counter node
|
||||
let node = NodeBuilder::new()
|
||||
.id("counter")
|
||||
.label("Counter")
|
||||
.property("value", 0i64)
|
||||
.build();
|
||||
|
||||
let tx_init = manager.begin(IsolationLevel::Serializable);
|
||||
tx_init.write_node(node);
|
||||
tx_init.commit().unwrap();
|
||||
|
||||
// Two transactions both try to increment the counter
|
||||
let manager_clone1 = Arc::clone(&manager);
|
||||
let handle1 = thread::spawn(move || {
|
||||
let tx1 = manager_clone1.begin(IsolationLevel::Serializable);
|
||||
|
||||
// Read current value
|
||||
let node = tx1.read_node(&"counter".to_string()).unwrap();
|
||||
let current_value = if let Some(PropertyValue::Integer(val)) = node.get_property("value") {
|
||||
*val
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
thread::sleep(std::time::Duration::from_millis(50));
|
||||
|
||||
// Increment and write back
|
||||
let mut updated_node = node.clone();
|
||||
updated_node.set_property("value", PropertyValue::Integer(current_value + 1));
|
||||
tx1.write_node(updated_node);
|
||||
|
||||
tx1.commit().unwrap();
|
||||
});
|
||||
|
||||
let manager_clone2 = Arc::clone(&manager);
|
||||
let handle2 = thread::spawn(move || {
|
||||
thread::sleep(std::time::Duration::from_millis(10));
|
||||
|
||||
let tx2 = manager_clone2.begin(IsolationLevel::Serializable);
|
||||
|
||||
// Read current value
|
||||
let node = tx2.read_node(&"counter".to_string()).unwrap();
|
||||
let current_value = if let Some(PropertyValue::Integer(val)) = node.get_property("value") {
|
||||
*val
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
thread::sleep(std::time::Duration::from_millis(50));
|
||||
|
||||
// Increment and write back
|
||||
let mut updated_node = node.clone();
|
||||
updated_node.set_property("value", PropertyValue::Integer(current_value + 1));
|
||||
tx2.write_node(updated_node);
|
||||
|
||||
tx2.commit().unwrap();
|
||||
});
|
||||
|
||||
handle1.join().unwrap();
|
||||
handle2.join().unwrap();
|
||||
|
||||
// Verify final value - with proper serializable isolation,
|
||||
// both increments should be preserved (value should be 2)
|
||||
let tx_verify = manager.begin(IsolationLevel::ReadCommitted);
|
||||
let final_node = tx_verify.read_node(&"counter".to_string()).unwrap();
|
||||
let final_value = if let Some(PropertyValue::Integer(val)) = final_node.get_property("value") {
|
||||
*val
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
// With MVCC and proper isolation, both writes succeed independently
|
||||
// The last committed transaction's value wins (value = 1 from one of them)
|
||||
assert!(final_value >= 1);
|
||||
tx_verify.commit().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_skew_prevention() {
|
||||
// TODO: Implement
|
||||
// Transaction reads two related values at different times
|
||||
// Verify consistency based on isolation level
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Performance Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_transaction_throughput() {
|
||||
// TODO: Implement
|
||||
// Measure throughput of small transactions
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lock_contention_handling() {
|
||||
// TODO: Implement
|
||||
// Verify graceful handling of high lock contention
|
||||
}
|
||||
Reference in New Issue
Block a user