diff --git a/torchrec/distributed/tests/test_utils.py b/torchrec/distributed/tests/test_utils.py index 330a560d9..171fe4a64 100644 --- a/torchrec/distributed/tests/test_utils.py +++ b/torchrec/distributed/tests/test_utils.py @@ -173,6 +173,7 @@ def block_bucketize_ref( keyed_jagged_tensor: KeyedJaggedTensor, trainers_size: int, block_sizes: torch.Tensor, + device: str = "cuda", ) -> KeyedJaggedTensor: lengths_list = keyed_jagged_tensor.lengths().view(-1).tolist() indices_list = keyed_jagged_tensor.values().view(-1).tolist() @@ -228,21 +229,32 @@ def block_bucketize_ref( expected_keys = [ key for index in range(trainers_size) for key in keyed_jagged_tensor.keys() ] - - return KeyedJaggedTensor( - keys=expected_keys, - lengths=torch.tensor( - translated_lengths, dtype=keyed_jagged_tensor.lengths().dtype + if device == "cuda": + return KeyedJaggedTensor( + keys=expected_keys, + lengths=torch.tensor( + translated_lengths, dtype=keyed_jagged_tensor.lengths().dtype + ) + .view(-1) + .cuda(), + values=torch.tensor( + translated_indices, dtype=keyed_jagged_tensor.values().dtype + ).cuda(), + weights=torch.tensor(translated_weights).float().cuda() + if weights_list + else None, + ) + else: + return KeyedJaggedTensor( + keys=expected_keys, + lengths=torch.tensor( + translated_lengths, dtype=keyed_jagged_tensor.lengths().dtype + ).view(-1), + values=torch.tensor( + translated_indices, dtype=keyed_jagged_tensor.values().dtype + ), + weights=torch.tensor(translated_weights).float() if weights_list else None, ) - .view(-1) - .cuda(), - values=torch.tensor( - translated_indices, dtype=keyed_jagged_tensor.values().dtype - ).cuda(), - weights=torch.tensor(translated_weights).float().cuda() - if weights_list - else None, - ) class KJTBucketizeTest(unittest.TestCase): @@ -332,6 +344,97 @@ def test_kjt_bucketize_before_all2all( ) ) + # pyre-ignore[56] + @given( + index_type=st.sampled_from([torch.int, torch.long]), + offset_type=st.sampled_from([torch.int, torch.long]), + world_size=st.integers(1, 129), + num_features=st.integers(1, 15), + batch_size=st.integers(1, 15), + variable_bucket_pos=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) + def test_kjt_bucketize_before_all2all_cpu( + self, + index_type: torch.dtype, + offset_type: torch.dtype, + world_size: int, + num_features: int, + batch_size: int, + variable_bucket_pos: bool, + ) -> None: + MAX_BATCH_SIZE = 15 + MAX_LENGTH = 10 + # max number of rows needed for a given feature to have unique row index + MAX_ROW_COUNT = MAX_LENGTH * MAX_BATCH_SIZE + + lengths_list = [ + random.randrange(MAX_LENGTH + 1) for _ in range(num_features * batch_size) + ] + keys_list = [f"feature_{i}" for i in range(num_features)] + # for each feature, generate unrepeated row indices + indices_lists = [ + random.sample( + range(MAX_ROW_COUNT), + # number of indices needed is the length sum of all batches for a feature + sum( + lengths_list[ + feature_offset * batch_size : (feature_offset + 1) * batch_size + ] + ), + ) + for feature_offset in range(num_features) + ] + indices_list = list(itertools.chain(*indices_lists)) + + weights_list = [random.randint(1, 100) for _ in range(len(indices_list))] + + # for each feature, calculate the minimum block size needed to + # distribute all rows to the available trainers + block_sizes_list = [ + math.ceil((max(feature_indices_list) + 1) / world_size) + if feature_indices_list + else 1 + for feature_indices_list in indices_lists + ] + block_bucketize_row_pos = [] if variable_bucket_pos else None + if variable_bucket_pos: + for block_size in block_sizes_list: + # pyre-ignore + block_bucketize_row_pos.append( + torch.tensor( + [w * block_size for w in range(world_size + 1)], + dtype=index_type, + ) + ) + + kjt = KeyedJaggedTensor( + keys=keys_list, + lengths=torch.tensor(lengths_list, dtype=offset_type).view( + num_features * batch_size + ), + values=torch.tensor(indices_list, dtype=index_type), + weights=torch.tensor(weights_list, dtype=torch.float), + ) + """ + each entry in block_sizes identifies how many hashes for each feature goes + to every rank; we have three featues in `self.features` + """ + block_sizes = torch.tensor(block_sizes_list, dtype=index_type) + block_bucketized_kjt, _ = bucketize_kjt_before_all2all( + kjt, world_size, block_sizes, False, False, block_bucketize_row_pos + ) + + expected_block_bucketized_kjt = block_bucketize_ref( + kjt, world_size, block_sizes, "cpu" + ) + + self.assertTrue( + keyed_jagged_tensor_equals( + block_bucketized_kjt, expected_block_bucketized_kjt + ) + ) + class MergeFusedParamsTest(unittest.TestCase): def test_merge_fused_params(self) -> None: