Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/pre commit #140

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[flake8]
exclude = .venv, .idea, .pytest_cache, __pycache__, .git, .scripts/*, logs/*, docker/*, build/*
ignore = E501, E203, W503
per-file-ignores = */__init__.py: F401
max-line-length = 88
4 changes: 4 additions & 0 deletions .isort.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[settings]
multi_line_output=3
lines_after_imports=2
sections=FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
27 changes: 27 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
language_version: python3.10
args: ["--profile", "black"]
- repo: https://github.com/ambv/black
rev: 23.3.0
hooks:
- id: black
language_version: python3.10
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
- id: flake8
language_version: python3.10
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-yaml
- id: debug-statements
- id: trailing-whitespace
default_language_version:
python: python3.10
22 changes: 11 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
<img src="figures/teaser.png" width="700">
</p>

This repo covers an reference implementation for the following papers in PyTorch, using CIFAR as an illustrative example:
(1) Supervised Contrastive Learning. [Paper](https://arxiv.org/abs/2004.11362)
(2) A Simple Framework for Contrastive Learning of Visual Representations. [Paper](https://arxiv.org/abs/2002.05709)
This repo covers an reference implementation for the following papers in PyTorch, using CIFAR as an illustrative example:
(1) Supervised Contrastive Learning. [Paper](https://arxiv.org/abs/2004.11362)
(2) A Simple Framework for Contrastive Learning of Visual Representations. [Paper](https://arxiv.org/abs/2002.05709)

## Update

Expand Down Expand Up @@ -40,32 +40,32 @@ Results on CIFAR-10:
| |Arch | Setting | Loss | Accuracy(%) |
|----------|:----:|:---:|:---:|:---:|
| SupCrossEntropy | ResNet50 | Supervised | Cross Entropy | 95.0 |
| SupContrast | ResNet50 | Supervised | Contrastive | 96.0 |
| SupContrast | ResNet50 | Supervised | Contrastive | 96.0 |
| SimCLR | ResNet50 | Unsupervised | Contrastive | 93.6 |

Results on CIFAR-100:
| |Arch | Setting | Loss | Accuracy(%) |
|----------|:----:|:---:|:---:|:---:|
| SupCrossEntropy | ResNet50 | Supervised | Cross Entropy | 75.3 |
| SupContrast | ResNet50 | Supervised | Contrastive | 76.5 |
| SupContrast | ResNet50 | Supervised | Contrastive | 76.5 |
| SimCLR | ResNet50 | Unsupervised | Contrastive | 70.7 |

Results on ImageNet (Stay tuned):
| |Arch | Setting | Loss | Accuracy(%) |
|----------|:----:|:---:|:---:|:---:|
| SupCrossEntropy | ResNet50 | Supervised | Cross Entropy | - |
| SupContrast | ResNet50 | Supervised | Contrastive | 79.1 (MoCo trick) |
| SupContrast | ResNet50 | Supervised | Contrastive | 79.1 (MoCo trick) |
| SimCLR | ResNet50 | Unsupervised | Contrastive | - |

## Running
You might use `CUDA_VISIBLE_DEVICES` to set proper number of GPUs, and/or switch to CIFAR100 by `--dataset cifar100`.
You might use `CUDA_VISIBLE_DEVICES` to set proper number of GPUs, and/or switch to CIFAR100 by `--dataset cifar100`.
**(1) Standard Cross-Entropy**
```
python main_ce.py --batch_size 1024 \
--learning_rate 0.8 \
--cosine --syncBN \
```
**(2) Supervised Contrastive Learning**
**(2) Supervised Contrastive Learning**
Pretraining stage:
```
python main_supcon.py --batch_size 1024 \
Expand All @@ -84,7 +84,7 @@ python main_linear.py --batch_size 512 \
--learning_rate 5 \
--ckpt /path/to/model.pth
```
**(3) SimCLR**
**(3) SimCLR**
Pretraining stage:
```
python main_supcon.py --batch_size 1024 \
Expand All @@ -104,7 +104,7 @@ python main_linear.py --batch_size 512 \
On custom dataset:
```
python main_supcon.py --batch_size 1024 \
--learning_rate 0.5 \
--learning_rate 0.5 \
--temp 0.1 --cosine \
--dataset path \
--data_folder ./path \
Expand All @@ -115,7 +115,7 @@ python main_supcon.py --batch_size 1024 \

The `--data_folder` must be of form ./path/label/xxx.png folowing https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision.datasets.ImageFolder convension.

and
and
## t-SNE Visualization

**(1) Standard Cross-Entropy**
Expand Down
36 changes: 17 additions & 19 deletions losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@
Author: Yonglong Tian ([email protected])
Date: May 07, 2020
"""
from __future__ import print_function

import torch
import torch.nn as nn
from torch import nn


class SupConLoss(nn.Module):
"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
It also supports the unsupervised contrastive loss in SimCLR"""
def __init__(self, temperature=0.07, contrast_mode='all',
base_temperature=0.07):

def __init__(self, temperature=0.07, contrast_mode="all", base_temperature=0.07):
super(SupConLoss, self).__init__()
self.temperature = temperature
self.contrast_mode = contrast_mode
Expand All @@ -31,44 +29,44 @@ def forward(self, features, labels=None, mask=None):
Returns:
A loss scalar.
"""
device = (torch.device('cuda')
if features.is_cuda
else torch.device('cpu'))
device = torch.device("cuda") if features.is_cuda else torch.device("cpu")

if len(features.shape) < 3:
raise ValueError('`features` needs to be [bsz, n_views, ...],'
'at least 3 dimensions are required')
raise ValueError(
"`features` needs to be [bsz, n_views, ...],"
"at least 3 dimensions are required"
)
if len(features.shape) > 3:
features = features.view(features.shape[0], features.shape[1], -1)

batch_size = features.shape[0]
if labels is not None and mask is not None:
raise ValueError('Cannot define both `labels` and `mask`')
raise ValueError("Cannot define both `labels` and `mask`")
elif labels is None and mask is None:
mask = torch.eye(batch_size, dtype=torch.float32).to(device)
elif labels is not None:
labels = labels.contiguous().view(-1, 1)
if labels.shape[0] != batch_size:
raise ValueError('Num of labels does not match num of features')
raise ValueError("Num of labels does not match num of features")
mask = torch.eq(labels, labels.T).float().to(device)
else:
mask = mask.float().to(device)

contrast_count = features.shape[1]
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
if self.contrast_mode == 'one':
if self.contrast_mode == "one":
anchor_feature = features[:, 0]
anchor_count = 1
elif self.contrast_mode == 'all':
elif self.contrast_mode == "all":
anchor_feature = contrast_feature
anchor_count = contrast_count
else:
raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
raise ValueError("Unknown mode: {}".format(self.contrast_mode))

# compute logits
anchor_dot_contrast = torch.div(
torch.matmul(anchor_feature, contrast_feature.T),
self.temperature)
torch.matmul(anchor_feature, contrast_feature.T), self.temperature
)
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
Expand All @@ -80,7 +78,7 @@ def forward(self, features, labels=None, mask=None):
torch.ones_like(mask),
1,
torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
0
0,
)
mask = mask * logits_mask

Expand All @@ -92,7 +90,7 @@ def forward(self, features, labels=None, mask=None):
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

# loss
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.view(anchor_count, batch_size).mean()

return loss
Loading