Skip to content

Commit

Permalink
2025-01-31 nightly release (b3e19e2)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Jan 31, 2025
1 parent cc3cdc7 commit 7fca948
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import itertools
import logging
import tempfile
from collections import OrderedDict
from dataclasses import dataclass
from typing import (
Any,
Expand Down Expand Up @@ -216,6 +215,7 @@ def __init__( # noqa C901
pg: Optional[dist.ProcessGroup] = None,
create_for_table: Optional[str] = None,
param_weight_for_table: Optional[nn.Parameter] = None,
embedding_weights_by_table: Optional[List[torch.Tensor]] = None,
) -> None:
"""
Implementation of a FusedOptimizer. Designed as a base class Embedding kernels
Expand Down Expand Up @@ -391,7 +391,9 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
# that state_dict look identical to no-fused version.
table_to_shard_params: Dict[str, ShardParams] = {}

embedding_weights_by_table = emb_module.split_embedding_weights()
embedding_weights_by_table = (
embedding_weights_by_table or emb_module.split_embedding_weights()
)

all_optimizer_states = emb_module.get_optimizer_state()
optimizer_states_keys_by_table: Dict[str, List[torch.Tensor]] = {}
Expand Down Expand Up @@ -674,6 +676,8 @@ def _gen_named_parameters_by_table_fused(
pg: Optional[dist.ProcessGroup] = None,
) -> Iterator[Tuple[str, TableBatchedEmbeddingSlice]]:
# TODO: move logic to FBGEMM to avoid accessing fbgemm internals
# Cache embedding_weights_by_table
embedding_weights_by_table = emb_module.split_embedding_weights()
for t_idx, (rows, dim, location, _) in enumerate(emb_module.embedding_specs):
table_name = config.embedding_tables[t_idx].name
if table_name not in table_name_to_count:
Expand Down Expand Up @@ -709,6 +713,7 @@ def _gen_named_parameters_by_table_fused(
pg=pg,
create_for_table=table_name,
param_weight_for_table=weight,
embedding_weights_by_table=embedding_weights_by_table,
)
]
yield (table_name, weight)
Expand Down

0 comments on commit 7fca948

Please sign in to comment.