Skip to content

Commit

Permalink
Merge pull request #268 from mgxd/enh/runwise-reference-image
Browse files Browse the repository at this point in the history
ENH: Runwise bold reference generation
  • Loading branch information
mgxd authored Jan 19, 2023
2 parents 737e62c + 7718b99 commit b7c4e3c
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 112 deletions.
99 changes: 32 additions & 67 deletions nibabies/workflows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,80 +408,44 @@ def init_single_subject_wf(subject_id, session_id=None):

# Append the functional section to the existing anatomical exerpt
# That way we do not need to stream down the number of bold datasets
anat_preproc_wf.__postdesc__ = (
(anat_preproc_wf.__postdesc__ if hasattr(anat_preproc_wf, "__postdesc__") else "")
+ f"""
anat_preproc_wf.__postdesc__ = getattr(anat_preproc_wf, '__postdesc__') or ''
func_pre_desc = f"""
Functional data preprocessing
: For each of the {len(subject_data['bold'])} BOLD runs found per subject (across all
tasks and sessions), the following preprocessing was performed.
"""
)

# calculate reference image(s) for BOLD images
# group all BOLD files based on same:
# 1) session
# 2) PE direction
# 3) total readout time
from niworkflows.workflows.epi.refmap import init_epi_reference_wf

bold_groupings = group_bolds_ref(
layout=config.execution.layout,
subject=subject_id,
sessions=[session_id],
)
tasks and sessions), the following preprocessing was performed."""

func_preproc_wfs = []
has_fieldmap = bool(fmap_estimators)
for idx, grouping in enumerate(bold_groupings.values()):
bold_ref_wf = init_epi_reference_wf(
auto_bold_nss=True,
name=f"bold_reference_wf{idx}",
omp_nthreads=config.nipype.omp_nthreads,
)
bold_files = grouping.files
bold_ref_wf.inputs.inputnode.in_files = grouping.files

if grouping.multiecho_id is not None:
bold_files = [bold_files]
for idx, bold_file in enumerate(bold_files):
func_preproc_wf = init_func_preproc_wf(
bold_file,
has_fieldmap=has_fieldmap,
existing_derivatives=derivatives,
)
# fmt: off
workflow.connect([
(bold_ref_wf, func_preproc_wf, [
('outputnode.epi_ref_file', 'inputnode.bold_ref'),
(
('outputnode.xfm_files', _select_iter_idx, idx),
'inputnode.bold_ref_xfm'),
(
('outputnode.n_dummy', _select_iter_idx, idx),
'inputnode.n_dummy_scans'),
]),
(anat_preproc_wf, func_preproc_wf, [
('outputnode.anat_preproc', 'inputnode.anat_preproc'),
('outputnode.anat_mask', 'inputnode.anat_mask'),
('outputnode.anat_brain', 'inputnode.anat_brain'),
('outputnode.anat_dseg', 'inputnode.anat_dseg'),
('outputnode.anat_aseg', 'inputnode.anat_aseg'),
('outputnode.anat_aparc', 'inputnode.anat_aparc'),
('outputnode.anat_tpms', 'inputnode.anat_tpms'),
('outputnode.template', 'inputnode.template'),
('outputnode.anat2std_xfm', 'inputnode.anat2std_xfm'),
('outputnode.std2anat_xfm', 'inputnode.std2anat_xfm'),
# Undefined if --fs-no-reconall, but this is safe
('outputnode.subjects_dir', 'inputnode.subjects_dir'),
('outputnode.subject_id', 'inputnode.subject_id'),
('outputnode.t1w2fsnative_xfm', 'inputnode.t1w2fsnative_xfm'),
('outputnode.fsnative2t1w_xfm', 'inputnode.fsnative2t1w_xfm'),
]),
])
# fmt: on
func_preproc_wfs.append(func_preproc_wf)
for bold_file in subject_data['bold']:
func_preproc_wf = init_func_preproc_wf(bold_file, has_fieldmap=has_fieldmap)
if func_preproc_wf is None:
continue

func_preproc_wf.__desc__ = func_pre_desc + (getattr(func_preproc_wf, '__desc__') or '')
# fmt:off
workflow.connect([
(anat_preproc_wf, func_preproc_wf, [
('outputnode.anat_preproc', 'inputnode.anat_preproc'),
('outputnode.anat_mask', 'inputnode.anat_mask'),
('outputnode.anat_brain', 'inputnode.anat_brain'),
('outputnode.anat_dseg', 'inputnode.anat_dseg'),
('outputnode.anat_aseg', 'inputnode.anat_aseg'),
('outputnode.anat_aparc', 'inputnode.anat_aparc'),
('outputnode.anat_tpms', 'inputnode.anat_tpms'),
('outputnode.template', 'inputnode.template'),
('outputnode.anat2std_xfm', 'inputnode.anat2std_xfm'),
('outputnode.std2anat_xfm', 'inputnode.std2anat_xfm'),
# Undefined if --fs-no-reconall, but this is safe
('outputnode.subjects_dir', 'inputnode.subjects_dir'),
('outputnode.subject_id', 'inputnode.subject_id'),
('outputnode.t1w2fsnative_xfm', 'inputnode.t1w2fsnative_xfm'),
('outputnode.fsnative2t1w_xfm', 'inputnode.fsnative2t1w_xfm'),
]),
])
# fmt:on
func_preproc_wfs.append(func_preproc_wf)

if not has_fieldmap:
config.loggers.workflow.warning(
Expand All @@ -506,6 +470,7 @@ def init_single_subject_wf(subject_id, session_id=None):
subject=subject_id,
)
fmap_wf.__desc__ = f"""
Preprocessing of B<sub>0</sub> inhomogeneity mappings
: A total of {len(fmap_estimators)} fieldmaps were found available within the input
Expand Down
71 changes: 39 additions & 32 deletions nibabies/workflows/bold/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from ...interfaces.reports import FunctionalSummary
from ...utils.bids import extract_entities
from ...utils.misc import combine_meepi_source
from .boldref import init_infant_epi_reference_wf

# BOLD workflows
from .confounds import init_bold_confs_wf, init_carpetplot_wf
Expand Down Expand Up @@ -127,12 +128,6 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
LTA-style affine matrix translating from T1w to FreeSurfer-conformed subject space
fsnative2t1w_xfm
LTA-style affine matrix translating from FreeSurfer-conformed subject space to T1w
bold_ref
BOLD reference file
bold_ref_xfm
Transform file in LTA format from bold to reference
n_dummy_scans
Number of nonsteady states at the beginning of the BOLD run
Outputs
-------
Expand Down Expand Up @@ -177,6 +172,7 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
"""
from niworkflows.engine.workflows import LiterateWorkflow as Workflow
from niworkflows.interfaces.bold import NonsteadyStatesDetector
from niworkflows.interfaces.nibabel import ApplyMask
from niworkflows.interfaces.utility import DictMerge, KeySelect
from niworkflows.workflows.epi.refmap import init_epi_reference_wf
Expand Down Expand Up @@ -244,9 +240,14 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
)

# Find associated sbref, if possible
entities["suffix"] = "sbref"
entities["extension"] = [".nii", ".nii.gz"] # Overwrite extensions
sbref_files = layout.get(scope="raw", return_type="file", **entities)
overrides = {
"suffix": "sbref",
"extension": [".nii", ".nii.gz"],
}
if config.execution.bids_filters:
overrides.update(config.execution.bids_filters.get('sbref', {}))
sb_ents = {**entities, **overrides}
sbref_files = layout.get(return_type="file", **sb_ents)

sbref_msg = f"No single-band-reference found for {os.path.basename(ref_file)}."
if sbref_files and "sbref" in config.workflow.ignore:
Expand Down Expand Up @@ -319,10 +320,6 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
"anat2std_xfm",
"std2anat_xfm",
"template",
# from bold reference workflow
"bold_ref",
"bold_ref_xfm",
"n_dummy_scans",
# from sdcflows (optional)
"fmap",
"fmap_ref",
Expand Down Expand Up @@ -514,12 +511,21 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
)
bold_confounds_wf.get_node("inputnode").inputs.t1_transform_flags = [False]

dummy_buffer = pe.Node(niu.IdentityInterface(fields=['n_dummy']), name='dummy_buffer')
if (dummy := config.workflow.dummy_scans) is not None:
dummy_buffer.inputs.n_dummy = dummy
else:
# Detect dummy scans
nss_detector = pe.Node(NonsteadyStatesDetector(), name='nss_detector')
nss_detector.inputs.in_file = ref_file
workflow.connect(nss_detector, 'n_dummy', dummy_buffer, 'n_dummy')

