diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..f1d91c99 --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] +exclude = .venv, .idea, .pytest_cache, __pycache__, .git, .scripts/*, logs/*, docker/*, build/* +ignore = E501, E203, W503 +per-file-ignores = */__init__.py: F401 +max-line-length = 88 \ No newline at end of file diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 00000000..bd2a5fd4 --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,4 @@ +[settings] +multi_line_output=3 +lines_after_imports=2 +sections=FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..be27aba1 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,27 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + language_version: python3.10 + args: ["--profile", "black"] + - repo: https://github.com/ambv/black + rev: 23.3.0 + hooks: + - id: black + language_version: python3.10 + - repo: https://github.com/pycqa/flake8 + rev: 6.0.0 + hooks: + - id: flake8 + language_version: python3.10 + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-yaml + - id: debug-statements + - id: trailing-whitespace +default_language_version: + python: python3.10 diff --git a/README.md b/README.md index 4ac5357f..c4c10da7 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,9 @@

-This repo covers an reference implementation for the following papers in PyTorch, using CIFAR as an illustrative example: -(1) Supervised Contrastive Learning. [Paper](https://arxiv.org/abs/2004.11362) -(2) A Simple Framework for Contrastive Learning of Visual Representations. [Paper](https://arxiv.org/abs/2002.05709) +This repo covers an reference implementation for the following papers in PyTorch, using CIFAR as an illustrative example: +(1) Supervised Contrastive Learning. [Paper](https://arxiv.org/abs/2004.11362) +(2) A Simple Framework for Contrastive Learning of Visual Representations. [Paper](https://arxiv.org/abs/2002.05709) ## Update @@ -40,32 +40,32 @@ Results on CIFAR-10: | |Arch | Setting | Loss | Accuracy(%) | |----------|:----:|:---:|:---:|:---:| | SupCrossEntropy | ResNet50 | Supervised | Cross Entropy | 95.0 | -| SupContrast | ResNet50 | Supervised | Contrastive | 96.0 | +| SupContrast | ResNet50 | Supervised | Contrastive | 96.0 | | SimCLR | ResNet50 | Unsupervised | Contrastive | 93.6 | Results on CIFAR-100: | |Arch | Setting | Loss | Accuracy(%) | |----------|:----:|:---:|:---:|:---:| | SupCrossEntropy | ResNet50 | Supervised | Cross Entropy | 75.3 | -| SupContrast | ResNet50 | Supervised | Contrastive | 76.5 | +| SupContrast | ResNet50 | Supervised | Contrastive | 76.5 | | SimCLR | ResNet50 | Unsupervised | Contrastive | 70.7 | Results on ImageNet (Stay tuned): | |Arch | Setting | Loss | Accuracy(%) | |----------|:----:|:---:|:---:|:---:| | SupCrossEntropy | ResNet50 | Supervised | Cross Entropy | - | -| SupContrast | ResNet50 | Supervised | Contrastive | 79.1 (MoCo trick) | +| SupContrast | ResNet50 | Supervised | Contrastive | 79.1 (MoCo trick) | | SimCLR | ResNet50 | Unsupervised | Contrastive | - | ## Running -You might use `CUDA_VISIBLE_DEVICES` to set proper number of GPUs, and/or switch to CIFAR100 by `--dataset cifar100`. +You might use `CUDA_VISIBLE_DEVICES` to set proper number of GPUs, and/or switch to CIFAR100 by `--dataset cifar100`. **(1) Standard Cross-Entropy** ``` python main_ce.py --batch_size 1024 \ --learning_rate 0.8 \ --cosine --syncBN \ ``` -**(2) Supervised Contrastive Learning** +**(2) Supervised Contrastive Learning** Pretraining stage: ``` python main_supcon.py --batch_size 1024 \ @@ -84,7 +84,7 @@ python main_linear.py --batch_size 512 \ --learning_rate 5 \ --ckpt /path/to/model.pth ``` -**(3) SimCLR** +**(3) SimCLR** Pretraining stage: ``` python main_supcon.py --batch_size 1024 \ @@ -104,7 +104,7 @@ python main_linear.py --batch_size 512 \ On custom dataset: ``` python main_supcon.py --batch_size 1024 \ - --learning_rate 0.5 \ + --learning_rate 0.5 \ --temp 0.1 --cosine \ --dataset path \ --data_folder ./path \ @@ -115,7 +115,7 @@ python main_supcon.py --batch_size 1024 \ The `--data_folder` must be of form ./path/label/xxx.png folowing https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision.datasets.ImageFolder convension. -and +and ## t-SNE Visualization **(1) Standard Cross-Entropy** diff --git a/losses.py b/losses.py index 17117d42..c58c4a34 100644 --- a/losses.py +++ b/losses.py @@ -2,17 +2,15 @@ Author: Yonglong Tian (yonglong@mit.edu) Date: May 07, 2020 """ -from __future__ import print_function - import torch -import torch.nn as nn +from torch import nn class SupConLoss(nn.Module): """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. It also supports the unsupervised contrastive loss in SimCLR""" - def __init__(self, temperature=0.07, contrast_mode='all', - base_temperature=0.07): + + def __init__(self, temperature=0.07, contrast_mode="all", base_temperature=0.07): super(SupConLoss, self).__init__() self.temperature = temperature self.contrast_mode = contrast_mode @@ -31,44 +29,44 @@ def forward(self, features, labels=None, mask=None): Returns: A loss scalar. """ - device = (torch.device('cuda') - if features.is_cuda - else torch.device('cpu')) + device = torch.device("cuda") if features.is_cuda else torch.device("cpu") if len(features.shape) < 3: - raise ValueError('`features` needs to be [bsz, n_views, ...],' - 'at least 3 dimensions are required') + raise ValueError( + "`features` needs to be [bsz, n_views, ...]," + "at least 3 dimensions are required" + ) if len(features.shape) > 3: features = features.view(features.shape[0], features.shape[1], -1) batch_size = features.shape[0] if labels is not None and mask is not None: - raise ValueError('Cannot define both `labels` and `mask`') + raise ValueError("Cannot define both `labels` and `mask`") elif labels is None and mask is None: mask = torch.eye(batch_size, dtype=torch.float32).to(device) elif labels is not None: labels = labels.contiguous().view(-1, 1) if labels.shape[0] != batch_size: - raise ValueError('Num of labels does not match num of features') + raise ValueError("Num of labels does not match num of features") mask = torch.eq(labels, labels.T).float().to(device) else: mask = mask.float().to(device) contrast_count = features.shape[1] contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) - if self.contrast_mode == 'one': + if self.contrast_mode == "one": anchor_feature = features[:, 0] anchor_count = 1 - elif self.contrast_mode == 'all': + elif self.contrast_mode == "all": anchor_feature = contrast_feature anchor_count = contrast_count else: - raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) + raise ValueError("Unknown mode: {}".format(self.contrast_mode)) # compute logits anchor_dot_contrast = torch.div( - torch.matmul(anchor_feature, contrast_feature.T), - self.temperature) + torch.matmul(anchor_feature, contrast_feature.T), self.temperature + ) # for numerical stability logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) logits = anchor_dot_contrast - logits_max.detach() @@ -80,7 +78,7 @@ def forward(self, features, labels=None, mask=None): torch.ones_like(mask), 1, torch.arange(batch_size * anchor_count).view(-1, 1).to(device), - 0 + 0, ) mask = mask * logits_mask @@ -92,7 +90,7 @@ def forward(self, features, labels=None, mask=None): mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) # loss - loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos + loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos loss = loss.view(anchor_count, batch_size).mean() return loss diff --git a/main_ce.py b/main_ce.py index 29573d30..5c0f1cb1 100644 --- a/main_ce.py +++ b/main_ce.py @@ -1,168 +1,184 @@ -from __future__ import print_function - +import argparse +import math import os import sys -import argparse import time -import math -import tensorboard_logger as tb_logger import torch -import torch.backends.cudnn as cudnn -from torchvision import transforms, datasets +from torch.backends import cudnn +from torchvision import datasets, transforms -from util import AverageMeter -from util import adjust_learning_rate, warmup_learning_rate, accuracy -from util import set_optimizer, save_model from networks.resnet_big import SupCEResNet - -try: - import apex - from apex import amp, optimizers -except ImportError: - pass +from util import ( + AverageMeter, + accuracy, + adjust_learning_rate, + save_model, + set_optimizer, + warmup_learning_rate, +) def parse_option(): - parser = argparse.ArgumentParser('argument for training') - - parser.add_argument('--print_freq', type=int, default=10, - help='print frequency') - parser.add_argument('--save_freq', type=int, default=50, - help='save frequency') - parser.add_argument('--batch_size', type=int, default=256, - help='batch_size') - parser.add_argument('--num_workers', type=int, default=16, - help='num of workers to use') - parser.add_argument('--epochs', type=int, default=500, - help='number of training epochs') + parser = argparse.ArgumentParser("argument for training") + + parser.add_argument("--print_freq", type=int, default=10, help="print frequency") + parser.add_argument("--save_freq", type=int, default=50, help="save frequency") + parser.add_argument("--batch_size", type=int, default=256, help="batch_size") + parser.add_argument( + "--num_workers", type=int, default=16, help="num of workers to use" + ) + parser.add_argument( + "--epochs", type=int, default=500, help="number of training epochs" + ) # optimization - parser.add_argument('--learning_rate', type=float, default=0.2, - help='learning rate') - parser.add_argument('--lr_decay_epochs', type=str, default='350,400,450', - help='where to decay lr, can be a list') - parser.add_argument('--lr_decay_rate', type=float, default=0.1, - help='decay rate for learning rate') - parser.add_argument('--weight_decay', type=float, default=1e-4, - help='weight decay') - parser.add_argument('--momentum', type=float, default=0.9, - help='momentum') + parser.add_argument( + "--learning_rate", type=float, default=0.2, help="learning rate" + ) + parser.add_argument( + "--lr_decay_epochs", + type=str, + default="350,400,450", + help="where to decay lr, can be a list", + ) + parser.add_argument( + "--lr_decay_rate", type=float, default=0.1, help="decay rate for learning rate" + ) + parser.add_argument("--weight_decay", type=float, default=1e-4, help="weight decay") + parser.add_argument("--momentum", type=float, default=0.9, help="momentum") # model dataset - parser.add_argument('--model', type=str, default='resnet50') - parser.add_argument('--dataset', type=str, default='cifar10', - choices=['cifar10', 'cifar100'], help='dataset') + parser.add_argument("--model", type=str, default="resnet50") + parser.add_argument( + "--dataset", + type=str, + default="cifar10", + choices=["cifar10", "cifar100"], + help="dataset", + ) # other setting - parser.add_argument('--cosine', action='store_true', - help='using cosine annealing') - parser.add_argument('--syncBN', action='store_true', - help='using synchronized batch normalization') - parser.add_argument('--warm', action='store_true', - help='warm-up for large batch training') - parser.add_argument('--trial', type=str, default='0', - help='id for recording multiple runs') + parser.add_argument("--cosine", action="store_true", help="using cosine annealing") + parser.add_argument( + "--warm", action="store_true", help="warm-up for large batch training" + ) + parser.add_argument( + "--trial", type=str, default="0", help="id for recording multiple runs" + ) opt = parser.parse_args() # set the path according to the environment - opt.data_folder = './datasets/' - opt.model_path = './save/SupCon/{}_models'.format(opt.dataset) - opt.tb_path = './save/SupCon/{}_tensorboard'.format(opt.dataset) + opt.data_folder = "./datasets/" + opt.model_path = "./save/SupCon/{}_models".format(opt.dataset) - iterations = opt.lr_decay_epochs.split(',') + iterations = opt.lr_decay_epochs.split(",") opt.lr_decay_epochs = list([]) for it in iterations: opt.lr_decay_epochs.append(int(it)) - opt.model_name = 'SupCE_{}_{}_lr_{}_decay_{}_bsz_{}_trial_{}'.\ - format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay, - opt.batch_size, opt.trial) + opt.model_name = "SupCE_{}_{}_lr_{}_decay_{}_bsz_{}_trial_{}".format( + opt.dataset, + opt.model, + opt.learning_rate, + opt.weight_decay, + opt.batch_size, + opt.trial, + ) if opt.cosine: - opt.model_name = '{}_cosine'.format(opt.model_name) + opt.model_name = "{}_cosine".format(opt.model_name) # warm-up for large-batch training, if opt.batch_size > 256: opt.warm = True if opt.warm: - opt.model_name = '{}_warm'.format(opt.model_name) + opt.model_name = "{}_warm".format(opt.model_name) opt.warmup_from = 0.01 opt.warm_epochs = 10 if opt.cosine: - eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) - opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * ( - 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2 + eta_min = opt.learning_rate * (opt.lr_decay_rate**3) + opt.warmup_to = ( + eta_min + + (opt.learning_rate - eta_min) + * (1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) + / 2 + ) else: opt.warmup_to = opt.learning_rate - opt.tb_folder = os.path.join(opt.tb_path, opt.model_name) - if not os.path.isdir(opt.tb_folder): - os.makedirs(opt.tb_folder) - opt.save_folder = os.path.join(opt.model_path, opt.model_name) if not os.path.isdir(opt.save_folder): os.makedirs(opt.save_folder) - if opt.dataset == 'cifar10': + if opt.dataset == "cifar10": opt.n_cls = 10 - elif opt.dataset == 'cifar100': + elif opt.dataset == "cifar100": opt.n_cls = 100 else: - raise ValueError('dataset not supported: {}'.format(opt.dataset)) + raise ValueError("dataset not supported: {}".format(opt.dataset)) return opt def set_loader(opt): # construct data loader - if opt.dataset == 'cifar10': + if opt.dataset == "cifar10": mean = (0.4914, 0.4822, 0.4465) std = (0.2023, 0.1994, 0.2010) - elif opt.dataset == 'cifar100': + elif opt.dataset == "cifar100": mean = (0.5071, 0.4867, 0.4408) std = (0.2675, 0.2565, 0.2761) else: - raise ValueError('dataset not supported: {}'.format(opt.dataset)) + raise ValueError("dataset not supported: {}".format(opt.dataset)) normalize = transforms.Normalize(mean=mean, std=std) - train_transform = transforms.Compose([ - transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ]) - - val_transform = transforms.Compose([ - transforms.ToTensor(), - normalize, - ]) - - if opt.dataset == 'cifar10': - train_dataset = datasets.CIFAR10(root=opt.data_folder, - transform=train_transform, - download=True) - val_dataset = datasets.CIFAR10(root=opt.data_folder, - train=False, - transform=val_transform) - elif opt.dataset == 'cifar100': - train_dataset = datasets.CIFAR100(root=opt.data_folder, - transform=train_transform, - download=True) - val_dataset = datasets.CIFAR100(root=opt.data_folder, - train=False, - transform=val_transform) + train_transform = transforms.Compose( + [ + transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ] + ) + + val_transform = transforms.Compose( + [ + transforms.ToTensor(), + normalize, + ] + ) + + if opt.dataset == "cifar10": + train_dataset = datasets.CIFAR10( + root=opt.data_folder, transform=train_transform, download=True + ) + val_dataset = datasets.CIFAR10( + root=opt.data_folder, train=False, transform=val_transform + ) + elif opt.dataset == "cifar100": + train_dataset = datasets.CIFAR100( + root=opt.data_folder, transform=train_transform, download=True + ) + val_dataset = datasets.CIFAR100( + root=opt.data_folder, train=False, transform=val_transform + ) else: raise ValueError(opt.dataset) train_sampler = None train_loader = torch.utils.data.DataLoader( - train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None), - num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler) + train_dataset, + batch_size=opt.batch_size, + shuffle=(train_sampler is None), + num_workers=opt.num_workers, + pin_memory=True, + sampler=train_sampler, + ) val_loader = torch.utils.data.DataLoader( - val_dataset, batch_size=256, shuffle=False, - num_workers=8, pin_memory=True) + val_dataset, batch_size=256, shuffle=False, num_workers=8, pin_memory=True + ) return train_loader, val_loader @@ -171,10 +187,6 @@ def set_model(opt): model = SupCEResNet(name=opt.model, num_classes=opt.n_cls) criterion = torch.nn.CrossEntropyLoss() - # enable synchronized Batch Normalization - if opt.syncBN: - model = apex.parallel.convert_syncbn_model(model) - if torch.cuda.is_available(): if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) @@ -225,13 +237,21 @@ def train(train_loader, model, criterion, optimizer, epoch, opt): # print info if (idx + 1) % opt.print_freq == 0: - print('Train: [{0}][{1}/{2}]\t' - 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' - 'loss {loss.val:.3f} ({loss.avg:.3f})\t' - 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( - epoch, idx + 1, len(train_loader), batch_time=batch_time, - data_time=data_time, loss=losses, top1=top1)) + print( + "Train: [{0}][{1}/{2}]\t" + "BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t" + "DT {data_time.val:.3f} ({data_time.avg:.3f})\t" + "loss {loss.val:.3f} ({loss.avg:.3f})\t" + "Acc@1 {top1.val:.3f} ({top1.avg:.3f})".format( + epoch, + idx + 1, + len(train_loader), + batch_time=batch_time, + data_time=data_time, + loss=losses, + top1=top1, + ) + ) sys.stdout.flush() return losses.avg, top1.avg @@ -266,14 +286,20 @@ def validate(val_loader, model, criterion, opt): end = time.time() if idx % opt.print_freq == 0: - print('Test: [{0}/{1}]\t' - 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' - 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( - idx, len(val_loader), batch_time=batch_time, - loss=losses, top1=top1)) - - print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) + print( + "Test: [{0}/{1}]\t" + "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" + "Loss {loss.val:.4f} ({loss.avg:.4f})\t" + "Acc@1 {top1.val:.3f} ({top1.avg:.3f})".format( + idx, + len(val_loader), + batch_time=batch_time, + loss=losses, + top1=top1, + ) + ) + + print(" * Acc@1 {top1.avg:.3f}".format(top1=top1)) return losses.avg, top1.avg @@ -290,9 +316,6 @@ def main(): # build optimizer optimizer = set_optimizer(opt, model) - # tensorboard - logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) - # training routine for epoch in range(1, opt.epochs + 1): adjust_learning_rate(opt, optimizer, epoch) @@ -301,33 +324,33 @@ def main(): time1 = time.time() loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, opt) time2 = time.time() - print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) + print("epoch {}, total time {:.2f}".format(epoch, time2 - time1)) # tensorboard logger - logger.log_value('train_loss', loss, epoch) - logger.log_value('train_acc', train_acc, epoch) - logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) + print(f"epoch {epoch} train_loss: {loss}") + print(f"\ttrain_acc: {train_acc}") + print(f"\tlearning_rate {optimizer.param_groups[0]['lr']}") # evaluation loss, val_acc = validate(val_loader, model, criterion, opt) - logger.log_value('val_loss', loss, epoch) - logger.log_value('val_acc', val_acc, epoch) + print(f"epoch {epoch} val_loss: {loss}") + print(f"\tval_acc: {val_acc}") if val_acc > best_acc: best_acc = val_acc if epoch % opt.save_freq == 0: save_file = os.path.join( - opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) + opt.save_folder, "ckpt_epoch_{epoch}.pth".format(epoch=epoch) + ) save_model(model, optimizer, opt, epoch, save_file) # save the last model - save_file = os.path.join( - opt.save_folder, 'last.pth') + save_file = os.path.join(opt.save_folder, "last.pth") save_model(model, optimizer, opt, opt.epochs, save_file) - print('best accuracy: {:.2f}'.format(best_acc)) + print("best accuracy: {:.2f}".format(best_acc)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/main_linear.py b/main_linear.py index 1b5e7960..579a5eef 100644 --- a/main_linear.py +++ b/main_linear.py @@ -1,101 +1,110 @@ -from __future__ import print_function - -import sys import argparse -import time import math +import sys +import time import torch -import torch.backends.cudnn as cudnn +from torch.backends import cudnn from main_ce import set_loader -from util import AverageMeter -from util import adjust_learning_rate, warmup_learning_rate, accuracy -from util import set_optimizer -from networks.resnet_big import SupConResNet, LinearClassifier - -try: - import apex - from apex import amp, optimizers -except ImportError: - pass +from networks.resnet_big import LinearClassifier, SupConResNet +from util import ( + AverageMeter, + accuracy, + adjust_learning_rate, + set_optimizer, + warmup_learning_rate, +) def parse_option(): - parser = argparse.ArgumentParser('argument for training') - - parser.add_argument('--print_freq', type=int, default=10, - help='print frequency') - parser.add_argument('--save_freq', type=int, default=50, - help='save frequency') - parser.add_argument('--batch_size', type=int, default=256, - help='batch_size') - parser.add_argument('--num_workers', type=int, default=16, - help='num of workers to use') - parser.add_argument('--epochs', type=int, default=100, - help='number of training epochs') + parser = argparse.ArgumentParser("argument for training") + + parser.add_argument("--print_freq", type=int, default=10, help="print frequency") + parser.add_argument("--save_freq", type=int, default=50, help="save frequency") + parser.add_argument("--batch_size", type=int, default=256, help="batch_size") + parser.add_argument( + "--num_workers", type=int, default=16, help="num of workers to use" + ) + parser.add_argument( + "--epochs", type=int, default=100, help="number of training epochs" + ) # optimization - parser.add_argument('--learning_rate', type=float, default=0.1, - help='learning rate') - parser.add_argument('--lr_decay_epochs', type=str, default='60,75,90', - help='where to decay lr, can be a list') - parser.add_argument('--lr_decay_rate', type=float, default=0.2, - help='decay rate for learning rate') - parser.add_argument('--weight_decay', type=float, default=0, - help='weight decay') - parser.add_argument('--momentum', type=float, default=0.9, - help='momentum') + parser.add_argument( + "--learning_rate", type=float, default=0.1, help="learning rate" + ) + parser.add_argument( + "--lr_decay_epochs", + type=str, + default="60,75,90", + help="where to decay lr, can be a list", + ) + parser.add_argument( + "--lr_decay_rate", type=float, default=0.2, help="decay rate for learning rate" + ) + parser.add_argument("--weight_decay", type=float, default=0, help="weight decay") + parser.add_argument("--momentum", type=float, default=0.9, help="momentum") # model dataset - parser.add_argument('--model', type=str, default='resnet50') - parser.add_argument('--dataset', type=str, default='cifar10', - choices=['cifar10', 'cifar100'], help='dataset') + parser.add_argument("--model", type=str, default="resnet50") + parser.add_argument( + "--dataset", + type=str, + default="cifar10", + choices=["cifar10", "cifar100"], + help="dataset", + ) # other setting - parser.add_argument('--cosine', action='store_true', - help='using cosine annealing') - parser.add_argument('--warm', action='store_true', - help='warm-up for large batch training') + parser.add_argument("--cosine", action="store_true", help="using cosine annealing") + parser.add_argument( + "--warm", action="store_true", help="warm-up for large batch training" + ) - parser.add_argument('--ckpt', type=str, default='', - help='path to pre-trained model') + parser.add_argument( + "--ckpt", type=str, default="", help="path to pre-trained model" + ) opt = parser.parse_args() # set the path according to the environment - opt.data_folder = './datasets/' + opt.data_folder = "./datasets/" - iterations = opt.lr_decay_epochs.split(',') + iterations = opt.lr_decay_epochs.split(",") opt.lr_decay_epochs = list([]) for it in iterations: opt.lr_decay_epochs.append(int(it)) - opt.model_name = '{}_{}_lr_{}_decay_{}_bsz_{}'.\ - format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay, - opt.batch_size) + opt.model_name = "{}_{}_lr_{}_decay_{}_bsz_{}".format( + opt.dataset, opt.model, opt.learning_rate, opt.weight_decay, opt.batch_size + ) if opt.cosine: - opt.model_name = '{}_cosine'.format(opt.model_name) + opt.model_name = "{}_cosine".format(opt.model_name) # warm-up for large-batch training, if opt.warm: - opt.model_name = '{}_warm'.format(opt.model_name) + opt.model_name = "{}_warm".format(opt.model_name) opt.warmup_from = 0.01 opt.warm_epochs = 10 if opt.cosine: - eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) - opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * ( - 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2 + eta_min = opt.learning_rate * (opt.lr_decay_rate**3) + opt.warmup_to = ( + eta_min + + (opt.learning_rate - eta_min) + * (1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) + / 2 + ) else: opt.warmup_to = opt.learning_rate - if opt.dataset == 'cifar10': + if opt.dataset == "cifar10": opt.n_cls = 10 - elif opt.dataset == 'cifar100': + elif opt.dataset == "cifar100": opt.n_cls = 100 else: - raise ValueError('dataset not supported: {}'.format(opt.dataset)) + raise ValueError("dataset not supported: {}".format(opt.dataset)) return opt @@ -106,8 +115,8 @@ def set_model(opt): classifier = LinearClassifier(name=opt.model, num_classes=opt.n_cls) - ckpt = torch.load(opt.ckpt, map_location='cpu') - state_dict = ckpt['model'] + ckpt = torch.load(opt.ckpt, map_location="cpu") + state_dict = ckpt["model"] if torch.cuda.is_available(): if torch.cuda.device_count() > 1: @@ -125,7 +134,7 @@ def set_model(opt): model.load_state_dict(state_dict) else: - raise NotImplementedError('This code requires GPU') + raise NotImplementedError("This code requires GPU") return model, classifier, criterion @@ -173,13 +182,21 @@ def train(train_loader, model, classifier, criterion, optimizer, epoch, opt): # print info if (idx + 1) % opt.print_freq == 0: - print('Train: [{0}][{1}/{2}]\t' - 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' - 'loss {loss.val:.3f} ({loss.avg:.3f})\t' - 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( - epoch, idx + 1, len(train_loader), batch_time=batch_time, - data_time=data_time, loss=losses, top1=top1)) + print( + "Train: [{0}][{1}/{2}]\t" + "BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t" + "DT {data_time.val:.3f} ({data_time.avg:.3f})\t" + "loss {loss.val:.3f} ({loss.avg:.3f})\t" + "Acc@1 {top1.val:.3f} ({top1.avg:.3f})".format( + epoch, + idx + 1, + len(train_loader), + batch_time=batch_time, + data_time=data_time, + loss=losses, + top1=top1, + ) + ) sys.stdout.flush() return losses.avg, top1.avg @@ -215,14 +232,20 @@ def validate(val_loader, model, classifier, criterion, opt): end = time.time() if idx % opt.print_freq == 0: - print('Test: [{0}/{1}]\t' - 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' - 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( - idx, len(val_loader), batch_time=batch_time, - loss=losses, top1=top1)) - - print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) + print( + "Test: [{0}/{1}]\t" + "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" + "Loss {loss.val:.4f} ({loss.avg:.4f})\t" + "Acc@1 {top1.val:.3f} ({top1.avg:.3f})".format( + idx, + len(val_loader), + batch_time=batch_time, + loss=losses, + top1=top1, + ) + ) + + print(" * Acc@1 {top1.avg:.3f}".format(top1=top1)) return losses.avg, top1.avg @@ -245,19 +268,23 @@ def main(): # train for one epoch time1 = time.time() - loss, acc = train(train_loader, model, classifier, criterion, - optimizer, epoch, opt) + loss, acc = train( + train_loader, model, classifier, criterion, optimizer, epoch, opt + ) time2 = time.time() - print('Train epoch {}, total time {:.2f}, accuracy:{:.2f}'.format( - epoch, time2 - time1, acc)) + print( + "Train epoch {}, total time {:.2f}, accuracy:{:.2f}".format( + epoch, time2 - time1, acc + ) + ) # eval for one epoch loss, val_acc = validate(val_loader, model, classifier, criterion, opt) if val_acc > best_acc: best_acc = val_acc - print('best accuracy: {:.2f}'.format(best_acc)) + print("best accuracy: {:.2f}".format(best_acc)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/main_supcon.py b/main_supcon.py index ea6a625f..1372f184 100644 --- a/main_supcon.py +++ b/main_supcon.py @@ -1,126 +1,149 @@ -from __future__ import print_function - +import argparse +import math import os import sys -import argparse import time -import math -import tensorboard_logger as tb_logger import torch -import torch.backends.cudnn as cudnn -from torchvision import transforms, datasets +from torch.backends import cudnn +from torchvision import datasets, transforms -from util import TwoCropTransform, AverageMeter -from util import adjust_learning_rate, warmup_learning_rate -from util import set_optimizer, save_model -from networks.resnet_big import SupConResNet from losses import SupConLoss - -try: - import apex - from apex import amp, optimizers -except ImportError: - pass +from networks.resnet_big import SupConResNet +from util import ( + AverageMeter, + TwoCropTransform, + adjust_learning_rate, + save_model, + set_optimizer, + warmup_learning_rate, +) def parse_option(): - parser = argparse.ArgumentParser('argument for training') - - parser.add_argument('--print_freq', type=int, default=10, - help='print frequency') - parser.add_argument('--save_freq', type=int, default=50, - help='save frequency') - parser.add_argument('--batch_size', type=int, default=256, - help='batch_size') - parser.add_argument('--num_workers', type=int, default=16, - help='num of workers to use') - parser.add_argument('--epochs', type=int, default=1000, - help='number of training epochs') + parser = argparse.ArgumentParser("argument for training") + + parser.add_argument("--print_freq", type=int, default=10, help="print frequency") + parser.add_argument("--save_freq", type=int, default=50, help="save frequency") + parser.add_argument("--batch_size", type=int, default=256, help="batch_size") + parser.add_argument( + "--num_workers", type=int, default=16, help="num of workers to use" + ) + parser.add_argument( + "--epochs", type=int, default=1000, help="number of training epochs" + ) # optimization - parser.add_argument('--learning_rate', type=float, default=0.05, - help='learning rate') - parser.add_argument('--lr_decay_epochs', type=str, default='700,800,900', - help='where to decay lr, can be a list') - parser.add_argument('--lr_decay_rate', type=float, default=0.1, - help='decay rate for learning rate') - parser.add_argument('--weight_decay', type=float, default=1e-4, - help='weight decay') - parser.add_argument('--momentum', type=float, default=0.9, - help='momentum') + parser.add_argument( + "--learning_rate", type=float, default=0.05, help="learning rate" + ) + parser.add_argument( + "--lr_decay_epochs", + type=str, + default="700,800,900", + help="where to decay lr, can be a list", + ) + parser.add_argument( + "--lr_decay_rate", type=float, default=0.1, help="decay rate for learning rate" + ) + parser.add_argument("--weight_decay", type=float, default=1e-4, help="weight decay") + parser.add_argument("--momentum", type=float, default=0.9, help="momentum") # model dataset - parser.add_argument('--model', type=str, default='resnet50') - parser.add_argument('--dataset', type=str, default='cifar10', - choices=['cifar10', 'cifar100', 'path'], help='dataset') - parser.add_argument('--mean', type=str, help='mean of dataset in path in form of str tuple') - parser.add_argument('--std', type=str, help='std of dataset in path in form of str tuple') - parser.add_argument('--data_folder', type=str, default=None, help='path to custom dataset') - parser.add_argument('--size', type=int, default=32, help='parameter for RandomResizedCrop') + parser.add_argument("--model", type=str, default="resnet50") + parser.add_argument( + "--dataset", + type=str, + default="cifar10", + choices=["cifar10", "cifar100", "path"], + help="dataset", + ) + parser.add_argument( + "--mean", type=str, help="mean of dataset in path in form of str tuple" + ) + parser.add_argument( + "--std", type=str, help="std of dataset in path in form of str tuple" + ) + parser.add_argument( + "--data_folder", type=str, default=None, help="path to custom dataset" + ) + parser.add_argument( + "--size", type=int, default=32, help="parameter for RandomResizedCrop" + ) # method - parser.add_argument('--method', type=str, default='SupCon', - choices=['SupCon', 'SimCLR'], help='choose method') + parser.add_argument( + "--method", + type=str, + default="SupCon", + choices=["SupCon", "SimCLR"], + help="choose method", + ) # temperature - parser.add_argument('--temp', type=float, default=0.07, - help='temperature for loss function') + parser.add_argument( + "--temp", type=float, default=0.07, help="temperature for loss function" + ) # other setting - parser.add_argument('--cosine', action='store_true', - help='using cosine annealing') - parser.add_argument('--syncBN', action='store_true', - help='using synchronized batch normalization') - parser.add_argument('--warm', action='store_true', - help='warm-up for large batch training') - parser.add_argument('--trial', type=str, default='0', - help='id for recording multiple runs') + parser.add_argument("--cosine", action="store_true", help="using cosine annealing") + parser.add_argument( + "--warm", action="store_true", help="warm-up for large batch training" + ) + parser.add_argument( + "--trial", type=str, default="0", help="id for recording multiple runs" + ) opt = parser.parse_args() # check if dataset is path that passed required arguments - if opt.dataset == 'path': - assert opt.data_folder is not None \ - and opt.mean is not None \ - and opt.std is not None + if opt.dataset == "path": + assert ( + opt.data_folder is not None and opt.mean is not None and opt.std is not None + ) # set the path according to the environment if opt.data_folder is None: - opt.data_folder = './datasets/' - opt.model_path = './save/SupCon/{}_models'.format(opt.dataset) - opt.tb_path = './save/SupCon/{}_tensorboard'.format(opt.dataset) + opt.data_folder = "./datasets/" + opt.model_path = "./save/SupCon/{}_models".format(opt.dataset) - iterations = opt.lr_decay_epochs.split(',') + iterations = opt.lr_decay_epochs.split(",") opt.lr_decay_epochs = list([]) for it in iterations: opt.lr_decay_epochs.append(int(it)) - opt.model_name = '{}_{}_{}_lr_{}_decay_{}_bsz_{}_temp_{}_trial_{}'.\ - format(opt.method, opt.dataset, opt.model, opt.learning_rate, - opt.weight_decay, opt.batch_size, opt.temp, opt.trial) + opt.model_name = "{}_{}_{}_lr_{}_decay_{}_bsz_{}_temp_{}_trial_{}".format( + opt.method, + opt.dataset, + opt.model, + opt.learning_rate, + opt.weight_decay, + opt.batch_size, + opt.temp, + opt.trial, + ) if opt.cosine: - opt.model_name = '{}_cosine'.format(opt.model_name) + opt.model_name = "{}_cosine".format(opt.model_name) # warm-up for large-batch training, if opt.batch_size > 256: opt.warm = True if opt.warm: - opt.model_name = '{}_warm'.format(opt.model_name) + opt.model_name = "{}_warm".format(opt.model_name) opt.warmup_from = 0.01 opt.warm_epochs = 10 if opt.cosine: - eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) - opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * ( - 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2 + eta_min = opt.learning_rate * (opt.lr_decay_rate**3) + opt.warmup_to = ( + eta_min + + (opt.learning_rate - eta_min) + * (1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) + / 2 + ) else: opt.warmup_to = opt.learning_rate - opt.tb_folder = os.path.join(opt.tb_path, opt.model_name) - if not os.path.isdir(opt.tb_folder): - os.makedirs(opt.tb_folder) - opt.save_folder = os.path.join(opt.model_path, opt.model_name) if not os.path.isdir(opt.save_folder): os.makedirs(opt.save_folder) @@ -130,48 +153,58 @@ def parse_option(): def set_loader(opt): # construct data loader - if opt.dataset == 'cifar10': + if opt.dataset == "cifar10": mean = (0.4914, 0.4822, 0.4465) std = (0.2023, 0.1994, 0.2010) - elif opt.dataset == 'cifar100': + elif opt.dataset == "cifar100": mean = (0.5071, 0.4867, 0.4408) std = (0.2675, 0.2565, 0.2761) - elif opt.dataset == 'path': + elif opt.dataset == "path": mean = eval(opt.mean) std = eval(opt.std) else: - raise ValueError('dataset not supported: {}'.format(opt.dataset)) + raise ValueError("dataset not supported: {}".format(opt.dataset)) normalize = transforms.Normalize(mean=mean, std=std) - train_transform = transforms.Compose([ - transforms.RandomResizedCrop(size=opt.size, scale=(0.2, 1.)), - transforms.RandomHorizontalFlip(), - transforms.RandomApply([ - transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) - ], p=0.8), - transforms.RandomGrayscale(p=0.2), - transforms.ToTensor(), - normalize, - ]) - - if opt.dataset == 'cifar10': - train_dataset = datasets.CIFAR10(root=opt.data_folder, - transform=TwoCropTransform(train_transform), - download=True) - elif opt.dataset == 'cifar100': - train_dataset = datasets.CIFAR100(root=opt.data_folder, - transform=TwoCropTransform(train_transform), - download=True) - elif opt.dataset == 'path': - train_dataset = datasets.ImageFolder(root=opt.data_folder, - transform=TwoCropTransform(train_transform)) + train_transform = transforms.Compose( + [ + transforms.RandomResizedCrop(size=opt.size, scale=(0.2, 1.0)), + transforms.RandomHorizontalFlip(), + transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.ToTensor(), + normalize, + ] + ) + + if opt.dataset == "cifar10": + train_dataset = datasets.CIFAR10( + root=opt.data_folder, + transform=TwoCropTransform(train_transform), + download=True, + ) + elif opt.dataset == "cifar100": + train_dataset = datasets.CIFAR100( + root=opt.data_folder, + transform=TwoCropTransform(train_transform), + download=True, + ) + elif opt.dataset == "path": + train_dataset = datasets.ImageFolder( + root=opt.data_folder, transform=TwoCropTransform(train_transform) + ) else: raise ValueError(opt.dataset) train_sampler = None train_loader = torch.utils.data.DataLoader( - train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None), - num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler) + train_dataset, + batch_size=opt.batch_size, + shuffle=(train_sampler is None), + num_workers=opt.num_workers, + pin_memory=True, + sampler=train_sampler, + ) return train_loader @@ -180,10 +213,6 @@ def set_model(opt): model = SupConResNet(name=opt.model) criterion = SupConLoss(temperature=opt.temp) - # enable synchronized Batch Normalization - if opt.syncBN: - model = apex.parallel.convert_syncbn_model(model) - if torch.cuda.is_available(): if torch.cuda.device_count() > 1: model.encoder = torch.nn.DataParallel(model.encoder) @@ -219,13 +248,12 @@ def train(train_loader, model, criterion, optimizer, epoch, opt): features = model(images) f1, f2 = torch.split(features, [bsz, bsz], dim=0) features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) - if opt.method == 'SupCon': + if opt.method == "SupCon": loss = criterion(features, labels) - elif opt.method == 'SimCLR': + elif opt.method == "SimCLR": loss = criterion(features) else: - raise ValueError('contrastive method not supported: {}'. - format(opt.method)) + raise ValueError("contrastive method not supported: {}".format(opt.method)) # update metric losses.update(loss.item(), bsz) @@ -241,12 +269,19 @@ def train(train_loader, model, criterion, optimizer, epoch, opt): # print info if (idx + 1) % opt.print_freq == 0: - print('Train: [{0}][{1}/{2}]\t' - 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' - 'loss {loss.val:.3f} ({loss.avg:.3f})'.format( - epoch, idx + 1, len(train_loader), batch_time=batch_time, - data_time=data_time, loss=losses)) + print( + "Train: [{0}][{1}/{2}]\t" + "BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t" + "DT {data_time.val:.3f} ({data_time.avg:.3f})\t" + "loss {loss.val:.3f} ({loss.avg:.3f})".format( + epoch, + idx + 1, + len(train_loader), + batch_time=batch_time, + data_time=data_time, + loss=losses, + ) + ) sys.stdout.flush() return losses.avg @@ -264,9 +299,6 @@ def main(): # build optimizer optimizer = set_optimizer(opt, model) - # tensorboard - logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) - # training routine for epoch in range(1, opt.epochs + 1): adjust_learning_rate(opt, optimizer, epoch) @@ -275,22 +307,22 @@ def main(): time1 = time.time() loss = train(train_loader, model, criterion, optimizer, epoch, opt) time2 = time.time() - print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) + print("epoch {}, total time {:.2f}".format(epoch, time2 - time1)) # tensorboard logger - logger.log_value('loss', loss, epoch) - logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) + print(f"epoch {epoch} loss: {loss}") + print(f"learning_rate: {optimizer.param_groups[0]['lr']}") if epoch % opt.save_freq == 0: save_file = os.path.join( - opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) + opt.save_folder, "ckpt_epoch_{epoch}.pth".format(epoch=epoch) + ) save_model(model, optimizer, opt, epoch, save_file) # save the last model - save_file = os.path.join( - opt.save_folder, 'last.pth') + save_file = os.path.join(opt.save_folder, "last.pth") save_model(model, optimizer, opt, opt.epochs, save_file) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/networks/resnet_big.py b/networks/resnet_big.py index 4f4e37df..2299e410 100644 --- a/networks/resnet_big.py +++ b/networks/resnet_big.py @@ -5,8 +5,8 @@ Adapted from: https://github.com/bearpaw/pytorch-classification """ import torch -import torch.nn as nn -import torch.nn.functional as F +from torch import nn +from torch.nn import functional as F class BasicBlock(nn.Module): @@ -15,16 +15,26 @@ class BasicBlock(nn.Module): def __init__(self, in_planes, planes, stride=1, is_last=False): super(BasicBlock, self).__init__() self.is_last = is_last - self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=1, bias=False + ) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( - nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(self.expansion * planes) + nn.Conv2d( + in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(self.expansion * planes), ) def forward(self, x): @@ -47,16 +57,26 @@ def __init__(self, in_planes, planes, stride=1, is_last=False): self.is_last = is_last self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) self.bn2 = nn.BatchNorm2d(planes) - self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) + self.conv3 = nn.Conv2d( + planes, self.expansion * planes, kernel_size=1, bias=False + ) self.bn3 = nn.BatchNorm2d(self.expansion * planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( - nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(self.expansion * planes) + nn.Conv2d( + in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(self.expansion * planes), ) def forward(self, x): @@ -77,8 +97,9 @@ def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False): super(ResNet, self).__init__() self.in_planes = 64 - self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, - bias=False) + self.conv1 = nn.Conv2d( + in_channel, 64, kernel_size=3, stride=1, padding=1, bias=False + ) self.bn1 = nn.BatchNorm2d(64) self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) @@ -88,7 +109,7 @@ def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False): for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) @@ -141,15 +162,16 @@ def resnet101(**kwargs): model_dict = { - 'resnet18': [resnet18, 512], - 'resnet34': [resnet34, 512], - 'resnet50': [resnet50, 2048], - 'resnet101': [resnet101, 2048], + "resnet18": [resnet18, 512], + "resnet34": [resnet34, 512], + "resnet50": [resnet50, 2048], + "resnet101": [resnet101, 2048], } class LinearBatchNorm(nn.Module): """Implements BatchNorm1d by BatchNorm2d, for SyncBN purpose""" + def __init__(self, dim, affine=True): super(LinearBatchNorm, self).__init__() self.dim = dim @@ -164,21 +186,21 @@ def forward(self, x): class SupConResNet(nn.Module): """backbone + projection head""" - def __init__(self, name='resnet50', head='mlp', feat_dim=128): + + def __init__(self, name="resnet50", head="mlp", feat_dim=128): super(SupConResNet, self).__init__() model_fun, dim_in = model_dict[name] self.encoder = model_fun() - if head == 'linear': + if head == "linear": self.head = nn.Linear(dim_in, feat_dim) - elif head == 'mlp': + elif head == "mlp": self.head = nn.Sequential( nn.Linear(dim_in, dim_in), nn.ReLU(inplace=True), - nn.Linear(dim_in, feat_dim) + nn.Linear(dim_in, feat_dim), ) else: - raise NotImplementedError( - 'head not supported: {}'.format(head)) + raise NotImplementedError("head not supported: {}".format(head)) def forward(self, x): feat = self.encoder(x) @@ -188,7 +210,8 @@ def forward(self, x): class SupCEResNet(nn.Module): """encoder + classifier""" - def __init__(self, name='resnet50', num_classes=10): + + def __init__(self, name="resnet50", num_classes=10): super(SupCEResNet, self).__init__() model_fun, dim_in = model_dict[name] self.encoder = model_fun() @@ -200,7 +223,8 @@ def forward(self, x): class LinearClassifier(nn.Module): """Linear classifier""" - def __init__(self, name='resnet50', num_classes=10): + + def __init__(self, name="resnet50", num_classes=10): super(LinearClassifier, self).__init__() _, feat_dim = model_dict[name] self.fc = nn.Linear(feat_dim, num_classes) diff --git a/poetry.lock b/poetry.lock new file mode 100644 index 00000000..bc072cc7 --- /dev/null +++ b/poetry.lock @@ -0,0 +1,569 @@ +# This file is automatically @generated by Poetry 1.4.0 and should not be changed by hand. + +[[package]] +name = "certifi" +version = "2023.7.22" +description = "Python package for providing Mozilla's CA Bundle." +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "certifi-2023.7.22-py3-none-any.whl", hash = "sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9"}, + {file = "certifi-2023.7.22.tar.gz", hash = "sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082"}, +] + +[[package]] +name = "charset-normalizer" +version = "3.3.1" +description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +category = "main" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "charset-normalizer-3.3.1.tar.gz", hash = "sha256:d9137a876020661972ca6eec0766d81aef8a5627df628b664b234b73396e727e"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8aee051c89e13565c6bd366813c386939f8e928af93c29fda4af86d25b73d8f8"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:352a88c3df0d1fa886562384b86f9a9e27563d4704ee0e9d56ec6fcd270ea690"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:223b4d54561c01048f657fa6ce41461d5ad8ff128b9678cfe8b2ecd951e3f8a2"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f861d94c2a450b974b86093c6c027888627b8082f1299dfd5a4bae8e2292821"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1171ef1fc5ab4693c5d151ae0fdad7f7349920eabbaca6271f95969fa0756c2d"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28f512b9a33235545fbbdac6a330a510b63be278a50071a336afc1b78781b147"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0e842112fe3f1a4ffcf64b06dc4c61a88441c2f02f373367f7b4c1aa9be2ad5"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f9bc2ce123637a60ebe819f9fccc614da1bcc05798bbbaf2dd4ec91f3e08846"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:f194cce575e59ffe442c10a360182a986535fd90b57f7debfaa5c845c409ecc3"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:9a74041ba0bfa9bc9b9bb2cd3238a6ab3b7618e759b41bd15b5f6ad958d17605"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:b578cbe580e3b41ad17b1c428f382c814b32a6ce90f2d8e39e2e635d49e498d1"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:6db3cfb9b4fcecb4390db154e75b49578c87a3b9979b40cdf90d7e4b945656e1"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:debb633f3f7856f95ad957d9b9c781f8e2c6303ef21724ec94bea2ce2fcbd056"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-win32.whl", hash = "sha256:87071618d3d8ec8b186d53cb6e66955ef2a0e4fa63ccd3709c0c90ac5a43520f"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:e372d7dfd154009142631de2d316adad3cc1c36c32a38b16a4751ba78da2a397"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ae4070f741f8d809075ef697877fd350ecf0b7c5837ed68738607ee0a2c572cf"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:58e875eb7016fd014c0eea46c6fa92b87b62c0cb31b9feae25cbbe62c919f54d"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:dbd95e300367aa0827496fe75a1766d198d34385a58f97683fe6e07f89ca3e3c"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de0b4caa1c8a21394e8ce971997614a17648f94e1cd0640fbd6b4d14cab13a72"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:985c7965f62f6f32bf432e2681173db41336a9c2611693247069288bcb0c7f8b"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a15c1fe6d26e83fd2e5972425a772cca158eae58b05d4a25a4e474c221053e2d"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ae55d592b02c4349525b6ed8f74c692509e5adffa842e582c0f861751701a673"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:be4d9c2770044a59715eb57c1144dedea7c5d5ae80c68fb9959515037cde2008"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:851cf693fb3aaef71031237cd68699dded198657ec1e76a76eb8be58c03a5d1f"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:31bbaba7218904d2eabecf4feec0d07469284e952a27400f23b6628439439fa7"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:871d045d6ccc181fd863a3cd66ee8e395523ebfbc57f85f91f035f50cee8e3d4"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:501adc5eb6cd5f40a6f77fbd90e5ab915c8fd6e8c614af2db5561e16c600d6f3"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f5fb672c396d826ca16a022ac04c9dce74e00a1c344f6ad1a0fdc1ba1f332213"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-win32.whl", hash = "sha256:bb06098d019766ca16fc915ecaa455c1f1cd594204e7f840cd6258237b5079a8"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:8af5a8917b8af42295e86b64903156b4f110a30dca5f3b5aedea123fbd638bff"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:7ae8e5142dcc7a49168f4055255dbcced01dc1714a90a21f87448dc8d90617d1"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5b70bab78accbc672f50e878a5b73ca692f45f5b5e25c8066d748c09405e6a55"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5ceca5876032362ae73b83347be8b5dbd2d1faf3358deb38c9c88776779b2e2f"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34d95638ff3613849f473afc33f65c401a89f3b9528d0d213c7037c398a51296"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9edbe6a5bf8b56a4a84533ba2b2f489d0046e755c29616ef8830f9e7d9cf5728"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f6a02a3c7950cafaadcd46a226ad9e12fc9744652cc69f9e5534f98b47f3bbcf"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10b8dd31e10f32410751b3430996f9807fc4d1587ca69772e2aa940a82ab571a"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:edc0202099ea1d82844316604e17d2b175044f9bcb6b398aab781eba957224bd"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b891a2f68e09c5ef989007fac11476ed33c5c9994449a4e2c3386529d703dc8b"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:71ef3b9be10070360f289aea4838c784f8b851be3ba58cf796262b57775c2f14"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:55602981b2dbf8184c098bc10287e8c245e351cd4fdcad050bd7199d5a8bf514"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:46fb9970aa5eeca547d7aa0de5d4b124a288b42eaefac677bde805013c95725c"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:520b7a142d2524f999447b3a0cf95115df81c4f33003c51a6ab637cbda9d0bf4"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-win32.whl", hash = "sha256:8ec8ef42c6cd5856a7613dcd1eaf21e5573b2185263d87d27c8edcae33b62a61"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:baec8148d6b8bd5cee1ae138ba658c71f5b03e0d69d5907703e3e1df96db5e41"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:63a6f59e2d01310f754c270e4a257426fe5a591dc487f1983b3bbe793cf6bac6"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d6bfc32a68bc0933819cfdfe45f9abc3cae3877e1d90aac7259d57e6e0f85b1"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4f3100d86dcd03c03f7e9c3fdb23d92e32abbca07e7c13ebd7ddfbcb06f5991f"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39b70a6f88eebe239fa775190796d55a33cfb6d36b9ffdd37843f7c4c1b5dc67"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e12f8ee80aa35e746230a2af83e81bd6b52daa92a8afaef4fea4a2ce9b9f4fa"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b6cefa579e1237ce198619b76eaa148b71894fb0d6bcf9024460f9bf30fd228"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:61f1e3fb621f5420523abb71f5771a204b33c21d31e7d9d86881b2cffe92c47c"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4f6e2a839f83a6a76854d12dbebde50e4b1afa63e27761549d006fa53e9aa80e"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:1ec937546cad86d0dce5396748bf392bb7b62a9eeb8c66efac60e947697f0e58"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:82ca51ff0fc5b641a2d4e1cc8c5ff108699b7a56d7f3ad6f6da9dbb6f0145b48"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:633968254f8d421e70f91c6ebe71ed0ab140220469cf87a9857e21c16687c034"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-win32.whl", hash = "sha256:c0c72d34e7de5604df0fde3644cc079feee5e55464967d10b24b1de268deceb9"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-win_amd64.whl", hash = "sha256:63accd11149c0f9a99e3bc095bbdb5a464862d77a7e309ad5938fbc8721235ae"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5a3580a4fdc4ac05f9e53c57f965e3594b2f99796231380adb2baaab96e22761"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2465aa50c9299d615d757c1c888bc6fef384b7c4aec81c05a0172b4400f98557"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cb7cd68814308aade9d0c93c5bd2ade9f9441666f8ba5aa9c2d4b389cb5e2a45"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91e43805ccafa0a91831f9cd5443aa34528c0c3f2cc48c4cb3d9a7721053874b"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:854cc74367180beb327ab9d00f964f6d91da06450b0855cbbb09187bcdb02de5"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c15070ebf11b8b7fd1bfff7217e9324963c82dbdf6182ff7050519e350e7ad9f"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c4c99f98fc3a1835af8179dcc9013f93594d0670e2fa80c83aa36346ee763d2"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3fb765362688821404ad6cf86772fc54993ec11577cd5a92ac44b4c2ba52155b"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:dced27917823df984fe0c80a5c4ad75cf58df0fbfae890bc08004cd3888922a2"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a66bcdf19c1a523e41b8e9d53d0cedbfbac2e93c649a2e9502cb26c014d0980c"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:ecd26be9f112c4f96718290c10f4caea6cc798459a3a76636b817a0ed7874e42"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:3f70fd716855cd3b855316b226a1ac8bdb3caf4f7ea96edcccc6f484217c9597"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:17a866d61259c7de1bdadef418a37755050ddb4b922df8b356503234fff7932c"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-win32.whl", hash = "sha256:548eefad783ed787b38cb6f9a574bd8664468cc76d1538215d510a3cd41406cb"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:45f053a0ece92c734d874861ffe6e3cc92150e32136dd59ab1fb070575189c97"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bc791ec3fd0c4309a753f95bb6c749ef0d8ea3aea91f07ee1cf06b7b02118f2f"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0c8c61fb505c7dad1d251c284e712d4e0372cef3b067f7ddf82a7fa82e1e9a93"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2c092be3885a1b7899cd85ce24acedc1034199d6fca1483fa2c3a35c86e43041"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2000c54c395d9e5e44c99dc7c20a64dc371f777faf8bae4919ad3e99ce5253e"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4cb50a0335382aac15c31b61d8531bc9bb657cfd848b1d7158009472189f3d62"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c30187840d36d0ba2893bc3271a36a517a717f9fd383a98e2697ee890a37c273"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe81b35c33772e56f4b6cf62cf4aedc1762ef7162a31e6ac7fe5e40d0149eb67"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0bf89afcbcf4d1bb2652f6580e5e55a840fdf87384f6063c4a4f0c95e378656"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:06cf46bdff72f58645434d467bf5228080801298fbba19fe268a01b4534467f5"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:3c66df3f41abee950d6638adc7eac4730a306b022570f71dd0bd6ba53503ab57"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:cd805513198304026bd379d1d516afbf6c3c13f4382134a2c526b8b854da1c2e"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:9505dc359edb6a330efcd2be825fdb73ee3e628d9010597aa1aee5aa63442e97"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:31445f38053476a0c4e6d12b047b08ced81e2c7c712e5a1ad97bc913256f91b2"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-win32.whl", hash = "sha256:bd28b31730f0e982ace8663d108e01199098432a30a4c410d06fe08fdb9e93f4"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:555fe186da0068d3354cdf4bbcbc609b0ecae4d04c921cc13e209eece7720727"}, + {file = "charset_normalizer-3.3.1-py3-none-any.whl", hash = "sha256:800561453acdecedaac137bf09cd719c7a440b6800ec182f077bb8e7025fb708"}, +] + +[[package]] +name = "cmake" +version = "3.27.7" +description = "CMake is an open-source, cross-platform family of tools designed to build, test and package software" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "cmake-3.27.7-py2.py3-none-macosx_10_10_universal2.macosx_10_10_x86_64.macosx_11_0_arm64.macosx_11_0_universal2.whl", hash = "sha256:d582ef3e9ff0bd113581c1a32e881d1c2f9a34d2de76c93324a28593a76433db"}, + {file = "cmake-3.27.7-py2.py3-none-manylinux2010_i686.manylinux_2_12_i686.whl", hash = "sha256:8056c99e371ff57229df2068364d7c32fea716cb53b4675f639edfb62663decf"}, + {file = "cmake-3.27.7-py2.py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:68983b09de633fc1ce6ab6bce9a25bfa181e41598e7c6bc0a6c0108773ee01cb"}, + {file = "cmake-3.27.7-py2.py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8bd1e1fa4fc8de7605c663d9408dceb649112f855aab05cca31fdb72e4d78364"}, + {file = "cmake-3.27.7-py2.py3-none-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:c981aafcca2cd7210bd210ec75710c0f34e1fde1998cdcab812e4133e3ab615d"}, + {file = "cmake-3.27.7-py2.py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:1b9067ce0251cba3d4c018f2e1577ba9078e9c1eff6ad607ad5ce867843d4571"}, + {file = "cmake-3.27.7-py2.py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b8a2fcb619b89d1cce7b52828316de9a1f27f0c90c2e39d1eae886428c8ee8c6"}, + {file = "cmake-3.27.7-py2.py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:499b38c77d52fb1964dbb38d0228fed246263a181939a8e753fde8ca227c8e1e"}, + {file = "cmake-3.27.7-py2.py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:2fb48c780f1a6a3d19e785ebbb754be79d369e25a1cb81043fab049e709564da"}, + {file = "cmake-3.27.7-py2.py3-none-musllinux_1_1_i686.whl", hash = "sha256:7bf96237ba11ce2437dc5e071d96b510120a1be4708c631a64b2f38fb46bbd77"}, + {file = "cmake-3.27.7-py2.py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:009058bdf4f488709f38eaa5dd0ef0f89c6b9c6b6edd9d5b475a308ef75f80bb"}, + {file = "cmake-3.27.7-py2.py3-none-musllinux_1_1_s390x.whl", hash = "sha256:591f6b056527aefec009bc61a388776b2fc62444deb0038112a471031f61aeca"}, + {file = "cmake-3.27.7-py2.py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:bd40d46dbad3555d5b3ce054bef24b85f256b19139493773751ab6f2b71c1219"}, + {file = "cmake-3.27.7-py2.py3-none-win32.whl", hash = "sha256:bdbf0256f554f68c7b1d9740f5d059daf875b685c81a479cbe69038e84eb2fb9"}, + {file = "cmake-3.27.7-py2.py3-none-win_amd64.whl", hash = "sha256:810e592b606d05a3080a9c19ea839b13226f62cae447a22485b2365782f6b926"}, + {file = "cmake-3.27.7-py2.py3-none-win_arm64.whl", hash = "sha256:72289361866314f73be2ae63ddee224ff70223dcef9feb66d0072bf17e245564"}, + {file = "cmake-3.27.7.tar.gz", hash = "sha256:9f4a7c7be2a25de5901f045618f41b833ea6c0f647115201d38e4fdf7e2815bc"}, +] + +[package.extras] +test = ["coverage (>=4.2)", "flake8 (>=3.0.4)", "path.py (>=11.5.0)", "pytest (>=3.0.3)", "pytest-cov (>=2.4.0)", "pytest-runner (>=2.9)", "pytest-virtualenv (>=1.7.0)", "scikit-build (>=0.10.0)", "setuptools (>=28.0.0)", "virtualenv (>=15.0.3)", "wheel"] + +[[package]] +name = "filelock" +version = "3.12.4" +description = "A platform independent file lock." +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "filelock-3.12.4-py3-none-any.whl", hash = "sha256:08c21d87ded6e2b9da6728c3dff51baf1dcecf973b768ef35bcbc3447edb9ad4"}, + {file = "filelock-3.12.4.tar.gz", hash = "sha256:2e6f249f1f3654291606e046b09f1fd5eac39b360664c27f5aad072012f8bcbd"}, +] + +[package.extras] +docs = ["furo (>=2023.7.26)", "sphinx (>=7.1.2)", "sphinx-autodoc-typehints (>=1.24)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3)", "diff-cover (>=7.7)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)", "pytest-timeout (>=2.1)"] +typing = ["typing-extensions (>=4.7.1)"] + +[[package]] +name = "idna" +version = "3.4" +description = "Internationalized Domain Names in Applications (IDNA)" +category = "main" +optional = false +python-versions = ">=3.5" +files = [ + {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, + {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, +] + +[[package]] +name = "jinja2" +version = "3.1.2" +description = "A very fast and expressive template engine." +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"}, + {file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"}, +] + +[package.dependencies] +MarkupSafe = ">=2.0" + +[package.extras] +i18n = ["Babel (>=2.7)"] + +[[package]] +name = "lit" +version = "17.0.3" +description = "A Software Testing Tool" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "lit-17.0.3.tar.gz", hash = "sha256:e6049032462be1e2928686cbd4a6cc5b3c545d83ecd078737fe79412c1f3fcc1"}, +] + +[[package]] +name = "markupsafe" +version = "2.1.3" +description = "Safely add untrusted strings to HTML/XML markup." +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-win32.whl", hash = "sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-win_amd64.whl", hash = "sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-win32.whl", hash = "sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-win_amd64.whl", hash = "sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-win32.whl", hash = "sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-win_amd64.whl", hash = "sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-win32.whl", hash = "sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-win_amd64.whl", hash = "sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba"}, + {file = "MarkupSafe-2.1.3.tar.gz", hash = "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad"}, +] + +[[package]] +name = "mpmath" +version = "1.3.0" +description = "Python library for arbitrary-precision floating-point arithmetic" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, + {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, +] + +[package.extras] +develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"] +docs = ["sphinx"] +gmpy = ["gmpy2 (>=2.1.0a4)"] +tests = ["pytest (>=4.6)"] + +[[package]] +name = "networkx" +version = "3.2" +description = "Python package for creating and manipulating graphs and networks" +category = "main" +optional = false +python-versions = ">=3.9" +files = [ + {file = "networkx-3.2-py3-none-any.whl", hash = "sha256:8b25f564bd28f94ac821c58b04ae1a3109e73b001a7d476e4bb0d00d63706bf8"}, + {file = "networkx-3.2.tar.gz", hash = "sha256:bda29edf392d9bfa5602034c767d28549214ec45f620081f0b74dc036a1fbbc1"}, +] + +[package.extras] +default = ["matplotlib (>=3.5)", "numpy (>=1.22)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"] +developer = ["changelist (==0.4)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"] +doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"] +extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"] +test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] + +[[package]] +name = "numpy" +version = "1.25.2" +description = "Fundamental package for array computing in Python" +category = "main" +optional = false +python-versions = ">=3.9" +files = [ + {file = "numpy-1.25.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:db3ccc4e37a6873045580d413fe79b68e47a681af8db2e046f1dacfa11f86eb3"}, + {file = "numpy-1.25.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:90319e4f002795ccfc9050110bbbaa16c944b1c37c0baeea43c5fb881693ae1f"}, + {file = "numpy-1.25.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dfe4a913e29b418d096e696ddd422d8a5d13ffba4ea91f9f60440a3b759b0187"}, + {file = "numpy-1.25.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f08f2e037bba04e707eebf4bc934f1972a315c883a9e0ebfa8a7756eabf9e357"}, + {file = "numpy-1.25.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bec1e7213c7cb00d67093247f8c4db156fd03075f49876957dca4711306d39c9"}, + {file = "numpy-1.25.2-cp310-cp310-win32.whl", hash = "sha256:7dc869c0c75988e1c693d0e2d5b26034644399dd929bc049db55395b1379e044"}, + {file = "numpy-1.25.2-cp310-cp310-win_amd64.whl", hash = "sha256:834b386f2b8210dca38c71a6e0f4fd6922f7d3fcff935dbe3a570945acb1b545"}, + {file = "numpy-1.25.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c5462d19336db4560041517dbb7759c21d181a67cb01b36ca109b2ae37d32418"}, + {file = "numpy-1.25.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c5652ea24d33585ea39eb6a6a15dac87a1206a692719ff45d53c5282e66d4a8f"}, + {file = "numpy-1.25.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d60fbae8e0019865fc4784745814cff1c421df5afee233db6d88ab4f14655a2"}, + {file = "numpy-1.25.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:60e7f0f7f6d0eee8364b9a6304c2845b9c491ac706048c7e8cf47b83123b8dbf"}, + {file = "numpy-1.25.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:bb33d5a1cf360304754913a350edda36d5b8c5331a8237268c48f91253c3a364"}, + {file = "numpy-1.25.2-cp311-cp311-win32.whl", hash = "sha256:5883c06bb92f2e6c8181df7b39971a5fb436288db58b5a1c3967702d4278691d"}, + {file = "numpy-1.25.2-cp311-cp311-win_amd64.whl", hash = "sha256:5c97325a0ba6f9d041feb9390924614b60b99209a71a69c876f71052521d42a4"}, + {file = "numpy-1.25.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b79e513d7aac42ae918db3ad1341a015488530d0bb2a6abcbdd10a3a829ccfd3"}, + {file = "numpy-1.25.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:eb942bfb6f84df5ce05dbf4b46673ffed0d3da59f13635ea9b926af3deb76926"}, + {file = "numpy-1.25.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e0746410e73384e70d286f93abf2520035250aad8c5714240b0492a7302fdca"}, + {file = "numpy-1.25.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7806500e4f5bdd04095e849265e55de20d8cc4b661b038957354327f6d9b295"}, + {file = "numpy-1.25.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8b77775f4b7df768967a7c8b3567e309f617dd5e99aeb886fa14dc1a0791141f"}, + {file = "numpy-1.25.2-cp39-cp39-win32.whl", hash = "sha256:2792d23d62ec51e50ce4d4b7d73de8f67a2fd3ea710dcbc8563a51a03fb07b01"}, + {file = "numpy-1.25.2-cp39-cp39-win_amd64.whl", hash = "sha256:76b4115d42a7dfc5d485d358728cdd8719be33cc5ec6ec08632a5d6fca2ed380"}, + {file = "numpy-1.25.2-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:1a1329e26f46230bf77b02cc19e900db9b52f398d6722ca853349a782d4cff55"}, + {file = "numpy-1.25.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c3abc71e8b6edba80a01a52e66d83c5d14433cbcd26a40c329ec7ed09f37901"}, + {file = "numpy-1.25.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:1b9735c27cea5d995496f46a8b1cd7b408b3f34b6d50459d9ac8fe3a20cc17bf"}, + {file = "numpy-1.25.2.tar.gz", hash = "sha256:fd608e19c8d7c55021dffd43bfe5492fab8cc105cc8986f813f8c3c048b38760"}, +] + +[[package]] +name = "pillow" +version = "10.1.0" +description = "Python Imaging Library (Fork)" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "Pillow-10.1.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:1ab05f3db77e98f93964697c8efc49c7954b08dd61cff526b7f2531a22410106"}, + {file = "Pillow-10.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6932a7652464746fcb484f7fc3618e6503d2066d853f68a4bd97193a3996e273"}, + {file = "Pillow-10.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5f63b5a68daedc54c7c3464508d8c12075e56dcfbd42f8c1bf40169061ae666"}, + {file = "Pillow-10.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0949b55eb607898e28eaccb525ab104b2d86542a85c74baf3a6dc24002edec2"}, + {file = "Pillow-10.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:ae88931f93214777c7a3aa0a8f92a683f83ecde27f65a45f95f22d289a69e593"}, + {file = "Pillow-10.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b0eb01ca85b2361b09480784a7931fc648ed8b7836f01fb9241141b968feb1db"}, + {file = "Pillow-10.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d27b5997bdd2eb9fb199982bb7eb6164db0426904020dc38c10203187ae2ff2f"}, + {file = "Pillow-10.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7df5608bc38bd37ef585ae9c38c9cd46d7c81498f086915b0f97255ea60c2818"}, + {file = "Pillow-10.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:41f67248d92a5e0a2076d3517d8d4b1e41a97e2df10eb8f93106c89107f38b57"}, + {file = "Pillow-10.1.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:1fb29c07478e6c06a46b867e43b0bcdb241b44cc52be9bc25ce5944eed4648e7"}, + {file = "Pillow-10.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2cdc65a46e74514ce742c2013cd4a2d12e8553e3a2563c64879f7c7e4d28bce7"}, + {file = "Pillow-10.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50d08cd0a2ecd2a8657bd3d82c71efd5a58edb04d9308185d66c3a5a5bed9610"}, + {file = "Pillow-10.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:062a1610e3bc258bff2328ec43f34244fcec972ee0717200cb1425214fe5b839"}, + {file = "Pillow-10.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:61f1a9d247317fa08a308daaa8ee7b3f760ab1809ca2da14ecc88ae4257d6172"}, + {file = "Pillow-10.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a646e48de237d860c36e0db37ecaecaa3619e6f3e9d5319e527ccbc8151df061"}, + {file = "Pillow-10.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:47e5bf85b80abc03be7455c95b6d6e4896a62f6541c1f2ce77a7d2bb832af262"}, + {file = "Pillow-10.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a92386125e9ee90381c3369f57a2a50fa9e6aa8b1cf1d9c4b200d41a7dd8e992"}, + {file = "Pillow-10.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:0f7c276c05a9767e877a0b4c5050c8bee6a6d960d7f0c11ebda6b99746068c2a"}, + {file = "Pillow-10.1.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:a89b8312d51715b510a4fe9fc13686283f376cfd5abca8cd1c65e4c76e21081b"}, + {file = "Pillow-10.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:00f438bb841382b15d7deb9a05cc946ee0f2c352653c7aa659e75e592f6fa17d"}, + {file = "Pillow-10.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d929a19f5469b3f4df33a3df2983db070ebb2088a1e145e18facbc28cae5b27"}, + {file = "Pillow-10.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a92109192b360634a4489c0c756364c0c3a2992906752165ecb50544c251312"}, + {file = "Pillow-10.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:0248f86b3ea061e67817c47ecbe82c23f9dd5d5226200eb9090b3873d3ca32de"}, + {file = "Pillow-10.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:9882a7451c680c12f232a422730f986a1fcd808da0fd428f08b671237237d651"}, + {file = "Pillow-10.1.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:1c3ac5423c8c1da5928aa12c6e258921956757d976405e9467c5f39d1d577a4b"}, + {file = "Pillow-10.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:806abdd8249ba3953c33742506fe414880bad78ac25cc9a9b1c6ae97bedd573f"}, + {file = "Pillow-10.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:eaed6977fa73408b7b8a24e8b14e59e1668cfc0f4c40193ea7ced8e210adf996"}, + {file = "Pillow-10.1.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:fe1e26e1ffc38be097f0ba1d0d07fcade2bcfd1d023cda5b29935ae8052bd793"}, + {file = "Pillow-10.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7a7e3daa202beb61821c06d2517428e8e7c1aab08943e92ec9e5755c2fc9ba5e"}, + {file = "Pillow-10.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:24fadc71218ad2b8ffe437b54876c9382b4a29e030a05a9879f615091f42ffc2"}, + {file = "Pillow-10.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa1d323703cfdac2036af05191b969b910d8f115cf53093125e4058f62012c9a"}, + {file = "Pillow-10.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:912e3812a1dbbc834da2b32299b124b5ddcb664ed354916fd1ed6f193f0e2d01"}, + {file = "Pillow-10.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:7dbaa3c7de82ef37e7708521be41db5565004258ca76945ad74a8e998c30af8d"}, + {file = "Pillow-10.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:9d7bc666bd8c5a4225e7ac71f2f9d12466ec555e89092728ea0f5c0c2422ea80"}, + {file = "Pillow-10.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:baada14941c83079bf84c037e2d8b7506ce201e92e3d2fa0d1303507a8538212"}, + {file = "Pillow-10.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:2ef6721c97894a7aa77723740a09547197533146fba8355e86d6d9a4a1056b14"}, + {file = "Pillow-10.1.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:0a026c188be3b443916179f5d04548092e253beb0c3e2ee0a4e2cdad72f66099"}, + {file = "Pillow-10.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:04f6f6149f266a100374ca3cc368b67fb27c4af9f1cc8cb6306d849dcdf12616"}, + {file = "Pillow-10.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb40c011447712d2e19cc261c82655f75f32cb724788df315ed992a4d65696bb"}, + {file = "Pillow-10.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a8413794b4ad9719346cd9306118450b7b00d9a15846451549314a58ac42219"}, + {file = "Pillow-10.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:c9aeea7b63edb7884b031a35305629a7593272b54f429a9869a4f63a1bf04c34"}, + {file = "Pillow-10.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:b4005fee46ed9be0b8fb42be0c20e79411533d1fd58edabebc0dd24626882cfd"}, + {file = "Pillow-10.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4d0152565c6aa6ebbfb1e5d8624140a440f2b99bf7afaafbdbf6430426497f28"}, + {file = "Pillow-10.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d921bc90b1defa55c9917ca6b6b71430e4286fc9e44c55ead78ca1a9f9eba5f2"}, + {file = "Pillow-10.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cfe96560c6ce2f4c07d6647af2d0f3c54cc33289894ebd88cfbb3bcd5391e256"}, + {file = "Pillow-10.1.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:937bdc5a7f5343d1c97dc98149a0be7eb9704e937fe3dc7140e229ae4fc572a7"}, + {file = "Pillow-10.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1c25762197144e211efb5f4e8ad656f36c8d214d390585d1d21281f46d556ba"}, + {file = "Pillow-10.1.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:afc8eef765d948543a4775f00b7b8c079b3321d6b675dde0d02afa2ee23000b4"}, + {file = "Pillow-10.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:883f216eac8712b83a63f41b76ddfb7b2afab1b74abbb413c5df6680f071a6b9"}, + {file = "Pillow-10.1.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:b920e4d028f6442bea9a75b7491c063f0b9a3972520731ed26c83e254302eb1e"}, + {file = "Pillow-10.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c41d960babf951e01a49c9746f92c5a7e0d939d1652d7ba30f6b3090f27e412"}, + {file = "Pillow-10.1.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1fafabe50a6977ac70dfe829b2d5735fd54e190ab55259ec8aea4aaea412fa0b"}, + {file = "Pillow-10.1.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:3b834f4b16173e5b92ab6566f0473bfb09f939ba14b23b8da1f54fa63e4b623f"}, + {file = "Pillow-10.1.0.tar.gz", hash = "sha256:e6bf8de6c36ed96c86ea3b6e1d5273c53f46ef518a062464cd7ef5dd2cf92e38"}, +] + +[package.extras] +docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-removed-in", "sphinxext-opengraph"] +tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] + +[[package]] +name = "requests" +version = "2.31.0" +description = "Python HTTP for Humans." +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, + {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, +] + +[package.dependencies] +certifi = ">=2017.4.17" +charset-normalizer = ">=2,<4" +idna = ">=2.5,<4" +urllib3 = ">=1.21.1,<3" + +[package.extras] +socks = ["PySocks (>=1.5.6,!=1.5.7)"] +use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] + +[[package]] +name = "sympy" +version = "1.12" +description = "Computer algebra system (CAS) in Python" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"}, + {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"}, +] + +[package.dependencies] +mpmath = ">=0.19" + +[[package]] +name = "torch" +version = "2.0.0+cu118" +description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +category = "main" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "torch-2.0.0%2Bcu118-cp310-cp310-linux_x86_64.whl", hash = "sha256:4b690e2b77f21073500c65d8bb9ea9656b8cb4e969f357370bbc992a3b074764"}, +] + +[package.dependencies] +filelock = "*" +jinja2 = "*" +networkx = "*" +sympy = "*" +triton = {version = "2.0.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +typing-extensions = "*" + +[package.extras] +opt-einsum = ["opt-einsum (>=3.3)"] + +[package.source] +type = "url" +url = "https://download.pytorch.org/whl/cu118/torch-2.0.0%2Bcu118-cp310-cp310-linux_x86_64.whl" +[[package]] +name = "torchvision" +version = "0.15.0+cu118" +description = "image and video datasets and models for torch deep learning" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "torchvision-0.15.0%2Bcu118-cp310-cp310-linux_x86_64.whl", hash = "sha256:5a9614b080d31c1c7b23574a301114b28cb25d86b25b60ab85b2eaedd0b3e6e9"}, +] + +[package.dependencies] +numpy = "*" +pillow = ">=5.3.0,<8.3.0 || >=8.4.0" +requests = "*" +torch = "2.0.0+cu118" + +[package.extras] +scipy = ["scipy"] + +[package.source] +type = "url" +url = "https://download.pytorch.org/whl/cu118/torchvision-0.15.0%2Bcu118-cp310-cp310-linux_x86_64.whl" +[[package]] +name = "triton" +version = "2.0.0" +description = "A language and compiler for custom Deep Learning operations" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:38806ee9663f4b0f7cd64790e96c579374089e58f49aac4a6608121aa55e2505"}, + {file = "triton-2.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:226941c7b8595219ddef59a1fdb821e8c744289a132415ddd584facedeb475b1"}, + {file = "triton-2.0.0-1-cp36-cp36m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4c9fc8c89874bc48eb7e7b2107a9b8d2c0bf139778637be5bfccb09191685cfd"}, + {file = "triton-2.0.0-1-cp37-cp37m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d2684b6a60b9f174f447f36f933e9a45f31db96cb723723ecd2dcfd1c57b778b"}, + {file = "triton-2.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9d4978298b74fcf59a75fe71e535c092b023088933b2f1df933ec32615e4beef"}, + {file = "triton-2.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:74f118c12b437fb2ca25e1a04759173b517582fcf4c7be11913316c764213656"}, + {file = "triton-2.0.0-1-pp37-pypy37_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9618815a8da1d9157514f08f855d9e9ff92e329cd81c0305003eb9ec25cc5add"}, + {file = "triton-2.0.0-1-pp38-pypy38_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1aca3303629cd3136375b82cb9921727f804e47ebee27b2677fef23005c3851a"}, + {file = "triton-2.0.0-1-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e3e13aa8b527c9b642e3a9defcc0fbd8ffbe1c80d8ac8c15a01692478dc64d8a"}, + {file = "triton-2.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f05a7e64e4ca0565535e3d5d3405d7e49f9d308505bb7773d21fb26a4c008c2"}, + {file = "triton-2.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb4b99ca3c6844066e516658541d876c28a5f6e3a852286bbc97ad57134827fd"}, + {file = "triton-2.0.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47b4d70dc92fb40af553b4460492c31dc7d3a114a979ffb7a5cdedb7eb546c08"}, + {file = "triton-2.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fedce6a381901b1547e0e7e1f2546e4f65dca6d91e2d8a7305a2d1f5551895be"}, + {file = "triton-2.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75834f27926eab6c7f00ce73aaf1ab5bfb9bec6eb57ab7c0bfc0a23fac803b4c"}, + {file = "triton-2.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0117722f8c2b579cd429e0bee80f7731ae05f63fe8e9414acd9a679885fcbf42"}, + {file = "triton-2.0.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcd9be5d0c2e45d2b7e6ddc6da20112b6862d69741576f9c3dbaf941d745ecae"}, + {file = "triton-2.0.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42a0d2c3fc2eab4ba71384f2e785fbfd47aa41ae05fa58bf12cb31dcbd0aeceb"}, + {file = "triton-2.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:52c47b72c72693198163ece9d90a721299e4fb3b8e24fd13141e384ad952724f"}, +] + +[package.dependencies] +cmake = "*" +filelock = "*" +lit = "*" +torch = "*" + +[package.extras] +tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"] +tutorials = ["matplotlib", "pandas", "tabulate"] + +[[package]] +name = "typing-extensions" +version = "4.8.0" +description = "Backported and Experimental Type Hints for Python 3.8+" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "typing_extensions-4.8.0-py3-none-any.whl", hash = "sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0"}, + {file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"}, +] + +[[package]] +name = "urllib3" +version = "2.0.7" +description = "HTTP library with thread-safe connection pooling, file post, and more." +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "urllib3-2.0.7-py3-none-any.whl", hash = "sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e"}, + {file = "urllib3-2.0.7.tar.gz", hash = "sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84"}, +] + +[package.extras] +brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"] +socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] +zstd = ["zstandard (>=0.18.0)"] + +[metadata] +lock-version = "2.0" +python-versions = "^3.10" +content-hash = "049b8c18f0fcabe2117d06166110312bba4d11568489470249ed552eb4d1c13f" diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..8ccf6bc6 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,16 @@ +[tool.poetry] +name = "supervised-contrastive-loss" +version = "0.1.0" +description = "" +authors = ["Sean (Seok-Won) Yi "] +readme = "README.md" +# packages = [{include = "supervised_contrastive_loss"}] + +[tool.poetry.dependencies] +python = "^3.10" +torch = {url = "https://download.pytorch.org/whl/cu118/torch-2.0.0%2Bcu118-cp310-cp310-linux_x86_64.whl"} +torchvision = {url = "https://download.pytorch.org/whl/cu118/torchvision-0.15.0%2Bcu118-cp310-cp310-linux_x86_64.whl"} + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/util.py b/util.py index b6323530..436156a4 100644 --- a/util.py +++ b/util.py @@ -1,13 +1,15 @@ from __future__ import print_function import math + import numpy as np import torch -import torch.optim as optim +from torch import optim class TwoCropTransform: """Create two crops of the same image""" + def __init__(self, transform): self.transform = transform @@ -17,6 +19,7 @@ def __call__(self, x): class AverageMeter(object): """Computes and stores the average and current value""" + def __init__(self): self.reset() @@ -53,43 +56,47 @@ def accuracy(output, target, topk=(1,)): def adjust_learning_rate(args, optimizer, epoch): lr = args.learning_rate if args.cosine: - eta_min = lr * (args.lr_decay_rate ** 3) - lr = eta_min + (lr - eta_min) * ( - 1 + math.cos(math.pi * epoch / args.epochs)) / 2 + eta_min = lr * (args.lr_decay_rate**3) + lr = ( + eta_min + (lr - eta_min) * (1 + math.cos(math.pi * epoch / args.epochs)) / 2 + ) else: steps = np.sum(epoch > np.asarray(args.lr_decay_epochs)) if steps > 0: - lr = lr * (args.lr_decay_rate ** steps) + lr = lr * (args.lr_decay_rate**steps) for param_group in optimizer.param_groups: - param_group['lr'] = lr + param_group["lr"] = lr def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer): if args.warm and epoch <= args.warm_epochs: - p = (batch_id + (epoch - 1) * total_batches) / \ - (args.warm_epochs * total_batches) + p = (batch_id + (epoch - 1) * total_batches) / ( + args.warm_epochs * total_batches + ) lr = args.warmup_from + p * (args.warmup_to - args.warmup_from) for param_group in optimizer.param_groups: - param_group['lr'] = lr + param_group["lr"] = lr def set_optimizer(opt, model): - optimizer = optim.SGD(model.parameters(), - lr=opt.learning_rate, - momentum=opt.momentum, - weight_decay=opt.weight_decay) + optimizer = optim.SGD( + model.parameters(), + lr=opt.learning_rate, + momentum=opt.momentum, + weight_decay=opt.weight_decay, + ) return optimizer def save_model(model, optimizer, opt, epoch, save_file): - print('==> Saving...') + print("==> Saving...") state = { - 'opt': opt, - 'model': model.state_dict(), - 'optimizer': optimizer.state_dict(), - 'epoch': epoch, + "opt": opt, + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch, } torch.save(state, save_file) del state