Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
480
vendor/ruvector/crates/ruvector-cluster/src/consensus.rs
vendored
Normal file
480
vendor/ruvector/crates/ruvector-cluster/src/consensus.rs
vendored
Normal file
@@ -0,0 +1,480 @@
|
||||
//! DAG-based consensus protocol inspired by QuDAG
|
||||
//!
|
||||
//! Implements a directed acyclic graph for transaction ordering and consensus.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use dashmap::DashMap;
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, HashSet, VecDeque};
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{ClusterError, Result};
|
||||
|
||||
/// A vertex in the consensus DAG
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DagVertex {
|
||||
/// Unique vertex ID
|
||||
pub id: String,
|
||||
/// Node that created this vertex
|
||||
pub node_id: String,
|
||||
/// Transaction data
|
||||
pub transaction: Transaction,
|
||||
/// Parent vertices (edges in the DAG)
|
||||
pub parents: Vec<String>,
|
||||
/// Timestamp when vertex was created
|
||||
pub timestamp: DateTime<Utc>,
|
||||
/// Vector clock for causality tracking
|
||||
pub vector_clock: HashMap<String, u64>,
|
||||
/// Signature (in production, this would be cryptographic)
|
||||
pub signature: String,
|
||||
}
|
||||
|
||||
impl DagVertex {
|
||||
/// Create a new DAG vertex
|
||||
pub fn new(
|
||||
node_id: String,
|
||||
transaction: Transaction,
|
||||
parents: Vec<String>,
|
||||
vector_clock: HashMap<String, u64>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
node_id,
|
||||
transaction,
|
||||
parents,
|
||||
timestamp: Utc::now(),
|
||||
vector_clock,
|
||||
signature: String::new(), // Would be computed cryptographically
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify the vertex signature
|
||||
pub fn verify_signature(&self) -> bool {
|
||||
// In production, verify cryptographic signature
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// A transaction in the consensus system
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Transaction {
|
||||
/// Transaction ID
|
||||
pub id: String,
|
||||
/// Transaction type
|
||||
pub tx_type: TransactionType,
|
||||
/// Transaction data
|
||||
pub data: Vec<u8>,
|
||||
/// Nonce for ordering
|
||||
pub nonce: u64,
|
||||
}
|
||||
|
||||
/// Type of transaction
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum TransactionType {
|
||||
/// Write operation
|
||||
Write,
|
||||
/// Read operation
|
||||
Read,
|
||||
/// Delete operation
|
||||
Delete,
|
||||
/// Batch operation
|
||||
Batch,
|
||||
/// System operation
|
||||
System,
|
||||
}
|
||||
|
||||
/// DAG-based consensus engine
|
||||
pub struct DagConsensus {
|
||||
/// Node ID
|
||||
node_id: String,
|
||||
/// DAG vertices (vertex_id -> vertex)
|
||||
vertices: Arc<DashMap<String, DagVertex>>,
|
||||
/// Finalized vertices
|
||||
finalized: Arc<RwLock<HashSet<String>>>,
|
||||
/// Vector clock for this node
|
||||
vector_clock: Arc<RwLock<HashMap<String, u64>>>,
|
||||
/// Pending transactions
|
||||
pending_txs: Arc<RwLock<VecDeque<Transaction>>>,
|
||||
/// Minimum quorum size
|
||||
min_quorum_size: usize,
|
||||
/// Transaction nonce counter
|
||||
nonce_counter: Arc<RwLock<u64>>,
|
||||
}
|
||||
|
||||
impl DagConsensus {
|
||||
/// Create a new DAG consensus engine
|
||||
pub fn new(node_id: String, min_quorum_size: usize) -> Self {
|
||||
let mut vector_clock = HashMap::new();
|
||||
vector_clock.insert(node_id.clone(), 0);
|
||||
|
||||
Self {
|
||||
node_id,
|
||||
vertices: Arc::new(DashMap::new()),
|
||||
finalized: Arc::new(RwLock::new(HashSet::new())),
|
||||
vector_clock: Arc::new(RwLock::new(vector_clock)),
|
||||
pending_txs: Arc::new(RwLock::new(VecDeque::new())),
|
||||
min_quorum_size,
|
||||
nonce_counter: Arc::new(RwLock::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Submit a transaction to the consensus system
|
||||
pub fn submit_transaction(&self, tx_type: TransactionType, data: Vec<u8>) -> Result<String> {
|
||||
let mut nonce = self.nonce_counter.write();
|
||||
*nonce += 1;
|
||||
|
||||
let transaction = Transaction {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
tx_type,
|
||||
data,
|
||||
nonce: *nonce,
|
||||
};
|
||||
|
||||
let tx_id = transaction.id.clone();
|
||||
|
||||
let mut pending = self.pending_txs.write();
|
||||
pending.push_back(transaction);
|
||||
|
||||
debug!("Transaction {} submitted to consensus", tx_id);
|
||||
Ok(tx_id)
|
||||
}
|
||||
|
||||
/// Create a new vertex for pending transactions
|
||||
pub fn create_vertex(&self) -> Result<Option<DagVertex>> {
|
||||
let mut pending = self.pending_txs.write();
|
||||
|
||||
if pending.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Take the next transaction
|
||||
let transaction = pending.pop_front().unwrap();
|
||||
|
||||
// Find parent vertices (tips of the DAG)
|
||||
let parents = self.find_tips();
|
||||
|
||||
// Update vector clock
|
||||
let mut clock = self.vector_clock.write();
|
||||
let count = clock.entry(self.node_id.clone()).or_insert(0);
|
||||
*count += 1;
|
||||
|
||||
let vertex = DagVertex::new(self.node_id.clone(), transaction, parents, clock.clone());
|
||||
|
||||
let vertex_id = vertex.id.clone();
|
||||
self.vertices.insert(vertex_id.clone(), vertex.clone());
|
||||
|
||||
debug!(
|
||||
"Created vertex {} for transaction {}",
|
||||
vertex_id, vertex.transaction.id
|
||||
);
|
||||
Ok(Some(vertex))
|
||||
}
|
||||
|
||||
/// Find tip vertices (vertices with no children)
|
||||
fn find_tips(&self) -> Vec<String> {
|
||||
let mut has_children = HashSet::new();
|
||||
|
||||
// Mark all vertices that have children
|
||||
for entry in self.vertices.iter() {
|
||||
for parent in &entry.value().parents {
|
||||
has_children.insert(parent.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Find vertices without children
|
||||
self.vertices
|
||||
.iter()
|
||||
.filter(|entry| !has_children.contains(entry.key()))
|
||||
.map(|entry| entry.key().clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Add a vertex from another node
|
||||
pub fn add_vertex(&self, vertex: DagVertex) -> Result<()> {
|
||||
// Verify signature
|
||||
if !vertex.verify_signature() {
|
||||
return Err(ClusterError::ConsensusError(
|
||||
"Invalid vertex signature".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Verify parents exist
|
||||
for parent_id in &vertex.parents {
|
||||
if !self.vertices.contains_key(parent_id) && !self.is_finalized(parent_id) {
|
||||
return Err(ClusterError::ConsensusError(format!(
|
||||
"Parent vertex {} not found",
|
||||
parent_id
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Merge vector clock
|
||||
let mut clock = self.vector_clock.write();
|
||||
for (node, count) in &vertex.vector_clock {
|
||||
let existing = clock.entry(node.clone()).or_insert(0);
|
||||
*existing = (*existing).max(*count);
|
||||
}
|
||||
|
||||
self.vertices.insert(vertex.id.clone(), vertex);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a vertex is finalized
|
||||
pub fn is_finalized(&self, vertex_id: &str) -> bool {
|
||||
let finalized = self.finalized.read();
|
||||
finalized.contains(vertex_id)
|
||||
}
|
||||
|
||||
/// Finalize vertices using the wave algorithm
|
||||
pub fn finalize_vertices(&self) -> Result<Vec<String>> {
|
||||
let mut finalized_ids = Vec::new();
|
||||
|
||||
// Find vertices that can be finalized
|
||||
// A vertex is finalized if it has enough confirmations from different nodes
|
||||
let mut confirmations: HashMap<String, HashSet<String>> = HashMap::new();
|
||||
|
||||
for entry in self.vertices.iter() {
|
||||
let vertex = entry.value();
|
||||
|
||||
// Count confirmations (vertices that reference this one)
|
||||
for other_entry in self.vertices.iter() {
|
||||
if other_entry.value().parents.contains(&vertex.id) {
|
||||
confirmations
|
||||
.entry(vertex.id.clone())
|
||||
.or_insert_with(HashSet::new)
|
||||
.insert(other_entry.value().node_id.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Finalize vertices with enough confirmations
|
||||
let mut finalized = self.finalized.write();
|
||||
|
||||
for (vertex_id, confirming_nodes) in confirmations {
|
||||
if confirming_nodes.len() >= self.min_quorum_size && !finalized.contains(&vertex_id) {
|
||||
finalized.insert(vertex_id.clone());
|
||||
finalized_ids.push(vertex_id.clone());
|
||||
info!("Finalized vertex {}", vertex_id);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(finalized_ids)
|
||||
}
|
||||
|
||||
/// Get the total order of finalized transactions
|
||||
pub fn get_finalized_order(&self) -> Vec<Transaction> {
|
||||
let finalized = self.finalized.read();
|
||||
let mut ordered_txs = Vec::new();
|
||||
|
||||
// Topological sort of finalized vertices
|
||||
let finalized_vertices: Vec<_> = self
|
||||
.vertices
|
||||
.iter()
|
||||
.filter(|entry| finalized.contains(entry.key()))
|
||||
.map(|entry| entry.value().clone())
|
||||
.collect();
|
||||
|
||||
// Sort by vector clock and timestamp
|
||||
let mut sorted = finalized_vertices;
|
||||
sorted.sort_by(|a, b| {
|
||||
// First by vector clock dominance
|
||||
let a_dominates = Self::vector_clock_dominates(&a.vector_clock, &b.vector_clock);
|
||||
let b_dominates = Self::vector_clock_dominates(&b.vector_clock, &a.vector_clock);
|
||||
|
||||
if a_dominates && !b_dominates {
|
||||
std::cmp::Ordering::Less
|
||||
} else if b_dominates && !a_dominates {
|
||||
std::cmp::Ordering::Greater
|
||||
} else {
|
||||
// Fall back to timestamp
|
||||
a.timestamp.cmp(&b.timestamp)
|
||||
}
|
||||
});
|
||||
|
||||
for vertex in sorted {
|
||||
ordered_txs.push(vertex.transaction);
|
||||
}
|
||||
|
||||
ordered_txs
|
||||
}
|
||||
|
||||
/// Check if vector clock a dominates vector clock b
|
||||
fn vector_clock_dominates(a: &HashMap<String, u64>, b: &HashMap<String, u64>) -> bool {
|
||||
let mut dominates = false;
|
||||
|
||||
for (node, &a_count) in a {
|
||||
let b_count = b.get(node).copied().unwrap_or(0);
|
||||
if a_count < b_count {
|
||||
return false;
|
||||
}
|
||||
if a_count > b_count {
|
||||
dominates = true;
|
||||
}
|
||||
}
|
||||
|
||||
dominates
|
||||
}
|
||||
|
||||
/// Detect conflicts between transactions
|
||||
pub fn detect_conflicts(&self, tx1: &Transaction, tx2: &Transaction) -> bool {
|
||||
// In a real implementation, this would analyze transaction data
|
||||
// For now, conservatively assume all writes conflict
|
||||
matches!(
|
||||
(&tx1.tx_type, &tx2.tx_type),
|
||||
(TransactionType::Write, TransactionType::Write)
|
||||
| (TransactionType::Delete, TransactionType::Write)
|
||||
| (TransactionType::Write, TransactionType::Delete)
|
||||
)
|
||||
}
|
||||
|
||||
/// Get consensus statistics
|
||||
pub fn get_stats(&self) -> ConsensusStats {
|
||||
let finalized = self.finalized.read();
|
||||
let pending = self.pending_txs.read();
|
||||
|
||||
ConsensusStats {
|
||||
total_vertices: self.vertices.len(),
|
||||
finalized_vertices: finalized.len(),
|
||||
pending_transactions: pending.len(),
|
||||
tips: self.find_tips().len(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Prune old finalized vertices to save memory
|
||||
pub fn prune_old_vertices(&self, keep_count: usize) {
|
||||
let finalized = self.finalized.read();
|
||||
|
||||
if finalized.len() <= keep_count {
|
||||
return;
|
||||
}
|
||||
|
||||
// Remove oldest finalized vertices
|
||||
let mut vertices_to_remove = Vec::new();
|
||||
|
||||
for vertex_id in finalized.iter() {
|
||||
if let Some(vertex) = self.vertices.get(vertex_id) {
|
||||
vertices_to_remove.push((vertex_id.clone(), vertex.timestamp));
|
||||
}
|
||||
}
|
||||
|
||||
vertices_to_remove.sort_by_key(|(_, ts)| *ts);
|
||||
|
||||
let to_remove = vertices_to_remove.len().saturating_sub(keep_count);
|
||||
for (vertex_id, _) in vertices_to_remove.iter().take(to_remove) {
|
||||
self.vertices.remove(vertex_id);
|
||||
}
|
||||
|
||||
debug!("Pruned {} old vertices", to_remove);
|
||||
}
|
||||
}
|
||||
|
||||
/// Consensus statistics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ConsensusStats {
|
||||
pub total_vertices: usize,
|
||||
pub finalized_vertices: usize,
|
||||
pub pending_transactions: usize,
|
||||
pub tips: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_consensus_creation() {
|
||||
let consensus = DagConsensus::new("node1".to_string(), 2);
|
||||
let stats = consensus.get_stats();
|
||||
|
||||
assert_eq!(stats.total_vertices, 0);
|
||||
assert_eq!(stats.pending_transactions, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_submit_transaction() {
|
||||
let consensus = DagConsensus::new("node1".to_string(), 2);
|
||||
|
||||
let tx_id = consensus
|
||||
.submit_transaction(TransactionType::Write, vec![1, 2, 3])
|
||||
.unwrap();
|
||||
|
||||
assert!(!tx_id.is_empty());
|
||||
|
||||
let stats = consensus.get_stats();
|
||||
assert_eq!(stats.pending_transactions, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_vertex() {
|
||||
let consensus = DagConsensus::new("node1".to_string(), 2);
|
||||
|
||||
consensus
|
||||
.submit_transaction(TransactionType::Write, vec![1, 2, 3])
|
||||
.unwrap();
|
||||
|
||||
let vertex = consensus.create_vertex().unwrap();
|
||||
assert!(vertex.is_some());
|
||||
|
||||
let stats = consensus.get_stats();
|
||||
assert_eq!(stats.total_vertices, 1);
|
||||
assert_eq!(stats.pending_transactions, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_clock_dominance() {
|
||||
let mut clock1 = HashMap::new();
|
||||
clock1.insert("node1".to_string(), 2);
|
||||
clock1.insert("node2".to_string(), 1);
|
||||
|
||||
let mut clock2 = HashMap::new();
|
||||
clock2.insert("node1".to_string(), 1);
|
||||
clock2.insert("node2".to_string(), 1);
|
||||
|
||||
assert!(DagConsensus::vector_clock_dominates(&clock1, &clock2));
|
||||
assert!(!DagConsensus::vector_clock_dominates(&clock2, &clock1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conflict_detection() {
|
||||
let consensus = DagConsensus::new("node1".to_string(), 2);
|
||||
|
||||
let tx1 = Transaction {
|
||||
id: "1".to_string(),
|
||||
tx_type: TransactionType::Write,
|
||||
data: vec![1],
|
||||
nonce: 1,
|
||||
};
|
||||
|
||||
let tx2 = Transaction {
|
||||
id: "2".to_string(),
|
||||
tx_type: TransactionType::Write,
|
||||
data: vec![2],
|
||||
nonce: 2,
|
||||
};
|
||||
|
||||
assert!(consensus.detect_conflicts(&tx1, &tx2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_finalization() {
|
||||
let consensus = DagConsensus::new("node1".to_string(), 2);
|
||||
|
||||
// Create some vertices
|
||||
for i in 0..5 {
|
||||
consensus
|
||||
.submit_transaction(TransactionType::Write, vec![i])
|
||||
.unwrap();
|
||||
consensus.create_vertex().unwrap();
|
||||
}
|
||||
|
||||
// Try to finalize
|
||||
let finalized = consensus.finalize_vertices().unwrap();
|
||||
|
||||
// Without enough confirmations, nothing should be finalized yet
|
||||
// (would need vertices from other nodes)
|
||||
assert_eq!(finalized.len(), 0);
|
||||
}
|
||||
}
|
||||
383
vendor/ruvector/crates/ruvector-cluster/src/discovery.rs
vendored
Normal file
383
vendor/ruvector/crates/ruvector-cluster/src/discovery.rs
vendored
Normal file
@@ -0,0 +1,383 @@
|
||||
//! Node discovery mechanisms for cluster formation
|
||||
//!
|
||||
//! Supports static configuration and gossip-based discovery.
|
||||
|
||||
use crate::{ClusterError, ClusterNode, NodeStatus, Result};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use dashmap::DashMap;
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::time;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// Service for discovering nodes in the cluster
|
||||
#[async_trait]
|
||||
pub trait DiscoveryService: Send + Sync {
|
||||
/// Discover nodes in the cluster
|
||||
async fn discover_nodes(&self) -> Result<Vec<ClusterNode>>;
|
||||
|
||||
/// Register this node in the discovery service
|
||||
async fn register_node(&self, node: ClusterNode) -> Result<()>;
|
||||
|
||||
/// Unregister this node from the discovery service
|
||||
async fn unregister_node(&self, node_id: &str) -> Result<()>;
|
||||
|
||||
/// Update node heartbeat
|
||||
async fn heartbeat(&self, node_id: &str) -> Result<()>;
|
||||
}
|
||||
|
||||
/// Static discovery using predefined node list
|
||||
pub struct StaticDiscovery {
|
||||
/// Predefined list of nodes
|
||||
nodes: Arc<RwLock<Vec<ClusterNode>>>,
|
||||
}
|
||||
|
||||
impl StaticDiscovery {
|
||||
/// Create a new static discovery service
|
||||
pub fn new(nodes: Vec<ClusterNode>) -> Self {
|
||||
Self {
|
||||
nodes: Arc::new(RwLock::new(nodes)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a node to the static list
|
||||
pub fn add_node(&self, node: ClusterNode) {
|
||||
let mut nodes = self.nodes.write();
|
||||
nodes.push(node);
|
||||
}
|
||||
|
||||
/// Remove a node from the static list
|
||||
pub fn remove_node(&self, node_id: &str) {
|
||||
let mut nodes = self.nodes.write();
|
||||
nodes.retain(|n| n.node_id != node_id);
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl DiscoveryService for StaticDiscovery {
|
||||
async fn discover_nodes(&self) -> Result<Vec<ClusterNode>> {
|
||||
let nodes = self.nodes.read();
|
||||
Ok(nodes.clone())
|
||||
}
|
||||
|
||||
async fn register_node(&self, node: ClusterNode) -> Result<()> {
|
||||
self.add_node(node);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn unregister_node(&self, node_id: &str) -> Result<()> {
|
||||
self.remove_node(node_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn heartbeat(&self, node_id: &str) -> Result<()> {
|
||||
let mut nodes = self.nodes.write();
|
||||
if let Some(node) = nodes.iter_mut().find(|n| n.node_id == node_id) {
|
||||
node.heartbeat();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Gossip-based discovery protocol
|
||||
pub struct GossipDiscovery {
|
||||
/// Local node information
|
||||
local_node: Arc<RwLock<ClusterNode>>,
|
||||
/// Known nodes (node_id -> node)
|
||||
nodes: Arc<DashMap<String, ClusterNode>>,
|
||||
/// Seed nodes to bootstrap gossip
|
||||
seed_nodes: Vec<SocketAddr>,
|
||||
/// Gossip interval
|
||||
gossip_interval: Duration,
|
||||
/// Node timeout
|
||||
node_timeout: Duration,
|
||||
}
|
||||
|
||||
impl GossipDiscovery {
|
||||
/// Create a new gossip discovery service
|
||||
pub fn new(
|
||||
local_node: ClusterNode,
|
||||
seed_nodes: Vec<SocketAddr>,
|
||||
gossip_interval: Duration,
|
||||
node_timeout: Duration,
|
||||
) -> Self {
|
||||
let nodes = Arc::new(DashMap::new());
|
||||
nodes.insert(local_node.node_id.clone(), local_node.clone());
|
||||
|
||||
Self {
|
||||
local_node: Arc::new(RwLock::new(local_node)),
|
||||
nodes,
|
||||
seed_nodes,
|
||||
gossip_interval,
|
||||
node_timeout,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the gossip protocol
|
||||
pub async fn start(&self) -> Result<()> {
|
||||
info!("Starting gossip discovery protocol");
|
||||
|
||||
// Bootstrap from seed nodes
|
||||
self.bootstrap().await?;
|
||||
|
||||
// Start periodic gossip
|
||||
let nodes = Arc::clone(&self.nodes);
|
||||
let gossip_interval = self.gossip_interval;
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut interval = time::interval(gossip_interval);
|
||||
loop {
|
||||
interval.tick().await;
|
||||
Self::gossip_round(&nodes).await;
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Bootstrap by contacting seed nodes
|
||||
async fn bootstrap(&self) -> Result<()> {
|
||||
debug!("Bootstrapping from {} seed nodes", self.seed_nodes.len());
|
||||
|
||||
for seed_addr in &self.seed_nodes {
|
||||
// In a real implementation, this would contact the seed node
|
||||
// For now, we'll simulate it
|
||||
debug!("Contacting seed node at {}", seed_addr);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Perform a gossip round
|
||||
async fn gossip_round(nodes: &Arc<DashMap<String, ClusterNode>>) {
|
||||
// Select random subset of nodes to gossip with
|
||||
let node_list: Vec<_> = nodes.iter().map(|e| e.value().clone()).collect();
|
||||
|
||||
if node_list.len() < 2 {
|
||||
return;
|
||||
}
|
||||
|
||||
debug!("Gossiping with {} nodes", node_list.len());
|
||||
|
||||
// In a real implementation, we would:
|
||||
// 1. Select random peers
|
||||
// 2. Exchange node lists
|
||||
// 3. Merge received information
|
||||
// 4. Detect failures
|
||||
}
|
||||
|
||||
/// Merge gossip information from another node
|
||||
pub fn merge_gossip(&self, remote_nodes: Vec<ClusterNode>) {
|
||||
for node in remote_nodes {
|
||||
if let Some(mut existing) = self.nodes.get_mut(&node.node_id) {
|
||||
// Update if remote has newer information
|
||||
if node.last_seen > existing.last_seen {
|
||||
*existing = node;
|
||||
}
|
||||
} else {
|
||||
// Add new node
|
||||
self.nodes.insert(node.node_id.clone(), node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove failed nodes
|
||||
pub fn prune_failed_nodes(&self) {
|
||||
let now = Utc::now();
|
||||
self.nodes.retain(|_, node| {
|
||||
let elapsed = now
|
||||
.signed_duration_since(node.last_seen)
|
||||
.to_std()
|
||||
.unwrap_or(Duration::MAX);
|
||||
elapsed < self.node_timeout
|
||||
});
|
||||
}
|
||||
|
||||
/// Get gossip statistics
|
||||
pub fn get_stats(&self) -> GossipStats {
|
||||
let nodes: Vec<_> = self.nodes.iter().map(|e| e.value().clone()).collect();
|
||||
let healthy = nodes
|
||||
.iter()
|
||||
.filter(|n| n.is_healthy(self.node_timeout))
|
||||
.count();
|
||||
|
||||
GossipStats {
|
||||
total_nodes: nodes.len(),
|
||||
healthy_nodes: healthy,
|
||||
seed_nodes: self.seed_nodes.len(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl DiscoveryService for GossipDiscovery {
|
||||
async fn discover_nodes(&self) -> Result<Vec<ClusterNode>> {
|
||||
Ok(self.nodes.iter().map(|e| e.value().clone()).collect())
|
||||
}
|
||||
|
||||
async fn register_node(&self, node: ClusterNode) -> Result<()> {
|
||||
self.nodes.insert(node.node_id.clone(), node);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn unregister_node(&self, node_id: &str) -> Result<()> {
|
||||
self.nodes.remove(node_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn heartbeat(&self, node_id: &str) -> Result<()> {
|
||||
if let Some(mut node) = self.nodes.get_mut(node_id) {
|
||||
node.heartbeat();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Gossip protocol statistics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GossipStats {
|
||||
pub total_nodes: usize,
|
||||
pub healthy_nodes: usize,
|
||||
pub seed_nodes: usize,
|
||||
}
|
||||
|
||||
/// Multicast-based discovery (for local networks)
|
||||
pub struct MulticastDiscovery {
|
||||
/// Local node
|
||||
local_node: ClusterNode,
|
||||
/// Discovered nodes
|
||||
nodes: Arc<DashMap<String, ClusterNode>>,
|
||||
/// Multicast address
|
||||
multicast_addr: String,
|
||||
/// Multicast port
|
||||
multicast_port: u16,
|
||||
}
|
||||
|
||||
impl MulticastDiscovery {
|
||||
/// Create a new multicast discovery service
|
||||
pub fn new(local_node: ClusterNode, multicast_addr: String, multicast_port: u16) -> Self {
|
||||
Self {
|
||||
local_node,
|
||||
nodes: Arc::new(DashMap::new()),
|
||||
multicast_addr,
|
||||
multicast_port,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start multicast discovery
|
||||
pub async fn start(&self) -> Result<()> {
|
||||
info!(
|
||||
"Starting multicast discovery on {}:{}",
|
||||
self.multicast_addr, self.multicast_port
|
||||
);
|
||||
|
||||
// In a real implementation, this would:
|
||||
// 1. Join multicast group
|
||||
// 2. Send periodic announcements
|
||||
// 3. Listen for other nodes
|
||||
// 4. Update node list
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl DiscoveryService for MulticastDiscovery {
|
||||
async fn discover_nodes(&self) -> Result<Vec<ClusterNode>> {
|
||||
Ok(self.nodes.iter().map(|e| e.value().clone()).collect())
|
||||
}
|
||||
|
||||
async fn register_node(&self, node: ClusterNode) -> Result<()> {
|
||||
self.nodes.insert(node.node_id.clone(), node);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn unregister_node(&self, node_id: &str) -> Result<()> {
|
||||
self.nodes.remove(node_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn heartbeat(&self, node_id: &str) -> Result<()> {
|
||||
if let Some(mut node) = self.nodes.get_mut(node_id) {
|
||||
node.heartbeat();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
|
||||
fn create_test_node(id: &str, port: u16) -> ClusterNode {
|
||||
ClusterNode::new(
|
||||
id.to_string(),
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_static_discovery() {
|
||||
let node1 = create_test_node("node1", 8000);
|
||||
let node2 = create_test_node("node2", 8001);
|
||||
|
||||
let discovery = StaticDiscovery::new(vec![node1, node2]);
|
||||
|
||||
let nodes = discovery.discover_nodes().await.unwrap();
|
||||
assert_eq!(nodes.len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_static_discovery_register() {
|
||||
let discovery = StaticDiscovery::new(vec![]);
|
||||
|
||||
let node = create_test_node("node1", 8000);
|
||||
discovery.register_node(node).await.unwrap();
|
||||
|
||||
let nodes = discovery.discover_nodes().await.unwrap();
|
||||
assert_eq!(nodes.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_gossip_discovery() {
|
||||
let local_node = create_test_node("local", 8000);
|
||||
let seed_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 9000);
|
||||
|
||||
let discovery = GossipDiscovery::new(
|
||||
local_node,
|
||||
vec![seed_addr],
|
||||
Duration::from_secs(5),
|
||||
Duration::from_secs(30),
|
||||
);
|
||||
|
||||
let nodes = discovery.discover_nodes().await.unwrap();
|
||||
assert_eq!(nodes.len(), 1); // Only local node initially
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_gossip_merge() {
|
||||
let local_node = create_test_node("local", 8000);
|
||||
let discovery = GossipDiscovery::new(
|
||||
local_node,
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
Duration::from_secs(30),
|
||||
);
|
||||
|
||||
let remote_nodes = vec![
|
||||
create_test_node("node1", 8001),
|
||||
create_test_node("node2", 8002),
|
||||
];
|
||||
|
||||
discovery.merge_gossip(remote_nodes);
|
||||
|
||||
let stats = discovery.get_stats();
|
||||
assert_eq!(stats.total_nodes, 3); // local + 2 remote
|
||||
}
|
||||
}
|
||||
513
vendor/ruvector/crates/ruvector-cluster/src/lib.rs
vendored
Normal file
513
vendor/ruvector/crates/ruvector-cluster/src/lib.rs
vendored
Normal file
@@ -0,0 +1,513 @@
|
||||
//! Distributed clustering and sharding for ruvector
|
||||
//!
|
||||
//! This crate provides distributed coordination capabilities including:
|
||||
//! - Cluster node management and health monitoring
|
||||
//! - Consistent hashing for shard distribution
|
||||
//! - DAG-based consensus protocol
|
||||
//! - Dynamic node discovery and topology management
|
||||
|
||||
pub mod consensus;
|
||||
pub mod discovery;
|
||||
pub mod shard;
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use dashmap::DashMap;
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, error, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub use consensus::DagConsensus;
|
||||
pub use discovery::{DiscoveryService, GossipDiscovery, StaticDiscovery};
|
||||
pub use shard::{ConsistentHashRing, ShardRouter};
|
||||
|
||||
/// Cluster-related errors
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ClusterError {
|
||||
#[error("Node not found: {0}")]
|
||||
NodeNotFound(String),
|
||||
|
||||
#[error("Shard not found: {0}")]
|
||||
ShardNotFound(u32),
|
||||
|
||||
#[error("Invalid configuration: {0}")]
|
||||
InvalidConfig(String),
|
||||
|
||||
#[error("Consensus error: {0}")]
|
||||
ConsensusError(String),
|
||||
|
||||
#[error("Discovery error: {0}")]
|
||||
DiscoveryError(String),
|
||||
|
||||
#[error("Network error: {0}")]
|
||||
NetworkError(String),
|
||||
|
||||
#[error("Serialization error: {0}")]
|
||||
SerializationError(String),
|
||||
|
||||
#[error("IO error: {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ClusterError>;
|
||||
|
||||
/// Status of a cluster node
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum NodeStatus {
|
||||
/// Node is the cluster leader
|
||||
Leader,
|
||||
/// Node is a follower
|
||||
Follower,
|
||||
/// Node is campaigning to be leader
|
||||
Candidate,
|
||||
/// Node is offline or unreachable
|
||||
Offline,
|
||||
}
|
||||
|
||||
/// Information about a cluster node
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClusterNode {
|
||||
/// Unique node identifier
|
||||
pub node_id: String,
|
||||
/// Network address of the node
|
||||
pub address: SocketAddr,
|
||||
/// Current status of the node
|
||||
pub status: NodeStatus,
|
||||
/// Last time the node was seen alive
|
||||
pub last_seen: DateTime<Utc>,
|
||||
/// Metadata about the node
|
||||
pub metadata: HashMap<String, String>,
|
||||
/// Node capacity (for load balancing)
|
||||
pub capacity: f64,
|
||||
}
|
||||
|
||||
impl ClusterNode {
|
||||
/// Create a new cluster node
|
||||
pub fn new(node_id: String, address: SocketAddr) -> Self {
|
||||
Self {
|
||||
node_id,
|
||||
address,
|
||||
status: NodeStatus::Follower,
|
||||
last_seen: Utc::now(),
|
||||
metadata: HashMap::new(),
|
||||
capacity: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the node is healthy (seen recently)
|
||||
pub fn is_healthy(&self, timeout: Duration) -> bool {
|
||||
let now = Utc::now();
|
||||
let elapsed = now
|
||||
.signed_duration_since(self.last_seen)
|
||||
.to_std()
|
||||
.unwrap_or(Duration::MAX);
|
||||
elapsed < timeout
|
||||
}
|
||||
|
||||
/// Update the last seen timestamp
|
||||
pub fn heartbeat(&mut self) {
|
||||
self.last_seen = Utc::now();
|
||||
}
|
||||
}
|
||||
|
||||
/// Information about a data shard
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ShardInfo {
|
||||
/// Shard identifier
|
||||
pub shard_id: u32,
|
||||
/// Primary node responsible for this shard
|
||||
pub primary_node: String,
|
||||
/// Replica nodes for this shard
|
||||
pub replica_nodes: Vec<String>,
|
||||
/// Number of vectors in this shard
|
||||
pub vector_count: usize,
|
||||
/// Shard status
|
||||
pub status: ShardStatus,
|
||||
/// Creation timestamp
|
||||
pub created_at: DateTime<Utc>,
|
||||
/// Last modified timestamp
|
||||
pub modified_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Status of a shard
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum ShardStatus {
|
||||
/// Shard is active and serving requests
|
||||
Active,
|
||||
/// Shard is being migrated
|
||||
Migrating,
|
||||
/// Shard is being replicated
|
||||
Replicating,
|
||||
/// Shard is offline
|
||||
Offline,
|
||||
}
|
||||
|
||||
/// Cluster configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClusterConfig {
|
||||
/// Number of replica copies for each shard
|
||||
pub replication_factor: usize,
|
||||
/// Total number of shards in the cluster
|
||||
pub shard_count: u32,
|
||||
/// Interval between heartbeat checks
|
||||
pub heartbeat_interval: Duration,
|
||||
/// Timeout before considering a node offline
|
||||
pub node_timeout: Duration,
|
||||
/// Enable DAG-based consensus
|
||||
pub enable_consensus: bool,
|
||||
/// Minimum nodes required for quorum
|
||||
pub min_quorum_size: usize,
|
||||
}
|
||||
|
||||
impl Default for ClusterConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
replication_factor: 3,
|
||||
shard_count: 64,
|
||||
heartbeat_interval: Duration::from_secs(5),
|
||||
node_timeout: Duration::from_secs(30),
|
||||
enable_consensus: true,
|
||||
min_quorum_size: 2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Manages a distributed cluster of vector database nodes
|
||||
pub struct ClusterManager {
|
||||
/// Cluster configuration
|
||||
config: ClusterConfig,
|
||||
/// Map of node_id to ClusterNode
|
||||
nodes: Arc<DashMap<String, ClusterNode>>,
|
||||
/// Map of shard_id to ShardInfo
|
||||
shards: Arc<DashMap<u32, ShardInfo>>,
|
||||
/// Consistent hash ring for shard assignment
|
||||
hash_ring: Arc<RwLock<ConsistentHashRing>>,
|
||||
/// Shard router for query routing
|
||||
router: Arc<ShardRouter>,
|
||||
/// DAG-based consensus engine
|
||||
consensus: Option<Arc<DagConsensus>>,
|
||||
/// Discovery service (boxed for type erasure)
|
||||
discovery: Box<dyn DiscoveryService>,
|
||||
/// Current node ID
|
||||
node_id: String,
|
||||
}
|
||||
|
||||
impl ClusterManager {
|
||||
/// Create a new cluster manager
|
||||
pub fn new(
|
||||
config: ClusterConfig,
|
||||
node_id: String,
|
||||
discovery: Box<dyn DiscoveryService>,
|
||||
) -> Result<Self> {
|
||||
let nodes = Arc::new(DashMap::new());
|
||||
let shards = Arc::new(DashMap::new());
|
||||
let hash_ring = Arc::new(RwLock::new(ConsistentHashRing::new(
|
||||
config.replication_factor,
|
||||
)));
|
||||
let router = Arc::new(ShardRouter::new(config.shard_count));
|
||||
|
||||
let consensus = if config.enable_consensus {
|
||||
Some(Arc::new(DagConsensus::new(
|
||||
node_id.clone(),
|
||||
config.min_quorum_size,
|
||||
)))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
nodes,
|
||||
shards,
|
||||
hash_ring,
|
||||
router,
|
||||
consensus,
|
||||
discovery,
|
||||
node_id,
|
||||
})
|
||||
}
|
||||
|
||||
/// Add a node to the cluster
|
||||
pub async fn add_node(&self, node: ClusterNode) -> Result<()> {
|
||||
info!("Adding node {} to cluster", node.node_id);
|
||||
|
||||
// Add to hash ring
|
||||
{
|
||||
let mut ring = self.hash_ring.write();
|
||||
ring.add_node(node.node_id.clone());
|
||||
}
|
||||
|
||||
// Store node information
|
||||
self.nodes.insert(node.node_id.clone(), node.clone());
|
||||
|
||||
// Rebalance shards if needed
|
||||
self.rebalance_shards().await?;
|
||||
|
||||
info!("Node {} successfully added", node.node_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove a node from the cluster
|
||||
pub async fn remove_node(&self, node_id: &str) -> Result<()> {
|
||||
info!("Removing node {} from cluster", node_id);
|
||||
|
||||
// Remove from hash ring
|
||||
{
|
||||
let mut ring = self.hash_ring.write();
|
||||
ring.remove_node(node_id);
|
||||
}
|
||||
|
||||
// Remove node information
|
||||
self.nodes.remove(node_id);
|
||||
|
||||
// Rebalance shards
|
||||
self.rebalance_shards().await?;
|
||||
|
||||
info!("Node {} successfully removed", node_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get node by ID
|
||||
pub fn get_node(&self, node_id: &str) -> Option<ClusterNode> {
|
||||
self.nodes.get(node_id).map(|n| n.clone())
|
||||
}
|
||||
|
||||
/// List all nodes in the cluster
|
||||
pub fn list_nodes(&self) -> Vec<ClusterNode> {
|
||||
self.nodes
|
||||
.iter()
|
||||
.map(|entry| entry.value().clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get healthy nodes only
|
||||
pub fn healthy_nodes(&self) -> Vec<ClusterNode> {
|
||||
self.nodes
|
||||
.iter()
|
||||
.filter(|entry| entry.value().is_healthy(self.config.node_timeout))
|
||||
.map(|entry| entry.value().clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get shard information
|
||||
pub fn get_shard(&self, shard_id: u32) -> Option<ShardInfo> {
|
||||
self.shards.get(&shard_id).map(|s| s.clone())
|
||||
}
|
||||
|
||||
/// List all shards
|
||||
pub fn list_shards(&self) -> Vec<ShardInfo> {
|
||||
self.shards
|
||||
.iter()
|
||||
.map(|entry| entry.value().clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Assign a shard to nodes using consistent hashing
|
||||
pub fn assign_shard(&self, shard_id: u32) -> Result<ShardInfo> {
|
||||
let ring = self.hash_ring.read();
|
||||
let key = format!("shard:{}", shard_id);
|
||||
|
||||
let nodes = ring.get_nodes(&key, self.config.replication_factor);
|
||||
|
||||
if nodes.is_empty() {
|
||||
return Err(ClusterError::InvalidConfig(
|
||||
"No nodes available for shard assignment".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let primary_node = nodes[0].clone();
|
||||
let replica_nodes = nodes.into_iter().skip(1).collect();
|
||||
|
||||
let shard_info = ShardInfo {
|
||||
shard_id,
|
||||
primary_node,
|
||||
replica_nodes,
|
||||
vector_count: 0,
|
||||
status: ShardStatus::Active,
|
||||
created_at: Utc::now(),
|
||||
modified_at: Utc::now(),
|
||||
};
|
||||
|
||||
self.shards.insert(shard_id, shard_info.clone());
|
||||
Ok(shard_info)
|
||||
}
|
||||
|
||||
/// Rebalance shards across nodes
|
||||
async fn rebalance_shards(&self) -> Result<()> {
|
||||
debug!("Rebalancing shards across cluster");
|
||||
|
||||
for shard_id in 0..self.config.shard_count {
|
||||
if let Some(mut shard) = self.shards.get_mut(&shard_id) {
|
||||
let ring = self.hash_ring.read();
|
||||
let key = format!("shard:{}", shard_id);
|
||||
let nodes = ring.get_nodes(&key, self.config.replication_factor);
|
||||
|
||||
if !nodes.is_empty() {
|
||||
shard.primary_node = nodes[0].clone();
|
||||
shard.replica_nodes = nodes.into_iter().skip(1).collect();
|
||||
shard.modified_at = Utc::now();
|
||||
}
|
||||
} else {
|
||||
// Create new shard assignment
|
||||
self.assign_shard(shard_id)?;
|
||||
}
|
||||
}
|
||||
|
||||
debug!("Shard rebalancing complete");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Run periodic health checks
|
||||
pub async fn run_health_checks(&self) -> Result<()> {
|
||||
debug!("Running health checks");
|
||||
|
||||
let mut unhealthy_nodes = Vec::new();
|
||||
|
||||
for entry in self.nodes.iter() {
|
||||
let node = entry.value();
|
||||
if !node.is_healthy(self.config.node_timeout) {
|
||||
warn!("Node {} is unhealthy", node.node_id);
|
||||
unhealthy_nodes.push(node.node_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Mark unhealthy nodes as offline
|
||||
for node_id in unhealthy_nodes {
|
||||
if let Some(mut node) = self.nodes.get_mut(&node_id) {
|
||||
node.status = NodeStatus::Offline;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Start the cluster manager (health checks, discovery, etc.)
|
||||
pub async fn start(&self) -> Result<()> {
|
||||
info!("Starting cluster manager for node {}", self.node_id);
|
||||
|
||||
// Start discovery service
|
||||
let discovered = self.discovery.discover_nodes().await?;
|
||||
for node in discovered {
|
||||
if node.node_id != self.node_id {
|
||||
self.add_node(node).await?;
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize shards
|
||||
for shard_id in 0..self.config.shard_count {
|
||||
self.assign_shard(shard_id)?;
|
||||
}
|
||||
|
||||
info!("Cluster manager started successfully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get cluster statistics
|
||||
pub fn get_stats(&self) -> ClusterStats {
|
||||
let nodes = self.list_nodes();
|
||||
let shards = self.list_shards();
|
||||
let healthy = self.healthy_nodes();
|
||||
|
||||
ClusterStats {
|
||||
total_nodes: nodes.len(),
|
||||
healthy_nodes: healthy.len(),
|
||||
total_shards: shards.len(),
|
||||
active_shards: shards
|
||||
.iter()
|
||||
.filter(|s| s.status == ShardStatus::Active)
|
||||
.count(),
|
||||
total_vectors: shards.iter().map(|s| s.vector_count).sum(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the shard router
|
||||
pub fn router(&self) -> Arc<ShardRouter> {
|
||||
Arc::clone(&self.router)
|
||||
}
|
||||
|
||||
/// Get the consensus engine
|
||||
pub fn consensus(&self) -> Option<Arc<DagConsensus>> {
|
||||
self.consensus.as_ref().map(Arc::clone)
|
||||
}
|
||||
}
|
||||
|
||||
/// Cluster statistics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClusterStats {
|
||||
pub total_nodes: usize,
|
||||
pub healthy_nodes: usize,
|
||||
pub total_shards: usize,
|
||||
pub active_shards: usize,
|
||||
pub total_vectors: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
|
||||
fn create_test_node(id: &str, port: u16) -> ClusterNode {
|
||||
ClusterNode::new(
|
||||
id.to_string(),
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cluster_node_creation() {
|
||||
let node = create_test_node("node1", 8000);
|
||||
assert_eq!(node.node_id, "node1");
|
||||
assert_eq!(node.status, NodeStatus::Follower);
|
||||
assert!(node.is_healthy(Duration::from_secs(60)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cluster_manager_creation() {
|
||||
let config = ClusterConfig::default();
|
||||
let discovery = Box::new(StaticDiscovery::new(vec![]));
|
||||
let manager = ClusterManager::new(config, "test-node".to_string(), discovery);
|
||||
assert!(manager.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add_remove_node() {
|
||||
let config = ClusterConfig::default();
|
||||
let discovery = Box::new(StaticDiscovery::new(vec![]));
|
||||
let manager = ClusterManager::new(config, "test-node".to_string(), discovery).unwrap();
|
||||
|
||||
let node = create_test_node("node1", 8000);
|
||||
manager.add_node(node).await.unwrap();
|
||||
|
||||
assert_eq!(manager.list_nodes().len(), 1);
|
||||
|
||||
manager.remove_node("node1").await.unwrap();
|
||||
assert_eq!(manager.list_nodes().len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_shard_assignment() {
|
||||
let config = ClusterConfig {
|
||||
shard_count: 4,
|
||||
replication_factor: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let discovery = Box::new(StaticDiscovery::new(vec![]));
|
||||
let manager = ClusterManager::new(config, "test-node".to_string(), discovery).unwrap();
|
||||
|
||||
// Add some nodes
|
||||
for i in 0..3 {
|
||||
let node = create_test_node(&format!("node{}", i), 8000 + i);
|
||||
manager.add_node(node).await.unwrap();
|
||||
}
|
||||
|
||||
// Assign a shard
|
||||
let shard = manager.assign_shard(0).unwrap();
|
||||
assert_eq!(shard.shard_id, 0);
|
||||
assert!(!shard.primary_node.is_empty());
|
||||
}
|
||||
}
|
||||
443
vendor/ruvector/crates/ruvector-cluster/src/shard.rs
vendored
Normal file
443
vendor/ruvector/crates/ruvector-cluster/src/shard.rs
vendored
Normal file
@@ -0,0 +1,443 @@
|
||||
//! Sharding logic for distributed vector storage
|
||||
//!
|
||||
//! Implements consistent hashing for shard distribution and routing.
|
||||
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::sync::Arc;
|
||||
use tracing::debug;
|
||||
|
||||
const VIRTUAL_NODE_COUNT: usize = 150;
|
||||
|
||||
/// Consistent hash ring for node assignment
|
||||
#[derive(Debug)]
|
||||
pub struct ConsistentHashRing {
|
||||
/// Virtual nodes on the ring (hash -> node_id)
|
||||
ring: BTreeMap<u64, String>,
|
||||
/// Real nodes in the cluster
|
||||
nodes: HashMap<String, usize>,
|
||||
/// Replication factor
|
||||
replication_factor: usize,
|
||||
}
|
||||
|
||||
impl ConsistentHashRing {
|
||||
/// Create a new consistent hash ring
|
||||
pub fn new(replication_factor: usize) -> Self {
|
||||
Self {
|
||||
ring: BTreeMap::new(),
|
||||
nodes: HashMap::new(),
|
||||
replication_factor,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a node to the ring
|
||||
pub fn add_node(&mut self, node_id: String) {
|
||||
if self.nodes.contains_key(&node_id) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Add virtual nodes for better distribution
|
||||
for i in 0..VIRTUAL_NODE_COUNT {
|
||||
let virtual_key = format!("{}:{}", node_id, i);
|
||||
let hash = Self::hash_key(&virtual_key);
|
||||
self.ring.insert(hash, node_id.clone());
|
||||
}
|
||||
|
||||
self.nodes.insert(node_id, VIRTUAL_NODE_COUNT);
|
||||
debug!(
|
||||
"Added node to hash ring with {} virtual nodes",
|
||||
VIRTUAL_NODE_COUNT
|
||||
);
|
||||
}
|
||||
|
||||
/// Remove a node from the ring
|
||||
pub fn remove_node(&mut self, node_id: &str) {
|
||||
if !self.nodes.contains_key(node_id) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Remove all virtual nodes
|
||||
self.ring.retain(|_, v| v != node_id);
|
||||
self.nodes.remove(node_id);
|
||||
debug!("Removed node from hash ring");
|
||||
}
|
||||
|
||||
/// Get nodes responsible for a key
|
||||
pub fn get_nodes(&self, key: &str, count: usize) -> Vec<String> {
|
||||
if self.ring.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let hash = Self::hash_key(key);
|
||||
let mut nodes = Vec::new();
|
||||
let mut seen = std::collections::HashSet::new();
|
||||
|
||||
// Find the first node on or after the hash
|
||||
for (_, node_id) in self.ring.range(hash..) {
|
||||
if seen.insert(node_id.clone()) {
|
||||
nodes.push(node_id.clone());
|
||||
if nodes.len() >= count {
|
||||
return nodes;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wrap around to the beginning if needed
|
||||
for (_, node_id) in self.ring.iter() {
|
||||
if seen.insert(node_id.clone()) {
|
||||
nodes.push(node_id.clone());
|
||||
if nodes.len() >= count {
|
||||
return nodes;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nodes
|
||||
}
|
||||
|
||||
/// Get the primary node for a key
|
||||
pub fn get_primary_node(&self, key: &str) -> Option<String> {
|
||||
self.get_nodes(key, 1).first().cloned()
|
||||
}
|
||||
|
||||
/// Hash a key to a u64
|
||||
fn hash_key(key: &str) -> u64 {
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
let mut hasher = DefaultHasher::new();
|
||||
key.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
/// Get the number of real nodes
|
||||
pub fn node_count(&self) -> usize {
|
||||
self.nodes.len()
|
||||
}
|
||||
|
||||
/// List all real nodes
|
||||
pub fn list_nodes(&self) -> Vec<String> {
|
||||
self.nodes.keys().cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Routes queries to the correct shard
|
||||
pub struct ShardRouter {
|
||||
/// Total number of shards
|
||||
shard_count: u32,
|
||||
/// Shard assignment cache
|
||||
cache: Arc<RwLock<HashMap<String, u32>>>,
|
||||
}
|
||||
|
||||
impl ShardRouter {
|
||||
/// Create a new shard router
|
||||
pub fn new(shard_count: u32) -> Self {
|
||||
Self {
|
||||
shard_count,
|
||||
cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the shard ID for a key using jump consistent hashing
|
||||
pub fn get_shard(&self, key: &str) -> u32 {
|
||||
// Check cache first
|
||||
{
|
||||
let cache = self.cache.read();
|
||||
if let Some(&shard_id) = cache.get(key) {
|
||||
return shard_id;
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate using jump consistent hash
|
||||
let shard_id = self.jump_consistent_hash(key, self.shard_count);
|
||||
|
||||
// Update cache
|
||||
{
|
||||
let mut cache = self.cache.write();
|
||||
cache.insert(key.to_string(), shard_id);
|
||||
}
|
||||
|
||||
shard_id
|
||||
}
|
||||
|
||||
/// Jump consistent hash algorithm
|
||||
/// Provides minimal key migration on shard count changes
|
||||
fn jump_consistent_hash(&self, key: &str, num_buckets: u32) -> u32 {
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
|
||||
let mut hasher = DefaultHasher::new();
|
||||
key.hash(&mut hasher);
|
||||
let mut hash = hasher.finish();
|
||||
|
||||
let mut b: i64 = -1;
|
||||
let mut j: i64 = 0;
|
||||
|
||||
while j < num_buckets as i64 {
|
||||
b = j;
|
||||
hash = hash.wrapping_mul(2862933555777941757).wrapping_add(1);
|
||||
j = ((b.wrapping_add(1) as f64)
|
||||
* ((1i64 << 31) as f64 / ((hash >> 33).wrapping_add(1) as f64)))
|
||||
as i64;
|
||||
}
|
||||
|
||||
b as u32
|
||||
}
|
||||
|
||||
/// Get shard ID for a vector ID
|
||||
pub fn get_shard_for_vector(&self, vector_id: &str) -> u32 {
|
||||
self.get_shard(vector_id)
|
||||
}
|
||||
|
||||
/// Get shard IDs for a range query (may span multiple shards)
|
||||
pub fn get_shards_for_range(&self, _start: &str, _end: &str) -> Vec<u32> {
|
||||
// For range queries, we might need to check multiple shards
|
||||
// For simplicity, return all shards (can be optimized based on key distribution)
|
||||
(0..self.shard_count).collect()
|
||||
}
|
||||
|
||||
/// Clear the routing cache
|
||||
pub fn clear_cache(&self) {
|
||||
let mut cache = self.cache.write();
|
||||
cache.clear();
|
||||
}
|
||||
|
||||
/// Get cache statistics
|
||||
pub fn cache_stats(&self) -> CacheStats {
|
||||
let cache = self.cache.read();
|
||||
CacheStats {
|
||||
entries: cache.len(),
|
||||
shard_count: self.shard_count as usize,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cache statistics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CacheStats {
|
||||
pub entries: usize,
|
||||
pub shard_count: usize,
|
||||
}
|
||||
|
||||
/// Shard migration manager
|
||||
pub struct ShardMigration {
|
||||
/// Source shard ID
|
||||
pub source_shard: u32,
|
||||
/// Target shard ID
|
||||
pub target_shard: u32,
|
||||
/// Migration progress (0.0 to 1.0)
|
||||
pub progress: f64,
|
||||
/// Keys migrated
|
||||
pub keys_migrated: usize,
|
||||
/// Total keys to migrate
|
||||
pub total_keys: usize,
|
||||
}
|
||||
|
||||
impl ShardMigration {
|
||||
/// Create a new shard migration
|
||||
pub fn new(source_shard: u32, target_shard: u32, total_keys: usize) -> Self {
|
||||
Self {
|
||||
source_shard,
|
||||
target_shard,
|
||||
progress: 0.0,
|
||||
keys_migrated: 0,
|
||||
total_keys,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update migration progress
|
||||
pub fn update_progress(&mut self, keys_migrated: usize) {
|
||||
self.keys_migrated = keys_migrated;
|
||||
self.progress = if self.total_keys > 0 {
|
||||
keys_migrated as f64 / self.total_keys as f64
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
}
|
||||
|
||||
/// Check if migration is complete
|
||||
pub fn is_complete(&self) -> bool {
|
||||
self.progress >= 1.0 || self.keys_migrated >= self.total_keys
|
||||
}
|
||||
}
|
||||
|
||||
/// Load balancer for shard distribution
|
||||
pub struct LoadBalancer {
|
||||
/// Shard load statistics (shard_id -> load)
|
||||
loads: Arc<RwLock<HashMap<u32, f64>>>,
|
||||
}
|
||||
|
||||
impl LoadBalancer {
|
||||
/// Create a new load balancer
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
loads: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update load for a shard
|
||||
pub fn update_load(&self, shard_id: u32, load: f64) {
|
||||
let mut loads = self.loads.write();
|
||||
loads.insert(shard_id, load);
|
||||
}
|
||||
|
||||
/// Get load for a shard
|
||||
pub fn get_load(&self, shard_id: u32) -> f64 {
|
||||
let loads = self.loads.read();
|
||||
loads.get(&shard_id).copied().unwrap_or(0.0)
|
||||
}
|
||||
|
||||
/// Get the least loaded shard
|
||||
pub fn get_least_loaded_shard(&self, shard_ids: &[u32]) -> Option<u32> {
|
||||
let loads = self.loads.read();
|
||||
|
||||
shard_ids
|
||||
.iter()
|
||||
.min_by(|&&a, &&b| {
|
||||
let load_a = loads.get(&a).copied().unwrap_or(0.0);
|
||||
let load_b = loads.get(&b).copied().unwrap_or(0.0);
|
||||
load_a
|
||||
.partial_cmp(&load_b)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
.copied()
|
||||
}
|
||||
|
||||
/// Get load statistics
|
||||
pub fn get_stats(&self) -> LoadStats {
|
||||
let loads = self.loads.read();
|
||||
|
||||
let total: f64 = loads.values().sum();
|
||||
let count = loads.len();
|
||||
let avg = if count > 0 { total / count as f64 } else { 0.0 };
|
||||
|
||||
let max = loads.values().copied().fold(f64::NEG_INFINITY, f64::max);
|
||||
let min = loads.values().copied().fold(f64::INFINITY, f64::min);
|
||||
|
||||
LoadStats {
|
||||
total_load: total,
|
||||
avg_load: avg,
|
||||
max_load: if max.is_finite() { max } else { 0.0 },
|
||||
min_load: if min.is_finite() { min } else { 0.0 },
|
||||
shard_count: count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LoadBalancer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Load statistics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LoadStats {
|
||||
pub total_load: f64,
|
||||
pub avg_load: f64,
|
||||
pub max_load: f64,
|
||||
pub min_load: f64,
|
||||
pub shard_count: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_consistent_hash_ring() {
|
||||
let mut ring = ConsistentHashRing::new(3);
|
||||
|
||||
ring.add_node("node1".to_string());
|
||||
ring.add_node("node2".to_string());
|
||||
ring.add_node("node3".to_string());
|
||||
|
||||
assert_eq!(ring.node_count(), 3);
|
||||
|
||||
let nodes = ring.get_nodes("test-key", 3);
|
||||
assert_eq!(nodes.len(), 3);
|
||||
|
||||
// Test primary node selection
|
||||
let primary = ring.get_primary_node("test-key");
|
||||
assert!(primary.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_consistent_hashing_distribution() {
|
||||
let mut ring = ConsistentHashRing::new(3);
|
||||
|
||||
ring.add_node("node1".to_string());
|
||||
ring.add_node("node2".to_string());
|
||||
ring.add_node("node3".to_string());
|
||||
|
||||
let mut distribution: HashMap<String, usize> = HashMap::new();
|
||||
|
||||
// Test distribution across many keys
|
||||
for i in 0..1000 {
|
||||
let key = format!("key{}", i);
|
||||
if let Some(node) = ring.get_primary_node(&key) {
|
||||
*distribution.entry(node).or_insert(0) += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Each node should get roughly 1/3 of the keys (within 20% tolerance)
|
||||
for count in distribution.values() {
|
||||
let ratio = *count as f64 / 1000.0;
|
||||
assert!(ratio > 0.2 && ratio < 0.5, "Distribution ratio: {}", ratio);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shard_router() {
|
||||
let router = ShardRouter::new(16);
|
||||
|
||||
let shard1 = router.get_shard("test-key-1");
|
||||
let shard2 = router.get_shard("test-key-1"); // Should be cached
|
||||
|
||||
assert_eq!(shard1, shard2);
|
||||
assert!(shard1 < 16);
|
||||
|
||||
let stats = router.cache_stats();
|
||||
assert_eq!(stats.entries, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jump_consistent_hash() {
|
||||
let router = ShardRouter::new(10);
|
||||
|
||||
// Same key should always map to same shard
|
||||
let shard1 = router.get_shard("consistent-key");
|
||||
let shard2 = router.get_shard("consistent-key");
|
||||
|
||||
assert_eq!(shard1, shard2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shard_migration() {
|
||||
let mut migration = ShardMigration::new(0, 1, 100);
|
||||
|
||||
assert!(!migration.is_complete());
|
||||
assert_eq!(migration.progress, 0.0);
|
||||
|
||||
migration.update_progress(50);
|
||||
assert_eq!(migration.progress, 0.5);
|
||||
|
||||
migration.update_progress(100);
|
||||
assert!(migration.is_complete());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_balancer() {
|
||||
let balancer = LoadBalancer::new();
|
||||
|
||||
balancer.update_load(0, 0.5);
|
||||
balancer.update_load(1, 0.8);
|
||||
balancer.update_load(2, 0.3);
|
||||
|
||||
let least_loaded = balancer.get_least_loaded_shard(&[0, 1, 2]);
|
||||
assert_eq!(least_loaded, Some(2));
|
||||
|
||||
let stats = balancer.get_stats();
|
||||
assert_eq!(stats.shard_count, 3);
|
||||
assert!(stats.avg_load > 0.0);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user