# SLICE-TIME CORRECTION (or bypass) #############################################
if run_stc:
bold_stc_wf = init_bold_stc_wf(name="bold_stc_wf", metadata=metadata)
# fmt:off
workflow.connect([
(inputnode, bold_stc_wf, [('n_dummy_scans', 'inputnode.skip_vols')]),
(dummy_buffer, bold_stc_wf, [('n_dummy', 'inputnode.skip_vols')]),
(select_bold, bold_stc_wf, [("out", 'inputnode.bold_file')]),
(bold_stc_wf, boldbuffer, [('outputnode.stc_file', 'bold_file')]),
])
Expand Down Expand Up @@ -577,8 +583,11 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
name="bold_final",
)

# Mask input BOLD reference image
initial_boldref_mask = pe.Node(BrainExtraction(), name="initial_boldref_mask")
# Create a reference image for the bold run
initial_boldref_wf = init_infant_epi_reference_wf(omp_nthreads, is_sbref=bool(sbref_files))
initial_boldref_wf.inputs.inputnode.epi_file = (
pop_file(sbref_files) if sbref_files else ref_file
)

# This final boldref will be calculated after bold_bold_trans_wf, which includes one or more:
# HMC (head motion correction)
Expand All @@ -602,8 +611,8 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
# BOLD buffer has slice-time corrected if it was run, original otherwise
(boldbuffer, bold_split, [('bold_file', 'in_file')]),
# HMC
(inputnode, bold_hmc_wf, [
('bold_ref', 'inputnode.raw_ref_image')]),
(initial_boldref_wf, bold_hmc_wf, [
('outputnode.boldref_file', 'inputnode.raw_ref_image')]),
(validate_bolds, bold_hmc_wf, [
(("out_file", pop_file), 'inputnode.bold_file')]),
(bold_hmc_wf, outputnode, [
Expand Down Expand Up @@ -659,8 +668,8 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
('outputnode.rmsd_file', 'inputnode.rmsd_file')]),
(bold_reg_wf, bold_confounds_wf, [
('outputnode.itk_t1_to_bold', 'inputnode.t1_bold_xform')]),
(inputnode, bold_confounds_wf, [
('n_dummy_scans', 'inputnode.skip_vols')]),
(dummy_buffer, bold_confounds_wf, [
('n_dummy', 'inputnode.skip_vols')]),
(bold_final, bold_confounds_wf, [
('bold', 'inputnode.bold'),
('mask', 'inputnode.bold_mask'),
Expand All @@ -672,7 +681,7 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
('outputnode.tcompcor_mask', 'tcompcor_mask'),
]),
# Summary
(inputnode, summary, [('n_dummy_scans', 'algo_dummy_scans')]),
(dummy_buffer, summary, [('n_dummy', 'algo_dummy_scans')]),
(bold_reg_wf, summary, [('outputnode.fallback', 'fallback')]),
(outputnode, summary, [('confounds', 'confounds_file')]),
# Select echo indices for original/validated BOLD files
Expand Down Expand Up @@ -874,8 +883,8 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
('bold_file', 'inputnode.name_source')]),
(bold_hmc_wf, ica_aroma_wf, [
('outputnode.movpar_file', 'inputnode.movpar_file')]),
(inputnode, ica_aroma_wf, [
('n_dummy_scans', 'inputnode.skip_vols')]),
(dummy_buffer, ica_aroma_wf, [
('n_dummy', 'inputnode.skip_vols')]),
(bold_confounds_wf, join, [
('outputnode.confounds_file', 'in_file')]),
(bold_confounds_wf, mrg_conf_metadata,
Expand Down Expand Up @@ -1051,9 +1060,8 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
("outputnode.bold", "inputnode.in_files"),
]),
] if not multiecho else [
(inputnode, initial_boldref_mask, [('bold_ref', 'in_file')]),
(initial_boldref_mask, bold_t2s_wf, [
("out_mask", "inputnode.bold_mask"),
(initial_boldref_wf, bold_t2s_wf, [
("outputnode.boldref_mask", "inputnode.bold_mask"),
]),
(bold_bold_trans_wf, join_echos, [
("outputnode.bold", "bold_files"),
Expand Down Expand Up @@ -1125,14 +1133,13 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
("fmap_coeff", "inputnode.fmap_coeff"),
("fmap_mask", "inputnode.fmap_mask")]),
(output_select, summary, [("sdc_method", "distortion_correction")]),
(inputnode, initial_boldref_mask, [('bold_ref', 'in_file')]),
(inputnode, coeff2epi_wf, [
("bold_ref", "inputnode.target_ref")]),
(initial_boldref_mask, coeff2epi_wf, [
("out_mask", "inputnode.target_mask")]), # skull-stripped brain
(initial_boldref_wf, coeff2epi_wf, [
("outputnode.boldref_file", "inputnode.target_ref")]),
(initial_boldref_wf, coeff2epi_wf, [
("outputnode.boldref_mask", "inputnode.target_mask")]), # skull-stripped brain
(coeff2epi_wf, unwarp_wf, [
("outputnode.fmap_coeff", "inputnode.fmap_coeff")]),
(inputnode, sdc_report, [("bold_ref", "before")]),
(initial_boldref_wf, sdc_report, [("outputnode.boldref_file", "before")]),
(bold_hmc_wf, unwarp_wf, [
("outputnode.xforms", "inputnode.hmc_xforms")]),
(bold_split, unwarp_wf, [
Expand Down
97 changes: 97 additions & 0 deletions nibabies/workflows/bold/boldref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import nipype.interfaces.utility as niu
import nipype.pipeline.engine as pe


def init_infant_epi_reference_wf(
omp_nthreads: int,
is_sbref: bool = False,
start_frame: int = 17,
name: str = 'infant_epi_reference_wf',
) -> pe.Workflow:
"""
Workflow to generate a reference map from one or more infant EPI images.
If any single-band references are provided, the reference map will be calculated from those.
If no single-band references are provided, the BOLD files are used.
To account for potential increased motion on the start of image acquisition, this
workflow discards a bigger chunk of the initial frames.
Parameters
----------
omp_nthreads
Maximum number of threads an individual process may use
has_sbref
A single-band reference is provided.
start_frame
BOLD frame to start creating the reference map from. Any earlier frames are discarded.
Inputs
------
bold_file
BOLD EPI file
sbref_file
single-band reference EPI
Outputs
-------
boldref_file
The generated reference map
boldref_mask
Binary brain mask of the ``boldref_file``
boldref_xfm
Rigid-body transforms in LTA format
"""
from niworkflows.workflows.epi.refmap import init_epi_reference_wf
from sdcflows.interfaces.brainmask import BrainExtraction

wf = pe.Workflow(name=name)

inputnode = pe.Node(
niu.IdentityInterface(fields=['epi_file']),
name='inputnode',
)
outputnode = pe.Node(
niu.IdentityInterface(fields=['boldref_file', 'boldref_mask']),
name='outputnode',
)

epi_reference_wf = init_epi_reference_wf(omp_nthreads)

boldref_mask = pe.Node(BrainExtraction(), name='boldref_mask')

# fmt:off
wf.connect([
(inputnode, epi_reference_wf, [('epi_file', 'inputnode.in_files')]),
(epi_reference_wf, boldref_mask, [('outputnode.epi_ref_file', 'in_file')]),
(epi_reference_wf, outputnode, [('outputnode.epi_ref_file', 'boldref_file')]),
(boldref_mask, outputnode, [('out_mask', 'boldref_mask')]),
])
# fmt:on
if not is_sbref:
select_frames = pe.Node(
niu.Function(function=_select_frames, output_names=['t_masks']),
name='select_frames',
)
select_frames.inputs.start_frame = start_frame
# fmt:off
wf.connect([
(inputnode, select_frames, [('epi_file', 'in_file')]),
(select_frames, epi_reference_wf, [('t_masks', 'inputnode.t_masks')]),
])
# fmt:on
return wf


def _select_frames(in_file: str, start_frame: int) -> list:
import nibabel as nb
import numpy as np

img = nb.load(in_file)
img_len = img.shape[3]
if start_frame >= img_len:
start_frame = img_len - 1
t_mask = np.array([False] * img_len, dtype=bool)
t_mask[start_frame:] = True
return list(t_mask)
Loading

0 comments on commit b7c4e3c

Please sign in to comment.