Skip to content

Commit

Permalink
2024-11-05 nightly release (786bb1e)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Nov 5, 2024
1 parent 623ece0 commit 5e46ad6
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ jobs:
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -fsSL "http://169.254.169.254/latest/meta-data/${category}"
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
Expand Down
19 changes: 3 additions & 16 deletions torchrec/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,6 @@ def _fx_to_list(tensor: torch.Tensor) -> List[int]:
return tensor.long().tolist()


@torch.fx.wrap
def _get_unflattened_lengths(lengths: torch.Tensor, num_features: int) -> torch.Tensor:
"""
Unflatten lengths tensor from [F * B] to [F, B].
"""
return lengths.view(num_features, -1)


@torch.fx.wrap
def _slice_1d_tensor(tensor: torch.Tensor, start: int, end: int) -> torch.Tensor:
"""
Expand Down Expand Up @@ -311,23 +303,18 @@ def construct_jagged_tensors_inference(
remove_padding: bool = False,
) -> Dict[str, JaggedTensor]:
with record_function("## construct_jagged_tensors_inference ##"):
# [F * B] -> [F, B]
unflattened_lengths = _get_unflattened_lengths(lengths, len(embedding_names))

if reverse_indices is not None:
embeddings = torch.index_select(
embeddings, 0, reverse_indices.to(torch.int32)
)
elif remove_padding:
embeddings = _slice_1d_tensor(
embeddings, 0, unflattened_lengths.sum().item()
)
embeddings = _slice_1d_tensor(embeddings, 0, lengths.sum().item())

ret: Dict[str, JaggedTensor] = {}

length_per_key: List[int] = _fx_to_list(torch.sum(unflattened_lengths, dim=1))
length_per_key: List[int] = _fx_to_list(torch.sum(lengths, dim=1))

lengths_tuple = torch.unbind(unflattened_lengths, dim=0)
lengths_tuple = torch.unbind(lengths, dim=0)

embeddings_list = torch.split(embeddings, length_per_key, dim=0)
values_list = torch.split(values, length_per_key) if need_indices else None
Expand Down
20 changes: 17 additions & 3 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@

MODULE_ATTR_REMOVE_STBE_PADDING_BOOL: str = "__remove_stbe_padding"

MODULE_ATTR_USE_BATCHING_HINTED_OUTPUT_BOOL: str = "__use_batching_hinted_output"
MODULE_ATTR_USE_UNFLATTENED_LENGTHS_FOR_BATCHING: str = (
"__use_unflattened_lengths_for_batching"
)

DEFAULT_ROW_ALIGNMENT = 16

Expand All @@ -108,6 +110,14 @@ def _cat_embeddings(embeddings: List[Tensor]) -> Tensor:
return embeddings[0] if len(embeddings) == 1 else torch.cat(embeddings, dim=1)


@torch.fx.wrap
def _get_unflattened_lengths(lengths: torch.Tensor, num_features: int) -> torch.Tensor:
"""
Unflatten lengths tensor from [F * B] to [F, B].
"""
return lengths.view(num_features, -1)


def for_each_module_of_type_do(
module: nn.Module,
module_types: List[Type[torch.nn.Module]],
Expand Down Expand Up @@ -893,14 +903,18 @@ def forward(
f = kjts_per_key[i]
lengths = _get_feature_length(f)
indices, offsets = _fx_trec_unwrap_kjt(f)
embedding_names = self._embedding_names_by_batched_tables[key]
lookup = (
emb_module(indices=indices, offsets=offsets)
if self.register_tbes
else emb_module.forward(indices=indices, offsets=offsets)
)
if getattr(self, MODULE_ATTR_USE_BATCHING_HINTED_OUTPUT_BOOL, True):
if getattr(self, MODULE_ATTR_USE_UNFLATTENED_LENGTHS_FOR_BATCHING, False):
lengths = _get_unflattened_lengths(lengths, len(embedding_names))
lookup = _get_batching_hinted_output(lengths=lengths, output=lookup)
embedding_names = self._embedding_names_by_batched_tables[key]
else:
lookup = _get_batching_hinted_output(lengths=lengths, output=lookup)
lengths = _get_unflattened_lengths(lengths, len(embedding_names))
jt = construct_jagged_tensors_inference(
embeddings=lookup,
lengths=lengths,
Expand Down
94 changes: 94 additions & 0 deletions torchrec/quant/tests/test_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@
)
from torchrec.quant.embedding_modules import (
_fx_trec_unwrap_kjt,
_get_batching_hinted_output,
_get_unflattened_lengths,
EmbeddingBagCollection as QuantEmbeddingBagCollection,
EmbeddingCollection as QuantEmbeddingCollection,
MODULE_ATTR_USE_UNFLATTENED_LENGTHS_FOR_BATCHING,
quant_prep_enable_quant_state_dict_split_scale_bias,
)
from torchrec.sparse.jagged_tensor import (
Expand Down Expand Up @@ -863,3 +866,94 @@ def test_fx_unwrap_unsharded_vs_sharded_in_sync(

self.assertEqual(indices.dtype, sharded_indices.dtype)
self.assertEqual(offsets.dtype, sharded_offsets.dtype)

def test_using_flattened_or_unflattened_length_rebatching(self) -> None:
data_type = DataType.FP16
quant_type = torch.half
output_type = torch.half

ec1_config = EmbeddingConfig(
name="t1",
embedding_dim=16,
num_embeddings=10,
feature_names=["f1", "f2"],
data_type=data_type,
)
ec2_config = EmbeddingConfig(
name="t2",
embedding_dim=16,
num_embeddings=10,
feature_names=["f3", "f4"],
data_type=data_type,
)

ec = EmbeddingCollection(tables=[ec1_config, ec2_config])
ec.qconfig = torch.quantization.QConfig(
activation=torch.quantization.PlaceholderObserver.with_args(
dtype=output_type
),
weight=torch.quantization.PlaceholderObserver.with_args(dtype=quant_type),
)

qec = QuantEmbeddingCollection.from_float(ec)

import copy

from torchrec.fx import symbolic_trace

# test using flattened lengths for rebatching (default)

gm = symbolic_trace(copy.deepcopy(qec))

found_get_unflattened_lengths_func = False

for node in gm.graph.nodes:
if (
node.op == "call_function"
and node.name == _get_unflattened_lengths.__name__
):
found_get_unflattened_lengths_func = True
for user in node.users:
if (
user.op == "call_function"
and user.name == _get_batching_hinted_output.__name__
):
self.assertTrue(
False,
"Should not call _get_batching_hinted_output after _get_unflattened_lengths",
)

self.assertTrue(
found_get_unflattened_lengths_func,
"_get_unflattened_lengths must exist in the graph",
)

# test using unflattened lengths for rebatching

setattr(qec, MODULE_ATTR_USE_UNFLATTENED_LENGTHS_FOR_BATCHING, True)

gm = symbolic_trace(qec)

found_get_unflattened_lengths_func = False
for node in gm.graph.nodes:
if (
node.op == "call_function"
and node.name == _get_unflattened_lengths.__name__
):
found_get_unflattened_lengths_func = True
found_get_batching_hinted_output_func = False
for user in node.users:
if (
user.op == "call_function"
and user.name == _get_batching_hinted_output.__name__
):
found_get_batching_hinted_output_func = True
self.assertTrue(
found_get_batching_hinted_output_func,
"Should call _get_batching_hinted_output after _get_unflattened_lengths",
)

self.assertTrue(
found_get_unflattened_lengths_func,
"_get_unflattened_lengths must exist in the graph",
)

0 comments on commit 5e46ad6

Please sign in to comment.