Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving and loading escnn models #1

Closed
ishaanb92 opened this issue May 3, 2022 · 4 comments
Closed

Saving and loading escnn models #1

ishaanb92 opened this issue May 3, 2022 · 4 comments

Comments

@ishaanb92
Copy link

ishaanb92 commented May 3, 2022

Hi,

Firstly, thanks for making such an accessible library to implement equivariant models alongside such informative documentation!

I wanted to train a toy model with MNIST before moving on to a bigger architecture and chose the model provided in the model.ipynb notebook in the 'examples' folder. After plugging it into my training script, I saved it using the regular PyTorch save procedure:
torch.save(model.state_dict(), 'mnist_model_e2cnn_{}.pt'.format(n_orientations))

In my test script, when I try to load this model using:
model.load_state_dict(torch.load('mnist_model_e2cnn_{}.pt'.format(n_orientations), map_location='cpu'))

However, trying to load the model throws up the following error:

RuntimeError: Error(s) in loading state_dict for MNISTE2CNN:
Missing key(s) in state_dict: "block1.1.filter", "block2.0.filter", "block3.0.filter", "block4.0.filter", "block5.0.filter", "block6.0.filter".

I'm not sure what I'm doing incorrectly, is there a special procedure involved in saving models that use escnn.nn.SequentialModule to stack ops?

EDIT: The torch version I am using is 1.7.0

Cheers,
Ishaan

@Gabri95
Copy link
Collaborator

Gabri95 commented May 3, 2022

Hi @kilgore92

Thanks for opening the first issue! :)

You can probably solve this issue by calling model.eval() before storing and loading the model's state-dict.
The reason behind this behaviour is explained in the first warning block here.
Does it solve your problem?

Best,
Gabriele

@ishaanb92
Copy link
Author

Hi @Gabri95,

Thanks! That did the trick :)

Cheers,
Ishaan

@Gabri95 Gabri95 closed this as completed May 6, 2022
@Peter010103
Copy link

Hi I wanted to save my escnn model using the torch.save(model, PATH) and not the model.state_dict().

I think there are some issues of directly saving it this way due to the library using its bespoke GeometricTensor datatype. Is there a way to directly save the entire model rather than just the state dict?

@kalekundert
Copy link
Contributor

No, there's no way to save the whole model right now. The issue is that torch.save() basically just pickles whatever you give it, and ESCNN models are not pickleable. The specific reason doesn't have anything to do with geometric tensors; it has to do with the group, representation, and gspace objects that ESCNN models contain. See #37 and #78 for attempts to fix this. It's not an easy problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants