Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Instructions on using the pretrained model #3

Open
Bucanero06 opened this issue May 23, 2024 · 2 comments
Open

Instructions on using the pretrained model #3

Bucanero06 opened this issue May 23, 2024 · 2 comments
Assignees

Comments

@Bucanero06
Copy link

Hi, I have prepared the Python environment, and downloaded data from the PDEBench, as well as the model weights from Google Drive. I can train the model from scratch on my current data available and I wanted to know your instructions on how to use the load the pre-trained model. Thank you for your time and contribution.

P.S.
I also have questions about the paper MPP 2023 paper, what is your preferred method of communication about the project?

@Bucanero06
Copy link
Author

I have added the path of the MPP_AViT_S tar file to YAML setting pretrained_ckpt_path. From here I arrive to an attribute error when accessing self.model.module

python train_basic.py --run_name first_test_run --config finetune --yaml_config config/mpp_avit_s_config.yaml 
Loading configuration file: multiple_physics_pretraining/config/mpp_avit_s_config.yaml
Configuration name: finetune
Initializing data on rank 0
Initializing model on rank 0
Model parameter count: 28979436
Starting from pretrained model at weights/MPP_AViT_S

Traceback (most recent call last):
  File "multiple_physics_pretraining/train_basic.py", line 547, in <module>
    trainer = Trainer(params, global_rank, local_rank, device, sweep_id=args.sweep_id)
  File "multiple_physics_pretraining/train_basic.py", line 80, in __init__
    self.restore_checkpoint(params.pretrained_ckpt_path)
  File "multiple_physics_pretraining/train_basic.py", line 201, in restore_checkpoint
    self.model.module.unfreeze()
  File "multiple_physics_pretraining/multiple_physics_pretrained_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1695, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'AViT' object has no attribute 'module'. Did you mean: 'modules'?

I will look into it so just updating the thread here

@Bucanero06
Copy link
Author

I have added the path of the MPP_AViT_S tar file to YAML setting pretrained_ckpt_path. From here I arrive to an attribute error when accessing self.model.module

python train_basic.py --run_name first_test_run --config finetune --yaml_config config/mpp_avit_s_config.yaml 
Loading configuration file: multiple_physics_pretraining/config/mpp_avit_s_config.yaml
Configuration name: finetune
Initializing data on rank 0
Initializing model on rank 0
Model parameter count: 28979436
Starting from pretrained model at weights/MPP_AViT_S

Traceback (most recent call last):
  File "multiple_physics_pretraining/train_basic.py", line 547, in <module>
    trainer = Trainer(params, global_rank, local_rank, device, sweep_id=args.sweep_id)
  File "multiple_physics_pretraining/train_basic.py", line 80, in __init__
    self.restore_checkpoint(params.pretrained_ckpt_path)
  File "multiple_physics_pretraining/train_basic.py", line 201, in restore_checkpoint
    self.model.module.unfreeze()
  File "multiple_physics_pretraining/multiple_physics_pretrained_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1695, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'AViT' object has no attribute 'module'. Did you mean: 'modules'?

I will look into it so just updating the thread here

The issue I was encountering seems to be consistent with trying to access self.model.module when self.model is not wrapped in a DistributedDataParallel (DDP) object; the 'module' attribute is added by the DDP wrapper.
It seems the original method assumed the presence of the module attribute even when not using DDP. The updated restore_checkpoint method checks if self.model is an instance of DistributedDataParallel (DDP) and only then access the module attribute; both when loading the state dict and under the self.params.pretrained if statement. Just minor changes but its working on my 1 local GPU workstation.

    def restore_checkpoint(self, checkpoint_path):
        """ Load model/opt from path """
        checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.local_rank))
        if 'model_state' in checkpoint:
            model_state = checkpoint['model_state']
        else:
            model_state = checkpoint
        try:  # Try to load with DDP Wrapper
            self.model.load_state_dict(model_state)
        except:  # If that fails, either try to load into module or strip DDP prefix
            if isinstance(self.model, DDP):
                self.model.module.load_state_dict(model_state)
            else:
                new_state_dict = OrderedDict()
                for key, val in model_state.items():
                    # Failing means this came from DDP - strip the DDP prefix
                    name = key[7:]
                    new_state_dict[name] = val
                self.model.load_state_dict(new_state_dict)

        if self.params.resuming:  # restore checkpoint is used for finetuning as well as resuming. If finetuning (i.e., not resuming), restore checkpoint does not load optimizer state, instead uses config specified lr.
            self.iters = checkpoint['iters']
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.startEpoch = checkpoint['epoch']
            self.epoch = self.startEpoch
        else:
            self.iters = 0

        if self.params.pretrained:
            if isinstance(self.model, DDP):
                model_to_modify = self.model.module
            else:
                model_to_modify = self.model

            if self.params.freeze_middle:
                model_to_modify.freeze_middle()
            elif self.params.freeze_processor:
                model_to_modify.freeze_processor()
            else:
                model_to_modify.unfreeze()

            # See how much we need to expand the projections
            exp_proj = 0
            # Iterate through the appended datasets and add on enough embeddings for all of them.
            for add_on in self.params.append_datasets:
                exp_proj += len(DSET_NAME_TO_OBJECT[add_on]._specifics()[2])
            model_to_modify.expand_projections(exp_proj)

        checkpoint = None
        self.model = self.model.to(self.device)

@Bucanero06 Bucanero06 changed the title Instructions on how to load the downloaded weights or where can the ckpt file be found for pretrained model Instructions on using the pretrained model May 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants