diff --git a/src/mvesuvio/analysis_fitting.py b/src/mvesuvio/analysis_fitting.py index ca2cf61..efcc4d1 100644 --- a/src/mvesuvio/analysis_fitting.py +++ b/src/mvesuvio/analysis_fitting.py @@ -13,7 +13,7 @@ from mantid.simpleapi import AnalysisDataService from mvesuvio.util import handle_config -from mvesuvio.util.analysis_helpers import print_table_workspace, passDataIntoWS +from mvesuvio.util.analysis_helpers import print_table_workspace, pass_data_into_ws repoPath = Path(__file__).absolute().parent # Path to the repository @@ -207,7 +207,7 @@ def replaceZerosWithNCP(ws, ncp): ] # mask of ncp adjusted for last col present or not wsMasked = CloneWorkspace(ws, OutputWorkspace=ws.name() + "_NCPMasked") - passDataIntoWS(dataX, dataY, dataE, wsMasked) + pass_data_into_ws(dataX, dataY, dataE, wsMasked) SumSpectra(wsMasked, OutputWorkspace=wsMasked.name() + "_Sum") return wsMasked @@ -257,7 +257,7 @@ def dataXBining(ws, xp): dataE[dataY == 0] = 0 wsXBins = CloneWorkspace(ws, OutputWorkspace=ws.name() + "_XBinned") - wsXBins = passDataIntoWS(dataX, dataY, dataE, wsXBins) + wsXBins = pass_data_into_ws(dataX, dataY, dataE, wsXBins) return wsXBins @@ -437,7 +437,7 @@ def symmetrizeWs(avgYSpace): dataYS, dataES = weightedSymArr(dataY, dataE) wsSym = CloneWorkspace(avgYSpace, OutputWorkspace=avgYSpace.name() + "_sym") - wsSym = passDataIntoWS(dataX, dataYS, dataES, wsSym) + wsSym = pass_data_into_ws(dataX, dataYS, dataES, wsSym) return wsSym @@ -1880,5 +1880,5 @@ def plotGlobalFit(dataX, dataY, dataE, mObj, totCost, wsName, yFitIC): def save_workspaces(yFitIC): for ws_name in mtd.getObjectNames(): - if ws_name.endswith('Parameters') or ws_name.endswith('Workspace'): + if ws_name.endswith('Parameters') or ws_name.endswith('parameters') or ws_name.endswith('Workspace'): SaveAscii(ws_name, str(yFitIC.figSavePath.parent / "output_files" / ws_name)) diff --git a/src/mvesuvio/util/analysis_helpers.py b/src/mvesuvio/util/analysis_helpers.py index fc0bfaa..461fe26 100644 --- a/src/mvesuvio/util/analysis_helpers.py +++ b/src/mvesuvio/util/analysis_helpers.py @@ -11,7 +11,7 @@ import ntpath -def passDataIntoWS(dataX, dataY, dataE, ws): +def pass_data_into_ws(dataX, dataY, dataE, ws): "Modifies ws data to input data" for i in range(ws.getNumberHistograms()): ws.dataX(i)[:] = dataX[i, :] @@ -24,11 +24,11 @@ def print_table_workspace(table, precision=3): table_dict = table.toDict() # Convert floats into strings for key, values in table_dict.items(): - new_column = [int(item) if (not isinstance(item, str) and item.is_integer()) else item for item in values] + new_column = [int(item) if (isinstance(item, float) and item.is_integer()) else item for item in values] table_dict[key] = [f"{item:.{precision}f}" if isinstance(item, float) else str(item) for item in new_column] max_spacing = [max([len(item) for item in values] + [len(key)]) for key, values in table_dict.items()] - header = "|" + "|".join(f"{item}{' '*(spacing-len(item))}" for item, spacing in zip(table.keys(), max_spacing)) + "|" + header = "|" + "|".join(f"{item}{' '*(spacing-len(item))}" for item, spacing in zip(table_dict.keys(), max_spacing)) + "|" logger.notice(f"Table {table.name()}:") logger.notice(' '+'-'*(len(header)-2)+' ') logger.notice(header) @@ -255,7 +255,7 @@ def mask_time_of_flight_bins_with_zeros(ws, maskTOFRange): dataY[mask] = 0 - passDataIntoWS(dataX, dataY, dataE, ws) + pass_data_into_ws(dataX, dataY, dataE, ws) return diff --git a/tests/unit/analysis/test_analysis_helpers.py b/tests/unit/analysis/test_analysis_helpers.py index 84c7280..79652f8 100644 --- a/tests/unit/analysis/test_analysis_helpers.py +++ b/tests/unit/analysis/test_analysis_helpers.py @@ -3,10 +3,10 @@ import scipy import dill import numpy.testing as nptest -from mock import MagicMock +from mock import MagicMock, patch, call from mvesuvio.util.analysis_helpers import extractWS, _convert_dict_to_table, \ fix_profile_parameters, calculate_h_ratio, extend_range_of_array, numerical_third_derivative, \ - mask_time_of_flight_bins_with_zeros + mask_time_of_flight_bins_with_zeros, pass_data_into_ws, print_table_workspace from mantid.simpleapi import CreateWorkspace, DeleteWorkspace @@ -178,5 +178,54 @@ def test_mask_time_of_flight_bins_with_zeros(self): np.testing.assert_allclose(actual_data_y, expected_data_y) + def test_pass_data_into_ws(self): + + dataX = np.arange(20).reshape(4, 5) + dataY = np.arange(20, 40).reshape(4, 5) + dataE = np.arange(40, 60).reshape(4, 5) + + dataX_mock = np.zeros_like(dataX) + dataY_mock = np.zeros_like(dataY) + dataE_mock = np.zeros_like(dataE) + + ws_mock = MagicMock( + dataY=lambda row: dataY_mock[row], + dataX=lambda row: dataX_mock[row], + dataE=lambda row: dataE_mock[row], + getNumberHistograms=MagicMock(return_value=4) + ) + + pass_data_into_ws(dataX, dataY, dataE, ws_mock) + + np.testing.assert_allclose(dataX_mock, dataX) + np.testing.assert_allclose(dataY_mock, dataY) + np.testing.assert_allclose(dataE_mock, dataE) + + + @patch('mantid.kernel.logger.notice') + def test_print_table_workspace(self, mock_notice): + mock_table = MagicMock() + mock_table.name.return_value = "my_table" + mock_table.rowCount.return_value = 3 + mock_table.toDict.return_value = { + "names": ["1.0", "12.0", "16.0"], + "mass": [1, 12.0, 16.00000], + "width": [5, 10.3456, 15.23], + "bounds": ["[3, 6]", "[8, 13]", "[9, 17]"] + } + + print_table_workspace(mock_table, precision=2) + + mock_notice.assert_has_calls( + [call('Table my_table:'), + call(' ------------------------ '), + call('|names|mass|width|bounds |'), + call('|1.0 |1 |5 |[3, 6] |'), + call('|12.0 |12 |10.35|[8, 13]|'), + call('|16.0 |16 |15.23|[9, 17]|'), + call(' ------------------------ ')] + ) + + if __name__ == "__main__": unittest.main()