Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to vectorize connectivity matrices #1308

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions xcp_d/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,21 @@ def _build_parser():
"Default is 0.5."
),
)
g_parcellation.add_argument(
"--flatten-conmats",
"--flatten_conmats",
dest="flatten_conmats",
nargs="?",
const=None,
default="auto",
choices=["y", "n"],
action=parser_utils.YesNoAction,
help=(
"Flatten connectivity matrices to 1D arrays. "
"This will reduce the dimensionality of the connectivity matrices, "
"which can be useful for some analyses."
),
)

g_dcan = parser.add_argument_group("abcd/hbcd mode options")
g_dcan.add_argument(
Expand Down Expand Up @@ -939,6 +954,7 @@ def _validate_parameters(opts, build_log, parser):
assert opts.combine_runs in (True, False, "auto")
assert opts.despike in (True, False, "auto")
assert opts.file_format in ("nifti", "cifti", "auto")
assert opts.flatten_conmats in (True, False, "auto")
assert opts.linc_qc in (True, False, "auto")
assert opts.mode in ("abcd", "hbcd", "linc", "none"), f"Unsupported mode '{opts.mode}'."
assert opts.output_type in ("censored", "interpolated", "auto")
Expand Down Expand Up @@ -967,6 +983,7 @@ def _validate_parameters(opts, build_log, parser):
opts.despike = True if (opts.despike == "auto") else opts.despike
opts.fd_thresh = 0.3 if (opts.fd_thresh == "auto") else opts.fd_thresh
opts.file_format = "cifti" if (opts.file_format == "auto") else opts.file_format
opts.flatten_conmats = True if (opts.flatten_conmats == "auto") else opts.flatten_conmats
opts.input_type = "fmriprep" if opts.input_type == "auto" else opts.input_type
opts.linc_qc = True if (opts.linc_qc == "auto") else opts.linc_qc
if opts.motion_filter_type is None:
Expand All @@ -993,6 +1010,7 @@ def _validate_parameters(opts, build_log, parser):
opts.despike = True if (opts.despike == "auto") else opts.despike
opts.fd_thresh = 0.3 if (opts.fd_thresh == "auto") else opts.fd_thresh
opts.file_format = "cifti" if (opts.file_format == "auto") else opts.file_format
opts.flatten_conmats = True if (opts.flatten_conmats == "auto") else opts.flatten_conmats
opts.input_type = "nibabies" if opts.input_type == "auto" else opts.input_type
opts.linc_qc = True if (opts.linc_qc == "auto") else opts.linc_qc
if opts.motion_filter_type is None:
Expand All @@ -1016,6 +1034,7 @@ def _validate_parameters(opts, build_log, parser):
opts.despike = True if (opts.despike == "auto") else opts.despike
opts.fd_thresh = 0 if (opts.fd_thresh == "auto") else opts.fd_thresh
opts.file_format = "cifti" if (opts.file_format == "auto") else opts.file_format
opts.flatten_conmats = False if (opts.flatten_conmats == "auto") else opts.flatten_conmats
opts.input_type = "fmriprep" if opts.input_type == "auto" else opts.input_type
opts.linc_qc = True if (opts.linc_qc == "auto") else opts.linc_qc
opts.output_correlations = True
Expand Down Expand Up @@ -1052,6 +1071,9 @@ def _validate_parameters(opts, build_log, parser):
if opts.file_format == "auto":
error_messages.append("'--file-format' is required for 'none' mode.")

if opts.flatten_conmats == "auto":
error_messages.append("'--flatten-conmats' (y or n) is required for 'none' mode.")

if opts.input_type == "auto":
error_messages.append("'--input-type' is required for 'none' mode.")

Expand Down
2 changes: 2 additions & 0 deletions xcp_d/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,8 @@ class workflow(_Config):
"""Coverage threshold to apply to parcels in each atlas."""
dcan_correlation_lengths = None
"""Produce correlation matrices limited to each requested amount of time."""
flatten_conmats = None
"""Flatten the correlation matrices."""
process_surfaces = None
"""Warp fsnative-space surfaces to the MNI space."""
abcc_qc = None
Expand Down
118 changes: 115 additions & 3 deletions xcp_d/interfaces/connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,16 @@
mandatory=False,
desc="Temporal mask, after dummy scan removal.",
)
flatten = traits.Bool(
False,
usedefault=True,
desc="Flatten the correlation matrix to a TSV file.",
)


