Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP + SP does not work with --compile #61

Closed
tianyu-l opened this issue Feb 16, 2024 · 9 comments
Closed

FSDP + SP does not work with --compile #61

tianyu-l opened this issue Feb 16, 2024 · 9 comments
Labels
bug Something isn't working

Comments

@tianyu-l
Copy link
Contributor

FSDP + SP works fine when compile is off, but got the following error when compile is on:

error log SP=2 ./run_llama_train.sh + TRAINER_DIR=/home/lty/local/torchtrain + MODEL=llama + MODEL_CONF=debugmodel + NGPU=8 + PP=1 + SP=2 + DP=-1 + LOG_RANK=0 + CHECKPOINT_FOLDER= + CHECKPOINT_INTERVAL=5 + torchrun --nproc_per_node=8 --rdzv_endpoint=localhost:5972 --local-ranks-filter 0 --role rank --tee 3 train.py --steps 10 --model llama --model_conf debugmodel --pp_degree 1 --sp_degree 2 --dp_degree -1 --compile --checkpoint-folder= --checkpoint-interval=5 W0215 17:38:16.585000 140337690436736 torch/distributed/run.py:717 W0215 17:38:16.585000 140337690436736 torch/distributed/run.py:717 ***************************************** W0215 17:38:16.585000 140337690436736 torch/distributed/run.py:717 Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0215 17:38:16.585000 140337690436736 torch/distributed/run.py:717 ***************************************** [rank0]:2024-02-15 17:38:20,132 - torchtrain.parallelisms - INFO - Building 2-D device mesh with ('dp', 'sp'), [4, 2] [rank0]:2024-02-15 17:38:28,308 - root - INFO - Building llama [rank0]:2024-02-15 17:38:28,325 - root - INFO - Reloaded SentencePiece model from ./torchtrain/datasets/tokenizer/tokenizer.model [rank0]:2024-02-15 17:38:28,325 - root - INFO - #words: 32000 - BOS ID: 1 - EOS ID: 2 [rank0]:2024-02-15 17:38:31,662 - root - INFO - Model fully initialized via reset_params [rank0]:2024-02-15 17:38:31,662 - root - INFO - Model built with: ModelArgs(dim=256, n_layers=2, n_heads=16, n_kv_heads=None, vocab_size=32000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, max_batch_size=32, max_seq_len=32768) [rank0]:2024-02-15 17:38:31,662 - root - INFO - Model llama debugmodel size: 18,089,216 total parameters [rank0]:2024-02-15 17:38:31,663 - root - INFO - GPU memory usage: NVIDIA PG509-210 (0): 79.1537 GB capacity, 0.0 GB in-use, 0.0% in-use [rank0]:NCCL version 2.19.3+cuda12.0 [rank0]:2024-02-15 17:38:36,274 - root - INFO - Applied Sequence Parallelism to the model... [rank0]:2024-02-15 17:38:36,575 - root - INFO - Applied FSDP to the model... [rank0]:2024-02-15 17:38:36,579 - root - INFO - Gradient scaling not enabled. [rank0]:2024-02-15 17:38:36,579 - root - INFO - Metrics logging active. Tensorboard logs will be saved at ./torchtrain/outputs/tb/20240215-1738. [rank0]:2024-02-15 17:38:36,580 - root - INFO - Compiling model llama with torch.compile... [rank0]:2024-02-15 17:38:40,957 - root - INFO - Profiling active. Traces will be saved at ./torchtrain/outputs/profiling/traces [rank0]:[rank0]:W0215 17:38:41.362000 139938524181632 torch/_logging/_internal.py:873 [0/0] Profiler function will be ignored [rank0]:/home/lty/pytorch/torch/_inductor/lowering.py:1704: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:[rank0]: Traceback (most recent call last): [rank0]:[rank0]: File "/home/lty/torchtrain/train.py", line 349, in [rank0]:[rank0]: main(args) [rank0]:[rank0]: File "/home/lty/torchtrain/train.py", line 179, in main [rank0]:[rank0]: pred = model(input_ids) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1529, in _wrapped_call_impl [rank0]:[rank0]: return self._call_impl(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1538, in _call_impl [rank0]:[rank0]: return forward_call(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/eval_frame.py", line 455, in _fn [rank0]:[rank0]: return fn(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/external_utils.py", line 25, in inner [rank0]:[rank0]: return fn(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1529, in _wrapped_call_impl [rank0]:[rank0]: return self._call_impl(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1538, in _call_impl [rank0]:[rank0]: return forward_call(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 853, in forward [rank0]:[rank0]: output = self._fsdp_wrapped_module(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1529, in _wrapped_call_impl [rank0]:[rank0]: return self._call_impl(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1538, in _call_impl [rank0]:[rank0]: return forward_call(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/torchtrain/torchtrain/models/llama/model.py", line 482, in forward [rank0]:[rank0]: def forward(self, tokens: torch.Tensor): [rank0]:[rank0]: File "/home/lty/torchtrain/torchtrain/models/llama/model.py", line 498, in torch_dynamo_resume_in_forward_at_493 [rank0]:[rank0]: h = layer(h, freqs_cis) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1529, in _wrapped_call_impl [rank0]:[rank0]: return self._call_impl(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1538, in _call_impl [rank0]:[rank0]: return forward_call(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 853, in forward [rank0]:[rank0]: output = self._fsdp_wrapped_module(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1529, in _wrapped_call_impl [rank0]:[rank0]: return self._call_impl(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1538, in _call_impl [rank0]:[rank0]: return forward_call(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/convert_frame.py", line 912, in catch_errors [rank0]:[rank0]: return callback(frame, cache_entry, hooks, frame_state, skip=1) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/convert_frame.py", line 777, in _convert_frame [rank0]:[rank0]: result = inner_convert( [rank0]:[rank0]: ^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/convert_frame.py", line 398, in _convert_frame_assert [rank0]:[rank0]: return _compile( [rank0]:[rank0]: ^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/.conda/envs/pytorch-3.11/lib/python3.11/contextlib.py", line 81, in inner [rank0]:[rank0]: return func(*args, **kwds) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/convert_frame.py", line 669, in _compile [rank0]:[rank0]: guarded_code = compile_inner(code, one_graph, hooks, transform) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/utils.py", line 250, in time_wrapper [rank0]:[rank0]: r = func(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/convert_frame.py", line 542, in compile_inner [rank0]:[rank0]: out_code = transform_code_object(code, transform) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object [rank0]:[rank0]: transformations(instructions, code_options) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/convert_frame.py", line 163, in _fn [rank0]:[rank0]: return fn(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/convert_frame.py", line 507, in transform [rank0]:[rank0]: tracer.run() [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2130, in run [rank0]:[rank0]: super().run() [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 793, in run [rank0]:[rank0]: and self.step() [rank0]:[rank0]: ^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 756, in step [rank0]:[rank0]: getattr(self, inst.opname)(inst) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper [rank0]:[rank0]: return inner_fn(self, inst) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 1243, in CALL_FUNCTION_EX [rank0]:[rank0]: self.call_function(fn, argsvars.items, kwargsvars) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in call_function [rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs)) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 734, in call_function [rank0]:[rank0]: return self.func.call_function(tx, merged_args, merged_kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 1392, in call_function [rank0]:[rank0]: ) = self.create_wrapped_node( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 1204, in create_wrapped_node [rank0]:[rank0]: ) = speculate_subgraph( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 396, in speculate_subgraph [rank0]:[rank0]: output = f.call_function(tx, args, sub_kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/nn_module.py", line 716, in call_function [rank0]:[rank0]: return variables.UserFunctionVariable(fn, source=source).call_function( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 288, in call_function [rank0]:[rank0]: return super().call_function(tx, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 89, in call_function [rank0]:[rank0]: return tx.inline_user_function_return( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 663, in inline_user_function_return [rank0]:[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2266, in inline_call [rank0]:[rank0]: return cls.inline_call_(parent, func, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2380, in inline_call_ [rank0]:[rank0]: tracer.run() [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 793, in run [rank0]:[rank0]: and self.step() [rank0]:[rank0]: ^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 756, in step [rank0]:[rank0]: getattr(self, inst.opname)(inst) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper [rank0]:[rank0]: return inner_fn(self, inst) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 1243, in CALL_FUNCTION_EX [rank0]:[rank0]: self.call_function(fn, argsvars.items, kwargsvars) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in call_function [rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs)) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function [rank0]:[rank0]: return super().call_function(tx, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 288, in call_function [rank0]:[rank0]: return super().call_function(tx, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 89, in call_function [rank0]:[rank0]: return tx.inline_user_function_return( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 663, in inline_user_function_return [rank0]:[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2266, in inline_call [rank0]:[rank0]: return cls.inline_call_(parent, func, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2380, in inline_call_ [rank0]:[rank0]: tracer.run() [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 793, in run [rank0]:[rank0]: and self.step() [rank0]:[rank0]: ^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 756, in step [rank0]:[rank0]: getattr(self, inst.opname)(inst) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper [rank0]:[rank0]: return inner_fn(self, inst) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 1785, in CALL [rank0]:[rank0]: self.call_function(fn, args, kwargs) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in call_function [rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs)) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/nn_module.py", line 716, in call_function [rank0]:[rank0]: return variables.UserFunctionVariable(fn, source=source).call_function( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 288, in call_function [rank0]:[rank0]: return super().call_function(tx, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 89, in call_function [rank0]:[rank0]: return tx.inline_user_function_return( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 663, in inline_user_function_return [rank0]:[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2266, in inline_call [rank0]:[rank0]: return cls.inline_call_(parent, func, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2380, in inline_call_ [rank0]:[rank0]: tracer.run() [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 793, in run [rank0]:[rank0]: and self.step() [rank0]:[rank0]: ^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 756, in step [rank0]:[rank0]: getattr(self, inst.opname)(inst) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper [rank0]:[rank0]: return inner_fn(self, inst) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 1785, in CALL [rank0]:[rank0]: self.call_function(fn, args, kwargs) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in call_function [rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs)) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/lazy.py", line 94, in realize_and_forward [rank0]:[rank0]: return getattr(self.realize(), name)(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 288, in call_function [rank0]:[rank0]: return super().call_function(tx, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 89, in call_function [rank0]:[rank0]: return tx.inline_user_function_return( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 663, in inline_user_function_return [rank0]:[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2266, in inline_call [rank0]:[rank0]: return cls.inline_call_(parent, func, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2380, in inline_call_ [rank0]:[rank0]: tracer.run() [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 793, in run [rank0]:[rank0]: and self.step() [rank0]:[rank0]: ^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 756, in step [rank0]:[rank0]: getattr(self, inst.opname)(inst) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper [rank0]:[rank0]: return inner_fn(self, inst) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 1785, in CALL [rank0]:[rank0]: self.call_function(fn, args, kwargs) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in call_function [rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs)) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function [rank0]:[rank0]: return super().call_function(tx, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 288, in call_function [rank0]:[rank0]: return super().call_function(tx, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 89, in call_function [rank0]:[rank0]: return tx.inline_user_function_return( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 663, in inline_user_function_return [rank0]:[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2266, in inline_call [rank0]:[rank0]: return cls.inline_call_(parent, func, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2380, in inline_call_ [rank0]:[rank0]: tracer.run() [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 793, in run [rank0]:[rank0]: and self.step() [rank0]:[rank0]: ^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 756, in step [rank0]:[rank0]: getattr(self, inst.opname)(inst) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper [rank0]:[rank0]: return inner_fn(self, inst) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 1785, in CALL [rank0]:[rank0]: self.call_function(fn, args, kwargs) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in call_function [rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs)) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/misc.py", line 547, in call_function [rank0]:[rank0]: return self.obj.call_method(tx, self.name, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/tensor.py", line 388, in call_method [rank0]:[rank0]: result = handler_method(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/tensor.py", line 730, in method_redistribute [rank0]:[rank0]: return wrap_fx_proxy( [rank0]:[rank0]: ^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/builder.py", line 1273, in wrap_fx_proxy [rank0]:[rank0]: return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/builder.py", line 1358, in wrap_fx_proxy_cls [rank0]:[rank0]: example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/utils.py", line 1683, in get_fake_value [rank0]:[rank0]: raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/utils.py", line 1629, in get_fake_value [rank0]:[rank0]: ret_val = wrap_fake_exception( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/utils.py", line 1165, in wrap_fake_exception [rank0]:[rank0]: return fn() [rank0]:[rank0]: ^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/utils.py", line 1630, in [rank0]:[rank0]: lambda: run_node(tx.output, node, args, kwargs, nnmodule) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/utils.py", line 1750, in run_node [rank0]:[rank0]: raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/utils.py", line 1729, in run_node [rank0]:[rank0]: return node.target(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/tensor.py", line 723, in redistribute_fn_with_prim_types [rank0]:[rank0]: return x.redistribute(*args_as_value, **kwargs_as_value) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/_tensor/api.py", line 467, in redistribute [rank0]:[rank0]: return Redistribute.apply(self, device_mesh, placements) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/autograd/function.py", line 572, in apply [rank0]:[rank0]: return super().apply(*args, **kwargs) # type: ignore[misc] [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/_tensor/redistribute.py", line 263, in forward [rank0]:[rank0]: output = redistribute_local_tensor(local_tensor, current_spec, target_spec) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/_tensor/redistribute.py", line 164, in redistribute_local_tensor [rank0]:[rank0]: transform_infos = _gen_transform_infos(current_spec, target_spec) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/_tensor/placement_types.py", line 441, in __hash__ [rank0]:[rank0]: self._hash = self._hash_impl() [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/_tensor/placement_types.py", line 424, in _hash_impl [rank0]:[rank0]: return hash( [rank0]:[rank0]: ^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/__init__.py", line 309, in __hash__ [rank0]:[rank0]: raise TypeError("unhashable type: non-singleton SymInt") [rank0]:[rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function .redistribute_fn_with_prim_types at 0x7f45431c1b20>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('sp',)), placements=(Shard(dim=0),)),), **{}): [rank0]:[rank0]: unhashable type: non-singleton SymInt [rank0]: [rank0]:[rank0]: from user code: [rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 168, in forward [rank0]:[rank0]: return self.checkpoint_fn( # type: ignore[misc] [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1538, in _call_impl [rank0]:[rank0]: return forward_call(*args, **kwargs) [rank0]:[rank0]: File "/home/lty/torchtrain/torchtrain/models/llama/model.py", line 413, in forward [rank0]:[rank0]: h = x + self.attention(self.attention_norm(x), freqs_cis) [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1568, in _call_impl [rank0]:[rank0]: args_result = hook(self, args) [rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/tensor/parallel/style.py", line 323, in [rank0]:[rank0]: module.register_forward_pre_hook(lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)) # type: ignore[misc, call-arg] [rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/tensor/parallel/style.py", line 316, in _prepare_input_fn [rank0]:[rank0]: dt_inp = dt_inp.redistribute(placements=(desired_layout,)) [rank0]: [rank0]:[rank0]: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information [rank0]: [rank0]: [rank0]:[rank0]: You can suppress this exception and fall back to eager by setting: [rank0]:[rank0]: import torch._dynamo [rank0]:[rank0]: torch._dynamo.config.suppress_errors = True [rank0]: W0215 17:39:06.601000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:694 Sending process 2321633 closing signal SIGTERM W0215 17:39:06.602000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:694 Sending process 2321634 closing signal SIGTERM W0215 17:39:06.603000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:694 Sending process 2321636 closing signal SIGTERM W0215 17:39:06.604000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:694 Sending process 2321637 closing signal SIGTERM W0215 17:39:06.605000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:694 Sending process 2321638 closing signal SIGTERM W0215 17:39:06.606000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:694 Sending process 2321639 closing signal SIGTERM W0215 17:39:06.608000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:694 Sending process 2321641 closing signal SIGTERM E0215 17:39:09.856000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:669 failed (exitcode: 1) local_rank: 0 (pid: 2321629) of binary: /home/lty/.conda/envs/pytorch-3.11/bin/python Traceback (most recent call last): File "/home/lty/.conda/envs/pytorch-3.11/bin/torchrun", line 33, in sys.exit(load_entry_point('torch', 'console_scripts', 'torchrun')()) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/lty/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper return f(*args, **kwargs) ^^^^^^^^^^^^^^^^^^ File "/home/lty/pytorch/torch/distributed/run.py", line 834, in main run(args) File "/home/lty/pytorch/torch/distributed/run.py", line 825, in run elastic_launch( File "/home/lty/pytorch/torch/distributed/launcher/api.py", line 137, in __call__ return launch_agent(self._config, self._entrypoint, list(args)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/lty/pytorch/torch/distributed/launcher/api.py", line 271, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError: ============================================================ train.py FAILED ------------------------------------------------------------ Failures: ------------------------------------------------------------ Root Cause (first observed failure): [0]: time : 2024-02-15_17:39:06 host : devgpu051.cln3.facebook.com rank : 0 (local_rank: 0) exitcode : 1 (pid: 2321629) error_file: traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html ============================================================
@wanchaol
Copy link
Contributor

