Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,480 @@
//! DAG-based consensus protocol inspired by QuDAG
//!
//! Implements a directed acyclic graph for transaction ordering and consensus.
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::Arc;
use tracing::{debug, info, warn};
use uuid::Uuid;
use crate::{ClusterError, Result};
/// A vertex in the consensus DAG
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DagVertex {
/// Unique vertex ID
pub id: String,
/// Node that created this vertex
pub node_id: String,
/// Transaction data
pub transaction: Transaction,
/// Parent vertices (edges in the DAG)
pub parents: Vec<String>,
/// Timestamp when vertex was created
pub timestamp: DateTime<Utc>,
/// Vector clock for causality tracking
pub vector_clock: HashMap<String, u64>,
/// Signature (in production, this would be cryptographic)
pub signature: String,
}
impl DagVertex {
/// Create a new DAG vertex
pub fn new(
node_id: String,
transaction: Transaction,
parents: Vec<String>,
vector_clock: HashMap<String, u64>,
) -> Self {
Self {
id: Uuid::new_v4().to_string(),
node_id,
transaction,
parents,
timestamp: Utc::now(),
vector_clock,
signature: String::new(), // Would be computed cryptographically
}
}
/// Verify the vertex signature
pub fn verify_signature(&self) -> bool {
// In production, verify cryptographic signature
true
}
}
/// A transaction in the consensus system
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Transaction {
/// Transaction ID
pub id: String,
/// Transaction type
pub tx_type: TransactionType,
/// Transaction data
pub data: Vec<u8>,
/// Nonce for ordering
pub nonce: u64,
}
/// Type of transaction
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TransactionType {
/// Write operation
Write,
/// Read operation
Read,
/// Delete operation
Delete,
/// Batch operation
Batch,
/// System operation
System,
}
/// DAG-based consensus engine
pub struct DagConsensus {
/// Node ID
node_id: String,
/// DAG vertices (vertex_id -> vertex)
vertices: Arc<DashMap<String, DagVertex>>,
/// Finalized vertices
finalized: Arc<RwLock<HashSet<String>>>,
/// Vector clock for this node
vector_clock: Arc<RwLock<HashMap<String, u64>>>,
/// Pending transactions
pending_txs: Arc<RwLock<VecDeque<Transaction>>>,
/// Minimum quorum size
min_quorum_size: usize,
/// Transaction nonce counter
nonce_counter: Arc<RwLock<u64>>,
}
impl DagConsensus {
/// Create a new DAG consensus engine
pub fn new(node_id: String, min_quorum_size: usize) -> Self {
let mut vector_clock = HashMap::new();
vector_clock.insert(node_id.clone(), 0);
Self {
node_id,
vertices: Arc::new(DashMap::new()),
finalized: Arc::new(RwLock::new(HashSet::new())),
vector_clock: Arc::new(RwLock::new(vector_clock)),
pending_txs: Arc::new(RwLock::new(VecDeque::new())),
min_quorum_size,
nonce_counter: Arc::new(RwLock::new(0)),
}
}
/// Submit a transaction to the consensus system
pub fn submit_transaction(&self, tx_type: TransactionType, data: Vec<u8>) -> Result<String> {
let mut nonce = self.nonce_counter.write();
*nonce += 1;
let transaction = Transaction {
id: Uuid::new_v4().to_string(),
tx_type,
data,
nonce: *nonce,
};
let tx_id = transaction.id.clone();
let mut pending = self.pending_txs.write();
pending.push_back(transaction);
debug!("Transaction {} submitted to consensus", tx_id);
Ok(tx_id)
}
/// Create a new vertex for pending transactions
pub fn create_vertex(&self) -> Result<Option<DagVertex>> {
let mut pending = self.pending_txs.write();
if pending.is_empty() {
return Ok(None);
}
// Take the next transaction
let transaction = pending.pop_front().unwrap();
// Find parent vertices (tips of the DAG)
let parents = self.find_tips();
// Update vector clock
let mut clock = self.vector_clock.write();
let count = clock.entry(self.node_id.clone()).or_insert(0);
*count += 1;
let vertex = DagVertex::new(self.node_id.clone(), transaction, parents, clock.clone());
let vertex_id = vertex.id.clone();
self.vertices.insert(vertex_id.clone(), vertex.clone());
debug!(
"Created vertex {} for transaction {}",
vertex_id, vertex.transaction.id
);
Ok(Some(vertex))
}
/// Find tip vertices (vertices with no children)
fn find_tips(&self) -> Vec<String> {
let mut has_children = HashSet::new();
// Mark all vertices that have children
for entry in self.vertices.iter() {
for parent in &entry.value().parents {
has_children.insert(parent.clone());
}
}
// Find vertices without children
self.vertices
.iter()
.filter(|entry| !has_children.contains(entry.key()))
.map(|entry| entry.key().clone())
.collect()
}
/// Add a vertex from another node
pub fn add_vertex(&self, vertex: DagVertex) -> Result<()> {
// Verify signature
if !vertex.verify_signature() {
return Err(ClusterError::ConsensusError(
"Invalid vertex signature".to_string(),
));
}
// Verify parents exist
for parent_id in &vertex.parents {
if !self.vertices.contains_key(parent_id) && !self.is_finalized(parent_id) {
return Err(ClusterError::ConsensusError(format!(
"Parent vertex {} not found",
parent_id
)));
}
}
// Merge vector clock
let mut clock = self.vector_clock.write();
for (node, count) in &vertex.vector_clock {
let existing = clock.entry(node.clone()).or_insert(0);
*existing = (*existing).max(*count);
}
self.vertices.insert(vertex.id.clone(), vertex);
Ok(())
}
/// Check if a vertex is finalized
pub fn is_finalized(&self, vertex_id: &str) -> bool {
let finalized = self.finalized.read();
finalized.contains(vertex_id)
}
/// Finalize vertices using the wave algorithm
pub fn finalize_vertices(&self) -> Result<Vec<String>> {
let mut finalized_ids = Vec::new();
// Find vertices that can be finalized
// A vertex is finalized if it has enough confirmations from different nodes
let mut confirmations: HashMap<String, HashSet<String>> = HashMap::new();
for entry in self.vertices.iter() {
let vertex = entry.value();
// Count confirmations (vertices that reference this one)
for other_entry in self.vertices.iter() {
if other_entry.value().parents.contains(&vertex.id) {
confirmations
.entry(vertex.id.clone())
.or_insert_with(HashSet::new)
.insert(other_entry.value().node_id.clone());
}
}
}
// Finalize vertices with enough confirmations
let mut finalized = self.finalized.write();
for (vertex_id, confirming_nodes) in confirmations {
if confirming_nodes.len() >= self.min_quorum_size && !finalized.contains(&vertex_id) {
finalized.insert(vertex_id.clone());
finalized_ids.push(vertex_id.clone());
info!("Finalized vertex {}", vertex_id);
}
}
Ok(finalized_ids)
}
/// Get the total order of finalized transactions
pub fn get_finalized_order(&self) -> Vec<Transaction> {
let finalized = self.finalized.read();
let mut ordered_txs = Vec::new();
// Topological sort of finalized vertices
let finalized_vertices: Vec<_> = self
.vertices
.iter()
.filter(|entry| finalized.contains(entry.key()))
.map(|entry| entry.value().clone())
.collect();
// Sort by vector clock and timestamp
let mut sorted = finalized_vertices;
sorted.sort_by(|a, b| {
// First by vector clock dominance
let a_dominates = Self::vector_clock_dominates(&a.vector_clock, &b.vector_clock);
let b_dominates = Self::vector_clock_dominates(&b.vector_clock, &a.vector_clock);
if a_dominates && !b_dominates {
std::cmp::Ordering::Less
} else if b_dominates && !a_dominates {
std::cmp::Ordering::Greater
} else {
// Fall back to timestamp
a.timestamp.cmp(&b.timestamp)
}
});
for vertex in sorted {
ordered_txs.push(vertex.transaction);
}
ordered_txs
}
/// Check if vector clock a dominates vector clock b
fn vector_clock_dominates(a: &HashMap<String, u64>, b: &HashMap<String, u64>) -> bool {
let mut dominates = false;
for (node, &a_count) in a {
let b_count = b.get(node).copied().unwrap_or(0);
if a_count < b_count {
return false;
}
if a_count > b_count {
dominates = true;
}
}
dominates
}
/// Detect conflicts between transactions
pub fn detect_conflicts(&self, tx1: &Transaction, tx2: &Transaction) -> bool {
// In a real implementation, this would analyze transaction data
// For now, conservatively assume all writes conflict
matches!(
(&tx1.tx_type, &tx2.tx_type),
(TransactionType::Write, TransactionType::Write)
| (TransactionType::Delete, TransactionType::Write)
| (TransactionType::Write, TransactionType::Delete)
)
}
/// Get consensus statistics
pub fn get_stats(&self) -> ConsensusStats {
let finalized = self.finalized.read();
let pending = self.pending_txs.read();
ConsensusStats {
total_vertices: self.vertices.len(),
finalized_vertices: finalized.len(),
pending_transactions: pending.len(),
tips: self.find_tips().len(),
}
}
/// Prune old finalized vertices to save memory
pub fn prune_old_vertices(&self, keep_count: usize) {
let finalized = self.finalized.read();
if finalized.len() <= keep_count {
return;
}
// Remove oldest finalized vertices
let mut vertices_to_remove = Vec::new();
for vertex_id in finalized.iter() {
if let Some(vertex) = self.vertices.get(vertex_id) {
vertices_to_remove.push((vertex_id.clone(), vertex.timestamp));
}
}
vertices_to_remove.sort_by_key(|(_, ts)| *ts);
let to_remove = vertices_to_remove.len().saturating_sub(keep_count);
for (vertex_id, _) in vertices_to_remove.iter().take(to_remove) {
self.vertices.remove(vertex_id);
}
debug!("Pruned {} old vertices", to_remove);
}
}
/// Consensus statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConsensusStats {
pub total_vertices: usize,
pub finalized_vertices: usize,
pub pending_transactions: usize,
pub tips: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_consensus_creation() {
let consensus = DagConsensus::new("node1".to_string(), 2);
let stats = consensus.get_stats();
assert_eq!(stats.total_vertices, 0);
assert_eq!(stats.pending_transactions, 0);
}
#[test]
fn test_submit_transaction() {
let consensus = DagConsensus::new("node1".to_string(), 2);
let tx_id = consensus
.submit_transaction(TransactionType::Write, vec![1, 2, 3])
.unwrap();
assert!(!tx_id.is_empty());
let stats = consensus.get_stats();
assert_eq!(stats.pending_transactions, 1);
}
#[test]
fn test_create_vertex() {
let consensus = DagConsensus::new("node1".to_string(), 2);
consensus
.submit_transaction(TransactionType::Write, vec![1, 2, 3])
.unwrap();
let vertex = consensus.create_vertex().unwrap();
assert!(vertex.is_some());
let stats = consensus.get_stats();
assert_eq!(stats.total_vertices, 1);
assert_eq!(stats.pending_transactions, 0);
}
#[test]
fn test_vector_clock_dominance() {
let mut clock1 = HashMap::new();
clock1.insert("node1".to_string(), 2);
clock1.insert("node2".to_string(), 1);
let mut clock2 = HashMap::new();
clock2.insert("node1".to_string(), 1);
clock2.insert("node2".to_string(), 1);
assert!(DagConsensus::vector_clock_dominates(&clock1, &clock2));
assert!(!DagConsensus::vector_clock_dominates(&clock2, &clock1));
}
#[test]
fn test_conflict_detection() {
let consensus = DagConsensus::new("node1".to_string(), 2);
let tx1 = Transaction {
id: "1".to_string(),
tx_type: TransactionType::Write,
data: vec![1],
nonce: 1,
};
let tx2 = Transaction {
id: "2".to_string(),
tx_type: TransactionType::Write,
data: vec![2],
nonce: 2,
};
assert!(consensus.detect_conflicts(&tx1, &tx2));
}
#[test]
fn test_finalization() {
let consensus = DagConsensus::new("node1".to_string(), 2);
// Create some vertices
for i in 0..5 {
consensus
.submit_transaction(TransactionType::Write, vec![i])
.unwrap();
consensus.create_vertex().unwrap();
}
// Try to finalize
let finalized = consensus.finalize_vertices().unwrap();
// Without enough confirmations, nothing should be finalized yet
// (would need vertices from other nodes)
assert_eq!(finalized.len(), 0);
}
}

View File

@@ -0,0 +1,383 @@
//! Node discovery mechanisms for cluster formation
//!
//! Supports static configuration and gossip-based discovery.
use crate::{ClusterError, ClusterNode, NodeStatus, Result};
use async_trait::async_trait;
use chrono::Utc;
use dashmap::DashMap;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::time;
use tracing::{debug, info, warn};
/// Service for discovering nodes in the cluster
#[async_trait]
pub trait DiscoveryService: Send + Sync {
/// Discover nodes in the cluster
async fn discover_nodes(&self) -> Result<Vec<ClusterNode>>;
/// Register this node in the discovery service
async fn register_node(&self, node: ClusterNode) -> Result<()>;
/// Unregister this node from the discovery service
async fn unregister_node(&self, node_id: &str) -> Result<()>;
/// Update node heartbeat
async fn heartbeat(&self, node_id: &str) -> Result<()>;
}
/// Static discovery using predefined node list
pub struct StaticDiscovery {
/// Predefined list of nodes
nodes: Arc<RwLock<Vec<ClusterNode>>>,
}
impl StaticDiscovery {
/// Create a new static discovery service
pub fn new(nodes: Vec<ClusterNode>) -> Self {
Self {
nodes: Arc::new(RwLock::new(nodes)),
}
}
/// Add a node to the static list
pub fn add_node(&self, node: ClusterNode) {
let mut nodes = self.nodes.write();
nodes.push(node);
}
/// Remove a node from the static list
pub fn remove_node(&self, node_id: &str) {
let mut nodes = self.nodes.write();
nodes.retain(|n| n.node_id != node_id);
}
}
#[async_trait]
impl DiscoveryService for StaticDiscovery {
async fn discover_nodes(&self) -> Result<Vec<ClusterNode>> {
let nodes = self.nodes.read();
Ok(nodes.clone())
}
async fn register_node(&self, node: ClusterNode) -> Result<()> {
self.add_node(node);
Ok(())
}
async fn unregister_node(&self, node_id: &str) -> Result<()> {
self.remove_node(node_id);
Ok(())
}
async fn heartbeat(&self, node_id: &str) -> Result<()> {
let mut nodes = self.nodes.write();
if let Some(node) = nodes.iter_mut().find(|n| n.node_id == node_id) {
node.heartbeat();
}
Ok(())
}
}
/// Gossip-based discovery protocol
pub struct GossipDiscovery {
/// Local node information
local_node: Arc<RwLock<ClusterNode>>,
/// Known nodes (node_id -> node)
nodes: Arc<DashMap<String, ClusterNode>>,
/// Seed nodes to bootstrap gossip
seed_nodes: Vec<SocketAddr>,
/// Gossip interval
gossip_interval: Duration,
/// Node timeout
node_timeout: Duration,
}
impl GossipDiscovery {
/// Create a new gossip discovery service
pub fn new(
local_node: ClusterNode,
seed_nodes: Vec<SocketAddr>,
gossip_interval: Duration,
node_timeout: Duration,
) -> Self {
let nodes = Arc::new(DashMap::new());
nodes.insert(local_node.node_id.clone(), local_node.clone());
Self {
local_node: Arc::new(RwLock::new(local_node)),
nodes,
seed_nodes,
gossip_interval,
node_timeout,
}
}
/// Start the gossip protocol
pub async fn start(&self) -> Result<()> {
info!("Starting gossip discovery protocol");
// Bootstrap from seed nodes
self.bootstrap().await?;
// Start periodic gossip
let nodes = Arc::clone(&self.nodes);
let gossip_interval = self.gossip_interval;
tokio::spawn(async move {
let mut interval = time::interval(gossip_interval);
loop {
interval.tick().await;
Self::gossip_round(&nodes).await;
}
});
Ok(())
}
/// Bootstrap by contacting seed nodes
async fn bootstrap(&self) -> Result<()> {
debug!("Bootstrapping from {} seed nodes", self.seed_nodes.len());
for seed_addr in &self.seed_nodes {
// In a real implementation, this would contact the seed node
// For now, we'll simulate it
debug!("Contacting seed node at {}", seed_addr);
}
Ok(())
}
/// Perform a gossip round
async fn gossip_round(nodes: &Arc<DashMap<String, ClusterNode>>) {
// Select random subset of nodes to gossip with
let node_list: Vec<_> = nodes.iter().map(|e| e.value().clone()).collect();
if node_list.len() < 2 {
return;
}
debug!("Gossiping with {} nodes", node_list.len());
// In a real implementation, we would:
// 1. Select random peers
// 2. Exchange node lists
// 3. Merge received information
// 4. Detect failures
}
/// Merge gossip information from another node
pub fn merge_gossip(&self, remote_nodes: Vec<ClusterNode>) {
for node in remote_nodes {
if let Some(mut existing) = self.nodes.get_mut(&node.node_id) {
// Update if remote has newer information
if node.last_seen > existing.last_seen {
*existing = node;
}
} else {
// Add new node
self.nodes.insert(node.node_id.clone(), node);
}
}
}
/// Remove failed nodes
pub fn prune_failed_nodes(&self) {
let now = Utc::now();
self.nodes.retain(|_, node| {
let elapsed = now
.signed_duration_since(node.last_seen)
.to_std()
.unwrap_or(Duration::MAX);
elapsed < self.node_timeout
});
}
/// Get gossip statistics
pub fn get_stats(&self) -> GossipStats {
let nodes: Vec<_> = self.nodes.iter().map(|e| e.value().clone()).collect();
let healthy = nodes
.iter()
.filter(|n| n.is_healthy(self.node_timeout))
.count();
GossipStats {
total_nodes: nodes.len(),
healthy_nodes: healthy,
seed_nodes: self.seed_nodes.len(),
}
}
}
#[async_trait]
impl DiscoveryService for GossipDiscovery {
async fn discover_nodes(&self) -> Result<Vec<ClusterNode>> {
Ok(self.nodes.iter().map(|e| e.value().clone()).collect())
}
async fn register_node(&self, node: ClusterNode) -> Result<()> {
self.nodes.insert(node.node_id.clone(), node);
Ok(())
}
async fn unregister_node(&self, node_id: &str) -> Result<()> {
self.nodes.remove(node_id);
Ok(())
}
async fn heartbeat(&self, node_id: &str) -> Result<()> {
if let Some(mut node) = self.nodes.get_mut(node_id) {
node.heartbeat();
}
Ok(())
}
}
/// Gossip protocol statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GossipStats {
pub total_nodes: usize,
pub healthy_nodes: usize,
pub seed_nodes: usize,
}
/// Multicast-based discovery (for local networks)
pub struct MulticastDiscovery {
/// Local node
local_node: ClusterNode,
/// Discovered nodes
nodes: Arc<DashMap<String, ClusterNode>>,
/// Multicast address
multicast_addr: String,
/// Multicast port
multicast_port: u16,
}
impl MulticastDiscovery {
/// Create a new multicast discovery service
pub fn new(local_node: ClusterNode, multicast_addr: String, multicast_port: u16) -> Self {
Self {
local_node,
nodes: Arc::new(DashMap::new()),
multicast_addr,
multicast_port,
}
}
/// Start multicast discovery
pub async fn start(&self) -> Result<()> {
info!(
"Starting multicast discovery on {}:{}",
self.multicast_addr, self.multicast_port
);
// In a real implementation, this would:
// 1. Join multicast group
// 2. Send periodic announcements
// 3. Listen for other nodes
// 4. Update node list
Ok(())
}
}
#[async_trait]
impl DiscoveryService for MulticastDiscovery {
async fn discover_nodes(&self) -> Result<Vec<ClusterNode>> {
Ok(self.nodes.iter().map(|e| e.value().clone()).collect())
}
async fn register_node(&self, node: ClusterNode) -> Result<()> {
self.nodes.insert(node.node_id.clone(), node);
Ok(())
}
async fn unregister_node(&self, node_id: &str) -> Result<()> {
self.nodes.remove(node_id);
Ok(())
}
async fn heartbeat(&self, node_id: &str) -> Result<()> {
if let Some(mut node) = self.nodes.get_mut(node_id) {
node.heartbeat();
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
fn create_test_node(id: &str, port: u16) -> ClusterNode {
ClusterNode::new(
id.to_string(),
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port),
)
}
#[tokio::test]
async fn test_static_discovery() {
let node1 = create_test_node("node1", 8000);
let node2 = create_test_node("node2", 8001);
let discovery = StaticDiscovery::new(vec![node1, node2]);
let nodes = discovery.discover_nodes().await.unwrap();
assert_eq!(nodes.len(), 2);
}
#[tokio::test]
async fn test_static_discovery_register() {
let discovery = StaticDiscovery::new(vec![]);
let node = create_test_node("node1", 8000);
discovery.register_node(node).await.unwrap();
let nodes = discovery.discover_nodes().await.unwrap();
assert_eq!(nodes.len(), 1);
}
#[tokio::test]
async fn test_gossip_discovery() {
let local_node = create_test_node("local", 8000);
let seed_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 9000);
let discovery = GossipDiscovery::new(
local_node,
vec![seed_addr],
Duration::from_secs(5),
Duration::from_secs(30),
);
let nodes = discovery.discover_nodes().await.unwrap();
assert_eq!(nodes.len(), 1); // Only local node initially
}
#[tokio::test]
async fn test_gossip_merge() {
let local_node = create_test_node("local", 8000);
let discovery = GossipDiscovery::new(
local_node,
vec![],
Duration::from_secs(5),
Duration::from_secs(30),
);
let remote_nodes = vec![
create_test_node("node1", 8001),
create_test_node("node2", 8002),
];
discovery.merge_gossip(remote_nodes);
let stats = discovery.get_stats();
assert_eq!(stats.total_nodes, 3); // local + 2 remote
}
}

View File

@@ -0,0 +1,513 @@
//! Distributed clustering and sharding for ruvector
//!
//! This crate provides distributed coordination capabilities including:
//! - Cluster node management and health monitoring
//! - Consistent hashing for shard distribution
//! - DAG-based consensus protocol
//! - Dynamic node discovery and topology management
pub mod consensus;
pub mod discovery;
pub mod shard;
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use tracing::{debug, error, info, warn};
use uuid::Uuid;
pub use consensus::DagConsensus;
pub use discovery::{DiscoveryService, GossipDiscovery, StaticDiscovery};
pub use shard::{ConsistentHashRing, ShardRouter};
/// Cluster-related errors
#[derive(Debug, Error)]
pub enum ClusterError {
#[error("Node not found: {0}")]
NodeNotFound(String),
#[error("Shard not found: {0}")]
ShardNotFound(u32),
#[error("Invalid configuration: {0}")]
InvalidConfig(String),
#[error("Consensus error: {0}")]
ConsensusError(String),
#[error("Discovery error: {0}")]
DiscoveryError(String),
#[error("Network error: {0}")]
NetworkError(String),
#[error("Serialization error: {0}")]
SerializationError(String),
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
}
pub type Result<T> = std::result::Result<T, ClusterError>;
/// Status of a cluster node
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum NodeStatus {
/// Node is the cluster leader
Leader,
/// Node is a follower
Follower,
/// Node is campaigning to be leader
Candidate,
/// Node is offline or unreachable
Offline,
}
/// Information about a cluster node
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterNode {
/// Unique node identifier
pub node_id: String,
/// Network address of the node
pub address: SocketAddr,
/// Current status of the node
pub status: NodeStatus,
/// Last time the node was seen alive
pub last_seen: DateTime<Utc>,
/// Metadata about the node
pub metadata: HashMap<String, String>,
/// Node capacity (for load balancing)
pub capacity: f64,
}
impl ClusterNode {
/// Create a new cluster node
pub fn new(node_id: String, address: SocketAddr) -> Self {
Self {
node_id,
address,
status: NodeStatus::Follower,
last_seen: Utc::now(),
metadata: HashMap::new(),
capacity: 1.0,
}
}
/// Check if the node is healthy (seen recently)
pub fn is_healthy(&self, timeout: Duration) -> bool {
let now = Utc::now();
let elapsed = now
.signed_duration_since(self.last_seen)
.to_std()
.unwrap_or(Duration::MAX);
elapsed < timeout
}
/// Update the last seen timestamp
pub fn heartbeat(&mut self) {
self.last_seen = Utc::now();
}
}
/// Information about a data shard
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShardInfo {
/// Shard identifier
pub shard_id: u32,
/// Primary node responsible for this shard
pub primary_node: String,
/// Replica nodes for this shard
pub replica_nodes: Vec<String>,
/// Number of vectors in this shard
pub vector_count: usize,
/// Shard status
pub status: ShardStatus,
/// Creation timestamp
pub created_at: DateTime<Utc>,
/// Last modified timestamp
pub modified_at: DateTime<Utc>,
}
/// Status of a shard
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ShardStatus {
/// Shard is active and serving requests
Active,
/// Shard is being migrated
Migrating,
/// Shard is being replicated
Replicating,
/// Shard is offline
Offline,
}
/// Cluster configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterConfig {
/// Number of replica copies for each shard
pub replication_factor: usize,
/// Total number of shards in the cluster
pub shard_count: u32,
/// Interval between heartbeat checks
pub heartbeat_interval: Duration,
/// Timeout before considering a node offline
pub node_timeout: Duration,
/// Enable DAG-based consensus
pub enable_consensus: bool,
/// Minimum nodes required for quorum
pub min_quorum_size: usize,
}
impl Default for ClusterConfig {
fn default() -> Self {
Self {
replication_factor: 3,
shard_count: 64,
heartbeat_interval: Duration::from_secs(5),
node_timeout: Duration::from_secs(30),
enable_consensus: true,
min_quorum_size: 2,
}
}
}
/// Manages a distributed cluster of vector database nodes
pub struct ClusterManager {
/// Cluster configuration
config: ClusterConfig,
/// Map of node_id to ClusterNode
nodes: Arc<DashMap<String, ClusterNode>>,
/// Map of shard_id to ShardInfo
shards: Arc<DashMap<u32, ShardInfo>>,
/// Consistent hash ring for shard assignment
hash_ring: Arc<RwLock<ConsistentHashRing>>,
/// Shard router for query routing
router: Arc<ShardRouter>,
/// DAG-based consensus engine
consensus: Option<Arc<DagConsensus>>,
/// Discovery service (boxed for type erasure)
discovery: Box<dyn DiscoveryService>,
/// Current node ID
node_id: String,
}
impl ClusterManager {
/// Create a new cluster manager
pub fn new(
config: ClusterConfig,
node_id: String,
discovery: Box<dyn DiscoveryService>,
) -> Result<Self> {
let nodes = Arc::new(DashMap::new());
let shards = Arc::new(DashMap::new());
let hash_ring = Arc::new(RwLock::new(ConsistentHashRing::new(
config.replication_factor,
)));
let router = Arc::new(ShardRouter::new(config.shard_count));
let consensus = if config.enable_consensus {
Some(Arc::new(DagConsensus::new(
node_id.clone(),
config.min_quorum_size,
)))
} else {
None
};
Ok(Self {
config,
nodes,
shards,
hash_ring,
router,
consensus,
discovery,
node_id,
})
}
/// Add a node to the cluster
pub async fn add_node(&self, node: ClusterNode) -> Result<()> {
info!("Adding node {} to cluster", node.node_id);
// Add to hash ring
{
let mut ring = self.hash_ring.write();
ring.add_node(node.node_id.clone());
}
// Store node information
self.nodes.insert(node.node_id.clone(), node.clone());
// Rebalance shards if needed
self.rebalance_shards().await?;
info!("Node {} successfully added", node.node_id);
Ok(())
}
/// Remove a node from the cluster
pub async fn remove_node(&self, node_id: &str) -> Result<()> {
info!("Removing node {} from cluster", node_id);
// Remove from hash ring
{
let mut ring = self.hash_ring.write();
ring.remove_node(node_id);
}
// Remove node information
self.nodes.remove(node_id);
// Rebalance shards
self.rebalance_shards().await?;
info!("Node {} successfully removed", node_id);
Ok(())
}
/// Get node by ID
pub fn get_node(&self, node_id: &str) -> Option<ClusterNode> {
self.nodes.get(node_id).map(|n| n.clone())
}
/// List all nodes in the cluster
pub fn list_nodes(&self) -> Vec<ClusterNode> {
self.nodes
.iter()
.map(|entry| entry.value().clone())
.collect()
}
/// Get healthy nodes only
pub fn healthy_nodes(&self) -> Vec<ClusterNode> {
self.nodes
.iter()
.filter(|entry| entry.value().is_healthy(self.config.node_timeout))
.map(|entry| entry.value().clone())
.collect()
}
/// Get shard information
pub fn get_shard(&self, shard_id: u32) -> Option<ShardInfo> {
self.shards.get(&shard_id).map(|s| s.clone())
}
/// List all shards
pub fn list_shards(&self) -> Vec<ShardInfo> {
self.shards
.iter()
.map(|entry| entry.value().clone())
.collect()
}
/// Assign a shard to nodes using consistent hashing
pub fn assign_shard(&self, shard_id: u32) -> Result<ShardInfo> {
let ring = self.hash_ring.read();
let key = format!("shard:{}", shard_id);
let nodes = ring.get_nodes(&key, self.config.replication_factor);
if nodes.is_empty() {
return Err(ClusterError::InvalidConfig(
"No nodes available for shard assignment".to_string(),
));
}
let primary_node = nodes[0].clone();
let replica_nodes = nodes.into_iter().skip(1).collect();
let shard_info = ShardInfo {
shard_id,
primary_node,
replica_nodes,
vector_count: 0,
status: ShardStatus::Active,
created_at: Utc::now(),
modified_at: Utc::now(),
};
self.shards.insert(shard_id, shard_info.clone());
Ok(shard_info)
}
/// Rebalance shards across nodes
async fn rebalance_shards(&self) -> Result<()> {
debug!("Rebalancing shards across cluster");
for shard_id in 0..self.config.shard_count {
if let Some(mut shard) = self.shards.get_mut(&shard_id) {
let ring = self.hash_ring.read();
let key = format!("shard:{}", shard_id);
let nodes = ring.get_nodes(&key, self.config.replication_factor);
if !nodes.is_empty() {
shard.primary_node = nodes[0].clone();
shard.replica_nodes = nodes.into_iter().skip(1).collect();
shard.modified_at = Utc::now();
}
} else {
// Create new shard assignment
self.assign_shard(shard_id)?;
}
}
debug!("Shard rebalancing complete");
Ok(())
}
/// Run periodic health checks
pub async fn run_health_checks(&self) -> Result<()> {
debug!("Running health checks");
let mut unhealthy_nodes = Vec::new();
for entry in self.nodes.iter() {
let node = entry.value();
if !node.is_healthy(self.config.node_timeout) {
warn!("Node {} is unhealthy", node.node_id);
unhealthy_nodes.push(node.node_id.clone());
}
}
// Mark unhealthy nodes as offline
for node_id in unhealthy_nodes {
if let Some(mut node) = self.nodes.get_mut(&node_id) {
node.status = NodeStatus::Offline;
}
}
Ok(())
}
/// Start the cluster manager (health checks, discovery, etc.)
pub async fn start(&self) -> Result<()> {
info!("Starting cluster manager for node {}", self.node_id);
// Start discovery service
let discovered = self.discovery.discover_nodes().await?;
for node in discovered {
if node.node_id != self.node_id {
self.add_node(node).await?;
}
}
// Initialize shards
for shard_id in 0..self.config.shard_count {
self.assign_shard(shard_id)?;
}
info!("Cluster manager started successfully");
Ok(())
}
/// Get cluster statistics
pub fn get_stats(&self) -> ClusterStats {
let nodes = self.list_nodes();
let shards = self.list_shards();
let healthy = self.healthy_nodes();
ClusterStats {
total_nodes: nodes.len(),
healthy_nodes: healthy.len(),
total_shards: shards.len(),
active_shards: shards
.iter()
.filter(|s| s.status == ShardStatus::Active)
.count(),
total_vectors: shards.iter().map(|s| s.vector_count).sum(),
}
}
/// Get the shard router
pub fn router(&self) -> Arc<ShardRouter> {
Arc::clone(&self.router)
}
/// Get the consensus engine
pub fn consensus(&self) -> Option<Arc<DagConsensus>> {
self.consensus.as_ref().map(Arc::clone)
}
}
/// Cluster statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterStats {
pub total_nodes: usize,
pub healthy_nodes: usize,
pub total_shards: usize,
pub active_shards: usize,
pub total_vectors: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
fn create_test_node(id: &str, port: u16) -> ClusterNode {
ClusterNode::new(
id.to_string(),
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port),
)
}
#[tokio::test]
async fn test_cluster_node_creation() {
let node = create_test_node("node1", 8000);
assert_eq!(node.node_id, "node1");
assert_eq!(node.status, NodeStatus::Follower);
assert!(node.is_healthy(Duration::from_secs(60)));
}
#[tokio::test]
async fn test_cluster_manager_creation() {
let config = ClusterConfig::default();
let discovery = Box::new(StaticDiscovery::new(vec![]));
let manager = ClusterManager::new(config, "test-node".to_string(), discovery);
assert!(manager.is_ok());
}
#[tokio::test]
async fn test_add_remove_node() {
let config = ClusterConfig::default();
let discovery = Box::new(StaticDiscovery::new(vec![]));
let manager = ClusterManager::new(config, "test-node".to_string(), discovery).unwrap();
let node = create_test_node("node1", 8000);
manager.add_node(node).await.unwrap();
assert_eq!(manager.list_nodes().len(), 1);
manager.remove_node("node1").await.unwrap();
assert_eq!(manager.list_nodes().len(), 0);
}
#[tokio::test]
async fn test_shard_assignment() {
let config = ClusterConfig {
shard_count: 4,
replication_factor: 2,
..Default::default()
};
let discovery = Box::new(StaticDiscovery::new(vec![]));
let manager = ClusterManager::new(config, "test-node".to_string(), discovery).unwrap();
// Add some nodes
for i in 0..3 {
let node = create_test_node(&format!("node{}", i), 8000 + i);
manager.add_node(node).await.unwrap();
}
// Assign a shard
let shard = manager.assign_shard(0).unwrap();
assert_eq!(shard.shard_id, 0);
assert!(!shard.primary_node.is_empty());
}
}

View File

@@ -0,0 +1,443 @@
//! Sharding logic for distributed vector storage
//!
//! Implements consistent hashing for shard distribution and routing.
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use tracing::debug;
const VIRTUAL_NODE_COUNT: usize = 150;
/// Consistent hash ring for node assignment
#[derive(Debug)]
pub struct ConsistentHashRing {
/// Virtual nodes on the ring (hash -> node_id)
ring: BTreeMap<u64, String>,
/// Real nodes in the cluster
nodes: HashMap<String, usize>,
/// Replication factor
replication_factor: usize,
}
impl ConsistentHashRing {
/// Create a new consistent hash ring
pub fn new(replication_factor: usize) -> Self {
Self {
ring: BTreeMap::new(),
nodes: HashMap::new(),
replication_factor,
}
}
/// Add a node to the ring
pub fn add_node(&mut self, node_id: String) {
if self.nodes.contains_key(&node_id) {
return;
}
// Add virtual nodes for better distribution
for i in 0..VIRTUAL_NODE_COUNT {
let virtual_key = format!("{}:{}", node_id, i);
let hash = Self::hash_key(&virtual_key);
self.ring.insert(hash, node_id.clone());
}
self.nodes.insert(node_id, VIRTUAL_NODE_COUNT);
debug!(
"Added node to hash ring with {} virtual nodes",
VIRTUAL_NODE_COUNT
);
}
/// Remove a node from the ring
pub fn remove_node(&mut self, node_id: &str) {
if !self.nodes.contains_key(node_id) {
return;
}
// Remove all virtual nodes
self.ring.retain(|_, v| v != node_id);
self.nodes.remove(node_id);
debug!("Removed node from hash ring");
}
/// Get nodes responsible for a key
pub fn get_nodes(&self, key: &str, count: usize) -> Vec<String> {
if self.ring.is_empty() {
return Vec::new();
}
let hash = Self::hash_key(key);
let mut nodes = Vec::new();
let mut seen = std::collections::HashSet::new();
// Find the first node on or after the hash
for (_, node_id) in self.ring.range(hash..) {
if seen.insert(node_id.clone()) {
nodes.push(node_id.clone());
if nodes.len() >= count {
return nodes;
}
}
}
// Wrap around to the beginning if needed
for (_, node_id) in self.ring.iter() {
if seen.insert(node_id.clone()) {
nodes.push(node_id.clone());
if nodes.len() >= count {
return nodes;
}
}
}
nodes
}
/// Get the primary node for a key
pub fn get_primary_node(&self, key: &str) -> Option<String> {
self.get_nodes(key, 1).first().cloned()
}
/// Hash a key to a u64
fn hash_key(key: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
/// Get the number of real nodes
pub fn node_count(&self) -> usize {
self.nodes.len()
}
/// List all real nodes
pub fn list_nodes(&self) -> Vec<String> {
self.nodes.keys().cloned().collect()
}
}
/// Routes queries to the correct shard
pub struct ShardRouter {
/// Total number of shards
shard_count: u32,
/// Shard assignment cache
cache: Arc<RwLock<HashMap<String, u32>>>,
}
impl ShardRouter {
/// Create a new shard router
pub fn new(shard_count: u32) -> Self {
Self {
shard_count,
cache: Arc::new(RwLock::new(HashMap::new())),
}
}
/// Get the shard ID for a key using jump consistent hashing
pub fn get_shard(&self, key: &str) -> u32 {
// Check cache first
{
let cache = self.cache.read();
if let Some(&shard_id) = cache.get(key) {
return shard_id;
}
}
// Calculate using jump consistent hash
let shard_id = self.jump_consistent_hash(key, self.shard_count);
// Update cache
{
let mut cache = self.cache.write();
cache.insert(key.to_string(), shard_id);
}
shard_id
}
/// Jump consistent hash algorithm
/// Provides minimal key migration on shard count changes
fn jump_consistent_hash(&self, key: &str, num_buckets: u32) -> u32 {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let mut hash = hasher.finish();
let mut b: i64 = -1;
let mut j: i64 = 0;
while j < num_buckets as i64 {
b = j;
hash = hash.wrapping_mul(2862933555777941757).wrapping_add(1);
j = ((b.wrapping_add(1) as f64)
* ((1i64 << 31) as f64 / ((hash >> 33).wrapping_add(1) as f64)))
as i64;
}
b as u32
}
/// Get shard ID for a vector ID
pub fn get_shard_for_vector(&self, vector_id: &str) -> u32 {
self.get_shard(vector_id)
}
/// Get shard IDs for a range query (may span multiple shards)
pub fn get_shards_for_range(&self, _start: &str, _end: &str) -> Vec<u32> {
// For range queries, we might need to check multiple shards
// For simplicity, return all shards (can be optimized based on key distribution)
(0..self.shard_count).collect()
}
/// Clear the routing cache
pub fn clear_cache(&self) {
let mut cache = self.cache.write();
cache.clear();
}
/// Get cache statistics
pub fn cache_stats(&self) -> CacheStats {
let cache = self.cache.read();
CacheStats {
entries: cache.len(),
shard_count: self.shard_count as usize,
}
}
}
/// Cache statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheStats {
pub entries: usize,
pub shard_count: usize,
}
/// Shard migration manager
pub struct ShardMigration {
/// Source shard ID
pub source_shard: u32,
/// Target shard ID
pub target_shard: u32,
/// Migration progress (0.0 to 1.0)
pub progress: f64,
/// Keys migrated
pub keys_migrated: usize,
/// Total keys to migrate
pub total_keys: usize,
}
impl ShardMigration {
/// Create a new shard migration
pub fn new(source_shard: u32, target_shard: u32, total_keys: usize) -> Self {
Self {
source_shard,
target_shard,
progress: 0.0,
keys_migrated: 0,
total_keys,
}
}
/// Update migration progress
pub fn update_progress(&mut self, keys_migrated: usize) {
self.keys_migrated = keys_migrated;
self.progress = if self.total_keys > 0 {
keys_migrated as f64 / self.total_keys as f64
} else {
1.0
};
}
/// Check if migration is complete
pub fn is_complete(&self) -> bool {
self.progress >= 1.0 || self.keys_migrated >= self.total_keys
}
}
/// Load balancer for shard distribution
pub struct LoadBalancer {
/// Shard load statistics (shard_id -> load)
loads: Arc<RwLock<HashMap<u32, f64>>>,
}
impl LoadBalancer {
/// Create a new load balancer
pub fn new() -> Self {
Self {
loads: Arc::new(RwLock::new(HashMap::new())),
}
}
/// Update load for a shard
pub fn update_load(&self, shard_id: u32, load: f64) {
let mut loads = self.loads.write();
loads.insert(shard_id, load);
}
/// Get load for a shard
pub fn get_load(&self, shard_id: u32) -> f64 {
let loads = self.loads.read();
loads.get(&shard_id).copied().unwrap_or(0.0)
}
/// Get the least loaded shard
pub fn get_least_loaded_shard(&self, shard_ids: &[u32]) -> Option<u32> {
let loads = self.loads.read();
shard_ids
.iter()
.min_by(|&&a, &&b| {
let load_a = loads.get(&a).copied().unwrap_or(0.0);
let load_b = loads.get(&b).copied().unwrap_or(0.0);
load_a
.partial_cmp(&load_b)
.unwrap_or(std::cmp::Ordering::Equal)
})
.copied()
}
/// Get load statistics
pub fn get_stats(&self) -> LoadStats {
let loads = self.loads.read();
let total: f64 = loads.values().sum();
let count = loads.len();
let avg = if count > 0 { total / count as f64 } else { 0.0 };
let max = loads.values().copied().fold(f64::NEG_INFINITY, f64::max);
let min = loads.values().copied().fold(f64::INFINITY, f64::min);
LoadStats {
total_load: total,
avg_load: avg,
max_load: if max.is_finite() { max } else { 0.0 },
min_load: if min.is_finite() { min } else { 0.0 },
shard_count: count,
}
}
}
impl Default for LoadBalancer {
fn default() -> Self {
Self::new()
}
}
/// Load statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoadStats {
pub total_load: f64,
pub avg_load: f64,
pub max_load: f64,
pub min_load: f64,
pub shard_count: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_consistent_hash_ring() {
let mut ring = ConsistentHashRing::new(3);
ring.add_node("node1".to_string());
ring.add_node("node2".to_string());
ring.add_node("node3".to_string());
assert_eq!(ring.node_count(), 3);
let nodes = ring.get_nodes("test-key", 3);
assert_eq!(nodes.len(), 3);
// Test primary node selection
let primary = ring.get_primary_node("test-key");
assert!(primary.is_some());
}
#[test]
fn test_consistent_hashing_distribution() {
let mut ring = ConsistentHashRing::new(3);
ring.add_node("node1".to_string());
ring.add_node("node2".to_string());
ring.add_node("node3".to_string());
let mut distribution: HashMap<String, usize> = HashMap::new();
// Test distribution across many keys
for i in 0..1000 {
let key = format!("key{}", i);
if let Some(node) = ring.get_primary_node(&key) {
*distribution.entry(node).or_insert(0) += 1;
}
}
// Each node should get roughly 1/3 of the keys (within 20% tolerance)
for count in distribution.values() {
let ratio = *count as f64 / 1000.0;
assert!(ratio > 0.2 && ratio < 0.5, "Distribution ratio: {}", ratio);
}
}
#[test]
fn test_shard_router() {
let router = ShardRouter::new(16);
let shard1 = router.get_shard("test-key-1");
let shard2 = router.get_shard("test-key-1"); // Should be cached
assert_eq!(shard1, shard2);
assert!(shard1 < 16);
let stats = router.cache_stats();
assert_eq!(stats.entries, 1);
}
#[test]
fn test_jump_consistent_hash() {
let router = ShardRouter::new(10);
// Same key should always map to same shard
let shard1 = router.get_shard("consistent-key");
let shard2 = router.get_shard("consistent-key");
assert_eq!(shard1, shard2);
}
#[test]
fn test_shard_migration() {
let mut migration = ShardMigration::new(0, 1, 100);
assert!(!migration.is_complete());
assert_eq!(migration.progress, 0.0);
migration.update_progress(50);
assert_eq!(migration.progress, 0.5);
migration.update_progress(100);
assert!(migration.is_complete());
}
#[test]
fn test_load_balancer() {
let balancer = LoadBalancer::new();
balancer.update_load(0, 0.5);
balancer.update_load(1, 0.8);
balancer.update_load(2, 0.3);
let least_loaded = balancer.get_least_loaded_shard(&[0, 1, 2]);
assert_eq!(least_loaded, Some(2));
let stats = balancer.get_stats();
assert_eq!(stats.shard_count, 3);
assert!(stats.avg_load > 0.0);
}
}