diff --git a/diffsims/simulations/simulation2d.py b/diffsims/simulations/simulation2d.py index 313b912d..a081f6ee 100644 --- a/diffsims/simulations/simulation2d.py +++ b/diffsims/simulations/simulation2d.py @@ -21,6 +21,7 @@ import numpy as np import matplotlib.pyplot as plt +import math from orix.crystal_map import Phase from orix.quaternion import Rotation from orix.vector import Vector3d @@ -182,6 +183,31 @@ def __init__( self.iphase = PhaseGetter(self) self.irot = RotationGetter(self) + def get_simulation(self, item): + """Return the rotation and the phase index of the simulation""" + if self.has_multiple_phases and self.has_multiple_rotations: + cumsum = np.cumsum(self._num_rotations()) + ind = np.searchsorted(cumsum, item, side="right") + cumsum = np.insert(cumsum, 0, 0) + num_rot = cumsum[ind] + return ( + self.rotations[ind][item - num_rot], + ind, + self.coordinates[ind][item - num_rot], + ) + elif self.has_multiple_phases: + return self.rotations[item], item, self.coordinates[item] + elif self.has_multiple_rotations: + return self.rotations[item], 0, self.coordinates[item] + else: + return self.rotations[item], 0, self.coordinates + + def _num_rotations(self): + if self.has_multiple_phases: + return [r.size for r in self.rotations] + else: + return self.rotations.size + def __iter__(self): return self @@ -268,7 +294,7 @@ def rotate_shift_coordinates( ) return coords_new - def polar_flatten_simulations(self): + def polar_flatten_simulations(self, radial_axes=None, azimuthal_axes=None): """Flattens the simulations into polar coordinates for use in template matching. The resulting arrays are of shape (n_simulations, n_spots) where n_spots is the maximum number of spots in any simulation. @@ -285,12 +311,19 @@ def polar_flatten_simulations(self): r_templates = np.zeros((len(flattened_vectors), max_num_spots)) theta_templates = np.zeros((len(flattened_vectors), max_num_spots)) intensities_templates = np.zeros((len(flattened_vectors), max_num_spots)) - for i, v in enumerate(flattened_vectors): r, t, _ = v.to_polar() + if radial_axes is not None and azimuthal_axes is not None: + r = get_closest(radial_axes, r) + t = get_closest(azimuthal_axes, t) + r = r[r < len(radial_axes)] + t = t[t < len(azimuthal_axes)] r_templates[i, : len(r)] = r theta_templates[i, : len(t)] = t intensities_templates[i, : len(v.intensity)] = v.intensity + if radial_axes is not None and azimuthal_axes is not None: + r_templates = np.array(r_templates, dtype=int) + theta_templates = np.array(theta_templates, dtype=int) return r_templates, theta_templates, intensities_templates @@ -554,3 +587,20 @@ def plot( ax.set_xlabel("pixels") ax.set_ylabel("pixels") return ax, sp + + +def get_closest(array, values): + # make sure array is a numpy array + array = np.array(array) + + # get insert positions + idxs = np.searchsorted(array, values, side="left") + + # find indexes where previous index is closer + prev_idx_is_less = (idxs == len(array)) | ( + np.fabs(values - array[np.maximum(idxs - 1, 0)]) + < np.fabs(values - array[np.minimum(idxs, len(array) - 1)]) + ) + idxs[prev_idx_is_less] -= 1 + + return idxs diff --git a/diffsims/tests/simulations/test_simulation.py b/diffsims/tests/simulations/test_simulation.py index d93881ab..b8d75272 100644 --- a/diffsims/tests/simulations/test_simulation.py +++ b/diffsims/tests/simulations/test_simulation.py @@ -184,6 +184,12 @@ def multi_simulation(self, al_phase): ) return sim + def test_get_simulation(self, multi_simulation): + for i in range(4): + rotation, phase = multi_simulation.get_simulation(i) + assert isinstance(rotation, Rotation) + assert phase == 0 + def test_get_current_rotation(self, multi_simulation): rot = multi_simulation.get_current_rotation() np.testing.assert_array_equal(rot, multi_simulation.rotations[0].to_matrix()[0]) @@ -285,6 +291,16 @@ def test_init(self, multi_simulation): assert isinstance(multi_simulation.rotations, np.ndarray) assert isinstance(multi_simulation.coordinates, np.ndarray) + def test_get_simulation(self, multi_simulation): + for i in range(4): + rotation, phase = multi_simulation.get_simulation(i) + assert isinstance(rotation, Rotation) + assert phase == 0 + for i in range(4, 8): + rotation, phase = multi_simulation.get_simulation(i) + assert isinstance(rotation, Rotation) + assert phase == 1 + def test_iphase(self, multi_simulation): phase_slic = multi_simulation.iphase[0] assert isinstance(phase_slic, Simulation2D)