Forked from Luke Melas-Kyriazi repository.
git clone https://github.com/arkel23/PyTorch-Pretrained-ViT.git
cd PyTorch-Pretrained-ViT
pip install -e .
python download_convert_models.py # can modify to download different models, by default it downloads all 5 ViTs pretrained on ImageNet21k
from pytorch_pretrained_vit import ViT, ViTConfigExtended, PRETRAINED_CONFIGS
model_name = 'B_16'
def_config = PRETRAINED_CONFIGS['{}'.format(model_name)]['config']
configuration = ViTConfigExtended(**def_config)
model = ViT(configuration, name=model_name, pretrained=True, load_repr_layer=False, ret_attn_scores=False)
- Added support for 'H-14' and L'16' ViT models.
- Added support for downloading the models directly from Google's cloud storage.
- Corrected the Jax to Pytorch weights transformation. Previous methodology would lead to .pth state_dict files without the 'representation layer'.
ViT('load_repr_layer'=True)
would lead to an error. If only interested in inference the representation layer was unnecessary as discussed in the original paper for the Vision Transformer, but for other applications and experiments it may be useful so I added adownload_convert_models.py
to first download the required models, convert them with all the weights, and then you can completely tune the parameters. - Added support for visualizing attention, by returning the scores values in the multi-head self-attention layers. The visualizing script was mostly taken from jeonsworld/ViT-pytorch repository.
- Added examples for inference (single image), and fine-tuning/training (using CIFAR-10).
- Modified loading of models by using configurations similar to HuggingFace's Transformers.
# Change the default configuration by accessing individual attributes
configuration.image_size = 128
configuration.num_classes = 10
configuration.num_hidden_layers = 3
model = ViT_modified(config=configuration, name='B_16', pretrained=True)
# for another example see examples/configurations/load_configs.py
- Added support to partially load ViT
model = ViT(config=configuration, name='B_16')
pretrained_mode = 'full_tokenizer'
weights_path = "/hdd/edwin/support/torch/hub/checkpoints/B_16.pth"
model.load_partial(weights_path=weights_path, pretrained_image_size=configuration.pretrained_image_size,
pretrained_mode=pretrained_mode, verbose=True)
for pretrained_mode in ['full_tokenizer', 'patchprojection', 'posembeddings', 'clstoken',
'patchandposembeddings', 'patchandclstoken', 'posembeddingsandclstoken']:
model.load_partial(weights_path=weights_path,
pretrained_image_size=configuration.pretrained_image_size, pretrained_mode=pretrained_mode, verbose=True)
This repository contains an op-for-op PyTorch reimplementation of the Vision Transformer architecture from Google, along with pre-trained models and examples.
Visual Transformers (ViT) are a straightforward application of the transformer architecture to image classification. Even in computer vision, it seems, attention is all you need.
The ViT architecture works as follows: (1) it considers an image as a 1-dimensional sequence of patches, (2) it prepends a classification token to the sequence, (3) it passes these patches through a transformer encoder (like BERT), (4) it passes the first token of the output of the transformer through a small MLP to obtain the classification logits. ViT is trained on a large-scale dataset (ImageNet-21k) with a huge amount of compute.
Other great repositories with this model include:
- Google Research's repo
- Ross Wightman's repo
- Phil Wang's repo
- Eunkwang Jeon's repo
- Luke Melas-Kyriazi repo
If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.
I look forward to seeing what the community does with these models!