Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
GuiMacielPereira committed Jan 7, 2025
1 parent 51e2777 commit dc9b2b7
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 11 deletions.
10 changes: 5 additions & 5 deletions src/mvesuvio/analysis_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from mantid.simpleapi import AnalysisDataService

from mvesuvio.util import handle_config
from mvesuvio.util.analysis_helpers import print_table_workspace, passDataIntoWS
from mvesuvio.util.analysis_helpers import print_table_workspace, pass_data_into_ws

repoPath = Path(__file__).absolute().parent # Path to the repository

Expand Down Expand Up @@ -207,7 +207,7 @@ def replaceZerosWithNCP(ws, ncp):
] # mask of ncp adjusted for last col present or not

wsMasked = CloneWorkspace(ws, OutputWorkspace=ws.name() + "_NCPMasked")
passDataIntoWS(dataX, dataY, dataE, wsMasked)
pass_data_into_ws(dataX, dataY, dataE, wsMasked)
SumSpectra(wsMasked, OutputWorkspace=wsMasked.name() + "_Sum")
return wsMasked

Expand Down Expand Up @@ -257,7 +257,7 @@ def dataXBining(ws, xp):
dataE[dataY == 0] = 0

wsXBins = CloneWorkspace(ws, OutputWorkspace=ws.name() + "_XBinned")
wsXBins = passDataIntoWS(dataX, dataY, dataE, wsXBins)
wsXBins = pass_data_into_ws(dataX, dataY, dataE, wsXBins)
return wsXBins


Expand Down Expand Up @@ -437,7 +437,7 @@ def symmetrizeWs(avgYSpace):
dataYS, dataES = weightedSymArr(dataY, dataE)

wsSym = CloneWorkspace(avgYSpace, OutputWorkspace=avgYSpace.name() + "_sym")
wsSym = passDataIntoWS(dataX, dataYS, dataES, wsSym)
wsSym = pass_data_into_ws(dataX, dataYS, dataES, wsSym)
return wsSym


Expand Down Expand Up @@ -1880,5 +1880,5 @@ def plotGlobalFit(dataX, dataY, dataE, mObj, totCost, wsName, yFitIC):

def save_workspaces(yFitIC):
for ws_name in mtd.getObjectNames():
if ws_name.endswith('Parameters') or ws_name.endswith('Workspace'):
if ws_name.endswith('Parameters') or ws_name.endswith('parameters') or ws_name.endswith('Workspace'):
SaveAscii(ws_name, str(yFitIC.figSavePath.parent / "output_files" / ws_name))
8 changes: 4 additions & 4 deletions src/mvesuvio/util/analysis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import ntpath


def passDataIntoWS(dataX, dataY, dataE, ws):
def pass_data_into_ws(dataX, dataY, dataE, ws):
"Modifies ws data to input data"
for i in range(ws.getNumberHistograms()):
ws.dataX(i)[:] = dataX[i, :]
Expand All @@ -24,11 +24,11 @@ def print_table_workspace(table, precision=3):
table_dict = table.toDict()
# Convert floats into strings
for key, values in table_dict.items():
new_column = [int(item) if (not isinstance(item, str) and item.is_integer()) else item for item in values]
new_column = [int(item) if (isinstance(item, float) and item.is_integer()) else item for item in values]
table_dict[key] = [f"{item:.{precision}f}" if isinstance(item, float) else str(item) for item in new_column]

max_spacing = [max([len(item) for item in values] + [len(key)]) for key, values in table_dict.items()]
header = "|" + "|".join(f"{item}{' '*(spacing-len(item))}" for item, spacing in zip(table.keys(), max_spacing)) + "|"
header = "|" + "|".join(f"{item}{' '*(spacing-len(item))}" for item, spacing in zip(table_dict.keys(), max_spacing)) + "|"
logger.notice(f"Table {table.name()}:")
logger.notice(' '+'-'*(len(header)-2)+' ')
logger.notice(header)
Expand Down Expand Up @@ -255,7 +255,7 @@ def mask_time_of_flight_bins_with_zeros(ws, maskTOFRange):

dataY[mask] = 0

passDataIntoWS(dataX, dataY, dataE, ws)
pass_data_into_ws(dataX, dataY, dataE, ws)
return


Expand Down
53 changes: 51 additions & 2 deletions tests/unit/analysis/test_analysis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import scipy
import dill
import numpy.testing as nptest
from mock import MagicMock
from mock import MagicMock, patch, call
from mvesuvio.util.analysis_helpers import extractWS, _convert_dict_to_table, \
fix_profile_parameters, calculate_h_ratio, extend_range_of_array, numerical_third_derivative, \
mask_time_of_flight_bins_with_zeros
mask_time_of_flight_bins_with_zeros, pass_data_into_ws, print_table_workspace
from mantid.simpleapi import CreateWorkspace, DeleteWorkspace


Expand Down Expand Up @@ -178,5 +178,54 @@ def test_mask_time_of_flight_bins_with_zeros(self):
np.testing.assert_allclose(actual_data_y, expected_data_y)


def test_pass_data_into_ws(self):

dataX = np.arange(20).reshape(4, 5)
dataY = np.arange(20, 40).reshape(4, 5)
dataE = np.arange(40, 60).reshape(4, 5)

dataX_mock = np.zeros_like(dataX)
dataY_mock = np.zeros_like(dataY)
dataE_mock = np.zeros_like(dataE)

ws_mock = MagicMock(
dataY=lambda row: dataY_mock[row],
dataX=lambda row: dataX_mock[row],
dataE=lambda row: dataE_mock[row],
getNumberHistograms=MagicMock(return_value=4)
)

pass_data_into_ws(dataX, dataY, dataE, ws_mock)

np.testing.assert_allclose(dataX_mock, dataX)
np.testing.assert_allclose(dataY_mock, dataY)
np.testing.assert_allclose(dataE_mock, dataE)


@patch('mantid.kernel.logger.notice')
def test_print_table_workspace(self, mock_notice):
mock_table = MagicMock()
mock_table.name.return_value = "my_table"
mock_table.rowCount.return_value = 3
mock_table.toDict.return_value = {
"names": ["1.0", "12.0", "16.0"],
"mass": [1, 12.0, 16.00000],
"width": [5, 10.3456, 15.23],
"bounds": ["[3, 6]", "[8, 13]", "[9, 17]"]
}

print_table_workspace(mock_table, precision=2)

mock_notice.assert_has_calls(
[call('Table my_table:'),
call(' ------------------------ '),
call('|names|mass|width|bounds |'),
call('|1.0 |1 |5 |[3, 6] |'),
call('|12.0 |12 |10.35|[8, 13]|'),
call('|16.0 |16 |15.23|[9, 17]|'),
call(' ------------------------ ')]
)


if __name__ == "__main__":
unittest.main()

0 comments on commit dc9b2b7

Please sign in to comment.