Skip to content

Commit

Permalink
Refactor: Add methods for easier indexing related to Orientation Mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
CSSFrancis committed Feb 14, 2024
1 parent 7efd4ee commit 980ae97
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 2 deletions.
54 changes: 52 additions & 2 deletions diffsims/simulations/simulation2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import numpy as np
import matplotlib.pyplot as plt
import math
from orix.crystal_map import Phase
from orix.quaternion import Rotation
from orix.vector import Vector3d
Expand Down Expand Up @@ -182,6 +183,31 @@ def __init__(
self.iphase = PhaseGetter(self)
self.irot = RotationGetter(self)

def get_simulation(self, item):
"""Return the rotation and the phase index of the simulation"""
if self.has_multiple_phases and self.has_multiple_rotations:
cumsum = np.cumsum(self._num_rotations())
ind = np.searchsorted(cumsum, item, side="right")
cumsum = np.insert(cumsum, 0, 0)
num_rot = cumsum[ind]
return (
self.rotations[ind][item - num_rot],
ind,
self.coordinates[ind][item - num_rot],
)
elif self.has_multiple_phases:
return self.rotations[item], item, self.coordinates[item]
elif self.has_multiple_rotations:
return self.rotations[item], 0, self.coordinates[item]
else:
return self.rotations[item], 0, self.coordinates

def _num_rotations(self):
if self.has_multiple_phases:
return [r.size for r in self.rotations]
else:
return self.rotations.size

def __iter__(self):
return self

Expand Down Expand Up @@ -268,7 +294,7 @@ def rotate_shift_coordinates(
)
return coords_new

def polar_flatten_simulations(self):
def polar_flatten_simulations(self, radial_axes=None, azimuthal_axes=None):
"""Flattens the simulations into polar coordinates for use in template matching.
The resulting arrays are of shape (n_simulations, n_spots) where n_spots is the
maximum number of spots in any simulation.
Expand All @@ -285,12 +311,19 @@ def polar_flatten_simulations(self):
r_templates = np.zeros((len(flattened_vectors), max_num_spots))
theta_templates = np.zeros((len(flattened_vectors), max_num_spots))
intensities_templates = np.zeros((len(flattened_vectors), max_num_spots))

for i, v in enumerate(flattened_vectors):
r, t, _ = v.to_polar()
if radial_axes is not None and azimuthal_axes is not None:
r = get_closest(radial_axes, r)
t = get_closest(azimuthal_axes, t)
r = r[r < len(radial_axes)]
t = t[t < len(azimuthal_axes)]
r_templates[i, : len(r)] = r
theta_templates[i, : len(t)] = t
intensities_templates[i, : len(v.intensity)] = v.intensity
if radial_axes is not None and azimuthal_axes is not None:
r_templates = np.array(r_templates, dtype=int)
theta_templates = np.array(theta_templates, dtype=int)

return r_templates, theta_templates, intensities_templates

Expand Down Expand Up @@ -554,3 +587,20 @@ def plot(
ax.set_xlabel("pixels")
ax.set_ylabel("pixels")
return ax, sp


def get_closest(array, values):
# make sure array is a numpy array
array = np.array(array)

# get insert positions
idxs = np.searchsorted(array, values, side="left")

# find indexes where previous index is closer
prev_idx_is_less = (idxs == len(array)) | (
np.fabs(values - array[np.maximum(idxs - 1, 0)])
< np.fabs(values - array[np.minimum(idxs, len(array) - 1)])
)
idxs[prev_idx_is_less] -= 1

return idxs
16 changes: 16 additions & 0 deletions diffsims/tests/simulations/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,12 @@ def multi_simulation(self, al_phase):
)
return sim

def test_get_simulation(self, multi_simulation):
for i in range(4):
rotation, phase = multi_simulation.get_simulation(i)
assert isinstance(rotation, Rotation)
assert phase == 0

def test_get_current_rotation(self, multi_simulation):
rot = multi_simulation.get_current_rotation()
np.testing.assert_array_equal(rot, multi_simulation.rotations[0].to_matrix()[0])
Expand Down Expand Up @@ -285,6 +291,16 @@ def test_init(self, multi_simulation):
assert isinstance(multi_simulation.rotations, np.ndarray)
assert isinstance(multi_simulation.coordinates, np.ndarray)

def test_get_simulation(self, multi_simulation):
for i in range(4):
rotation, phase = multi_simulation.get_simulation(i)
assert isinstance(rotation, Rotation)
assert phase == 0
for i in range(4, 8):
rotation, phase = multi_simulation.get_simulation(i)
assert isinstance(rotation, Rotation)
assert phase == 1

def test_iphase(self, multi_simulation):
phase_slic = multi_simulation.iphase[0]
assert isinstance(phase_slic, Simulation2D)
Expand Down

0 comments on commit 980ae97

Please sign in to comment.