Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing usages of index(t) to a more robust isclose() #23

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
34 changes: 32 additions & 2 deletions lindbladmpo/plot_routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,32 @@ def _save_fig(b_save_figures, s_file_prefix, s_file_label):
plt.savefig(s_file_prefix + s_file_label + ".png")


def find_index_nearest_time_within_tolerance(arr, target_time, rtol):
"""
Find the index of the nearest time entry in an array of time values.

This function computes the index of the time entry in the given array that is
closest to the specified target time. It takes into account the relative tolerance
(`rtol`) for determining closeness.

Args:
arr (sequence of float): An array of time values.
target_time (float): The target time to compare against.
rtol (float): The relative tolerance for considering time values as close.

Returns:
int: The index of the closest time entry in the array. If no time entry is found
that meets the rtol criteria, the function still returns the index of
the time entry closest to the target time.
"""
diff = np.abs(np.array(arr) - target_time)
closest_index = np.argmin(diff)
if np.isclose(diff[closest_index], 0, rtol=rtol, atol=0):
return closest_index
else:
return closest_index


def prepare_time_data(
parameters: dict,
n_t_ticks=10,
Expand Down Expand Up @@ -235,7 +261,9 @@ def prepare_2q_correlation_matrix(
# arrays are identical, if they are equal in number. Verifying the time array lengths
# will avoid crashes due to interrupted simulations with incomplete data files.
try:
t_index = obs_0[0].index(t)
t_index = find_index_nearest_time_within_tolerance(
obs_0[0], t, 0.01
)
obs_data[i, j] = (
obs_2[1][t_index] - obs_0[1][t_index] * obs_1[1][t_index]
)
Expand Down Expand Up @@ -284,7 +312,9 @@ def prepare_xy_current_data(
# arrays are identical, if they are equal in number. Verifying the time array lengths
# will avoid crashes due to interrupted simulations with incomplete data files.
try:
t_index = obs_2[0].index(t)
t_index = find_index_nearest_time_within_tolerance(
obs_2[0], t, 0.01
)
obs_data[i_bond] = 0.5 * (obs_1[1][t_index] - obs_2[1][t_index])
except ValueError:
pass
Expand Down
Empty file added test/plot/__init__.py
Empty file.
55 changes: 55 additions & 0 deletions test/plot/test_plot_routines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Tests of the plot routines.
"""
import unittest

from lindbladmpo.plot_routines import find_index_nearest_time_within_tolerance


class TestFindIndexNearestTimeWithinTolerance(unittest.TestCase):
"""The class testing the plot routines"""

def test_nearest_time_within_tolerance(self):
"""
Test find_index_nearest_time_within_tolerance when the nearest
time entry is within the specified tolerance.
"""
time_values = [0.1, 0.2, 0.3, 0.4]
target = 0.25
tolerance = 0.05
index = find_index_nearest_time_within_tolerance(time_values, target, tolerance)
self.assertEqual(index, 1) # Match at index 1 is the closest

def test_nearest_time_outside_tolerance(self):
"""
Test find_index_nearest_time_within_tolerance when the nearest
time entry is outside the specified tolerance.
"""
time_values = [0.1, 0.2, 0.3, 0.4]
target = 0.25
tolerance = 0.01
index = find_index_nearest_time_within_tolerance(time_values, target, tolerance)
self.assertEqual(index, 1) # Match at index 1 is the closest

def test_exact_time_entry(self):
"""
Test find_index_nearest_time_within_tolerance when the target
time is an exact match to one of the time entries.
"""
time_values = [0.1, 0.2, 0.3, 0.4]
target = 0.2
tolerance = 0.05
index = find_index_nearest_time_within_tolerance(time_values, target, tolerance)
self.assertEqual(index, 1) # Match at index 1


if __name__ == "__main__":
unittest.main()
Loading