Skip to content

Commit

Permalink
Update base for Update on "make sure dynamo doesn't inline DTensor __…
Browse files Browse the repository at this point in the history
…new__ or __torch_dispatch__"

Fixes #122459, pytorch/torchtitan#61

Even with the previous PR ("support DTensor/subclass constructors directly in the graph"), I still see some errors when running the repro above that start some logs showing that dynamo is inlining `__new__`.

I noticed that putting `torch._dynamo.disable` on DTensor's `__new__` makes the entire repro pass.

Why does having dynamo try to inline `Subclass.__new__` run into problems? Morally, dynamo probably shouldn't be inlining __new__ ("creating a subclass" is a blackbox operation that AOTAutograd can trace through anyway). But concretely, we can end up with a node in the dynamo FX graph that has a "partially initialized tensor subclass" as its example value, because the subclass has been created but its fields have not been assigned to yet.

This breaks a bunch of invariants throughout dynamo: there are many places where if we have a tensor subclass node, we want to look at its inner tensors, to see if they are FakeTensors, what their FakeTensorMode is, and if they have dynamic shapes.

One option is to decide that "uninitialized subclass" is a first-class thing that anyone looking at the FX node examples values on the dynamo graph needs to handle, but this seems like a lot of work when in reality we don't need dynamo to trace the __new__ at all. Hence the `torch._dynamo.disable`.

I still wasn't very satisfied, since it was unclear to me **why** dynamo was inlining the `__new__` call, instead of interposing on the `DTensor()` constructor directly. After a long chat with anijain2305, he explained that with code like this:
```
torch._dynamo.disable(recursive=False)
def f(x):
    out = SubclassConstructor(x)
```

Dynamo will never get the chance to interpose on the subclass constructor. Instead, what will happen is:
(1) Dynamo hands back control to cpython to run `f()`, since we disabled that frame
(2) `SubclassConstructor(x)` is run in eager mode
(3) `SubclassConstructor(x)` eventually calls `SubclassConstructor__new__`
(4) this is a new frame, that cpython then allows dynamo to intercept and start compiling

So it looks like we are basically forced to handle the situation where dynamo might directly start compiling `Subclass.__new__`

All of the above does not explain the story for `__torch_dispatch__` though. Empirically, I have a repro in torchtrain where looking at the dynamo logs, we see dynamo try to inline `__torch_dispatch__`.
```
[rank0]:DEBUG: Skipping frame because no content in function call _prepare_output_fn                     /data/users/hirsheybar/b/pytorch/torch/distributed/tensor/parallel/style.py 318
[rank0]:DEBUG: torchdynamo start compiling __torch_dispatch__ /data/users/hirsheybar/b/pytorch/torch/distributed/_tensor/api.py:297, stack (elided 5 frames):
```

I haven't been able to create a smaller repro of the problem (even using `_dynamo.disable(recursive=False)`), although in theory, if there is a `torch.*` op that you were to inline (where one of the inputs is a subclass), the next frame would likely be `__torch_dispatch__`. Dynamo always treats `torch.*` operations as not-inlinable though, so in theory we shouldn't ever see dynamo inline `__torch_dispatch__`, but a `_dynamo.disable()` fixes the problem.

I asked Animesh if we can have dynamo automatically apply this behavior to subclasses instead of needing it to be added explicitly. He pointed out that for `disable(recursive=False)`, we can't really do this within dynamo




[ghstack-poisoned]
  • Loading branch information
bdhirsh committed Apr 12, 2024
2 parents 40992df + d0ccf59 commit 32bb41c
Show file tree
Hide file tree
Showing 206 changed files with 4,959 additions and 2,086 deletions.
5 changes: 3 additions & 2 deletions .ci/docker/requirements-ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ librosa>=0.6.2 ; python_version < "3.11"
#Pinned versions:
#test that import:

mypy==1.8.0
mypy==1.9.0
# Pin MyPy version because new errors are likely to appear with each release
#Description: linter
#Pinned versions: 1.8.0
#Pinned versions: 1.9.0
#test that import: test_typing.py, test_type_hints.py

networkx==2.8.8
Expand Down Expand Up @@ -231,6 +231,7 @@ scikit-image==0.20.0 ; python_version >= "3.10"
scipy==1.6.3 ; python_version < "3.10"
scipy==1.8.1 ; python_version == "3.10"
scipy==1.10.1 ; python_version == "3.11"
scipy==1.12.0 ; python_version == "3.12"
# Pin SciPy because of failing distribution tests (see #60347)
#Description: scientific python
#Pinned versions: 1.6.3
Expand Down
4 changes: 4 additions & 0 deletions .ci/onnx/common.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#!/bin/bash

set -ex

source "$(dirname "${BASH_SOURCE[0]}")/../pytorch/common_utils.sh"

LOCAL_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
ROOT_DIR=$(cd "$LOCAL_DIR"/../.. && pwd)
TEST_DIR="$ROOT_DIR/test"
Expand Down
14 changes: 14 additions & 0 deletions .ci/onnx/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,20 @@
# shellcheck source=./common.sh
source "$(dirname "${BASH_SOURCE[0]}")/common.sh"

# Workaround for dind-rootless userid mapping (https://github.com/pytorch/ci-infra/issues/96)
WORKSPACE_ORIGINAL_OWNER_ID=$(stat -c '%u' "/var/lib/jenkins/workspace")
cleanup_workspace() {
echo "sudo may print the following warning message that can be ignored. The chown command will still run."
echo " sudo: setrlimit(RLIMIT_STACK): Operation not permitted"
echo "For more details refer to https://github.com/sudo-project/sudo/issues/42"
sudo chown -R "$WORKSPACE_ORIGINAL_OWNER_ID" /var/lib/jenkins/workspace
}
# Disable shellcheck SC2064 as we want to parse the original owner immediately.
# shellcheck disable=SC2064
trap_add cleanup_workspace EXIT
sudo chown -R jenkins /var/lib/jenkins/workspace
git config --global --add safe.directory /var/lib/jenkins/workspace

