diff --git a/src/md/trajectory/traj.rs b/src/md/trajectory/traj.rs index 0609fea9..a28ffe9f 100644 --- a/src/md/trajectory/traj.rs +++ b/src/md/trajectory/traj.rs @@ -595,55 +595,53 @@ where DefaultAllocator: Allocator + Allocator + Allocator, { - type Output = Traj; + type Output = Result, NyxError>; /// Add one trajectory to another. If they do not overlap to within 10ms, a warning will be printed. fn add(self, other: Traj) -> Self::Output { - self + &other + &self + &other } } -impl ops::Add<&Traj> for Traj +impl ops::Add<&Traj> for &Traj where DefaultAllocator: Allocator + Allocator + Allocator, { - type Output = Traj; + type Output = Result, NyxError>; - /// Add one trajectory to another. If they do not overlap to within 10ms, a warning will be printed. + /// Add one trajectory to another, returns an error if the frames don't match fn add(self, other: &Traj) -> Self::Output { - let (first, second) = if self.first().epoch() < other.first().epoch() { - (&self, other) + if self.first().frame() != other.first().frame() { + Err(NyxError::Trajectory(TrajError::CreationError(format!( + "Frame mismatch in add operation: {} != {}", + self.first().frame(), + other.first().frame() + )))) } else { - (other, &self) - }; + if self.last().epoch() < other.first().epoch() { + let gap = other.first().epoch() - self.last().epoch(); + warn!( + "Resulting merged trajectory will have a time-gap of {} starting at {}", + gap, + self.last().epoch() + ); + } - if first.last().epoch() < second.first().epoch() { - let gap = second.first().epoch() - first.last().epoch(); - warn!( - "Resulting merged trajectory will have a time-gap of {} starting at {}", - gap, - first.last().epoch() - ); - } + let mut me = self.clone(); + // Now start adding the other segments while correcting the index + for state in &other + .states + .iter() + .filter(|s| s.epoch() > self.last().epoch()) + .collect::>() + { + me.states.push(**state); + } + me.finalize(); - let mut me = self.clone(); - // Now start adding the other segments while correcting the index - for state in &second.states { - me.states.push(*state); + Ok(me) } - me.finalize(); - me - } -} - -impl ops::AddAssign for Traj -where - DefaultAllocator: - Allocator + Allocator + Allocator, -{ - fn add_assign(&mut self, rhs: Self) { - *self = self.clone() + rhs; } } @@ -652,8 +650,13 @@ where DefaultAllocator: Allocator + Allocator + Allocator, { + /// Attempt to add two trajectories together and assign it to `self` + /// + /// # Warnings + /// 1. This will panic if the frames mismatch! + /// 2. This is inefficient because both `self` and `rhs` are cloned. fn add_assign(&mut self, rhs: &Self) { - *self = self.clone() + rhs; + *self = (self.clone() + rhs.clone()).unwrap(); } } diff --git a/src/python/mission_design/orbit_trajectory.rs b/src/python/mission_design/orbit_trajectory.rs index f1129398..04ff188a 100644 --- a/src/python/mission_design/orbit_trajectory.rs +++ b/src/python/mission_design/orbit_trajectory.rs @@ -167,6 +167,24 @@ impl OrbitTraj { } } + /// Allows converting the source trajectory into the (almost) equivalent trajectory in another frame. + /// This simply converts each state into the other frame and may lead to aliasing due to the Nyquist–Shannon sampling theorem. + fn to_frame(&self, new_frame: String) -> Result { + let cosm = Cosm::de438(); + + let frame = cosm.try_frame(&new_frame)?; + + let conv_traj = self.inner.to_frame(frame, cosm)?; + + Ok(Self { inner: conv_traj }) + } + + fn __add__(&self, rhs: &Self) -> Result { + let inner = (self.inner.clone() + rhs.inner.clone())?; + + Ok(Self { inner }) + } + fn __str__(&self) -> String { format!("{}", self.inner) } diff --git a/src/python/mission_design/sc_trajectory.rs b/src/python/mission_design/sc_trajectory.rs index 3f7f8e3f..7c92da5a 100644 --- a/src/python/mission_design/sc_trajectory.rs +++ b/src/python/mission_design/sc_trajectory.rs @@ -172,6 +172,24 @@ impl SpacecraftTraj { } } + /// Allows converting the source trajectory into the (almost) equivalent trajectory in another frame. + /// This simply converts each state into the other frame and may lead to aliasing due to the Nyquist–Shannon sampling theorem. + fn to_frame(&self, new_frame: String) -> Result { + let cosm = Cosm::de438(); + + let frame = cosm.try_frame(&new_frame)?; + + let conv_traj = self.inner.to_frame(frame, cosm)?; + + Ok(Self { inner: conv_traj }) + } + + fn __add__(&self, rhs: &Self) -> Result { + let inner = (self.inner.clone() + rhs.inner.clone())?; + + Ok(Self { inner }) + } + fn __str__(&self) -> String { format!("{}", self.inner) } diff --git a/tests/python/test_mission_design.py b/tests/python/test_mission_design.py index bd4fc20f..6ce18727 100644 --- a/tests/python/test_mission_design.py +++ b/tests/python/test_mission_design.py @@ -155,7 +155,12 @@ def test_build_spacecraft(): traj_orbit = traj.to_orbit_traj() traj_orbit_dc = traj_sc.downcast() # Check that we can query it (will raise an exception if we can't, thereby failing the test) - ts = TimeSeries(Epoch("2020-06-01T12:00:00.000000"), Epoch("2020-06-01T13:00:00.000000"), step=Unit.Minute*17 + Unit.Second*13.8, inclusive=True) + ts = TimeSeries( + Epoch("2020-06-01T12:00:00.000000"), + Epoch("2020-06-01T13:00:00.000000"), + step=Unit.Minute * 17 + Unit.Second * 13.8, + inclusive=True, + ) for epoch in ts: orbit = traj_orbit.at(epoch) dc_orbit = traj_orbit_dc.at(epoch) @@ -204,11 +209,11 @@ def test_two_body(): ) # And propagate in parallel using a single duration - proped_orbits = two_body(orbits, durations=[Unit.Day*531.5]) + proped_orbits = two_body(orbits, durations=[Unit.Day * 531.5]) assert len(proped_orbits) == len(orbits) # And propagate in parallel using many epochs - ts = TimeSeries(e, e + Unit.Day * 1000, step=Unit.Day*1, inclusive=False) + ts = TimeSeries(e, e + Unit.Day * 1000, step=Unit.Day * 1, inclusive=False) epochs = [e for e in ts] proped_orbits = two_body(orbits, new_epochs=epochs) # Allow up to two to fail @@ -217,5 +222,42 @@ def test_two_body(): timing = timeit(lambda: two_body(orbits, new_epochs=epochs), number=1) print(f"two body propagation of {len(orbits)} orbits in {timing} s") + +def test_merge_traj(): + # Initialize logging + FORMAT = "%(levelname)s %(name)s %(filename)s:%(lineno)d\t%(message)s" + logging.basicConfig(format=FORMAT) + logging.getLogger().setLevel(logging.INFO) + + # Base path + root = Path(__file__).joinpath("../../../").resolve() + + config_path = root.joinpath("./data/tests/config/") + + sc1 = Spacecraft.load(str(config_path.joinpath("spacecraft.yaml"))) + + dynamics = SpacecraftDynamics.load_named( + str(config_path.joinpath("dynamics.yaml")) + )["lofi"] + + sc2, traj1 = propagate(sc1, dynamics, Unit.Day * 5) + # And propagate again + sc3, traj2 = propagate(sc2, dynamics, Unit.Day * 5) + # Add the trajectories + traj = traj1 + traj2 + + assert traj.last().epoch == sc3.epoch, f"{traj.last()} != {sc3}" + + # Convert into another frame and try to add them too. + # We only check the epoch this time. + traj1_moon = traj1.to_frame("Moon J2000") + traj2_moon = traj2.to_frame("Moon J2000") + + traj_moon = traj1_moon + traj2_moon + + assert traj_moon.last().epoch == sc3.epoch + print(traj_moon) + + if __name__ == "__main__": - test_two_body() + test_merge_traj()