Skip to content

Commit

Permalink
2024-11-19 nightly release (1f955b5)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Nov 19, 2024
1 parent 3183564 commit 8c79542
Show file tree
Hide file tree
Showing 20 changed files with 301 additions and 87 deletions.
7 changes: 4 additions & 3 deletions .github/scripts/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

def main():
"""
Since FBGEMM doesn't publish CUDA 12 binaries, torchrec will not work with
CUDA 12. As a result, we filter out CUDA 12 from the build matrix that
Since FBGEMM doesn't publish CUDA 12.6 binaries yet, torchrec will not work with
CUDA 12.6. As a result, we filter out CUDA 12.6 from the build matrix that
determines with nightly builds are run.
"""

Expand All @@ -22,7 +22,8 @@ def main():
new_matrix_entries = []

for entry in full_matrix["include"]:
new_matrix_entries.append(entry)
if entry["gpu_arch_version"] != "12.6":
new_matrix_entries.append(entry)

new_matrix = {"include": new_matrix_entries}
print(json.dumps(new_matrix))
Expand Down
7 changes: 4 additions & 3 deletions .github/workflows/build_dynamic_embedding_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ jobs:
fail-fast: false
matrix:
os: [ ubuntu-latest ]
pyver: [ cp39, cp310 ]
cuver: [ "11.8" ]
pyver: [ cp39, cp310, cp311, cp312 ]
cuver: [ "12.1", "12.4"]

steps:
-
Expand All @@ -41,13 +41,14 @@ jobs:
with:
submodules: recursive

- uses: pypa/cibuildwheel@v2.19.2
- uses: pypa/cibuildwheel@v2.20.0
with:
package-dir: contrib/dynamic_embedding
env:
CIBW_BEFORE_BUILD: "env CUDA_VERSION=${{ matrix.cuver }} contrib/dynamic_embedding/tools/before_linux_build.sh"
CIBW_BUILD: "${{ matrix.pyver }}-manylinux_x86_64"
CIBW_REPAIR_WHEEL_COMMAND: "env CUDA_VERSION=${{ matrix.cuver }} contrib/dynamic_embedding/tools/repair_wheel.sh {wheel} {dest_dir}"
CIBW_MANYLINUX_X86_64_IMAGE: "manylinux_2_28"

- name: Verify clean directory
run: git diff --exit-code
Expand Down
1 change: 1 addition & 0 deletions contrib/dynamic_embedding/tools/before_linux_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ CUDA_VERSION="${CUDA_VERSION:-11.8}"
CUDA_MAJOR_VERSION=$(echo "${CUDA_VERSION}" | tr '.' ' ' | awk '{print $1}')
CUDA_MINOR_VERSION=$(echo "${CUDA_VERSION}" | tr '.' ' ' | awk '{print $2}')

yum install -y yum-utils
yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$distro/$arch/cuda-$distro.repo
yum install -y \
cuda-toolkit-"${CUDA_MAJOR_VERSION}"-"${CUDA_MINOR_VERSION}" \
Expand Down
2 changes: 2 additions & 0 deletions contrib/dynamic_embedding/tools/build_wheels.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ export CIBW_BEFORE_BUILD="tools/before_linux_build.sh"
# all kinds of CPython.
export CIBW_BUILD=${CIBW_BUILD:-"cp39-manylinux_x86_64"}

export CIBW_MANYLINUX_X86_64_IMAGE=${CIBW_MANYLINUX_X86_64_IMAGE:-"manylinux_2_28"}

# Do not auditwheels since tde uses torch's shared libraries.
export CIBW_REPAIR_WHEEL_COMMAND="tools/repair_wheel.sh {wheel} {dest_dir}"

Expand Down
1 change: 1 addition & 0 deletions install-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
fbgemm-gpu
tensordict
torchmetrics==1.0.3
tqdm
pyre-extensions
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ numpy
pandas
pyre-extensions
scikit-build
tensordict
torchmetrics==1.0.3
torchx
tqdm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

#!/usr/bin/env python3

from typing import Dict, List

import click

import torch
Expand Down Expand Up @@ -82,9 +84,10 @@ def op_bench(
)

def _func_to_benchmark(
kjt: KeyedJaggedTensor,
kjts: List[Dict[str, KeyedJaggedTensor]],
model: torch.nn.Module,
) -> torch.Tensor:
kjt = kjts[0]["feature"]
return model.forward(kjt.values(), kjt.offsets())

# breakpoint() # import fbvscode; fbvscode.set_trace()
Expand All @@ -108,8 +111,8 @@ def _func_to_benchmark(

result = benchmark_func(
name=f"SplitTableBatchedEmbeddingBagsCodegen-{num_embeddings}-{embedding_dim}-{num_tables}-{batch_size}-{bag_size}",
bench_inputs=inputs, # pyre-ignore
prof_inputs=inputs, # pyre-ignore
bench_inputs=[{"feature": inputs}],
prof_inputs=[{"feature": inputs}],
num_benchmarks=10,
num_profiles=10,
profile_dir=".",
Expand Down
5 changes: 4 additions & 1 deletion torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,11 +374,14 @@ def get_inputs(

if train:
sparse_features_by_rank = [
model_input.idlist_features for model_input in model_input_by_rank
model_input.idlist_features
for model_input in model_input_by_rank
if isinstance(model_input.idlist_features, KeyedJaggedTensor)
]
inputs_batch.append(sparse_features_by_rank)
else:
sparse_features = model_input_by_rank[0].idlist_features
assert isinstance(sparse_features, KeyedJaggedTensor)
inputs_batch.append([sparse_features])

# Transpose if train, as inputs_by_rank is currently in [B X R] format
Expand Down
19 changes: 17 additions & 2 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,15 @@ def _prefetch_and_cached(
)


def _all_tables_are_quant_kernel(
tables: List[ShardedEmbeddingTable],
) -> bool:
"""
Return if all tables have quant compute kernel.
"""
return all(table.compute_kernel == EmbeddingComputeKernel.QUANT for table in tables)


# group tables by `DataType`, `PoolingType`, and `EmbeddingComputeKernel`.
def group_tables(
tables_per_rank: List[List[ShardedEmbeddingTable]],
Expand Down Expand Up @@ -489,6 +498,8 @@ def _group_tables_per_rank(
# Collect groups
groups = defaultdict(list)
grouping_keys = []
# Assumes all compute kernels within tables are the same
is_inference = _all_tables_are_quant_kernel(embedding_tables)
for table in embedding_tables:
bucketer = (
prefetch_cached_dim_bucketer
Expand All @@ -499,12 +510,16 @@ def _group_tables_per_rank(
_get_grouping_fused_params(table.fused_params, table.name) or {}
)
grouping_key = (
table.data_type,
table.data_type if not is_inference else None,
table.pooling,
table.has_feature_processor,
tuple(sorted(group_fused_params.items())),
_get_compute_kernel_type(table.compute_kernel),
bucketer.get_bucket(table.local_cols, table.data_type),
# TODO: Unit test to check if table.data_type affects table grouping
bucketer.get_bucket(
table.local_cols,
table.data_type,
),
_prefetch_and_cached(table),
)
# micromanage the order of we traverse the groups to ensure backwards compatibility
Expand Down
34 changes: 22 additions & 12 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
from torch import distributed as dist, nn, Tensor
from torch.autograd.profiler import record_function
from torch.distributed._tensor import DTensor, Shard
from torch.distributed._tensor import DTensor
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.embedding_sharding import (
Expand Down Expand Up @@ -67,6 +67,7 @@
ShardingEnv,
ShardingType,
ShardMetadata,
TensorProperties,
)
from torchrec.distributed.utils import (
add_params_from_parameter_sharding,
Expand Down Expand Up @@ -98,13 +99,6 @@
pass


# OSS
try:
pass
except ImportError:
pass


def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
return (
tensor
Expand Down Expand Up @@ -938,11 +932,27 @@ def _initialize_torch_state(self) -> None: # noqa
# created ShardedTensors once in init, use in post_state_dict_hook
# note: at this point kvstore backed tensors don't own valid snapshots, so no read
# access is allowed on them.
sharding_spec = none_throws(
self.module_sharding_plan[table_name].sharding_spec
)
metadata = sharding_spec.build_metadata(
tensor_sizes=self._name_to_table_size[table_name],
tensor_properties=(
TensorProperties(
dtype=local_shards[0].tensor.dtype,
layout=local_shards[0].tensor.layout,
requires_grad=local_shards[0].tensor.requires_grad,
)
if local_shards
else TensorProperties()
),
)

self._model_parallel_name_to_sharded_tensor[table_name] = (
ShardedTensor._init_from_local_shards(
local_shards,
self._name_to_table_size[table_name],
process_group=self._env.process_group,
ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=local_shards,
sharded_tensor_metadata=metadata,
process_group=none_throws(self._env.process_group),
)
)

Expand Down
43 changes: 23 additions & 20 deletions torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,27 +216,30 @@ def __init__(
self._runtime_device: torch.device = _get_runtime_device(device, config)
# 16 for CUDA, 1 for others like CPU and MTIA.
self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1
embedding_specs = []
for local_rows, local_cols, table, location in zip(
self._local_rows,
self._local_cols,
config.embedding_tables,
managed,
):
embedding_specs.append(
(
table.name,
local_rows,
(
local_cols
if self._quant_state_dict_split_scale_bias
else table.embedding_dim
),
data_type_to_sparse_type(table.data_type),
location,
)
)

self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = (
IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs=[
(
table.name,
local_rows,
(
local_cols
if self._quant_state_dict_split_scale_bias
else table.embedding_dim
),
data_type_to_sparse_type(config.data_type),
location,
)
for local_rows, local_cols, table, location in zip(
self._local_rows,
self._local_cols,
config.embedding_tables,
managed,
)
],
embedding_specs=embedding_specs,
device=device,
pooling_mode=self._pooling,
feature_table_map=self._feature_table_map,
Expand Down Expand Up @@ -411,7 +414,7 @@ def __init__(
if self._quant_state_dict_split_scale_bias
else table.embedding_dim
),
data_type_to_sparse_type(config.data_type),
data_type_to_sparse_type(table.data_type),
location,
)
for local_rows, local_cols, table, location in zip(
Expand Down
4 changes: 3 additions & 1 deletion torchrec/distributed/test_utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def model_input_to_forward_args_kjt(
Optional[torch.Tensor],
]:
kjt = mi.idlist_features
assert isinstance(kjt, KeyedJaggedTensor)
return (
kjt._keys,
kjt._values,
Expand Down Expand Up @@ -289,7 +290,8 @@ def model_input_to_forward_args(
]:
idlist_kjt = mi.idlist_features
idscore_kjt = mi.idscore_features
assert idscore_kjt is not None
assert isinstance(idlist_kjt, KeyedJaggedTensor)
assert isinstance(idscore_kjt, KeyedJaggedTensor)
return (
mi.float_features,
idlist_kjt._keys,
Expand Down
Loading

0 comments on commit 8c79542

Please sign in to comment.