diff --git a/src/lib/drivers/tlog/writer.rs b/src/lib/drivers/tlog/writer.rs index 43ff481..1552490 100644 --- a/src/lib/drivers/tlog/writer.rs +++ b/src/lib/drivers/tlog/writer.rs @@ -1,6 +1,6 @@ use std::{path::PathBuf, sync::Arc}; -use anyhow::Result; +use anyhow::{Context, Result}; use tokio::{ io::{AsyncWriteExt, BufWriter}, sync::{broadcast, RwLock}, @@ -113,6 +113,29 @@ impl TlogWriter { writer.write_all(×tamp.to_be_bytes()).await?; writer.write_all(raw_bytes).await?; writer.flush().await?; + + if let FileCreationCondition::OnArm(ExpectedOrigin { + system_id, + component_id, + }) = &self.file_creation_condition + { + let system_id = system_id.read().await.expect( + "System ID should always be Some at this point because it was replaced when armed, which is the condition to reach this part", + ); + + if *message.system_id() != system_id + || message.component_id() != component_id + { + continue; + } + + if let Some(ArmState::Disarmed) = check_arm_state(&message) { + debug!( + "Vehicle disarmed, finishing tlog file writer until next arm..." + ); + break; + } + } } Err(error) => { error!("Failed to receive message from hub: {error:?}"); @@ -130,11 +153,56 @@ impl TlogWriter { impl Driver for TlogWriter { #[instrument(level = "debug", skip(self, hub_sender))] async fn run(&self, hub_sender: broadcast::Sender>) -> Result<()> { - let file = tokio::fs::File::create(self.path.clone()).await?; - let writer = tokio::io::BufWriter::with_capacity(1024, file); - let hub_receiver = hub_sender.subscribe(); + let mut armed = false; + + let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(1)); + let mut first = true; + loop { + if first { + first = false; + } else { + interval.tick().await; + } + + if !armed { + if let FileCreationCondition::OnArm(expected_origin) = &self.file_creation_condition + { + debug!( + "TlogWriter waiting for its arm condition as {:?}", + self.file_creation_condition + ); + let hub_receiver = hub_sender.subscribe(); + if let Err(error) = wait_for_arm(hub_receiver, expected_origin).await { + warn!("Failed waiting for arm: {error:?}"); + continue; + } + debug!( + "TlogWriter has reached its arm condition as {:?}", + self.file_creation_condition + ); + armed = true; + } + } + + let file = match create_tlog_file(self.path.clone()).await { + Ok(file) => file, + Err(error) => { + warn!("Failed creating tlog file: {error:?}"); + continue; + } + }; + + debug!("Writing to tlog file: {file:?}"); + + let writer = tokio::io::BufWriter::with_capacity(1024, file); + let hub_receiver = hub_sender.subscribe(); + + if let Err(error) = TlogWriter::handle_client(self, writer, hub_receiver).await { + debug!("TlogWriter client ended with an error: {error:?}"); + } - TlogWriter::handle_client(self, writer, hub_receiver).await + armed = false; + } } #[instrument(level = "debug", skip(self))] @@ -246,3 +314,121 @@ impl DriverInfo for TlogWriterInfo { )) } } + +async fn create_tlog_file(path: PathBuf) -> Result { + let file_path: PathBuf = if path.extension().and_then(|ext| ext.to_str()) == Some("tlog") { + path.clone() + } else { + if !std::path::Path::new(&path).exists() { + tokio::fs::create_dir_all(&path).await?; + } + + let sequence = get_sequence(&path).await.unwrap_or_default(); + let timestamp = chrono::Local::now().format("%Y-%m-%d_%H-%M-%S"); + let file_name = format!("{sequence:05}-{timestamp}.tlog"); + + let mut file_path = path.clone(); + file_path.push(file_name); + file_path + }; + + tokio::fs::File::create(file_path) + .await + .map_err(anyhow::Error::msg) +} + +async fn get_sequence(path: &PathBuf) -> Result { + let re = regex::Regex::new(r"^(\d{5})-\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\.tlog$") + .expect("Failed to compile regex"); + + let mut max_seq: u32 = 0; + let mut files_in_dir = tokio::fs::read_dir(&path).await?; + + while let Some(entry) = files_in_dir.next_entry().await? { + let entry_path = entry.path(); + + if let Some(sequence) = entry_path + .file_name() + .and_then(|name| name.to_str()) + .and_then(|file_name| re.captures(file_name)) + .and_then(|captures| captures.get(1)) + .and_then(|seq_match| seq_match.as_str().parse::().ok()) + { + max_seq = max_seq.max(sequence); + } + } + + Ok(max_seq + 1) +} + +async fn wait_for_arm( + mut hub_receiver: broadcast::Receiver>, + ExpectedOrigin { + system_id, + component_id, + }: &ExpectedOrigin, +) -> Result<()> { + loop { + let message = hub_receiver.recv().await?; + + if message.component_id() == component_id + && matches!(check_arm_state(&message), Some(ArmState::Armed)) + { + debug!( + "Received arm from system {:?}. Current: {system_id:?}", + message.system_id() + ); + + let current_system_id = system_id.read().await.clone(); + + let system_id = match current_system_id { + Some(system_id) => system_id, + None => { + system_id + .write() + .await + .replace(*message.system_id()) + .context("This should always be None") + .expect_err("This should never be Ok"); + + debug!("Expected System ID updated to {system_id:?}"); + + *message.system_id() + } + }; + + if *message.system_id() == system_id { + break; + } + } + } + + Ok(()) +} + +enum ArmState { + Disarmed, + Armed, +} + +/// A performant way of checking if the vehicle is armed without parsing the message into a heartbeat message type (from Mavlink crate) +fn check_arm_state(message: &Arc) -> Option { + use mavlink::MessageData; + + if message.message_id() != mavlink::common::HEARTBEAT_DATA::ID { + return None; + } + + const BASE_MODE_BYTE: usize = 6; // From: https://mavlink.io/en/messages/common.html#HEARTBEAT + + let base_mode = message + .payload() + .get(BASE_MODE_BYTE) + .cloned() + .unwrap_or_else(|| mavlink::common::MavModeFlag::empty().bits()); + + match base_mode & mavlink::common::MavModeFlag::MAV_MODE_FLAG_SAFETY_ARMED.bits() { + 0 => Some(ArmState::Disarmed), + _ => Some(ArmState::Armed), + } +}