Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transfer learning unable to do #391

Open
ravitejarj opened this issue May 11, 2021 · 2 comments
Open

Transfer learning unable to do #391

ravitejarj opened this issue May 11, 2021 · 2 comments

Comments

@ravitejarj
Copy link

ravitejarj commented May 11, 2021

Hi @zhanghang1989
I am unable to do transfer learning on the model
i have downloaded get_deeplab_resnest101_ade and
when i changed the no of classes in ade20k.py (no of classes 8)
pre trained model is not loading ( getting error )

So I have changed the code for Transfer learning

Code changes:

  1. deeplab.py
    in get_deeplab_resnest101_ade function changed from
    from
    #model = DeepLabV3(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
    to
    no_of_classes = 150
    model = DeepLabV3(no_of_classes, backbone=backbone, root=root, **kwargs)

so i can load pretrained model with 150 classes
then
2) In train_dist.py file

Model loading

model_ft = get_deeplab_resnest101_ade(pretrained=True)

for param in model.parameters():
param.requires_grad = False

model_ft.head.block = Sequential(
(Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)),
(BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
(ReLU(inplace=True)),
(Dropout(p=0.1, inplace=False)),
(Conv2d(256, 8, kernel_size=(1, 1), stride=(1, 1))))

for param in model_ft.head.parameters():
param.requires_grad = True

Training

python train_dist.py --dataset ade20k --model deeplab --aux --backbone resnest101 --ft --epochs 100

after successful training
i am getting 150 classes output not 8 classes(i have given 8 classes in last layer)
i need 8 classes output

can you help me with this

@zhanghang1989
Copy link
Owner

An easy solution is set strict=False when loading the pretrained model
https://pytorch.org/docs/master/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict

@ravitejarj
Copy link
Author

Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants