Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,359 @@
//! Leader election implementation
//!
//! Implements the Raft leader election algorithm including:
//! - Randomized election timeouts
//! - Vote request handling
//! - Term management
//! - Split vote prevention
use crate::{NodeId, Term};
use rand::Rng;
use std::time::Duration;
use tokio::time::Instant;
/// Election timer with randomized timeout
#[derive(Debug)]
pub struct ElectionTimer {
/// Last time the timer was reset
last_reset: Instant,
/// Current timeout duration
timeout: Duration,
/// Minimum election timeout (milliseconds)
min_timeout_ms: u64,
/// Maximum election timeout (milliseconds)
max_timeout_ms: u64,
}
impl ElectionTimer {
/// Create a new election timer
pub fn new(min_timeout_ms: u64, max_timeout_ms: u64) -> Self {
let timeout = Self::random_timeout(min_timeout_ms, max_timeout_ms);
Self {
last_reset: Instant::now(),
timeout,
min_timeout_ms,
max_timeout_ms,
}
}
/// Create with default timeouts (150-300ms as per Raft paper)
pub fn with_defaults() -> Self {
Self::new(150, 300)
}
/// Reset the election timer with a new random timeout
pub fn reset(&mut self) {
self.last_reset = Instant::now();
self.timeout = Self::random_timeout(self.min_timeout_ms, self.max_timeout_ms);
}
/// Check if the election timeout has elapsed
pub fn is_elapsed(&self) -> bool {
self.last_reset.elapsed() >= self.timeout
}
/// Get time remaining until timeout
pub fn time_remaining(&self) -> Duration {
self.timeout.saturating_sub(self.last_reset.elapsed())
}
/// Generate a random timeout duration
fn random_timeout(min_ms: u64, max_ms: u64) -> Duration {
let mut rng = rand::thread_rng();
let timeout_ms = rng.gen_range(min_ms..=max_ms);
Duration::from_millis(timeout_ms)
}
/// Get the current timeout duration
pub fn timeout(&self) -> Duration {
self.timeout
}
}
/// Vote tracker for an election
#[derive(Debug)]
pub struct VoteTracker {
/// Votes received in favor
votes_received: Vec<NodeId>,
/// Total number of nodes in the cluster
cluster_size: usize,
/// Required number of votes for quorum
quorum_size: usize,
}
impl VoteTracker {
/// Create a new vote tracker
pub fn new(cluster_size: usize) -> Self {
let quorum_size = (cluster_size / 2) + 1;
Self {
votes_received: Vec::new(),
cluster_size,
quorum_size,
}
}
/// Record a vote from a node
pub fn record_vote(&mut self, node_id: NodeId) {
if !self.votes_received.contains(&node_id) {
self.votes_received.push(node_id);
}
}
/// Check if quorum has been reached
pub fn has_quorum(&self) -> bool {
self.votes_received.len() >= self.quorum_size
}
/// Get the number of votes received
pub fn vote_count(&self) -> usize {
self.votes_received.len()
}
/// Get the required quorum size
pub fn quorum_size(&self) -> usize {
self.quorum_size
}
/// Reset the vote tracker
pub fn reset(&mut self) {
self.votes_received.clear();
}
}
/// Election state machine
#[derive(Debug)]
pub struct ElectionState {
/// Current election timer
pub timer: ElectionTimer,
/// Vote tracker for current election
pub votes: VoteTracker,
/// Current term being contested
pub current_term: Term,
}
impl ElectionState {
/// Create a new election state
pub fn new(cluster_size: usize, min_timeout_ms: u64, max_timeout_ms: u64) -> Self {
Self {
timer: ElectionTimer::new(min_timeout_ms, max_timeout_ms),
votes: VoteTracker::new(cluster_size),
current_term: 0,
}
}
/// Start a new election for the given term
pub fn start_election(&mut self, term: Term, self_id: &NodeId) {
self.current_term = term;
self.votes.reset();
self.votes.record_vote(self_id.clone());
self.timer.reset();
}
/// Reset the election timer (when receiving valid heartbeat)
pub fn reset_timer(&mut self) {
self.timer.reset();
}
/// Check if election timeout has occurred
pub fn should_start_election(&self) -> bool {
self.timer.is_elapsed()
}
/// Record a vote and check if we won
pub fn record_vote(&mut self, node_id: NodeId) -> bool {
self.votes.record_vote(node_id);
self.votes.has_quorum()
}
/// Update cluster size
pub fn update_cluster_size(&mut self, cluster_size: usize) {
self.votes = VoteTracker::new(cluster_size);
}
}
/// Vote request validation
pub struct VoteValidator;
impl VoteValidator {
/// Validate if a vote request should be granted
///
/// A vote should be granted if:
/// 1. The candidate's term is at least as current as receiver's term
/// 2. The receiver hasn't voted in this term, or has voted for this candidate
/// 3. The candidate's log is at least as up-to-date as receiver's log
pub fn should_grant_vote(
receiver_term: Term,
receiver_voted_for: &Option<NodeId>,
receiver_last_log_index: u64,
receiver_last_log_term: Term,
candidate_id: &NodeId,
candidate_term: Term,
candidate_last_log_index: u64,
candidate_last_log_term: Term,
) -> bool {
// Reject if candidate's term is older
if candidate_term < receiver_term {
return false;
}
// Check if we can vote for this candidate
let can_vote = match receiver_voted_for {
None => true,
Some(voted_for) => voted_for == candidate_id,
};
if !can_vote {
return false;
}
// Check if candidate's log is at least as up-to-date
Self::is_log_up_to_date(
candidate_last_log_term,
candidate_last_log_index,
receiver_last_log_term,
receiver_last_log_index,
)
}
/// Check if candidate's log is at least as up-to-date as receiver's
///
/// Raft determines which of two logs is more up-to-date by comparing
/// the index and term of the last entries in the logs. If the logs have
/// last entries with different terms, then the log with the later term
/// is more up-to-date. If the logs end with the same term, then whichever
/// log is longer is more up-to-date.
fn is_log_up_to_date(
candidate_last_term: Term,
candidate_last_index: u64,
receiver_last_term: Term,
receiver_last_index: u64,
) -> bool {
if candidate_last_term != receiver_last_term {
candidate_last_term >= receiver_last_term
} else {
candidate_last_index >= receiver_last_index
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread::sleep;
#[test]
fn test_election_timer() {
let mut timer = ElectionTimer::new(50, 100);
assert!(!timer.is_elapsed());
sleep(Duration::from_millis(150));
assert!(timer.is_elapsed());
timer.reset();
assert!(!timer.is_elapsed());
}
#[test]
fn test_vote_tracker() {
let mut tracker = VoteTracker::new(5);
assert_eq!(tracker.quorum_size(), 3);
assert!(!tracker.has_quorum());
tracker.record_vote("node1".to_string());
assert!(!tracker.has_quorum());
tracker.record_vote("node2".to_string());
assert!(!tracker.has_quorum());
tracker.record_vote("node3".to_string());
assert!(tracker.has_quorum());
}
#[test]
fn test_election_state() {
let mut state = ElectionState::new(5, 50, 100);
let self_id = "node1".to_string();
state.start_election(1, &self_id);
assert_eq!(state.current_term, 1);
assert_eq!(state.votes.vote_count(), 1);
let won = state.record_vote("node2".to_string());
assert!(!won);
let won = state.record_vote("node3".to_string());
assert!(won);
}
#[test]
fn test_vote_validation() {
// Should grant vote when candidate is up-to-date
assert!(VoteValidator::should_grant_vote(
1,
&None,
10,
1,
&"candidate".to_string(),
2,
10,
1
));
// Should reject when candidate term is older
assert!(!VoteValidator::should_grant_vote(
2,
&None,
10,
1,
&"candidate".to_string(),
1,
10,
1
));
// Should reject when already voted for someone else
assert!(!VoteValidator::should_grant_vote(
1,
&Some("other".to_string()),
10,
1,
&"candidate".to_string(),
1,
10,
1
));
// Should grant when voted for same candidate
assert!(VoteValidator::should_grant_vote(
1,
&Some("candidate".to_string()),
10,
1,
&"candidate".to_string(),
1,
10,
1
));
}
#[test]
fn test_log_up_to_date() {
// Higher term is more up-to-date
assert!(VoteValidator::is_log_up_to_date(2, 5, 1, 10));
assert!(!VoteValidator::is_log_up_to_date(1, 10, 2, 5));
// Same term, longer log is more up-to-date
assert!(VoteValidator::is_log_up_to_date(1, 10, 1, 5));
assert!(!VoteValidator::is_log_up_to_date(1, 5, 1, 10));
// Same term and length is up-to-date
assert!(VoteValidator::is_log_up_to_date(1, 10, 1, 10));
}
}

View File

@@ -0,0 +1,72 @@
//! Raft consensus implementation for ruvector distributed metadata
//!
//! This crate provides a production-ready Raft consensus implementation
//! following the Raft paper specification for managing distributed metadata
//! in the ruvector vector database.
pub mod election;
pub mod log;
pub mod node;
pub mod rpc;
pub mod state;
pub use node::{RaftNode, RaftNodeConfig};
pub use rpc::{
AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest, InstallSnapshotResponse,
RequestVoteRequest, RequestVoteResponse,
};
pub use state::{LeaderState, PersistentState, RaftState, VolatileState};
use thiserror::Error;
/// Result type for Raft operations
pub type RaftResult<T> = Result<T, RaftError>;
/// Errors that can occur during Raft operations
#[derive(Debug, Error)]
pub enum RaftError {
#[error("Node is not the leader")]
NotLeader,
#[error("No leader available")]
NoLeader,
#[error("Invalid term: {0}")]
InvalidTerm(u64),
#[error("Invalid log index: {0}")]
InvalidLogIndex(u64),
#[error("Serialization error: {0}")]
SerializationEncodeError(#[from] bincode::error::EncodeError),
#[error("Deserialization error: {0}")]
SerializationDecodeError(#[from] bincode::error::DecodeError),
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
#[error("Election timeout")]
ElectionTimeout,
#[error("Log inconsistency detected")]
LogInconsistency,
#[error("Snapshot installation failed: {0}")]
SnapshotFailed(String),
#[error("Configuration error: {0}")]
ConfigError(String),
#[error("Internal error: {0}")]
Internal(String),
}
/// Node identifier type
pub type NodeId = String;
/// Term number in Raft consensus
pub type Term = u64;
/// Log index in Raft log
pub type LogIndex = u64;

View File

@@ -0,0 +1,350 @@
//! Raft log implementation
//!
//! Manages the replicated log with support for:
//! - Appending entries
//! - Truncation and conflict resolution
//! - Snapshots and compaction
//! - Persistence
use crate::{LogIndex, RaftError, RaftResult, Term};
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
/// A single entry in the Raft log
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct LogEntry {
/// Term when entry was received by leader
pub term: Term,
/// Index position in the log
pub index: LogIndex,
/// State machine command
pub command: Vec<u8>,
}
impl LogEntry {
/// Create a new log entry
pub fn new(term: Term, index: LogIndex, command: Vec<u8>) -> Self {
Self {
term,
index,
command,
}
}
}
/// Snapshot metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Snapshot {
/// Index of last entry in snapshot
pub last_included_index: LogIndex,
/// Term of last entry in snapshot
pub last_included_term: Term,
/// Snapshot data
pub data: Vec<u8>,
/// Configuration at the time of snapshot
pub configuration: Vec<String>,
}
/// The Raft replicated log
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RaftLog {
/// Log entries (index starts at 1)
entries: VecDeque<LogEntry>,
/// Current snapshot (if any)
snapshot: Option<Snapshot>,
/// Base index from snapshot (0 if no snapshot)
base_index: LogIndex,
/// Base term from snapshot (0 if no snapshot)
base_term: Term,
}
impl RaftLog {
/// Create a new empty log
pub fn new() -> Self {
Self {
entries: VecDeque::new(),
snapshot: None,
base_index: 0,
base_term: 0,
}
}
/// Get the index of the last log entry
pub fn last_index(&self) -> LogIndex {
if let Some(entry) = self.entries.back() {
entry.index
} else {
self.base_index
}
}
/// Get the term of the last log entry
pub fn last_term(&self) -> Term {
if let Some(entry) = self.entries.back() {
entry.term
} else {
self.base_term
}
}
/// Get the term at a specific index
pub fn term_at(&self, index: LogIndex) -> Option<Term> {
if index == self.base_index {
return Some(self.base_term);
}
if index < self.base_index {
return None;
}
let offset = (index - self.base_index - 1) as usize;
self.entries.get(offset).map(|entry| entry.term)
}
/// Get a log entry at a specific index
pub fn get(&self, index: LogIndex) -> Option<&LogEntry> {
if index <= self.base_index {
return None;
}
let offset = (index - self.base_index - 1) as usize;
self.entries.get(offset)
}
/// Get entries starting from an index
pub fn entries_from(&self, start_index: LogIndex) -> Vec<LogEntry> {
if start_index <= self.base_index {
return self.entries.iter().cloned().collect();
}
let offset = (start_index - self.base_index - 1) as usize;
self.entries.iter().skip(offset).cloned().collect()
}
/// Append a new entry to the log
pub fn append(&mut self, term: Term, command: Vec<u8>) -> LogIndex {
let index = self.last_index() + 1;
let entry = LogEntry::new(term, index, command);
self.entries.push_back(entry);
index
}
/// Append multiple entries (for replication)
pub fn append_entries(&mut self, entries: Vec<LogEntry>) -> RaftResult<()> {
for entry in entries {
// Verify index is sequential
let expected_index = self.last_index() + 1;
if entry.index != expected_index {
return Err(RaftError::LogInconsistency);
}
self.entries.push_back(entry);
}
Ok(())
}
/// Truncate log from a given index (delete entries >= index)
pub fn truncate_from(&mut self, index: LogIndex) -> RaftResult<()> {
if index <= self.base_index {
return Err(RaftError::InvalidLogIndex(index));
}
let offset = (index - self.base_index - 1) as usize;
self.entries.truncate(offset);
Ok(())
}
/// Check if log contains an entry at index with the given term
pub fn matches(&self, index: LogIndex, term: Term) -> bool {
if index == 0 {
return true;
}
if index == self.base_index {
return term == self.base_term;
}
match self.term_at(index) {
Some(entry_term) => entry_term == term,
None => false,
}
}
/// Install a snapshot and compact the log
pub fn install_snapshot(&mut self, snapshot: Snapshot) -> RaftResult<()> {
let last_index = snapshot.last_included_index;
let last_term = snapshot.last_included_term;
// Remove all entries up to and including the snapshot's last index
while let Some(entry) = self.entries.front() {
if entry.index <= last_index {
self.entries.pop_front();
} else {
break;
}
}
self.base_index = last_index;
self.base_term = last_term;
self.snapshot = Some(snapshot);
Ok(())
}
/// Create a snapshot up to the given index
pub fn create_snapshot(
&mut self,
up_to_index: LogIndex,
data: Vec<u8>,
configuration: Vec<String>,
) -> RaftResult<Snapshot> {
if up_to_index <= self.base_index {
return Err(RaftError::InvalidLogIndex(up_to_index));
}
let term = self
.term_at(up_to_index)
.ok_or(RaftError::InvalidLogIndex(up_to_index))?;
let snapshot = Snapshot {
last_included_index: up_to_index,
last_included_term: term,
data,
configuration,
};
// Compact the log by removing entries before the snapshot
self.install_snapshot(snapshot.clone())?;
Ok(snapshot)
}
/// Get the current snapshot
pub fn snapshot(&self) -> Option<&Snapshot> {
self.snapshot.as_ref()
}
/// Get the number of entries in memory
pub fn len(&self) -> usize {
self.entries.len()
}
/// Check if the log is empty
pub fn is_empty(&self) -> bool {
self.entries.is_empty() && self.base_index == 0
}
/// Get the base index from snapshot
pub fn base_index(&self) -> LogIndex {
self.base_index
}
/// Get the base term from snapshot
pub fn base_term(&self) -> Term {
self.base_term
}
}
impl Default for RaftLog {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_log_append() {
let mut log = RaftLog::new();
assert_eq!(log.last_index(), 0);
let idx1 = log.append(1, b"cmd1".to_vec());
assert_eq!(idx1, 1);
assert_eq!(log.last_index(), 1);
assert_eq!(log.last_term(), 1);
let idx2 = log.append(1, b"cmd2".to_vec());
assert_eq!(idx2, 2);
assert_eq!(log.last_index(), 2);
}
#[test]
fn test_log_get() {
let mut log = RaftLog::new();
log.append(1, b"cmd1".to_vec());
log.append(1, b"cmd2".to_vec());
log.append(2, b"cmd3".to_vec());
let entry = log.get(2).unwrap();
assert_eq!(entry.term, 1);
assert_eq!(entry.command, b"cmd2");
assert!(log.get(0).is_none());
assert!(log.get(10).is_none());
}
#[test]
fn test_log_truncate() {
let mut log = RaftLog::new();
log.append(1, b"cmd1".to_vec());
log.append(1, b"cmd2".to_vec());
log.append(2, b"cmd3".to_vec());
log.truncate_from(2).unwrap();
assert_eq!(log.last_index(), 1);
assert!(log.get(2).is_none());
}
#[test]
fn test_log_matches() {
let mut log = RaftLog::new();
log.append(1, b"cmd1".to_vec());
log.append(1, b"cmd2".to_vec());
log.append(2, b"cmd3".to_vec());
assert!(log.matches(1, 1));
assert!(log.matches(2, 1));
assert!(log.matches(3, 2));
assert!(!log.matches(3, 1));
assert!(!log.matches(10, 1));
}
#[test]
fn test_snapshot_creation() {
let mut log = RaftLog::new();
log.append(1, b"cmd1".to_vec());
log.append(1, b"cmd2".to_vec());
log.append(2, b"cmd3".to_vec());
let snapshot = log
.create_snapshot(2, b"state".to_vec(), vec!["node1".to_string()])
.unwrap();
assert_eq!(snapshot.last_included_index, 2);
assert_eq!(snapshot.last_included_term, 1);
assert_eq!(log.base_index(), 2);
assert_eq!(log.len(), 1); // Only entry 3 remains
}
#[test]
fn test_entries_from() {
let mut log = RaftLog::new();
log.append(1, b"cmd1".to_vec());
log.append(1, b"cmd2".to_vec());
log.append(2, b"cmd3".to_vec());
let entries = log.entries_from(2);
assert_eq!(entries.len(), 2);
assert_eq!(entries[0].index, 2);
assert_eq!(entries[1].index, 3);
}
}

View File

@@ -0,0 +1,631 @@
//! Raft node implementation
//!
//! Coordinates all Raft components:
//! - State machine management
//! - RPC message handling
//! - Log replication
//! - Leader election
//! - Client request processing
use crate::{
election::{ElectionState, VoteValidator},
rpc::{
AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest,
InstallSnapshotResponse, RaftMessage, RequestVoteRequest, RequestVoteResponse,
},
state::{LeaderState, PersistentState, RaftState, VolatileState},
LogIndex, NodeId, RaftError, RaftResult, Term,
};
use parking_lot::RwLock;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::time::{interval, sleep};
use tracing::{debug, error, info, warn};
/// Configuration for a Raft node
#[derive(Debug, Clone)]
pub struct RaftNodeConfig {
/// This node's ID
pub node_id: NodeId,
/// IDs of all cluster members (including self)
pub cluster_members: Vec<NodeId>,
/// Minimum election timeout (milliseconds)
pub election_timeout_min: u64,
/// Maximum election timeout (milliseconds)
pub election_timeout_max: u64,
/// Heartbeat interval (milliseconds)
pub heartbeat_interval: u64,
/// Maximum entries per AppendEntries RPC
pub max_entries_per_message: usize,
/// Snapshot chunk size (bytes)
pub snapshot_chunk_size: usize,
}
impl RaftNodeConfig {
/// Create a new configuration with defaults
pub fn new(node_id: NodeId, cluster_members: Vec<NodeId>) -> Self {
Self {
node_id,
cluster_members,
election_timeout_min: 150,
election_timeout_max: 300,
heartbeat_interval: 50,
max_entries_per_message: 100,
snapshot_chunk_size: 64 * 1024, // 64KB
}
}
}
/// Command to apply to the state machine
#[derive(Debug, Clone)]
pub struct Command {
pub data: Vec<u8>,
}
/// Result of applying a command
#[derive(Debug, Clone)]
pub struct CommandResult {
pub index: LogIndex,
pub term: Term,
}
/// Internal messages for the Raft node
#[derive(Debug)]
enum InternalMessage {
/// RPC message from another node
Rpc { from: NodeId, message: RaftMessage },
/// Client command to replicate
ClientCommand {
command: Command,
response_tx: mpsc::Sender<RaftResult<CommandResult>>,
},
/// Election timeout fired
ElectionTimeout,
/// Heartbeat timeout fired
HeartbeatTimeout,
}
/// The Raft consensus node
pub struct RaftNode {
/// Configuration
config: RaftNodeConfig,
/// Persistent state
persistent: Arc<RwLock<PersistentState>>,
/// Volatile state
volatile: Arc<RwLock<VolatileState>>,
/// Current Raft state (Follower, Candidate, Leader)
state: Arc<RwLock<RaftState>>,
/// Leader-specific state (only valid when state is Leader)
leader_state: Arc<RwLock<Option<LeaderState>>>,
/// Election state
election_state: Arc<RwLock<ElectionState>>,
/// Current leader ID (if known)
current_leader: Arc<RwLock<Option<NodeId>>>,
/// Channel for internal messages
internal_tx: mpsc::UnboundedSender<InternalMessage>,
internal_rx: Arc<RwLock<mpsc::UnboundedReceiver<InternalMessage>>>,
}
impl RaftNode {
/// Create a new Raft node
pub fn new(config: RaftNodeConfig) -> Self {
let (internal_tx, internal_rx) = mpsc::unbounded_channel();
let cluster_size = config.cluster_members.len();
Self {
persistent: Arc::new(RwLock::new(PersistentState::new())),
volatile: Arc::new(RwLock::new(VolatileState::new())),
state: Arc::new(RwLock::new(RaftState::Follower)),
leader_state: Arc::new(RwLock::new(None)),
election_state: Arc::new(RwLock::new(ElectionState::new(
cluster_size,
config.election_timeout_min,
config.election_timeout_max,
))),
current_leader: Arc::new(RwLock::new(None)),
config,
internal_tx,
internal_rx: Arc::new(RwLock::new(internal_rx)),
}
}
/// Start the Raft node
pub async fn start(self: Arc<Self>) {
info!("Starting Raft node: {}", self.config.node_id);
// Spawn election timer task
self.clone().spawn_election_timer();
// Spawn heartbeat timer task (for leaders)
self.clone().spawn_heartbeat_timer();
// Main message processing loop
self.run().await;
}
/// Main message processing loop
async fn run(self: Arc<Self>) {
loop {
let message = {
let mut rx = self.internal_rx.write();
rx.recv().await
};
match message {
Some(InternalMessage::Rpc { from, message }) => {
self.handle_rpc_message(from, message).await;
}
Some(InternalMessage::ClientCommand {
command,
response_tx,
}) => {
self.handle_client_command(command, response_tx).await;
}
Some(InternalMessage::ElectionTimeout) => {
self.handle_election_timeout().await;
}
Some(InternalMessage::HeartbeatTimeout) => {
self.handle_heartbeat_timeout().await;
}
None => {
warn!("Internal channel closed, stopping node");
break;
}
}
}
}
/// Handle RPC message from another node
async fn handle_rpc_message(&self, from: NodeId, message: RaftMessage) {
// Update term if necessary
let message_term = message.term();
let current_term = self.persistent.read().current_term;
if message_term > current_term {
self.step_down(message_term).await;
}
match message {
RaftMessage::AppendEntriesRequest(req) => {
let response = self.handle_append_entries(req).await;
// TODO: Send response back to sender
debug!("AppendEntries response to {}: {:?}", from, response);
}
RaftMessage::AppendEntriesResponse(resp) => {
self.handle_append_entries_response(from, resp).await;
}
RaftMessage::RequestVoteRequest(req) => {
let response = self.handle_request_vote(req).await;
// TODO: Send response back to sender
debug!("RequestVote response to {}: {:?}", from, response);
}
RaftMessage::RequestVoteResponse(resp) => {
self.handle_request_vote_response(from, resp).await;
}
RaftMessage::InstallSnapshotRequest(req) => {
let response = self.handle_install_snapshot(req).await;
// TODO: Send response back to sender
debug!("InstallSnapshot response to {}: {:?}", from, response);
}
RaftMessage::InstallSnapshotResponse(resp) => {
self.handle_install_snapshot_response(from, resp).await;
}
}
}
/// Handle AppendEntries RPC
async fn handle_append_entries(&self, req: AppendEntriesRequest) -> AppendEntriesResponse {
let mut persistent = self.persistent.write();
let mut volatile = self.volatile.write();
// Reply false if term < currentTerm
if req.term < persistent.current_term {
return AppendEntriesResponse::failure(persistent.current_term, None, None);
}
// Reset election timer
self.election_state.write().reset_timer();
*self.current_leader.write() = Some(req.leader_id.clone());
// Reply false if log doesn't contain an entry at prevLogIndex with prevLogTerm
if !persistent
.log
.matches(req.prev_log_index, req.prev_log_term)
{
let conflict_index = req.prev_log_index;
let conflict_term = persistent.log.term_at(conflict_index);
return AppendEntriesResponse::failure(
persistent.current_term,
Some(conflict_index),
conflict_term,
);
}
// Append new entries
if !req.entries.is_empty() {
// Delete conflicting entries and append new ones
let mut index = req.prev_log_index + 1;
for entry in &req.entries {
if let Some(existing_term) = persistent.log.term_at(index) {
if existing_term != entry.term {
// Conflict found, truncate from here
let _ = persistent.log.truncate_from(index);
}
}
index += 1;
}
// Append entries
if let Err(e) = persistent.log.append_entries(req.entries.clone()) {
error!("Failed to append entries: {}", e);
return AppendEntriesResponse::failure(persistent.current_term, None, None);
}
}
// Update commit index
if req.leader_commit > volatile.commit_index {
let last_new_entry = if req.entries.is_empty() {
req.prev_log_index
} else {
req.entries.last().unwrap().index
};
volatile.update_commit_index(std::cmp::min(req.leader_commit, last_new_entry));
}
AppendEntriesResponse::success(persistent.current_term, persistent.log.last_index())
}
/// Handle AppendEntries response
async fn handle_append_entries_response(&self, from: NodeId, resp: AppendEntriesResponse) {
if !self.state.read().is_leader() {
return;
}
let persistent = self.persistent.write();
let mut leader_state_guard = self.leader_state.write();
if let Some(leader_state) = leader_state_guard.as_mut() {
if resp.success {
// Update next_index and match_index
if let Some(match_index) = resp.match_index {
leader_state.update_replication(&from, match_index);
// Update commit index
let new_commit = leader_state.calculate_commit_index();
let mut volatile = self.volatile.write();
if new_commit > volatile.commit_index {
// Verify the entry is from current term
if let Some(term) = persistent.log.term_at(new_commit) {
if term == persistent.current_term {
volatile.update_commit_index(new_commit);
info!("Updated commit index to {}", new_commit);
}
}
}
}
} else {
// Decrement next_index and retry
leader_state.decrement_next_index(&from);
debug!("Replication failed for {}, decrementing next_index", from);
}
}
}
/// Handle RequestVote RPC
async fn handle_request_vote(&self, req: RequestVoteRequest) -> RequestVoteResponse {
let mut persistent = self.persistent.write();
// Reply false if term < currentTerm
if req.term < persistent.current_term {
return RequestVoteResponse::denied(persistent.current_term);
}
let last_log_index = persistent.log.last_index();
let last_log_term = persistent.log.last_term();
// Check if we should grant vote
let should_grant = VoteValidator::should_grant_vote(
persistent.current_term,
&persistent.voted_for,
last_log_index,
last_log_term,
&req.candidate_id,
req.term,
req.last_log_index,
req.last_log_term,
);
if should_grant {
persistent.vote_for(req.candidate_id.clone());
self.election_state.write().reset_timer();
info!("Granted vote to {} for term {}", req.candidate_id, req.term);
RequestVoteResponse::granted(persistent.current_term)
} else {
debug!("Denied vote to {} for term {}", req.candidate_id, req.term);
RequestVoteResponse::denied(persistent.current_term)
}
}
/// Handle RequestVote response
async fn handle_request_vote_response(&self, from: NodeId, resp: RequestVoteResponse) {
if !self.state.read().is_candidate() {
return;
}
let current_term = self.persistent.read().current_term;
if resp.term != current_term {
return;
}
if resp.vote_granted {
let won_election = self.election_state.write().record_vote(from.clone());
if won_election {
info!("Won election for term {}", current_term);
self.become_leader().await;
}
}
}
/// Handle InstallSnapshot RPC
async fn handle_install_snapshot(
&self,
req: InstallSnapshotRequest,
) -> InstallSnapshotResponse {
let persistent = self.persistent.write();
if req.term < persistent.current_term {
return InstallSnapshotResponse::failure(persistent.current_term);
}
// TODO: Implement snapshot installation
// For now, just acknowledge
InstallSnapshotResponse::success(persistent.current_term, None)
}
/// Handle InstallSnapshot response
async fn handle_install_snapshot_response(
&self,
_from: NodeId,
_resp: InstallSnapshotResponse,
) {
// TODO: Implement snapshot response handling
}
/// Handle client command
async fn handle_client_command(
&self,
command: Command,
response_tx: mpsc::Sender<RaftResult<CommandResult>>,
) {
// Only leader can handle client commands
if !self.state.read().is_leader() {
let _ = response_tx.send(Err(RaftError::NotLeader)).await;
return;
}
let mut persistent = self.persistent.write();
let term = persistent.current_term;
let index = persistent.log.append(term, command.data);
let result = CommandResult { index, term };
let _ = response_tx.send(Ok(result)).await;
// Trigger immediate replication
drop(persistent);
let _ = self.internal_tx.send(InternalMessage::HeartbeatTimeout);
}
/// Handle election timeout
async fn handle_election_timeout(&self) {
if self.state.read().is_leader() {
return;
}
if !self.election_state.read().should_start_election() {
return;
}
info!("Election timeout, starting election");
self.start_election().await;
}
/// Start a new election
async fn start_election(&self) {
// Transition to candidate
*self.state.write() = RaftState::Candidate;
// Increment term and vote for self
let mut persistent = self.persistent.write();
persistent.increment_term();
persistent.vote_for(self.config.node_id.clone());
let term = persistent.current_term;
// Initialize election state
self.election_state
.write()
.start_election(term, &self.config.node_id);
let last_log_index = persistent.log.last_index();
let last_log_term = persistent.log.last_term();
info!(
"Starting election for term {} as {}",
term, self.config.node_id
);
// Send RequestVote RPCs to all other nodes
for member in &self.config.cluster_members {
if member != &self.config.node_id {
let _request = RequestVoteRequest::new(
term,
self.config.node_id.clone(),
last_log_index,
last_log_term,
);
// TODO: Send request to member
debug!("Would send RequestVote to {}", member);
}
}
}
/// Become leader after winning election
async fn become_leader(&self) {
info!(
"Becoming leader for term {}",
self.persistent.read().current_term
);
*self.state.write() = RaftState::Leader;
*self.current_leader.write() = Some(self.config.node_id.clone());
let last_log_index = self.persistent.read().log.last_index();
let other_members: Vec<_> = self
.config
.cluster_members
.iter()
.filter(|m| *m != &self.config.node_id)
.cloned()
.collect();
*self.leader_state.write() = Some(LeaderState::new(&other_members, last_log_index));
// Send initial heartbeats
let _ = self.internal_tx.send(InternalMessage::HeartbeatTimeout);
}
/// Step down to follower (when discovering higher term)
async fn step_down(&self, term: Term) {
info!("Stepping down to follower for term {}", term);
*self.state.write() = RaftState::Follower;
*self.leader_state.write() = None;
*self.current_leader.write() = None;
let mut persistent = self.persistent.write();
persistent.update_term(term);
}
/// Handle heartbeat timeout (for leaders)
async fn handle_heartbeat_timeout(&self) {
if !self.state.read().is_leader() {
return;
}
self.send_heartbeats().await;
}
/// Send heartbeats to all followers
async fn send_heartbeats(&self) {
let persistent = self.persistent.read();
let term = persistent.current_term;
let commit_index = self.volatile.read().commit_index;
for member in &self.config.cluster_members {
if member != &self.config.node_id {
let request = AppendEntriesRequest::heartbeat(
term,
self.config.node_id.clone(),
commit_index,
);
// TODO: Send heartbeat to member
debug!("Would send heartbeat to {}", member);
}
}
}
/// Spawn election timer task
fn spawn_election_timer(self: Arc<Self>) {
let node = self.clone();
tokio::spawn(async move {
let mut interval = interval(Duration::from_millis(50));
loop {
interval.tick().await;
if node.election_state.read().should_start_election() {
let _ = node.internal_tx.send(InternalMessage::ElectionTimeout);
}
}
});
}
/// Spawn heartbeat timer task
fn spawn_heartbeat_timer(self: Arc<Self>) {
let node = self.clone();
tokio::spawn(async move {
let interval_ms = node.config.heartbeat_interval;
let mut interval = interval(Duration::from_millis(interval_ms));
loop {
interval.tick().await;
if node.state.read().is_leader() {
let _ = node.internal_tx.send(InternalMessage::HeartbeatTimeout);
}
}
});
}
/// Submit a command to the Raft cluster
pub async fn submit_command(&self, data: Vec<u8>) -> RaftResult<CommandResult> {
let (tx, mut rx) = mpsc::channel(1);
let command = Command { data };
self.internal_tx
.send(InternalMessage::ClientCommand {
command,
response_tx: tx,
})
.map_err(|_| RaftError::Internal("Node stopped".to_string()))?;
rx.recv()
.await
.ok_or_else(|| RaftError::Internal("Response channel closed".to_string()))?
}
/// Get current state
pub fn current_state(&self) -> RaftState {
*self.state.read()
}
/// Get current term
pub fn current_term(&self) -> Term {
self.persistent.read().current_term
}
/// Get current leader
pub fn current_leader(&self) -> Option<NodeId> {
self.current_leader.read().clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_node_creation() {
let config = RaftNodeConfig::new(
"node1".to_string(),
vec![
"node1".to_string(),
"node2".to_string(),
"node3".to_string(),
],
);
let node = RaftNode::new(config);
assert_eq!(node.current_state(), RaftState::Follower);
assert_eq!(node.current_term(), 0);
}
}

View File

@@ -0,0 +1,442 @@
//! Raft RPC messages
//!
//! Defines the RPC message types for Raft consensus:
//! - AppendEntries (log replication and heartbeat)
//! - RequestVote (leader election)
//! - InstallSnapshot (snapshot transfer)
use crate::{log::LogEntry, log::Snapshot, LogIndex, NodeId, Term};
use serde::{Deserialize, Serialize};
/// AppendEntries RPC request
///
/// Invoked by leader to replicate log entries; also used as heartbeat
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppendEntriesRequest {
/// Leader's term
pub term: Term,
/// Leader's ID (so followers can redirect clients)
pub leader_id: NodeId,
/// Index of log entry immediately preceding new ones
pub prev_log_index: LogIndex,
/// Term of prevLogIndex entry
pub prev_log_term: Term,
/// Log entries to store (empty for heartbeat)
pub entries: Vec<LogEntry>,
/// Leader's commitIndex
pub leader_commit: LogIndex,
}
impl AppendEntriesRequest {
/// Create a new AppendEntries request
pub fn new(
term: Term,
leader_id: NodeId,
prev_log_index: LogIndex,
prev_log_term: Term,
entries: Vec<LogEntry>,
leader_commit: LogIndex,
) -> Self {
Self {
term,
leader_id,
prev_log_index,
prev_log_term,
entries,
leader_commit,
}
}
/// Create a heartbeat (AppendEntries with no entries)
pub fn heartbeat(term: Term, leader_id: NodeId, leader_commit: LogIndex) -> Self {
Self {
term,
leader_id,
prev_log_index: 0,
prev_log_term: 0,
entries: Vec::new(),
leader_commit,
}
}
/// Check if this is a heartbeat message
pub fn is_heartbeat(&self) -> bool {
self.entries.is_empty()
}
/// Serialize to bytes
pub fn to_bytes(&self) -> Result<Vec<u8>, bincode::error::EncodeError> {
use bincode::config;
bincode::encode_to_vec(bincode::serde::Compat(self), config::standard())
}
/// Deserialize from bytes
pub fn from_bytes(bytes: &[u8]) -> Result<Self, bincode::error::DecodeError> {
use bincode::config;
let (compat, _): (bincode::serde::Compat<Self>, _) =
bincode::decode_from_slice(bytes, config::standard())?;
Ok(compat.0)
}
}
/// AppendEntries RPC response
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppendEntriesResponse {
/// Current term, for leader to update itself
pub term: Term,
/// True if follower contained entry matching prevLogIndex and prevLogTerm
pub success: bool,
/// The follower's last log index (for optimization)
pub match_index: Option<LogIndex>,
/// Conflict information for faster log backtracking
pub conflict_index: Option<LogIndex>,
pub conflict_term: Option<Term>,
}
impl AppendEntriesResponse {
/// Create a successful response
pub fn success(term: Term, match_index: LogIndex) -> Self {
Self {
term,
success: true,
match_index: Some(match_index),
conflict_index: None,
conflict_term: None,
}
}
/// Create a failure response
pub fn failure(
term: Term,
conflict_index: Option<LogIndex>,
conflict_term: Option<Term>,
) -> Self {
Self {
term,
success: false,
match_index: None,
conflict_index,
conflict_term,
}
}
/// Serialize to bytes
pub fn to_bytes(&self) -> Result<Vec<u8>, bincode::error::EncodeError> {
use bincode::config;
bincode::encode_to_vec(bincode::serde::Compat(self), config::standard())
}
/// Deserialize from bytes
pub fn from_bytes(bytes: &[u8]) -> Result<Self, bincode::error::DecodeError> {
use bincode::config;
let (compat, _): (bincode::serde::Compat<Self>, _) =
bincode::decode_from_slice(bytes, config::standard())?;
Ok(compat.0)
}
}
/// RequestVote RPC request
///
/// Invoked by candidates to gather votes
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestVoteRequest {
/// Candidate's term
pub term: Term,
/// Candidate requesting vote
pub candidate_id: NodeId,
/// Index of candidate's last log entry
pub last_log_index: LogIndex,
/// Term of candidate's last log entry
pub last_log_term: Term,
}
impl RequestVoteRequest {
/// Create a new RequestVote request
pub fn new(
term: Term,
candidate_id: NodeId,
last_log_index: LogIndex,
last_log_term: Term,
) -> Self {
Self {
term,
candidate_id,
last_log_index,
last_log_term,
}
}
/// Serialize to bytes
pub fn to_bytes(&self) -> Result<Vec<u8>, bincode::error::EncodeError> {
use bincode::config;
bincode::encode_to_vec(bincode::serde::Compat(self), config::standard())
}
/// Deserialize from bytes
pub fn from_bytes(bytes: &[u8]) -> Result<Self, bincode::error::DecodeError> {
use bincode::config;
let (compat, _): (bincode::serde::Compat<Self>, _) =
bincode::decode_from_slice(bytes, config::standard())?;
Ok(compat.0)
}
}
/// RequestVote RPC response
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestVoteResponse {
/// Current term, for candidate to update itself
pub term: Term,
/// True means candidate received vote
pub vote_granted: bool,
}
impl RequestVoteResponse {
/// Create a vote granted response
pub fn granted(term: Term) -> Self {
Self {
term,
vote_granted: true,
}
}
/// Create a vote denied response
pub fn denied(term: Term) -> Self {
Self {
term,
vote_granted: false,
}
}
/// Serialize to bytes
pub fn to_bytes(&self) -> Result<Vec<u8>, bincode::error::EncodeError> {
use bincode::config;
bincode::encode_to_vec(bincode::serde::Compat(self), config::standard())
}
/// Deserialize from bytes
pub fn from_bytes(bytes: &[u8]) -> Result<Self, bincode::error::DecodeError> {
use bincode::config;
let (compat, _): (bincode::serde::Compat<Self>, _) =
bincode::decode_from_slice(bytes, config::standard())?;
Ok(compat.0)
}
}
/// InstallSnapshot RPC request
///
/// Invoked by leader to send chunks of a snapshot to a follower
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InstallSnapshotRequest {
/// Leader's term
pub term: Term,
/// Leader's ID (so follower can redirect clients)
pub leader_id: NodeId,
/// The snapshot replaces all entries up through and including this index
pub last_included_index: LogIndex,
/// Term of lastIncludedIndex
pub last_included_term: Term,
/// Byte offset where chunk is positioned in the snapshot file
pub offset: u64,
/// Raw bytes of the snapshot chunk, starting at offset
pub data: Vec<u8>,
/// True if this is the last chunk
pub done: bool,
}
impl InstallSnapshotRequest {
/// Create a new InstallSnapshot request
pub fn new(
term: Term,
leader_id: NodeId,
snapshot: Snapshot,
offset: u64,
chunk_size: usize,
) -> Self {
let data_len = snapshot.data.len();
let chunk_end = std::cmp::min(offset as usize + chunk_size, data_len);
let chunk = snapshot.data[offset as usize..chunk_end].to_vec();
let done = chunk_end >= data_len;
Self {
term,
leader_id,
last_included_index: snapshot.last_included_index,
last_included_term: snapshot.last_included_term,
offset,
data: chunk,
done,
}
}
/// Serialize to bytes
pub fn to_bytes(&self) -> Result<Vec<u8>, bincode::error::EncodeError> {
use bincode::config;
bincode::encode_to_vec(bincode::serde::Compat(self), config::standard())
}
/// Deserialize from bytes
pub fn from_bytes(bytes: &[u8]) -> Result<Self, bincode::error::DecodeError> {
use bincode::config;
let (compat, _): (bincode::serde::Compat<Self>, _) =
bincode::decode_from_slice(bytes, config::standard())?;
Ok(compat.0)
}
}
/// InstallSnapshot RPC response
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InstallSnapshotResponse {
/// Current term, for leader to update itself
pub term: Term,
/// True if snapshot was successfully installed
pub success: bool,
/// The byte offset for the next chunk (for resume)
pub next_offset: Option<u64>,
}
impl InstallSnapshotResponse {
/// Create a successful response
pub fn success(term: Term, next_offset: Option<u64>) -> Self {
Self {
term,
success: true,
next_offset,
}
}
/// Create a failure response
pub fn failure(term: Term) -> Self {
Self {
term,
success: false,
next_offset: None,
}
}
/// Serialize to bytes
pub fn to_bytes(&self) -> Result<Vec<u8>, bincode::error::EncodeError> {
use bincode::config;
bincode::encode_to_vec(bincode::serde::Compat(self), config::standard())
}
/// Deserialize from bytes
pub fn from_bytes(bytes: &[u8]) -> Result<Self, bincode::error::DecodeError> {
use bincode::config;
let (compat, _): (bincode::serde::Compat<Self>, _) =
bincode::decode_from_slice(bytes, config::standard())?;
Ok(compat.0)
}
}
/// RPC message envelope
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RaftMessage {
AppendEntriesRequest(AppendEntriesRequest),
AppendEntriesResponse(AppendEntriesResponse),
RequestVoteRequest(RequestVoteRequest),
RequestVoteResponse(RequestVoteResponse),
InstallSnapshotRequest(InstallSnapshotRequest),
InstallSnapshotResponse(InstallSnapshotResponse),
}
impl RaftMessage {
/// Get the term from the message
pub fn term(&self) -> Term {
match self {
RaftMessage::AppendEntriesRequest(req) => req.term,
RaftMessage::AppendEntriesResponse(resp) => resp.term,
RaftMessage::RequestVoteRequest(req) => req.term,
RaftMessage::RequestVoteResponse(resp) => resp.term,
RaftMessage::InstallSnapshotRequest(req) => req.term,
RaftMessage::InstallSnapshotResponse(resp) => resp.term,
}
}
/// Serialize to bytes
pub fn to_bytes(&self) -> Result<Vec<u8>, bincode::error::EncodeError> {
use bincode::config;
bincode::encode_to_vec(bincode::serde::Compat(self), config::standard())
}
/// Deserialize from bytes
pub fn from_bytes(bytes: &[u8]) -> Result<Self, bincode::error::DecodeError> {
use bincode::config;
let (compat, _): (bincode::serde::Compat<Self>, _) =
bincode::decode_from_slice(bytes, config::standard())?;
Ok(compat.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_append_entries_heartbeat() {
let req = AppendEntriesRequest::heartbeat(1, "leader".to_string(), 10);
assert!(req.is_heartbeat());
assert_eq!(req.entries.len(), 0);
}
#[test]
fn test_append_entries_serialization() {
let req = AppendEntriesRequest::new(1, "leader".to_string(), 10, 1, vec![], 10);
let bytes = req.to_bytes().unwrap();
let decoded = AppendEntriesRequest::from_bytes(&bytes).unwrap();
assert_eq!(req.term, decoded.term);
assert_eq!(req.leader_id, decoded.leader_id);
}
#[test]
fn test_request_vote_serialization() {
let req = RequestVoteRequest::new(2, "candidate".to_string(), 15, 2);
let bytes = req.to_bytes().unwrap();
let decoded = RequestVoteRequest::from_bytes(&bytes).unwrap();
assert_eq!(req.term, decoded.term);
assert_eq!(req.candidate_id, decoded.candidate_id);
}
#[test]
fn test_response_types() {
let success = AppendEntriesResponse::success(1, 10);
assert!(success.success);
assert_eq!(success.match_index, Some(10));
let failure = AppendEntriesResponse::failure(1, Some(5), Some(1));
assert!(!failure.success);
assert_eq!(failure.conflict_index, Some(5));
}
#[test]
fn test_vote_responses() {
let granted = RequestVoteResponse::granted(1);
assert!(granted.vote_granted);
let denied = RequestVoteResponse::denied(1);
assert!(!denied.vote_granted);
}
}

View File

@@ -0,0 +1,317 @@
//! Raft state management
//!
//! Implements the state machine for Raft consensus including:
//! - Persistent state (term, vote, log)
//! - Volatile state (commit index, last applied)
//! - Leader-specific state (next index, match index)
use crate::{log::RaftLog, LogIndex, NodeId, Term};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// The three states a Raft node can be in
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RaftState {
/// Follower state - responds to RPCs from leaders and candidates
Follower,
/// Candidate state - attempts to become leader
Candidate,
/// Leader state - handles client requests and replicates log
Leader,
}
impl RaftState {
/// Returns true if this node is the leader
pub fn is_leader(&self) -> bool {
matches!(self, RaftState::Leader)
}
/// Returns true if this node is a candidate
pub fn is_candidate(&self) -> bool {
matches!(self, RaftState::Candidate)
}
/// Returns true if this node is a follower
pub fn is_follower(&self) -> bool {
matches!(self, RaftState::Follower)
}
}
/// Persistent state on all servers
///
/// Updated on stable storage before responding to RPCs
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersistentState {
/// Latest term server has seen (initialized to 0, increases monotonically)
pub current_term: Term,
/// Candidate ID that received vote in current term (or None)
pub voted_for: Option<NodeId>,
/// Log entries (each entry contains command and term)
pub log: RaftLog,
}
impl PersistentState {
/// Create new persistent state with initial values
pub fn new() -> Self {
Self {
current_term: 0,
voted_for: None,
log: RaftLog::new(),
}
}
/// Increment the current term
pub fn increment_term(&mut self) {
self.current_term += 1;
self.voted_for = None;
}
/// Update term if the given term is higher
pub fn update_term(&mut self, term: Term) -> bool {
if term > self.current_term {
self.current_term = term;
self.voted_for = None;
true
} else {
false
}
}
/// Vote for a candidate in the current term
pub fn vote_for(&mut self, candidate_id: NodeId) {
self.voted_for = Some(candidate_id);
}
/// Check if vote can be granted for the given candidate
pub fn can_vote_for(&self, candidate_id: &NodeId) -> bool {
match &self.voted_for {
None => true,
Some(voted) => voted == candidate_id,
}
}
/// Serialize state to bytes for persistence
pub fn to_bytes(&self) -> Result<Vec<u8>, bincode::error::EncodeError> {
use bincode::config;
bincode::encode_to_vec(bincode::serde::Compat(self), config::standard())
}
/// Deserialize state from bytes
pub fn from_bytes(bytes: &[u8]) -> Result<Self, bincode::error::DecodeError> {
use bincode::config;
let (compat, _): (bincode::serde::Compat<Self>, _) =
bincode::decode_from_slice(bytes, config::standard())?;
Ok(compat.0)
}
}
impl Default for PersistentState {
fn default() -> Self {
Self::new()
}
}
/// Volatile state on all servers
///
/// Can be reconstructed from persistent state
#[derive(Debug, Clone)]
pub struct VolatileState {
/// Index of highest log entry known to be committed
/// (initialized to 0, increases monotonically)
pub commit_index: LogIndex,
/// Index of highest log entry applied to state machine
/// (initialized to 0, increases monotonically)
pub last_applied: LogIndex,
}
impl VolatileState {
/// Create new volatile state with initial values
pub fn new() -> Self {
Self {
commit_index: 0,
last_applied: 0,
}
}
/// Update commit index
pub fn update_commit_index(&mut self, index: LogIndex) {
if index > self.commit_index {
self.commit_index = index;
}
}
/// Advance last_applied index
pub fn apply_entries(&mut self, up_to_index: LogIndex) {
if up_to_index > self.last_applied {
self.last_applied = up_to_index;
}
}
/// Get the number of entries that need to be applied
pub fn pending_entries(&self) -> u64 {
self.commit_index.saturating_sub(self.last_applied)
}
}
impl Default for VolatileState {
fn default() -> Self {
Self::new()
}
}
/// Volatile state on leaders
///
/// Reinitialized after election
#[derive(Debug, Clone)]
pub struct LeaderState {
/// For each server, index of the next log entry to send to that server
/// (initialized to leader last log index + 1)
pub next_index: HashMap<NodeId, LogIndex>,
/// For each server, index of highest log entry known to be replicated
/// (initialized to 0, increases monotonically)
pub match_index: HashMap<NodeId, LogIndex>,
}
impl LeaderState {
/// Create new leader state for the given cluster members
pub fn new(cluster_members: &[NodeId], last_log_index: LogIndex) -> Self {
let mut next_index = HashMap::new();
let mut match_index = HashMap::new();
for member in cluster_members {
// Initialize next_index to last log index + 1
next_index.insert(member.clone(), last_log_index + 1);
// Initialize match_index to 0
match_index.insert(member.clone(), 0);
}
Self {
next_index,
match_index,
}
}
/// Update next_index for a follower (decrement on failure)
pub fn decrement_next_index(&mut self, node_id: &NodeId) {
if let Some(index) = self.next_index.get_mut(node_id) {
if *index > 1 {
*index -= 1;
}
}
}
/// Update both next_index and match_index for successful replication
pub fn update_replication(&mut self, node_id: &NodeId, match_index: LogIndex) {
self.match_index.insert(node_id.clone(), match_index);
self.next_index.insert(node_id.clone(), match_index + 1);
}
/// Get the median match_index for determining commit_index
pub fn calculate_commit_index(&self) -> LogIndex {
if self.match_index.is_empty() {
return 0;
}
let mut indices: Vec<LogIndex> = self.match_index.values().copied().collect();
indices.sort_unstable();
// Return the median (quorum)
let mid = indices.len() / 2;
indices.get(mid).copied().unwrap_or(0)
}
/// Get next_index for a specific follower
pub fn get_next_index(&self, node_id: &NodeId) -> Option<LogIndex> {
self.next_index.get(node_id).copied()
}
/// Get match_index for a specific follower
pub fn get_match_index(&self, node_id: &NodeId) -> Option<LogIndex> {
self.match_index.get(node_id).copied()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_raft_state_checks() {
assert!(RaftState::Leader.is_leader());
assert!(RaftState::Candidate.is_candidate());
assert!(RaftState::Follower.is_follower());
}
#[test]
fn test_persistent_state_term_management() {
let mut state = PersistentState::new();
assert_eq!(state.current_term, 0);
state.increment_term();
assert_eq!(state.current_term, 1);
assert!(state.voted_for.is_none());
state.update_term(5);
assert_eq!(state.current_term, 5);
}
#[test]
fn test_voting() {
let mut state = PersistentState::new();
let candidate = "node1".to_string();
assert!(state.can_vote_for(&candidate));
state.vote_for(candidate.clone());
assert!(state.can_vote_for(&candidate));
assert!(!state.can_vote_for(&"node2".to_string()));
}
#[test]
fn test_volatile_state() {
let mut state = VolatileState::new();
assert_eq!(state.commit_index, 0);
assert_eq!(state.last_applied, 0);
state.update_commit_index(10);
assert_eq!(state.commit_index, 10);
assert_eq!(state.pending_entries(), 10);
state.apply_entries(5);
assert_eq!(state.last_applied, 5);
assert_eq!(state.pending_entries(), 5);
}
#[test]
fn test_leader_state() {
let members = vec!["node1".to_string(), "node2".to_string()];
let mut leader_state = LeaderState::new(&members, 10);
assert_eq!(leader_state.get_next_index(&members[0]), Some(11));
assert_eq!(leader_state.get_match_index(&members[0]), Some(0));
leader_state.update_replication(&members[0], 10);
assert_eq!(leader_state.get_next_index(&members[0]), Some(11));
assert_eq!(leader_state.get_match_index(&members[0]), Some(10));
}
#[test]
fn test_commit_index_calculation() {
let members = vec![
"node1".to_string(),
"node2".to_string(),
"node3".to_string(),
];
let mut leader_state = LeaderState::new(&members, 10);
leader_state.update_replication(&members[0], 5);
leader_state.update_replication(&members[1], 8);
leader_state.update_replication(&members[2], 3);
let commit = leader_state.calculate_commit_index();
assert_eq!(commit, 5); // Median of [3, 5, 8]
}
}