diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 27f8c1b42..861e6e868 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -94,6 +94,7 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]: SSDTableBatchedEmbeddingBags.__init__ ).parameters.keys() invalid_keys: List[str] = [] + for key, value in fused_params.items(): if key not in ssd_tbe_signature: invalid_keys.append(key) @@ -151,6 +152,21 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]: weights_precision = data_type_to_sparse_type(config.data_type) ssd_tbe_params["weights_precision"] = weights_precision + if "max_l1_cache_size" in fused_params: + l1_cache_size = fused_params.get("max_l1_cache_size") * 1024 * 1024 + max_dim: int = max(table.local_cols for table in config.embedding_tables) + weight_precision_bytes = ssd_tbe_params["weights_precision"].bit_rate() / 8 + max_cache_sets = ( + l1_cache_size / ASSOC / weight_precision_bytes / max_dim + ) # 100MB + + if ssd_tbe_params["cache_sets"] > int(max_cache_sets): + logger.warning( + f"cache_sets {ssd_tbe_params['cache_sets']} is larger than max_cache_sets {max_cache_sets} calculated " + "by max_l1_cache_size, cap at max_cache_sets instead" + ) + ssd_tbe_params["cache_sets"] = int(max_cache_sets) + return ssd_tbe_params diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index b7a886773..44461752a 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -649,7 +649,8 @@ class KeyValueParams: gather_ssd_cache_stats: Optional[bool] = None stats_reporter_config: Optional[TBEStatsReporterConfig] = None use_passed_in_path: bool = True - l2_cache_size: Optional[int] = None + l2_cache_size: Optional[int] = None # size in GB + max_l1_cache_size: Optional[int] = None # size in MB enable_async_update: Optional[bool] = None # Parameter Server (PS) Attributes @@ -673,6 +674,7 @@ def __hash__(self) -> int: self.gather_ssd_cache_stats, self.stats_reporter_config, self.l2_cache_size, + self.max_l1_cache_size, self.enable_async_update, ) )