Yeah this is sth still not working atm due to we compile out of fsdp wrapping and it triggered some issues I think:

  1. this issue specifically is about dynamic shapes, so after graph break we'll hit dynamic shape for each subgraph, which is not ideal.
  2. if I turn dynamic=False, then hitting new issue that dynamo incorrectly trace into DTensor somewhere

cc @bdhirsh we probably need to study the dynamic shape issue if we want 2D parallelism work with torch.compile as "default" setting

@wanchaol
Copy link
Contributor

@bdhirsh steps to repro:

  1. after set up the repo, on a devgpu with 8 GPUs, change SP degree to 2 or 4. https://github.com/pytorch-labs/torchtrain/blob/main/run_llama_train.sh#L15
  2. ./run_llama_train.sh

Should be able to hit the issues

@drisspg
Copy link
Contributor

drisspg commented Mar 20, 2024

Repining here, I am seeing:

[rank0]:WARNING: Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:NCCL version 2.20.5+cuda12.4
[rank0]:/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_inductor/lowering.py:1789: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:[rank0]: Traceback (most recent call last):
[rank0]:[rank0]:   File "/home/drisspg/meta/torchtrain/train.py", line 361, in <module>
[rank0]:[rank0]:     main(config)
[rank0]:[rank0]:   File "/home/drisspg/meta/torchtrain/train.py", line 247, in main
[rank0]:[rank0]:     pred = model(input_ids)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
[rank0]:[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
[rank0]:[rank0]:     return forward_call(*args, **kwargs)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 390, in _fn
[rank0]:[rank0]:     return fn(*args, **kwargs)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
[rank0]:[rank0]:     return fn(*args, **kwargs)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
[rank0]:[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
[rank0]:[rank0]:     return forward_call(*args, **kwargs)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 857, in forward
[rank0]:[rank0]:     output = self._fsdp_wrapped_module(*args, **kwargs)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
[rank0]:[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
[rank0]:[rank0]:     return forward_call(*args, **kwargs)
[rank0]:[rank0]:   File "/home/drisspg/meta/torchtrain/torchtrain/models/llama/model.py", line 504, in forward
[rank0]:[rank0]:     h, freqs_cis = self.embeddings(tokens)
[rank0]:[rank0]:   File "/home/drisspg/meta/torchtrain/torchtrain/models/llama/model.py", line 515, in torch_dynamo_resume_in_forward_at_504
[rank0]:[rank0]:     h = h.view(bsz, bs_seqlen // bsz, self.model_args.dim)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/distributed/_tensor/api.py", line 279, in __torch_dispatch__
[rank0]:[rank0]:     return DTensor._op_dispatcher.dispatch(
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/distributed/_tensor/dispatch.py", line 229, in dispatch
[rank0]:[rank0]:     return self.wrap(local_results, output_sharding.output_spec)  # type: ignore[possibly-undefined]
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/distributed/_tensor/dispatch.py", line 368, in wrap
[rank0]:[rank0]:     return dtensor.DTensor(
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/distributed/_tensor/api.py", line 229, in __new__
[rank0]:[rank0]:     r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 939, in catch_errors
[rank0]:[rank0]:     return callback(frame, cache_entry, hooks, frame_state, skip=1)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 802, in _convert_frame
[rank0]:[rank0]:     result = inner_convert(
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
[rank0]:[rank0]:     return _compile(
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:[rank0]:     return func(*args, **kwds)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 713, in _compile
[rank0]:[rank0]:     raise InternalTorchDynamoError(str(e)).with_traceback(
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 686, in _compile
[rank0]:[rank0]:     guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 264, in time_wrapper
[rank0]:[rank0]:     r = func(*args, **kwargs)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 541, in compile_inner
[rank0]:[rank0]:     out_code = transform_code_object(code, transform)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
[rank0]:[rank0]:     transformations(instructions, code_options)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
[rank0]:[rank0]:     return fn(*args, **kwargs)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 503, in transform
[rank0]:[rank0]:     tracer.run()
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2152, in run
[rank0]:[rank0]:     super().run()
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 850, in run
[rank0]:[rank0]:     while self.step():
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 764, in step
[rank0]:[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 919, in STORE_FAST
[rank0]:[rank0]:     loaded_vt.set_name_hint(name)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 91, in realize_and_forward
[rank0]:[rank0]:     return getattr(self.realize(), name)(*args, **kwargs)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 57, in realize
[rank0]:[rank0]:     self._cache.realize()
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 24, in realize
[rank0]:[rank0]:     self.vt = VariableBuilder(tx, self.source)(self.value)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 274, in __call__
[rank0]:[rank0]:     vt = self._wrap(value)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 424, in _wrap
[rank0]:[rank0]:     return self.wrap_tensor(value)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1047, in wrap_tensor
[rank0]:[rank0]:     self.assert_not_wrapped_by_this_graph(value)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 978, in assert_not_wrapped_by_this_graph
[rank0]:[rank0]:     if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode:
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 123, in is_fake
[rank0]:[rank0]:     attrs, _ = type(x).__tensor_flatten__(x)
[rank0]:[rank0]:   File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/distributed/_tensor/api.py", line 256, in __tensor_flatten__
[rank0]:[rank0]:     return ["_local_tensor"], (self._spec, self.requires_grad)
[rank0]:[rank0]: torch._dynamo.exc.InternalTorchDynamoError: 'DTensor' object has no attribute '_spec'
[rank0]:
[rank0]:[rank0]: from user code:
[rank0]:[rank0]:    File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/distributed/_tensor/api.py", line 229, in torch_dynamo_resume_in___new___at_229
[rank0]:[rank0]:     r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
[rank0]:
[rank0]:[rank0]: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
[rank0]:
[rank0]:
[rank0]:[rank0]: You can suppress this exception and fall back to eager by setting:
[rank0]:[rank0]:     import torch._dynamo
[rank0]:[rank0]:     torch._dynamo.config.suppress_errors = True

cc @bdhirsh

@bdhirsh
Copy link

bdhirsh commented Apr 4, 2024

The stack here: pytorch/pytorch#123347 looks like it's finally enough to get the torchtrain repro working, with these change:

diff --git a/train.py b/train.py
index 849ae78..171842a 100644
--- a/train.py
+++ b/train.py
@@ -221,7 +221,7 @@ def main(job_config: JobConfig):
                 True
             )
         logger.info("Compiling model with torch.compile")
-        model = torch.compile(model)
+        model = torch.compile(model, backend='inductor', dynamic=False)

     train_state = TrainState()

diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml
index c84407c..88192fc 100644
--- a/train_configs/debug_model.toml
+++ b/train_configs/debug_model.toml
@@ -37,10 +37,10 @@ warmup_steps = 2  # lr scheduler warm up, normally 20% of the train steps
 max_norm = 1.0  # grad norm clipping
 steps = 10
 data_parallel_degree = -1
-tensor_parallel_degree = 1
+tensor_parallel_degree = 2
 pipeline_parallel_degree = 1
 fp8_linear = ""
-compile = false
+compile = true
 dataset = "alpaca"   # supported datasets = alpaca (52K), minipile (1M), c4 (177M)

 [activation_checkpoint]

I had to turn off dynamic shapes - I spent some time fixing a few dynamic shapes issue with DTensor in this PR: pytorch/pytorch#123349, but there are more. So maybe for now, we can run all of our torch.compile testing (e.g. with Float8, cc @wanchaol @drisspg @vkuzo ) with dynamic=False, and kick the can some more on dynamic shapes (hopefully I'll have more time to keep looking at this).

bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 10, 2024
…es in a few places"

This was the result of me slogging through errors from running this repro: pytorch/torchtitan#61.

This doesn't get dynamic shapes fulling working, but it does fix some obvious errors I hit.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 10, 2024
This was the result of me slogging through errors from running this repro: pytorch/torchtitan#61.

This doesn't get dynamic shapes fulling working, but it does fix some obvious errors I hit.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 10, 2024
…new__ or __torch_dispatch__"

Fixes #122459, pytorch/torchtitan#61

Even with the previous PR ("support DTensor/subclass constructors directly in the graph"), I still see some errors when running the repro above that start some logs showing that dynamo is inlining `__new__`.

I noticed that putting `torch._dynamo.disable` on DTensor's `__new__` makes the entire repro pass.

Why does having dynamo try to inline `Subclass.__new__` run into problems? Morally, dynamo probably shouldn't be inlining __new__ ("creating a subclass" is a blackbox operation that AOTAutograd can trace through anyway). But concretely, we can end up with a node in the dynamo FX graph that has a "partially initialized tensor subclass" as its example value, because the subclass has been created but its fields have not been assigned to yet.

This breaks a bunch of invariants throughout dynamo: there are many places where if we have a tensor subclass node, we want to look at its inner tensors, to see if they are FakeTensors, what their FakeTensorMode is, and if they have dynamic shapes.

One option is to decide that "uninitialized subclass" is a first-class thing that anyone looking at the FX node examples values on the dynamo graph needs to handle, but this seems like a lot of work when in reality we don't need dynamo to trace the __new__ at all. Hence the `torch._dynamo.disable`.

I still wasn't very satisfied, since it was unclear to me **why** dynamo was inlining the `__new__` call, instead of interposing on the `DTensor()` constructor directly. After a long chat with anijain2305, he explained that with code like this:
```
torch._dynamo.disable(recursive=False)
def f(x):
    out = SubclassConstructor(x)
```

Dynamo will never get the chance to interpose on the subclass constructor. Instead, what will happen is:
(1) Dynamo hands back control to cpython to run `f()`, since we disabled that frame
(2) `SubclassConstructor(x)` is run in eager mode
(3) `SubclassConstructor(x)` eventually calls `SubclassConstructor__new__`
(4) this is a new frame, that cpython then allows dynamo to intercept and start compiling

So it looks like we are basically forced to handle the situation where dynamo might directly start compiling `Subclass.__new__`

All of the above does not explain the story for `__torch_dispatch__` though. Empirically, I have a repro in torchtrain where looking at the dynamo logs, we see dynamo try to inline `__torch_dispatch__`.
```
[rank0]:DEBUG: Skipping frame because no content in function call _prepare_output_fn                     /data/users/hirsheybar/b/pytorch/torch/distributed/tensor/parallel/style.py 318
[rank0]:DEBUG: torchdynamo start compiling __torch_dispatch__ /data/users/hirsheybar/b/pytorch/torch/distributed/_tensor/api.py:297, stack (elided 5 frames):
```

I haven't been able to create a smaller repro of the problem (even using `_dynamo.disable(recursive=False)`), although in theory, if there is a `torch.*` op that you were to inline (where one of the inputs is a subclass), the next frame would likely be `__torch_dispatch__`. Dynamo always treats `torch.*` operations as not-inlinable though, so in theory we shouldn't ever see dynamo inline `__torch_dispatch__`, but a `_dynamo.disable()` fixes the problem.

I asked Animesh if we can have dynamo automatically apply this behavior to subclasses instead of needing it to be added explicitly. He pointed out that for `disable(recursive=False)`, we can't really do this within dynamo




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 10, 2024
…_dispatch__"

Fixes #122459, pytorch/torchtitan#61

Even with the previous PR ("support DTensor/subclass constructors directly in the graph"), I still see some errors when running the repro above that start some logs showing that dynamo is inlining `__new__`.

I noticed that putting `torch._dynamo.disable` on DTensor's `__new__` makes the entire repro pass.

Why does having dynamo try to inline `Subclass.__new__` run into problems? Morally, dynamo probably shouldn't be inlining __new__ ("creating a subclass" is a blackbox operation that AOTAutograd can trace through anyway). But concretely, we can end up with a node in the dynamo FX graph that has a "partially initialized tensor subclass" as its example value, because the subclass has been created but its fields have not been assigned to yet.

This breaks a bunch of invariants throughout dynamo: there are many places where if we have a tensor subclass node, we want to look at its inner tensors, to see if they are FakeTensors, what their FakeTensorMode is, and if they have dynamic shapes.

One option is to decide that "uninitialized subclass" is a first-class thing that anyone looking at the FX node examples values on the dynamo graph needs to handle, but this seems like a lot of work when in reality we don't need dynamo to trace the __new__ at all. Hence the `torch._dynamo.disable`.

I still wasn't very satisfied, since it was unclear to me **why** dynamo was inlining the `__new__` call, instead of interposing on the `DTensor()` constructor directly. After a long chat with anijain2305, he explained that with code like this:
```
torch._dynamo.disable(recursive=False)
def f(x):
    out = SubclassConstructor(x)
```

Dynamo will never get the chance to interpose on the subclass constructor. Instead, what will happen is:
(1) Dynamo hands back control to cpython to run `f()`, since we disabled that frame
(2) `SubclassConstructor(x)` is run in eager mode
(3) `SubclassConstructor(x)` eventually calls `SubclassConstructor__new__`
(4) this is a new frame, that cpython then allows dynamo to intercept and start compiling

So it looks like we are basically forced to handle the situation where dynamo might directly start compiling `Subclass.__new__`

All of the above does not explain the story for `__torch_dispatch__` though. Empirically, I have a repro in torchtrain where looking at the dynamo logs, we see dynamo try to inline `__torch_dispatch__`.
```
[rank0]:DEBUG: Skipping frame because no content in function call _prepare_output_fn                     /data/users/hirsheybar/b/pytorch/torch/distributed/tensor/parallel/style.py 318
[rank0]:DEBUG: torchdynamo start compiling __torch_dispatch__ /data/users/hirsheybar/b/pytorch/torch/distributed/_tensor/api.py:297, stack (elided 5 frames):
```

I haven't been able to create a smaller repro of the problem (even using `_dynamo.disable(recursive=False)`), although in theory, if there is a `torch.*` op that you were to inline (where one of the inputs is a subclass), the next frame would likely be `__torch_dispatch__`. Dynamo always treats `torch.*` operations as not-inlinable though, so in theory we shouldn't ever see dynamo inline `__torch_dispatch__`, but a `_dynamo.disable()` fixes the problem.

I asked Animesh if we can have dynamo automatically apply this behavior to subclasses instead of needing it to be added explicitly. He pointed out that for `disable(recursive=False)`, we can't really do this within dynamo




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 11, 2024
…new__ or __torch_dispatch__"

Fixes #122459, pytorch/torchtitan#61

Even with the previous PR ("support DTensor/subclass constructors directly in the graph"), I still see some errors when running the repro above that start some logs showing that dynamo is inlining `__new__`.

I noticed that putting `torch._dynamo.disable` on DTensor's `__new__` makes the entire repro pass.

Why does having dynamo try to inline `Subclass.__new__` run into problems? Morally, dynamo probably shouldn't be inlining __new__ ("creating a subclass" is a blackbox operation that AOTAutograd can trace through anyway). But concretely, we can end up with a node in the dynamo FX graph that has a "partially initialized tensor subclass" as its example value, because the subclass has been created but its fields have not been assigned to yet.

This breaks a bunch of invariants throughout dynamo: there are many places where if we have a tensor subclass node, we want to look at its inner tensors, to see if they are FakeTensors, what their FakeTensorMode is, and if they have dynamic shapes.

One option is to decide that "uninitialized subclass" is a first-class thing that anyone looking at the FX node examples values on the dynamo graph needs to handle, but this seems like a lot of work when in reality we don't need dynamo to trace the __new__ at all. Hence the `torch._dynamo.disable`.

I still wasn't very satisfied, since it was unclear to me **why** dynamo was inlining the `__new__` call, instead of interposing on the `DTensor()` constructor directly. After a long chat with anijain2305, he explained that with code like this:
```
torch._dynamo.disable(recursive=False)
def f(x):
    out = SubclassConstructor(x)
```

Dynamo will never get the chance to interpose on the subclass constructor. Instead, what will happen is:
(1) Dynamo hands back control to cpython to run `f()`, since we disabled that frame
(2) `SubclassConstructor(x)` is run in eager mode
(3) `SubclassConstructor(x)` eventually calls `SubclassConstructor__new__`
(4) this is a new frame, that cpython then allows dynamo to intercept and start compiling

So it looks like we are basically forced to handle the situation where dynamo might directly start compiling `Subclass.__new__`

All of the above does not explain the story for `__torch_dispatch__` though. Empirically, I have a repro in torchtrain where looking at the dynamo logs, we see dynamo try to inline `__torch_dispatch__`.
```
[rank0]:DEBUG: Skipping frame because no content in function call _prepare_output_fn                     /data/users/hirsheybar/b/pytorch/torch/distributed/tensor/parallel/style.py 318
[rank0]:DEBUG: torchdynamo start compiling __torch_dispatch__ /data/users/hirsheybar/b/pytorch/torch/distributed/_tensor/api.py:297, stack (elided 5 frames):
```

I haven't been able to create a smaller repro of the problem (even using `_dynamo.disable(recursive=False)`), although in theory, if there is a `torch.*` op that you were to inline (where one of the inputs is a subclass), the next frame would likely be `__torch_dispatch__`. Dynamo always treats `torch.*` operations as not-inlinable though, so in theory we shouldn't ever see dynamo inline `__torch_dispatch__`, but a `_dynamo.disable()` fixes the problem.

I asked Animesh if we can have dynamo automatically apply this behavior to subclasses instead of needing it to be added explicitly. He pointed out that for `disable(recursive=False)`, we can't really do this within dynamo




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 11, 2024
…_dispatch__"

Fixes #122459, pytorch/torchtitan#61

Even with the previous PR ("support DTensor/subclass constructors directly in the graph"), I still see some errors when running the repro above that start some logs showing that dynamo is inlining `__new__`.

I noticed that putting `torch._dynamo.disable` on DTensor's `__new__` makes the entire repro pass.

Why does having dynamo try to inline `Subclass.__new__` run into problems? Morally, dynamo probably shouldn't be inlining __new__ ("creating a subclass" is a blackbox operation that AOTAutograd can trace through anyway). But concretely, we can end up with a node in the dynamo FX graph that has a "partially initialized tensor subclass" as its example value, because the subclass has been created but its fields have not been assigned to yet.

This breaks a bunch of invariants throughout dynamo: there are many places where if we have a tensor subclass node, we want to look at its inner tensors, to see if they are FakeTensors, what their FakeTensorMode is, and if they have dynamic shapes.

One option is to decide that "uninitialized subclass" is a first-class thing that anyone looking at the FX node examples values on the dynamo graph needs to handle, but this seems like a lot of work when in reality we don't need dynamo to trace the __new__ at all. Hence the `torch._dynamo.disable`.

I still wasn't very satisfied, since it was unclear to me **why** dynamo was inlining the `__new__` call, instead of interposing on the `DTensor()` constructor directly. After a long chat with anijain2305, he explained that with code like this:
```
torch._dynamo.disable(recursive=False)
def f(x):
    out = SubclassConstructor(x)
```

Dynamo will never get the chance to interpose on the subclass constructor. Instead, what will happen is:
(1) Dynamo hands back control to cpython to run `f()`, since we disabled that frame
(2) `SubclassConstructor(x)` is run in eager mode
(3) `SubclassConstructor(x)` eventually calls `SubclassConstructor__new__`
(4) this is a new frame, that cpython then allows dynamo to intercept and start compiling

So it looks like we are basically forced to handle the situation where dynamo might directly start compiling `Subclass.__new__`

All of the above does not explain the story for `__torch_dispatch__` though. Empirically, I have a repro in torchtrain where looking at the dynamo logs, we see dynamo try to inline `__torch_dispatch__`.
```
[rank0]:DEBUG: Skipping frame because no content in function call _prepare_output_fn                     /data/users/hirsheybar/b/pytorch/torch/distributed/tensor/parallel/style.py 318
[rank0]:DEBUG: torchdynamo start compiling __torch_dispatch__ /data/users/hirsheybar/b/pytorch/torch/distributed/_tensor/api.py:297, stack (elided 5 frames):
```

I haven't been able to create a smaller repro of the problem (even using `_dynamo.disable(recursive=False)`), although in theory, if there is a `torch.*` op that you were to inline (where one of the inputs is a subclass), the next frame would likely be `__torch_dispatch__`. Dynamo always treats `torch.*` operations as not-inlinable though, so in theory we shouldn't ever see dynamo inline `__torch_dispatch__`, but a `_dynamo.disable()` fixes the problem.

I asked Animesh if we can have dynamo automatically apply this behavior to subclasses instead of needing it to be added explicitly. He pointed out that for `disable(recursive=False)`, we can't really do this within dynamo




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 12, 2024
…new__ or __torch_dispatch__"

Fixes #122459, pytorch/torchtitan#61

Even with the previous PR ("support DTensor/subclass constructors directly in the graph"), I still see some errors when running the repro above that start some logs showing that dynamo is inlining `__new__`.

I noticed that putting `torch._dynamo.disable` on DTensor's `__new__` makes the entire repro pass.

Why does having dynamo try to inline `Subclass.__new__` run into problems? Morally, dynamo probably shouldn't be inlining __new__ ("creating a subclass" is a blackbox operation that AOTAutograd can trace through anyway). But concretely, we can end up with a node in the dynamo FX graph that has a "partially initialized tensor subclass" as its example value, because the subclass has been created but its fields have not been assigned to yet.

This breaks a bunch of invariants throughout dynamo: there are many places where if we have a tensor subclass node, we want to look at its inner tensors, to see if they are FakeTensors, what their FakeTensorMode is, and if they have dynamic shapes.

One option is to decide that "uninitialized subclass" is a first-class thing that anyone looking at the FX node examples values on the dynamo graph needs to handle, but this seems like a lot of work when in reality we don't need dynamo to trace the __new__ at all. Hence the `torch._dynamo.disable`.

I still wasn't very satisfied, since it was unclear to me **why** dynamo was inlining the `__new__` call, instead of interposing on the `DTensor()` constructor directly. After a long chat with anijain2305, he explained that with code like this:
```
torch._dynamo.disable(recursive=False)
def f(x):
    out = SubclassConstructor(x)
```

Dynamo will never get the chance to interpose on the subclass constructor. Instead, what will happen is:
(1) Dynamo hands back control to cpython to run `f()`, since we disabled that frame
(2) `SubclassConstructor(x)` is run in eager mode
(3) `SubclassConstructor(x)` eventually calls `SubclassConstructor__new__`
(4) this is a new frame, that cpython then allows dynamo to intercept and start compiling

So it looks like we are basically forced to handle the situation where dynamo might directly start compiling `Subclass.__new__`

All of the above does not explain the story for `__torch_dispatch__` though. Empirically, I have a repro in torchtrain where looking at the dynamo logs, we see dynamo try to inline `__torch_dispatch__`.
```
[rank0]:DEBUG: Skipping frame because no content in function call _prepare_output_fn                     /data/users/hirsheybar/b/pytorch/torch/distributed/tensor/parallel/style.py 318
[rank0]:DEBUG: torchdynamo start compiling __torch_dispatch__ /data/users/hirsheybar/b/pytorch/torch/distributed/_tensor/api.py:297, stack (elided 5 frames):
```

I haven't been able to create a smaller repro of the problem (even using `_dynamo.disable(recursive=False)`), although in theory, if there is a `torch.*` op that you were to inline (where one of the inputs is a subclass), the next frame would likely be `__torch_dispatch__`. Dynamo always treats `torch.*` operations as not-inlinable though, so in theory we shouldn't ever see dynamo inline `__torch_dispatch__`, but a `_dynamo.disable()` fixes the problem.

I asked Animesh if we can have dynamo automatically apply this behavior to subclasses instead of needing it to be added explicitly. He pointed out that for `disable(recursive=False)`, we can't really do this within dynamo




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 12, 2024
…_dispatch__"

Fixes #122459, pytorch/torchtitan#61

Even with the previous PR ("support DTensor/subclass constructors directly in the graph"), I still see some errors when running the repro above that start some logs showing that dynamo is inlining `__new__`.

I noticed that putting `torch._dynamo.disable` on DTensor's `__new__` makes the entire repro pass.

Why does having dynamo try to inline `Subclass.__new__` run into problems? Morally, dynamo probably shouldn't be inlining __new__ ("creating a subclass" is a blackbox operation that AOTAutograd can trace through anyway). But concretely, we can end up with a node in the dynamo FX graph that has a "partially initialized tensor subclass" as its example value, because the subclass has been created but its fields have not been assigned to yet.

This breaks a bunch of invariants throughout dynamo: there are many places where if we have a tensor subclass node, we want to look at its inner tensors, to see if they are FakeTensors, what their FakeTensorMode is, and if they have dynamic shapes.

One option is to decide that "uninitialized subclass" is a first-class thing that anyone looking at the FX node examples values on the dynamo graph needs to handle, but this seems like a lot of work when in reality we don't need dynamo to trace the __new__ at all. Hence the `torch._dynamo.disable`.

I still wasn't very satisfied, since it was unclear to me **why** dynamo was inlining the `__new__` call, instead of interposing on the `DTensor()` constructor directly. After a long chat with anijain2305, he explained that with code like this:
```
torch._dynamo.disable(recursive=False)
def f(x):
    out = SubclassConstructor(x)
```

Dynamo will never get the chance to interpose on the subclass constructor. Instead, what will happen is:
(1) Dynamo hands back control to cpython to run `f()`, since we disabled that frame
(2) `SubclassConstructor(x)` is run in eager mode
(3) `SubclassConstructor(x)` eventually calls `SubclassConstructor__new__`
(4) this is a new frame, that cpython then allows dynamo to intercept and start compiling

So it looks like we are basically forced to handle the situation where dynamo might directly start compiling `Subclass.__new__`

All of the above does not explain the story for `__torch_dispatch__` though. Empirically, I have a repro in torchtrain where looking at the dynamo logs, we see dynamo try to inline `__torch_dispatch__`.
```
[rank0]:DEBUG: Skipping frame because no content in function call _prepare_output_fn                     /data/users/hirsheybar/b/pytorch/torch/distributed/tensor/parallel/style.py 318
[rank0]:DEBUG: torchdynamo start compiling __torch_dispatch__ /data/users/hirsheybar/b/pytorch/torch/distributed/_tensor/api.py:297, stack (elided 5 frames):
```

I haven't been able to create a smaller repro of the problem (even using `_dynamo.disable(recursive=False)`), although in theory, if there is a `torch.*` op that you were to inline (where one of the inputs is a subclass), the next frame would likely be `__torch_dispatch__`. Dynamo always treats `torch.*` operations as not-inlinable though, so in theory we shouldn't ever see dynamo inline `__torch_dispatch__`, but a `_dynamo.disable()` fixes the problem.

I asked Animesh if we can have dynamo automatically apply this behavior to subclasses instead of needing it to be added explicitly. He pointed out that for `disable(recursive=False)`, we can't really do this within dynamo




[ghstack-poisoned]
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Apr 15, 2024
…#123347)

Fixes #122459, pytorch/torchtitan#61

Even with the previous PR ("support DTensor/subclass constructors directly in the graph"), I still see some errors when running the repro above that start some logs showing that dynamo is inlining `__new__`.

I noticed that putting `@torch._dynamo.disable` on DTensor's `__new__` makes the entire repro pass.

Why does having dynamo try to inline `Subclass.__new__` run into problems? Morally, dynamo probably shouldn't be inlining __new__ ("creating a subclass" is a blackbox operation that AOTAutograd can trace through anyway). But concretely, we can end up with a node in the dynamo FX graph that has a "partially initialized tensor subclass" as its example value, because the subclass has been created but its fields have not been assigned to yet.

This breaks a bunch of invariants throughout dynamo: there are many places where if we have a tensor subclass node, we want to look at its inner tensors, to see if they are FakeTensors, what their FakeTensorMode is, and if they have dynamic shapes.

One option is to decide that "uninitialized subclass" is a first-class thing that anyone looking at the FX node examples values on the dynamo graph needs to handle, but this seems like a lot of work when in reality we don't need dynamo to trace the __new__ at all. Hence the `torch._dynamo.disable`.

I still wasn't very satisfied, since it was unclear to me **why** dynamo was inlining the `__new__` call, instead of interposing on the `DTensor()` constructor directly. After a long chat with @anijain2305, he explained that with code like this:
```
@torch._dynamo.disable(recursive=False)
def f(x):
    out = SubclassConstructor(x)
```

Dynamo will never get the chance to interpose on the subclass constructor. Instead, what will happen is:
(1) Dynamo hands back control to cpython to run `f()`, since we disabled that frame
(2) `SubclassConstructor(x)` is run in eager mode
(3) `SubclassConstructor(x)` eventually calls `SubclassConstructor__new__`
(4) this is a new frame, that cpython then allows dynamo to intercept and start compiling

So it looks like we are basically forced to handle the situation where dynamo might directly start compiling `Subclass.__new__`

All of the above does not explain the story for `__torch_dispatch__` though. Empirically, I have a repro in torchtrain where looking at the dynamo logs, we see dynamo try to inline `__torch_dispatch__`.
```
[rank0]:DEBUG: Skipping frame because no content in function call _prepare_output_fn                     /data/users/hirsheybar/b/pytorch/torch/distributed/tensor/parallel/style.py 318
[rank0]:DEBUG: torchdynamo start compiling __torch_dispatch__ /data/users/hirsheybar/b/pytorch/torch/distributed/_tensor/api.py:297, stack (elided 5 frames):
```

I haven't been able to create a smaller repro of the problem (even using `_dynamo.disable(recursive=False)`), although in theory, if there is a `torch.*` op that you were to inline (where one of the inputs is a subclass), the next frame would likely be `__torch_dispatch__`. Dynamo always treats `torch.*` operations as not-inlinable though, so in theory we shouldn't ever see dynamo inline `__torch_dispatch__`, but a `_dynamo.disable()` fixes the problem.

I asked Animesh if we can have dynamo automatically apply this behavior to subclasses instead of needing it to be added explicitly. He pointed out that for `disable(recursive=False)`, we can't really do this within dynamo

Pull Request resolved: #123347
Approved by: https://github.com/zou3519
ghstack dependencies: #122502, #122751, #123348
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 18, 2024
…es in a few places"

This was the result of me slogging through errors from running this repro: pytorch/torchtitan#61.

This doesn't get dynamic shapes fulling working, but it does fix some obvious errors I hit.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 18, 2024
This was the result of me slogging through errors from running this repro: pytorch/torchtitan#61.

This doesn't get dynamic shapes fulling working, but it does fix some obvious errors I hit.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 18, 2024
…es in a few places"

This was the result of me slogging through errors from running this repro: pytorch/torchtitan#61.

This doesn't get dynamic shapes fulling working, but it does fix some obvious errors I hit.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 18, 2024
This was the result of me slogging through errors from running this repro: pytorch/torchtitan#61.

This doesn't get dynamic shapes fulling working, but it does fix some obvious errors I hit.




[ghstack-poisoned]
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this issue Apr 22, 2024
…pytorch#123347)

Fixes pytorch#122459, pytorch/torchtitan#61

Even with the previous PR ("support DTensor/subclass constructors directly in the graph"), I still see some errors when running the repro above that start some logs showing that dynamo is inlining `__new__`.

I noticed that putting `@torch._dynamo.disable` on DTensor's `__new__` makes the entire repro pass.

Why does having dynamo try to inline `Subclass.__new__` run into problems? Morally, dynamo probably shouldn't be inlining __new__ ("creating a subclass" is a blackbox operation that AOTAutograd can trace through anyway). But concretely, we can end up with a node in the dynamo FX graph that has a "partially initialized tensor subclass" as its example value, because the subclass has been created but its fields have not been assigned to yet.

This breaks a bunch of invariants throughout dynamo: there are many places where if we have a tensor subclass node, we want to look at its inner tensors, to see if they are FakeTensors, what their FakeTensorMode is, and if they have dynamic shapes.

One option is to decide that "uninitialized subclass" is a first-class thing that anyone looking at the FX node examples values on the dynamo graph needs to handle, but this seems like a lot of work when in reality we don't need dynamo to trace the __new__ at all. Hence the `torch._dynamo.disable`.

I still wasn't very satisfied, since it was unclear to me **why** dynamo was inlining the `__new__` call, instead of interposing on the `DTensor()` constructor directly. After a long chat with @anijain2305, he explained that with code like this:
```
@torch._dynamo.disable(recursive=False)
def f(x):
    out = SubclassConstructor(x)
```

Dynamo will never get the chance to interpose on the subclass constructor. Instead, what will happen is:
(1) Dynamo hands back control to cpython to run `f()`, since we disabled that frame
(2) `SubclassConstructor(x)` is run in eager mode
(3) `SubclassConstructor(x)` eventually calls `SubclassConstructor__new__`
(4) this is a new frame, that cpython then allows dynamo to intercept and start compiling

So it looks like we are basically forced to handle the situation where dynamo might directly start compiling `Subclass.__new__`

All of the above does not explain the story for `__torch_dispatch__` though. Empirically, I have a repro in torchtrain where looking at the dynamo logs, we see dynamo try to inline `__torch_dispatch__`.
```
[rank0]:DEBUG: Skipping frame because no content in function call _prepare_output_fn                     /data/users/hirsheybar/b/pytorch/torch/distributed/tensor/parallel/style.py 318
[rank0]:DEBUG: torchdynamo start compiling __torch_dispatch__ /data/users/hirsheybar/b/pytorch/torch/distributed/_tensor/api.py:297, stack (elided 5 frames):
```

I haven't been able to create a smaller repro of the problem (even using `_dynamo.disable(recursive=False)`), although in theory, if there is a `torch.*` op that you were to inline (where one of the inputs is a subclass), the next frame would likely be `__torch_dispatch__`. Dynamo always treats `torch.*` operations as not-inlinable though, so in theory we shouldn't ever see dynamo inline `__torch_dispatch__`, but a `_dynamo.disable()` fixes the problem.

I asked Animesh if we can have dynamo automatically apply this behavior to subclasses instead of needing it to be added explicitly. He pointed out that for `disable(recursive=False)`, we can't really do this within dynamo

Pull Request resolved: pytorch#123347
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#122502, pytorch#122751, pytorch#123348
@tianyu-l tianyu-l added the bug Something isn't working label May 3, 2024
petrex pushed a commit to petrex/pytorch that referenced this issue May 3, 2024
…pytorch#123347)

Fixes pytorch#122459, pytorch/torchtitan#61

Even with the previous PR ("support DTensor/subclass constructors directly in the graph"), I still see some errors when running the repro above that start some logs showing that dynamo is inlining `__new__`.

I noticed that putting `@torch._dynamo.disable` on DTensor's `__new__` makes the entire repro pass.

Why does having dynamo try to inline `Subclass.__new__` run into problems? Morally, dynamo probably shouldn't be inlining __new__ ("creating a subclass" is a blackbox operation that AOTAutograd can trace through anyway). But concretely, we can end up with a node in the dynamo FX graph that has a "partially initialized tensor subclass" as its example value, because the subclass has been created but its fields have not been assigned to yet.

This breaks a bunch of invariants throughout dynamo: there are many places where if we have a tensor subclass node, we want to look at its inner tensors, to see if they are FakeTensors, what their FakeTensorMode is, and if they have dynamic shapes.

One option is to decide that "uninitialized subclass" is a first-class thing that anyone looking at the FX node examples values on the dynamo graph needs to handle, but this seems like a lot of work when in reality we don't need dynamo to trace the __new__ at all. Hence the `torch._dynamo.disable`.

I still wasn't very satisfied, since it was unclear to me **why** dynamo was inlining the `__new__` call, instead of interposing on the `DTensor()` constructor directly. After a long chat with @anijain2305, he explained that with code like this:
```
@torch._dynamo.disable(recursive=False)
def f(x):
    out = SubclassConstructor(x)
```

Dynamo will never get the chance to interpose on the subclass constructor. Instead, what will happen is:
(1) Dynamo hands back control to cpython to run `f()`, since we disabled that frame
(2) `SubclassConstructor(x)` is run in eager mode
(3) `SubclassConstructor(x)` eventually calls `SubclassConstructor__new__`
(4) this is a new frame, that cpython then allows dynamo to intercept and start compiling

So it looks like we are basically forced to handle the situation where dynamo might directly start compiling `Subclass.__new__`

All of the above does not explain the story for `__torch_dispatch__` though. Empirically, I have a repro in torchtrain where looking at the dynamo logs, we see dynamo try to inline `__torch_dispatch__`.
```
[rank0]:DEBUG: Skipping frame because no content in function call _prepare_output_fn                     /data/users/hirsheybar/b/pytorch/torch/distributed/tensor/parallel/style.py 318
[rank0]:DEBUG: torchdynamo start compiling __torch_dispatch__ /data/users/hirsheybar/b/pytorch/torch/distributed/_tensor/api.py:297, stack (elided 5 frames):
```

I haven't been able to create a smaller repro of the problem (even using `_dynamo.disable(recursive=False)`), although in theory, if there is a `torch.*` op that you were to inline (where one of the inputs is a subclass), the next frame would likely be `__torch_dispatch__`. Dynamo always treats `torch.*` operations as not-inlinable though, so in theory we shouldn't ever see dynamo inline `__torch_dispatch__`, but a `_dynamo.disable()` fixes the problem.

I asked Animesh if we can have dynamo automatically apply this behavior to subclasses instead of needing it to be added explicitly. He pointed out that for `disable(recursive=False)`, we can't really do this within dynamo

Pull Request resolved: pytorch#123347
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#122502, pytorch#122751, pytorch#123348
@tianyu-l
Copy link
Contributor Author

closing as #268 landed -- we are using per-TransformerBlock compilation.

@nighting0le01
Copy link

nighting0le01 commented Jan 13, 2025

@tianyu-l hey, i see something like this for my modified model when running with torch.compile and Async tp and variable input lens


[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "<SOME_CODE with transformer blocks>.py", line 900, in forward
[rank0]:     x, context = block(
[rank0]:                  ^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1278, in __call__
[rank0]:     return self._torchdynamo_orig_callable(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
[rank0]:     return _compile(
[rank0]:            ^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
[rank0]:     guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
[rank0]:     return _compile_inner(code, one_graph, hooks, transform)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
[rank0]:     return function(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
[rank0]:     out_code = transform_code_object(code, transform)
[rank0]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
[rank0]:     transformations(instructions, code_options)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
[rank0]:     tracer.run()
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
[rank0]:     super().run()
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX
[rank0]:     self.call_function(fn, argsvars.items, kwargsvars)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward
[rank0]:     return getattr(self.realize(), name)(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 906, in call_function
[rank0]:     return self.func.call_function(tx, merged_args, merged_kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 1858, in call_function
[rank0]:     ) = self.create_wrapped_node(
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 1515, in create_wrapped_node
[rank0]:     ) = speculate_subgraph(
[rank0]:         ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 462, in speculate_subgraph
[rank0]:     output = f.call_function(tx, args, sub_kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/nn_module.py", line 899, in call_function
[rank0]:     return variables.UserFunctionVariable(fn, source=source).call_function(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 326, in call_function
[rank0]:     return super().call_function(tx, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
[rank0]:     return cls.inline_call_(parent, func, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
[rank0]:     tracer.run()
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2279, in CALL
[rank0]:     self._call(inst)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2273, in _call
[rank0]:     self.call_function(fn, args, kwargs)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
[rank0]:     return cls.inline_call_(parent, func, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
[rank0]:     tracer.run()
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2279, in CALL
[rank0]:     self._call(inst)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2273, in _call
[rank0]:     self.call_function(fn, args, kwargs)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 326, in call_function
[rank0]:     return super().call_function(tx, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
[rank0]:     return cls.inline_call_(parent, func, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
[rank0]:     tracer.run()
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2279, in CALL
[rank0]:     self._call(inst)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2273, in _call
[rank0]:     self.call_function(fn, args, kwargs)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 387, in call_function
[rank0]:     return super().call_function(tx, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 326, in call_function
[rank0]:     return super().call_function(tx, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
[rank0]:     return cls.inline_call_(parent, func, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
[rank0]:     tracer.run()
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2279, in CALL
[rank0]:     self._call(inst)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2273, in _call
[rank0]:     self.call_function(fn, args, kwargs)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 387, in call_function
[rank0]:     return super().call_function(tx, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 326, in call_function
[rank0]:     return super().call_function(tx, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
[rank0]:     return cls.inline_call_(parent, func, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
[rank0]:     tracer.run()
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2279, in CALL
[rank0]:     self._call(inst)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2273, in _call
[rank0]:     self.call_function(fn, args, kwargs)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 1024, in call_function
[rank0]:     return self.obj.call_method(tx, self.name, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py", line 527, in call_method
[rank0]:     result = handler_method(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py", line 919, in method_redistribute
[rank0]:     return wrap_fx_proxy(
[rank0]:            ^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 2043, in wrap_fx_proxy
[rank0]:     return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 2130, in wrap_fx_proxy_cls
[rank0]:     example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2080, in get_fake_value
[rank0]:     raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2015, in get_fake_value
[rank0]:     ret_val = wrap_fake_exception(
[rank0]:               ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1574, in wrap_fake_exception
[rank0]:     return fn()
[rank0]:            ^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2016, in <lambda>
[rank0]:     lambda: run_node(tx.output, node, args, kwargs, nnmodule)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2148, in run_node
[rank0]:     raise RuntimeError(make_error_message(e)).with_traceback(
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2130, in run_node
[rank0]:     return node.target(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py", line 912, in redistribute_fn_with_prim_types
[rank0]:     return x.redistribute(*args_as_value, **kwargs_as_value)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 538, in redistribute
[rank0]:     return Redistribute.apply(self, device_mesh, placements, async_op)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py", line 575, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_redistribute.py", line 294, in forward
[rank0]:     output = redistribute_local_tensor(
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_redistribute.py", line 177, in redistribute_local_tensor
[rank0]:     transform_infos = _gen_transform_infos(current_spec, target_spec)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_dtensor_spec.py", line 68, in __hash__
[rank0]:     self._hash = self._hash_impl()
[rank0]:                  ^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_dtensor_spec.py", line 51, in _hash_impl
[rank0]:     return hash(
[rank0]:            ^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/__init__.py", line 531, in __hash__
[rank0]:     raise TypeError("unhashable type: non-nested SymInt")
[rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function <function TensorVariable.method_redistribute.<locals>.redistribute_fn_with_prim_types at 0x7f946c1cb7e0>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(4, s2, 1024), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3, 4, 5, 6, 7], mesh_dim_names=('tensor_parallel',)), placements=(Replicate(),)),), **{}):
[rank0]: unhashable type: non-nested SymInt

[rank0]: from user code:
[rank0]:    File "/opt/conda/lib/python3.11/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 170, in forward
[rank0]:     return self.checkpoint_fn(  # type: ignore[misc]
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1841, in _call_impl
[rank0]:     return inner()
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in inner
[rank0]:     args_result = hook(self, args)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/parallel/style.py", line 528, in <lambda>
[rank0]:     module.register_forward_pre_hook(lambda _, inputs: self._prepare_input_fn(inputs, device_mesh))  # type: ignore[misc, call-arg]
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/parallel/style.py", line 501, in _prepare_input_fn
[rank0]:     self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/parallel/style.py", line 479, in _prepare_input_arg
[rank0]:     dt_inp = dt_inp.redistribute(placements=(desired_layout,))

@tianyu-l
Copy link
Contributor Author

@nighting0le01 Thanks for reporting the issue.

Could you please provide a repro / config so we can better help you? In particular, can you please explain how you use variable-length input? cc @yifuwang
Btw I think it might be better to create a new dedicated issue to track this problem.

@nighting0le01
Copy link

@tianyu-l hi i can create a new issue, but you can think of it as different resolution of of frames for a multimodal model.

@tianyu-l
Copy link
Contributor Author

@nighting0le01
btw do you still see the issue after disabling async TP (and possibly other "advanced" technique other than torch.compile)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants