This is a PyTorch implementation of the SimSiam paper:
@Article{chen2020simsiam,
author = {Xinlei Chen and Kaiming He},
title = {Exploring Simple Siamese Representation Learning},
journal = {arXiv preprint arXiv:2011.10566},
year = {2020},
}
Install PyTorch and download the ImageNet dataset following the official PyTorch ImageNet training code. Similar to MoCo, the code release contains minimal modifications for both unsupervised pre-training and linear classification to that code.
In addition, install apex for the LARS implementation needed for linear classification.
Only multi-gpu, DistributedDataParallel training is supported; single-gpu or DataParallel training is not supported.
To do unsupervised pre-training of a ResNet-50 model on ImageNet in an 8-gpu machine, run:
python main_simsiam.py \
-a resnet50 \
--dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 \
--fix-pred-lr \
[your imagenet-folder with train and val folders]
The script uses all the default hyper-parameters as described in the paper, and uses the default augmentation recipe from MoCo v2.
The above command performs pre-training with a non-decaying predictor learning rate for 100 epochs, corresponding to the last row of Table 1 in the paper.
With a pre-trained model, to train a supervised linear classifier on frozen features/weights in an 8-gpu machine, run:
python main_lincls.py \
-a resnet50 \
--dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 \
--pretrained [your checkpoint path]/checkpoint_0099.pth.tar \
--lars \
[your imagenet-folder with train and val folders]
The above command uses LARS optimizer and a default batch size of 4096.
Our pre-trained ResNet-50 models and logs:
pre-train epochs |
batch size |
pre-train ckpt |
pre-train log |
linear cls. ckpt |
linear cls. log |
top-1 acc. |
---|---|---|---|---|---|---|
100 | 512 | link | link | link | link | 68.1 |
100 | 256 | link | link | link | link | 68.3 |
Settings for the above: 8 NVIDIA V100 GPUs, CUDA 10.1/CuDNN 7.6.5, PyTorch 1.7.0.
Same as MoCo for object detection transfer, please see moco/detection.
This project is under the CC-BY-NC 4.0 license. See LICENSE for details.