From a8e10ebf316f80bbf24aea5e783066a0f90e42e8 Mon Sep 17 00:00:00 2001 From: Robert Smith Date: Thu, 18 Jun 2020 13:16:08 +1000 Subject: [PATCH] dwi2mask: Refactor image input & output - Centralise code for generating a mean b=0 image, and allow individual algorithms to request or decline generation of that method, rather than duplicating the code for generating such. - Have the execute() function of each algorithm return the filesystem path of the final generated mask, so that the code responsible for exporting the mask to the user-specified output location can be centralised rather than duplicated in each algorithm. --- bin/dwi2mask | 24 +++++++++++++++++++++++- lib/mrtrix3/dwi2mask/ants.py | 21 +++++++-------------- lib/mrtrix3/dwi2mask/fslbet.py | 32 ++++++++++++++------------------ lib/mrtrix3/dwi2mask/hdbet.py | 22 +++++++--------------- lib/mrtrix3/dwi2mask/legacy.py | 12 +++++++----- lib/mrtrix3/dwi2mask/template.py | 25 +++++++------------------ 6 files changed, 65 insertions(+), 71 deletions(-) diff --git a/bin/dwi2mask b/bin/dwi2mask index 549f55fd5b..771b6b468e 100755 --- a/bin/dwi2mask +++ b/bin/dwi2mask @@ -59,8 +59,30 @@ def execute(): #pylint: disable=unused-variable app.goto_scratch_dir() + # Generate a mean b=0 image (common task in many algorithms) + if alg.needs_mean_bzero(): + run.command('dwiextract input.mif -bzero - | ' + 'mrmath - mean - -axis 3 | ' + 'mrconvert - bzero.nii -strides +1,+2,+3') + # From here, the script splits depending on what algorithm is being used - alg.execute() + # The return value of the execute() function should be the name of the + # image in the scratch directory that is to be exported + mask_path = alg.execute() + + # Export the mask image + # Make relative strides of three spatial axes of output mask equivalent + # to input DWI; this may involve decrementing magnitude of stride + # if the input DWI is volume-contiguous + strides = image.Header('input.mif').strides()[0:3] + strides = [(abs(value) + 1 - min(abs(v) for v in strides)) * (-1 if value < 0 else 1) for value in strides] + run.command('mrconvert ' + + mask_path + + ' ' + + path.from_user(app.ARGS.output) + + ' -strides ' + ','.join(str(value) for value in strides), + mrconvert_keyval=path.from_user(app.ARGS.input, False), + force=app.FORCE_OVERWRITE) diff --git a/lib/mrtrix3/dwi2mask/ants.py b/lib/mrtrix3/dwi2mask/ants.py index c85a48a208..c6f23e0d91 100644 --- a/lib/mrtrix3/dwi2mask/ants.py +++ b/lib/mrtrix3/dwi2mask/ants.py @@ -16,7 +16,7 @@ import os from distutils.spawn import find_executable from mrtrix3 import MRtrixError -from mrtrix3 import app, image, path, run +from mrtrix3 import app, path, run ANTS_BRAIN_EXTRACTION_CMD = 'antsBrainExtraction.sh' @@ -45,6 +45,11 @@ def get_inputs(): #pylint: disable=unused-variable +def needs_mean_bzero(): #pylint: disable=unused-variable + return True + + + def execute(): #pylint: disable=unused-variable ants_path = os.environ.get('ANTSPATH', '') if not ants_path: @@ -55,11 +60,6 @@ def execute(): #pylint: disable=unused-variable + ANTS_BRAIN_EXTRACTION_CMD + '"; please check ANTs installation') - # Produce mean b=0 image - run.command('dwiextract input.mif -bzero - | ' - 'mrmath - mean - -axis 3 | ' - 'mrconvert - bzero.nii -strides +1,+2,+3') - run.command(ANTS_BRAIN_EXTRACTION_CMD + ' -d 3' + ' -c 3x3x2x1' @@ -70,11 +70,4 @@ def execute(): #pylint: disable=unused-variable + ('' if app.DO_CLEANUP else ' -k 1') + (' -z' if app.VERBOSITY >= 3 else '')) - strides = image.Header('input.mif').strides()[0:3] - strides = [(abs(value) + 1 - min(abs(v) for v in strides)) * (-1 if value < 0 else 1) for value in strides] - - run.command('mrconvert outBrainExtractionMask.nii.gz ' - + path.from_user(app.ARGS.output) - + ' -strides ' + ','.join(str(value) for value in strides), - mrconvert_keyval=path.from_user(app.ARGS.input, False), - force=app.FORCE_OVERWRITE) + return 'outBrainExtractionMask.nii.gz' diff --git a/lib/mrtrix3/dwi2mask/fslbet.py b/lib/mrtrix3/dwi2mask/fslbet.py index cd20abd4e4..ce8f8814df 100644 --- a/lib/mrtrix3/dwi2mask/fslbet.py +++ b/lib/mrtrix3/dwi2mask/fslbet.py @@ -15,7 +15,7 @@ import os from mrtrix3 import MRtrixError -from mrtrix3 import app, fsl, image, path, run +from mrtrix3 import app, fsl, image, run @@ -40,23 +40,23 @@ def get_inputs(): #pylint: disable=unused-variable +def needs_mean_bzero(): #pylint: disable=unused-variable + return True + + + def execute(): #pylint: disable=unused-variable if not os.environ.get('FSLDIR', ''): raise MRtrixError('Environment variable FSLDIR is not set; please run appropriate FSL configuration script') bet_cmd = fsl.exe_name('bet') - # Calculating mean b=0 image for BET - run.command('dwiextract input.mif - -bzero | ' - 'mrmath - mean - -axis 3 | ' - 'mrconvert - mean_bzero.nii -strides +1,+2,+3') - # Starting brain masking using BET if app.ARGS.rescale: - run.command('mrconvert mean_bzero.nii mean_bzero_rescaled.nii -vox 1,1,1') - vox = image.Header('mean_bzero.nii').spacing() - b0_image = 'mean_bzero_rescaled.nii' + run.command('mrconvert bzero.nii bzero_rescaled.nii -vox 1,1,1') + vox = image.Header('bzero.nii').spacing() + b0_image = 'bzero_rescaled.nii' else: - b0_image = 'mean_bzero.nii' + b0_image = 'bzero.nii' cmd_string = bet_cmd + ' ' + b0_image + ' DWI_BET -R -m' @@ -73,11 +73,7 @@ def execute(): #pylint: disable=unused-variable run.command(cmd_string) mask = fsl.find_image('DWI_BET_mask') - strides = image.Header('input.mif').strides()[0:3] - strides = [(abs(value) + 1 - min(abs(v) for v in strides)) * (-1 if value < 0 else 1) for value in strides] - run.command('mrconvert ' + mask + ' ' + path.from_user(app.ARGS.output) - + (' -vox ' + ','.join(str(value) for value in vox) if app.ARGS.rescale else '') - + ' -strides ' + ','.join(str(value) for value in strides) - + ' -datatype bit', - mrconvert_keyval=path.from_user(app.ARGS.input, False), - force=app.FORCE_OVERWRITE) + if app.ARGS.rescale: + run.command('mrconvert ' + mask + ' mask_rescaled.nii -vox ' + ','.join(str(value) for value in vox)) + return 'mask_rescaled.nii' + return mask diff --git a/lib/mrtrix3/dwi2mask/hdbet.py b/lib/mrtrix3/dwi2mask/hdbet.py index 85650b2b22..03d158d6c2 100644 --- a/lib/mrtrix3/dwi2mask/hdbet.py +++ b/lib/mrtrix3/dwi2mask/hdbet.py @@ -15,7 +15,7 @@ from distutils.spawn import find_executable from mrtrix3 import MRtrixError -from mrtrix3 import app, image, path, run +from mrtrix3 import app, run @@ -34,16 +34,16 @@ def get_inputs(): #pylint: disable=unused-variable +def needs_mean_bzero(): #pylint: disable=unused-variable + return True + + + def execute(): #pylint: disable=unused-variable hdbet_cmd = find_executable('hd-bet') if not hdbet_cmd: raise MRtrixError('Unable to locate "hd-bet" executable; check installation') - # Produce mean b=0 image - run.command('dwiextract input.mif -bzero - | ' - 'mrmath - mean - -axis 3 | ' - 'mrconvert - bzero.nii -strides +1,+2,+3') - # GPU version is not guaranteed to work; # attempt CPU version if that is the case try: @@ -59,12 +59,4 @@ def execute(): #pylint: disable=unused-variable exception_stderr = gpu_header + e_gpu.stderr + '\n\n' + cpu_header + e_cpu.stderr + '\n\n' raise run.MRtrixCmdError('hd-bet', 1, exception_stdout, exception_stderr) - strides = image.Header('input.mif').strides()[0:3] - strides = [(abs(value) + 1 - min(abs(v) for v in strides)) * (-1 if value < 0 else 1) for value in strides] - - run.command('mrconvert bzero_bet_mask.nii.gz ' - + path.from_user(app.ARGS.output) - + ' -strides ' + ','.join(str(value) for value in strides) - + ' -datatype bit', - mrconvert_keyval=path.from_user(app.ARGS.input, False), - force=app.FORCE_OVERWRITE) + return 'bzero_bet_mask.nii.gz' diff --git a/lib/mrtrix3/dwi2mask/legacy.py b/lib/mrtrix3/dwi2mask/legacy.py index ebe3d66e4d..74ac1855f0 100644 --- a/lib/mrtrix3/dwi2mask/legacy.py +++ b/lib/mrtrix3/dwi2mask/legacy.py @@ -13,7 +13,7 @@ # # For more details, see http://www.mrtrix.org/. -from mrtrix3 import app, path, run +from mrtrix3 import app, run DEFAULT_CLEAN_SCALE = 2 @@ -39,6 +39,11 @@ def get_inputs(): #pylint: disable=unused-variable +def needs_mean_bzero(): #pylint: disable=unused-variable + return False + + + def execute(): #pylint: disable=unused-variable run.command('mrcalc input.mif 0 -max input_nonneg.mif') @@ -54,7 +59,4 @@ def execute(): #pylint: disable=unused-variable run.command('mrmath input_nonneg.mif max -axis 3 - | ' 'mrcalc - 0.0 -gt init_mask.mif -mult final_mask.mif -datatype bit') - run.command('mrconvert final_mask.mif ' - + path.from_user(app.ARGS.output), - mrconvert_keyval=path.from_user(app.ARGS.input, False), - force=app.FORCE_OVERWRITE) + return 'final_mask.mif' diff --git a/lib/mrtrix3/dwi2mask/template.py b/lib/mrtrix3/dwi2mask/template.py index 0cafa05bdd..b4dcf69512 100644 --- a/lib/mrtrix3/dwi2mask/template.py +++ b/lib/mrtrix3/dwi2mask/template.py @@ -16,7 +16,7 @@ import os from distutils.spawn import find_executable from mrtrix3 import MRtrixError -from mrtrix3 import app, fsl, image, path, run +from mrtrix3 import app, fsl, path, run SOFTWARES = ['ants', 'fsl'] @@ -54,6 +54,11 @@ def get_inputs(): #pylint: disable=unused-variable +def needs_mean_bzero(): #pylint: disable=unused-variable + return True + + + def execute(): #pylint: disable=unused-variable # What image to generate here depends on the template: # - If a good T2-weighted template is found, use the mean b=0 image @@ -97,11 +102,6 @@ def execute(): #pylint: disable=unused-variable else: assert False - # Produce mean b=0 image - run.command('dwiextract input.mif -bzero - | ' - 'mrmath - mean - -axis 3 | ' - 'mrconvert - bzero.nii -strides +1,+2,+3') - if reg_software == 'ants': # Use ANTs SyN for registration to template @@ -175,15 +175,4 @@ def execute(): #pylint: disable=unused-variable run.command('mrthreshold ' + transformed_path + ' mask.mif -abs 0.5') - - # Make relative strides of three spatial axes of output mask equivalent - # to input DWI; this may involve decrementing magnitude of stride - # if the input DWI is volume-contiguous - strides = image.Header('input.mif').strides()[0:3] - strides = [(abs(value) + 1 - min(abs(v) for v in strides)) * (-1 if value < 0 else 1) for value in strides] - - run.command('mrconvert mask.mif ' - + path.from_user(app.ARGS.output) - + ' -strides ' + ','.join(str(value) for value in strides), - mrconvert_keyval=path.from_user(app.ARGS.input, False), - force=app.FORCE_OVERWRITE) + return 'mask.mif'