diff --git a/src/bin/server.rs b/src/bin/server.rs index 5e1472a..35b9b20 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -6,7 +6,7 @@ use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream}, sync::{Mutex, RwLock}, - time::Instant, + time::{timeout, Instant}, }; use tracing::{debug, error, info, trace, warn}; @@ -82,7 +82,9 @@ async fn main() { tokio::spawn(async move { loop { tokio::time::sleep(Duration::from_millis(100)).await; - server_clone.lock().await.broadcast_current_log().await; + if let Ok(mut server) = timeout(Duration::from_millis(200), server_clone.lock()).await { + server.broadcast_current_log().await; + } } }); @@ -94,10 +96,12 @@ async fn main() { loop { let election_timeout = Duration::from_millis(election_timeout as u64); tokio::time::sleep(election_timeout).await; - let mut server = server_clone.lock().await; - if no_hearbeats_received_from_leader(election_timeout, &server).await { - warn!("No heartbeats from leader, starting a new election"); - server.start_election().await; + if let Ok(mut server) = timeout(Duration::from_millis(300), server_clone.lock()).await { + trace!("Election timer off"); + if server.no_hearbeats_received_from_leader().await { + warn!("No heartbeats from leader, starting a new election"); + server.start_election().await; + } } } }); @@ -105,17 +109,6 @@ async fn main() { handle_connections(server, listener, logs, log_segments).await; } -/// Checks if the last heartbeat received from the leader(if there is -/// a last heartbeat) has passed the election timeout. -async fn no_hearbeats_received_from_leader(election_timeout: Duration, server: &Server) -> bool { - let time_elapsed = Instant::now() - election_timeout; - let last_heartbeat = server.last_heartbeat().await; - (last_heartbeat - .is_some_and(|heartbeat| heartbeat.duration_since(time_elapsed) > election_timeout) - || last_heartbeat.is_none()) - && !server.is_leader().await -} - async fn handle_connections( server: Arc>, listener: TcpListener, @@ -377,13 +370,17 @@ async fn request_vote( last_term, }; debug!("Receiving vote: {:?}", vote_request); - let vote_response = server.lock().await.receive_vote(vote_request).await; - let encoded_vote_response = bincode::serialize(&vote_response).unwrap(); - let mut buf = Vec::new(); - buf.extend((0_u8).to_be_bytes()); - buf.extend(encoded_vote_response.len().to_be_bytes()); - buf.extend(encoded_vote_response); + if let Ok(mut server) = timeout(Duration::from_millis(50), server.lock()).await { + let vote_response = server.receive_vote(vote_request).await; + let encoded_vote_response = bincode::serialize(&vote_response).unwrap(); + buf.extend((0_u8).to_be_bytes()); + buf.extend(encoded_vote_response.len().to_be_bytes()); + buf.extend(encoded_vote_response); + } else { + buf.extend((1_u8).to_be_bytes()); + } + buf } @@ -404,13 +401,18 @@ async fn log_request( leader_commit, suffix, }; - let log_response = server.lock().await.receive_log_request(log_request).await; - let encoded_log_response = bincode::serialize(&log_response).unwrap(); let mut buf = Vec::new(); - buf.extend((0_u8).to_be_bytes()); - buf.extend(encoded_log_response.len().to_be_bytes()); - buf.extend(encoded_log_response); + if let Ok(mut server) = timeout(Duration::from_millis(50), server.lock()).await { + let log_response = server.receive_log_request(log_request).await; + let encoded_log_response = bincode::serialize(&log_response).unwrap(); + buf.extend((0_u8).to_be_bytes()); + buf.extend(encoded_log_response.len().to_be_bytes()); + buf.extend(encoded_log_response); + } else { + buf.extend((1_u8).to_be_bytes()); + } + buf } diff --git a/src/raft.rs b/src/raft.rs index 448784d..4f0cfdc 100644 --- a/src/raft.rs +++ b/src/raft.rs @@ -1,8 +1,4 @@ -use std::{ - cmp, - collections::{HashMap, HashSet}, - sync::{atomic::AtomicU64, Arc}, -}; +use std::{cmp, collections::HashMap, sync::atomic::AtomicU64, time::Duration}; use crossbeam_skiplist::{SkipMap, SkipSet}; use tracing::{debug, error, info, trace}; @@ -14,7 +10,7 @@ use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::TcpStream, sync::RwLock, - time, + time::{self, Instant}, }; /// The cluster must have at least 5 servers. @@ -46,8 +42,8 @@ enum State { pub struct Server { id: NodeId, // Need to be stored on disk - current_term: Arc, - voted_for: Arc>>, + current_term: AtomicU64, + voted_for: RwLock>, log: RwLock>, // Can be stored in-memory state: RwLock, @@ -76,8 +72,8 @@ impl Server { pub fn new(id: u64, nodes: HashMap) -> Server { Server { id, - current_term: Arc::new(AtomicU64::new(0)), - voted_for: Arc::new(RwLock::new(None)), + current_term: AtomicU64::new(0), + voted_for: RwLock::new(None), log: RwLock::new(Vec::new()), state: RwLock::new(State::Follower), commit_length: AtomicU64::new(0), @@ -92,26 +88,21 @@ impl Server { } fn current_term(&self) -> u64 { - self.current_term - .clone() - .load(std::sync::atomic::Ordering::SeqCst) + self.current_term.load(std::sync::atomic::Ordering::SeqCst) } fn increment_current_term(&self) { self.current_term - .clone() .fetch_add(1, std::sync::atomic::Ordering::SeqCst); } fn decrement_current_term(&self) { self.current_term - .clone() .fetch_sub(1, std::sync::atomic::Ordering::SeqCst); } fn update_current_term(&self, value: u64) { self.current_term - .clone() .store(value, std::sync::atomic::Ordering::SeqCst) } @@ -191,6 +182,18 @@ impl Server { self.log.read().await.len() } + /// Checks if the last heartbeat received from the leader(if there + /// is a last heartbeat) has passed the election timeout. + pub async fn no_hearbeats_received_from_leader(&self) -> bool { + let election_timeout = Duration::from_millis(self.election_timeout() as u64); + let time_elapsed = Instant::now() - election_timeout; + let last_heartbeat = self.last_heartbeat().await; + (last_heartbeat + .is_some_and(|heartbeat| time_elapsed.duration_since(heartbeat) > election_timeout) + || last_heartbeat.is_none()) + && !self.is_leader().await + } + pub async fn start_election(&mut self) { if self.is_leader().await { return; @@ -676,3 +679,65 @@ pub struct LogResponse { ack: u64, successful: bool, } + +#[cfg(test)] +mod tests { + + use crate::raft::*; + + #[tokio::test] + async fn should_start_a_new_election_with_no_hearbeats() { + let server = Server::new(1, HashMap::new()); + + let result = server.no_hearbeats_received_from_leader().await; + + assert!(result); + } + + #[tokio::test] + async fn should_start_a_new_election_with_outdated_hearbeats() { + let one_second_ago = Instant::now() - Duration::from_secs(1); + let server = Server { + id: 1, + current_term: AtomicU64::new(0), + voted_for: RwLock::new(None), + log: RwLock::new(Vec::new()), + state: RwLock::new(State::Follower), + commit_length: AtomicU64::new(0), + election_timeout: rand::thread_rng().gen_range(150..300), + current_leader: AtomicU64::new(0), + votes_received: SkipSet::new(), + sent_length: SkipMap::new(), + acked_length: SkipMap::new(), + nodes: RwLock::new(HashMap::new()), + last_heartbeat: RwLock::new(Option::Some(one_second_ago)), + }; + + let result = server.no_hearbeats_received_from_leader().await; + + assert!(result); + } + + #[tokio::test] + async fn should_not_start_a_new_election() { + let server = Server { + id: 1, + current_term: AtomicU64::new(0), + voted_for: RwLock::new(None), + log: RwLock::new(Vec::new()), + state: RwLock::new(State::Follower), + commit_length: AtomicU64::new(0), + election_timeout: rand::thread_rng().gen_range(150..300), + current_leader: AtomicU64::new(0), + votes_received: SkipSet::new(), + sent_length: SkipMap::new(), + acked_length: SkipMap::new(), + nodes: RwLock::new(HashMap::new()), + last_heartbeat: RwLock::new(Option::Some(Instant::now())), + }; + + let result = server.no_hearbeats_received_from_leader().await; + + assert!(!result); + } +}