if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then
# TODO: This can be removed later once vision is also part of the Docker image
pip install -q --user --no-use-pep517 "git+https://github.com/pytorch/vision.git@$(cat .github/ci_commit_pins/vision.txt)"
Expand Down
1 change: 1 addition & 0 deletions .ci/pytorch/python_doc_push_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ if [ "$is_main_doc" = true ]; then
echo undocumented objects found:
cat build/coverage/python.txt
echo "Make sure you've updated relevant .rsts in docs/source!"
echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
exit 1
fi
else
Expand Down
2 changes: 1 addition & 1 deletion .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ init_command = [
'numpy==1.24.3 ; python_version == "3.8"',
'numpy==1.26.0 ; python_version >= "3.9"',
'expecttest==0.1.6',
'mypy==1.8.0',
'mypy==1.9.0',
'sympy==1.11.1',
'types-requests==2.27.25',
'types-PyYAML==6.0.7',
Expand Down
47 changes: 47 additions & 0 deletions aten/src/ATen/FunctionalStorageImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,31 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {

~FunctionalStorageImpl() override = default;

void mark_mutation() {
mutation_counter_++;
}
void mark_mutation_during_no_grad_or_inference_mode() {
mutation_counter_during_no_grad_or_inference_mode_++;
}
void mark_mutation_hidden_from_autograd() {
mutation_counter_hidden_from_autograd_++;
}

bool are_all_mutations_under_no_grad_or_inference_mode() const {
auto non_autograd_mutations =
mutation_counter_during_no_grad_or_inference_mode_ +
mutation_counter_hidden_from_autograd_;
// The <= is because both counters will technically be incremented, if we
// perform e.g. a triton kernel mutation under no_grad
return mutation_counter_ <= non_autograd_mutations;
}

bool are_all_mutations_hidden_from_autograd() const {
// mutations under no_grad / inference_mode are technically not hidden from
// autograd - they change the version counter
return mutation_counter_ <= mutation_counter_hidden_from_autograd_;
}

private:
// NB: base_ should always point to a tensor BELOW the current
// functionalization layer. This is mainly to avoid reference cycles. e.g.
Expand All @@ -125,6 +150,28 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
// If frozen, no more mutations are allowed on this storage. Once frozen, a
// storage cannot be unfrozen.
bool frozen_ = false;

// These mutation counters are bumped on the storage
// whenever a FunctionalTensorWrapper experiences a mutation.
// When the mutation is under no_grad, or comes from a triton kernel, we also
// bump the corresponding during_no_grad or hidden_from_autograd counters. Why
// do we need to detect these two situations separately from "normal" input
// mutations? (1) "normal" input mutations can mutate autograd metadata like
// .grad_fn,
// in which case they need to be replayed outside of the compiled graph
// (2) "no_grad" input mutations are generally safe to keep in the graph (and
// compile),
// but they bump the tensor's VC, so we need to mark_dirty() on the inputs
// in torch.compile
// (3) mutations that are fully hidden from autograd (e.g. from a triton
// kernel)
// do not mutate any autograd state, and be fully kept in the graph
// When we detect that an input was mutated, we need to be able to tell if:
// (1) all of the mutations were from triton kernels
// (2) all of the mutations were under no_grad
uint64_t mutation_counter_during_no_grad_or_inference_mode_ = 0;
uint64_t mutation_counter_ = 0;
uint64_t mutation_counter_hidden_from_autograd_ = 0;
};

} // namespace at::functionalization
24 changes: 15 additions & 9 deletions aten/src/ATen/FunctionalTensorWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ void FunctionalTensorWrapper::mutate_view_meta(const at::functionalization::View
// In the above, tmp is a batched tensor (because adding a normal tensor to a batched tensor does broadcasting and creates a batched tensor).
// But we can't just replace the underlying memory backing `tensor` with `tmp` - a batched tensor takes up more space!
// Instead, every input, intermediate and output of the program is wrapped in a FunctionalTensorImpl, which wraps the underlying tensor.
void FunctionalTensorWrapper::replace_(const Tensor& other) {
void FunctionalTensorWrapper::replace_(const Tensor& other, bool from_lazy_regenerate) {
// TODO: going to need to change this if we want nested functionalize() transforms.
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(other));
value_ = other;
Expand All @@ -231,10 +231,19 @@ void FunctionalTensorWrapper::replace_(const Tensor& other) {
value_ = at::_to_copy(value_, c10::TensorOptions().dtype(dtype()).layout(layout()));
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
}
mutation_counter_++;
if (!at::GradMode::is_enabled() || InferenceMode::is_enabled()) {
// This mutation happened under no_grad or inference_mode
mark_mutation_during_no_grad_or_inference_mode();
// might not be until after the no_grad region is exited.
// Therefore, replace_() is not unconditionally safe to check the current no_grad state.
// If this is a lazy regeneration, then it is guaranteed that we have already
// done the mutation for the storage alias (when we originally performed the mutation),
// so no counter update may be needed.
// Example: if a mutation happens to a view under a no_grad,
// we won't call replace_() on the other alias until the alias is later used, which
if (!from_lazy_regenerate) {
mark_mutation();
if (!at::GradMode::is_enabled() || InferenceMode::is_enabled()) {
// This mutation happened under no_grad or inference_mode
mark_mutation_during_no_grad_or_inference_mode();
}
}
}

Expand Down Expand Up @@ -338,7 +347,7 @@ void FunctionalTensorWrapper::regenerate_from_base() {
t = view_meta.forward_fn(t, view_meta.out_index);
}
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
replace_(t);
replace_(t, /*from_lazy_regenerate=*/true);
generation_ = storage_impl->generation();
}

Expand Down Expand Up @@ -366,9 +375,6 @@ void FunctionalTensorWrapper::copy_tensor_metadata(
// FunctionalTensorWrapper-specific fields.
dest_impl->value_ = src_impl->value_;
dest_impl->level_ = src_impl->level_;
dest_impl->mutation_counter_ = src_impl->mutation_counter_;
dest_impl->mutation_hidden_from_autograd_counter_ = src_impl->mutation_hidden_from_autograd_counter_;
dest_impl->mutation_during_no_grad_or_inference_mode_ = src_impl->mutation_during_no_grad_or_inference_mode_;
dest_impl->has_metadata_mutation_ = src_impl->has_metadata_mutation_;
dest_impl->is_multi_output_view_ = src_impl->is_multi_output_view_;
dest_impl->was_storage_changed_ = src_impl->was_storage_changed_;
Expand Down
19 changes: 9 additions & 10 deletions aten/src/ATen/FunctionalTensorWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,24 +75,26 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
return has_metadata_mutation_;
};

void mark_mutation() {
functional_storage_impl()->mark_mutation();
}
// Denotes a mutation that's hidden from autograd,
// e.g. for the purposes of passing a tensor to a triton kernel
void mark_mutation_hidden_from_autograd() {
mutation_hidden_from_autograd_counter_++;
functional_storage_impl()->mark_mutation_hidden_from_autograd();
}
void mark_mutation_during_no_grad_or_inference_mode() {
mutation_during_no_grad_or_inference_mode_++;
functional_storage_impl()->mark_mutation_during_no_grad_or_inference_mode();
}
// Are all the mutations happening to the tensor hidden from autograd
bool are_all_mutations_hidden_from_autograd() const {
return mutation_hidden_from_autograd_counter_ == mutation_counter_;
return functional_storage_impl()->are_all_mutations_hidden_from_autograd();
}
// Did all mutations happen under no_grad or inference_mode
// (We also need to ignore mutations fully hidden from autograd here)
bool are_all_mutations_under_no_grad_or_inference_mode() const {
return mutation_hidden_from_autograd_counter_ +
mutation_during_no_grad_or_inference_mode_ ==
mutation_counter_;
return functional_storage_impl()
->are_all_mutations_under_no_grad_or_inference_mode();
}

// Sync's the underlying tensor with its alias, if it's out of date. This
Expand Down Expand Up @@ -156,7 +158,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
// a.replace_(tmp)
//
// replace_() swaps out the wrapped tensor, value_, with tmp.
void replace_(const Tensor& other);
void replace_(const Tensor& other, bool from_lazy_regenerate = false);

bool is_multi_output_view() {
return is_multi_output_view_;
Expand Down Expand Up @@ -227,9 +229,6 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
// not. If we have an input mutation that is hidden from autograd, then once
// we convert the input mutation to a copy_() we know it will be safe to hide
// the copy_() from autograd as well.
uint64_t mutation_counter_ = 0;
uint64_t mutation_hidden_from_autograd_counter_ = 0;
uint64_t mutation_during_no_grad_or_inference_mode_ = 0;
bool has_metadata_mutation_ = false;
bool is_multi_output_view_ = false;
// Did the tensor experience a set_() call.
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/FunctionalizeFallbackKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,15 +316,15 @@ static at::Tensor& set__functionalize(at::Tensor& self, const at::Tensor& src) {
TORCH_CHECK(at::functionalization::impl::isFunctionalTensor(self) || !at::functionalization::impl::isFunctionalTensor(src),
"set__functionalize: Tried to mutate a non-functional tensor with a functional tensor, which is not allowed");

TORCH_CHECK(at::functionalization::impl::isFunctionalTensor(src),
"set__functionalize: We do not currently support x.set_(y) where y is not a FunctionalTensor. Please file an issue");

// nop case
if (!at::functionalization::impl::isFunctionalTensor(self) && !at::functionalization::impl::isFunctionalTensor(src)) {
at::AutoDispatchSkipFunctionalize guard;
return self.set_(src);
}

TORCH_CHECK(at::functionalization::impl::isFunctionalTensor(src),
"set__functionalize: We do not currently support x.set_(y) where y is not a FunctionalTensor. Please file an issue");

TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self));
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(src));
auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
Expand Down
30 changes: 30 additions & 0 deletions aten/src/ATen/cpu/vec/vec256/vec256_convert.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,36 @@ struct VecConvert<int32_t, 1, uint8_t, 1> {
}
};

template <typename dst_t>
struct VecConvert<
dst_t,
1,
float,
1,
typename std::enable_if_t<
std::is_same_v<dst_t, unsigned char> || std::is_same_v<dst_t, signed char>,
void>> {
static inline VectorizedN<dst_t, 1> apply(
const VectorizedN<float, 1>& src) {
return convert_float_to_int8<dst_t>(src[0]);
}
};

template <typename src_t>
struct VecConvert<
float,
1,
src_t,
1,
typename std::enable_if_t<
std::is_same_v<src_t, unsigned char> || std::is_same_v<src_t, signed char>,
void>> {
static inline VectorizedN<float, 1> apply(
const VectorizedN<src_t, 1>& src) {
return convert_int8_to_float<src_t>(src[0]);
}
};

template <typename dst_t>
struct VecConvert<
dst_t,
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/cpu/vec/vec256/vec256_half_neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -565,9 +565,9 @@ class Vectorized<c10::Half> {
}

Vectorized<c10::Half> operator!=(const Vectorized<c10::Half>& other) const {
float32x4_t r0 = vreinterpretq_f16_u16(
float16x8_t r0 = vreinterpretq_f16_u16(
vmvnq_u16(vceqq_f16(values.val[0], other.values.val[0])));
float32x4_t r1 = vreinterpretq_f16_u16(
float16x8_t r1 = vreinterpretq_f16_u16(
vmvnq_u16(vceqq_f16(values.val[1], other.values.val[1])));
return Vectorized<c10::Half>(r0, r1);
}
Expand Down
30 changes: 30 additions & 0 deletions aten/src/ATen/cpu/vec/vec512/vec512_convert.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,36 @@ struct VecConvert<int32_t, 1, uint8_t, 1> {
}
};

template <typename dst_t>
struct VecConvert<
dst_t,
1,
float,
1,
typename std::enable_if_t<
std::is_same_v<dst_t, unsigned char> || std::is_same_v<dst_t, signed char>,
void>> {
static inline VectorizedN<dst_t, 1> apply(
const VectorizedN<float, 1>& src) {
return convert_float_to_int8<dst_t>(src[0]);
}
};

template <typename src_t>
struct VecConvert<
float,
1,
src_t,
1,
typename std::enable_if_t<
std::is_same_v<src_t, unsigned char> || std::is_same_v<src_t, signed char>,
void>> {
static inline VectorizedN<float, 1> apply(
const VectorizedN<src_t, 1>& src) {
return convert_int8_to_float<src_t>(src[0]);
}
};

template <typename dst_t>
struct VecConvert<
dst_t,
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cuda/cub-RadixSortKeys.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,8 @@ void radix_sort_keys(
int64_t end_bit);

AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTATIATE_CUB_TEMPLATES)
AT_INSTATIATE_CUB_TEMPLATES(uint16_t, UInt16)
AT_INSTATIATE_CUB_TEMPLATES(uint32_t, UInt32)
AT_INSTATIATE_CUB_TEMPLATES(uint64_t, UInt64)

} // namespace at::cuda::cub
3 changes: 3 additions & 0 deletions aten/src/ATen/cuda/cub-RadixSortPairs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ AT_INSTANTIATE_SORT_PAIRS(int64_t, 4)
AT_INSTANTIATE_SORT_PAIRS(scalar_t, 8)

AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTANTIATE_SORT_PAIRS_8)
AT_INSTANTIATE_SORT_PAIRS(uint16_t, 8)
AT_INSTANTIATE_SORT_PAIRS(uint32_t, 8)
AT_INSTANTIATE_SORT_PAIRS(uint64_t, 8)

// BFloat16 Radix sort is supported from ROCm 4.5 onwards
#if !AT_ROCM_ENABLED() || (AT_ROCM_ENABLED() && ROCM_VERSION >= 40500)
Expand Down
Loading

0 comments on commit 32bb41c

Please sign in to comment.