From 774a9e4d47bac3902aa8c744bdcdc1fd88178ef5 Mon Sep 17 00:00:00 2001 From: dierbei Date: Thu, 15 Aug 2024 07:37:16 +0000 Subject: [PATCH] fix storage Signed-off-by: dierbei --- examples/simple_run.rs | 16 +- examples/simulate_add_node.rs | 38 ++-- examples/simulate_node_failure.rs | 4 +- src/c.rs | 148 +++++++++++++ src/error.rs | 4 +- src/network.rs | 2 +- src/server.rs | 177 +++++++++------- src/storage.rs | 333 ++++++++++++++++++++++-------- 8 files changed, 531 insertions(+), 191 deletions(-) create mode 100644 src/c.rs diff --git a/examples/simple_run.rs b/examples/simple_run.rs index 5e25e5e..71b14ab 100644 --- a/examples/simple_run.rs +++ b/examples/simple_run.rs @@ -9,8 +9,6 @@ use slog::{error, info}; use std::collections::HashMap; use std::net::SocketAddr; use std::str::FromStr; -use std::thread; -use tokio::runtime::Runtime; use tokio::time::Duration; use raft_rs::network::{NetworkLayer, TCPManager}; @@ -34,7 +32,7 @@ async fn main() { .clone() .iter() .map(|n| ServerConfig { - election_timeout: Duration::from_millis(200), + election_timeout: Duration::from_millis(1000), address: n.address, default_leader: Some(1u32), leadership_preferences: HashMap::new(), @@ -47,10 +45,11 @@ async fn main() { for (i, config) in configs.into_iter().enumerate() { let id = cluster_nodes[i]; let cc = cluster_config.clone(); - handles.push(thread::spawn(move || { - let rt = Runtime::new().unwrap(); - let mut server = Server::new(id, config, cc); - rt.block_on(server.start()); + handles.push(tokio::spawn(async move { + // let rt = Runtime::new().unwrap(); + let mut server = Server::new(id, config, cc).await; + server.start().await; + // rt.block_on(server.start()); })); } @@ -60,7 +59,8 @@ async fn main() { tokio::time::sleep(Duration::from_secs(2)).await; // Join all server threads for handle in handles { - handle.join().unwrap(); + // handle.join().unwrap(); + handle.await.unwrap(); } } diff --git a/examples/simulate_add_node.rs b/examples/simulate_add_node.rs index eb036d0..8dfd9fd 100644 --- a/examples/simulate_add_node.rs +++ b/examples/simulate_add_node.rs @@ -7,8 +7,6 @@ use slog::error; use std::collections::HashMap; use std::net::SocketAddr; use std::str::FromStr; -use std::thread; -use tokio::runtime::Runtime; use tokio::time::Duration; use raft_rs::network::{NetworkLayer, TCPManager}; @@ -17,7 +15,7 @@ use raft_rs::server::{Server, ServerConfig}; #[tokio::main] async fn main() { // Define cluster configuration - let mut cluster_nodes = vec![1, 2, 3, 4, 5]; + let cluster_nodes = vec![1, 2, 3, 4, 5]; let peers = vec![ NodeMeta::from((1, SocketAddr::from_str("127.0.0.1:5001").unwrap())), NodeMeta::from((2, SocketAddr::from_str("127.0.0.1:5002").unwrap())), @@ -25,7 +23,7 @@ async fn main() { NodeMeta::from((4, SocketAddr::from_str("127.0.0.1:5004").unwrap())), NodeMeta::from((5, SocketAddr::from_str("127.0.0.1:5005").unwrap())), ]; - let mut cluster_config = ClusterConfig::new(peers.clone()); + let cluster_config = ClusterConfig::new(peers.clone()); // Create server configs let configs: Vec<_> = peers .clone() @@ -44,10 +42,11 @@ async fn main() { for (i, config) in configs.into_iter().enumerate() { let id = cluster_nodes[i]; let cc = cluster_config.clone(); - handles.push(thread::spawn(move || { - let rt = Runtime::new().unwrap(); - let mut server = Server::new(id, config, cc); - rt.block_on(server.start()); + handles.push(tokio::spawn(async move { + // let rt = Runtime::new().unwrap(); + let mut server = Server::new(id, config, cc).await; + server.start().await; + // rt.block_on(server.start()); })); } @@ -66,10 +65,11 @@ async fn main() { }; // Launching a new node - handles.push(thread::spawn(move || { - let rt = Runtime::new().unwrap(); - let mut server = Server::new(new_node_id, new_node_conf, cluster_config); - rt.block_on(server.start()); + handles.push(tokio::spawn(async move { + // let rt = Runtime::new().unwrap(); + let mut server = Server::new(new_node_id, new_node_conf, cluster_config).await; + server.start().await; + // rt.block_on(server.start()); })); // Simulate sending a Raft Join request after a few seconds @@ -79,15 +79,15 @@ async fn main() { // Wait for all servers to finish for handle in handles { - handle.join().unwrap(); + // handle.join().unwrap(); + handle.await.unwrap(); } } async fn add_node_request(new_node_id: u32, addr: SocketAddr) { let log = get_logger(); - let server_address = addr; - let network_manager = TCPManager::new(server_address); + let network_manager = TCPManager::new(addr); let request_data = vec![ new_node_id.to_be_bytes().to_vec(), @@ -98,7 +98,13 @@ async fn add_node_request(new_node_id: u32, addr: SocketAddr) { .concat(); // Let's assume that 5001 is the port of the leader node. - if let Err(e) = network_manager.send(&server_address, &request_data).await { + if let Err(e) = network_manager + .send( + &SocketAddr::from_str("127.0.0.1:5001").unwrap(), + &request_data, + ) + .await + { error!(log, "Failed to send client request: {}", e); } } diff --git a/examples/simulate_node_failure.rs b/examples/simulate_node_failure.rs index 98feb24..0c3011f 100644 --- a/examples/simulate_node_failure.rs +++ b/examples/simulate_node_failure.rs @@ -45,7 +45,7 @@ async fn main() { let id = cluster_nodes[i]; let cc = cluster_config.clone(); let server_handle = tokio::spawn(async move { - let mut server = Server::new(id, config, cc); + let mut server = Server::new(id, config, cc).await; server.start().await; }); server_handles.push(server_handle); @@ -77,7 +77,7 @@ async fn main() { }; let cc = cluster_config.clone(); let server_handle = tokio::spawn(async move { - let mut server = Server::new(server_to_stop.try_into().unwrap(), config, cc); + let mut server = Server::new(server_to_stop.try_into().unwrap(), config, cc).await; server.start().await; }); server_handles[server_to_stop - 1] = server_handle; diff --git a/src/c.rs b/src/c.rs new file mode 100644 index 0000000..d3f087d --- /dev/null +++ b/src/c.rs @@ -0,0 +1,148 @@ +fn get_random_file_path() -> std::io::Result { + let filename: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(10) + .map(char::from) + .collect(); + + let filepath = PathBuf::from(format!("/tmp/.tmp{}", filename)); + + Ok(filepath) +} + +// fn remove_random_file(path: PathBuf) -> std::io::Result<()> { +// std::fs::remove_file(path)?; + +// Ok(()) +// } + +fn file_exists(path: &PathBuf) -> bool { + std::path::Path::new(path).exists() +} + +#[test] +fn test_retrieve_checksum() { + let data_str = "Some data followed by a checksum".as_bytes(); + let calculated_checksum = calculate_checksum(data_str); + + let data = [data_str, calculated_checksum.as_slice()].concat(); + + let retrieved_checksum = retrieve_checksum(&data); + assert_eq!(calculated_checksum, retrieved_checksum); +} + +#[tokio::test] +async fn test_store_async() { + let random_path = get_random_file_path().unwrap(); + + let storage: Box = Box::new(LocalStorage::new_from_path(&random_path).await); + let payload_data = "Some data to test raft".as_bytes(); + let store_result = storage.store(payload_data).await; + assert!(store_result.is_ok()); + + let disk_data = storage.retrieve().await.unwrap(); + assert_eq!(payload_data.len() + CHECKSUM_LEN, disk_data.len()); + + let data = &disk_data[..disk_data.len() - CHECKSUM_LEN]; + assert_eq!(payload_data, data); + + storage.delete().await.unwrap(); +} + +#[tokio::test] +async fn test_delete() { + let random_path = get_random_file_path().unwrap(); + + let storage: Box = Box::new(LocalStorage::new_from_path(&random_path).await); + let delete_result = storage.delete().await; + assert!(delete_result.is_ok()); + assert!(!file_exists(&random_path)); +} + +#[tokio::test] +async fn test_compaction_file_lt_max_file_size() { + let random_path = get_random_file_path().unwrap(); + + let storage: Box = Box::new(LocalStorage::new_from_path(&random_path).await); + let mock_data = vec![0u8; 1_000_000 /*1 MB*/ - 500]; + let store_result = storage.store(&mock_data).await; + assert!(store_result.is_ok()); + + let compaction_result = storage.compaction().await; + assert!(compaction_result.is_ok()); + assert!(file_exists(&random_path)); + + storage.delete().await.unwrap(); +} + +#[tokio::test] +async fn test_compaction_file_gt_max_file_size() { + let random_path = get_random_file_path().unwrap(); + + let storage: Box = Box::new(LocalStorage::new_from_path(&random_path).await); + let mock_data = vec![0u8; 1_000_000 /*1 MB*/]; + let store_result = storage.store(&mock_data).await; + assert!(store_result.is_ok()); + + let compaction_result = storage.compaction().await; + assert!(compaction_result.is_ok()); + + assert!(!file_exists(&random_path)); +} + +#[tokio::test] +async fn test_retrieve_data() { + let random_path = get_random_file_path().unwrap(); + + let storage: Box = Box::new(LocalStorage::new_from_path(&random_path).await); + let log_entry_size = std::mem::size_of::(); + + // Insert the first data first + let entry1 = LogEntry { + leader_id: 1, + server_id: 1, + term: 1, + command: LogCommand::Set, + data: 1, + }; + let serialize_data = bincode::serialize(&entry1).unwrap(); + storage.store(&serialize_data).await.unwrap(); + let disk_data = storage.retrieve().await.unwrap(); + let log_entry_bytes = &disk_data[0..log_entry_size]; + let disk_entry: LogEntry = bincode::deserialize(log_entry_bytes).unwrap(); + assert_eq!(entry1, disk_entry); + + // Then insert the second data + let entry2 = LogEntry { + leader_id: 2, + server_id: 2, + term: 2, + command: LogCommand::Set, + data: 2, + }; + let serialize_data = bincode::serialize(&entry2).unwrap(); + storage.store(&serialize_data).await.unwrap(); + let disk_data = storage.retrieve().await.unwrap(); + + // Try to read two pieces of data and sit down to compare + let mut log_entrys = vec![]; + let mut cursor = Cursor::new(&disk_data); + loop { + let mut bytes_data = vec![0u8; log_entry_size]; + if cursor.read_exact(&mut bytes_data).is_err() { + break; + } + let struct_data: LogEntry = bincode::deserialize(&bytes_data).unwrap(); + + let mut checksum = [0u8; CHECKSUM_LEN]; + if cursor.read_exact(&mut checksum).is_err() { + break; + } + + log_entrys.push(struct_data); + } + + assert_eq!(vec![entry1, entry2], log_entrys); + + storage.delete().await.unwrap(); +} \ No newline at end of file diff --git a/src/error.rs b/src/error.rs index 3af5544..c987942 100644 --- a/src/error.rs +++ b/src/error.rs @@ -33,8 +33,8 @@ pub enum NetworkError { ConnectError(SocketAddr), #[error("Failed binding to {0}")] BindError(SocketAddr), - #[error("Broadcast failed")] - BroadcastError, + #[error("Broadcast failed, errmsg: {0}")] + BroadcastError(String), } #[derive(Error, Debug)] diff --git a/src/network.rs b/src/network.rs index 50dbfd9..1493496 100644 --- a/src/network.rs +++ b/src/network.rs @@ -79,7 +79,7 @@ impl NetworkLayer for TCPManager { .into_iter() .collect::>() // FIXME: We should let client decide what to do with the errors - .map_err(|_e| NetworkError::BroadcastError)?; + .map_err(|e| NetworkError::BroadcastError(e.to_string()))?; Ok(()) } diff --git a/src/server.rs b/src/server.rs index a1e3cf0..da59050 100644 --- a/src/server.rs +++ b/src/server.rs @@ -4,10 +4,12 @@ use crate::cluster::{ClusterConfig, NodeMeta}; use crate::log::get_logger; use crate::network::{NetworkLayer, TCPManager}; -use crate::storage::{LocalStorage, Storage}; +use crate::storage::{LocalStorage, Storage, CHECKSUM_LEN}; use serde::{Deserialize, Serialize}; use slog::{error, info, o}; +use tokio::io::AsyncReadExt; use std::collections::{HashMap, VecDeque}; +use std::io::Cursor; use std::net::SocketAddr; use std::time::{Duration, Instant}; use tokio::time::sleep; @@ -51,20 +53,20 @@ struct ServerState { votes_received: HashMap, } -#[derive(Debug, Clone, Serialize, Deserialize)] -enum LogCommand { +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub enum LogCommand { Noop, Set, Delete, } -#[derive(Debug, Clone, Serialize, Deserialize)] -struct LogEntry { - leader_id: u32, - server_id: u32, - term: u32, - command: LogCommand, - data: u32, +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct LogEntry { + pub leader_id: u32, + pub server_id: u32, + pub term: u32, + pub command: LogCommand, + pub data: u32, } #[derive(Debug)] @@ -91,12 +93,12 @@ pub struct Server { } impl Server { - pub fn new(id: u32, config: ServerConfig, cluster_config: ClusterConfig) -> Server { + pub async fn new(id: u32, config: ServerConfig, cluster_config: ClusterConfig) -> Server { let log = get_logger(); let log = log.new( - o!("ip" => config.address.ip().to_string(), "port" => config.address.port(), "default leader" => config.default_leader.unwrap_or(1), "id" => id), + o!("ip" => config.address.ip().to_string(), "port" => config.address.port(), "id" => id), ); - + let peer_count = cluster_config.peer_count(id); let state = ServerState { current_term: 0, @@ -118,7 +120,7 @@ impl Server { Some(location) => location + &format!("server_{}.log", id), None => format!("server_{}.log", id), }; - let storage = LocalStorage::new(storage_location); + let storage = LocalStorage::new(storage_location).await; Server { id, @@ -169,54 +171,47 @@ impl Server { return; } - let log_byte = self.storage.retrieve().await; - if let Ok(log) = log_byte { - for entry in log.chunks(std::mem::size_of::()) { - if entry.len() != std::mem::size_of::() { - break; - } - let log_entry = self.deserialize_log_entries(entry); - if log_entry.term > self.state.current_term { - self.state.current_term = log_entry.term; - } - self.state.log.push_front(log_entry); + let log_byte = self.storage.retrieve().await.unwrap(); + let log_entry_size = std::mem::size_of::(); + + // Data integrity check failed + // try repair the log from other peers + if log_byte.len() % (log_entry_size + CHECKSUM_LEN) != 0 { + error!(self.log, "Data integrity check failed"); + + // step1 delete the log file + if let Err(e) = self.storage.delete().await { + error!(self.log, "Failed to delete log file: {}", e); } - info!( - self.log, - "Log after reading from disk: {:?}", self.state.log - ); - info!( - self.log, - "Log after reading from disk: {:?}", self.state.log - ); - } else { - // Data integrity check failed - if log_byte - .unwrap_err() - .to_string() - .contains("Data integrity check failed") - { - error!(self.log, "Data integrity check failed"); - // try repair the log from other peers - // step1 delete the log file - if let Err(e) = self.storage.delete().await { - error!(self.log, "Failed to delete log file: {}", e); - } - // step2 get the log from other peers - // ping all the peers to get the log - let addresses: Vec = self.peers_address(); - let data = [ - self.id.to_be_bytes(), - 0u32.to_be_bytes(), - 2u32.to_be_bytes(), - ] - .concat(); - let _ = self.network_manager.broadcast(&data, &addresses).await; - return; + // step2 get the log from other peers + // ping all the peers to get the log + let addresses: Vec = self.peers_address(); + let data = [ + self.id.to_be_bytes(), + 0u32.to_be_bytes(), + 2u32.to_be_bytes(), + ] + .concat(); + self.network_manager.broadcast(&data, &addresses).await.unwrap(); + return; + } + + let mut cursor = Cursor::new(&log_byte); + loop { + let mut bytes_data = vec![0u8; log_entry_size + CHECKSUM_LEN]; + if cursor.read_exact(&mut bytes_data).await.is_err() { + break; + } + bytes_data = bytes_data[0..log_entry_size].to_vec(); + + let log_entry = self.deserialize_log_entries(&bytes_data); + if log_entry.term > self.state.current_term { + self.state.current_term = log_entry.term; } - info!(self.log, "No log entries found on disk"); + self.state.log.push_front(log_entry); } + info!(self.log, "Log after reading from disk: {:?}", self.state.log); self.state.match_index = vec![0; self.peer_count() + 1]; self.state.next_index = vec![0; self.peer_count() + 1]; @@ -574,24 +569,39 @@ impl Server { let id = u32::from_be_bytes(data[0..4].try_into().unwrap()); let leader_term = u32::from_be_bytes(data[4..8].try_into().unwrap()); + let message_type = u32::from_be_bytes(data[8..12].try_into().unwrap()); + let prev_log_index = u32::from_be_bytes(data[12..16].try_into().unwrap()); + let commit_index = u32::from_be_bytes(data[16..20].try_into().unwrap()); + info!( + self.log, + "Node {} received append entries request from Node {}, \ + (term: self={}, receive={}), \ + (prev_log_index: self={}, receive={}), \ + (commit_index: self={}, receive={})", + self.id, + id, + self.state.current_term, + leader_term, + self.state.previous_log_index, + prev_log_index, + self.state.commit_index, + commit_index + ); if leader_term < self.state.current_term { return; } - let message_type = u32::from_be_bytes(data[8..12].try_into().unwrap()); if message_type != 2 { return; } - let prev_log_index = u32::from_be_bytes(data[12..16].try_into().unwrap()); if prev_log_index > self.state.previous_log_index { self.state.previous_log_index = prev_log_index; } else { return; } - let commit_index = u32::from_be_bytes(data[16..20].try_into().unwrap()); if commit_index > self.state.commit_index { self.state.commit_index = commit_index; } else { @@ -657,10 +667,20 @@ impl Server { let last_log_index = self.state.previous_log_index; self.state.match_index[sender_id as usize - 1] = last_log_index; self.state.next_index[sender_id as usize - 1] = last_log_index + 1; - + let mut match_indices = self.state.match_index.clone(); match_indices.sort(); let quorum_index = match_indices[self.peer_count() / 2]; + + info!( + self.log, + "Append entry response received from node {}: (match_index = {}, next_index = {}), current quorum_index: {}", + sender_id, + self.state.match_index[sender_id as usize - 1], + self.state.next_index[sender_id as usize - 1], + quorum_index + ); + if quorum_index >= self.state.commit_index { self.state.commit_index = quorum_index; // return client response @@ -732,21 +752,25 @@ impl Server { let peer_id = u32::from_be_bytes(data[0..4].try_into().unwrap()); - let log_byte = self.storage.retrieve().await; - if log_byte.is_err() { - error!(self.log, "Failed to retrieve log entries from disk"); + let log_byte = self.storage.retrieve().await.unwrap(); + let log_entry_size = std::mem::size_of::(); + + // Data integrity check failed + if log_byte.len() % (log_entry_size + CHECKSUM_LEN) != 0 { + error!(self.log, "Data integrity check failed"); return; } - - let log = log_byte.unwrap(); - let log_entries = log.chunks(std::mem::size_of::()); + + let mut cursor = Cursor::new(&log_byte); let mut repair_data = Vec::new(); - for entry in log_entries { - if entry.len() != std::mem::size_of::() { + loop { + let mut bytes_data = vec![0u8; log_entry_size + CHECKSUM_LEN]; + if cursor.read_exact(&mut bytes_data).await.is_err() { break; } - repair_data.extend_from_slice(entry); + repair_data.extend_from_slice(&bytes_data[0..log_entry_size].to_vec()); } + info!(self.log, "Send repair data from {} to {}, log_entry: {:?}", self.id, peer_id, self.state.log); let mut response = [ self.id.to_be_bytes(), @@ -882,7 +906,7 @@ impl Server { self.log, "Persisting logs to disk from peer: {} to server: {}", id, self.id ); - info!(self.log, "Data: {:?}", data); + // info!(self.log, "Data: {:?}", data); // Log Compaction if let Err(e) = self.storage.compaction().await { @@ -898,12 +922,17 @@ impl Server { error!(self.log, "Failed to store log entry to disk: {}", e); } - info!(self.log, "Log after appending: {:?}", self.state.log); + info!( + self.log, + "Log persistence complete, current log count: {}", + self.state.log.len() + ); + // info!(self.log, "Log after appending: {:?}", self.state.log); } fn deserialize_log_entries(&self, data: &[u8]) -> LogEntry { // convert data to logEntry using bincode - info!(self.log, "Deserializing log entry: {:?}", data); + // info!(self.log, "Deserializing log entry: {:?}", data); bincode::deserialize(data).unwrap() } diff --git a/src/storage.rs b/src/storage.rs index b97180c..41a897d 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,20 +1,23 @@ // Organization: SpacewalkHq // License: MIT License -use std::io; +use std::io::{self, Cursor, SeekFrom}; use std::path::{Path, PathBuf}; +use std::sync::Arc; use async_trait::async_trait; use hex; use sha2::{Digest, Sha256}; -use tokio::fs::{self, File}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::fs::{self, File, OpenOptions}; +use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; +use tokio::sync::Mutex; use crate::error::StorageError::CorruptFile; -use crate::error::{Error, Result, StorageError}; +use crate::error::{Error, Result}; +use crate::server::LogEntry; const MAX_FILE_SIZE: u64 = 1_000_000; -const CHECKSUM_LEN: usize = 64; +pub const CHECKSUM_LEN: usize = 64; #[async_trait] pub trait Storage { @@ -28,16 +31,37 @@ pub trait Storage { #[derive(Clone)] pub struct LocalStorage { path: PathBuf, + file: Arc>, } impl LocalStorage { - pub fn new(path: String) -> Self { - LocalStorage { path: path.into() } + pub async fn new(path: String) -> Self { + let file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .open(path.clone()) + .await + .unwrap(); + + LocalStorage { + path: path.into(), + file: Arc::new(Mutex::new(file)), + } } - pub fn new_from_path(path: &Path) -> Self { + pub async fn new_from_path(path: &Path) -> Self { + let file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .open(path) + .await + .unwrap(); + LocalStorage { - path: PathBuf::from(path), + path: path.into(), + file: Arc::new(Mutex::new(file)), } } @@ -53,40 +77,43 @@ impl LocalStorage { Ok(()) } + /// Asynchronously stores the provided data along with its checksum into a file. + /// + /// # Arguments + /// * `data` - A slice of bytes representing the data to be stored. async fn store_async(&self, data: &[u8]) -> Result<()> { let checksum = calculate_checksum(data); let data_with_checksum = [data, checksum.as_slice()].concat(); - let mut file = File::create(&self.path).await.map_err(Error::Io)?; - file.write_all(&data_with_checksum) + let file = Arc::clone(&self.file); + let mut locked_file = file.lock().await; + + locked_file.seek(SeekFrom::End(0)).await.unwrap(); + + locked_file + .write_all(&data_with_checksum) .await .map_err(Error::Io)?; - file.flush().await.map_err(Error::Io)?; + + // Attempts to sync all OS-internal metadata to disk. + locked_file.sync_all().await.map_err(Error::Io)?; + Ok(()) } + /// Asynchronously retrieves all data from the file. async fn retrieve_async(&self) -> Result> { - let mut file = File::open(&self.path).await.map_err(Error::Io)?; - let mut buffer = Vec::new(); - file.read_to_end(&mut buffer).await.map_err(Error::Io)?; - - if buffer.is_empty() { - return Err(Error::Store(StorageError::EmptyFile)); - } - - if buffer.len() < CHECKSUM_LEN { - return Err(Error::Store(StorageError::CorruptFile)); - } + let file = Arc::clone(&self.file); + let mut locked_file = file.lock().await; + locked_file.seek(SeekFrom::Start(0)).await.unwrap(); - let data = &buffer[..buffer.len() - 64]; - let stored_checksum = retrieve_checksum(&buffer); - let calculated_checksum = calculate_checksum(data); - - if stored_checksum != calculated_checksum { - return Err(Error::Store(StorageError::DataIntegrityError)); - } + let mut buffer = Vec::new(); + locked_file + .read_to_end(&mut buffer) + .await + .map_err(Error::Io)?; - Ok(data.to_vec()) + Ok(buffer) } async fn delete_async(&self) -> Result<()> { @@ -96,12 +123,26 @@ impl LocalStorage { async fn compaction_async(&self) -> Result<()> { // If file size is greater than 1MB, then compact it - let metadata = fs::metadata(&self.path).await.map_err(Error::Io)?; + let file = Arc::clone(&self.file); + let locked_file = file.lock().await; + let metadata = locked_file.metadata().await.map_err(Error::Io)?; if metadata.len() > MAX_FILE_SIZE { self.delete_async().await?; } Ok(()) } + + async fn is_file_size_exceeded(&self) -> Result<()> { + let file = Arc::clone(&self.file); + let locked_file = file.lock().await; + + let md = locked_file.metadata().await.map_err(Error::Io)?; + if md.len() > MAX_FILE_SIZE { + return Err(Error::Store(CorruptFile)); + } + + Ok(()) + } } #[async_trait] @@ -123,13 +164,42 @@ impl Storage for LocalStorage { } async fn turned_malicious(&self) -> Result<()> { - // Check if the file is tampered with - self.retrieve().await?; - let metadata = fs::metadata(&self.path).await.map_err(Error::Io)?; + self.is_file_size_exceeded().await.unwrap(); - if metadata.len() > MAX_FILE_SIZE { + let disk_data = self.retrieve().await?; + let log_entry_size = std::mem::size_of::(); + + if disk_data.len() % (log_entry_size + CHECKSUM_LEN) != 0 { return Err(Error::Store(CorruptFile)); } + + let mut cursor = Cursor::new(&disk_data); + loop { + let mut bytes_data = vec![0u8; log_entry_size]; + if let Err(err) = cursor.read_exact(&mut bytes_data).await { + if err.kind() == std::io::ErrorKind::UnexpectedEof { + break; + } else { + return Err(Error::Io(err)); + } + } + + let byte_data_checksum = calculate_checksum(&bytes_data); + + let mut checksum = [0u8; CHECKSUM_LEN]; + if let Err(err) = cursor.read_exact(&mut checksum).await { + if err.kind() == std::io::ErrorKind::UnexpectedEof { + break; + } else { + return Err(Error::Io(err)); + } + } + + if byte_data_checksum.ne(&checksum) { + return Err(Error::Store(CorruptFile)); + } + } + Ok(()) } } @@ -146,28 +216,34 @@ fn calculate_checksum(data: &[u8]) -> [u8; CHECKSUM_LEN] { checksum } -/// Helper function to extract the checksum from the end of a given byte slice. -/// It assumes that the checksum is of a fixed length `CHECKSUM_LEN` and is located -/// at the end of the provided data slice. -/// -/// This function will panic if the length of the provided data slice is less than `CHECKSUM_LEN`. -fn retrieve_checksum(data: &[u8]) -> [u8; CHECKSUM_LEN] { - assert!(data.len() >= CHECKSUM_LEN); - let mut op = [0; 64]; - op.copy_from_slice(&data[data.len() - CHECKSUM_LEN..]); - op -} - #[cfg(test)] mod tests { - use std::io::{Read, Seek, SeekFrom, Write}; + + use std::io::{Cursor, SeekFrom}; use tempfile::NamedTempFile; + use tokio::{ + fs::OpenOptions, + io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}, + }; - use crate::storage::{ - calculate_checksum, retrieve_checksum, LocalStorage, Storage, CHECKSUM_LEN, + use crate::{ + server::{LogCommand, LogEntry}, + storage::{calculate_checksum, LocalStorage, Storage, CHECKSUM_LEN}, }; + /// Helper function to extract the checksum from the end of a given byte slice. + /// It assumes that the checksum is of a fixed length `CHECKSUM_LEN` and is located + /// at the end of the provided data slice. + /// + /// This function will panic if the length of the provided data slice is less than `CHECKSUM_LEN`. + fn retrieve_checksum(data: &[u8]) -> [u8; CHECKSUM_LEN] { + assert!(data.len() >= CHECKSUM_LEN); + let mut op = [0; 64]; + op.copy_from_slice(&data[data.len() - CHECKSUM_LEN..]); + op + } + #[test] fn test_retrieve_checksum() { let data_str = "Some data followed by a checksum".as_bytes(); @@ -181,27 +257,34 @@ mod tests { #[tokio::test] async fn test_store_async() { - let mut tmp_file = NamedTempFile::new().unwrap(); - let storage: Box = Box::new(LocalStorage::new_from_path(tmp_file.path())); + let tmp_file = NamedTempFile::new().unwrap(); + let storage: Box = + Box::new(LocalStorage::new_from_path(tmp_file.path()).await); + let payload_data = "Some data to test raft".as_bytes(); let store_result = storage.store(payload_data).await; assert!(store_result.is_ok()); - tmp_file.as_file().sync_all().unwrap(); + let buffer = storage.retrieve().await.unwrap(); - let mut buffer = vec![]; - tmp_file.read_to_end(&mut buffer).unwrap(); - - assert_eq!(payload_data.len() + CHECKSUM_LEN, buffer.len()); + // Verify the length of the stored data (original data + checksum). + assert_eq!( + payload_data.len() + CHECKSUM_LEN, + buffer.len(), + "Stored data length mismatch" + ); - let data = &buffer[..buffer.len() - CHECKSUM_LEN]; - assert_eq!(payload_data, data); + let stored_data = &buffer[..buffer.len() - CHECKSUM_LEN]; + // Verify the original data matches the input data. + assert_eq!(payload_data, stored_data, "Stored data mismatch"); } #[tokio::test] async fn test_delete() { let tmp_file = NamedTempFile::new().unwrap(); - let storage: Box = Box::new(LocalStorage::new_from_path(tmp_file.path())); + let storage: Box = + Box::new(LocalStorage::new_from_path(tmp_file.path()).await); + let delete_result = storage.delete().await; assert!(delete_result.is_ok()); assert!(!tmp_file.path().exists()); @@ -210,13 +293,13 @@ mod tests { #[tokio::test] async fn test_compaction_file_lt_max_file_size() { let tmp_file = NamedTempFile::new().unwrap(); - let storage: Box = Box::new(LocalStorage::new_from_path(tmp_file.path())); + let storage: Box = + Box::new(LocalStorage::new_from_path(tmp_file.path()).await); let mock_data = vec![0u8; 1_000_000 /*1 MB*/ - 500]; + let store_result = storage.store(&mock_data).await; assert!(store_result.is_ok()); - tmp_file.as_file().sync_all().unwrap(); - let compaction_result = storage.compaction().await; assert!(compaction_result.is_ok()); @@ -226,13 +309,13 @@ mod tests { #[tokio::test] async fn test_compaction_file_gt_max_file_size() { let tmp_file = NamedTempFile::new().unwrap(); - let storage: Box = Box::new(LocalStorage::new_from_path(tmp_file.path())); + let storage: Box = + Box::new(LocalStorage::new_from_path(tmp_file.path()).await); let mock_data = vec![0u8; 1_000_000 /*1 MB*/]; + let store_result = storage.store(&mock_data).await; assert!(store_result.is_ok()); - tmp_file.as_file().sync_all().unwrap(); - let compaction_result = storage.compaction().await; assert!(compaction_result.is_ok()); @@ -242,31 +325,95 @@ mod tests { #[tokio::test] async fn test_retrieve() { let tmp_file = NamedTempFile::new().unwrap(); - let storage: Box = Box::new(LocalStorage::new_from_path(tmp_file.path())); - let test_data = "Some mocked data".as_bytes(); - - storage.store(test_data).await.unwrap(); - - let retrieved_result = storage.retrieve().await; - assert!(retrieved_result.is_ok()); + let storage: Box = + Box::new(LocalStorage::new_from_path(tmp_file.path()).await); + let log_entry_size = std::mem::size_of::(); + + // Insert the first data first + let entry1 = LogEntry { + leader_id: 1, + server_id: 1, + term: 1, + command: LogCommand::Set, + data: 1, + }; + let serialize_data = bincode::serialize(&entry1).unwrap(); + storage.store(&serialize_data).await.unwrap(); + let disk_data = storage.retrieve().await.unwrap(); + let log_entry_bytes = &disk_data[0..log_entry_size]; + let disk_entry: LogEntry = bincode::deserialize(log_entry_bytes).unwrap(); + assert_eq!(entry1, disk_entry); + + // Then insert the second data + let entry2 = LogEntry { + leader_id: 2, + server_id: 2, + term: 2, + command: LogCommand::Set, + data: 2, + }; + let serialize_data = bincode::serialize(&entry2).unwrap(); + storage.store(&serialize_data).await.unwrap(); + let disk_data = storage.retrieve().await.unwrap(); + + // Try to read two pieces of data and sit down to compare + let mut log_entrys = vec![]; + let mut cursor = Cursor::new(&disk_data); + loop { + let mut bytes_data = vec![0u8; log_entry_size]; + if cursor.read_exact(&mut bytes_data).await.is_err() { + break; + } + let struct_data: LogEntry = bincode::deserialize(&bytes_data).unwrap(); + + let mut checksum = [0u8; CHECKSUM_LEN]; + if cursor.read_exact(&mut checksum).await.is_err() { + break; + } + + log_entrys.push(struct_data); + } - assert_eq!(test_data, retrieved_result.unwrap()); + assert_eq!(vec![entry1, entry2], log_entrys); } #[tokio::test] async fn test_turned_malicious_file_corrupted() { - let mut tmp_file = NamedTempFile::new().unwrap(); - let storage: Box = Box::new(LocalStorage::new_from_path(tmp_file.path())); - storage.store("Java is awesome".as_bytes()).await.unwrap(); - - tmp_file.as_file().sync_all().unwrap(); + let tmp_file = NamedTempFile::new().unwrap(); + let storage: Box = + Box::new(LocalStorage::new_from_path(tmp_file.path()).await); + + // Try to write the data once + let entry1 = LogEntry { + leader_id: 1, + server_id: 1, + term: 1, + command: LogCommand::Set, + data: 1, + }; + let serialize_data = bincode::serialize(&entry1).unwrap(); + let store_result = storage.store(&serialize_data).await; + assert!(store_result.is_ok()); - // corrupt the file - tmp_file.seek(SeekFrom::Start(0)).unwrap(); - tmp_file.write_all("Raft".as_bytes()).unwrap(); + // We will go to simulate that the data is corrupted and does not conform to the original format + // [(LogEntry, checksum), ...] + let mut file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .open(tmp_file.path()) + .await + .unwrap(); + file.seek(SeekFrom::Start(0)).await.unwrap(); - tmp_file.as_file().sync_all().unwrap(); + file.write_all("Raft".as_bytes()).await.unwrap(); + file.seek(SeekFrom::Start(0)).await.unwrap(); + let mut buffer = vec![]; + file.read_to_end(&mut buffer).await.unwrap(); + file.sync_all().await.unwrap(); + let storage: Box = + Box::new(LocalStorage::new_from_path(tmp_file.path()).await); let result = storage.turned_malicious().await; assert!(result.is_err()); } @@ -274,10 +421,20 @@ mod tests { #[tokio::test] async fn test_turned_malicious_happy_case() { let tmp_file = NamedTempFile::new().unwrap(); - let storage: Box = Box::new(LocalStorage::new_from_path(tmp_file.path())); - storage.store("Java is awesome".as_bytes()).await.unwrap(); - - tmp_file.as_file().sync_all().unwrap(); + let storage: Box = + Box::new(LocalStorage::new_from_path(tmp_file.path()).await); + + // Try to write the data once + let entry1 = LogEntry { + leader_id: 1, + server_id: 1, + term: 1, + command: LogCommand::Set, + data: 1, + }; + let serialize_data = bincode::serialize(&entry1).unwrap(); + let store_result = storage.store(&serialize_data).await; + assert!(store_result.is_ok()); let result = storage.turned_malicious().await; assert!(result.is_ok());