Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile each TransformerBlock instead of the whole model #268

Merged
merged 20 commits into from
May 22, 2024

Conversation

wanchaol
Copy link
Contributor

@wanchaol wanchaol commented Apr 25, 2024

This way we could temporarily enable 2-D parallel compile, and it might make sense to do transformer block compile in the future with PP (which we'll see).

We should figure out the dynamic shape issue though

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 25, 2024
Copy link
Contributor

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me!

@wanchaol wanchaol requested a review from tianyu-l April 26, 2024 00:15
This way we could temporarily enable 2-D parallel compile, and it make
more sense to do transformer block compile in the future with PP
anyways.

We should figure out the dynamic shape issue though
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! Some nit comments on how to organize things.

train.py Outdated
Comment on lines 222 to 231
if job_config.training.compile:
if (
job_config.activation_checkpoint.mode == "selective"
and job_config.activation_checkpoint.selective_ac_option == "op"
):
# some flags for torch.compile enablement
torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = (
True
)
logger.info("Compiling model with torch.compile")
model = torch.compile(model)
logger.info("Compiling each TransformerBlock with torch.compile")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Since we put parallelization, ac, and compile all in parallelize_llama.py, it might make more sense to put this all in that file. This will simplify train.py.

Concretely, we can put this block to parallelize_llama.py, right before logger.info("Applied FSDP to the model") and change the wording from "Compiling each ..." to "Compiled each ..."

train.py Outdated
@@ -219,17 +219,16 @@ def loss_fn(pred, labels):

metric_logger = build_metric_logger(job_config)

# torch.compile model for improved performance
if job_config.training.compile:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: since we are moving compiling to parallelize_llama, let's add some comment on L201.

logger.info(f"Applied {ac_mode} activation checkpointing to the model")

if job_config.training.compile:
# turn on per-transformer block compile after AC wrappnig and before FSDP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: wrappnig -> wrapping

torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = (
True
)
if enable_compile:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can this error check be moved up to line 216?

@awgu
Copy link
Contributor

awgu commented Apr 26, 2024

Should we be able to close #61 after this PR?

Also, do we need to run end-to-end numerics testing?

@wanchaol
Copy link
Contributor Author

Should we be able to close #61 after this PR?

Also, do we need to run end-to-end numerics testing?

@awgu yeah I think it should resolve that, I'll do some e2e benchmarking before landing, so this would likely take a while

@wanchaol
Copy link
Contributor Author

I'll break up some other changes to land them first

@kwen2501
Copy link
Contributor

kwen2501 commented May 1, 2024

when PP is present, we may torch.compile the whole stage module, which is bigger than a transformer block, i.e.

pipe = pipeline(model, ...)
stage_mod = pipe.get_stage_module(stage_idx)
stage_mod = torch.compile(stage_mod)
stage = PipelineStage(stage_mod, ...)

It would also allow the code to be more model-agnostic -- there is no transformer_block, layer_id or model.layers here.

@chrisociepa
Copy link

In my case, enabling compilation in this way (per-layer) causes a memory leak

@wanchaol
Copy link
Contributor Author

wanchaol commented May 2, 2024

In my case, enabling compilation in this way (per-layer) causes a memory leak

🤔 interesting, how did you observe that?

fwiw this doesn't work out of box, as it trigger some non-trival numeric issues, I'm going to leave this PR here until I resolved it. Opening a new PR to turn dynamic shape off so that it works for both 1D and 2D compile

@chrisociepa
Copy link

With each iteration, the memory usage increases and eventually results in OOM. However, just to be clear, I haven't tested this on your entire code, only on a part of it. Adding per-layer compilation causes a memory leak with each iteration. I know that the memory leak might be related to my implementation, so I just wanted to bring this issue to your attention. If you don't observe this in your code, then it's likely an issue on my end.
Meanwhile, I didn't observe the numerical issues you mentioned, unless they are a direct consequence of the memory leak. The loss function appears practically the same with layer compilation both enabled and disabled.

if enable_compile:
# turn on per-transformer block compile after AC wrapping and before FSDP
# TODO: dynamic shape have some issues so we turn it off for now.
transformer_block = torch.compile(transformer_block, dynamic=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious how far we are from being able to enable fullgraph=True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think fullgraph=True should work already when we moving to per-TransformerBlock compile, maybe I can add that flag too.

I ended up just turn dynamic=False in the current full mode compile #297 in this case we can't full_graph=True yet as FSDP still graph breaking, but for that case we should also captured each TransformerBlock as full graphs already

@wanchaol
Copy link
Contributor Author

wanchaol commented May 22, 2024

going to merge this given that:

  1. 2D compile currently broken and this PR work arounds it and make TP can be compiled again (we should separately figure out the full model compile issue) cc @bdhirsh
  2. per-TransformerBlock compile would give us potential later once the cache reusing in torch.compile enabled, it would drastically improve the compile (code start and warm start) time. cc @anijain2305

@wanchaol wanchaol merged commit 60810a9 into main May 22, 2024
4 checks passed
@wanchaol wanchaol deleted the compile_2d branch May 22, 2024 05:10
tianyu-l pushed a commit that referenced this pull request May 28, 2024
This way we could temporarily enable 2-D parallel compile, and it might
make sense to do transformer block compile in the future with PP (which
we'll see).

We should figure out:
1. dynamic shape issue when turning on 2D parallel
2. full model compile issue for 2D parallel compile
3. cache reusing currently does not work, enable it later
tianyu-l pushed a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
…ch#268)

This way we could temporarily enable 2-D parallel compile, and it might
make sense to do transformer block compile in the future with PP (which
we'll see).

We should figure out:
1. dynamic shape issue when turning on 2D parallel
2. full model compile issue for 2D parallel compile
3. cache reusing currently does not work, enable it later
tianyu-l pushed a commit that referenced this pull request Aug 16, 2024
This way we could temporarily enable 2-D parallel compile, and it might
make sense to do transformer block compile in the future with PP (which
we'll see).

We should figure out:
1. dynamic shape issue when turning on 2D parallel
2. full model compile issue for 2D parallel compile
3. cache reusing currently does not work, enable it later
philippguevorguian pushed a commit to YerevaNN/YNNtitan that referenced this pull request Aug 17, 2024
…ch#268)

This way we could temporarily enable 2-D parallel compile, and it might
make sense to do transformer block compile in the future with PP (which
we'll see).

We should figure out:
1. dynamic shape issue when turning on 2D parallel
2. full model compile issue for 2D parallel compile
3. cache reusing currently does not work, enable it later
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants