Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save ckpt error #18

Open
SxJyJay opened this issue Aug 23, 2024 · 3 comments
Open

Save ckpt error #18

SxJyJay opened this issue Aug 23, 2024 · 3 comments

Comments

@SxJyJay
Copy link

SxJyJay commented Aug 23, 2024

During training, I found the training procedure crashes when running

consolidated_model_state_dict = {key: val.to(save_dtype) for key, val in model.state_dict().items()}

And the error is:
AssertionError: FSDP assumes model.norm.weight is in the state_dict but the state_dict only has odict_keys

@SxJyJay
Copy link
Author

SxJyJay commented Aug 23, 2024

I find there may be a bug in

modules = [*list(self.model.layers), self.lm_head, self.model.embed_tokens]

Because the full Chameleon model consists of:

ChameleonModel(
  (embed_tokens): Embedding(65536, 4096)
  (layers): ModuleList(
    (0-31): 32 x ChameleonDecoderLayer(
      (self_attn): ChameleonSdpaAttention(
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (q_norm): ChameleonLayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (k_norm): ChameleonLayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (rotary_emb): ChameleonRotaryEmbedding()
      )
      (mlp): ChameleonMLP(
        (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): ChameleonRMSNorm((4096,), eps=1e-05)
      (post_attention_layernorm): ChameleonRMSNorm((4096,), eps=1e-05)
      (dropout): Dropout(p=0.05, inplace=False)
    )
  )
  (norm): ChameleonRMSNorm((4096,), eps=1e-05)
)

However, modules = [*list(self.model.layers), self.lm_head, self.model.embed_tokens] ignores self.model.norm part. This causes the above error when saving the checkpoint. After I modify this line into:

modules = [*list(self.model.layers), self.lm_head, self.model.embed_tokens, self.model.norm]

the training can proceed.

@ChrisLiu6
Copy link
Contributor

ChrisLiu6 commented Aug 24, 2024

That's weird. The "get_fsdp_wrap_module_list" method is used for the auto_wrap_policy argument in the FSDP call:

model = FSDP(
model,
auto_wrap_policy=functools.partial(
lambda_auto_wrap_policy,
lambda_fn=lambda m: m in model.get_fsdp_wrap_module_list(),
),
process_group=fs_init.get_data_parallel_group(),

Note that FSDP wrapping is a recursive process, which means not only the outmost model, but some of the inner submodules, are also wrapped into FSDP modules. Operations like parameter sharding, gather, and flattening are then conducted at the FSDP-module level.

Importantly, the auto_wrap_policy argument is used to define "which sub-modules should be independently wrapped into new FSDP modules", rather than "which modules should be considered as part of the model". So self.model.norm is absent in the list merely means it won't make an independently-wrapped FSDP module, but it will be included in the outmost FSDP module.

Therefore, according to our experience, the problem you mentioned might not be the real cause of the error you met. Have you made any other modifications to the code? Or what's your pytorch version?

@SxJyJay
Copy link
Author

SxJyJay commented Aug 24, 2024

Thanks for your response! I use 1 GPU to debug the code. The only modification I made is probably I define a get_trainable_params method in ChameleonXLLMXForConditionalGeneration class to enable only a part of parameters trainable so as to save memory. After I make the aformentioned modification, the code works fine on both 1 GPU and 8 GPUs. I wonder if making the aforementioned modification will influence the model's performances?

BTW, my pytorch version is 2.3.0.

Best regards.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants