Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛[BUG]: CorrDiff - Double nested model_kwargs passed to U-Net constructor #536

Open
stathius opened this issue May 30, 2024 · 2 comments
Assignees
Labels
? - Needs Triage Need team to review and classify bug Something isn't working

Comments

@stathius
Copy link
Contributor

Version

0.7.0a

On which installation method(s) does this occur?

Docker, Source

Describe the issue

Analysis: The error comes because model_kwargs are double nested when passed to the U-Net constructor: {'model_kwargs': {'embedding_type': 'zero', 'encoder_type': 'standard', 'decoder_type': 'standard', 'channel_mult_noise': 1, 'resample_filter': [1, 1], 'model_channels': 128, 'channel_mult': [1, 2, 2, 2, 2], 'attn_resolutions': [28], 'dropout': 0.13}}

This could be a problem with how the metadata for the regression model are saved on the disk.

Proposed fix:

  1. Pass model_kwargs["model_kwargs"] instead of model_kwargs
  2. Investigate and change how metadata for regression model are saved.

Minimum reproducible example

python3 train.py --config-name=config_train_diffusion.yaml
config: 
     arch: ddpmpp-cwb
     precond: edmv1
     task: diffusion

Relevant log output

Traceback (most recent call last):
  File "/code/modulus/examples/generative/corrdiff/train.py", line 344, in main
    training_loop.training_loop(
  File "/code/modulus/examples/generative/corrdiff/training/training_loop.py", line 176, in training_loop   <--------------
    net_reg = Module.from_checkpoint(regression_checkpoint_path)                      
  File "/code/modulus/modulus/models/module.py", line 357, in from_checkpoint
    model = cls.instantiate(args)
  File "/code/modulus/modulus/models/module.py", line 175, in instantiate
    return _cls(**arg_dict["__args__"])
  File "/code/modulus/modulus/models/diffusion/unet.py", line 111, in __init__      <--------------
    self.model = model_class(
  File "/code/modulus/modulus/models/module.py", line 65, in __new__
    bound_args = sig.bind_partial(
  File "/usr/lib/python3.10/inspect.py", line 3193, in bind_partial
    return self._bind(args, kwargs, partial=True)
  File "/usr/lib/python3.10/inspect.py", line 3175, in _bind
    raise TypeError(
TypeError: got an unexpected keyword argument 'model_kwargs'

Environment details

python version: 3.10
modulus commit: `c07fa25321c48a1d71efca12b67d056adbca8bd4`
@stathius stathius added ? - Needs Triage Need team to review and classify bug Something isn't working labels May 30, 2024
@DavidLandup0
Copy link
Contributor

From a cursory look, it appears that these should be network_kwargs, not model_kwargs, and there doesn't seem to be nesting in the source code. These are constructed as part of:

c.network_kwargs.update(
            channel_mult_noise=1,
            resample_filter=[1, 1],
            model_channels=128,
            channel_mult=[1, 2, 2, 2, 2],
            attn_resolutions=[28],
        )  # era5-cwb, 448x448

And passed into the training loop via:

 training_loop.training_loop(
        dataset, dataset_iter, valid_dataset, valid_dataset_iter, **c
    )

And matched for:

def training_loop(
    dataset,
    dataset_iterator,
    validation_dataset,
    validation_dataset_iterator,
    *,
    task,
    run_dir=".",  # Output directory.
    network_kwargs={},  # Options for model and preconditioning.
   ...

@stathius
Copy link
Contributor Author

stathius commented Jun 12, 2024

I think the problem comes from how the regressor is saved on disk. It might have been solved already but some of the checkpoints I have been working on had this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
? - Needs Triage Need team to review and classify bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants