diff --git a/dypac/bascpp.py b/dypac/bascpp.py index 18b8ace..af4e06a 100644 --- a/dypac/bascpp.py +++ b/dypac/bascpp.py @@ -15,8 +15,7 @@ def _select_subsample(y, subsample_size, start=None): - """ Select a random subsample in a data array - """ + """Select a random subsample in a data array.""" n_samples = y.shape[1] subsample_size = np.min([subsample_size, n_samples]) max_start = n_samples - subsample_size @@ -30,8 +29,8 @@ def _select_subsample(y, subsample_size, start=None): def _part2onehot(part, n_clusters=0): - """ Convert a series of partition (one per row) with integer clusters into - a series of one-hot encoding vectors (one per row and cluster). + """Convert a series of partition (one per row) with integer clusters into + a series of one-hot encoding vectors (one per row and cluster). """ if n_clusters == 0: n_clusters = np.max(part) + 1 @@ -50,8 +49,7 @@ def _part2onehot(part, n_clusters=0): def _start_window(n_time, n_replications, subsample_size): - """ Get a list of the starting points of sliding windows. - """ + """Get a list of the starting points of sliding windows.""" max_replications = n_time - subsample_size + 1 n_replications = np.min([max_replications, n_replications]) list_start = np.linspace(0, max_replications, n_replications) @@ -61,13 +59,12 @@ def _start_window(n_time, n_replications, subsample_size): def _trim_states(onehot, states, n_states, verbose, threshold_sim): - """Trim the states clusters to exclude outliers - """ + """Trim the states clusters to exclude outliers.""" for ss in tqdm(range(n_states), disable=not verbose, desc="Trimming states"): [ix, iy, val] = find(onehot[states == ss, :]) size_onehot = np.array(onehot[states == ss, :].sum(axis=1)).flatten() ref_cluster = np.array(onehot[states == ss, :].mean(dtype="float", axis=0)) - avg_stab = np.bincount(ix, weights=ref_cluster[0, iy].flatten()) + avg_stab = np.bincount(ix, weights=ref_cluster[0,iy].flatten()) avg_stab = np.divide(avg_stab, size_onehot) tmp = states[states == ss] tmp[avg_stab < threshold_sim] = n_states @@ -76,18 +73,9 @@ def _trim_states(onehot, states, n_states, verbose, threshold_sim): def replicate_clusters( - y, - subsample_size, - n_clusters, - n_replications, - max_iter, - n_init, - verbose, - embedding=np.array([]), - desc="", - normalize=False, + y, subsample_size, n_clusters, n_replications, max_iter, n_init, random_state=None, verbose=False, embedding=np.array([]), desc="", normalize=False ): - """ Replicate a clustering on random subsamples + """Replicate a clustering on random subsamples Parameters ---------- @@ -138,53 +126,51 @@ def replicate_clusters( init="k-means++", max_iter=max_iter, n_init=n_init, + random_state=random_state, ) return _part2onehot(part, n_clusters) -def find_states( - onehot, - n_states=10, - max_iter=30, - threshold_sim=0.3, - n_batch=0, - n_init=10, - verbose=False, -): - """Find dynamic states based on the similarity of clusters over time - """ +def find_states(onehot, n_states=10, max_iter=30, threshold_sim=0.3, n_batch=0, n_init=10, random_state=None, verbose=False): + """Find dynamic states based on the similarity of clusters over time.""" if verbose: print("Consensus clustering.") cent, states, inert = k_means( - onehot, n_clusters=n_states, init="k-means++", max_iter=max_iter, n_init=n_init, + onehot, + n_clusters=n_states, + init="k-means++", + max_iter=max_iter, + random_state=random_state, + n_init=n_init, ) states = _trim_states(onehot, states, n_states, verbose, threshold_sim) return states -def stab_maps(onehot, states, n_replications, n_states): - """Generate stability maps associated with different states - """ - +def stab_maps(onehot, states, n_replications, n_states, weights=None): + """Generate stability maps associated with different states.""" dwell_time = np.zeros(n_states) val = np.array([]) col_ind = np.array([]) row_ind = np.array([]) for ss in range(0, n_states): - dwell_time[ss] = np.sum(states == ss) / n_replications + if np.any(weights == None): + dwell_time[ss] = np.sum(states == ss) / n_replications + else: + dwell_time[ss] = np.mean(weights[states == ss]) if np.any(states == ss): stab_map = onehot[states == ss, :].mean(dtype="float", axis=0) mask = stab_map > 0 - col_ind = np.append(col_ind, np.repeat(ss, np.sum(mask))) - row_ind = np.append(row_ind, np.nonzero(mask)[1]) + row_ind = np.append(row_ind, np.repeat(ss, np.sum(mask))) + col_ind = np.append(col_ind, np.nonzero(mask)[1]) val = np.append(val, stab_map[mask]) - stab_maps = csr_matrix((val, (row_ind, col_ind)), shape=[onehot.shape[1], n_states]) + stab_maps = csr_matrix((val, (row_ind, col_ind)), shape=[n_states, onehot.shape[1]]) # Re-order stab maps by descending dwell time indsort = np.argsort(-dwell_time) - stab_maps = stab_maps[:, indsort] + stab_maps = stab_maps[indsort, :] dwell_time = dwell_time[indsort] return stab_maps, dwell_time diff --git a/dypac/dypac.py b/dypac/dypac.py index be67ec1..5914a85 100644 --- a/dypac/dypac.py +++ b/dypac/dypac.py @@ -1,30 +1,25 @@ -""" -Dynamic Parcel Aggregation with Clustering (dypac) -""" +"""Dynamic Parcel Aggregation with Clustering (dypac).""" # Authors: Pierre Bellec, Amal Boukhdir # License: BSD 3 clause import glob import itertools +import warnings -from tqdm import tqdm - -import bascpp as bpp -from scipy.sparse import csr_matrix, vstack, find +from scipy.sparse import vstack import numpy as np -#from .bascpp import replicate_clusters, find_states, stab_maps +from sklearn.cluster import k_means from sklearn.utils import check_random_state from sklearn.linear_model import LinearRegression -from sklearn.preprocessing import scale from nilearn import EXPAND_PATH_WILDCARDS -from nilearn._utils.compat import Memory, Parallel, delayed, _basestring -from nilearn._utils.niimg import _safe_get_data +from nilearn._utils.compat import Memory, _basestring from nilearn._utils.niimg_conversions import _resolve_globbing from nilearn.input_data.masker_validation import check_embedded_nifti_masker from nilearn.decomposition.base import BaseDecomposition -from nilearn.image import new_img_like + +import bascpp as bpp class dypac(BaseDecomposition): @@ -225,7 +220,6 @@ def fit(self, imgs, confounds=None): """ # Base fit for decomposition estimators : compute the embedded masker - if isinstance(imgs, _basestring): if EXPAND_PATH_WILDCARDS and glob.has_magic(imgs): imgs = _resolve_globbing(imgs) @@ -253,31 +247,60 @@ def fit(self, imgs, confounds=None): self.masker_.fit() self.mask_img_ = self.masker_.mask_img_ + # if no confounds have been specified, match length of imgs + if confounds is None: + confounds = list(itertools.repeat(confounds, len(imgs))) + + # Control random number generation + self.random_state = check_random_state(self.random_state) + + # Check that number of batches is reasonable + if self.n_batch > len(imgs): + warnings.warn("{0} batches were requested, but only {1} datasets " + "avaible. Using one dataset per batch instead.".format(self.n_batch, len(imgs))) + self.n_batch = len(imgs) + # mask_and_reduce step - if self.verbose: - print("[{0}] Loading data".format(self.__class__.__name__)) - onehot = self._mask_and_reduce(imgs, confounds) + if (self.n_batch > 1): + stab_maps, dwell_time = self._mask_and_reduce_batch(imgs, confounds) + else: + stab_maps, dwell_time = self._mask_and_reduce(imgs, confounds) - # find the states - states = bpp.find_states( - onehot, - n_states=self.n_states, + # Return components + self.components_ = stab_maps + self.dwell_time_ = dwell_time + return self + + + def _mask_and_reduce_batch(self, imgs, confounds=None): + """Iterate dypac on batches of files.""" + for bb in range(self.n_batch): + slice_batch = slice(bb, len(imgs), self.n_batch) + if self.verbose: + print("[{0}] Processing batch {1}".format(self.__class__.__name__, bb)) + stab_maps, dwell_time = self._mask_and_reduce(imgs[slice_batch], confounds[slice_batch]) + if bb == 0: + stab_maps_all = stab_maps + dwell_time_all = dwell_time + else: + stab_maps_all = vstack([stab_maps_all, stab_maps]) + dwell_time_all = np.concatenate([dwell_time_all, dwell_time]) + + # Consensus clustering step + _, states_all, _ = k_means( + stab_maps_all, + n_clusters=self.n_states, + init="k-means++", max_iter=self.max_iter, - threshold_sim=self.threshold_sim, - n_batch=self.n_batch, + random_state=self.random_state, n_init=self.n_init, - verbose=self.verbose, ) - # Generate the stability maps - stab_maps, dwell_time = bpp.stab_maps( - onehot, states, self.n_replications, self.n_states - ) + # average stability maps and dwell times across consensus states + stab_maps_cons, dwell_time_cons = bpp.stab_maps(stab_maps_all, states_all, + self.n_replications, self.n_states, dwell_time_all) - # Return components - self.components_ = stab_maps.transpose() - self.dwell_time_ = dwell_time - return self + return stab_maps_cons, dwell_time_cons def _mask_and_reduce(self, imgs, confounds=None): """Uses cluster aggregation over sliding windows to estimate @@ -285,44 +308,50 @@ def _mask_and_reduce(self, imgs, confounds=None): Returns ------ - stab_maps: ndarray or memorymap + stab_maps: ndarray Concatenation of dynamic parcels across all datasets. """ - if not hasattr(imgs, "__iter__"): - imgs = [imgs] - - if confounds is None: - confounds = itertools.repeat(confounds) - onehot = csr_matrix([0,]) for ind, img, confound in zip(range(len(imgs)), imgs, confounds): - if ind > 0: - onehot = vstack( - [onehot, self._mask_and_cluster_single(img=img, confound=confound, ind=ind)] - ) + this_data = self.masker_.transform(img, confound) + # Now get rid of the img as fast as possible, to free a + # reference count on it, and possibly free the corresponding + # data + del img + onehot = bpp.replicate_clusters( + this_data.transpose(), + subsample_size=self.subsample_size, + n_clusters=self.n_clusters, + n_replications=self.n_replications, + max_iter=self.max_iter, + n_init=self.n_init, + random_state=self.random_state, + desc="Replicating clusters in data #{0}".format(ind), + verbose=self.verbose, + ) + if ind == 0: + onehot_all = onehot else: - onehot = self._mask_and_cluster_single(img=img, confound=confound, ind=ind) - return onehot + onehot_all = vstack([onehot_all, onehot]) - def _mask_and_cluster_single(self, img, confound, ind): - """Utility function for _mask_and_reduce""" - this_data = self.masker_.transform(img, confound) - # Now get rid of the img as fast as possible, to free a - # reference count on it, and possibly free the corresponding - # data - del img - random_state = check_random_state(self.random_state) - onehot = bpp.replicate_clusters( - this_data.transpose(), - subsample_size=self.subsample_size, - n_clusters=self.n_clusters, - n_replications=self.n_replications, + # find the states + states = bpp.find_states( + onehot, + n_states=self.n_states, max_iter=self.max_iter, + threshold_sim=self.threshold_sim, + n_batch=self.n_batch, + random_state=self.random_state, n_init=self.n_init, - desc="Replicating clusters in data #{0}".format(ind), verbose=self.verbose, ) - return onehot + + # Generate the stability maps + stab_maps, dwell_time = bpp.stab_maps( + onehot, states, self.n_replications, self.n_states + ) + + return stab_maps, dwell_time def transform_sparse(self, img, confound=None): """Transform a 4D dataset in a component space""" diff --git a/dypac/test_bascpp.py b/dypac/test_bascpp.py new file mode 100644 index 0000000..c222b4f --- /dev/null +++ b/dypac/test_bascpp.py @@ -0,0 +1,20 @@ +import numpy as np +import bascpp as bpp + + +def test_propagate_part(): + + # examples of batch-level and consensus-level partitions + # The following two batches are integer partitions on time windows + part_batch = np.array([0, 1, 1, 2, 2, 0, 0, 0, 1, 2]) + part_cons = np.array([2, 1, 0, 2, 0, 1]) + + # indices defining the batch + n_batch = 2 + index_cons = [0, 3, 6] + + # Manually derived solution + part_ground_truth = np.array([2, 1, 1, 0, 0, 2, 2, 2, 0, 1]) + + part = bpp._propagate_part(part_batch, part_cons, n_batch, index_cons) + assert np.min(part == part_ground_truth) diff --git a/dypac/test_dypac.py b/dypac/test_dypac.py index 48775fb..6e4a82a 100644 --- a/dypac/test_dypac.py +++ b/dypac/test_dypac.py @@ -1,10 +1,9 @@ import numpy as np -import pytest import dypac as dp def _simu_tseries(n_time, n_roi, n_clusters, alpha): - """ Simulate time series with a cluster structure for multiple ROIs. + """Simulate time series with a cluster structure for multiple ROIs. Returns: y (n_roi x n_time) the time series gt (n_roi) the ground truth partition @@ -19,21 +18,3 @@ def _simu_tseries(n_time, n_roi, n_clusters, alpha): y[cluster, :] = noise[cluster, :] + alpha * np.repeat(sig, ind[cc + 1] - ind[cc],0) # y = noise + a * signal gt[cluster] = cc # Adding the label for cluster in ground truth return y, gt - - -def test_propagate_part(): - - # examples of batch-level and consensus-level partitions - # The following two batches are integer partitions on time windows - part_batch = np.array([0, 1, 1, 2, 2, 0, 0, 0, 1, 2]) - part_cons = np.array([2, 1, 0, 2, 0, 1]) - - # indices defining the batch - n_batch = 2 - index_cons = [0, 3, 6] - - # Manually derived solution - part_ground_truth = np.array([2, 1, 1, 0, 0, 2, 2, 2, 0, 1]) - - part = dp._propagate_part(part_batch, part_cons, n_batch, index_cons) - assert np.min(part == part_ground_truth) diff --git a/examples/.gitignore b/examples/.gitignore new file mode 100644 index 0000000..763513e --- /dev/null +++ b/examples/.gitignore @@ -0,0 +1 @@ +.ipynb_checkpoints diff --git a/examples/dypac_fmri_compression.ipynb b/examples/dypac_fmri_compression.ipynb index a7dac1d..ce03162 100644 --- a/examples/dypac_fmri_compression.ipynb +++ b/examples/dypac_fmri_compression.ipynb @@ -35,8 +35,8 @@ } ], "source": [ - "adhd_dataset = datasets.fetch_adhd(n_subjects=1)\n", - "epi_filename = adhd_dataset.func[0]" + "adhd_dataset = datasets.fetch_adhd(n_subjects=10)\n", + "epi_filename = adhd_dataset.func" ] }, { @@ -48,10 +48,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "[MultiNiftiMasker.fit] Loading data from [/home/pbellec/nilearn_data/adhd/data/0010042/0010042_rest_tshift_RPI_voreg_mni.nii.gz]\n", + "[MultiNiftiMasker.fit] Loading data from [/home/pbellec/nilearn_data/adhd/data/0010042/0010042_rest_tshift_RPI_voreg_mni.nii.gz, /home/pbellec/nilearn_data/adhd/data/0010064/0010064_rest_tshift_RPI_voreg_mni.nii.gz, /home/pbellec/nilearn_dat\n", "[MultiNiftiMasker.fit] Computing mask\n", "[MultiNiftiMasker.transform] Resampling mask\n", - "[dypac] Loading data\n", + "[dypac] Processing batch 0\n", "[MultiNiftiMasker.transform_single_imgs] Loading data from Nifti1Image('/home/pbellec/nilearn_data/adhd/data/0010042/0010042_rest_tshift_RPI_voreg_mni.nii.gz')\n", "[MultiNiftiMasker.transform_single_imgs] Smoothing images\n", "[MultiNiftiMasker.transform_single_imgs] Extracting region signals\n", @@ -62,7 +62,201 @@ "name": "stderr", "output_type": "stream", "text": [ - "Replicating clusters in data #0: 100%|██████████| 20/20 [00:04<00:00, 4.08it/s]\n" + "Replicating clusters in data #0: 100%|██████████| 30/30 [00:08<00:00, 3.50it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[MultiNiftiMasker.transform_single_imgs] Loading data from Nifti1Image('/home/pbellec/nilearn_data/adhd/data/0010128/0010128_rest_tshift_RPI_voreg_mni.nii.gz')\n", + "[MultiNiftiMasker.transform_single_imgs] Smoothing images\n", + "[MultiNiftiMasker.transform_single_imgs] Extracting region signals\n", + "[MultiNiftiMasker.transform_single_imgs] Cleaning extracted signals\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Replicating clusters in data #1: 100%|██████████| 30/30 [00:09<00:00, 3.16it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[MultiNiftiMasker.transform_single_imgs] Loading data from Nifti1Image('/home/pbellec/nilearn_data/adhd/data/0023008/0023008_rest_tshift_RPI_voreg_mni.nii.gz')\n", + "[MultiNiftiMasker.transform_single_imgs] Smoothing images\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Replicating clusters in data #2: 0%| | 0/30 [00:00" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "num_comp = 18\n", "comp = model.masker_.inverse_transform(model.components_[num_comp,:].todense())\n", "plotting.view_img(comp, threshold=0.1, vmax=1, title=\"Dwell time: {dt}\".format(dt=model.dwell_time_[num_comp]))" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.components_.shape" + ] + }, { "cell_type": "code", "execution_count": null, @@ -219,7 +503,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.9" + "version": "3.7.5rc1" } }, "nbformat": 4,