Skip to content

Commit

Permalink
Uneven sharding for TGIF DI publish flow (#1518)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #1518

Reviewed By: jiayisuse

Differential Revision:
D51219009

Privacy Context Container: L1138451

fbshipit-source-id: 87eb16dce7a050511ea6d99699704fb9a81db1a3
  • Loading branch information
tissue3 authored and facebook-github-bot committed Dec 6, 2023
1 parent 1babe8e commit 9b9f3ca
Showing 1 changed file with 117 additions and 14 deletions.
131 changes: 117 additions & 14 deletions torchrec/distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def block_bucketize_ref(
keyed_jagged_tensor: KeyedJaggedTensor,
trainers_size: int,
block_sizes: torch.Tensor,
device: str = "cuda",
) -> KeyedJaggedTensor:
lengths_list = keyed_jagged_tensor.lengths().view(-1).tolist()
indices_list = keyed_jagged_tensor.values().view(-1).tolist()
Expand Down Expand Up @@ -228,21 +229,32 @@ def block_bucketize_ref(
expected_keys = [
key for index in range(trainers_size) for key in keyed_jagged_tensor.keys()
]

return KeyedJaggedTensor(
keys=expected_keys,
lengths=torch.tensor(
translated_lengths, dtype=keyed_jagged_tensor.lengths().dtype
if device == "cuda":
return KeyedJaggedTensor(
keys=expected_keys,
lengths=torch.tensor(
translated_lengths, dtype=keyed_jagged_tensor.lengths().dtype
)
.view(-1)
.cuda(),
values=torch.tensor(
translated_indices, dtype=keyed_jagged_tensor.values().dtype
).cuda(),
weights=torch.tensor(translated_weights).float().cuda()
if weights_list
else None,
)
else:
return KeyedJaggedTensor(
keys=expected_keys,
lengths=torch.tensor(
translated_lengths, dtype=keyed_jagged_tensor.lengths().dtype
).view(-1),
values=torch.tensor(
translated_indices, dtype=keyed_jagged_tensor.values().dtype
),
weights=torch.tensor(translated_weights).float() if weights_list else None,
)
.view(-1)
.cuda(),
values=torch.tensor(
translated_indices, dtype=keyed_jagged_tensor.values().dtype
).cuda(),
weights=torch.tensor(translated_weights).float().cuda()
if weights_list
else None,
)


class KJTBucketizeTest(unittest.TestCase):
Expand Down Expand Up @@ -332,6 +344,97 @@ def test_kjt_bucketize_before_all2all(
)
)

# pyre-ignore[56]
@given(
index_type=st.sampled_from([torch.int, torch.long]),
offset_type=st.sampled_from([torch.int, torch.long]),
world_size=st.integers(1, 129),
num_features=st.integers(1, 15),
batch_size=st.integers(1, 15),
variable_bucket_pos=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
def test_kjt_bucketize_before_all2all_cpu(
self,
index_type: torch.dtype,
offset_type: torch.dtype,
world_size: int,
num_features: int,
batch_size: int,
variable_bucket_pos: bool,
) -> None:
MAX_BATCH_SIZE = 15
MAX_LENGTH = 10
# max number of rows needed for a given feature to have unique row index
MAX_ROW_COUNT = MAX_LENGTH * MAX_BATCH_SIZE

lengths_list = [
random.randrange(MAX_LENGTH + 1) for _ in range(num_features * batch_size)
]
keys_list = [f"feature_{i}" for i in range(num_features)]
# for each feature, generate unrepeated row indices
indices_lists = [
random.sample(
range(MAX_ROW_COUNT),
# number of indices needed is the length sum of all batches for a feature
sum(
lengths_list[
feature_offset * batch_size : (feature_offset + 1) * batch_size
]
),
)
for feature_offset in range(num_features)
]
indices_list = list(itertools.chain(*indices_lists))

weights_list = [random.randint(1, 100) for _ in range(len(indices_list))]

# for each feature, calculate the minimum block size needed to
# distribute all rows to the available trainers
block_sizes_list = [
math.ceil((max(feature_indices_list) + 1) / world_size)
if feature_indices_list
else 1
for feature_indices_list in indices_lists
]
block_bucketize_row_pos = [] if variable_bucket_pos else None
if variable_bucket_pos:
for block_size in block_sizes_list:
# pyre-ignore
block_bucketize_row_pos.append(
torch.tensor(
[w * block_size for w in range(world_size + 1)],
dtype=index_type,
)
)

kjt = KeyedJaggedTensor(
keys=keys_list,
lengths=torch.tensor(lengths_list, dtype=offset_type).view(
num_features * batch_size
),
values=torch.tensor(indices_list, dtype=index_type),
weights=torch.tensor(weights_list, dtype=torch.float),
)
"""
each entry in block_sizes identifies how many hashes for each feature goes
to every rank; we have three featues in `self.features`
"""
block_sizes = torch.tensor(block_sizes_list, dtype=index_type)
block_bucketized_kjt, _ = bucketize_kjt_before_all2all(
kjt, world_size, block_sizes, False, False, block_bucketize_row_pos
)

expected_block_bucketized_kjt = block_bucketize_ref(
kjt, world_size, block_sizes, "cpu"
)

self.assertTrue(
keyed_jagged_tensor_equals(
block_bucketized_kjt, expected_block_bucketized_kjt
)
)


class MergeFusedParamsTest(unittest.TestCase):
def test_merge_fused_params(self) -> None:
Expand Down

0 comments on commit 9b9f3ca

Please sign in to comment.