From 2e5cddadce135eb24b3ca9465fac412d5a24e23f Mon Sep 17 00:00:00 2001 From: Prathap Sridharan Date: Fri, 3 May 2024 11:32:47 -0700 Subject: [PATCH] Replace np.array_split with a list implementation --- .../experimental/ml/pytorch.py | 30 ++++++++++++++++--- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index b84cbd0f0..ad0a98e18 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -129,11 +129,17 @@ def __init__( self.obs_column_names = obs_column_names if shuffle_chunk_count: assert shuffle_rng is not None - chunk_count = len(obs_joinids_chunked) - grouped_chunks_count = chunk_count // min(chunk_count, shuffle_chunk_count) + + # At the start of this step, `obs_joinids_chunked` is a list of one dimensional + # numpy arrays. Each numpy array corresponds to a chunk of contiguous rows in `obs`. + # Critically, `obs_joinids_chunked` is randomly ordered where each chunk is + # from a random section of `obs`. + # We then take `shuffle_chunk_count` of these in order, concatenate them into + # a larger numpy array and shuffle this larger numpy array. + # The result is again a list of numpy arrays. self.obs_joinids_chunks_iter = ( shuffle_rng.permutation(np.concatenate(grouped_chunks)) - for grouped_chunks in np.array_split(obs_joinids_chunked, grouped_chunks_count) + for grouped_chunks in list_split(obs_joinids_chunked, shuffle_chunk_count) ) else: self.obs_joinids_chunks_iter = iter(obs_joinids_chunked) @@ -185,6 +191,21 @@ def __next__(self) -> _SOMAChunk: return _SOMAChunk(obs=obs_batch, X=X_batch, stats=stats) +def list_split(arr_list: List[Any], sublist_len: int) -> List[List[Any]]: + """Splits a python list into a list of sublists where each sublist is of size `sublist_len`.""" + i = 0 + result = [] + while i < len(arr_list): + if (i + sublist_len) >= len(arr_list): + result.append(arr_list[i:]) + else: + result.append(arr_list[i : i + sublist_len]) + + i += sublist_len + + return result + + def run_gc() -> Tuple[Tuple[Any, Any, Any], Tuple[Any, Any, Any]]: # noqa: D103 proc = psutil.Process(os.getpid()) @@ -455,7 +476,8 @@ def __init__( parallel with client-side processing of the SOMA data, potentially improving overall performance at the cost of doubling memory utilization. Defaults to ``True``. shuffle_chunk_count: - TODO + The number contiguous blocks (chunks) of rows to read at random and then concatenated and shuffled. + Larger number for `shuffle_chunk_count` correspond to more randomness in the shuffling. Lifecycle: experimental