class _TSVConnectOutputSpec(TraitedSpec):
correlations = File(exists=True, desc="Correlation matrix file.")
correlations_square = File(exists=True, desc="Square correlation matrix file.")
correlations_exact = traits.Either(
None,
traits.List(File(exists=True)),
Expand Down Expand Up @@ -256,6 +262,17 @@
return correlations_df, correlations_exact


def flatten_conmat(df):
"""Flatten a correlation matrix."""
df = df.where(np.triu(np.ones(df.shape[0])).astype(bool))
df = df.stack().reset_index()
df.columns = ["Row", "Column", "Value"]
df["Edge"] = df["Row"] + "-" + df["Column"]
df = df.set_index("Edge")
df = df[["Value"]].T
return df

Check warning on line 273 in xcp_d/interfaces/connectivity.py

View check run for this annotation

Codecov / codecov/patch

xcp_d/interfaces/connectivity.py#L267-L273

Added lines #L267 - L273 were not covered by tests


class TSVConnect(SimpleInterface):
"""Extract timeseries and compute connectivity matrices.

Expand All @@ -273,6 +290,23 @@
temporal_mask=self.inputs.temporal_mask,
)

self._results["correlations_square"] = fname_presuffix(
"correlations_square.tsv",
newpath=runtime.cwd,
use_ext=True,
)
correlations_df.to_csv(
self._results["correlations_square"],
sep="\t",
na_rep="n/a",
index_label="Node",
)
if self.inputs.flatten:
correlations_df = flatten_conmat(correlations_df)
kwargs = {"index": False}

Check warning on line 306 in xcp_d/interfaces/connectivity.py

View check run for this annotation

Codecov / codecov/patch

xcp_d/interfaces/connectivity.py#L305-L306

Added lines #L305 - L306 were not covered by tests
else:
kwargs = {"index_label": "Node"}

self._results["correlations"] = fname_presuffix(
"correlations.tsv",
newpath=runtime.cwd,
Expand All @@ -282,7 +316,7 @@
self._results["correlations"],
sep="\t",
na_rep="n/a",
index_label="Node",
**kwargs,
)
del correlations_df
gc.collect()
Expand All @@ -298,11 +332,14 @@
newpath=runtime.cwd,
use_ext=True,
)
if self.inputs.flatten:
exact_correlations_df = flatten_conmat(exact_correlations_df)

Check warning on line 336 in xcp_d/interfaces/connectivity.py

View check run for this annotation

Codecov / codecov/patch

xcp_d/interfaces/connectivity.py#L336

Added line #L336 was not covered by tests

exact_correlations_df.to_csv(
exact_correlations_file,
sep="\t",
na_rep="n/a",
index_label="Node",
**kwargs,
)
self._results["correlations_exact"].append(exact_correlations_file)

Expand Down Expand Up @@ -539,10 +576,16 @@
desc="Parcellated CIFTI file to extract into a TSV.",
)
atlas_labels = File(exists=True, mandatory=True, desc="atlas labels file")
flatten = traits.Bool(
False,
usedefault=True,
desc="Flatten the correlation matrix to a TSV file.",
)


class _CiftiToTSVOutputSpec(TraitedSpec):
out_file = File(exists=True, desc="Parcellated data TSV file.")
correlations_square = File(desc="Square correlation matrix TSV file.")


class CiftiToTSV(SimpleInterface):
Expand Down Expand Up @@ -657,7 +700,25 @@
)

if in_file.endswith(".pconn.nii"):
df.to_csv(self._results["out_file"], sep="\t", na_rep="n/a", index_label="Node")
self._results["correlations_square"] = fname_presuffix(
"correlations_square.tsv",
newpath=runtime.cwd,
use_ext=True,
)
df.to_csv(
self._results["correlations_square"],
sep="\t",
na_rep="n/a",
index_label="Node",
)

if self.inputs.flatten:
df = flatten_conmat(df)
kwargs = {"index": False}

Check warning on line 717 in xcp_d/interfaces/connectivity.py

View check run for this annotation

Codecov / codecov/patch

xcp_d/interfaces/connectivity.py#L716-L717

Added lines #L716 - L717 were not covered by tests
else:
kwargs = {"index_label": "Node"}

df.to_csv(self._results["out_file"], sep="\t", na_rep="n/a", **kwargs)
else:
df.to_csv(self._results["out_file"], sep="\t", na_rep="n/a", index=False)

Expand Down Expand Up @@ -769,3 +830,54 @@
write_ndata(vertex_weights_arr, template=data_file, filename=self._results["mask_file"])

return runtime


class _FlattenTSVInputSpec(BaseInterfaceInputSpec):
in_file = File(exists=True, desc="Correlation matrix or labels TSV file.")
kind = traits.Enum("conmat", "labels", mandatory=True, desc="Input format.")


class _FlattenTSVOutputSpec(TraitedSpec):
out_file = File(exists=True, desc="Flattened time series TSV file.")


class FlattenTSV(SimpleInterface):
"""Flatten a correlation matrix TSV file."""

input_spec = _FlattenTSVInputSpec
output_spec = _FlattenTSVOutputSpec

def _run_interface(self, runtime):
in_file = self.inputs.in_file

Check warning on line 851 in xcp_d/interfaces/connectivity.py

View check run for this annotation

Codecov / codecov/patch

xcp_d/interfaces/connectivity.py#L851

Added line #L851 was not covered by tests

if self.inputs.kind == "conmat":
df = pd.read_table(in_file, index_col="Node")
df = df.where(np.triu(np.ones(df.shape[0])).astype(bool))
df = df.stack().reset_index()
df.columns = ["Row", "Column", "Value"]
df["Edge"] = df["Row"] + "-" + df["Column"]
df = df.set_index("Edge")
df = df[["Value"]].T

Check warning on line 860 in xcp_d/interfaces/connectivity.py

View check run for this annotation

Codecov / codecov/patch

xcp_d/interfaces/connectivity.py#L854-L860

Added lines #L854 - L860 were not covered by tests
elif self.inputs.kind == "labels":
df = pd.read_table(in_file)
df = pd.DataFrame(

Check warning on line 863 in xcp_d/interfaces/connectivity.py

View check run for this annotation

Codecov / codecov/patch

xcp_d/interfaces/connectivity.py#L862-L863

Added lines #L862 - L863 were not covered by tests
columns=df["label"].tolist(),
index=df["label"].tolist(),
data=np.ones((df.shape[0], df.shape[0])),
)
df = df.where(np.triu(np.ones(df.shape[0])).astype(bool))
df = df.stack().reset_index()
df.columns = ["Source", "Target", "Value"]
df["Edge"] = df["Source"] + "-" + df["Target"]
df = df[["Edge", "Source", "Target"]]

Check warning on line 872 in xcp_d/interfaces/connectivity.py

View check run for this annotation

Codecov / codecov/patch

xcp_d/interfaces/connectivity.py#L868-L872

Added lines #L868 - L872 were not covered by tests

# Save out the TSV
self._results["out_file"] = fname_presuffix(

Check warning on line 875 in xcp_d/interfaces/connectivity.py

View check run for this annotation

Codecov / codecov/patch

xcp_d/interfaces/connectivity.py#L875

Added line #L875 was not covered by tests
in_file,
prefix="flattened_",
newpath=runtime.cwd,
use_ext=True,
)
df.to_csv(self._results["out_file"], sep="\t", na_rep="n/a", index=False)

Check warning on line 881 in xcp_d/interfaces/connectivity.py

View check run for this annotation

Codecov / codecov/patch

xcp_d/interfaces/connectivity.py#L881

Added line #L881 was not covered by tests

return runtime

Check warning on line 883 in xcp_d/interfaces/connectivity.py

View check run for this annotation

Codecov / codecov/patch

xcp_d/interfaces/connectivity.py#L883

Added line #L883 was not covered by tests
2 changes: 1 addition & 1 deletion xcp_d/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ def test_ds001419_nifti(data_dir, output_dir, working_dir):
"--band-stop-min=6",
"--skip-parcellation",
"--min-time=100",
"--combine-runs",
"--output-type=censored",
"--combine-runs=y",
"--linc-qc=y",
"--abcc-qc=n",
"--despike=n",
"--flatten-conmats=y",
"--file-format=nifti",
"--input-type=fmriprep",
"--warp-surfaces-native2std=n",
Expand Down
6 changes: 6 additions & 0 deletions xcp_d/tests/test_cli_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def base_opts():
"motion_filter_order": None,
"process_surfaces": "auto",
"atlases": ["Glasser"],
"flatten_conmats": "auto",
"dcan_correlation_lengths": None,
"despike": "auto",
"abcc_qc": "auto",
Expand Down Expand Up @@ -276,6 +277,7 @@ def test_validate_parameters_linc_mode(base_opts, base_parser, capsys):
assert opts.abcc_qc is False
assert opts.linc_qc is True
assert opts.file_format == "cifti"
assert opts.flatten_conmats is False

# --create-matrices is not supported
opts.dcan_correlation_lengths = [300]
Expand All @@ -302,6 +304,7 @@ def test_validate_parameters_abcd_mode(base_opts, base_parser, capsys):
assert opts.despike is True
assert opts.fd_thresh == 0.3
assert opts.file_format == "cifti"
assert opts.flatten_conmats is True
assert opts.input_type == "fmriprep"
assert opts.linc_qc is True
assert opts.output_correlations is False
Expand Down Expand Up @@ -337,6 +340,7 @@ def test_validate_parameters_hbcd_mode(base_opts, base_parser, capsys):
assert opts.despike is True
assert opts.fd_thresh == 0.3
assert opts.file_format == "cifti"
assert opts.flatten_conmats is True
assert opts.input_type == "nibabies"
assert opts.linc_qc is True
assert opts.output_correlations is False
Expand Down Expand Up @@ -370,6 +374,7 @@ def test_validate_parameters_none_mode(base_opts, base_parser, capsys):
assert "'--despike' (y or n) is required for 'none' mode." in stderr
assert "'--fd-thresh' is required for 'none' mode." in stderr
assert "'--file-format' is required for 'none' mode." in stderr
assert "'--flatten-conmats' (y or n) is required for 'none' mode." in stderr
assert "'--input-type' is required for 'none' mode." in stderr
assert "'--linc-qc' (y or n) is required for 'none' mode." in stderr
assert "'--motion-filter-type' is required for 'none' mode." in stderr
Expand All @@ -383,6 +388,7 @@ def test_validate_parameters_none_mode(base_opts, base_parser, capsys):
opts.despike = False
opts.fd_thresh = 0
opts.file_format = "nifti"
opts.flatten_conmats = False
opts.input_type = "fmriprep"
opts.linc_qc = False
opts.motion_filter_type = "none"
Expand Down
2 changes: 2 additions & 0 deletions xcp_d/tests/test_workflows_connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def test_init_functional_connectivity_nifti_wf(ds001419_data, tmp_path_factory):
config.execution.output_dir = tmpdir
config.workflow.bandpass_filter = False
config.workflow.min_coverage = 0.5
config.workflow.flatten_conmats = False
config.nipype.omp_nthreads = 2
config.execution.atlases = atlas_names
config.workflow.output_correlations = True
Expand Down Expand Up @@ -305,6 +306,7 @@ def test_init_functional_connectivity_cifti_wf(ds001419_data, tmp_path_factory):
with mock_config():
config.execution.output_dir = tmpdir
config.workflow.bandpass_filter = False
config.workflow.flatten_conmats = False
config.workflow.min_coverage = 0.5
config.nipype.omp_nthreads = 2
config.execution.atlases = atlas_names
Expand Down
12 changes: 7 additions & 5 deletions xcp_d/workflows/bold/connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def init_functional_connectivity_nifti_wf(mem_gb, name="connectivity_wf"):

if config.workflow.output_correlations:
functional_connectivity = pe.MapNode(
TSVConnect(),
TSVConnect(flatten=config.workflow.flatten_conmats),
name="functional_connectivity",
iterfield=["timeseries"],
mem_gb=mem_gb["timeseries"],
Expand All @@ -156,7 +156,9 @@ def init_functional_connectivity_nifti_wf(mem_gb, name="connectivity_wf"):
("atlases", "atlases"),
("atlas_labels_files", "atlas_tsvs"),
]),
(functional_connectivity, connectivity_plot, [("correlations", "correlations_tsv")]),
(functional_connectivity, connectivity_plot, [
("correlations_square", "correlations_tsv"),
]),
]) # fmt:skip

ds_report_connectivity_plot = pe.Node(
Expand Down Expand Up @@ -420,7 +422,7 @@ def init_functional_connectivity_cifti_wf(mem_gb, exact_scans, name="connectivit

# Convert correlation pconn file to TSV
dconn_to_tsv = pe.MapNode(
CiftiToTSV(),
CiftiToTSV(flatten=config.workflow.flatten_conmats),
name="dconn_to_tsv",
iterfield=["in_file", "atlas_labels"],
)
Expand All @@ -441,7 +443,7 @@ def init_functional_connectivity_cifti_wf(mem_gb, exact_scans, name="connectivit
("atlases", "atlases"),
("atlas_labels_files", "atlas_tsvs"),
]),
(dconn_to_tsv, connectivity_plot, [("out_file", "correlations_tsv")]),
(dconn_to_tsv, connectivity_plot, [("correlations_square", "correlations_tsv")]),
]) # fmt:skip

ds_report_connectivity = pe.Node(
Expand Down Expand Up @@ -499,7 +501,7 @@ def init_functional_connectivity_cifti_wf(mem_gb, exact_scans, name="connectivit

# Convert correlation pconn file to TSV
exact_dconn_to_tsv = pe.MapNode(
CiftiToTSV(),
CiftiToTSV(flatten=config.workflow.flatten_conmats),
name=f"dconn_to_tsv_{exact_scan}volumes",
iterfield=["in_file", "atlas_labels"],
)
Expand Down