-
Notifications
You must be signed in to change notification settings - Fork 269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.compile each TransformerBlock instead of the whole model #268
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me!
This way we could temporarily enable 2-D parallel compile, and it make more sense to do transformer block compile in the future with PP anyways. We should figure out the dynamic shape issue though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm! Some nit comments on how to organize things.
train.py
Outdated
if job_config.training.compile: | ||
if ( | ||
job_config.activation_checkpoint.mode == "selective" | ||
and job_config.activation_checkpoint.selective_ac_option == "op" | ||
): | ||
# some flags for torch.compile enablement | ||
torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = ( | ||
True | ||
) | ||
logger.info("Compiling model with torch.compile") | ||
model = torch.compile(model) | ||
logger.info("Compiling each TransformerBlock with torch.compile") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Since we put parallelization, ac, and compile all in parallelize_llama.py, it might make more sense to put this all in that file. This will simplify train.py
.
Concretely, we can put this block to parallelize_llama.py
, right before logger.info("Applied FSDP to the model")
and change the wording from "Compiling each ..."
to "Compiled each ..."
train.py
Outdated
@@ -219,17 +219,16 @@ def loss_fn(pred, labels): | |||
|
|||
metric_logger = build_metric_logger(job_config) | |||
|
|||
# torch.compile model for improved performance | |||
if job_config.training.compile: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: since we are moving compiling to parallelize_llama, let's add some comment on L201.
logger.info(f"Applied {ac_mode} activation checkpointing to the model") | ||
|
||
if job_config.training.compile: | ||
# turn on per-transformer block compile after AC wrappnig and before FSDP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: wrappnig -> wrapping
torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = ( | ||
True | ||
) | ||
if enable_compile: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can this error check be moved up to line 216?
Should we be able to close #61 after this PR? Also, do we need to run end-to-end numerics testing? |
I'll break up some other changes to land them first |
when PP is present, we may torch.compile the whole stage module, which is bigger than a transformer block, i.e.
It would also allow the code to be more model-agnostic -- there is no |
In my case, enabling compilation in this way (per-layer) causes a memory leak |
🤔 interesting, how did you observe that? fwiw this doesn't work out of box, as it trigger some non-trival numeric issues, I'm going to leave this PR here until I resolved it. Opening a new PR to turn dynamic shape off so that it works for both 1D and 2D compile |
With each iteration, the memory usage increases and eventually results in OOM. However, just to be clear, I haven't tested this on your entire code, only on a part of it. Adding per-layer compilation causes a memory leak with each iteration. I know that the memory leak might be related to my implementation, so I just wanted to bring this issue to your attention. If you don't observe this in your code, then it's likely an issue on my end. |
if enable_compile: | ||
# turn on per-transformer block compile after AC wrapping and before FSDP | ||
# TODO: dynamic shape have some issues so we turn it off for now. | ||
transformer_block = torch.compile(transformer_block, dynamic=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious how far we are from being able to enable fullgraph=True
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think fullgraph=True
should work already when we moving to per-TransformerBlock compile, maybe I can add that flag too.
I ended up just turn dynamic=False in the current full mode compile #297 in this case we can't full_graph=True yet as FSDP still graph breaking, but for that case we should also captured each TransformerBlock as full graphs already
going to merge this given that:
|
This way we could temporarily enable 2-D parallel compile, and it might make sense to do transformer block compile in the future with PP (which we'll see). We should figure out: 1. dynamic shape issue when turning on 2D parallel 2. full model compile issue for 2D parallel compile 3. cache reusing currently does not work, enable it later
…ch#268) This way we could temporarily enable 2-D parallel compile, and it might make sense to do transformer block compile in the future with PP (which we'll see). We should figure out: 1. dynamic shape issue when turning on 2D parallel 2. full model compile issue for 2D parallel compile 3. cache reusing currently does not work, enable it later
This way we could temporarily enable 2-D parallel compile, and it might make sense to do transformer block compile in the future with PP (which we'll see). We should figure out: 1. dynamic shape issue when turning on 2D parallel 2. full model compile issue for 2D parallel compile 3. cache reusing currently does not work, enable it later
…ch#268) This way we could temporarily enable 2-D parallel compile, and it might make sense to do transformer block compile in the future with PP (which we'll see). We should figure out: 1. dynamic shape issue when turning on 2D parallel 2. full model compile issue for 2D parallel compile 3. cache reusing currently does not work, enable it later
This way we could temporarily enable 2-D parallel compile, and it might make sense to do transformer block compile in the future with PP (which we'll see).
We should figure out the dynamic shape issue though