diff --git a/examples/simple_run.rs b/examples/simple_run.rs index c367c92..9e4f0a8 100644 --- a/examples/simple_run.rs +++ b/examples/simple_run.rs @@ -46,7 +46,7 @@ async fn main() { let id = cluster_nodes[i]; let cc = cluster_config.clone(); handles.push(tokio::spawn(async move { - let mut server = Server::new(id, config, cc).await; + let mut server = Server::new(id, config, cc, None).await; server.start().await; })); } diff --git a/examples/simulate_add_node.rs b/examples/simulate_add_node.rs index a164c87..f005fcb 100644 --- a/examples/simulate_add_node.rs +++ b/examples/simulate_add_node.rs @@ -43,7 +43,7 @@ async fn main() { let id = cluster_nodes[i]; let cc = cluster_config.clone(); handles.push(tokio::spawn(async move { - let mut server = Server::new(id, config, cc).await; + let mut server = Server::new(id, config, cc, None).await; server.start().await; })); } @@ -64,7 +64,7 @@ async fn main() { // Launching a new node handles.push(tokio::spawn(async move { - let mut server = Server::new(new_node_id, new_node_conf, cluster_config).await; + let mut server = Server::new(new_node_id, new_node_conf, cluster_config, None).await; server.start().await; })); diff --git a/examples/simulate_node_failure.rs b/examples/simulate_node_failure.rs index 0c3011f..2456e4a 100644 --- a/examples/simulate_node_failure.rs +++ b/examples/simulate_node_failure.rs @@ -45,7 +45,7 @@ async fn main() { let id = cluster_nodes[i]; let cc = cluster_config.clone(); let server_handle = tokio::spawn(async move { - let mut server = Server::new(id, config, cc).await; + let mut server = Server::new(id, config, cc, None).await; server.start().await; }); server_handles.push(server_handle); @@ -77,7 +77,8 @@ async fn main() { }; let cc = cluster_config.clone(); let server_handle = tokio::spawn(async move { - let mut server = Server::new(server_to_stop.try_into().unwrap(), config, cc).await; + let mut server = + Server::new(server_to_stop.try_into().unwrap(), config, cc, None).await; server.start().await; }); server_handles[server_to_stop - 1] = server_handle; diff --git a/examples/simulate_replica_repair.rs b/examples/simulate_replica_repair.rs index 0ff62e2..f2b4faa 100644 --- a/examples/simulate_replica_repair.rs +++ b/examples/simulate_replica_repair.rs @@ -57,7 +57,7 @@ async fn main() { warn!(get_logger(), "Storage for server {} is corrupted", id); } - let mut server = Server::new(id, config, cc).await; + let mut server = Server::new(id, config, cc, None).await; server.start().await; }); server_handles.push(server_handle); @@ -98,7 +98,8 @@ async fn main() { leadership_preferences: HashMap::new(), storage_location: Some(storage_path.clone()), }; - let mut server = Server::new(server_to_fail.try_into().unwrap(), config, cc).await; + let mut server = + Server::new(server_to_fail.try_into().unwrap(), config, cc, None).await; server.start().await; // Handle recovery of corrupted storage info!( diff --git a/src/error.rs b/src/error.rs index c987942..b90334c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -21,6 +21,9 @@ pub enum Error { /// Some other error occurred. #[error("unknown error {0}")] Unknown(#[from] Box), + /// To handle all bincode error + #[error("Bincode error {0}")] + BincodeError(#[from] bincode::Error), } #[derive(Error, Debug)] @@ -39,6 +42,8 @@ pub enum NetworkError { #[derive(Error, Debug)] pub enum StorageError { + #[error("Path not found")] + PathNotFound, #[error("File is empty")] EmptyFile, #[error("File is corrupted")] diff --git a/src/lib.rs b/src/lib.rs index 2320d6e..35f4313 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,4 +6,5 @@ pub mod error; pub mod log; pub mod network; pub mod server; +pub mod state_mechine; pub mod storage; diff --git a/src/network.rs b/src/network.rs index 1493496..64f17a2 100644 --- a/src/network.rs +++ b/src/network.rs @@ -77,7 +77,7 @@ impl NetworkLayer for TCPManager { join_all(futures) .await .into_iter() - .collect::>() + .collect::>>() // FIXME: We should let client decide what to do with the errors .map_err(|e| NetworkError::BroadcastError(e.to_string()))?; Ok(()) diff --git a/src/server.rs b/src/server.rs index 8d137c3..e372a14 100644 --- a/src/server.rs +++ b/src/server.rs @@ -4,14 +4,18 @@ use crate::cluster::{ClusterConfig, NodeMeta}; use crate::log::get_logger; use crate::network::{NetworkLayer, TCPManager}; +use crate::state_mechine::{self, StateMachine}; use crate::storage::{LocalStorage, Storage, CHECKSUM_LEN}; use serde::{Deserialize, Serialize}; use slog::{error, info, o}; -use std::collections::{HashMap, VecDeque}; +use std::collections::HashMap; use std::io::Cursor; use std::net::SocketAddr; +use std::path::PathBuf; +use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::io::AsyncReadExt; +use tokio::sync::Mutex; use tokio::time::sleep; #[derive(Debug, Clone, PartialEq)] @@ -36,20 +40,59 @@ enum MessageType { // dynamic membership changes JoinRequest, JoinResponse, + + BatchAppendEntries, + BatchAppendEntriesResponse, } #[derive(Debug)] +/// Represents the state of a Raft server, which includes information +/// about the current term, election state, and log entries. struct ServerState { + /// The current term number, which increases monotonically. + /// It is used to identify the latest term known to this server. current_term: u32, + + /// The current state of the server in the Raft protocol + /// (e.g., Leader, Follower, or Candidate). state: RaftState, + + /// The candidate ID that this server voted for in the current term. + /// It is `None` if the server hasn't voted for anyone in this term. voted_for: Option, - log: VecDeque, + + /// A deque of log entries that are replicated to the Raft cluster. + // log: VecDeque, + state_machine: Arc>>, + + /// The index of the highest log entry known to be committed. + /// This indicates the index up to which the state machine is consistent. commit_index: u32, + + /// The index of the previous log entry used for consistency checks. + /// Typically used during the append entries process. previous_log_index: u32, + + /// For each follower, the next log entry to send to that follower. + /// This is used by the leader to keep track of what entries have been + /// sent to each follower. next_index: Vec, + + /// For each follower, the highest log entry index that is known to + /// be replicated on that follower. match_index: Vec, + + /// The election timeout duration. If this time passes without receiving + /// a valid heartbeat or a vote request, the server will trigger an election. election_timeout: Duration, + + /// The time when the last heartbeat from the current leader was received. + /// Used by followers to detect if the leader has failed. last_heartbeat: Instant, + + /// A map of received votes in the current election term. The key is the + /// peer's ID and the value is a boolean indicating whether the vote was + /// granted. votes_received: HashMap, } @@ -93,18 +136,47 @@ pub struct Server { } impl Server { - pub async fn new(id: u32, config: ServerConfig, cluster_config: ClusterConfig) -> Server { + pub async fn new( + id: u32, + config: ServerConfig, + cluster_config: ClusterConfig, + state_machine: Option>, + ) -> Server { let log = get_logger(); let log = log.new( o!("ip" => config.address.ip().to_string(), "port" => config.address.port(), "id" => id), ); + // if storage location is provided, use it else set empty string to use default location + let storage_location = match config.storage_location.clone() { + Some(location) => location + &format!("server_{}.log", id), + None => format!("server_{}.log", id), + }; + let storage = LocalStorage::new(storage_location.clone()).await; + let parent_path = PathBuf::from(storage_location) + .parent() // This returns Option<&Path> + .map(|p| p.to_path_buf()) // Convert &Path to PathBuf + .unwrap_or_else(|| PathBuf::from("logs")); // Provide default path + + // Use the provided state_machine or default to FileStateMachine if none is provided + let state_machine = state_machine.unwrap_or_else(|| { + // Default FileStateMachine initialization + let snapshot_path = parent_path.join(format!("server_{}_snapshot.log", id)); + + Box::new(state_mechine::FileStateMachine::new( + &snapshot_path, + Duration::from_secs(60 * 60), + )) + }); + + let state_machine = Arc::new(Mutex::new(state_machine)); + let peer_count = cluster_config.peer_count(id); let state = ServerState { current_term: 0, state: RaftState::Follower, voted_for: None, - log: VecDeque::new(), + state_machine, commit_index: 0, previous_log_index: 0, next_index: vec![0; peer_count], @@ -115,13 +187,6 @@ impl Server { }; let network_manager = TCPManager::new(config.address); - // if storage location is provided, use it else set empty string to use default location - let storage_location = match config.storage_location.clone() { - Some(location) => location + &format!("server_{}.log", id), - None => format!("server_{}.log", id), - }; - let storage = LocalStorage::new(storage_location).await; - Server { id, state, @@ -200,6 +265,8 @@ impl Server { return; } + let state_machine = Arc::clone(&self.state.state_machine); + // Attempting to recover LogEntry from a disk file let mut cursor = Cursor::new(&log_byte); loop { let mut bytes_data = vec![0u8; log_entry_size + CHECKSUM_LEN]; @@ -212,11 +279,25 @@ impl Server { if log_entry.term > self.state.current_term { self.state.current_term = log_entry.term; } - self.state.log.push_front(log_entry); + state_machine + .lock() + .await + .apply_log_entry( + self.state.current_term, + self.state.commit_index, + log_entry.clone(), + ) + .await; + + // After restoring the LogEntry, the node's state information should be updated + self.state.current_term = log_entry.term; } + info!( self.log, - "Log after reading from disk: {:?}", self.state.log + "Log after reading from disk: {:?}, current term: {}", + state_machine.lock().await.get_log_entry().await, + self.state.current_term ); self.state.match_index = vec![0; self.peer_count() + 1]; @@ -233,7 +314,29 @@ impl Server { } } } + + let state_machine = Arc::clone(&self.state.state_machine); loop { + if state_machine.lock().await.need_create_snapshot().await { + let state_machine_clone = Arc::clone(&self.state.state_machine); + let log_clone = self.log.clone(); + let node_id_clone = self.id; + tokio::spawn(async move { + let mut state_machine_lock = state_machine_clone.lock().await; + if let Err(e) = state_machine_lock.create_snapshot().await { + error!( + log_clone, + "Node: {}, failed to create snapshot: {:?}", node_id_clone, e + ); + } else { + info!( + log_clone, + "Node: {}, snapshot created successfully.", node_id_clone + ); + } + }); + } + let timeout_duration = self.state.election_timeout; let timeout_future = async { @@ -314,11 +417,35 @@ impl Server { if self.state.state != RaftState::Leader { return; } - info!(self.log, "Server {} is the leader", self.id); + info!( + self.log, + "Server {} is the leader, term: {}", self.id, self.state.current_term + ); let mut heartbeat_interval = tokio::time::interval(Duration::from_millis(300)); + let state_machine = Arc::clone(&self.state.state_machine); loop { + if state_machine.lock().await.need_create_snapshot().await { + let state_machine_clone = Arc::clone(&self.state.state_machine); + let log_clone = self.log.clone(); + let node_id_clone = self.id; + tokio::spawn(async move { + let mut state_machine_lock = state_machine_clone.lock().await; + if let Err(e) = state_machine_lock.create_snapshot().await { + error!( + log_clone, + "Node: {}, failed to create snapshot: {:?}", node_id_clone, e + ); + } else { + info!( + log_clone, + "Node: {}, snapshot created successfully.", node_id_clone + ); + } + }); + } + let rpc_future = self.receive_rpc(); tokio::select! { _ = heartbeat_interval.tick() => { @@ -420,6 +547,8 @@ impl Server { 9 => MessageType::RepairResponse, 10 => MessageType::JoinRequest, 11 => MessageType::JoinResponse, + 12 => MessageType::BatchAppendEntries, + 13 => MessageType::BatchAppendEntriesResponse, _ => return, }; @@ -472,6 +601,12 @@ impl Server { MessageType::JoinResponse => { self.handle_join_response(&data).await; } + MessageType::BatchAppendEntries => { + self.handle_batch_append_entries(&data).await; + } + MessageType::BatchAppendEntriesResponse => { + self.handle_batch_append_entries_response(&data).await; + } } } @@ -480,23 +615,28 @@ impl Server { return; } - let term = self.state.current_term; + self.state.previous_log_index += 1; + self.state.commit_index += 1; + self.state.current_term += 1; + let command = LogCommand::Set; let data = u32::from_be_bytes(data[12..16].try_into().unwrap()); let entry = LogEntry { leader_id: self.id, server_id: self.id, - term, + term: self.state.current_term, command, data, }; info!(self.log, "Received client request: {:?}", entry); self.write_buffer.push(entry.clone()); - self.state.log.push_front(entry); - self.state.previous_log_index += 1; - self.state.commit_index += 1; - self.state.current_term += 1; + let state_machine = Arc::clone(&self.state.state_machine); + state_machine + .lock() + .await + .apply_log_entry(self.state.current_term, self.state.commit_index, entry) + .await; } async fn handle_request_vote(&mut self, data: &[u8]) { @@ -765,6 +905,7 @@ impl Server { let mut cursor = Cursor::new(&log_byte); let mut repair_data = Vec::new(); + let state_machine = Arc::clone(&self.state.state_machine); loop { let mut bytes_data = vec![0u8; log_entry_size + CHECKSUM_LEN]; if cursor.read_exact(&mut bytes_data).await.is_err() { @@ -774,7 +915,10 @@ impl Server { } info!( self.log, - "Send repair data from {} to {}, log_entry: {:?}", self.id, peer_id, self.state.log + "Send repair data from {} to {}, log_entry: {:?}", + self.id, + peer_id, + state_machine.lock().await.get_log_entry().await ); let mut response = [ @@ -849,6 +993,7 @@ impl Server { return; } + // Add the new node's information to the cluster, ready to receive subsequent data self.cluster_config .add_server((node_id, node_ip_address.parse::().unwrap()).into()); @@ -866,6 +1011,42 @@ impl Server { if let Err(e) = self.network_manager.send(&peer_address, &response).await { error!(self.log, "Failed to send join response: {}", e); } + + // Here we will send the snapshot data to the new node + let state_machine = Arc::clone(&self.state.state_machine); + let log_entrys = if let Ok(data) = state_machine.lock().await.get_log_entry().await { + data + } else { + error!(self.log, "Failed to get log entrys from state machine."); + return; + }; + info!(self.log, "Sending log entrys to new node: {:?}", log_entrys); + + let log_entry_bytes = if let Ok(b) = bincode::serialize(&log_entrys) { + b + } else { + error!(self.log, "Failed to serialize log entrys."); + return; + }; + + let mut batch_append_entry_request: Vec = Vec::new(); + batch_append_entry_request.extend_from_slice(&self.id.to_be_bytes()); + batch_append_entry_request + .extend_from_slice(&state_machine.lock().await.get_term().await.to_be_bytes()); + batch_append_entry_request.extend_from_slice(&12u32.to_be_bytes()); + batch_append_entry_request + .extend_from_slice(&state_machine.lock().await.get_index().await.to_be_bytes()); + batch_append_entry_request.extend_from_slice(&log_entry_bytes); + if let Err(e) = self + .network_manager + .send(&peer_address, &batch_append_entry_request) + .await + { + error!( + self.log, + "Failed send batch append entry request to {}, err: {}", peer_address, e + ); + } } async fn handle_join_response(&mut self, data: &[u8]) { @@ -906,6 +1087,55 @@ impl Server { ); } + async fn handle_batch_append_entries(&mut self, data: &[u8]) { + let leader_id = u32::from_be_bytes(data[0..4].try_into().unwrap()); + let last_included_term = u32::from_be_bytes(data[4..8].try_into().unwrap()); + let last_included_index = u32::from_be_bytes(data[12..16].try_into().unwrap()); + let log_entrys = if let Ok(data) = bincode::deserialize::>(&data[16..]) { + data + } else { + info!(self.log, "Failed to deserialize log entrys."); + return; + }; + + self.state.current_term = last_included_term; + self.state.commit_index = last_included_index; + self.state.previous_log_index = last_included_index; + let state_machine = Arc::clone(&self.state.state_machine); + state_machine + .lock() + .await + .apply_log_entrys(last_included_term, last_included_index, log_entrys) + .await; + + let response = [ + self.id.to_be_bytes(), + self.state.current_term.to_be_bytes(), + 13u32.to_be_bytes(), + ] + .concat(); + let peer_address = if let Some(addr) = self.cluster_config.address(leader_id) { + addr + } else { + info!(self.log, "Failed to get peer address."); + return; + }; + if let Err(e) = self.network_manager.send(&peer_address, &response).await { + error!(self.log, "Failed to send join response: {}", e); + } + } + + async fn handle_batch_append_entries_response(&mut self, data: &[u8]) { + if self.state.state != RaftState::Leader { + return; + } + + let peer_id = u32::from_be_bytes(data[0..4].try_into().unwrap()); + let last_included_term = u32::from_be_bytes(data[4..8].try_into().unwrap()); + + info!(self.log, "Received batch append entries response from peer: {}, current peer last_included_term: {}", peer_id, last_included_term); + } + async fn persist_to_disk(&mut self, id: u32, data: &[u8]) { info!( self.log, @@ -917,10 +1147,16 @@ impl Server { error!(self.log, "Failed to do compaction on disk: {}", e); } + let state_machine = Arc::clone(&self.state.state_machine); + // let mut state_machine_lock = state_machine.lock().await; if self.state.state == RaftState::Follower { // deserialize log entries and append to log let log_entry = self.deserialize_log_entries(data); - self.state.log.push_front(log_entry); + state_machine + .lock() + .await + .apply_log_entry(self.state.current_term, self.state.commit_index, log_entry) + .await; } if let Err(e) = self.storage.store(data).await { error!(self.log, "Failed to store log entry to disk: {}", e); @@ -929,7 +1165,13 @@ impl Server { info!( self.log, "Log persistence complete, current log count: {}", - self.state.log.len() + state_machine + .lock() + .await + .get_log_entry() + .await + .unwrap() + .len() ); } diff --git a/src/state_mechine.rs b/src/state_mechine.rs new file mode 100644 index 0000000..a64ccca --- /dev/null +++ b/src/state_mechine.rs @@ -0,0 +1,344 @@ +use std::path::PathBuf; +use std::time::Duration; +use std::{fmt::Debug, path::Path}; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use tokio::fs; +use tokio::io::AsyncReadExt; +use tokio::{fs::OpenOptions, io::AsyncWriteExt, time::Instant}; + +use crate::error::StorageError::PathNotFound; +use crate::{error::Error, error::Result, server::LogEntry}; + +#[async_trait] +pub trait StateMachine: Debug + Send + Sync { + // Retrieve the current term stored in the state machine + async fn get_term(&self) -> u32; + + // Retrieve the current log index stored in the state machine + async fn get_index(&self) -> u32; + + // Apply a single log entry to the state machine, updating term, index, and log entries + async fn apply_log_entry( + &mut self, + last_included_term: u32, + last_included_index: u32, + log_entry: LogEntry, + ); + + // Apply multiple log entries to the state machine in bulk + async fn apply_log_entrys( + &mut self, + last_included_term: u32, + last_included_index: u32, + mut log_entrys: Vec, + ); + + // Retrieve all log entries currently stored in the state machine + async fn get_log_entry(&mut self) -> Result>; + + // Create a snapshot of the current state to the file system, and clear the log data after + async fn create_snapshot(&mut self) -> Result<()>; + + // Check if a snapshot is required based on the interval since the last snapshot + async fn need_create_snapshot(&mut self) -> bool; +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct FileStateMachine { + last_included_term: u32, + last_included_index: u32, + data: Vec, + + #[serde(skip)] + snapshot_path: Option>, + + /// The time interval between snapshots. + #[serde(skip)] + snapshot_interval: Duration, + + /// The time when snapshot generation started + #[serde(skip)] + snapshot_start_time: Option, + + /// Whether the state machine is currently generating a snapshot + #[serde(skip)] + is_snapshotting: bool, + + /// The time when the last snapshot was completed + #[serde(skip)] + last_snapshot_complete_time: Option, +} + +impl FileStateMachine { + /// Create a new FileStateMachine with initial values + pub fn new(snapshot_path: &Path, snapshot_interval: Duration) -> Self { + let current_time = Instant::now(); + Self { + snapshot_path: Some(PathBuf::from(snapshot_path).into_boxed_path()), + last_included_term: 0, + last_included_index: 0, + data: Vec::new(), + snapshot_interval, + snapshot_start_time: Some(current_time), + is_snapshotting: false, + last_snapshot_complete_time: Some(current_time), + } + } +} + +/// Implement the StateMachine trait for FileStateMachine +/// Generate snapshots based on time intervals as I record start, end time +/// The data in memory is cleared after generating a snapshot, to save memory +#[async_trait] +impl StateMachine for FileStateMachine { + async fn get_term(&self) -> u32 { + self.last_included_term + } + + async fn get_index(&self) -> u32 { + self.last_included_index + } + + async fn apply_log_entry( + &mut self, + last_included_term: u32, + last_included_index: u32, + log_entry: LogEntry, + ) { + self.last_included_term = last_included_term; + self.last_included_index = last_included_index; + self.data.push(log_entry); + } + + async fn apply_log_entrys( + &mut self, + last_included_term: u32, + last_included_index: u32, + mut log_entrys: Vec, + ) { + self.last_included_term = last_included_term; + self.last_included_index = last_included_index; + self.data.append(&mut log_entrys); + } + + async fn create_snapshot(&mut self) -> Result<()> { + let snapshot_path = if let Some(ref path) = self.snapshot_path { + path + } else { + return Err(Error::Store(PathNotFound)); + }; + + self.snapshot_start_time = Some(Instant::now()); + self.is_snapshotting = true; + + // Step 1: Read the existing snapshot from file if it exists + let mut existing_fsm = FileStateMachine::new(snapshot_path, Duration::from_secs(0)); + if fs::metadata(snapshot_path).await.is_ok() { + let mut file = OpenOptions::new().read(true).open(snapshot_path).await?; + let mut buffer = Vec::new(); + file.read_to_end(&mut buffer).await.map_err(Error::Io)?; + + if !buffer.is_empty() { + existing_fsm = bincode::deserialize(&buffer).map_err(Error::BincodeError)?; + } + } + + // Step 2: Merge existing snapshot data + self.data.splice(0..0, existing_fsm.data.drain(..)); + + // Step 3: Write the merged state back to the snapshot file + let mut file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .truncate(true) + .open(snapshot_path) + .await?; + let bytes = bincode::serialize(&self).map_err(Error::BincodeError)?; + file.write_all(&bytes).await?; + file.sync_all().await.map_err(Error::Io)?; + + // Step 4: Clear `data` after snapshot is created successfully + self.data.clear(); + + self.last_snapshot_complete_time = Some(Instant::now()); + self.is_snapshotting = false; + + Ok(()) + } + + /// Check if a new snapshot is needed. + async fn need_create_snapshot(&mut self) -> bool { + if self.is_snapshotting { + return false; // If we are currently snapshotting, we don't need another snapshot + } + + if let Some(last_snapshot_time) = self.last_snapshot_complete_time { + // Calculate the time since the last snapshot was completed + let time_since_last_snapshot = Instant::now().duration_since(last_snapshot_time); + + // If the time since the last snapshot is greater than the snapshot interval, return true + if time_since_last_snapshot >= self.snapshot_interval { + // self.last_snapshot_complete_time = Some(Instant::now()); + return true; + } + } else { + // If we never completed a snapshot, we need to create one + return true; + } + + false + } + + async fn get_log_entry(&mut self) -> Result> { + let snapshot_path = if let Some(ref path) = self.snapshot_path { + path + } else { + return Err(Error::Store(PathNotFound)); + }; + + let mut existing_fsm = FileStateMachine::new(snapshot_path, Duration::from_secs(0)); + if fs::metadata(snapshot_path).await.is_ok() { + let mut file = OpenOptions::new().read(true).open(snapshot_path).await?; + let mut buffer = Vec::new(); + file.read_to_end(&mut buffer).await.map_err(Error::Io)?; + + if !buffer.is_empty() { + existing_fsm = bincode::deserialize(&buffer).map_err(Error::BincodeError)?; + } + } + + self.data.splice(0..0, existing_fsm.data.drain(..)); + + Ok(self.data.clone()) + } +} + +#[cfg(test)] +mod tests { + use crate::server::LogCommand; + + use super::*; + use tempfile::NamedTempFile; + use tokio::time::sleep; + + #[tokio::test] + async fn test_apply_log_entry() { + let tmp_file = NamedTempFile::new().unwrap(); + let snapshot_path = tmp_file.path().to_str().unwrap(); + + let mut fsm = FileStateMachine { + snapshot_path: Option::Some(PathBuf::from(snapshot_path).into_boxed_path()), + last_included_term: 0, + last_included_index: 0, + data: vec![], + snapshot_interval: Duration::from_secs(300), + snapshot_start_time: None, + is_snapshotting: false, + last_snapshot_complete_time: None, + }; + + let log_entry = LogEntry { + term: 1, + command: LogCommand::Set, + leader_id: 1, + server_id: 1, + data: 1, + }; + + fsm.apply_log_entry(1, 1, log_entry.clone()).await; + + let log_entries = fsm.get_log_entry().await.unwrap(); + assert_eq!(log_entries.len(), 1); + assert_eq!(log_entries[0], log_entry); + assert_eq!(fsm.last_included_term, 1); + assert_eq!(fsm.last_included_index, 1); + } + + #[tokio::test] + async fn test_need_create_snapshot() { + let mut fsm = FileStateMachine { + snapshot_path: None, + last_included_term: 0, + last_included_index: 0, + data: vec![], + snapshot_interval: Duration::from_secs(1), + snapshot_start_time: None, + is_snapshotting: false, + last_snapshot_complete_time: Some(Instant::now()), + }; + + // Immediately after completing snapshot, no snapshot should be needed + assert!(!fsm.need_create_snapshot().await); + + // Wait for more than the interval and check again + sleep(Duration::from_secs(2)).await; + assert!(fsm.need_create_snapshot().await); + } + + #[tokio::test] + async fn test_create_snapshot() { + let tmp_file = NamedTempFile::new().unwrap(); + let snapshot_path = tmp_file.path().to_str().unwrap(); + + // Create a FileStateMachine with some data + let mut fsm = FileStateMachine { + snapshot_path: Some(PathBuf::from(snapshot_path).into_boxed_path()), + last_included_term: 1, + last_included_index: 1, + data: vec![ + LogEntry { + term: 1, + command: LogCommand::Set, + leader_id: 1, + server_id: 1, + data: 1, + }, + LogEntry { + term: 2, + command: LogCommand::Set, + leader_id: 2, + server_id: 2, + data: 2, + }, + ], + snapshot_interval: Duration::from_secs(300), + snapshot_start_time: None, + is_snapshotting: false, + last_snapshot_complete_time: None, + }; + + // Call create_snapshot and check result + let result = fsm.create_snapshot().await; + assert!(result.is_ok(), "Snapshot creation failed"); + + // Check that the snapshot file is created + let metadata = std::fs::metadata(snapshot_path); + assert!(metadata.is_ok(), "Snapshot file was not created"); + + // Read the file back and deserialize it to check contents + let snapshot_data = std::fs::read(snapshot_path).unwrap(); + let deserialized_fsm: FileStateMachine = bincode::deserialize(&snapshot_data).unwrap(); + + // Check that the deserialized data matches the original state machine + assert_eq!(deserialized_fsm.data.len(), 2); + + // Check that the snapshot start and complete times were set + assert!( + fsm.snapshot_start_time.is_some(), + "Snapshot start time was not set" + ); + assert!( + fsm.last_snapshot_complete_time.is_some(), + "Snapshot complete time was not set" + ); + + // Check that is_snapshotting was set to false after completion + assert!( + !fsm.is_snapshotting, + "is_snapshotting should be false after snapshot completion" + ); + } +}