From 434e5dc1b6fbfb70b16b06f5acbb5974ed2ee388 Mon Sep 17 00:00:00 2001 From: Dark Knight Date: Wed, 22 Jan 2025 22:28:25 -0800 Subject: [PATCH 1/2] Revert D66521351 Summary: This diff reverts D66521351 Need to revert this to fix lowering import error breaking aps tests Reviewed By: PoojaAg18 Differential Revision: D68528333 --- torchrec/distributed/embedding.py | 13 +----- .../distributed/test_utils/test_sharding.py | 32 +++------------ .../tests/test_sequence_model_parallel.py | 41 ------------------- torchrec/modules/embedding_modules.py | 8 +--- 4 files changed, 10 insertions(+), 84 deletions(-) diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index feb77a72a..93773cc1f 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -26,7 +26,6 @@ ) import torch -from tensordict import TensorDict from torch import distributed as dist, nn from torch.autograd.profiler import record_function from torch.distributed._shard.sharding_spec import EnumerableShardingSpec @@ -91,7 +90,6 @@ from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer from torchrec.sparse.jagged_tensor import _to_offsets, JaggedTensor, KeyedJaggedTensor -from torchrec.sparse.tensor_dict import maybe_td_to_kjt try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -1200,15 +1198,8 @@ def _compute_sequence_vbe_context( def input_dist( self, ctx: EmbeddingCollectionContext, - features: TypeUnion[KeyedJaggedTensor, TensorDict], + features: KeyedJaggedTensor, ) -> Awaitable[Awaitable[KJTList]]: - need_permute: bool = True - if isinstance(features, TensorDict): - feature_keys = list(features.keys()) # pyre-ignore[6] - if self._features_order: - feature_keys = [feature_keys[i] for i in self._features_order] - need_permute = False - features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6] if self._has_uninitialized_input_dist: self._create_input_dist(input_feature_names=features.keys()) self._has_uninitialized_input_dist = False @@ -1218,7 +1209,7 @@ def input_dist( unpadded_features = features features = pad_vbe_kjt_lengths(unpadded_features) - if need_permute and self._features_order: + if self._features_order: features = features.permute( self._features_order, # pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]` diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index 48b9a90ab..f2b65a833 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -147,7 +147,6 @@ def gen_model_and_input( long_indices: bool = True, global_constant_batch: bool = False, num_inputs: int = 1, - input_type: str = "kjt", # "kjt" or "td" ) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]: torch.manual_seed(0) if dedup_feature_names: @@ -178,9 +177,9 @@ def gen_model_and_input( feature_processor_modules=feature_processor_modules, ) inputs = [] - if input_type == "kjt" and generate == ModelInput.generate_variable_batch_input: - for _ in range(num_inputs): - inputs.append( + for _ in range(num_inputs): + inputs.append( + ( cast(VariableBatchModelInputCallable, generate)( average_batch_size=batch_size, world_size=world_size, @@ -189,26 +188,8 @@ def gen_model_and_input( weighted_tables=weighted_tables or [], global_constant_batch=global_constant_batch, ) - ) - elif generate == ModelInput.generate: - for _ in range(num_inputs): - inputs.append( - ModelInput.generate( - world_size=world_size, - tables=tables, - dedup_tables=dedup_tables, - weighted_tables=weighted_tables or [], - num_float_features=num_float_features, - variable_batch_size=variable_batch_size, - batch_size=batch_size, - long_indices=long_indices, - input_type=input_type, - ) - ) - else: - for _ in range(num_inputs): - inputs.append( - cast(ModelInputCallable, generate)( + if generate == ModelInput.generate_variable_batch_input + else cast(ModelInputCallable, generate)( world_size=world_size, tables=tables, dedup_tables=dedup_tables, @@ -219,6 +200,7 @@ def gen_model_and_input( long_indices=long_indices, ) ) + ) return (model, inputs) @@ -315,7 +297,6 @@ def sharding_single_rank_test( global_constant_batch: bool = False, world_size_2D: Optional[int] = None, node_group_size: Optional[int] = None, - input_type: str = "kjt", # "kjt" or "td" ) -> None: with MultiProcessContext(rank, world_size, backend, local_size) as ctx: # Generate model & inputs. @@ -338,7 +319,6 @@ def sharding_single_rank_test( batch_size=batch_size, feature_processor_modules=feature_processor_modules, global_constant_batch=global_constant_batch, - input_type=input_type, ) global_model = global_model.to(ctx.device) global_input = inputs[0][0].to(ctx.device) diff --git a/torchrec/distributed/tests/test_sequence_model_parallel.py b/torchrec/distributed/tests/test_sequence_model_parallel.py index d13d819c3..aec092354 100644 --- a/torchrec/distributed/tests/test_sequence_model_parallel.py +++ b/torchrec/distributed/tests/test_sequence_model_parallel.py @@ -376,44 +376,3 @@ def _test_sharding( variable_batch_per_feature=variable_batch_per_feature, global_constant_batch=True, ) - - -@skip_if_asan_class -class TDSequenceModelParallelTest(SequenceModelParallelTest): - - def test_sharding_variable_batch(self) -> None: - pass - - def _test_sharding( - self, - sharders: List[TestEmbeddingCollectionSharder], - backend: str = "gloo", - world_size: int = 2, - local_size: Optional[int] = None, - constraints: Optional[Dict[str, ParameterConstraints]] = None, - model_class: Type[TestSparseNNBase] = TestSequenceSparseNN, - qcomms_config: Optional[QCommsConfig] = None, - apply_optimizer_in_backward_config: Optional[ - Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] - ] = None, - variable_batch_size: bool = False, - variable_batch_per_feature: bool = False, - ) -> None: - self._run_multi_process_test( - callable=sharding_single_rank_test, - world_size=world_size, - local_size=local_size, - model_class=model_class, - tables=self.tables, - embedding_groups=self.embedding_groups, - sharders=sharders, - optim=EmbOptimType.EXACT_SGD, - backend=backend, - constraints=constraints, - qcomms_config=qcomms_config, - apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, - variable_batch_size=variable_batch_size, - variable_batch_per_feature=variable_batch_per_feature, - global_constant_batch=True, - input_type="td", - ) diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index d110fd57f..4ade3df2f 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -219,10 +219,7 @@ def __init__( self._feature_names: List[List[str]] = [table.feature_names for table in tables] self.reset_parameters() - def forward( - self, - features: KeyedJaggedTensor, # can also take TensorDict as input - ) -> KeyedTensor: + def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: """ Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor` and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature. @@ -453,7 +450,7 @@ def __init__( # noqa C901 def forward( self, - features: KeyedJaggedTensor, # can also take TensorDict as input + features: KeyedJaggedTensor, ) -> Dict[str, JaggedTensor]: """ Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor` @@ -466,7 +463,6 @@ def forward( Dict[str, JaggedTensor] """ - features = maybe_td_to_kjt(features, None) feature_embeddings: Dict[str, JaggedTensor] = {} jt_dict: Dict[str, JaggedTensor] = features.to_dict() for i, emb_module in enumerate(self.embeddings.values()): From 771f14212049ab5e1d8653cb6d5386734a9b9d26 Mon Sep 17 00:00:00 2001 From: Dark Knight Date: Wed, 22 Jan 2025 22:28:25 -0800 Subject: [PATCH 2/2] Revert D65103519 Summary: This diff reverts D65103519 Depends on D68528333 Need to revert this to fix lowering import error breaking aps tests Reviewed By: PoojaAg18 Differential Revision: D68528363 --- torchrec/distributed/embeddingbag.py | 16 ++++------------ .../train_pipeline/tests/pipeline_benchmarks.py | 4 ++-- torchrec/modules/embedding_modules.py | 2 -- 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index de3d495f2..8cfd16ae9 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -27,7 +27,6 @@ import torch from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings -from tensordict import TensorDict from torch import distributed as dist, nn, Tensor from torch.autograd.profiler import record_function from torch.distributed._shard.sharded_tensor import TensorProperties @@ -95,7 +94,6 @@ from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor -from torchrec.sparse.tensor_dict import maybe_td_to_kjt try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -658,7 +656,9 @@ def __init__( self._inverse_indices_permute_indices: Optional[torch.Tensor] = None # to support mean pooling callback hook self._has_mean_pooling_callback: bool = ( - PoolingType.MEAN.value in self._pooling_type_to_rs_features + True + if PoolingType.MEAN.value in self._pooling_type_to_rs_features + else False ) self._dim_per_key: Optional[torch.Tensor] = None self._kjt_key_indices: Dict[str, int] = {} @@ -1189,16 +1189,8 @@ def _create_inverse_indices_permute_indices( # pyre-ignore [14] def input_dist( - self, - ctx: EmbeddingBagCollectionContext, - features: Union[KeyedJaggedTensor, TensorDict], + self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor ) -> Awaitable[Awaitable[KJTList]]: - if isinstance(features, TensorDict): - feature_keys = list(features.keys()) # pyre-ignore[6] - if len(self._features_order) > 0: - feature_keys = [feature_keys[i] for i in self._features_order] - self._has_features_permute = False # feature_keys are in order - features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6] ctx.variable_batch_per_feature = features.variable_stride_per_key() ctx.inverse_indices = features.inverse_indices_or_none() if self._has_uninitialized_input_dist: diff --git a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py index fdb900fe0..e8dc5eccb 100644 --- a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py +++ b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py @@ -160,7 +160,7 @@ def main( tables = [ EmbeddingBagConfig( - num_embeddings=max(i + 1, 100) * 1000, + num_embeddings=(i + 1) * 1000, embedding_dim=dim_emb, name="table_" + str(i), feature_names=["feature_" + str(i)], @@ -169,7 +169,7 @@ def main( ] weighted_tables = [ EmbeddingBagConfig( - num_embeddings=max(i + 1, 100) * 1000, + num_embeddings=(i + 1) * 1000, embedding_dim=dim_emb, name="weighted_table_" + str(i), feature_names=["weighted_feature_" + str(i)], diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 4ade3df2f..307d66639 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -19,7 +19,6 @@ pooling_type_to_str, ) from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor -from torchrec.sparse.tensor_dict import maybe_td_to_kjt @torch.fx.wrap @@ -230,7 +229,6 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: KeyedTensor """ flat_feature_names: List[str] = [] - features = maybe_td_to_kjt(features, None) for names in self._feature_names: flat_feature_names.extend(names) inverse_indices = reorder_inverse_indices(