diff --git a/cocdownload.py b/cocdownload.py new file mode 100644 index 0000000..1a788fc --- /dev/null +++ b/cocdownload.py @@ -0,0 +1,13 @@ +import fiftyone as fo +import fiftyone.zoo as foz + +dataset = foz.load_zoo_dataset( + "coco/coco-2014", + splits=["validation", "test"], + label_types=["segmentations"], + max_samples=1000, + shuffle=True, + format=".jpg" +) +session = fo.launch_app(dataset) +session.dataset = dataset \ No newline at end of file diff --git a/coco_splits/split_-1.json b/coco_splits/split_-1.json new file mode 100644 index 0000000..d1771c5 --- /dev/null +++ b/coco_splits/split_-1.json @@ -0,0 +1 @@ +{"val": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80]} \ No newline at end of file diff --git a/coco_splits/split_0.json b/coco_splits/split_0.json new file mode 100644 index 0000000..ec4d9eb --- /dev/null +++ b/coco_splits/split_0.json @@ -0,0 +1 @@ +{"val": [1, 5, 9, 13, 17, 21, 25, 29, 33, 37, 41, 45, 49, 53, 57, 61, 65, 69, 73, 77], "train": [2, 3, 4, 6, 7, 8, 10, 11, 12, 14, 15, 16, 18, 19, 20, 22, 23, 24, 26, 27, 28, 30, 31, 32, 34, 35, 36, 38, 39, 40, 42, 43, 44, 46, 47, 48, 50, 51, 52, 54, 55, 56, 58, 59, 60, 62, 63, 64, 66, 67, 68, 70, 71, 72, 74, 75, 76, 78, 79, 80]} \ No newline at end of file diff --git a/coco_splits/split_1.json b/coco_splits/split_1.json new file mode 100644 index 0000000..cf76929 --- /dev/null +++ b/coco_splits/split_1.json @@ -0,0 +1 @@ +{"val": [2, 6, 10, 14, 18, 22, 26, 30, 34, 38, 42, 46, 50, 54, 58, 62, 66, 70, 74, 78], "train": [1, 3, 4, 5, 7, 8, 9, 11, 12, 13, 15, 16, 17, 19, 20, 21, 23, 24, 25, 27, 28, 29, 31, 32, 33, 35, 36, 37, 39, 40, 41, 43, 44, 45, 47, 48, 49, 51, 52, 53, 55, 56, 57, 59, 60, 61, 63, 64, 65, 67, 68, 69, 71, 72, 73, 75, 76, 77, 79, 80]} \ No newline at end of file diff --git a/coco_splits/split_2.json b/coco_splits/split_2.json new file mode 100644 index 0000000..454aa1b --- /dev/null +++ b/coco_splits/split_2.json @@ -0,0 +1 @@ +{"val": [3, 7, 11, 15, 19, 23, 27, 31, 35, 39, 43, 47, 51, 55, 59, 63, 67, 71, 75, 79], "train": [1, 2, 4, 5, 6, 8, 9, 10, 12, 13, 14, 16, 17, 18, 20, 21, 22, 24, 25, 26, 28, 29, 30, 32, 33, 34, 36, 37, 38, 40, 41, 42, 44, 45, 46, 48, 49, 50, 52, 53, 54, 56, 57, 58, 60, 61, 62, 64, 65, 66, 68, 69, 70, 72, 73, 74, 76, 77, 78, 80]} \ No newline at end of file diff --git a/coco_splits/split_3.json b/coco_splits/split_3.json new file mode 100644 index 0000000..0dfeadd --- /dev/null +++ b/coco_splits/split_3.json @@ -0,0 +1 @@ +{"val": [4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76, 80], "train": [1, 2, 3, 5, 6, 7, 9, 10, 11, 13, 14, 15, 17, 18, 19, 21, 22, 23, 25, 26, 27, 29, 30, 31, 33, 34, 35, 37, 38, 39, 41, 42, 43, 45, 46, 47, 49, 50, 51, 53, 54, 55, 57, 58, 59, 61, 62, 63, 65, 66, 67, 69, 70, 71, 73, 74, 75, 77, 78, 79]} \ No newline at end of file diff --git a/config/H_48_D_4_proto.json b/config/H_48_D_4_proto.json new file mode 100644 index 0000000..aab99ac --- /dev/null +++ b/config/H_48_D_4_proto.json @@ -0,0 +1,147 @@ +{ + "dataset": "coco", + "method": "fcn_segmentor", + "data": { + "image_tool": "cv2", + "input_mode": "BGR", + "num_classes": 19, + "label_list": [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33], + "data_dir": "/teamspace/studios/this_studio/lessdata", + "workers": 1 + }, + "train": { + "batch_size": 1, + "data_transformer": { + "size_mode": "fix_size", + "input_size": [1024, 512], + "align_method": "only_pad", + "pad_mode": "random" + } + }, + "val": { + "batch_size": 1, + "mode": "ss_test", + "data_transformer": { + "size_mode": "fix_size", + "input_size": [2048, 1024], + "align_method": "only_pad" + } + }, + "test": { + "batch_size": 1, + "mode": "ss_test", + "out_dir": "/msravcshare/dataset/seg_result/cityscapes", + "data_transformer": { + "size_mode": "fix_size", + "input_size": [2048, 1024], + "align_method": "only_pad" + } + }, + "train_trans": { + "trans_seq": ["random_resize", "random_crop", "random_hflip", "random_brightness"], + "random_brightness": { + "ratio": 1.0, + "shift_value": 10 + }, + "random_hflip": { + "ratio": 0.5, + "swap_pair": [] + }, + "random_resize": { + "ratio": 1.0, + "method": "random", + "scale_range": [0.5, 2.0], + "aspect_range": [0.9, 1.1] + }, + "random_crop":{ + "ratio": 1.0, + "crop_size": [1024, 512], + "method": "random", + "allow_outside_center": false + } + }, + "val_trans": { + "trans_seq": [] + }, + "normalize": { + "div_value": 255.0, + "mean_value": [0.485, 0.456, 0.406], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225] + }, + "checkpoints": { + "checkpoints_name": "fs_baseocnet_cityscapes_seg", + "checkpoints_dir": "./checkpoints/cityscapes", + "save_iters": 1000 + }, + "network":{ + "backbone": "deepbase_resnet101_dilated8", + "multi_grid": [1, 1, 1], + "model_name": "base_ocnet", + "bn_type": "torchsyncbn", + "stride": 8, + "factors": [[8, 8]], + "loss_weights": { + "corr_loss": 0.01, + "aux_loss": 0.4, + "seg_loss": 1.0 + } + }, + "logging": { + "logfile_level": "info", + "stdout_level": "info", + "log_file": "./log/cityscapes/fs_baseocnet_cityscapes_seg.log", + "log_format": "%(asctime)s %(levelname)-7s %(message)s", + "rewrite": true + }, + "lr": { + "base_lr": 0.01, + "metric": "iters", + "lr_policy": "lambda_poly", + "step": { + "gamma": 0.5, + "step_size": 100 + } + }, + "solver": { + "display_iter": 10, + "test_interval": 2000, + "max_iters": 40000 + }, + "optim": { + "optim_method": "sgd", + "adam": { + "betas": [0.9, 0.999], + "eps": 1e-08, + "weight_decay": 0.0001 + }, + "sgd": { + "weight_decay": 0.0005, + "momentum": 0.9, + "nesterov": false + } + }, + "loss": { + "loss_type": "pixel_prototype_ce_loss", + "params": { + "ce_weight": [0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754, + 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, + 1.0865, 1.0955, 1.0865, 1.1529, 1.0507], + "ce_reduction": "mean", + "ce_ignore_index": -1, + "ohem_minkeep": 100000, + "ohem_thresh": 0.9 + } + }, + "protoseg": { + "gamma": 0.999, + "loss_ppc_weight": 0.01, + "loss_ppd_weight": 0.001, + "num_prototype": 10, + "pretrain_prototype": false, + "use_rmi": false, + "use_prototype": true, + "update_prototype": true, + "warmup_iters": 0 + } +} diff --git a/config/coco.yaml b/config/coco.yaml index 717b94c..5d5b857 100644 --- a/config/coco.yaml +++ b/config/coco.yaml @@ -5,7 +5,7 @@ DATA: val_list: lists/coco/val.txt split: 0 use_split_coco: True - workers: 3 + workers: 1 image_size: 417 mean: [0.485, 0.456, 0.406] std: [0.229, 0.224, 0.225] @@ -24,10 +24,10 @@ EVALUATION: ckpt_path: model_ckpt/ load_model_id: 1 ckpt_used: model - test_num: 10000 - shot: 5 - batch_size_val: 50 - n_runs: 5 + test_num: 5 + shot: 2 + batch_size_val: 1 + n_runs: 2 support_only_one_novel: True use_training_images_for_supports: False generate_new_support_set_for_each_task: False @@ -41,4 +41,4 @@ CLASSIFIER: cls_lr: 0.00125 pi_estimation_strategy: self pi_update_at: [10] - fine_tune_base_classifier: True + fine_tune_base_classifier: True \ No newline at end of file diff --git a/config/coco_resnet_base.yaml b/config/coco_resnet_base.yaml new file mode 100644 index 0000000..c2a155f --- /dev/null +++ b/config/coco_resnet_base.yaml @@ -0,0 +1,66 @@ +Data: + data_root: ../data/coco + train_list: ./lists/coco/train.txt + val_list: ./lists/coco/val.txt + classes: 61 + + +Train: + # Aug + train_h: 417 + train_w: 417 + val_size: 417 + scale_min: 0.5 # minimum random scale + scale_max: 2.0 # maximum random scale + rotate_min: -10 # minimum random rotate + rotate_max: 10 # maximum random rotate + ignore_label: 255 + padding_label: 255 + # Dataset & Mode + split: 0 + data_set: 'coco' + use_split_coco: True # True means FWB setting + # Optimizer + batch_size: 3 # batch size for training (bs12 for 1GPU) + base_lr: 2.5e-4 + epochs: 20 + start_epoch: 0 + stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs + index_split: -1 # index for determining the params group with 10x learning rate + power: 0.9 # 0 means no decay + momentum: 0.9 + weight_decay: 0.0001 + warmup: False + # Viz & Save & Resume + print_freq: 10 + save_freq: 5 + resume: # path to latest checkpoint (default: none, such as epoch_10.pth) + # Validate + evaluate: True + fix_random_seed_val: True + batch_size_val: 4 + resized_val: True + ori_resize: False # use original label for evaluation + # Else + workers: 8 # 8 data loader workers + manual_seed: 321 + seed_deterministic: False + zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] + +Method: + layers: 50 + vgg: False + + + +## deprecated multi-processing training +# Distributed: +# dist_url: tcp://127.0.0.1:6789 +# dist_backend: 'nccl' +# multiprocessing_distributed: False +# world_size: 1 +# rank: 0 +# use_apex: False +# opt_level: 'O0' +# keep_batchnorm_fp32: +# loss_scale:ls \ No newline at end of file diff --git a/data/coco/create_masks.py b/data/coco/create_masks.py index 3964977..d4ed61d 100644 --- a/data/coco/create_masks.py +++ b/data/coco/create_masks.py @@ -5,7 +5,8 @@ from pycocotools.coco import COCO for dataset in ['train2014', 'val2014']: - annFile = os.path.join('annotations', f'instances_{dataset}.json') + annFile = os.path.join('coco-2014/raw', f'instances_{dataset}.json') + print(f'PATH...... {annFile}') img_dir = dataset save_dir = 'train' if 'train' in dataset else 'val' diff --git a/downloadCoco.py b/downloadCoco.py new file mode 100644 index 0000000..809b668 --- /dev/null +++ b/downloadCoco.py @@ -0,0 +1,9 @@ +import fiftyone as fo +import fiftyone.zoo as foz + +# To download the COCO dataset for only the "person" and "car" classes +dataset = foz.load_zoo_dataset( + "coco-2014", + splits=["train", "validation", "test"], + label_types=["detections", "segmentations"] +) \ No newline at end of file diff --git a/model/ASPP.py b/model/ASPP.py new file mode 100644 index 0000000..3f7b250 --- /dev/null +++ b/model/ASPP.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data + +class ASPP(nn.Module): + def __init__(self, out_channels=256): + super(ASPP, self).__init__() + self.layer6_0 = nn.Sequential( + nn.Conv2d(out_channels , out_channels, kernel_size=1, stride=1, padding=0, bias=True), + nn.ReLU(), + ) + self.layer6_1 = nn.Sequential( + nn.Conv2d(out_channels , out_channels, kernel_size=1, stride=1, padding=0, bias=True), + nn.ReLU(), + ) + self.layer6_2 = nn.Sequential( + nn.Conv2d(out_channels , out_channels , kernel_size=3, stride=1, padding=6,dilation=6, bias=True), + nn.ReLU(), + ) + self.layer6_3 = nn.Sequential( + nn.Conv2d(out_channels , out_channels, kernel_size=3, stride=1, padding=12, dilation=12, bias=True), + nn.ReLU(), + ) + self.layer6_4 = nn.Sequential( + nn.Conv2d(out_channels , out_channels , kernel_size=3, stride=1, padding=18, dilation=18, bias=True), + nn.ReLU(), + ) + + self._init_weight() + + def _init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + torch.nn.init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def forward(self, x): + feature_size = x.shape[-2:] + global_feature = F.avg_pool2d(x, kernel_size=feature_size) + + global_feature = self.layer6_0(global_feature) + + global_feature = global_feature.expand(-1, -1, feature_size[0], feature_size[1]) + out = torch.cat( + [global_feature, self.layer6_1(x), self.layer6_2(x), self.layer6_3(x), self.layer6_4(x)], dim=1) + return out \ No newline at end of file diff --git a/model/BAM.py b/model/BAM.py new file mode 100644 index 0000000..c0e3c72 --- /dev/null +++ b/model/BAM.py @@ -0,0 +1,332 @@ +import torch +from torch import nn +from torch._C import device +import torch.nn.functional as F +from torch.nn import BatchNorm2d as BatchNorm + +import numpy as np +import random +import time +import cv2 + +import model.resnet as models +import model.vgg as vgg_models +from model.ASPP import ASPP +from model.PPM import PPM +from model.PSPNet import OneModel as PSPNet +from util.util import get_train_val_set + + +def Weighted_GAP(supp_feat, mask): + supp_feat = supp_feat * mask + feat_h, feat_w = supp_feat.shape[-2:][0], supp_feat.shape[-2:][1] + area = F.avg_pool2d(mask, (supp_feat.size()[2], supp_feat.size()[3])) * feat_h * feat_w + 0.0005 + supp_feat = F.avg_pool2d(input=supp_feat, kernel_size=supp_feat.shape[-2:]) * feat_h * feat_w / area + return supp_feat + +def get_gram_matrix(fea): + b, c, h, w = fea.shape + fea = fea.reshape(b, c, h*w) # C*N + fea_T = fea.permute(0, 2, 1) # N*C + fea_norm = fea.norm(2, 2, True) + fea_T_norm = fea_T.norm(2, 1, True) + gram = torch.bmm(fea, fea_T)/(torch.bmm(fea_norm, fea_T_norm) + 1e-7) # C*C + return gram + + +class OneModel(nn.Module): + def __init__(self, args, cls_type=None): + super(OneModel, self).__init__() + + self.cls_type = cls_type # 'Base' or 'Novel' + self.layers = args.layers + self.zoom_factor = args.zoom_factor + self.shot = args.shot + self.vgg = args.vgg + self.dataset = args.data_set + self.criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label) + + self.print_freq = args.print_freq/2 + + self.pretrained = True + self.classes = 2 + if self.dataset == 'pascal': + self.base_classes = 15 + elif self.dataset == 'coco': + self.base_classes = 60 + + assert self.layers in [50, 101, 152] + + PSPNet_ = PSPNet(args) + backbone_str = 'vgg' if args.vgg else 'resnet'+str(args.layers) + weight_path = 'initmodel/PSPNet/{}/split{}/{}/best.pth'.format(args.data_set, args.split, backbone_str) + new_param = torch.load(weight_path, map_location=torch.device('cpu'))['state_dict'] + try: + PSPNet_.load_state_dict(new_param) + except RuntimeError: # 1GPU loads mGPU model + for key in list(new_param.keys()): + new_param[key[7:]] = new_param.pop(key) + PSPNet_.load_state_dict(new_param) + + self.layer0, self.layer1, self.layer2, self.layer3, self.layer4 = PSPNet_.layer0, PSPNet_.layer1, PSPNet_.layer2, PSPNet_.layer3, PSPNet_.layer4 + + # Base Learner + self.learner_base = nn.Sequential(PSPNet_.ppm, PSPNet_.cls) + + # Meta Learner + reduce_dim = 256 + self.low_fea_id = args.low_fea[-1] + if self.vgg: + fea_dim = 512 + 256 + else: + fea_dim = 1024 + 512 + self.down_query = nn.Sequential( + nn.Conv2d(fea_dim, reduce_dim, kernel_size=1, padding=0, bias=False), + nn.ReLU(inplace=True), + nn.Dropout2d(p=0.5)) + self.down_supp = nn.Sequential( + nn.Conv2d(fea_dim, reduce_dim, kernel_size=1, padding=0, bias=False), + nn.ReLU(inplace=True), + nn.Dropout2d(p=0.5)) + mask_add_num = 1 + self.init_merge = nn.Sequential( + nn.Conv2d(reduce_dim*2 + mask_add_num, reduce_dim, kernel_size=1, padding=0, bias=False), + nn.ReLU(inplace=True)) + self.ASPP_meta = ASPP(reduce_dim) + self.res1_meta = nn.Sequential( + nn.Conv2d(reduce_dim*5, reduce_dim, kernel_size=1, padding=0, bias=False), + nn.ReLU(inplace=True)) + self.res2_meta = nn.Sequential( + nn.Conv2d(reduce_dim, reduce_dim, kernel_size=3, padding=1, bias=False), + nn.ReLU(inplace=True), + nn.Conv2d(reduce_dim, reduce_dim, kernel_size=3, padding=1, bias=False), + nn.ReLU(inplace=True)) + self.cls_meta = nn.Sequential( + nn.Conv2d(reduce_dim, reduce_dim, kernel_size=3, padding=1, bias=False), + nn.ReLU(inplace=True), + nn.Dropout2d(p=0.1), + nn.Conv2d(reduce_dim, self.classes, kernel_size=1)) + + # Gram and Meta + self.gram_merge = nn.Conv2d(2, 1, kernel_size=1, bias=False) + self.gram_merge.weight = nn.Parameter(torch.tensor([[1.0],[0.0]]).reshape_as(self.gram_merge.weight)) + + # Learner Ensemble + self.cls_merge = nn.Conv2d(2, 1, kernel_size=1, bias=False) + self.cls_merge.weight = nn.Parameter(torch.tensor([[1.0],[0.0]]).reshape_as(self.cls_merge.weight)) + + # K-Shot Reweighting + if args.shot > 1: + self.kshot_trans_dim = args.kshot_trans_dim + if self.kshot_trans_dim == 0: + self.kshot_rw = nn.Conv2d(self.shot, self.shot, kernel_size=1, bias=False) + self.kshot_rw.weight = nn.Parameter(torch.ones_like(self.kshot_rw.weight) / args.shot) + else: + self.kshot_rw = nn.Sequential( + nn.Conv2d(self.shot, self.kshot_trans_dim, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(self.kshot_trans_dim, self.shot, kernel_size=1)) + + self.sigmoid = nn.Sigmoid() + + def get_optim(self, model, args, LR): + if args.shot > 1: + optimizer = torch.optim.SGD( + [ + {'params': model.down_query.parameters()}, + {'params': model.down_supp.parameters()}, + {'params': model.init_merge.parameters()}, + {'params': model.ASPP_meta.parameters()}, + {'params': model.res1_meta.parameters()}, + {'params': model.res2_meta.parameters()}, + {'params': model.cls_meta.parameters()}, + {'params': model.gram_merge.parameters()}, + {'params': model.cls_merge.parameters()}, + {'params': model.kshot_rw.parameters()}, + ], lr=LR, momentum=args.momentum, weight_decay=args.weight_decay) + else: + optimizer = torch.optim.SGD( + [ + {'params': model.down_query.parameters()}, + {'params': model.down_supp.parameters()}, + {'params': model.init_merge.parameters()}, + {'params': model.ASPP_meta.parameters()}, + {'params': model.res1_meta.parameters()}, + {'params': model.res2_meta.parameters()}, + {'params': model.cls_meta.parameters()}, + {'params': model.gram_merge.parameters()}, + {'params': model.cls_merge.parameters()}, + ], lr=LR, momentum=args.momentum, weight_decay=args.weight_decay) + return optimizer + + def freeze_modules(self, model): + for param in model.layer0.parameters(): + param.requires_grad = False + for param in model.layer1.parameters(): + param.requires_grad = False + for param in model.layer2.parameters(): + param.requires_grad = False + for param in model.layer3.parameters(): + param.requires_grad = False + for param in model.layer4.parameters(): + param.requires_grad = False + for param in model.learner_base.parameters(): + param.requires_grad = False + + + # que_img, sup_img, sup_mask, que_mask(meta), que_mask(base), cat_idx(meta) + def forward(self, x, s_x, s_y, y_m, y_b, cat_idx=None): + x_size = x.size() + bs = x_size[0] + h = int((x_size[2] - 1) / 8 * self.zoom_factor + 1) + w = int((x_size[3] - 1) / 8 * self.zoom_factor + 1) + + # Query Feature + with torch.no_grad(): + query_feat_0 = self.layer0(x) + query_feat_1 = self.layer1(query_feat_0) + query_feat_2 = self.layer2(query_feat_1) + query_feat_3 = self.layer3(query_feat_2) + query_feat_4 = self.layer4(query_feat_3) + if self.vgg: + query_feat_2 = F.interpolate(query_feat_2, size=(query_feat_3.size(2),query_feat_3.size(3)), mode='bilinear', align_corners=True) + + query_feat = torch.cat([query_feat_3, query_feat_2], 1) + query_feat = self.down_query(query_feat) + + # Support Feature + supp_pro_list = [] + final_supp_list = [] + mask_list = [] + supp_feat_list = [] + for i in range(self.shot): + mask = (s_y[:,i,:,:] == 1).float().unsqueeze(1) + mask_list.append(mask) + with torch.no_grad(): + supp_feat_0 = self.layer0(s_x[:,i,:,:,:]) + supp_feat_1 = self.layer1(supp_feat_0) + supp_feat_2 = self.layer2(supp_feat_1) + supp_feat_3 = self.layer3(supp_feat_2) + mask = F.interpolate(mask, size=(supp_feat_3.size(2), supp_feat_3.size(3)), mode='bilinear', align_corners=True) + supp_feat_4 = self.layer4(supp_feat_3*mask) + final_supp_list.append(supp_feat_4) + if self.vgg: + supp_feat_2 = F.interpolate(supp_feat_2, size=(supp_feat_3.size(2),supp_feat_3.size(3)), mode='bilinear', align_corners=True) + + supp_feat = torch.cat([supp_feat_3, supp_feat_2], 1) + supp_feat = self.down_supp(supp_feat) + supp_pro = Weighted_GAP(supp_feat, mask) + supp_pro_list.append(supp_pro) + supp_feat_list.append(eval('supp_feat_' + self.low_fea_id)) + + # K-Shot Reweighting + que_gram = get_gram_matrix(eval('query_feat_' + self.low_fea_id)) # [bs, C, C] in (0,1) + norm_max = torch.ones_like(que_gram).norm(dim=(1,2)) + est_val_list = [] + for supp_item in supp_feat_list: + supp_gram = get_gram_matrix(supp_item) + gram_diff = que_gram - supp_gram + est_val_list.append((gram_diff.norm(dim=(1,2))/norm_max).reshape(bs,1,1,1)) # norm2 + est_val_total = torch.cat(est_val_list, 1) # [bs, shot, 1, 1] + if self.shot > 1: + val1, idx1 = est_val_total.sort(1) + val2, idx2 = idx1.sort(1) + weight = self.kshot_rw(val1) + weight = weight.gather(1, idx2) + weight_soft = torch.softmax(weight, 1) + else: + weight_soft = torch.ones_like(est_val_total) + est_val = (weight_soft * est_val_total).sum(1,True) # [bs, 1, 1, 1] + + # Prior Similarity Mask + corr_query_mask_list = [] + cosine_eps = 1e-7 + for i, tmp_supp_feat in enumerate(final_supp_list): + resize_size = tmp_supp_feat.size(2) + tmp_mask = F.interpolate(mask_list[i], size=(resize_size, resize_size), mode='bilinear', align_corners=True) + + tmp_supp_feat_4 = tmp_supp_feat * tmp_mask + q = query_feat_4 + s = tmp_supp_feat_4 + bsize, ch_sz, sp_sz, _ = q.size()[:] + + tmp_query = q + tmp_query = tmp_query.reshape(bsize, ch_sz, -1) + tmp_query_norm = torch.norm(tmp_query, 2, 1, True) + + tmp_supp = s + tmp_supp = tmp_supp.reshape(bsize, ch_sz, -1) + tmp_supp = tmp_supp.permute(0, 2, 1) + tmp_supp_norm = torch.norm(tmp_supp, 2, 2, True) + + similarity = torch.bmm(tmp_supp, tmp_query)/(torch.bmm(tmp_supp_norm, tmp_query_norm) + cosine_eps) + similarity = similarity.max(1)[0].reshape(bsize, sp_sz*sp_sz) + similarity = (similarity - similarity.min(1)[0].unsqueeze(1))/(similarity.max(1)[0].unsqueeze(1) - similarity.min(1)[0].unsqueeze(1) + cosine_eps) + corr_query = similarity.reshape(bsize, 1, sp_sz, sp_sz) + corr_query = F.interpolate(corr_query, size=(query_feat_3.size()[2], query_feat_3.size()[3]), mode='bilinear', align_corners=True) + corr_query_mask_list.append(corr_query) + corr_query_mask = torch.cat(corr_query_mask_list, 1) + corr_query_mask = (weight_soft * corr_query_mask).sum(1,True) + + # Support Prototype + supp_pro = torch.cat(supp_pro_list, 2) # [bs, 256, shot, 1] + supp_pro = (weight_soft.permute(0,2,1,3) * supp_pro).sum(2,True) + + # Tile & Cat + concat_feat = supp_pro.expand_as(query_feat) + merge_feat = torch.cat([query_feat, concat_feat, corr_query_mask], 1) # 256+256+1 + merge_feat = self.init_merge(merge_feat) + + # Base and Meta + base_out = self.learner_base(query_feat_4) + + query_meta = self.ASPP_meta(merge_feat) + query_meta = self.res1_meta(query_meta) # 1080->256 + query_meta = self.res2_meta(query_meta) + query_meta + meta_out = self.cls_meta(query_meta) + + meta_out_soft = meta_out.softmax(1) + base_out_soft = base_out.softmax(1) + + # Classifier Ensemble + meta_map_bg = meta_out_soft[:,0:1,:,:] # [bs, 1, 60, 60] + meta_map_fg = meta_out_soft[:,1:,:,:] # [bs, 1, 60, 60] + if self.training and self.cls_type == 'Base': + c_id_array = torch.arange(self.base_classes+1, device='cuda') + base_map_list = [] + for b_id in range(bs): + c_id = cat_idx[0][b_id] + 1 + c_mask = (c_id_array!=0)&(c_id_array!=c_id) + base_map_list.append(base_out_soft[b_id,c_mask,:,:].unsqueeze(0).sum(1,True)) + base_map = torch.cat(base_map_list,0) + # + # gather_id = (cat_idx[0]+1).reshape(bs,1,1,1).expand_as(base_out_soft[:,0:1,:,:]).cuda() + # fg_map = base_out_soft.gather(1,gather_id) + # base_map = base_out_soft[:,1:,:,:].sum(1,True) - fg_map + else: + base_map = base_out_soft[:,1:,:,:].sum(1,True) + + est_map = est_val.expand_as(meta_map_fg) + + meta_map_bg = self.gram_merge(torch.cat([meta_map_bg,est_map], dim=1)) + meta_map_fg = self.gram_merge(torch.cat([meta_map_fg,est_map], dim=1)) + + merge_map = torch.cat([meta_map_bg, base_map], 1) + merge_bg = self.cls_merge(merge_map) # [bs, 1, 60, 60] + + final_out = torch.cat([merge_bg, meta_map_fg], dim=1) + + # Output Part + if self.zoom_factor != 1: + meta_out = F.interpolate(meta_out, size=(h, w), mode='bilinear', align_corners=True) + base_out = F.interpolate(base_out, size=(h, w), mode='bilinear', align_corners=True) + final_out = F.interpolate(final_out, size=(h, w), mode='bilinear', align_corners=True) + + # Loss + if self.training: + main_loss = self.criterion(final_out, y_m.long()) + aux_loss1 = self.criterion(meta_out, y_m.long()) + aux_loss2 = self.criterion(base_out, y_b.long()) + return final_out.max(1)[1], main_loss, aux_loss1, aux_loss2 + else: + return final_out, meta_out, base_out \ No newline at end of file diff --git a/model/PPM.py b/model/PPM.py new file mode 100644 index 0000000..48d71c2 --- /dev/null +++ b/model/PPM.py @@ -0,0 +1,23 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class PPM(nn.Module): + def __init__(self, in_dim, reduction_dim, bins): + super(PPM, self).__init__() + self.features = [] + for bin in bins: + self.features.append(nn.Sequential( + nn.AdaptiveAvgPool2d(bin), + nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(reduction_dim), + nn.ReLU(inplace=True) + )) + self.features = nn.ModuleList(self.features) + + def forward(self, x): + x_size = x.size() + out = [x] + for f in self.features: + out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True)) + return torch.cat(out, 1) \ No newline at end of file diff --git a/model/PSPNet.py b/model/PSPNet.py new file mode 100644 index 0000000..cde0691 --- /dev/null +++ b/model/PSPNet.py @@ -0,0 +1,126 @@ +import torch +from torch import nn +from torch._C import device +import torch.nn.functional as F +from torch.nn import BatchNorm2d as BatchNorm + +import numpy as np +import random +import time +import cv2 + +import model.resnet as models +import model.vgg as vgg_models +from model.PPM import PPM + + +def get_vgg16_layer(model): + layer0_idx = range(0,7) + layer1_idx = range(7,14) + layer2_idx = range(14,24) + layer3_idx = range(24,34) + layer4_idx = range(34,43) + layers_0 = [] + layers_1 = [] + layers_2 = [] + layers_3 = [] + layers_4 = [] + for idx in layer0_idx: + layers_0 += [model.features[idx]] + for idx in layer1_idx: + layers_1 += [model.features[idx]] + for idx in layer2_idx: + layers_2 += [model.features[idx]] + for idx in layer3_idx: + layers_3 += [model.features[idx]] + for idx in layer4_idx: + layers_4 += [model.features[idx]] + layer0 = nn.Sequential(*layers_0) + layer1 = nn.Sequential(*layers_1) + layer2 = nn.Sequential(*layers_2) + layer3 = nn.Sequential(*layers_3) + layer4 = nn.Sequential(*layers_4) + return layer0,layer1,layer2,layer3,layer4 + +class OneModel(nn.Module): + def __init__(self, args): + super(OneModel, self).__init__() + + self.layers = args.layers + self.zoom_factor = args.zoom_factor + self.vgg = args.vgg + self.dataset = args.data_set + self.criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label) + + self.pretrained = True + self.classes = 16 if self.dataset=='pascal' else 61 + + assert self.layers in [50, 101, 152] + + if self.vgg: + print('INFO: Using VGG_16 bn') + vgg_models.BatchNorm = BatchNorm + vgg16 = vgg_models.vgg16_bn(pretrained=self.pretrained) + print(vgg16) + self.layer0, self.layer1, self.layer2, self.layer3, self.layer4 = get_vgg16_layer(vgg16) + else: + print('INFO: Using ResNet {}'.format(self.layers)) + if self.layers == 50: + resnet = models.resnet50(pretrained=self.pretrained) + elif self.layers == 101: + resnet = models.resnet101(pretrained=self.pretrained) + else: + resnet = models.resnet152(pretrained=self.pretrained) + self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu1, resnet.conv2, resnet.bn2, resnet.relu2, resnet.conv3, resnet.bn3, resnet.relu3, resnet.maxpool) + self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 + + for n, m in self.layer3.named_modules(): + if 'conv2' in n: + m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) + elif 'downsample.0' in n: + m.stride = (1, 1) + for n, m in self.layer4.named_modules(): + if 'conv2' in n: + m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) + elif 'downsample.0' in n: + m.stride = (1, 1) + + # Base Learner + self.encoder = nn.Sequential(self.layer0, self.layer1, self.layer2, self.layer3, self.layer4) + fea_dim = 512 if self.vgg else 2048 + bins=(1, 2, 3, 6) + self.ppm = PPM(fea_dim, int(fea_dim/len(bins)), bins) + self.cls = nn.Sequential( + nn.Conv2d(fea_dim*2, 512, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(512), + nn.ReLU(inplace=True), + nn.Dropout2d(p=0.1), + nn.Conv2d(512, self.classes, kernel_size=1)) + + def get_optim(self, model, args, LR): + optimizer = torch.optim.SGD( + [ + {'params': model.encoder.parameters()}, + {'params': model.ppm.parameters()}, + {'params': model.cls.parameters()}, + ], lr=LR, momentum=args.momentum, weight_decay=args.weight_decay) + return optimizer + + + def forward(self, x, y): + x_size = x.size() + h = int((x_size[2] - 1) / 8 * self.zoom_factor + 1) + w = int((x_size[3] - 1) / 8 * self.zoom_factor + 1) # 473 + + x = self.encoder(x) + x = self.ppm(x) + x = self.cls(x) + + if self.zoom_factor != 1: + x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True) + + if self.training: + main_loss = self.criterion(x, y.long()) + return x.max(1)[1], main_loss + else: + return x \ No newline at end of file diff --git a/model/resnet.py b/model/resnet.py new file mode 100644 index 0000000..6aedc36 --- /dev/null +++ b/model/resnet.py @@ -0,0 +1,233 @@ +import torch +import torch.nn as nn +import math +import torch.utils.model_zoo as model_zoo + +BatchNorm = nn.BatchNorm2d + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = BatchNorm(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = BatchNorm(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = BatchNorm(planes) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = BatchNorm(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, deep_base=True): + super(ResNet, self).__init__() + self.deep_base = deep_base + if not self.deep_base: + self.inplanes = 64 + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = BatchNorm(64) + self.relu = nn.ReLU(inplace=True) + else: + self.inplanes = 128 + self.conv1 = conv3x3(3, 64, stride=2) + self.bn1 = BatchNorm(64) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = conv3x3(64, 64) + self.bn2 = BatchNorm(64) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = conv3x3(64, 128) + self.bn3 = BatchNorm(128) + self.relu3 = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, BatchNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + BatchNorm(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.relu1(self.bn1(self.conv1(x))) + if self.deep_base: + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def resnet18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) + return model + + +def resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) + return model + + +def resnet50(pretrained=True, **kwargs): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) + model_path = './initmodel/resnet50_v2.pth' + model.load_state_dict(torch.load(model_path), strict=False) + return model + + +def resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + # model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) + model_path = './initmodel/resnet101_v2.pth' + model.load_state_dict(torch.load(model_path), strict=False) + return model + + +def resnet152(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + if pretrained: + # model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) + model_path = './initmodel/resnet152_v2.pth' + model.load_state_dict(torch.load(model_path), strict=False) + return model \ No newline at end of file diff --git a/model/vgg.py b/model/vgg.py new file mode 100644 index 0000000..8f7448d --- /dev/null +++ b/model/vgg.py @@ -0,0 +1,246 @@ +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + +BatchNorm = nn.BatchNorm2d + +__all__ = [ + 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', + 'vgg19_bn', 'vgg19', +] + + +model_urls = { + 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', + 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', + 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', + 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', + 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', + 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', + 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', + 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', +} + + +class VGG(nn.Module): + + def __init__(self, features, num_classes=1000, init_weights=True): + super(VGG, self).__init__() + self.features = features + self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) + self.classifier = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, num_classes), + ) + if init_weights: + self._initialize_weights() + + def forward(self, x): + x = self.features(x) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, BatchNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + +def make_layers(cfg, batch_norm=False): + layers = [] + in_channels = 3 + for v in cfg: + if v == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) + if batch_norm: + layers += [conv2d, BatchNorm(v), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + return nn.Sequential(*layers) + + +cfg = { + 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + + +def vgg11(pretrained=False, **kwargs): + """VGG 11-layer model (configuration "A") + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['A']), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) + return model + + +def vgg11_bn(pretrained=False, **kwargs): + """VGG 11-layer model (configuration "A") with batch normalization + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) + return model + + +def vgg13(pretrained=False, **kwargs): + """VGG 13-layer model (configuration "B") + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['B']), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) + return model + + +def vgg13_bn(pretrained=False, **kwargs): + """VGG 13-layer model (configuration "B") with batch normalization + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) + return model + + +def vgg16(pretrained=False, **kwargs): + """VGG 16-layer model (configuration "D") + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['D']), **kwargs) + if pretrained: + #model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) + model_path = './initmodel/vgg16.pth' + model.load_state_dict(torch.load(model_path), strict=False) + return model + + +def vgg16_bn(pretrained=False, **kwargs): + """VGG 16-layer model (configuration "D") with batch normalization + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) + if pretrained: + #model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) + model_path = './initmodel/vgg16_bn.pth' + model.load_state_dict(torch.load(model_path), strict=False) + return model + + +def vgg19(pretrained=False, **kwargs): + """VGG 19-layer model (configuration "E") + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['E']), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) + return model + + +def vgg19_bn(pretrained=False, **kwargs): + """VGG 19-layer model (configuration 'E') with batch normalization + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) + return model + +if __name__ =='__main__': + import os + # os.environ["CUDA_VISIBLE_DEVICES"] = '7' + input = torch.rand(4, 3, 473, 473).cuda() + target = torch.rand(4, 473, 473).cuda()*1.0 + model = vgg16_bn(pretrained=False).cuda() + model.train() + layer0_idx = range(0,6) + layer1_idx = range(6,13) + layer2_idx = range(13,23) + layer3_idx = range(23,33) + layer4_idx = range(34,43) + #layer4_idx = range(34,43) + print(model.features) + layers_0 = [] + layers_1 = [] + layers_2 = [] + layers_3 = [] + layers_4 = [] + for idx in layer0_idx: + layers_0 += [model.features[idx]] + for idx in layer1_idx: + layers_1 += [model.features[idx]] + for idx in layer2_idx: + layers_2 += [model.features[idx]] + for idx in layer3_idx: + layers_3 += [model.features[idx]] + for idx in layer4_idx: + layers_4 += [model.features[idx]] + + layer0 = nn.Sequential(*layers_0) + layer1 = nn.Sequential(*layers_1) + layer2 = nn.Sequential(*layers_2) + layer3 = nn.Sequential(*layers_3) + layer4 = nn.Sequential(*layers_4) + + output = layer0(input) + print(layer0) + print('layer 0: {}'.format(output.size())) + output = layer1(output) + print(layer1) + print('layer 1: {}'.format(output.size())) + output = layer2(output) + print(layer2) + print('layer 2: {}'.format(output.size())) + output = layer3(output) + print(layer3) + print('layer 3: {}'.format(output.size())) + output = layer4(output) + print(layer4) + print('layer 4: {}'.format(output.size())) + \ No newline at end of file diff --git a/pascal_splits/split_-1.json b/pascal_splits/split_-1.json new file mode 100644 index 0000000..0255f72 --- /dev/null +++ b/pascal_splits/split_-1.json @@ -0,0 +1 @@ +{"val": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]} \ No newline at end of file diff --git a/pascal_splits/split_0.json b/pascal_splits/split_0.json new file mode 100644 index 0000000..a85a1a2 --- /dev/null +++ b/pascal_splits/split_0.json @@ -0,0 +1 @@ +{"val": [1, 2, 3, 4, 5], "train": [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]} \ No newline at end of file diff --git a/pascal_splits/split_1.json b/pascal_splits/split_1.json new file mode 100644 index 0000000..aba29d1 --- /dev/null +++ b/pascal_splits/split_1.json @@ -0,0 +1 @@ +{"val": [6, 7, 8, 9, 10], "train": [1, 2, 3, 4, 5, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]} \ No newline at end of file diff --git a/pascal_splits/split_10.json b/pascal_splits/split_10.json new file mode 100644 index 0000000..94f68bd --- /dev/null +++ b/pascal_splits/split_10.json @@ -0,0 +1 @@ +{"val": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], "train": [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]} \ No newline at end of file diff --git a/pascal_splits/split_11.json b/pascal_splits/split_11.json new file mode 100644 index 0000000..dd45a62 --- /dev/null +++ b/pascal_splits/split_11.json @@ -0,0 +1 @@ +{"val": [11, 12, 13, 14, 15, 16, 17, 18, 19, 20], "train": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]} \ No newline at end of file diff --git a/pascal_splits/split_2.json b/pascal_splits/split_2.json new file mode 100644 index 0000000..49414b3 --- /dev/null +++ b/pascal_splits/split_2.json @@ -0,0 +1 @@ +{"val": [11, 12, 13, 14, 15], "train": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 17, 18, 19, 20]} \ No newline at end of file diff --git a/pascal_splits/split_3.json b/pascal_splits/split_3.json new file mode 100644 index 0000000..f65c465 --- /dev/null +++ b/pascal_splits/split_3.json @@ -0,0 +1 @@ +{"val": [16, 17, 18, 19, 20], "train": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]} \ No newline at end of file diff --git a/proto/PixelPrototypeCELoss.py b/proto/PixelPrototypeCELoss.py new file mode 100644 index 0000000..4f6e09f --- /dev/null +++ b/proto/PixelPrototypeCELoss.py @@ -0,0 +1,109 @@ +from abc import ABC + +import torch +import torch.nn as nn +import torch.nn.functional as F +import yaml +import json +from .loss_helper import FSAuxRMILoss, FSCELoss +# from lib.utils.tools.logger import Logger as Log + + +class PPC(nn.Module, ABC): + def __init__(self, configer): + super(PPC, self).__init__() + + + # with open("config/H_48_D_4_proto.json", "r") as f: + # self.configer = json.load(f) + + # dataset_name = configer["dataset"] + # num_classes = configer["data"]["num_classes"] + # ce_weight = configer["loss"]["params"]["ce_weight"] + self.ignore_label = -1 + # if self.configer.exists('loss', 'params') and 'ce_ignore_index' in self.configer.get('loss', 'params'): + # self.ignore_label = self.configer.get('loss', 'params')['ce_ignore_index'] + + def forward(self, contrast_logits, contrast_target): + loss_ppc = F.cross_entropy(contrast_logits, contrast_target.long(), ignore_index=self.ignore_label) + return loss_ppc + + +class PPD(nn.Module, ABC): + def __init__(self, configer): + super(PPD, self).__init__() + + # with open("config/H_48_D_4_proto.json", "r") as f: + # self.configer = json.load(f) + + # dataset_name = configer["dataset"] + # num_classes = configer["data"]["num_classes"] + # ce_weight = configer["loss"]["params"]["ce_weight"] + + self.ignore_label = -1 + # if "loss" in configer and "params" in configer["loss"] and "ce_weight" in configer["loss"]["params"]: + # self.ignore_label = torch.FloatTensor(configer["loss"]["params"]["ce_weight"]).cuda() + # if self.configer.exists('loss', 'params') and 'ce_ignore_index' in self.configer.get('loss', 'params'): + # self.ignore_label = self.configer.get('loss', 'params')['ce_ignore_index'] + + def forward(self, contrast_logits, contrast_target): + contrast_logits = contrast_logits[contrast_target != self.ignore_label, :] + contrast_target = contrast_target[contrast_target != self.ignore_label] + + logits = torch.gather(contrast_logits, 1, contrast_target[:, None].long()) + loss_ppd = (1 - logits).pow(2).mean() + return loss_ppd + + +class PixelPrototypeCELoss(nn.Module, ABC): + def __init__(self, configer=None): + super(PixelPrototypeCELoss, self).__init__() + loss_ppc_weight = 0.01 + loss_ppd_weight = 0.001 + use_rmi = False + # self.configer = configer + + ignore_index = -1 + # if self.configer.exists('loss', 'params') and 'ce_ignore_index' in self.configer.get('loss', 'params'): + # ignore_index = self.configer.get('loss', 'params')['ce_ignore_index'] + # Log.info('ignore_index: {}'.format(ignore_index)) + + self.loss_ppc_weight = loss_ppc_weight + self.loss_ppd_weight = loss_ppd_weight + + self.use_rmi = False + + if self.use_rmi: + self.seg_criterion = FSAuxRMILoss(configer=configer) + else: + self.seg_criterion = FSCELoss(configer=configer) + + self.ppc_criterion = PPC(configer=configer) + self.ppd_criterion = PPD(configer=configer) + + def forward(self, preds, target): + h, w = target.size(1), target.size(2) + print("first THE PPD LOSS", preds.size()) + print("first TARGET", target.size()) + if isinstance(preds, dict): + assert "seg" in preds + assert "logits" in preds + assert "target" in preds + + seg = preds['seg'] + contrast_logits = preds['logits'] + contrast_target = preds['target'] + loss_ppc = self.ppc_criterion(contrast_logits, contrast_target) + loss_ppd = self.ppd_criterion(contrast_logits, contrast_target) + + pred = F.interpolate(input=seg, size=(h, w), mode='bilinear', align_corners=True) + loss = self.seg_criterion(pred, target) + return loss + self.loss_ppc_weight * loss_ppc + self.loss_ppd_weight * loss_ppd + + # concatenated_tensor = torch.cat([preds[0],preds[1]], dim=0) + print("HITTING THE PPD LOSS", preds.size()) + aux_out, seg_out = preds + seg = seg_out + pred = F.interpolate(input=seg, size=(h, w), mode='bilinear', align_corners=True) + loss = self.seg_criterion(pred, target) + return loss diff --git a/proto/__init__.py b/proto/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/proto/loss_helper.py b/proto/loss_helper.py new file mode 100644 index 0000000..92a7898 --- /dev/null +++ b/proto/loss_helper.py @@ -0,0 +1,104 @@ +# ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +# ## Created by: Donny You, RainbowSecret +# ## Microsoft Research +# ## yuyua@microsoft.com +# ## Copyright (c) 2019 +# ## +# ## This source code is licensed under the MIT-style license found in the +# ## LICENSE file in the root directory of this source tree +# ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + + +import os +import pdb +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F +from torch.autograd import Variable +from .rmi_loss import RMILoss +import json + +# Cross-entropy Loss +class FSCELoss(nn.Module): + def __init__(self, configer): + super(FSCELoss, self).__init__() + self.configer = configer + + with open("config/H_48_D_4_proto.json", "r") as f: + self.configer = json.load(f) + # print("CONFIGURE....",self.configure) + # print("CONFIGURE....",configure) + weight = None + # Access and print a specific config value (example) + if "loss" in self.configer and "params" in self.configer["loss"]: + ce_weight = self.configer["loss"]["params"].get("ce_weight") + if ce_weight is not None: + print("CE weight from config:", ce_weight) + else: + print("CE weight not found in config") + # if "loss" in configer and "params" in configer["loss"] and "ce_weight" in configer["loss"]["params"]: + # weight = self.configer["loss"]["params"]['ce_weight'] + # weight = torch.FloatTensor(weight).cuda() + # if self.configer.exists('loss', 'params') and 'ce_weight' in self.configer.get('loss', 'params'): + # weight = self.configer.get('loss', 'params')['ce_weight'] + # weight = torch.FloatTensor(weight).cuda() + + reduction = 'mean' + if "loss" in self.configer and "params" in self.configer["loss"] and "ce_reduction" in self.configer["loss"]["params"]: + reduction = self.configer["loss"]["params"]['ce_reduction'] + # if self.configer.exists('loss', 'params') and 'ce_reduction' in self.configer.get('loss', 'params'): + # reduction = self.configer.get('loss', 'params')['ce_reduction'] + + ignore_index = -1 + if "loss" in self.configer and "params" in self.configer["loss"] and "ce_weight" in self.configer["loss"]["params"]: + ignore_index = self.configer["loss"]["params"]['ce_ignore_index'] + # if self.configer.exists('loss', 'params') and 'ce_ignore_index' in self.configer.get('loss', 'params'): + # ignore_index = self.configer.get('loss', 'params')['ce_ignore_index'] + + self.ce_loss = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, reduction=reduction) + + def forward(self, inputs, *targets, weights=None, **kwargs): + loss = 0.0 + if isinstance(inputs, tuple) or isinstance(inputs, list): + if weights is None: + weights = [1.0] * len(inputs) + + for i in range(len(inputs)): + if len(targets) > 1: + target = self._scale_target(targets[i], (inputs[i].size(2), inputs[i].size(3))) + loss += weights[i] * self.ce_loss(inputs[i], target) + else: + target = self._scale_target(targets[0], (inputs[i].size(2), inputs[i].size(3))) + loss += weights[i] * self.ce_loss(inputs[i], target) + + else: + target = self._scale_target(targets[0], (inputs.size(2), inputs.size(3))) + loss = self.ce_loss(inputs, target) + + return loss + + @staticmethod + def _scale_target(targets_, scaled_size): + targets = targets_.clone().unsqueeze(1).float() + targets = F.interpolate(targets, size=scaled_size, mode='nearest') + return targets.squeeze(1).long() + + +class FSAuxRMILoss(nn.Module): + def __init__(self, configer): + super(FSAuxRMILoss, self).__init__() + self.configer = configer + self.ce_loss = FSCELoss(self.configer) + self.rmi_loss = RMILoss(self.configer) + + def forward(self, inputs, targets, **kwargs): + aux_out, seg_out = inputs + aux_loss = self.ce_loss(aux_out, targets) + seg_loss = self.rmi_loss(seg_out, targets) + loss = self.configer.get('network', 'loss_weights')['seg_loss'] * seg_loss + loss = loss + self.configer.get('network', 'loss_weights')['aux_loss'] * aux_loss + return loss + + + diff --git a/proto/rmi_loss.py b/proto/rmi_loss.py new file mode 100644 index 0000000..949e50d --- /dev/null +++ b/proto/rmi_loss.py @@ -0,0 +1,399 @@ +# coding=utf-8 + +""" +The implementation of the paper: +Region Mutual Information Loss for Semantic Segmentation. +""" + +# python 2.X, 3.X compatibility + +import pdb +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['RMILoss'] + +TORCH_VERSION = torch.__version__[:3] + +_euler_num = 2.718281828 # euler number +_pi = 3.14159265 # pi +_ln_2_pi = 1.837877 # ln(2 * pi) +_CLIP_MIN = 1e-6 # min clip value after softmax or sigmoid operations +_CLIP_MAX = 1.0 # max clip value after softmax or sigmoid operations +_POS_ALPHA = 1e-3 # add this factor to ensure the AA^T is positive definite +_IS_SUM = 1 # sum the loss per channel + + +def map_get_pairs(labels_4D, probs_4D, radius=3, is_combine=True): + """get map pairs + Args: + labels_4D : labels, shape [N, C, H, W] + probs_4D : probabilities, shape [N, C, H, W] + radius : the square radius + Return: + tensor with shape [N, C, radius * radius, H - (radius - 1), W - (radius - 1)] + """ + # pad to ensure the following slice operation is valid + # pad_beg = int(radius // 2) + # pad_end = radius - pad_beg + + # the original height and width + label_shape = labels_4D.size() + h, w = label_shape[2], label_shape[3] + new_h, new_w = h - (radius - 1), w - (radius - 1) + # https://pytorch.org/docs/stable/nn.html?highlight=f%20pad#torch.nn.functional.pad + # padding = (pad_beg, pad_end, pad_beg, pad_end) + # labels_4D, probs_4D = F.pad(labels_4D, padding), F.pad(probs_4D, padding) + + # get the neighbors + la_ns = [] + pr_ns = [] + # for x in range(0, radius, 1): + for y in range(0, radius, 1): + for x in range(0, radius, 1): + la_now = labels_4D[:, :, y:y + new_h, x:x + new_w] + pr_now = probs_4D[:, :, y:y + new_h, x:x + new_w] + la_ns.append(la_now) + pr_ns.append(pr_now) + + if is_combine: + # for calculating RMI + pair_ns = la_ns + pr_ns + p_vectors = torch.stack(pair_ns, dim=2) + return p_vectors + else: + # for other purpose + la_vectors = torch.stack(la_ns, dim=2) + pr_vectors = torch.stack(pr_ns, dim=2) + return la_vectors, pr_vectors + + +def map_get_pairs_region(labels_4D, probs_4D, radius=3, is_combine=0, num_classeses=21): + """get map pairs + Args: + labels_4D : labels, shape [N, C, H, W]. + probs_4D : probabilities, shape [N, C, H, W]. + radius : The side length of the square region. + Return: + A tensor with shape [N, C, radiu * radius, H // radius, W // raidius] + """ + kernel = torch.zeros([num_classeses, 1, radius, radius]).type_as(probs_4D) + padding = radius // 2 + # get the neighbours + la_ns = [] + pr_ns = [] + for y in range(0, radius, 1): + for x in range(0, radius, 1): + kernel_now = kernel.clone() + kernel_now[:, :, y, x] = 1.0 + la_now = F.conv2d(labels_4D, kernel_now, stride=radius, padding=padding, groups=num_classeses) + pr_now = F.conv2d(probs_4D, kernel_now, stride=radius, padding=padding, groups=num_classeses) + la_ns.append(la_now) + pr_ns.append(pr_now) + + if is_combine: + # for calculating RMI + pair_ns = la_ns + pr_ns + p_vectors = torch.stack(pair_ns, dim=2) + return p_vectors + else: + # for other purpose + la_vectors = torch.stack(la_ns, dim=2) + pr_vectors = torch.stack(pr_ns, dim=2) + return la_vectors, pr_vectors + return + + +def log_det_by_cholesky(matrix): + """ + Args: + matrix: matrix must be a positive define matrix. + shape [N, C, D, D]. + Ref: + https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/linalg/linalg_impl.py + """ + # This uses the property that the log det(A) = 2 * sum(log(real(diag(C)))) + # where C is the cholesky decomposition of A. + chol = torch.cholesky(matrix) + # return 2.0 * torch.sum(torch.log(torch.diagonal(chol, dim1=-2, dim2=-1) + 1e-6), dim=-1) + return 2.0 * torch.sum(torch.log(torch.diagonal(chol, dim1=-2, dim2=-1) + 1e-8), dim=-1) + + +def batch_cholesky_inverse(matrix): + """ + Args: matrix, 4-D tensor, [N, C, M, M]. + matrix must be a symmetric positive define matrix. + """ + chol_low = torch.cholesky(matrix, upper=False) + chol_low_inv = batch_low_tri_inv(chol_low) + return torch.matmul(chol_low_inv.transpose(-2, -1), chol_low_inv) + + +def batch_low_tri_inv(L): + """ + Batched inverse of lower triangular matrices + Args: + L : a lower triangular matrix + Ref: + https://www.pugetsystems.com/labs/hpc/PyTorch-for-Scientific-Computing + """ + n = L.shape[-1] + invL = torch.zeros_like(L) + for j in range(0, n): + invL[..., j, j] = 1.0 / L[..., j, j] + for i in range(j + 1, n): + S = 0.0 + for k in range(0, i + 1): + S = S - L[..., i, k] * invL[..., k, j].clone() + invL[..., i, j] = S / L[..., i, i] + return invL + + +def log_det_by_cholesky_test(): + """ + test for function log_det_by_cholesky() + """ + a = torch.randn(1, 4, 4) + a = torch.matmul(a, a.transpose(2, 1)) + print(a) + res_1 = torch.logdet(torch.squeeze(a)) + res_2 = log_det_by_cholesky(a) + print(res_1, res_2) + + +def batch_inv_test(): + """ + test for function batch_cholesky_inverse() + """ + a = torch.randn(1, 1, 4, 4) + a = torch.matmul(a, a.transpose(-2, -1)) + print(a) + res_1 = torch.inverse(a) + res_2 = batch_cholesky_inverse(a) + print(res_1, '\n', res_2) + + +def mean_var_test(): + x = torch.randn(3, 4) + y = torch.randn(3, 4) + + x_mean = x.mean(dim=1, keepdim=True) + x_sum = x.sum(dim=1, keepdim=True) / 2.0 + y_mean = y.mean(dim=1, keepdim=True) + y_sum = y.sum(dim=1, keepdim=True) / 2.0 + + x_var_1 = torch.matmul(x - x_mean, (x - x_mean).t()) + x_var_2 = torch.matmul(x, x.t()) - torch.matmul(x_sum, x_sum.t()) + xy_cov = torch.matmul(x - x_mean, (y - y_mean).t()) + xy_cov_1 = torch.matmul(x, y.t()) - x_sum.matmul(y_sum.t()) + + print(x_var_1) + print(x_var_2) + + print(xy_cov, '\n', xy_cov_1) + + +class RMILoss(nn.Module): + """ + region mutual information + I(A, B) = H(A) + H(B) - H(A, B) + This version need a lot of memory if do not dwonsample. + """ + + def __init__(self, + configer=None): + super(RMILoss, self).__init__() + self.configer = configer + self.use_sigmoid = self.configer.get('loss', 'params')['use_sigmoid'] + self.num_classes = self.configer.get('loss', 'params')['num_classes'] + # radius choices + self.rmi_radius = self.configer.get('loss', 'params')['rmi_radius'] + assert self.rmi_radius in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + self.rmi_pool_way = self.configer.get('loss', 'params')['rmi_pool_way'] + assert self.rmi_pool_way in [0, 1, 2, 3] + + # set the pool_size = rmi_pool_stride + self.rmi_pool_size = self.configer.get('loss', 'params')['rmi_pool_size'] + self.rmi_pool_stride = self.configer.get('loss', 'params')['rmi_pool_stride'] + assert self.rmi_pool_size == self.rmi_pool_stride + + self.weight_lambda = self.configer.get('loss', 'params')['loss_weight_lambda'] + self.loss_weight = self.configer.get('loss', 'params')['loss_weight'] + self.lambda_way = self.configer.get('loss', 'params')['lambda_way'] + + # dimension of the distribution + self.half_d = self.rmi_radius * self.rmi_radius + self.d = 2 * self.half_d + self.kernel_padding = self.rmi_pool_size // 2 + # ignore class + self.ignore_index = 255 + + def forward(self, + cls_score, + label, + weight=None, + **kwargs): + label[label < 0] = 255 + loss = self.loss_weight * self.forward_sigmoid(cls_score, label) + label[label == 255] = -1 + # loss = self.forward_softmax_sigmoid(cls_score, label) + return loss + + def forward_softmax_sigmoid(self, logits_4D, labels_4D): + """ + Using both softmax and sigmoid operations. + Args: + logits_4D : [N, C, H, W], dtype=float32 + labels_4D : [N, H, W], dtype=long + """ + # PART I -- get the normal cross entropy loss + print( + "max label: {} min label: {}".format(labels_4D[labels_4D != 255].max(), labels_4D[labels_4D != 255].min())) + normal_loss = F.cross_entropy(input=logits_4D, + target=labels_4D.long(), + ignore_index=self.ignore_index, + reduction='mean') + + # PART II -- get the lower bound of the region mutual information + # get the valid label and logits + # valid label, [N, C, H, W] + label_mask_3D = labels_4D < self.num_classes + valid_onehot_labels_4D = F.one_hot(labels_4D.long() * label_mask_3D.long(), + num_classes=self.num_classes).float() + label_mask_3D = label_mask_3D.float() + valid_onehot_labels_4D = valid_onehot_labels_4D * label_mask_3D.unsqueeze(dim=3) + valid_onehot_labels_4D = valid_onehot_labels_4D.permute(0, 3, 1, 2).requires_grad_(False) + # valid probs + probs_4D = F.sigmoid(logits_4D) * label_mask_3D.unsqueeze(dim=1) + probs_4D = probs_4D.clamp(min=_CLIP_MIN, max=_CLIP_MAX) + + # get region mutual information + rmi_loss = self.rmi_lower_bound(valid_onehot_labels_4D, probs_4D) + + # add together + final_loss = (self.weight_lambda * normal_loss + rmi_loss * (1 - self.weight_lambda) if self.lambda_way + else normal_loss + rmi_loss * self.weight_lambda) + + return final_loss + + def forward_sigmoid(self, logits_4D, labels_4D): + """ + Using the sigmiod operation both. + Args: + logits_4D : [N, C, H, W], dtype=float32 + labels_4D : [N, H, W], dtype=long + """ + # label mask -- [N, H, W, 1] + label_mask_3D = labels_4D < self.num_classes + + # valid label + valid_onehot_labels_4D = F.one_hot(labels_4D.long() * label_mask_3D.long(), + num_classes=self.num_classes).float() + label_mask_3D = label_mask_3D.float() + label_mask_flat = label_mask_3D.view([-1, ]) + valid_onehot_labels_4D = valid_onehot_labels_4D * label_mask_3D.unsqueeze(dim=3) + valid_onehot_labels_4D.requires_grad_(False) + + # PART I -- calculate the sigmoid binary cross entropy loss + valid_onehot_label_flat = valid_onehot_labels_4D.view([-1, self.num_classes]).requires_grad_(False) + logits_flat = logits_4D.permute(0, 2, 3, 1).contiguous().view([-1, self.num_classes]) + + # binary loss, multiplied by the not_ignore_mask + valid_pixels = torch.sum(label_mask_flat) + binary_loss = F.binary_cross_entropy_with_logits(logits_flat, + target=valid_onehot_label_flat, + weight=label_mask_flat.unsqueeze(dim=1), + reduction='sum') + bce_loss = torch.div(binary_loss, valid_pixels + 1.0) + + # PART II -- get rmi loss + # onehot_labels_4D -- [N, C, H, W] + probs_4D = logits_4D.sigmoid() * label_mask_3D.unsqueeze(dim=1) + _CLIP_MIN + valid_onehot_labels_4D = valid_onehot_labels_4D.permute(0, 3, 1, 2).requires_grad_(False) + + # get region mutual information + rmi_loss = self.rmi_lower_bound(valid_onehot_labels_4D, probs_4D) + + # add together + final_loss = (self.weight_lambda * bce_loss + rmi_loss * (1 - self.weight_lambda) if self.lambda_way + else bce_loss + rmi_loss * self.weight_lambda) + + return final_loss + + def rmi_lower_bound(self, labels_4D, probs_4D): + """ + calculate the lower bound of the region mutual information. + Args: + labels_4D : [N, C, H, W], dtype=float32 + probs_4D : [N, C, H, W], dtype=float32 + """ + assert labels_4D.size() == probs_4D.size() + + p, s = self.rmi_pool_size, self.rmi_pool_stride + if self.rmi_pool_stride > 1: + if self.rmi_pool_way == 0: + labels_4D = F.max_pool2d(labels_4D, kernel_size=p, stride=s, padding=self.kernel_padding) + probs_4D = F.max_pool2d(probs_4D, kernel_size=p, stride=s, padding=self.kernel_padding) + elif self.rmi_pool_way == 1: + labels_4D = F.avg_pool2d(labels_4D, kernel_size=p, stride=s, padding=self.kernel_padding) + probs_4D = F.avg_pool2d(probs_4D, kernel_size=p, stride=s, padding=self.kernel_padding) + elif self.rmi_pool_way == 2: + # interpolation + shape = labels_4D.size() + new_h, new_w = shape[2] // s, shape[3] // s + labels_4D = F.interpolate(labels_4D, size=(new_h, new_w), mode='nearest') + probs_4D = F.interpolate(probs_4D, size=(new_h, new_w), mode='bilinear', align_corners=True) + else: + raise NotImplementedError("Pool way of RMI is not defined!") + # we do not need the gradient of label. + label_shape = labels_4D.size() + n, c = label_shape[0], label_shape[1] + + # combine the high dimension points from label and probability map. new shape [N, C, radius * radius, H, W] + la_vectors, pr_vectors = map_get_pairs(labels_4D, probs_4D, radius=self.rmi_radius, is_combine=0) + + la_vectors = la_vectors.view([n, c, self.half_d, -1]).type(torch.cuda.DoubleTensor).requires_grad_(False) + pr_vectors = pr_vectors.view([n, c, self.half_d, -1]).type(torch.cuda.DoubleTensor) + + # small diagonal matrix, shape = [1, 1, radius * radius, radius * radius] + diag_matrix = torch.eye(self.half_d).unsqueeze(dim=0).unsqueeze(dim=0) + + # the mean and covariance of these high dimension points + # Var(X) = E(X^2) - E(X) E(X), N * Var(X) = X^2 - X E(X) + la_vectors = la_vectors - la_vectors.mean(dim=3, keepdim=True) + la_cov = torch.matmul(la_vectors, la_vectors.transpose(2, 3)) + + pr_vectors = pr_vectors - pr_vectors.mean(dim=3, keepdim=True) + pr_cov = torch.matmul(pr_vectors, pr_vectors.transpose(2, 3)) + # https://github.com/pytorch/pytorch/issues/7500 + # waiting for batched torch.cholesky_inverse() + pr_cov_inv = torch.inverse(pr_cov + diag_matrix.type_as(pr_cov) * _POS_ALPHA) + # if the dimension of the point is less than 9, you can use the below function + # to acceleration computational speed. + # pr_cov_inv = utils.batch_cholesky_inverse(pr_cov + diag_matrix.type_as(pr_cov) * _POS_ALPHA) + + la_pr_cov = torch.matmul(la_vectors, pr_vectors.transpose(2, 3)) + # the approxiamation of the variance, det(c A) = c^n det(A), A is in n x n shape; + # then log det(c A) = n log(c) + log det(A). + # appro_var = appro_var / n_points, we do not divide the appro_var by number of points here, + # and the purpose is to avoid underflow issue. + # If A = A^T, A^-1 = (A^-1)^T. + appro_var = la_cov - torch.matmul(la_pr_cov.matmul(pr_cov_inv), la_pr_cov.transpose(-2, -1)) + # appro_var = la_cov - torch.chain_matmul(la_pr_cov, pr_cov_inv, la_pr_cov.transpose(-2, -1)) + # appro_var = torch.div(appro_var, n_points.type_as(appro_var)) + diag_matrix.type_as(appro_var) * 1e-6 + + # The lower bound. If A is nonsingular, ln( det(A) ) = Tr( ln(A) ). + rmi_now = 0.5 * log_det_by_cholesky(appro_var + diag_matrix.type_as(appro_var) * _POS_ALPHA) + # rmi_now = 0.5 * torch.logdet(appro_var + diag_matrix.type_as(appro_var) * _POS_ALPHA) + + # mean over N samples. sum over classes. + rmi_per_class = rmi_now.view([-1, self.num_classes]).mean(dim=0).float() + # is_half = False + # if is_half: + # rmi_per_class = torch.div(rmi_per_class, float(self.half_d / 2.0)) + # else: + rmi_per_class = torch.div(rmi_per_class, float(self.half_d)) + + rmi_loss = torch.sum(rmi_per_class) if _IS_SUM else torch.mean(rmi_per_class) + return rmi_loss diff --git a/src/PixelPrototypeCELoss.py b/src/PixelPrototypeCELoss.py new file mode 100644 index 0000000..92991ef --- /dev/null +++ b/src/PixelPrototypeCELoss.py @@ -0,0 +1,204 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +import yaml +import json +from .loss_helper import FSAuxRMILoss, FSCELoss + +def distributed_sinkhorn(Q, nmb_iters): + with torch.no_grad(): + # print(f'Q shape is .....{Q.shape}') + Q = Q.T + # B = Q.shape[1] * Q.shape[2] + B = Q.shape[1] * Q.shape[0] + # print(f'print set....{B}') + K = Q.shape[0] + # make the matrix sums to 1 + sum_Q = torch.sum(Q) + dist.all_reduce(sum_Q) + Q /= sum_Q + for it in range(nmb_iters): + # normalize each row: total weight per prototype must be 1/K + sum_of_rows = torch.sum(Q, dim=1, keepdim=True) + dist.all_reduce(sum_of_rows) + Q /= sum_of_rows + Q /= K + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= B + Q *= B # the colomns must sum to 1 + Q = Q.T + return Q, torch.argmax(Q, dim=1) + +class ProjectionHead(nn.Module): + def __init__(self, dim_in, proj_dim=256, drop_rate=0.1): + super(ProjectionHead, self).__init__() + self.proj = nn.Linear(dim_in, proj_dim) + self.norm = nn.BatchNorm1d(proj_dim) + self.act = nn.ReLU(inplace=True) + self.drop = nn.Dropout(p=drop_rate) + + def forward(self, x): + B, C, H, W = x.shape + x = x.permute(0, 2, 3, 1).reshape(B * H * W, C) + x = self.proj(x) + x = self.norm(x) + x = self.act(x) + x = self.drop(x) + return x + +class PixelPrototypeClassifier(nn.Module): + def __init__(self, configer: dict, backbone, feature_dim, num_prototypes): + super(PixelPrototypeClassifier, self).__init__() + self.configer = configer + self.gamma = self.configer['protoseg']['gamma'] + self.num_prototype = self.configer['protoseg']['num_prototype'] + self.use_prototype = self.configer['protoseg']['use_prototype'] + self.update_prototype = self.configer['protoseg']['update_prototype'] + self.pretrain_prototype = self.configer['protoseg']['pretrain_prototype'] + self.num_classes = self.configer['data']['num_classes'] + self.backbone = backbone + self.feature_dim = feature_dim + + # Prototype layer + self.prototypes = nn.Parameter(torch.zeros(self.num_classes, num_prototypes, self.feature_dim), + requires_grad=True) + + self.proj_head = ProjectionHead(self.feature_dim, self.feature_dim) + self.feat_norm = nn.LayerNorm(self.feature_dim) + self.mask_norm = nn.LayerNorm(self.num_classes) + + trunc_normal_(self.prototypes, std=0.02) + + def prototype_learning(self, features, gt_seg): + # Flatten features + _c = rearrange(features, 'b c h w -> (b h w) c') + _c = self.feat_norm(_c) + _c = l2_normalize(_c) # Assuming l2_normalize is defined elsewhere + + # Normalize prototypes + self.prototypes.data.copy_(l2_normalize(self.prototypes)) + + # Cosine similarity + masks = torch.einsum('nd,kmd->nmk', _c, self.prototypes) + out_seg = torch.amax(masks, dim=1) + out_seg = self.mask_norm(out_seg) + out_seg = rearrange(out_seg, "(b h w) k -> b k h w", b=features.shape[0], h=features.shape[2]) + + if self.use_prototype and gt_seg is not None: + gt_seg = F.interpolate(gt_seg.float(), size=features.size()[2:], mode='nearest').view(-1) + contrast_logits, contrast_target = self.prototype_learning_step(_c, out_seg, gt_seg, masks) + return {'seg': out_seg, 'logits': contrast_logits, 'target': contrast_target} + + return out_seg + + def prototype_learning_step(self, _c, out_seg, gt_seg, masks): + pred_seg = torch.max(out_seg, 1)[1] + mask = (gt_seg == pred_seg.view(-1)) + cosine_similarity = torch.mm(_c, self.prototypes.view(-1, self.prototypes.shape[-1]).t()) + proto_logits = cosine_similarity + proto_target = gt_seg.clone().float() + + protos = self.prototypes.data.clone() + for k in range(self.num_classes): + init_q = masks[..., k] + init_q = init_q[gt_seg == k, ...] + if init_q.shape[0] == 0: + continue + # Assuming distributed_sinkhorn is defined elsewhere + q, indexs = distributed_sinkhorn(init_q, nmb_iters=3) + m_k = mask[gt_seg == k] + c_k = _c[gt_seg == k, ...] + m_k_tile = repeat(m_k, 'n -> n tile', tile=self.num_prototype) + m_q = q * m_k_tile + c_k_tile = repeat(m_k, 'n -> n tile', tile=c_k.shape[-1]) + c_q = c_k * c_k_tile + f = m_q.transpose(0, 1) @ c_q + n = torch.sum(m_q, dim=0) + if torch.sum(n) > 0 and self.update_prototype is True: + f = F.normalize(f, p=2, dim=-1) + new_value = momentum_update(old_value=protos[k, n != 0, :], new_value=f[n != 0, :], + momentum=self.gamma) + protos[k, n != 0, :] = new_value + proto_target[gt_seg == k] = indexs.float() + (self.num_prototype * k) + self.prototypes = nn.Parameter(l2_normalize(protos), requires_grad=False) + return proto_logits, proto_target + + def forward(self, x, gt_semantic_seg=None, pretrain_prototype=False): + features = self.backbone(x)[-1] + features = self.proj_head(features) + if pretrain_prototype is False and self.use_prototype and gt_semantic_seg is not None: + output = self.prototype_learning(features, gt_semantic_seg) + else: + output = self.prototype_learning(features, None) + return output + +class PPC(nn.Module): + def __init__(self, ignore_label=-1): + super(PPC, self).__init__() + self.ignore_label = ignore_label + + def forward(self, contrast_logits, contrast_target): + loss_ppc = F.cross_entropy(contrast_logits, contrast_target.long(), ignore_index=self.ignore_label) + return loss_ppc + +class PPD(nn.Module): + def __init__(self, ignore_label=-1): + super(PPD, self).__init__() + self.ignore_label = ignore_label + + def forward(self, contrast_logits, contrast_target): + contrast_logits = contrast_logits[contrast_target != self.ignore_label, :] + contrast_target = contrast_target[contrast_target != self.ignore_label] + + logits = torch.gather(contrast_logits, 1, contrast_target[:, None].long()) + loss_ppd = (1 - logits).pow(2).mean() + return loss_ppd + +class PixelPrototypeCELoss(nn.Module): + def __init__(self, configer=None, use_rmi=False): + super(PixelPrototypeCELoss, self).__init__() + self.configer = configer + ignore_index = -1 + if 'loss' in self.configer and 'params' in self.configer['loss'] and 'ce_ignore_index' in self.configer['loss']['params']: + ignore_index = self.configer['loss']['params']['ce_ignore_index'] + + self.loss_ppc_weight = self.configer['protoseg']['loss_ppc_weight'] + self.loss_ppd_weight = self.configer['protoseg']['loss_ppd_weight'] + + self.use_rmi = use_rmi + + # Replace with your appropriate segmentation loss + if self.use_rmi: + self.seg_criterion = FSAuxRMILoss(configer=configer) + else: + self.seg_criterion = FSCELoss(configer=configer) + + self.ppc_criterion = PPC(ignore_label=ignore_index) + self.ppd_criterion = PPD(ignore_label=ignore_index) + + def forward(self, preds, target): + h, w = target.size(1), target.size(2) + if isinstance(preds, dict): + seg = preds['seg'] + contrast_logits = preds['logits'] + contrast_target = preds['target'] + loss_ppc = self.ppc_criterion(contrast_logits, contrast_target) + loss_ppd = self.ppd_criterion(contrast_logits, contrast_target) + + pred = F.interpolate(input=seg, size=(h, w), mode='bilinear', align_corners=True) + loss_seg = self.seg_criterion(pred, target) + return loss_seg + self.loss_ppc_weight * loss_ppc + self.loss_ppd_weight * loss_ppd + + # Assuming preds is the segmentation output directly + pred = F.interpolate(input=preds, size=(h, w), mode='bilinear', align_corners=True) + loss_seg = self.seg_criterion(pred, target) + return loss_seg + +def l2_normalize(x): + return x / (torch.norm(x, p=2, dim=-1, keepdim=True) + 1e-10) + +def momentum_update(old_value, new_value, momentum): + update = momentum * old_value + (1 - momentum) * new_value + return update \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/aaf/layers.py b/src/aaf/layers.py new file mode 100644 index 0000000..c3c9e5e --- /dev/null +++ b/src/aaf/layers.py @@ -0,0 +1,198 @@ +import torch +import torch.nn.functional as F +import numpy as np + + +def eightway_activation(x): + """Retrieves neighboring pixels/features on the eight corners from + a 3x3 patch. + + Args: + x: A tensor of size [batch_size, height_in, width_in, channels] + + Returns: + A tensor of size [batch_size, height_in, width_in, channels, 8] + """ + # Get the number of channels in the input. + shape_x = list(x.shape) + if len(shape_x) != 4: + raise ValueError('Only support for 4-D tensors!') + + # Pad at the margin. + x = F.pad(x, + pad=(0, 0, 1, 1, 1, 1, 0, 0), + mode='reflect') + # Get eight neighboring pixels/features. + x_groups = [ + x[:, 1:-1, :-2, :].clone(), # left + x[:, 1:-1, 2:, :].clone(), # right + x[:, :-2, 1:-1, :].clone(), # up + x[:, 2:, 1:-1, :].clone(), # down + x[:, :-2, :-2, :].clone(), # left-up + x[:, 2:, :-2, :].clone(), # left-down + x[:, :-2, 2:, :].clone(), # right-up + x[:, 2:, 2:, :].clone() # right-down + ] + output = [ + torch.unsqueeze(c, dim=-1) for c in x_groups + ] + output = torch.cat(output, dim=-1) + + return output + + +def eightcorner_activation(x, size): + """Retrieves neighboring pixels one the eight corners from a + (2*size+1)x(2*size+1) patch. + + Args: + x: A tensor of size [batch_size, height_in, width_in, channels] + size: A number indicating the half size of a patch. + + Returns: + A tensor of size [batch_size, height_in, width_in, channels, 8] + """ + # Get the number of channels in the input. + shape_x = list(x.shape) + if len(shape_x) != 4: + raise ValueError('Only support for 4-D tensors!') + n, c, h, w = shape_x + + # Pad at the margin. + p = size + x_pad = F.pad(x, + pad=(p, p, p, p, 0, 0, 0, 0), + mode='constant', + value=0) + + # Get eight corner pixels/features in the patch. + x_groups = [] + for st_y in range(0, 2 * size + 1, size): + for st_x in range(0, 2 * size + 1, size): + if st_y == size and st_x == size: + # Ignore the center pixel/feature. + continue + + x_neighbor = x_pad[:, :, st_y:st_y + h, st_x:st_x + w].clone() + x_groups.append(x_neighbor) + + output = [torch.unsqueeze(c, dim=-1) for c in x_groups] + output = torch.cat(output, dim=-1) + + return output + + +def ignores_from_label(labels, num_classes, size, ignore_index): + """Retrieves ignorable pixels from the ground-truth labels. + + This function returns a binary map in which 1 denotes ignored pixels + and 0 means not ignored ones. For those ignored pixels, they are not + only the pixels with label value >= num_classes, but also the + corresponding neighboring pixels, which are on the the eight cornerls + from a (2*size+1)x(2*size+1) patch. + + Args: + labels: A tensor of size [batch_size, height_in, width_in], indicating + semantic segmentation ground-truth labels. + num_classes: A number indicating the total number of valid classes. The + labels ranges from 0 to (num_classes-1), and any value >= num_classes + would be ignored. + size: A number indicating the half size of a patch. + + Return: + A tensor of size [batch_size, height_in, width_in, 8] + """ + # Get the number of channels in the input. + shape_lab = list(labels.shape) + if len(shape_lab) != 3: + raise ValueError('Only support for 3-D label tensors!') + n, h, w = shape_lab + + # Retrieve ignored pixels with label value >= num_classes. + # ignore = labels>num_classes-1 # NxHxW + ignore = (labels == ignore_index) + + # Pad at the margin. + p = size + ignore_pad = F.pad(ignore, + pad=(p, p, p, p, 0, 0), + mode='constant', + value=1) + + # Retrieve eight corner pixels from the center, where the center + # is ignored. Note that it should be bi-directional. For example, + # when computing AAF loss with top-left pixels, the ignored pixels + # might be the center or the top-left ones. + ignore_groups = [] + for st_y in range(2 * size, -1, -size): + for st_x in range(2 * size, -1, -size): + if st_y == size and st_x == size: + continue + ignore_neighbor = ignore_pad[:, st_y:st_y + h, st_x:st_x + w].clone() + mask = ignore_neighbor | ignore + ignore_groups.append(mask) + + ig = 0 + for st_y in range(0, 2 * size + 1, size): + for st_x in range(0, 2 * size + 1, size): + if st_y == size and st_x == size: + continue + ignore_neighbor = ignore_pad[:, st_y:st_y + h, st_x:st_x + w].clone() + mask = ignore_neighbor | ignore_groups[ig] + ignore_groups[ig] = mask + ig += 1 + + ignore_groups = [ + torch.unsqueeze(c, dim=-1) for c in ignore_groups + ] # NxHxWx1 + ignore = torch.cat(ignore_groups, dim=-1) # NxHxWx8 + + return ignore + + +def edges_from_label(labels, size, ignore_class=255): + """Retrieves edge positions from the ground-truth labels. + + This function computes the edge map by considering if the pixel values + are equal between the center and the neighboring pixels on the eight + corners from a (2*size+1)*(2*size+1) patch. Ignore edges where the any + of the paired pixels with label value >= num_classes. + + Args: + labels: A tensor of size [batch_size, height_in, width_in], indicating + semantic segmentation ground-truth labels. + size: A number indicating the half size of a patch. + ignore_class: A number indicating the label value to ignore. + + Return: + A tensor of size [batch_size, height_in, width_in, 1, 8] + """ + # Get the number of channels in the input. + shape_lab = list(labels.shape) + if len(shape_lab) != 4: + raise ValueError('Only support for 4-D label tensors!') + n, h, w, c = shape_lab + + # Pad at the margin. + p = size + labels_pad = F.pad( + labels, pad=(0, 0, p, p, p, p, 0, 0), + mode='constant', + value=ignore_class) + + # Get the edge by comparing label value of the center and it paired pixels. + edge_groups = [] + for st_y in range(0, 2 * size + 1, size): + for st_x in range(0, 2 * size + 1, size): + if st_y == size and st_x == size: + continue + labels_neighbor = labels_pad[:, st_y:st_y + h, st_x:st_x + w] + edge = labels_neighbor != labels + edge_groups.append(edge) + + edge_groups = [ + torch.unsqueeze(c, dim=-1) for c in edge_groups + ] # NxHxWx1x1 + edge = torch.cat(edge_groups, dim=-1) # NxHxWx1x8 + + return edge diff --git a/src/aaf/losses.py b/src/aaf/losses.py new file mode 100644 index 0000000..d25ab15 --- /dev/null +++ b/src/aaf/losses.py @@ -0,0 +1,192 @@ +import torch +import torch.nn.functional as F +from src.aaf import layers as nnx +import numpy as np + + +def affinity_loss(labels, + probs, + num_classes, + kld_margin): + """Affinity Field (AFF) loss. + + This function computes AFF loss. There are several components in the + function: + 1) extracts edges from the ground-truth labels. + 2) extracts ignored pixels and their paired pixels (the neighboring + pixels on the eight corners). + 3) extracts neighboring pixels on the eight corners from a 3x3 patch. + 4) computes KL-Divergence between center pixels and their neighboring + pixels from the eight corners. + + Args: + labels: A tensor of size [batch_size, height_in, width_in], indicating + semantic segmentation ground-truth labels. + probs: A tensor of size [batch_size, height_in, width_in, num_classes], + indicating segmentation predictions. + num_classes: A number indicating the total number of valid classes. + kld_margin: A number indicating the margin for KL-Divergence at edge. + + Returns: + Two 1-D tensors value indicating the loss at edge and non-edge. + """ + # Compute ignore map (e.g, label of 255 and their paired pixels). + + labels = torch.squeeze(labels, dim=1) # NxHxW + ignore = nnx.ignores_from_label(labels, num_classes, 1) # NxHxWx8 + not_ignore = np.logical_not(ignore) + not_ignore = torch.unsqueeze(not_ignore, dim=3) # NxHxWx1x8 + + # Compute edge map. + one_hot_lab = F.one_hot(labels, depth=num_classes) + edge = nnx.edges_from_label(one_hot_lab, 1, 255) # NxHxWxCx8 + + # Remove ignored pixels from the edge/non-edge. + edge = np.logical_and(edge, not_ignore) + not_edge = np.logical_and(np.logical_not(edge), not_ignore) + + edge_indices = torch.nonzero(torch.reshape(edge, (-1,))) + not_edge_indices = torch.nonzero(torch.reshape(not_edge, (-1,))) + + # Extract eight corner from the center in a patch as paired pixels. + probs_paired = nnx.eightcorner_activation(probs, 1) # NxHxWxCx8 + probs = torch.unsqueeze(probs, dim=-1) # NxHxWxCx1 + bot_epsilon = 1e-4 + top_epsilon = 1.0 + + neg_probs = np.clip( + 1 - probs, bot_epsilon, top_epsilon) + neg_probs_paired = np.clip( + 1 - probs_paired, bot_epsilon, top_epsilon) + probs = np.clip( + probs, bot_epsilon, top_epsilon) + probs_paired = np.clip( + probs_paired, bot_epsilon, top_epsilon) + + # Compute KL-Divergence. + kldiv = probs_paired * torch.log(probs_paired / probs) + kldiv += neg_probs_paired * torch.log(neg_probs_paired / neg_probs) + edge_loss = torch.max(0.0, kld_margin - kldiv) + not_edge_loss = kldiv + + not_edge_loss = torch.reshape(not_edge_loss, (-1,)) + not_edge_loss = torch.gather(not_edge_loss, 0, not_edge_indices) + edge_loss = torch.reshape(edge_loss, (-1,)) + edge_loss = torch.gather(edge_loss, 0, edge_indices) + + return edge_loss, not_edge_loss + +# from lib.utils.tools.logger import Logger as Log + +def adaptive_affinity_loss(labels, + one_hot_lab, + probs, + size, + num_classes, + kld_margin, + w_edge, + w_not_edge, + ignore_index=-1): + """Adaptive affinity field (AAF) loss. + + This function computes AAF loss. There are three components in the function: + 1) extracts edges from the ground-truth labels. + 2) extracts ignored pixels and their paired pixels (usually the eight corner + pixels). + 3) extracts eight corner pixels/predictions from the center in a + (2*size+1)x(2*size+1) patch + 4) computes KL-Divergence between center pixels and their paired pixels (the + eight corner). + 5) imposes adaptive weightings on the loss. + + Args: + labels: A tensor of size [batch_size, height_in, width_in], indicating + semantic segmentation ground-truth labels. + one_hot_lab: A tensor of size [batch_size, num_classes, height_in, width_in] + which is the ground-truth labels in the form of one-hot vector. + probs: A tensor of size [batch_size, num_classes, height_in, width_in], + indicating segmentation predictions. + size: A number indicating the half size of a patch. + num_classes: A number indicating the total number of valid classes. The + kld_margin: A number indicating the margin for KL-Divergence at edge. + w_edge: A number indicating the weighting for KL-Divergence at edge. + w_not_edge: A number indicating the weighting for KL-Divergence at non-edge. + ignore_index: ignore index + + Returns: + Two 1-D tensors value indicating the loss at edge and non-edge. + """ + # Compute ignore map (e.g, label of 255 and their paired pixels). + labels = torch.squeeze(labels, dim=1) # NxHxW + ignore = nnx.ignores_from_label(labels, num_classes, size, ignore_index) # NxHxWx8 + not_ignore = ~ignore + not_ignore = torch.unsqueeze(not_ignore, dim=3) # NxHxWx1x8 + + # Compute edge map. + edge = nnx.edges_from_label(one_hot_lab, size, ignore_index) # NxHxWxCx8 + + # Log.info('{} {}'.format(edge.shape, not_ignore.shape)) + + # Remove ignored pixels from the edge/non-edge. + edge = edge & not_ignore + + + not_edge = ~edge & not_ignore + + edge_indices = torch.nonzero(torch.reshape(edge, (-1,))) + # print(edge_indices.size()) + if edge_indices.size()[0] == 0: + edge_loss = torch.tensor(0.0, requires_grad=False).cuda() + not_edge_loss = torch.tensor(0.0, requires_grad=False).cuda() + return edge_loss, not_edge_loss + + not_edge_indices = torch.nonzero(torch.reshape(not_edge, (-1,))) + + # Extract eight corner from the center in a patch as paired pixels. + probs_paired = nnx.eightcorner_activation(probs, size) # NxHxWxCx8 + probs = torch.unsqueeze(probs, dim=-1) # NxHxWxCx1 + bot_epsilon = torch.tensor(1e-4, requires_grad=False).cuda() + top_epsilon = torch.tensor(1.0, requires_grad=False).cuda() + + neg_probs = torch.where(1 - probs < bot_epsilon, bot_epsilon, 1 - probs) + neg_probs = torch.where(neg_probs > top_epsilon, top_epsilon, neg_probs) + + neg_probs_paired = torch.where(1 - probs_paired < bot_epsilon, bot_epsilon, 1 - probs_paired) + neg_probs_paired = torch.where(neg_probs_paired > top_epsilon, top_epsilon, neg_probs_paired) + + probs = torch.where(probs < bot_epsilon, bot_epsilon, probs) + probs = torch.where(probs > top_epsilon, top_epsilon, probs) + + probs_paired = torch.where(probs_paired < bot_epsilon, bot_epsilon, probs_paired) + probs_paired = torch.where(probs_paired > top_epsilon, top_epsilon, probs_paired) + + # neg_probs = np.clip( + # 1-probs, bot_epsilon, top_epsilon) + # neg_probs_paired = np.clip( + # 1-probs_paired, bot_epsilon, top_epsilon) + # probs = np.clip( + # probs, bot_epsilon, top_epsilon) + # probs_paired = np.clip( + # probs_paired, bot_epsilon, top_epsilon) + + # Compute KL-Divergence. + kldiv = probs_paired * torch.log(probs_paired / probs) + kldiv += neg_probs_paired * torch.log(neg_probs_paired / neg_probs) + edge_loss = torch.max(torch.tensor(0.0, requires_grad=False).cuda(), kld_margin - kldiv) + not_edge_loss = kldiv + + # Impose weights on edge/non-edge losses. + one_hot_lab = torch.unsqueeze(one_hot_lab, dim=-1) + + w_edge = torch.sum(w_edge * one_hot_lab.float(), dim=3, keepdim=True) # NxHxWx1x1 + w_not_edge = torch.sum(w_not_edge * one_hot_lab.float(), dim=3, keepdim=True) # NxHxWx1x1 + + edge_loss *= w_edge.permute(0, 3, 1, 2, 4) + not_edge_loss *= w_not_edge.permute(0, 3, 1, 2, 4) + + not_edge_loss = torch.reshape(not_edge_loss, (-1, 1)) + not_edge_loss = torch.gather(not_edge_loss, 0, not_edge_indices) + edge_loss = torch.reshape(edge_loss, (-1, 1)) + edge_loss = torch.gather(edge_loss, 0, edge_indices) + + return edge_loss, not_edge_loss diff --git a/src/classifier.py b/src/classifier.py index 80c9d04..5ba7c78 100644 --- a/src/classifier.py +++ b/src/classifier.py @@ -5,7 +5,7 @@ from src.util import compute_wce from .util import to_one_hot - +from proto.PixelPrototypeCELoss import PixelPrototypeCELoss class Classifier(object): def __init__(self, args, base_weight, base_bias, n_tasks): @@ -251,7 +251,8 @@ def optimize(self, features_s: torch.tensor, features_q: torch.tensor, gt_s: tor valid_pixels_q : shape [batch_size_val, 1, h, w] """ l1, l2, l3, l4 = self.weights - + # //added section + criterion = PixelPrototypeCELoss() params = [self.novel_weight, self.novel_bias] if self.fine_tune_base_classifier: params.extend([self.base_weight, self.base_bias]) @@ -269,9 +270,12 @@ def optimize(self, features_s: torch.tensor, features_q: torch.tensor, gt_s: tor valid_pixels_q = F.interpolate(valid_pixels_q.float(), size=features_q.size()[-2:], mode='nearest').long() for iteration in range(self.adapt_iter): + # Create dictionary with required outputs + # preds = {"seg": seg_output, "logits": logits_q,"target": gt_q } logits_s, logits_q = self.get_logits(features_s), self.get_logits(features_q) proba_s, proba_q = self.get_probas(logits_s), self.get_probas(logits_q) - + preds = {"seg": logits_s, "logits": logits_q,"target": features_q } + loss_ppc = criterion(preds, features_q) snapshot_proba_q = self.get_base_snapshot_probas(features_q) distillation = self.distillation_loss(proba_q, snapshot_proba_q, valid_pixels_q, reduction='none') d_kl, entropy, marginal = self.get_entropies(valid_pixels_q, proba_q, reduction='none') diff --git a/src/classifierV1.py b/src/classifierV1.py new file mode 100644 index 0000000..9b6900a --- /dev/null +++ b/src/classifierV1.py @@ -0,0 +1,311 @@ +from typing import Tuple + +import torch +import torch.nn.functional as F + +from src.util import compute_wce +from .util import to_one_hot +from .PixelPrototypeCELoss import PixelPrototypeCELoss + +class Classifier(object): + def __init__(self, args, base_weight, base_bias, n_tasks,cfg: dict, backbone,): + self.num_base_classes_and_bg = base_weight.size(-1) + self.num_novel_classes = args.num_classes_val + self.num_classes = self.num_base_classes_and_bg + self.num_novel_classes + self.n_tasks = n_tasks + + self.snapshot_weight = base_weight.squeeze(0).squeeze(0).clone() # Snapshot of the model right after training, frozen + self.snapshot_bias = base_bias.clone() + self.base_weight = base_weight.squeeze(0).repeat(self.n_tasks, 1, 1) # [n_tasks, c, num_base_classes_and_bg] + self.base_bias = base_bias.unsqueeze(0).repeat(self.n_tasks, 1) # [n_tasks, num_base_classes_and_bg] + + self.novel_weight, self.novel_bias = None, None + self.pi, self.true_pi = None, None + + self.fine_tune_base_classifier = args.fine_tune_base_classifier + self.lr = args.cls_lr + self.adapt_iter = args.adapt_iter + self.weights = args.weights + self.pi_estimation_strategy = args.pi_estimation_strategy + self.pi_update_at = args.pi_update_at + self.pixel_prototype_cls = PixelPrototypeClassifier( + configer=cfg, + backbone=backbone, + feature_dim=self.feature_dim, + num_prototypes=20, + num_classes=self.num_classes + ).cuda() + self.pixel_prototype_loss = PixelPrototypeCELoss(configer=cfg).cuda() + + @staticmethod + def _valid_mean(t, valid_pixels, dim): + s = (valid_pixels * t).sum(dim=dim) + return s / (valid_pixels.sum(dim=dim) + 1e-10) + + + def init_prototypes(self, features_s: torch.tensor, gt_s: torch.tensor) -> None: + """ + inputs: + features_s : shape [num_novel_classes, shot, c, h, w] + gt_s : shape [num_novel_classes, shot, H, W] + """ + # 1. Downsample support masks to match feature map resolution + ds_gt_s = F.interpolate(gt_s.float(), size=features_s.shape[-2:], mode='nearest') + ds_gt_s = ds_gt_s.long().unsqueeze(2) # [n_novel_classes, shot, 1, h, w] + + # 2. Initialize novel class prototypes (from support set) + num_novel_classes = args.num_classes_val # Number of novel classes + + self.novel_prototypes = torch.zeros((features_s.size(2), num_novel_classes), device=features_s.device) + for cls in range(num_novel_classes): + + class_mask = (ds_gt_s == cls) + class_prototype = self._valid_mean(features_s, class_mask, (0, 1, 3, 4)) + self.novel_prototypes[:, cls] = class_prototype + + self.novel_prototypes /= self.novel_prototypes.norm(dim=0).unsqueeze(0) + 1e-10 + + + def get_logits(self, features: torch.tensor) -> torch.tensor: + """ + Computes logits for given features + + inputs: + features : shape [1 or batch_size_val, num_novel_classes * shot or 1, c, h, w] + + returns : + logits : shape [batch_size_val, num_novel_classes * shot or 1, num_classes, h, w] + + """ + equation = 'bochw,bcC->boChw' # 'o' is n_novel_classes * shot for support and is 1 for query + + novel_logits = torch.einsum(equation, features, self.novel_weight) + base_logits = torch.einsum(equation, features, self.base_prototypes) # Use base_prototypes + novel_logits += self.novel_bias.unsqueeze(1).unsqueeze(3).unsqueeze(4) + base_logits += self.base_bias.unsqueeze(1).unsqueeze(3).unsqueeze(4) + + logits = torch.concat([base_logits, novel_logits], dim=2) + return logits + + # equation = 'bochw,bcC->boChw' # 'o' is n_novel_classes * shot for support and is 1 for query + + # novel_logits = torch.einsum(equation, features, self.novel_weight) + # base_logits = torch.einsum(equation, features, self.base_weight) + # novel_logits += self.novel_bias.unsqueeze(1).unsqueeze(3).unsqueeze(4) + # base_logits += self.base_bias.unsqueeze(1).unsqueeze(3).unsqueeze(4) + + # logits = torch.concat([base_logits, novel_logits], dim=2) + # return logits + + @staticmethod + def get_probas(logits: torch.tensor) -> torch.tensor: + """ + inputs: + logits : shape [batch_size_val, num_novel_classes * shot or 1, num_classes, h, w] + + returns : + probas : shape [batch_size_val, num_novel_classes * shot or 1, num_classes, h, w] + """ + return torch.softmax(logits, dim=2) + + def get_base_snapshot_probas(self, features: torch.tensor) -> torch.tensor: + """ + Computes probability maps for given query features, using the snapshot of the base model right after the + training. It only computes values for base classes. + + inputs: + features : shape [batch_size_val, 1, c, h, w] + + returns : + probas : shape [batch_size_val, 1, num_base_classes_and_bg, h, w] + """ + logits = torch.einsum('bochw,cC->boChw', features, self.snapshot_weight) + self.snapshot_bias.view(1, 1, -1, 1, 1) + return torch.softmax(logits, dim=2) + + def self_estimate_pi(self, features_q: torch.tensor, unsqueezed_valid_pixels_q: torch.tensor) -> torch.tensor: + """ + Estimates pi using model's prototypes + + inputs: + features_q : shape [batch_size_val, 1, c, h, w] + unsqueezed_valid_pixels_q : shape [batch_size_val, 1, 1, h, w] + + returns : + pi : shape [batch_size_val, num_classes] + """ + logits_q = self.get_logits(features_q) + probas = torch.softmax(logits_q, dim=2).detach() + return self._valid_mean(probas, unsqueezed_valid_pixels_q, (1, 3, 4)) + + def image_level_supervision_pi(self, features_q: torch.tensor, + unsqueezed_valid_pixels_q: torch.tensor) -> torch.tensor: + """ + Estimates pi using model's prototypes and information about whether each class is present in a query image. + + inputs: + features_q : shape [batch_size_val, 1, c, h, w] + unsqueezed_valid_pixels_q : shape [batch_size_val, 1, 1, h, w] + + returns : + pi : shape [batch_size_val, num_classes] + """ + logits_q = self.get_logits(features_q) + absent_indices = torch.where(self.true_pi == 0) + logits_q[absent_indices[0], :, absent_indices[1], :, :] = -torch.inf + probas = torch.softmax(logits_q, dim=2).detach() + return self._valid_mean(probas, unsqueezed_valid_pixels_q, (1, 3, 4)) + + def compute_pi(self, features_q: torch.tensor, valid_pixels_q: torch.tensor, + gt_q: torch.tensor = None) -> torch.tensor: + """ + inputs: + features_q : shape [batch_size_val, 1, c, h, w] + valid_pixels_q : shape [batch_size_val, 1, h, w] + gt_q : shape [batch_size_val, 1, H, W] + """ + valid_pixels_q = F.interpolate(valid_pixels_q.float(), size=features_q.size()[-2:], mode='nearest').long() + valid_pixels_q = valid_pixels_q.unsqueeze(2) + + if gt_q is not None: + ds_gt_q = F.interpolate(gt_q.float(), size=features_q.size()[-2:], mode='nearest').long() + one_hot_gt_q = to_one_hot(ds_gt_q, self.num_classes) # [batch_size_val, shot, num_classes, h, w] + self.true_pi = self._valid_mean(one_hot_gt_q, valid_pixels_q, (1, 3, 4)) + + if self.pi_estimation_strategy == 'upperbound': + self.pi = self.true_pi + elif self.pi_estimation_strategy == 'self': + self.pi = self.self_estimate_pi(features_q, valid_pixels_q) + elif self.pi_estimation_strategy == 'imglvl': + self.pi = self.image_level_supervision_pi(features_q, valid_pixels_q) + elif self.pi_estimation_strategy == 'uniform': + pi = 1 / self.num_classes + self.pi = torch.full_like(self.true_pi, pi) # [batch_size_val, num_classes] + else: + raise ValueError('pi_estimation_strategy is not implemented') + + def distillation_loss(self, curr_p: torch.tensor, snapshot_p: torch.tensor, valid_pixels: torch.tensor, + reduction: str = 'mean') -> torch.tensor: + """ + inputs: + curr_p : shape [batch_size_val, 1, num_classes, h, w] + snapshot_p : shape [batch_size_val, 1, num_base_classes_and_bg, h, w] + valid_pixels : shape [batch_size_val, 1, h, w] + + returns: + kl : Distillation loss for the query + """ + adjusted_curr_p = curr_p.clone()[:, :, :self.num_base_classes_and_bg, ...] + adjusted_curr_p[:, :, 0, ...] += curr_p[:, :, self.num_base_classes_and_bg:, ...].sum(dim=2) + kl = (adjusted_curr_p * torch.log(1e-10 + adjusted_curr_p / (1e-10 + snapshot_p))).sum(dim=2) + kl = self._valid_mean(kl, valid_pixels, (1, 2, 3)) + if reduction == 'sum': + kl = kl.sum(0) + elif reduction == 'mean': + kl = kl.mean(0) + return kl + + def get_entropies(self, valid_pixels: torch.tensor, probas: torch.tensor, + reduction: str = 'mean') -> Tuple[torch.tensor, torch.tensor, torch.tensor]: + """ + inputs: + valid_pixels: shape [batch_size_val, 1, h, w] + probas : shape [batch_size_val, 1, num_classes, h, w] + + returns: + d_kl : Classes proportion kl + entropy : Entropy of predictions + marginal : Current marginal distribution over labels [batch_size_val, num_classes] + """ + entropy = - (probas * torch.log(probas + 1e-10)).sum(2) + entropy = self._valid_mean(entropy, valid_pixels, (1, 2, 3)) + marginal = self._valid_mean(probas, valid_pixels.unsqueeze(2), (1, 3, 4)) + + d_kl = (marginal * torch.log(1e-10 + marginal / (self.pi + 1e-10))).sum(1) + + if reduction == 'sum': + entropy = entropy.sum(0) + d_kl = d_kl.sum(0) + assert not torch.isnan(entropy), entropy + assert not torch.isnan(d_kl), d_kl + elif reduction == 'mean': + entropy = entropy.mean(0) + d_kl = d_kl.mean(0) + return d_kl, entropy, marginal + + def get_ce(self, probas: torch.tensor, valid_pixels: torch.tensor, one_hot_gt: torch.tensor, + reduction: str = 'mean') -> torch.tensor: + """ + inputs: + probas : shape [batch_size_val, num_novel_classes * shot, c, h, w] + valid_pixels : shape [1, num_novel_classes * shot, h, w] + one_hot_gt: shape [1, num_novel_classes * shot, num_classes, h, w] + + returns: + ce : Cross-Entropy between one_hot_gt and probas + """ + probas = probas.clone() + probas[:, :, 0, ...] += probas[:, :, 1:self.num_base_classes_and_bg, ...].sum(dim=2) + probas[:, :, 1:self.num_base_classes_and_bg, ...] = 0. + + ce = - (one_hot_gt * torch.log(probas + 1e-10)) + ce = (ce * compute_wce(one_hot_gt, self.num_novel_classes)).sum(2) + ce = self._valid_mean(ce, valid_pixels, (1, 2, 3)) # [batch_size_val,] + + if reduction == 'sum': + ce = ce.sum(0) + elif reduction == 'mean': + ce = ce.mean(0) + return ce + + def optimize(self, features_s: torch.tensor, features_q: torch.tensor, gt_s: torch.tensor, + valid_pixels_q: torch.tensor) -> torch.tensor: + """ + DIaM inference optimization + + inputs: + features_s : shape [num_novel_classes, shot, c, h, w] + features_q : shape [batch_size_val, 1, c, h, w] + gt_s : shape [num_novel_classes, shot, h, w] + valid_pixels_q : shape [batch_size_val, 1, h, w] + """ + l1, l2, l3, l4 = self.weights + + # Get the PixelPrototypeCELoss instance from the classifier + criterion = self.criterion # Access the criterion from the classifier object + + # Only novel prototypes are trainable parameters + params = [self.novel_prototypes] # Optimize only novel prototypes + for p in params: + p.requires_grad = True + optimizer = torch.optim.SGD(params, lr=self.lr) + + # Flatten the dimensions of different novel classes and shots + features_s = features_s.flatten(0, 1).unsqueeze(0) + gt_s = gt_s.flatten(0, 1).unsqueeze(0) + + ds_gt_s = F.interpolate(gt_s.float(), size=features_s.size()[-2:], mode='nearest').long() + one_hot_gt_s = to_one_hot(ds_gt_s, self.num_classes) # [1, num_novel_classes * shot, num_classes, h, w] + valid_pixels_s = (ds_gt_s != 255).float() + valid_pixels_q = F.interpolate(valid_pixels_q.float(), size=features_q.size()[-2:], mode='nearest').long() + + for iteration in range(self.adapt_iter): + logits_s, logits_q = self.get_logits(features_s), self.get_logits(features_q) + proba_s, proba_q = self.get_probas(logits_s), self.get_probas(logits_q) + + snapshot_proba_q = self.get_base_snapshot_probas(features_q) + distillation = self.distillation_loss(proba_q, snapshot_proba_q, valid_pixels_q, reduction='none') + d_kl, entropy, marginal = self.get_entropies(valid_pixels_q, proba_q, reduction='none') + + # Calculate loss with PixelPrototypeCELoss (using the instance from the classifier) + loss = classifier.criterion( + {'seg': proba_s, 'logits': logits_q, 'target': one_hot_gt_s}, ds_gt_s + ) + loss += l2 * d_kl + l3 * entropy + l4 * distillation # Keep your other loss terms + + optimizer.zero_grad() + loss.sum(0).backward() + optimizer.step() + + # Update pi + if (iteration + 1) in self.pi_update_at and (self.pi_estimation_strategy == 'self') and (l2 != 0): + self.compute_pi(features_q, valid_pixels_q) diff --git a/src/contrast.py b/src/contrast.py new file mode 100644 index 0000000..d8634d4 --- /dev/null +++ b/src/contrast.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def momentum_update(old_value, new_value, momentum, debug=False): + update = momentum * old_value + (1 - momentum) * new_value + if debug: + print("old prot: {:.3f} x |{:.3f}|, new val: {:.3f} x |{:.3f}|, result= |{:.3f}|".format( + momentum, torch.norm(old_value, p=2), (1 - momentum), torch.norm(new_value, p=2), + torch.norm(update, p=2))) + return update + + +def l2_normalize(x): + return F.normalize(x, p=2, dim=-1) + + +class ProjectionHead(nn.Module): + def __init__(self, dim_in, proj_dim=256): + super(ProjectionHead, self).__init__() + + self.proj = self.mlp2 = nn.Sequential( + nn.Conv2d(dim_in, dim_in, 1), + nn.ReLU(inplace=True), + nn.Conv2d(dim_in, proj_dim, 1)) + + def forward(self, x): + return l2_normalize(self.proj(x)) \ No newline at end of file diff --git a/src/dataset/classes.py b/src/dataset/classes.py index 4736828..c791bc3 100644 --- a/src/dataset/classes.py +++ b/src/dataset/classes.py @@ -1,6 +1,8 @@ import argparse from collections import defaultdict from typing import Dict, Any +import os +import json classId2className = {'coco': { 1: 'person', @@ -82,7 +84,8 @@ 77: 'scissors', 78: 'teddy bear', 79: 'hair drier', - 80: 'toothbrush'}, + 80: 'toothbrush' + }, 'pascal': { 1: 'airplane', @@ -114,7 +117,7 @@ className2classId[dataset][classId2className[dataset][id]] = id -def get_split_classes(args: argparse.Namespace) -> Dict[str, Any]: +def get_split_classes(cfg: dict, args: argparse.Namespace) -> Dict[str, Any]: """ Returns the split of classes for Pascal-5i, Pascal-10i and Coco-20i inputs: @@ -124,15 +127,52 @@ def get_split_classes(args: argparse.Namespace) -> Dict[str, Any]: split_classes : Dict. split_classes['coco'][0]['train'] = training classes in fold 0 of Coco-20i """ + + def save_splits_to_files(split_classes): + for dataset_name, splits in split_classes.items(): + dataset_dir = f"{dataset_name}_splits" + os.makedirs(dataset_dir, exist_ok=True) + for split_key, split in splits.items(): + file_path = os.path.join(dataset_dir, f"split_{split_key}.json") + with open(file_path, 'w') as f: + json.dump(split, f) + print(f"Saved split {split_key} for {dataset_name} to {file_path}") + + def load_splits_from_files(): + split_classes = {'coco': defaultdict(dict), 'pascal': defaultdict(dict)} + for dataset_name in split_classes.keys(): + dataset_dir = f"{dataset_name}_splits" + if not os.path.exists(dataset_dir): + continue + for split_file in os.listdir(dataset_dir): + split_key = int(split_file.split('_')[1].split('.')[0]) + file_path = os.path.join(dataset_dir, split_file) + with open(file_path, 'r') as f: + try: + split_classes[dataset_name][split_key] = json.load(f) + print(f"Loaded split {split_key} for {dataset_name} from {file_path}") + except json.JSONDecodeError as e: + print(f"Error loading {file_path}: {e}") + continue + return split_classes + + # Check if splits are already saved + if os.path.exists("coco_splits") or os.path.exists("pascal_splits"): + split_classes = load_splits_from_files() + else: + split_classes = {'coco': defaultdict(dict), 'pascal': defaultdict(dict)} + + split_classes = {'coco': defaultdict(dict), 'pascal': defaultdict(dict)} # =============== COCO =================== name = 'coco' class_list = list(range(1, 81)) + # class_list = list(range(1, 8)) split_classes[name][-1]['val'] = class_list - if args.use_split_coco: - vals_lists = [list(range(1, 78, 4)), list(range(2, 79, 4)), - list(range(3, 80, 4)), list(range(4, 81, 4))] + if cfg['DATA']['use_split_coco']: + vals_lists = [list(range(1, 78, 4)), list(range(2, 79, 4)),list(range(3, 80, 4)), list(range(4, 81, 4))] + # vals_lists = [list(range(1, 7, 4)), list(range(2, 8, 4)),list(range(3, 8, 4)), list(range(4, 8, 4))] for i, val_list in enumerate(vals_lists): split_classes[name][i]['val'] = val_list split_classes[name][i]['train'] = sorted(list(set(class_list) - set(val_list))) @@ -144,6 +184,8 @@ def get_split_classes(args: argparse.Namespace) -> Dict[str, Any]: for i, val_list in enumerate(vals_lists): split_classes[name][i]['val'] = val_list split_classes[name][i]['train'] = sorted(list(set(class_list) - set(val_list))) + + print(f'split classes is here.. {split_classes}') # =============== Pascal =================== name = 'pascal' @@ -157,5 +199,11 @@ def get_split_classes(args: argparse.Namespace) -> Dict[str, Any]: for i, val_list in vals_lists: split_classes[name][i]['val'] = val_list split_classes[name][i]['train'] = sorted(list(set(class_list) - set(val_list))) + + + save_splits_to_files(split_classes) + # print("CLASSES FROM....", split_classes) + # Now, let's create folders for each split + - return split_classes + return split_classes \ No newline at end of file diff --git a/src/dataset/dataset.py b/src/dataset/dataset.py index 34a6ad5..45bc2fa 100755 --- a/src/dataset/dataset.py +++ b/src/dataset/dataset.py @@ -13,21 +13,22 @@ from .utils import make_dataset -def get_val_loader(args: argparse.Namespace) -> torch.utils.data.DataLoader: +def get_val_loader(cfg: dict, args: argparse.Namespace) -> torch.utils.data.DataLoader: """ Build the validation loader. """ assert args.split in [0, 1, 2, 3, 10, 11, -1] val_transform = transform.Compose([ - transform.Resize(args.image_size), + transform.Resize(cfg['DATA']['image_size']), transform.ToTensor(), - transform.Normalize(mean=args.mean, std=args.std)]) - split_classes = get_split_classes(args) + transform.Normalize(mean=cfg['DATA']['mean'], std=cfg['DATA']['std'])]) + + split_classes = get_split_classes(cfg, args) # ===================== Get base and novel classes ===================== - print(f'Data: {args.data_name}, S{args.split}') - base_class_list = split_classes[args.data_name][args.split]['train'] - novel_class_list = split_classes[args.data_name][args.split]['val'] + print(f"Data: {cfg['DATA']['data_name']}, S{cfg['DATA']['split']}") + base_class_list = split_classes[cfg['DATA']['data_name']][cfg['DATA']['split']]['train'] + novel_class_list = split_classes[cfg['DATA']['data_name']][cfg['DATA']['split']]['val'] print('Novel classes:', novel_class_list) args.num_classes_tr = len(base_class_list) + 1 # +1 for bg args.num_classes_val = len(novel_class_list) @@ -37,15 +38,17 @@ def get_val_loader(args: argparse.Namespace) -> torch.utils.data.DataLoader: val_data = MultiClassValData(transform=val_transform, base_class_list=base_class_list, novel_class_list=novel_class_list, - data_list_path_train=args.train_list, - data_list_path_test=args.val_list, - args=args) + data_list_path_train=cfg['DATA']['train_list'], + data_list_path_test=cfg['DATA']['val_list'], + args=args, + cfg=cfg) + val_loader = torch.utils.data.DataLoader(val_data, - batch_size=args.batch_size_val, + batch_size=cfg['EVALUATION']['batch_size_val'], drop_last=False, - shuffle=args.shuffle_test_data, - num_workers=args.workers, - pin_memory=args.pin_memory, + shuffle=cfg['EVALUATION']['shuffle_test_data'], + num_workers=cfg['DATA']['workers'], + pin_memory=cfg['DATA']['pin_memory'], sampler=val_sampler) return val_loader @@ -190,22 +193,22 @@ def __getitem__(self, index): class MultiClassValData(Dataset): def __init__(self, transform: transform.Compose, base_class_list: List[int], novel_class_list: List[int], - data_list_path_train: str, data_list_path_test: str, args: argparse.Namespace): - self.support_only_one_novel = args.support_only_one_novel - self.use_training_images_for_supports = args.use_training_images_for_supports + data_list_path_train: str, data_list_path_test: str, args: argparse.Namespace,cfg: dict): + self.support_only_one_novel = cfg['EVALUATION']['support_only_one_novel'] + self.use_training_images_for_supports = cfg['EVALUATION']['use_training_images_for_supports'] assert not self.use_training_images_for_supports or data_list_path_train support_data_list_path = data_list_path_train if self.use_training_images_for_supports else data_list_path_test - self.shot = args.shot - self.data_root = args.data_root + self.shot = cfg['EVALUATION']['shot'] + self.data_root = cfg['DATA']['data_root'] self.base_class_list = base_class_list # Does not contain bg self.novel_class_list = novel_class_list - self.query_data_list, _ = make_dataset(args.data_root, data_list_path_test, + self.query_data_list, _ = make_dataset(cfg['DATA']['data_root'], data_list_path_test, self.base_class_list + self.novel_class_list, keep_small_area_classes=True) self.complete_query_data_list = self.query_data_list.copy() print('Total number of kept images (query):', len(self.query_data_list)) - support_data_list, self.support_sub_class_file_list = make_dataset(args.data_root, support_data_list_path, + support_data_list, self.support_sub_class_file_list = make_dataset(cfg['DATA']['data_root'], support_data_list_path, self.novel_class_list, keep_small_area_classes=False) print('Total number of kept images (support):', len(support_data_list)) @@ -236,14 +239,17 @@ def __getitem__(self, index): # It only gives the query return qry_img, label, valid_pixels, image_path def generate_support(self, query_image_path_list, remove_them_from_query_data_list=False): + print(f'generating data...') image_list, label_list = list(), list() support_image_path_list, support_label_path_list = list(), list() for c in self.novel_class_list: + file_class_chosen = self.support_sub_class_file_list[c] num_file = len(file_class_chosen) indices_list = list(range(num_file)) random.shuffle(indices_list) current_path_list = list() + print(f"NOVEL DOWN...") for idx in indices_list: if len(current_path_list) >= self.shot: break @@ -268,26 +274,36 @@ def generate_support(self, query_image_path_list, remove_them_from_query_data_li indices_to_repeat = random.choices(range(found_images_count), k=self.shot-found_images_count) image_list.extend([image_list[i] for i in indices_to_repeat]) label_list.extend([label_list[i] for i in indices_to_repeat]) - + + transformed_image_list, transformed_label_list = list(), list() + print(f"NOVEL SHOT...{self.shot}") if self.shot == 1: for i, l in zip(image_list, label_list): transformed_i, transformed_l = self.transform(i, l) transformed_image_list.append(transformed_i.unsqueeze(0)) transformed_label_list.append(transformed_l.unsqueeze(0)) else: + print(f'OTHER {self.shot}..') with Pool(self.shot) as pool: + print(f'LOOPING STARTING HERE...') + print(f'putting data together..{len(image_list)}') + print(f'putting data together..{len(label_list)}') for transformed_i, transformed_l in pool.starmap(self.transform, zip(image_list, label_list)): + print(f'LOOPING.....WORKING.....') transformed_image_list.append(transformed_i.unsqueeze(0)) transformed_label_list.append(transformed_l.unsqueeze(0)) + print("LOOP.....") pool.close() pool.join() - + print("FINAL") + print(f'putting data together..') spprt_imgs = torch.cat(transformed_image_list, 0) spprt_labels = torch.cat(transformed_label_list, 0) - + print(f'support length...') if remove_them_from_query_data_list and not self.use_training_images_for_supports: self.query_data_list = self.complete_query_data_list.copy() for i, l in zip(support_image_path_list, support_label_path_list): self.query_data_list.remove((i, l)) + print(f'exiting generating of data') return spprt_imgs, spprt_labels diff --git a/src/dataset/datasetV1.py b/src/dataset/datasetV1.py new file mode 100644 index 0000000..20d2290 --- /dev/null +++ b/src/dataset/datasetV1.py @@ -0,0 +1,335 @@ +import argparse +import random +from multiprocessing import Pool +from typing import List +import sys +import cv2 +import numpy as np +import torch +from torch.utils.data import Dataset + +import src.dataset.transform as transform +from .classes import get_split_classes +from .utils import make_dataset +from torchvision.transforms.functional import resize +# from torchvision import transforms + +def get_val_loader(cfg: dict, args: argparse.Namespace) -> torch.utils.data.DataLoader: + """ + Build the validation loader. + """ + assert args.split in [0, 1, 2, 3, 10, 11, -1] + val_transform = transform.Compose([ + transform.Resize(cfg['DATA']['image_size']), + transform.ToTensor(), + transform.Normalize(mean=cfg['DATA']['mean'], std=cfg['DATA']['std'])]) + + split_classes = get_split_classes(cfg, args) + # ===================== Get base and novel classes ===================== + # print(f"Data: {cfg['DATA']['data_name']}, S{cfg['DATA']['split']}") + # print(f'Data: {cfg['DATA']['data_name']}, S{cfg['DATA']['split']}') + + base_class_list = split_classes[cfg['DATA']['data_name']][cfg['DATA']['split']]['train'] + novel_class_list = split_classes[cfg['DATA']['data_name']][cfg['DATA']['split']]['val'] + print('Novel classes:', novel_class_list) + print('Base classes:', base_class_list) + args.num_classes_tr = len(base_class_list) + 1 # +1 for bg + args.num_classes_val = len(novel_class_list) + print(f"Novel classes {args.num_classes_val} {args.num_classes_tr}") # Add this line + # sys.exit(1) + # ===================== Build loader ===================== + val_sampler = None + val_data = MultiClassValData(transform=val_transform, + base_class_list=base_class_list, + novel_class_list=novel_class_list, + data_list_path_train=cfg['DATA']['train_list'], + data_list_path_test=cfg['DATA']['val_list'], + args=args, + cfg=cfg) + + val_loader = torch.utils.data.DataLoader(val_data, + batch_size=cfg['EVALUATION']['batch_size_val'], + drop_last=False, + shuffle=cfg['EVALUATION']['shuffle_test_data'], + num_workers=cfg['DATA']['workers'], + pin_memory=cfg['DATA']['pin_memory'], + sampler=val_sampler) + + # prepare data iterator + + print(f'number of novel class in dataset preparation... {len(val_data.novel_class_list)}') + print(f'number of ALL class in dataset preparation... {len(val_data.all_classes)}') + args.num_novel_classes = len(val_data.novel_class_list) + total_samples = len(val_loader) + # print(f"Total data samples: {total_samples}") + return val_loader +# data/coco/val2014/ +def get_image_and_label(image_path, label_path): + image = cv2.imread(image_path, cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = np.float32(image) + label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) + if image.shape[0] != label.shape[0] or image.shape[1] != label.shape[1]: + raise (RuntimeError("Image & label shape mismatch: " + image_path + " " + label_path + "\n")) + return image, label + +def adjust_label(base_class_list, novel_class_list, label, chosen_novel_class, base_label=-1, other_novels_label=255): + # -1 for base_label or other_novels_label means including the true labels + assert base_label in [-1, 0, 255] and other_novels_label in [-1, 0, 255] + new_label = np.zeros_like(label) # background + for lab in base_class_list: + indexes = np.where(label == lab) + if base_label == -1: + new_label[indexes[0], indexes[1]] = base_class_list.index(lab) + 1 # Add 1 because class 0 is bg + else: + new_label[indexes[0], indexes[1]] = base_label + + for lab in novel_class_list: + indexes = np.where(label == lab) + if other_novels_label == -1: + new_label[indexes[0], indexes[1]] = 1 + len(base_class_list) + novel_class_list.index(lab) + elif lab == chosen_novel_class: + new_label[indexes[0], indexes[1]] = 1 + len(base_class_list) + else: + new_label[indexes[0], indexes[1]] = other_novels_label + + ignore_pix = np.where(label == 255) + new_label[ignore_pix] = 255 + + return new_label + + +class ClassicValData(Dataset): + def __init__(self, transform: transform.Compose, base_class_list: List[int], novel_class_list: List[int], + data_list_path_train: str, data_list_path_test: str, args: argparse.Namespace): + assert args.support_only_one_novel + self.shot = args.shot + self.data_root = args.data_root + self.base_class_list = base_class_list + self.novel_class_list = novel_class_list + self.transform = transform + + self.use_training_images_for_supports = args.use_training_images_for_supports + assert not self.use_training_images_for_supports or data_list_path_train + support_data_list_path = data_list_path_train if self.use_training_images_for_supports else data_list_path_test + + self.query_data_list, _ = make_dataset(args.data_root, data_list_path_test, + self.base_class_list + self.novel_class_list, + keep_small_area_classes=True) + print('Total number of kept images (query):', len(self.query_data_list)) + self.support_data_list, self.support_sub_class_file_list = make_dataset(args.data_root, support_data_list_path, + self.novel_class_list, + keep_small_area_classes=False) + print('Total number of kept images (support):', len(self.support_data_list)) + + @property + def num_novel_classes(self): + return len(self.novel_class_list) + + @property + def all_classes(self): + return [0] + self.base_class_list + self.novel_class_list + + def _adjust_label(self, label, chosen_novel_class, base_label=-1, other_novels_label=255): + return adjust_label(self.base_class_list, self.novel_class_list, + label, chosen_novel_class, base_label, other_novels_label) + + def __len__(self): + return len(self.query_data_list) + + def __getitem__(self, index): + # ========= Read query image and Choose class ======================= + image_path, label_path = self.query_data_list[index] + qry_img, label = get_image_and_label(image_path, label_path) + + # # =========== RESIZE QUERY IMAGE HERE ============= + # desired_output_size = (26, 26) # Example: Set your desired output size + # required_input_size = (desired_output_size[0] * 8, + # desired_output_size[1] * 8) + # qry_img = cv2.resize(qry_img, required_input_size, interpolation=cv2.INTER_LINEAR) + if self.transform is not None: + qry_img, label = self.transform(qry_img, label) + # print(f"Query image and label: {qry_img.shape}.....{label.shape}") # Add this line + # desired_output_size = (26, 26) # Example: Set your desired output size + # required_input_size = (desired_output_size[0] * 8, + # desired_output_size[1] * 8) + + # # Resize using PyTorch + # qry_img = resize(qry_img, required_input_size, interpolation=transforms.InterpolationMode.BILINEAR) + # print(f"Query image after transform shape: {qry_img.shape}") # Add this line + # == From classes in the query image, choose one randomly === + label_class = set(np.unique(label)) + label_class -= {0, 255} + novel_classes_in_image = list(label_class.intersection(set(self.novel_class_list))) + if len(novel_classes_in_image) > 0: + class_chosen = np.random.choice(novel_classes_in_image) + else: + class_chosen = np.random.choice(self.novel_class_list) + + q_valid_pixels = (label != 255).float() + target = self._adjust_label(label, class_chosen, base_label=-1, other_novels_label=0) + + support_image_list = [] + support_label_list = [] + + file_class_chosen = self.support_sub_class_file_list[class_chosen] + num_file = len(file_class_chosen) + + # ========= Build support ============================================== + # == First, randomly choose indexes of support images == + support_image_path_list = [] + support_label_path_list = [] + support_idx_list = [] + + for _ in range(self.shot): + support_idx = random.randint(1, num_file) - 1 + support_image_path = image_path + support_label_path = label_path + while (support_image_path == image_path and support_label_path == label_path) or support_idx in support_idx_list: + support_idx = random.randint(1, num_file) - 1 + support_image_path, support_label_path = file_class_chosen[support_idx] + support_idx_list.append(support_idx) + support_image_path_list.append(support_image_path) + support_label_path_list.append(support_label_path) + + # == Second, read support images and masks ============ + for k in range(self.shot): + support_image_path, support_label_path = support_image_path_list[k], support_label_path_list[k] + support_image, support_label = get_image_and_label(support_image_path, support_label_path) + support_label = self._adjust_label(support_label, class_chosen, base_label=0, other_novels_label=0) + support_image_list.append(support_image) + support_label_list.append(support_label) + + # == Forward images through transforms ================= + if self.transform is not None: + for k in range(len(support_image_list)): + support_image_list[k], support_label_list[k] = self.transform(support_image_list[k], support_label_list[k]) + support_image_list[k] = support_image_list[k].unsqueeze(0) + support_label_list[k] = support_label_list[k].unsqueeze(0) + + # == Reshape properly ================================== + spprt_imgs = torch.cat(support_image_list, 0) + spprt_labels = torch.cat(support_label_list, 0) + + return qry_img, target, q_valid_pixels, spprt_imgs, spprt_labels, class_chosen + +class MultiClassValData(Dataset): + def __init__(self, transform: transform.Compose, base_class_list: List[int], novel_class_list: List[int], + data_list_path_train: str, data_list_path_test: str, args: argparse.Namespace, cfg: dict): + self.support_only_one_novel = cfg['EVALUATION']['support_only_one_novel'] + self.use_training_images_for_supports = cfg['EVALUATION']['use_training_images_for_supports'] + assert not self.use_training_images_for_supports or data_list_path_train + support_data_list_path = data_list_path_train if self.use_training_images_for_supports else data_list_path_test + + self.shot = cfg['EVALUATION']['shot'] + self.data_root = cfg['DATA']['data_root'] + self.base_class_list = base_class_list # Does not contain bg + self.novel_class_list = novel_class_list + self.query_data_list, _ = make_dataset(cfg['DATA']['data_root'], data_list_path_test, + self.base_class_list + self.novel_class_list, + keep_small_area_classes=True) + self.complete_query_data_list = self.query_data_list.copy() + print('Total number of kept images (query)-MULTICLASS:', len(self.query_data_list)) + support_data_list, self.support_sub_class_file_list = make_dataset(cfg['DATA']['data_root'], support_data_list_path, + self.novel_class_list, + keep_small_area_classes=False) + print('Total number of kept images (support):', len(support_data_list)) + self.transform = transform + + @property + def num_novel_classes(self): + return len(self.novel_class_list) + + @property + def all_classes(self): + return [0] + self.base_class_list + self.novel_class_list + + def _adjust_label(self, label, chosen_novel_class, base_label=-1, other_novels_label=255): + return adjust_label(self.base_class_list, self.novel_class_list, + label, chosen_novel_class, base_label, other_novels_label) + + def __len__(self): + return len(self.query_data_list) + + def __getitem__(self, index): # It only gives the query + image_path, label_path = self.query_data_list[index] + qry_img, label = get_image_and_label(image_path, label_path) + # print(f"Query image original shape: {qry_img.shape}") # Add this line + + label = self._adjust_label(label, -1, base_label=-1, other_novels_label=-1) + if self.transform is not None: + qry_img, label = self.transform(qry_img, label) + + # print(f"Query image after transform shape: {qry_img.shape}") # Add this line torch.Size([3, 417, 417]) + valid_pixels = (label != 255).float() + # query image get item shape ..torch.Size([3, 417, 417]) and the ..... lable ... torch.Size([417, 417]) + # print(f'query image get item shape ..{qry_img.shape} and the ..... lable ... {label.shape}') + return qry_img, label, valid_pixels, image_path + + def generate_support(self, query_image_path_list, remove_them_from_query_data_list=False): + # print("GENERATION_SUPPORT_IMAGES......") + image_list, label_list = list(), list() + support_image_path_list, support_label_path_list = list(), list() + # print(f"Number of novel classes before: {len(self.novel_class_list)}") + for c in self.novel_class_list: + file_class_chosen = self.support_sub_class_file_list[c] + num_file = len(file_class_chosen) + indices_list = list(range(num_file)) + random.shuffle(indices_list) + current_path_list = list() + for idx in indices_list: + if len(current_path_list) >= self.shot: + break + image_path, label_path = file_class_chosen[idx] + if image_path in (query_image_path_list + current_path_list): + continue + image, label = get_image_and_label(image_path, label_path) + # print(f'image after resizing is...{image.shape}') + # print(f'Support image original shape: {image.shape}') # Add this line + + if self.support_only_one_novel: # Ignore images that have multiple novel classes + present_novel_classes = set(np.unique(label)) - {0, 255} - set(self.base_class_list) + if len(present_novel_classes) > 1: + continue + + label = self._adjust_label(label, -1, base_label=0, other_novels_label=-1) # If support_only_one_novel is True, images with more than one novel classes won't reach this line. So, -1 won't make the image contain two different novel classes. + image_list.append(image) + label_list.append(label) + current_path_list.append(image_path) + support_image_path_list.append(image_path) + support_label_path_list.append(label_path) + found_images_count = len(current_path_list) + assert found_images_count > 0, f'No support candidate for class {c} out of {num_file} images' + if found_images_count < self.shot: + indices_to_repeat = random.choices(range(found_images_count), k=self.shot-found_images_count) + image_list.extend([image_list[i] for i in indices_to_repeat]) + label_list.extend([label_list[i] for i in indices_to_repeat]) + + transformed_image_list, transformed_label_list = list(), list() + if self.shot == 1: + for i, l in zip(image_list, label_list): + transformed_i, transformed_l = self.transform(i, l) + # print(f"Support image after transform: {transformed_i.shape}") + transformed_image_list.append(transformed_i.unsqueeze(0)) + transformed_label_list.append(transformed_l.unsqueeze(0)) + else: + with Pool(self.shot) as pool: + for transformed_i, transformed_l in pool.starmap(self.transform, zip(image_list, label_list)): + transformed_image_list.append(transformed_i.unsqueeze(0)) + transformed_label_list.append(transformed_l.unsqueeze(0)) + pool.close() + pool.join() + + spprt_imgs = torch.cat(transformed_image_list, 0) + spprt_labels = torch.cat(transformed_label_list, 0) + + if remove_them_from_query_data_list and not self.use_training_images_for_supports: + self.query_data_list = self.complete_query_data_list.copy() + for i, l in zip(support_image_path_list, support_label_path_list): + self.query_data_list.remove((i, l)) + + print("RETURNING GENERATE IMAGES.....") + print(f'{spprt_labels.shape}') + # Query image after transform shape generation support section: torch.Size([20, 3, 417, 417]) and label .....torch.Size([20, 417, 417]) + print(f"Query image after transform shape generation support section: {spprt_imgs.shape} and label .....{spprt_labels.shape}") + return spprt_imgs, spprt_labels diff --git a/src/loss_helper.py b/src/loss_helper.py new file mode 100644 index 0000000..8684c9e --- /dev/null +++ b/src/loss_helper.py @@ -0,0 +1,132 @@ +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +## Created by: Donny You, RainbowSecret +## Microsoft Research +## yuyua@microsoft.com +## Copyright (c) 2019 +## +## This source code is licensed under the MIT-style license found in the +## LICENSE file in the root directory of this source tree +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import pdb +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F +from torch.autograd import Variable +from .rmi_loss import RMILoss + +from .lovasz_loss import lovasz_softmax_flat, flatten_probas + + + +# Cross-entropy Loss +class FSCELoss(nn.Module): + def __init__(self, configer=None): + super(FSCELoss, self).__init__() + self.configer = configer + weight = None + # if 'loss' in self.configer and 'params' in self.configer['loss'] and 'ce_weight' in self.configer['loss']['params']: + # weight = self.configer["loss"]["params"]["ce_weight"] + # weight = torch.FloatTensor(weight).cuda() + # if self.configer.exists('loss', 'params') and 'ce_weight' in self.configer.get('loss', 'params'): + # weight = self.configer.get('loss', 'params')['ce_weight'] + # weight = torch.FloatTensor(weight).cuda() + + reduction = 'mean' + if 'loss' in self.configer and 'params' in self.configer['loss'] and 'ce_reduction' in self.configer['loss']['params']: + reduction = self.configer['loss']['params']['ce_reduction'] + # if self.configer.exists('loss', 'params') and 'ce_reduction' in self.configer.get('loss', 'params'): + # reduction = self.configer.get('loss', 'params')['ce_reduction'] + + ignore_index = -1 + # if self.configer.exists('loss', 'params') and 'ce_ignore_index' in self.configer.get('loss', 'params'): + # ignore_index = self.configer.get('loss', 'params')['ce_ignore_index'] + if 'loss' in self.configer and 'params' in self.configer['loss'] and 'ce_ignore_index' in self.configer['loss']['params']: + ignore_index = self.configer['loss']['params']['ce_ignore_index'] + + # print(f'ignore index......{ignore_index}') + # print(f'weight index......{weight.shape}') + # print(f'reduction index......{reduction}') + # self.ce_loss = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, reduction=reduction) + self.ce_loss = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, reduction=reduction) + + def forward(self, inputs, *targets, weights=None, **kwargs): + loss = 0.0 + if isinstance(inputs, tuple) or isinstance(inputs, list): + if weights is None: + weights = [1.0] * len(inputs) + + for i in range(len(inputs)): + if len(targets) > 1: + target = self._scale_target(targets[i], (inputs[i].size(2), inputs[i].size(3))) + loss += weights[i] * self.ce_loss(inputs[i], target) + else: + target = self._scale_target(targets[0], (inputs[i].size(2), inputs[i].size(3))) + loss += weights[i] * self.ce_loss(inputs[i], target) + + else: + # print(f'target shape dimension ....{targets[0].shape}') + # targets = targets[0].squeeze(1) + targets = targets[0].clone().unsqueeze(1).float() + # print(f'type of input object.....{type(inputs)} and type of target object....{type(targets)}') + # for key, value in inputs.items(): + # print(f"{key}: {value.shape}") + + # target = self._scale_target(targets[0], (inputs.size(2), inputs.size(3))) + # print(f"the segmentation......{inputs['seg'].shape} and {inputs['seg'].size(2)}" ) + # print(f'INPUT IN FORWARD OF LOSS {inputs.shape} {inputs.size(2)}') + # target = self._scale_target(targets[0], (inputs['seg'].size(2), inputs['seg'].size(3))) + target = self._scale_target(targets[0], (inputs.size(2), inputs.size(3))) + # INPUT TO LOSS SHAPE......torch.Size([20, 81, 53, 53]) + # print(f"INPUT TO LOSS SHAPE......{inputs['seg'].shape}") + # TARGET TO LOSS SHAPE......torch.Size([1, 53, 53]) + # print(f"TARGET TO LOSS SHAPE......{target.shape}") + # TARGET SHAPE IN FORWARD OF LOSS torch.Size([1, 53, 53]) INPUT SHAPE IN FORWARD OF LOSS torch.Size([20, 81, 53, 53]) + # print(f'TARGET SHAPE IN FORWARD OF LOSS {target.shape} INPUT SHAPE IN FORWARD OF LOSS {inputs.shape}') + # target = target.repeat(20, 1, 1) + # .torch.Size([20, 1, 417, 417]) + # inputs = inputs.reshape(1, 20, 53, 53) + # inputs = inputs.view(1, 20, 53, 53) + # inputs = inputs.repeat(1, 20, 1, 1) + inputs = inputs.reshape(1, -1, 53, 53) + # print(f'TARGET SHAPE IN FORWARD OF LOSS {target.shape} INPUT SHAPE IN FORWARD OF LOSS {inputs.shape}') + # target = target.repeat(20, 1, 1) # Repeat the target tensor along the first dimension (batch size) + # print(f"TARGET TO LOSS SHAPED HERE......{target.shape}") # Output: TARGET TO LOSS SHAPE......torch.Size([20, 53, 53]) + # TARGET SHAPE IN FORWARD OF LOSS torch.Size([1, 512, 1024]) INPUT SHAPE IN FORWARD OF LOSS torch.Size([1, 19, 512, 1024]) + loss = self.ce_loss(inputs, target) + # print(f'loss value....{loss}') + + return loss + + @staticmethod + def _scale_target(targets_, scaled_size): + targets = targets_.clone().unsqueeze(1).float() + # targets = targets_.clone().float() + # print(f'shape of target......{targets.shape}') + # print(f'shape of scaled_size...{scaled_size}') + # Remove singleton dimensions + # targets = targets.reshape(20, 1, 417, 417) + # print(f'shape of reshape target...{targets.shape}') + targets = F.interpolate(targets, size=scaled_size, mode='nearest') + return targets.squeeze(1).long() + +class FSAuxRMILoss(nn.Module): + def __init__(self, configer=None): + super(FSAuxRMILoss, self).__init__() + self.configer = configer + self.ce_loss = FSCELoss(self.configer) + self.rmi_loss = RMILoss(self.configer) + + def forward(self, inputs, targets, **kwargs): + aux_out, seg_out = inputs + aux_loss = self.ce_loss(aux_out, targets) + seg_loss = self.rmi_loss(seg_out, targets) + loss = self.configer.get('network', 'loss_weights')['seg_loss'] * seg_loss + loss = loss + self.configer.get('network', 'loss_weights')['aux_loss'] * aux_loss + return loss diff --git a/src/lovasz_loss.py b/src/lovasz_loss.py new file mode 100644 index 0000000..0a1eca8 --- /dev/null +++ b/src/lovasz_loss.py @@ -0,0 +1,431 @@ +from itertools import filterfalse as ifilterfalse + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable + +from torch.nn import BCELoss +from src.aaf import losses as lossx + + + +# weights +# ATR training +# [0.85978634, 1.19630769, 1.02639146, 1.30664970, 0.97220603, 1.04885815, +# 1.01745278, 1.01481690, 1.27155077, 1.12947663, 1.13016390, 1.06514227, +# 1.08384483, 1.08506841, 1.09560942, 1.09565198, 1.07504567, 1.20411509] + +# CCF +# [0.82073458, 1.23651165, 1.0366326, 0.97076566, 1.2802332, 0.98860602, +# 1.29035071, 1.03882453, 0.96725283, 1.05142434, 1.0075884, 0.98630539, +# 1.06208869, 1.0160915, 1.1613597, 1.17624919, 1.1701143, 1.24720215] + +# PPSS +# [0.89680465, 1.14352656, 1.20982646, 0.99269248, +# 1.17911144, 1.00641032, 1.47017195, 1.16447113] + +# Pascal +# [0.82877791, 0.95688253, 0.94921949, 1.00538108, 1.0201687, 1.01665831, 1.05470914] + +# Lip +# [0.7602572, 0.94236198, 0.85644457, 1.04346266, 1.10627293, 0.80980162, +# 0.95168713, 0.8403769, 1.05798412, 0.85746254, 1.01274366, 1.05854692, +# 1.03430773, 0.84867818, 0.88027721, 0.87580925, 0.98747462, 0.9876475, +# 1.00016535, 1.00108882] + +class ABRLovaszLoss(nn.Module): + """Lovasz loss for Alpha process""" + + def __init__(self, ignore_index=None, only_present=True): + super(ABRLovaszLoss, self).__init__() + self.ignore_index = ignore_index + self.only_present = only_present + # self.weight = torch.FloatTensor([0.80777327, 1.00125961, 0.90997236, 1.10867908, 1.17541499, + # 0.86041422, 1.01116758, 0.89290045, 1.12410812, 0.91105395, + # 1.07604013, 1.12470610, 1.09895196, 0.90172057, 0.93529453, + # 0.93054733, 1.04919178, 1.04937547, 1.06267568, 1.06365688]) + self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index) + + def forward(self, preds, targets): + h, w = targets[0].size(1), targets[0].size(2) + # seg loss + pred = F.interpolate(input=preds[0], size=(h, w), mode='bilinear', align_corners=True) + pred = F.softmax(input=pred, dim=1) + loss = lovasz_softmax_flat(*flatten_probas(pred, targets[0], self.ignore_index), only_present=self.only_present) + + # dsn loss + pred_dsn = F.interpolate(input=preds[-1], size=(h, w), mode='bilinear', align_corners=True) + loss_dsn = self.criterion(pred_dsn, targets[0]) + return loss + 0.4 * loss_dsn + + +class SegmentationLoss(nn.Module): + """Lovasz loss for Alpha process""" + + def __init__(self, ignore_index=None, only_present=True): + super(SegmentationLoss, self).__init__() + self.ignore_index = ignore_index + self.only_present = only_present + + self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index) + + def forward(self, preds, targets): + h, w = targets.size(1), targets.size(2) + # seg loss + pred = F.interpolate(input=preds[0], size=(h, w), mode='bilinear', align_corners=True) + loss_ce = self.criterion(pred, targets) + + # dsn loss + pred_dsn = F.interpolate(input=preds[-1], size=(h, w), mode='bilinear', align_corners=True) + loss_dsn = self.criterion(pred_dsn, targets) + total_loss = loss_ce + 0.4 * loss_dsn + + return total_loss + + +class ABRLovaszCELoss(nn.Module): + """Lovasz loss for Alpha process""" + + def __init__(self, ignore_index=None, only_present=True): + super(ABRLovaszCELoss, self).__init__() + self.ignore_index = ignore_index + self.only_present = only_present + + self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index) + + def forward(self, preds, targets): + h, w = targets.size(1), targets.size(2) + # seg loss + pred = F.interpolate(input=preds[0], size=(h, w), mode='bilinear', align_corners=True) + loss_ce = self.criterion(pred, targets) + + pred = F.softmax(input=pred, dim=1) + loss = lovasz_softmax_flat(*flatten_probas(pred, targets, self.ignore_index), + only_present=self.only_present) + + # dsn loss + pred_dsn = F.interpolate(input=preds[-1], size=(h, w), mode='bilinear', align_corners=True) + loss_dsn = self.criterion(pred_dsn, targets) + total_loss = loss_ce + loss + 0.4 * loss_dsn + + return total_loss + + +class LovaszSoftmaxLoss(nn.Module): + """Lovasz loss for Deep Supervision""" + + def __init__(self, ignore_index=None, only_present=False, per_image=False): + super(LovaszSoftmaxLoss, self).__init__() + self.ignore_index = ignore_index + self.only_present = only_present + self.per_image = per_image + self.weight = torch.FloatTensor([0.80777327, 1.00125961, 0.90997236, 1.10867908, 1.17541499, + 0.86041422, 1.01116758, 0.89290045, 1.12410812, 0.91105395, + 1.07604013, 1.12470610, 1.09895196, 0.90172057, 0.93529453, + 0.93054733, 1.04919178, 1.04937547, 1.06267568, 1.06365688]) + self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, weight=self.weight) + + def forward(self, preds, targets): + h, w = targets.size(1), targets.size(2) + # seg loss + pred = F.interpolate(input=preds[0], size=(h, w), mode='bilinear', align_corners=True) + pred = F.softmax(input=pred, dim=1) + if self.per_image: + loss = mean(lovasz_softmax_flat(*flatten_probas(pre.unsqueeze(0), tar.unsqueeze(0), self.ignore_index), + only_present=self.only_present) for pre, tar in zip(pred, targets)) + else: + loss = lovasz_softmax_flat(*flatten_probas(pred, targets, self.ignore_index), + only_present=self.only_present) + # dsn loss + pred_dsn = F.interpolate(input=preds[1], size=(h, w), mode='bilinear', align_corners=True) + loss_dsn = self.criterion(pred_dsn, targets) + return loss + 0.4 * loss_dsn + + +def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): + """ + Multi-class Lovasz-Softmax loss + probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). + Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. + labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) + classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. + per_image: compute the loss per image instead of per batch + ignore: void class labels + """ + if per_image: + loss = mean( + lovasz_softmax_flat_ori(*flatten_probas_ori(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) + for prob, lab in zip(probas, labels)) + else: + loss = lovasz_softmax_flat_ori(*flatten_probas_ori(probas, labels, ignore), classes=classes) + return loss + + +def lovasz_softmax_flat_ori(probas, labels, classes='present'): + """ + Multi-class Lovasz-Softmax loss + probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) + labels: [P] Tensor, ground truth labels (between 0 and C - 1) + classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. + """ + if probas.numel() == 0: + # only void pixels, the gradients should be 0 + return probas * 0. + C = probas.size(1) + losses = [] + class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes + for c in class_to_sum: + fg = (labels == c).float() # foreground for class c + if (classes is 'present' and fg.sum() == 0): + continue + if C == 1: + if len(classes) > 1: + raise ValueError('Sigmoid output possible only with 1 class') + class_pred = probas[:, 0] + else: + class_pred = probas[:, c] + errors = (Variable(fg) - class_pred).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) + return mean(losses) + + +def flatten_probas_ori(probas, labels, ignore=None): + """ + Flattens predictions in the batch + """ + if probas.dim() == 3: + # assumes output of a sigmoid layer + B, H, W = probas.size() + probas = probas.view(B, 1, H, W) + B, C, H, W = probas.size() + probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C + labels = labels.view(-1) + if ignore is None: + return probas, labels + valid = (labels != ignore) + vprobas = probas[valid.nonzero().squeeze()] + vlabels = labels[valid] + return vprobas, vlabels + + +def lovasz_softmax_flat(preds, targets, only_present=False): + """ + Multi-class Lovasz-Softmax loss + :param preds: [P, C] Variable, class probabilities at each prediction (between 0 and 1) + :param targets: [P] Tensor, ground truth labels (between 0 and C - 1) + :param only_present: average only on classes present in ground truth + """ + if preds.numel() == 0: + # only void pixels, the gradients should be 0 + return preds * 0. + + C = preds.size(1) + losses = [] + for c in range(C): + fg = (targets == c).float() # foreground for class c + if only_present and fg.sum() == 0: + continue + errors = (Variable(fg) - preds[:, c]).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) + return mean(losses) + + +def lovasz_grad(gt_sorted): + """ + Computes gradient of the Lovasz extension w.r.t sorted errors + """ + p = len(gt_sorted) + gts = gt_sorted.sum() + intersection = gts - gt_sorted.float().cumsum(0) + union = gts + (1 - gt_sorted).float().cumsum(0) + jaccard = 1. - intersection / union + if p > 1: # cover 1-pixel case + jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] + return jaccard + + +def flatten_probas(preds, targets, ignore=None): + """ + Flattens predictions in the batch + """ + B, C, H, W = preds.size() + preds = preds.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C + targets = targets.view(-1) + if ignore is None: + return preds, targets + valid = (targets != ignore) + vprobas = preds[valid.nonzero().squeeze()] + vlabels = targets[valid] + return vprobas, vlabels + + +# --------------------------- BINARY LOSSES --------------------------- + + +def lovasz_hinge(logits, labels, per_image=True, ignore=None): + """ + Binary Lovasz hinge loss + logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) + labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) + per_image: compute the loss per image instead of per batch + ignore: void class id + """ + if per_image: + loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) + for log, lab in zip(logits, labels)) + else: + loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) + return loss + + +def lovasz_hinge_flat(logits, labels): + """ + Binary Lovasz hinge loss + logits: [P] Variable, logits at each prediction (between -\infty and +\infty) + labels: [P] Tensor, binary ground truth labels (0 or 1) + ignore: label to ignore + """ + if len(labels) == 0: + # only void pixels, the gradients should be 0 + return logits.sum() * 0. + signs = 2. * labels.float() - 1. + errors = (1. - logits * Variable(signs)) + errors_sorted, perm = torch.sort(errors, dim=0, descending=True) + perm = perm.data + gt_sorted = labels[perm] + grad = lovasz_grad(gt_sorted) + loss = torch.dot(F.relu(errors_sorted), Variable(grad)) + return loss + + +def flatten_binary_scores(scores, labels, ignore=None): + """ + Flattens predictions in the batch (binary case) + Remove labels equal to 'ignore' + """ + scores = scores.view(-1) + labels = labels.view(-1) + if ignore is None: + return scores, labels + valid = (labels != ignore) + vscores = scores[valid] + vlabels = labels[valid] + return vscores, vlabels + + +def mean(l, ignore_nan=True, empty=0): + """ + nan mean compatible with generators. + """ + l = iter(l) + if ignore_nan: + l = ifilterfalse(isnan, l) + try: + n = 1 + acc = next(l) + except StopIteration: + if empty == 'raise': + raise ValueError('Empty mean') + return empty + for n, v in enumerate(l, 2): + acc += v + if n == 1: + return acc + return acc / n + + +def isnan(x): + return x != x + + +class AAF_Loss(nn.Module): + """ + Loss function for multiple outputs + """ + + def __init__(self, ignore_index=255, num_classes=7): + super(AAF_Loss, self).__init__() + self.ignore_index = ignore_index + self.num_classes = num_classes + self.kld_margin = 3.0 + self.kld_lambda_1 = 1.0 + self.kld_lambda_2 = 1.0 + # self.dec = 1e-3 + self.dec = 1e-2 + self.softmax = nn.Softmax(dim=1) + self.w_edge = torch.zeros(1, 1, 1, self.num_classes, 1, 3) + self.w_edge_softmax = nn.Softmax(dim=-1) + self.w_not_edge = torch.zeros(1, 1, 1, self.num_classes, 1, 3) + self.w_not_edge_softmax = nn.Softmax(dim=-1) + + def forward(self, preds, targets): + h, w = targets.size(1), targets.size(2) + # seg loss + pred = F.interpolate(input=preds, size=(h, w), mode='bilinear', align_corners=True) + pred = F.softmax(input=pred, dim=1) + + # aaf loss + labels = targets.unsqueeze(1) + one_label = labels.clone() + one_label[labels == self.ignore_index] = 0 + # one_hot_lab = F.one_hot(one_label, num_classes=self.num_classes) + + one_hot_lab = torch.zeros(one_label.size(0), self.num_classes, one_label.size(2), one_label.size(3)).cuda() + one_hot_lab = one_hot_lab.scatter_(1, one_label.data, 1) + + targets_p_node_list = list(torch.split(one_hot_lab, 1, dim=1)) + for i in range(self.num_classes): + # Log.info('{} {}'.format(targets_p_node_list[i].shape, labels.shape)) + targets_p_node_list[i] = targets_p_node_list[i].squeeze(-1) + targets_p_node_list[i][labels == self.ignore_index] = self.ignore_index + one_hot_lab = torch.cat(targets_p_node_list, dim=1).permute(0, 2, 3, 1) + + prob = pred + w_edge = self.w_edge_softmax(self.w_edge).cuda() + w_not_edge = self.w_not_edge_softmax(self.w_not_edge).cuda() + # Log.info('{} {} {} {}'.format(one_hot_lab.shape, labels.shape, w_edge.shape, w_not_edge.shape)) + + # w_edge_shape=list(w_edge.shape) + # Apply AAF on 3x3 patch. + eloss_1, neloss_1 = lossx.adaptive_affinity_loss(labels, + one_hot_lab, + prob, + 1, + self.num_classes, + self.kld_margin, + w_edge[..., 0], + w_not_edge[..., 0]) + # Apply AAF on 5x5 patch. + # eloss_2, neloss_2 = lossx.adaptive_affinity_loss(labels, + # one_hot_lab, + # prob, + # 2, + # self.num_classes, + # self.kld_margin, + # w_edge[..., 1], + # w_not_edge[..., 1]) + # # Apply AAF on 7x7 patch. + # eloss_3, neloss_3 = lossx.adaptive_affinity_loss(labels, + # one_hot_lab, + # prob, + # 3, + # self.num_classes, + # self.kld_margin, + # w_edge[..., 2], + # w_not_edge[..., 2]) + dec = self.dec + aaf_loss = torch.mean(eloss_1) * self.kld_lambda_1 * dec + # aaf_loss += torch.mean(eloss_2) * self.kld_lambda_1*dec + # aaf_loss += torch.mean(eloss_3) * self.kld_lambda_1*dec + aaf_loss += torch.mean(neloss_1) * self.kld_lambda_2 * dec + # aaf_loss += torch.mean(neloss_2) * self.kld_lambda_2*dec + # aaf_loss += torch.mean(neloss_3) * self.kld_lambda_2*dec + + return aaf_loss \ No newline at end of file diff --git a/src/model/pspnet.py b/src/model/pspnet.py index 4d2227d..05be907 100755 --- a/src/model/pspnet.py +++ b/src/model/pspnet.py @@ -1,12 +1,11 @@ import torch import torch.nn.functional as F from torch import nn - from .resnet import resnet50, resnet101 -def get_model(args) -> nn.Module: - return PSPNet(args, zoom_factor=8, use_ppm=True) +def get_model(cfg: dict, args) -> nn.Module: + return PSPNet(cfg, args, zoom_factor=8, use_ppm=True) class PPM(nn.Module): @@ -30,20 +29,20 @@ def forward(self, x): class PSPNet(nn.Module): - def __init__(self, args, zoom_factor, use_ppm): + def __init__(self, cfg, args, zoom_factor, use_ppm): super(PSPNet, self).__init__() - assert 2048 % len(args.bins) == 0 - assert args.get('num_classes_tr') is not None, 'Get the data loaders first' + assert 2048 % len(cfg['MODEL']['bins']) == 0 + assert args.num_classes_tr is not None, 'Get the data loaders first' assert zoom_factor in [1, 2, 4, 8] self.zoom_factor = zoom_factor self.use_ppm = use_ppm - self.m_scale = args.m_scale - self.bottleneck_dim = args.bottleneck_dim + self.m_scale = cfg['MODEL']['m_scale'] + self.bottleneck_dim = cfg['MODEL']['bottleneck_dim'] - if args.layers == 50: - resnet = resnet50(pretrained=args.pretrained) + if cfg['MODEL']['layers'] == 50: + resnet = resnet50(pretrained=cfg['MODEL']['pretrained']) else: - resnet = resnet101(pretrained=args.pretrained) + resnet = resnet101(pretrained=cfg['MODEL']['pretrained']) self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.conv2, resnet.bn2, resnet.relu, resnet.conv3, resnet.bn3, resnet.relu, resnet.maxpool) self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 @@ -65,13 +64,13 @@ def __init__(self, args, zoom_factor, use_ppm): fea_dim = 2048 if use_ppm: - self.ppm = PPM(fea_dim, int(fea_dim/len(args.bins)), args.bins) + self.ppm = PPM(fea_dim, int(fea_dim/len(cfg['MODEL']['bins'])), cfg['MODEL']['bins']) fea_dim *= 2 self.bottleneck = nn.Sequential( nn.Conv2d(fea_dim, self.bottleneck_dim, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(self.bottleneck_dim), nn.ReLU(inplace=True), - nn.Dropout2d(p=args.dropout), + nn.Dropout2d(p=cfg['MODEL']['dropout']), ) self.classifier = nn.Conv2d(self.bottleneck_dim, args.num_classes_tr, kernel_size=1) diff --git a/src/rmi_loss.py b/src/rmi_loss.py new file mode 100644 index 0000000..e1a2f3d --- /dev/null +++ b/src/rmi_loss.py @@ -0,0 +1,402 @@ +# coding=utf-8 + +""" +The implementation of the paper: +Region Mutual Information Loss for Semantic Segmentation. +""" + +# python 2.X, 3.X compatibility +from __future__ import print_function +from __future__ import division +from __future__ import absolute_import + +import pdb +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['RMILoss'] + +TORCH_VERSION = torch.__version__[:3] + +_euler_num = 2.718281828 # euler number +_pi = 3.14159265 # pi +_ln_2_pi = 1.837877 # ln(2 * pi) +_CLIP_MIN = 1e-6 # min clip value after softmax or sigmoid operations +_CLIP_MAX = 1.0 # max clip value after softmax or sigmoid operations +_POS_ALPHA = 1e-3 # add this factor to ensure the AA^T is positive definite +_IS_SUM = 1 # sum the loss per channel + + +def map_get_pairs(labels_4D, probs_4D, radius=3, is_combine=True): + """get map pairs + Args: + labels_4D : labels, shape [N, C, H, W] + probs_4D : probabilities, shape [N, C, H, W] + radius : the square radius + Return: + tensor with shape [N, C, radius * radius, H - (radius - 1), W - (radius - 1)] + """ + # pad to ensure the following slice operation is valid + # pad_beg = int(radius // 2) + # pad_end = radius - pad_beg + + # the original height and width + label_shape = labels_4D.size() + h, w = label_shape[2], label_shape[3] + new_h, new_w = h - (radius - 1), w - (radius - 1) + # https://pytorch.org/docs/stable/nn.html?highlight=f%20pad#torch.nn.functional.pad + # padding = (pad_beg, pad_end, pad_beg, pad_end) + # labels_4D, probs_4D = F.pad(labels_4D, padding), F.pad(probs_4D, padding) + + # get the neighbors + la_ns = [] + pr_ns = [] + # for x in range(0, radius, 1): + for y in range(0, radius, 1): + for x in range(0, radius, 1): + la_now = labels_4D[:, :, y:y + new_h, x:x + new_w] + pr_now = probs_4D[:, :, y:y + new_h, x:x + new_w] + la_ns.append(la_now) + pr_ns.append(pr_now) + + if is_combine: + # for calculating RMI + pair_ns = la_ns + pr_ns + p_vectors = torch.stack(pair_ns, dim=2) + return p_vectors + else: + # for other purpose + la_vectors = torch.stack(la_ns, dim=2) + pr_vectors = torch.stack(pr_ns, dim=2) + return la_vectors, pr_vectors + + +def map_get_pairs_region(labels_4D, probs_4D, radius=3, is_combine=0, num_classeses=21): + """get map pairs + Args: + labels_4D : labels, shape [N, C, H, W]. + probs_4D : probabilities, shape [N, C, H, W]. + radius : The side length of the square region. + Return: + A tensor with shape [N, C, radiu * radius, H // radius, W // raidius] + """ + kernel = torch.zeros([num_classeses, 1, radius, radius]).type_as(probs_4D) + padding = radius // 2 + # get the neighbours + la_ns = [] + pr_ns = [] + for y in range(0, radius, 1): + for x in range(0, radius, 1): + kernel_now = kernel.clone() + kernel_now[:, :, y, x] = 1.0 + la_now = F.conv2d(labels_4D, kernel_now, stride=radius, padding=padding, groups=num_classeses) + pr_now = F.conv2d(probs_4D, kernel_now, stride=radius, padding=padding, groups=num_classeses) + la_ns.append(la_now) + pr_ns.append(pr_now) + + if is_combine: + # for calculating RMI + pair_ns = la_ns + pr_ns + p_vectors = torch.stack(pair_ns, dim=2) + return p_vectors + else: + # for other purpose + la_vectors = torch.stack(la_ns, dim=2) + pr_vectors = torch.stack(pr_ns, dim=2) + return la_vectors, pr_vectors + return + + +def log_det_by_cholesky(matrix): + """ + Args: + matrix: matrix must be a positive define matrix. + shape [N, C, D, D]. + Ref: + https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/linalg/linalg_impl.py + """ + # This uses the property that the log det(A) = 2 * sum(log(real(diag(C)))) + # where C is the cholesky decomposition of A. + chol = torch.cholesky(matrix) + # return 2.0 * torch.sum(torch.log(torch.diagonal(chol, dim1=-2, dim2=-1) + 1e-6), dim=-1) + return 2.0 * torch.sum(torch.log(torch.diagonal(chol, dim1=-2, dim2=-1) + 1e-8), dim=-1) + + +def batch_cholesky_inverse(matrix): + """ + Args: matrix, 4-D tensor, [N, C, M, M]. + matrix must be a symmetric positive define matrix. + """ + chol_low = torch.cholesky(matrix, upper=False) + chol_low_inv = batch_low_tri_inv(chol_low) + return torch.matmul(chol_low_inv.transpose(-2, -1), chol_low_inv) + + +def batch_low_tri_inv(L): + """ + Batched inverse of lower triangular matrices + Args: + L : a lower triangular matrix + Ref: + https://www.pugetsystems.com/labs/hpc/PyTorch-for-Scientific-Computing + """ + n = L.shape[-1] + invL = torch.zeros_like(L) + for j in range(0, n): + invL[..., j, j] = 1.0 / L[..., j, j] + for i in range(j + 1, n): + S = 0.0 + for k in range(0, i + 1): + S = S - L[..., i, k] * invL[..., k, j].clone() + invL[..., i, j] = S / L[..., i, i] + return invL + + +def log_det_by_cholesky_test(): + """ + test for function log_det_by_cholesky() + """ + a = torch.randn(1, 4, 4) + a = torch.matmul(a, a.transpose(2, 1)) + print(a) + res_1 = torch.logdet(torch.squeeze(a)) + res_2 = log_det_by_cholesky(a) + print(res_1, res_2) + + +def batch_inv_test(): + """ + test for function batch_cholesky_inverse() + """ + a = torch.randn(1, 1, 4, 4) + a = torch.matmul(a, a.transpose(-2, -1)) + print(a) + res_1 = torch.inverse(a) + res_2 = batch_cholesky_inverse(a) + print(res_1, '\n', res_2) + + +def mean_var_test(): + x = torch.randn(3, 4) + y = torch.randn(3, 4) + + x_mean = x.mean(dim=1, keepdim=True) + x_sum = x.sum(dim=1, keepdim=True) / 2.0 + y_mean = y.mean(dim=1, keepdim=True) + y_sum = y.sum(dim=1, keepdim=True) / 2.0 + + x_var_1 = torch.matmul(x - x_mean, (x - x_mean).t()) + x_var_2 = torch.matmul(x, x.t()) - torch.matmul(x_sum, x_sum.t()) + xy_cov = torch.matmul(x - x_mean, (y - y_mean).t()) + xy_cov_1 = torch.matmul(x, y.t()) - x_sum.matmul(y_sum.t()) + + print(x_var_1) + print(x_var_2) + + print(xy_cov, '\n', xy_cov_1) + + +class RMILoss(nn.Module): + """ + region mutual information + I(A, B) = H(A) + H(B) - H(A, B) + This version need a lot of memory if do not dwonsample. + """ + + def __init__(self, + configer=None): + super(RMILoss, self).__init__() + self.configer = configer + self.use_sigmoid = self.configer.get('loss', 'params')['use_sigmoid'] + self.num_classes = self.configer.get('loss', 'params')['num_classes'] + # radius choices + self.rmi_radius = self.configer.get('loss', 'params')['rmi_radius'] + assert self.rmi_radius in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + self.rmi_pool_way = self.configer.get('loss', 'params')['rmi_pool_way'] + assert self.rmi_pool_way in [0, 1, 2, 3] + + # set the pool_size = rmi_pool_stride + self.rmi_pool_size = self.configer.get('loss', 'params')['rmi_pool_size'] + self.rmi_pool_stride = self.configer.get('loss', 'params')['rmi_pool_stride'] + assert self.rmi_pool_size == self.rmi_pool_stride + + self.weight_lambda = self.configer.get('loss', 'params')['loss_weight_lambda'] + self.loss_weight = self.configer.get('loss', 'params')['loss_weight'] + self.lambda_way = self.configer.get('loss', 'params')['lambda_way'] + + # dimension of the distribution + self.half_d = self.rmi_radius * self.rmi_radius + self.d = 2 * self.half_d + self.kernel_padding = self.rmi_pool_size // 2 + # ignore class + self.ignore_index = 255 + + def forward(self, + cls_score, + label, + weight=None, + **kwargs): + label[label < 0] = 255 + loss = self.loss_weight * self.forward_sigmoid(cls_score, label) + label[label == 255] = -1 + # loss = self.forward_softmax_sigmoid(cls_score, label) + return loss + + def forward_softmax_sigmoid(self, logits_4D, labels_4D): + """ + Using both softmax and sigmoid operations. + Args: + logits_4D : [N, C, H, W], dtype=float32 + labels_4D : [N, H, W], dtype=long + """ + # PART I -- get the normal cross entropy loss + print( + "max label: {} min label: {}".format(labels_4D[labels_4D != 255].max(), labels_4D[labels_4D != 255].min())) + normal_loss = F.cross_entropy(input=logits_4D, + target=labels_4D.long(), + ignore_index=self.ignore_index, + reduction='mean') + + # PART II -- get the lower bound of the region mutual information + # get the valid label and logits + # valid label, [N, C, H, W] + label_mask_3D = labels_4D < self.num_classes + valid_onehot_labels_4D = F.one_hot(labels_4D.long() * label_mask_3D.long(), + num_classes=self.num_classes).float() + label_mask_3D = label_mask_3D.float() + valid_onehot_labels_4D = valid_onehot_labels_4D * label_mask_3D.unsqueeze(dim=3) + valid_onehot_labels_4D = valid_onehot_labels_4D.permute(0, 3, 1, 2).requires_grad_(False) + # valid probs + probs_4D = F.sigmoid(logits_4D) * label_mask_3D.unsqueeze(dim=1) + probs_4D = probs_4D.clamp(min=_CLIP_MIN, max=_CLIP_MAX) + + # get region mutual information + rmi_loss = self.rmi_lower_bound(valid_onehot_labels_4D, probs_4D) + + # add together + final_loss = (self.weight_lambda * normal_loss + rmi_loss * (1 - self.weight_lambda) if self.lambda_way + else normal_loss + rmi_loss * self.weight_lambda) + + return final_loss + + def forward_sigmoid(self, logits_4D, labels_4D): + """ + Using the sigmiod operation both. + Args: + logits_4D : [N, C, H, W], dtype=float32 + labels_4D : [N, H, W], dtype=long + """ + # label mask -- [N, H, W, 1] + label_mask_3D = labels_4D < self.num_classes + + # valid label + valid_onehot_labels_4D = F.one_hot(labels_4D.long() * label_mask_3D.long(), + num_classes=self.num_classes).float() + label_mask_3D = label_mask_3D.float() + label_mask_flat = label_mask_3D.view([-1, ]) + valid_onehot_labels_4D = valid_onehot_labels_4D * label_mask_3D.unsqueeze(dim=3) + valid_onehot_labels_4D.requires_grad_(False) + + # PART I -- calculate the sigmoid binary cross entropy loss + valid_onehot_label_flat = valid_onehot_labels_4D.view([-1, self.num_classes]).requires_grad_(False) + logits_flat = logits_4D.permute(0, 2, 3, 1).contiguous().view([-1, self.num_classes]) + + # binary loss, multiplied by the not_ignore_mask + valid_pixels = torch.sum(label_mask_flat) + binary_loss = F.binary_cross_entropy_with_logits(logits_flat, + target=valid_onehot_label_flat, + weight=label_mask_flat.unsqueeze(dim=1), + reduction='sum') + bce_loss = torch.div(binary_loss, valid_pixels + 1.0) + + # PART II -- get rmi loss + # onehot_labels_4D -- [N, C, H, W] + probs_4D = logits_4D.sigmoid() * label_mask_3D.unsqueeze(dim=1) + _CLIP_MIN + valid_onehot_labels_4D = valid_onehot_labels_4D.permute(0, 3, 1, 2).requires_grad_(False) + + # get region mutual information + rmi_loss = self.rmi_lower_bound(valid_onehot_labels_4D, probs_4D) + + # add together + final_loss = (self.weight_lambda * bce_loss + rmi_loss * (1 - self.weight_lambda) if self.lambda_way + else bce_loss + rmi_loss * self.weight_lambda) + + return final_loss + + def rmi_lower_bound(self, labels_4D, probs_4D): + """ + calculate the lower bound of the region mutual information. + Args: + labels_4D : [N, C, H, W], dtype=float32 + probs_4D : [N, C, H, W], dtype=float32 + """ + assert labels_4D.size() == probs_4D.size() + + p, s = self.rmi_pool_size, self.rmi_pool_stride + if self.rmi_pool_stride > 1: + if self.rmi_pool_way == 0: + labels_4D = F.max_pool2d(labels_4D, kernel_size=p, stride=s, padding=self.kernel_padding) + probs_4D = F.max_pool2d(probs_4D, kernel_size=p, stride=s, padding=self.kernel_padding) + elif self.rmi_pool_way == 1: + labels_4D = F.avg_pool2d(labels_4D, kernel_size=p, stride=s, padding=self.kernel_padding) + probs_4D = F.avg_pool2d(probs_4D, kernel_size=p, stride=s, padding=self.kernel_padding) + elif self.rmi_pool_way == 2: + # interpolation + shape = labels_4D.size() + new_h, new_w = shape[2] // s, shape[3] // s + labels_4D = F.interpolate(labels_4D, size=(new_h, new_w), mode='nearest') + probs_4D = F.interpolate(probs_4D, size=(new_h, new_w), mode='bilinear', align_corners=True) + else: + raise NotImplementedError("Pool way of RMI is not defined!") + # we do not need the gradient of label. + label_shape = labels_4D.size() + n, c = label_shape[0], label_shape[1] + + # combine the high dimension points from label and probability map. new shape [N, C, radius * radius, H, W] + la_vectors, pr_vectors = map_get_pairs(labels_4D, probs_4D, radius=self.rmi_radius, is_combine=0) + + la_vectors = la_vectors.view([n, c, self.half_d, -1]).type(torch.cuda.DoubleTensor).requires_grad_(False) + pr_vectors = pr_vectors.view([n, c, self.half_d, -1]).type(torch.cuda.DoubleTensor) + + # small diagonal matrix, shape = [1, 1, radius * radius, radius * radius] + diag_matrix = torch.eye(self.half_d).unsqueeze(dim=0).unsqueeze(dim=0) + + # the mean and covariance of these high dimension points + # Var(X) = E(X^2) - E(X) E(X), N * Var(X) = X^2 - X E(X) + la_vectors = la_vectors - la_vectors.mean(dim=3, keepdim=True) + la_cov = torch.matmul(la_vectors, la_vectors.transpose(2, 3)) + + pr_vectors = pr_vectors - pr_vectors.mean(dim=3, keepdim=True) + pr_cov = torch.matmul(pr_vectors, pr_vectors.transpose(2, 3)) + # https://github.com/pytorch/pytorch/issues/7500 + # waiting for batched torch.cholesky_inverse() + pr_cov_inv = torch.inverse(pr_cov + diag_matrix.type_as(pr_cov) * _POS_ALPHA) + # if the dimension of the point is less than 9, you can use the below function + # to acceleration computational speed. + # pr_cov_inv = utils.batch_cholesky_inverse(pr_cov + diag_matrix.type_as(pr_cov) * _POS_ALPHA) + + la_pr_cov = torch.matmul(la_vectors, pr_vectors.transpose(2, 3)) + # the approxiamation of the variance, det(c A) = c^n det(A), A is in n x n shape; + # then log det(c A) = n log(c) + log det(A). + # appro_var = appro_var / n_points, we do not divide the appro_var by number of points here, + # and the purpose is to avoid underflow issue. + # If A = A^T, A^-1 = (A^-1)^T. + appro_var = la_cov - torch.matmul(la_pr_cov.matmul(pr_cov_inv), la_pr_cov.transpose(-2, -1)) + # appro_var = la_cov - torch.chain_matmul(la_pr_cov, pr_cov_inv, la_pr_cov.transpose(-2, -1)) + # appro_var = torch.div(appro_var, n_points.type_as(appro_var)) + diag_matrix.type_as(appro_var) * 1e-6 + + # The lower bound. If A is nonsingular, ln( det(A) ) = Tr( ln(A) ). + rmi_now = 0.5 * log_det_by_cholesky(appro_var + diag_matrix.type_as(appro_var) * _POS_ALPHA) + # rmi_now = 0.5 * torch.logdet(appro_var + diag_matrix.type_as(appro_var) * _POS_ALPHA) + + # mean over N samples. sum over classes. + rmi_per_class = rmi_now.view([-1, self.num_classes]).mean(dim=0).float() + # is_half = False + # if is_half: + # rmi_per_class = torch.div(rmi_per_class, float(self.half_d / 2.0)) + # else: + rmi_per_class = torch.div(rmi_per_class, float(self.half_d)) + + rmi_loss = torch.sum(rmi_per_class) if _IS_SUM else torch.mean(rmi_per_class) + return rmi_loss diff --git a/src/test.py b/src/test.py index 22d9f64..2812d2e 100755 --- a/src/test.py +++ b/src/test.py @@ -1,5 +1,6 @@ import argparse import os +import yaml import time from typing import Tuple @@ -12,8 +13,8 @@ from torch.nn.parallel import DistributedDataParallel as DDP from tqdm import tqdm -from .classifier import Classifier -from .dataset.dataset import get_val_loader +from .classifierV1 import Classifier +from .dataset.datasetV1 import get_val_loader from .model.pspnet import get_model from .util import get_model_dir, fast_intersection_and_union, setup_seed, resume_random_state, find_free_port, setup, \ cleanup, get_cfg @@ -22,26 +23,32 @@ def parse_args(): parser = argparse.ArgumentParser(description='Testing') return get_cfg(parser) - - +# data/coco/val2014/COCO_val2014_000000054091.jpg +# data/coco/val2014/COCO_val2014_000000039115.jpg def main_worker(rank: int, world_size: int, args: argparse.Namespace) -> None: print(f"==> Running evaluation script") + # Access the 'config' argument directly + print(f'NAME SPACE....{args}') + cfg = args setup(args, rank, world_size) - setup_seed(args.manual_seed) + + setup_seed(cfg['EVALUATION']['manual_seed']) + # setup_seed(args.manual_seed) # ========== Data ========== - val_loader = get_val_loader(args) + val_loader = get_val_loader(cfg, args) + # val_loader = get_val_loader(args) # ========== Model ========== - model = get_model(args).to(rank) + model = get_model(cfg,args).to(0) model = nn.SyncBatchNorm.convert_sync_batchnorm(model) - model = DDP(model, device_ids=[rank]) + model = DDP(model, device_ids=[0]) - root = get_model_dir(args) + root = get_model_dir(cfg, args) print("=> Creating the model") - if args.ckpt_used is not None: - filepath = os.path.join(root, f'{args.ckpt_used}.pth') + if cfg['EVALUATION']['ckpt_used'] is not None: + filepath = os.path.join(root, f"{cfg['EVALUATION']['ckpt_used']}.pth") assert os.path.isfile(filepath), filepath checkpoint = torch.load(filepath) model.load_state_dict(checkpoint['state_dict']) @@ -50,27 +57,27 @@ def main_worker(rank: int, world_size: int, args: argparse.Namespace) -> None: print("=> Not loading anything") # ========== Test ========== - validate(args=args, val_loader=val_loader, model=model) + validate(args=args, val_loader=val_loader, model=model, cfg=cfg) cleanup() -def validate(args: argparse.Namespace, val_loader: torch.utils.data.DataLoader, model: DDP) -> Tuple[torch.tensor, torch.tensor]: - print('\n==> Start testing ({} runs)'.format(args.n_runs), flush=True) - random_state = setup_seed(args.manual_seed, return_old_state=True) +def validate(args: argparse.Namespace, val_loader: torch.utils.data.DataLoader, model: DDP,cfg: dict) -> Tuple[torch.tensor, torch.tensor]: + print('\n==> Start testing ({} runs)'.format(cfg['EVALUATION']['n_runs']), flush=True) + random_state = setup_seed(cfg['EVALUATION']['manual_seed'], return_old_state=True) device = torch.device('cuda:{}'.format(dist.get_rank())) model.eval() c = model.module.bottleneck_dim h = model.module.feature_res[0] w = model.module.feature_res[1] - - nb_episodes = len(val_loader) if args.test_num == -1 else int(args.test_num / args.batch_size_val) - runtimes = torch.zeros(args.n_runs) - base_mIoU, novel_mIoU = [torch.zeros(args.n_runs, device=device) for _ in range(2)] + print(f'channel .....{c}') + nb_episodes = len(val_loader) if cfg['EVALUATION']['test_num'] == -1 else int(cfg['EVALUATION']['test_num'] / cfg['EVALUATION']['batch_size_val']) + runtimes = torch.zeros(cfg['EVALUATION']['n_runs']) + base_mIoU, novel_mIoU = [torch.zeros(cfg['EVALUATION']['n_runs'], device=device) for _ in range(2)] # ========== Perform the runs ========== - for run in range(args.n_runs): - print('Run', run + 1, 'of', args.n_runs) + for run in range(cfg['EVALUATION']['n_runs']): + print('Run', run + 1, 'of', cfg['EVALUATION']['n_runs']) # The order of classes in the following tensors is the same as the order of classifier (novels at last) cls_intersection = torch.zeros(args.num_classes_tr + args.num_classes_val) cls_union = torch.zeros(args.num_classes_tr + args.num_classes_val) @@ -78,14 +85,15 @@ def validate(args: argparse.Namespace, val_loader: torch.utils.data.DataLoader, runtime = 0 features_s, gt_s = None, None - if not args.generate_new_support_set_for_each_task: + if not cfg['EVALUATION']['generate_new_support_set_for_each_task']: with torch.no_grad(): spprt_imgs, s_label = val_loader.dataset.generate_support([], remove_them_from_query_data_list=True) - nb_episodes = len(val_loader) if args.test_num == -1 else nb_episodes # Updates nb_episodes since some images were removed by generate_support + nb_episodes = len(val_loader) if cfg['EVALUATION']['test_num'] == -1 else nb_episodes # Updates nb_episodes since some images were removed by generate_support spprt_imgs = spprt_imgs.to(device, non_blocking=True) s_label = s_label.to(device, non_blocking=True) features_s = model.module.extract_features(spprt_imgs).detach().view((args.num_classes_val, args.shot, c, h, w)) gt_s = s_label.view((args.num_classes_val, args.shot, args.image_size, args.image_size)) + print(f'running here...') for _ in tqdm(range(nb_episodes), leave=True): t0 = time.time() @@ -102,19 +110,21 @@ def validate(args: argparse.Namespace, val_loader: torch.utils.data.DataLoader, features_q = model.module.extract_features(qry_img).detach().unsqueeze(1) valid_pixels_q = q_valid_pix.unsqueeze(1).to(device) gt_q = q_label.unsqueeze(1) - + print(f'running here. images..') query_image_path_list = list(img_path) - if args.generate_new_support_set_for_each_task: + if cfg['EVALUATION']['generate_new_support_set_for_each_task']: spprt_imgs, s_label = val_loader.dataset.generate_support(query_image_path_list) spprt_imgs = spprt_imgs.to(device, non_blocking=True) s_label = s_label.to(device, non_blocking=True) - features_s = model.module.extract_features(spprt_imgs).detach().view((args.num_classes_val, args.shot, c, h, w)) - gt_s = s_label.view((args.num_classes_val, args.shot, args.image_size, args.image_size)) + features_s = model.module.extract_features(spprt_imgs).detach().view((args.num_classes_val, acfg['EVALUATION']['shot'], c, h, w)) + gt_s = s_label.view((args.num_classes_val, cfg['EVALUATION']['shot'], cfg['DATA']['image_size'], cfg['DATA']['image_size'])) # =========== Initialize the classifier and run the method =============== base_weight = model.module.classifier.weight.detach().clone().T base_bias = model.module.classifier.bias.detach().clone() - classifier = Classifier(args, base_weight, base_bias, n_tasks=features_q.size(0)) + print(f'running classifier...') + # classifier = DIaMClassifier(args, base_weight, base_bias, n_tasks=features_q.size(0), cfg=cfg, backbone=model,features_s=features_s,gt_s=gt_s,num_novel_classes= args.num_novel_classes,feature_dim=feature_dim) + classifier = Classifier(args, base_weight, base_bias, n_tasks=features_q.size(0), cfg=cfg, backbone=model) classifier.init_prototypes(features_s, gt_s) classifier.compute_pi(features_q, valid_pixels_q, gt_q) # gt_q won't be used in optimization if pi estimation strategy is self or uniform classifier.optimize(features_s, features_q, gt_s, valid_pixels_q) @@ -162,10 +172,12 @@ def validate(args: argparse.Namespace, val_loader: torch.utils.data.DataLoader, if __name__ == "__main__": + parser = argparse.ArgumentParser(description='DIaM Training and Testing Script') + parser.add_argument('--config', type=str, help='Path to YAML config file') args = parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpus) os.environ['OPENBLAS_NUM_THREADS'] = '1' - + if args.debug: args.test_num = 64 args.n_runs = 2 @@ -183,3 +195,19 @@ def validate(args: argparse.Namespace, val_loader: torch.utils.data.DataLoader, args=(world_size, args), nprocs=world_size, join=True) + # parser = argparse.ArgumentParser(description='DIaM Training and Testing Script') + # parser.add_argument('--config', type=str, help='Path to YAML config file') + # parser.add_argument('--split', type=int, help='Data split to use') + # parser.add_argument('--shot', type=int, help='Number of shots') + # parser.add_argument('--gpus', type=int, default=0, help='GPU ID to use (-1 for CPU)') + # args = parser.parse_args() + # # if args.debug: + # # args.test_num = 64 + # # args.n_runs = 2 + + # world_size = len(str(args.gpus)) + # distributed = world_size > 1 + # assert not distributed, 'Testing should not be done in a distributed way' + # args.distributed = distributed + # args.port = find_free_port() + # main_worker(0, world_size, args) diff --git a/src/testNotebook.ipynb b/src/testNotebook.ipynb new file mode 100644 index 0000000..11dd1c2 --- /dev/null +++ b/src/testNotebook.ipynb @@ -0,0 +1,637 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "44000167-076c-458b-a015-5cc7d8f1989a", + "metadata": {}, + "outputs": [], + "source": [ + "import argparse\n", + "import os\n", + "import time\n", + "from typing import Tuple\n", + "\n", + "import torch\n", + "import torch.distributed as dist\n", + "import torch.multiprocessing as mp\n", + "import torch.nn as nn\n", + "import torch.nn.parallel\n", + "import torch.utils.data\n", + "from torch.nn.parallel import DistributedDataParallel as DDP\n", + "from tqdm import tqdm\n", + "\n", + "from classifier import Classifier\n", + "from dataset.dataset import get_val_loader\n", + "from model.pspnet import get_model\n", + "from util import get_model_dir, fast_intersection_and_union, setup_seed, resume_random_state, find_free_port, setup, \\\n", + " cleanup, get_cfg\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a0e8e8d-cbbc-4457-a44d-7415de077372", + "metadata": {}, + "outputs": [], + "source": [ + "def parse_args():\n", + " parser = argparse.ArgumentParser(description='Testing')\n", + " return get_cfg(parser)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c59b9c9c-a5cb-4939-a82e-7adfd98471fb", + "metadata": {}, + "outputs": [], + "source": [ + "def main_worker(rank: int, world_size: int, args: argparse.Namespace) -> None:\n", + " print(f\"==> Running evaluation script\")\n", + " setup(args, rank, world_size)\n", + " setup_seed(args.manual_seed)\n", + "\n", + " # ========== Data ==========\n", + " val_loader = get_val_loader(args)\n", + "\n", + " # ========== Model ==========\n", + " model = get_model(args).to(rank)\n", + " model = nn.SyncBatchNorm.convert_sync_batchnorm(model)\n", + " model = DDP(model, device_ids=[rank])\n", + "\n", + " root = get_model_dir(args)\n", + "\n", + " print(\"=> Creating the model\")\n", + " if args.ckpt_used is not None:\n", + " filepath = os.path.join(root, f'{args.ckpt_used}.pth')\n", + " assert os.path.isfile(filepath), filepath\n", + " checkpoint = torch.load(filepath)\n", + " model.load_state_dict(checkpoint['state_dict'])\n", + " print(\"=> Loaded weight '{}'\".format(filepath))\n", + " else:\n", + " print(\"=> Not loading anything\")\n", + "\n", + " # ========== Test ==========\n", + " validate(args=args, val_loader=val_loader, model=model)\n", + " cleanup()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d4cf292-6d92-4ef7-87fc-76b92007b9fe", + "metadata": {}, + "outputs": [], + "source": [ + "def validate(args: argparse.Namespace, val_loader: torch.utils.data.DataLoader, model: DDP) -> Tuple[torch.tensor, torch.tensor]:\n", + " print('\\n==> Start testing ({} runs)'.format(args.n_runs), flush=True)\n", + " random_state = setup_seed(args.manual_seed, return_old_state=True)\n", + " device = torch.device('cuda:{}'.format(dist.get_rank()))\n", + " model.eval()\n", + "\n", + " c = model.module.bottleneck_dim\n", + " h = model.module.feature_res[0]\n", + " w = model.module.feature_res[1]\n", + "\n", + " nb_episodes = len(val_loader) if args.test_num == -1 else int(args.test_num / args.batch_size_val)\n", + " runtimes = torch.zeros(args.n_runs)\n", + " base_mIoU, novel_mIoU = [torch.zeros(args.n_runs, device=device) for _ in range(2)]\n", + "\n", + " # ========== Perform the runs ==========\n", + " for run in range(args.n_runs):\n", + " print('Run', run + 1, 'of', args.n_runs)\n", + " # The order of classes in the following tensors is the same as the order of classifier (novels at last)\n", + " cls_intersection = torch.zeros(args.num_classes_tr + args.num_classes_val)\n", + " cls_union = torch.zeros(args.num_classes_tr + args.num_classes_val)\n", + " cls_target = torch.zeros(args.num_classes_tr + args.num_classes_val)\n", + "\n", + " runtime = 0\n", + " features_s, gt_s = None, None\n", + " if not args.generate_new_support_set_for_each_task:\n", + " with torch.no_grad():\n", + " spprt_imgs, s_label = val_loader.dataset.generate_support([], remove_them_from_query_data_list=True)\n", + " nb_episodes = len(val_loader) if args.test_num == -1 else nb_episodes # Updates nb_episodes since some images were removed by generate_support\n", + " spprt_imgs = spprt_imgs.to(device, non_blocking=True)\n", + " s_label = s_label.to(device, non_blocking=True)\n", + " features_s = model.module.extract_features(spprt_imgs).detach().view((args.num_classes_val, args.shot, c, h, w))\n", + " gt_s = s_label.view((args.num_classes_val, args.shot, args.image_size, args.image_size))\n", + "\n", + " for _ in tqdm(range(nb_episodes), leave=True):\n", + " t0 = time.time()\n", + " with torch.no_grad():\n", + " try:\n", + " loader_output = next(iter_loader)\n", + " except (UnboundLocalError, StopIteration):\n", + " iter_loader = iter(val_loader)\n", + " loader_output = next(iter_loader)\n", + " qry_img, q_label, q_valid_pix, img_path = loader_output\n", + "\n", + " qry_img = qry_img.to(device, non_blocking=True)\n", + " q_label = q_label.to(device, non_blocking=True)\n", + " features_q = model.module.extract_features(qry_img).detach().unsqueeze(1)\n", + " valid_pixels_q = q_valid_pix.unsqueeze(1).to(device)\n", + " gt_q = q_label.unsqueeze(1)\n", + "\n", + " query_image_path_list = list(img_path)\n", + " if args.generate_new_support_set_for_each_task:\n", + " spprt_imgs, s_label = val_loader.dataset.generate_support(query_image_path_list)\n", + " spprt_imgs = spprt_imgs.to(device, non_blocking=True)\n", + " s_label = s_label.to(device, non_blocking=True)\n", + " features_s = model.module.extract_features(spprt_imgs).detach().view((args.num_classes_val, args.shot, c, h, w))\n", + " gt_s = s_label.view((args.num_classes_val, args.shot, args.image_size, args.image_size))\n", + "\n", + " # =========== Initialize the classifier and run the method ===============\n", + " base_weight = model.module.classifier.weight.detach().clone().T\n", + " base_bias = model.module.classifier.bias.detach().clone()\n", + " classifier = Classifier(args, base_weight, base_bias, n_tasks=features_q.size(0))\n", + " classifier.init_prototypes(features_s, gt_s)\n", + " classifier.compute_pi(features_q, valid_pixels_q, gt_q) # gt_q won't be used in optimization if pi estimation strategy is self or uniform\n", + " classifier.optimize(features_s, features_q, gt_s, valid_pixels_q)\n", + "\n", + " runtime += time.time() - t0\n", + "\n", + " # =========== Perform inference and compute metrics ===============\n", + " logits = classifier.get_logits(features_q).detach()\n", + " probas = classifier.get_probas(logits)\n", + "\n", + " intersection, union, target = fast_intersection_and_union(probas, gt_q) # [batch_size_val, 1, num_classes]\n", + " intersection, union, target = intersection.squeeze(1).cpu(), union.squeeze(1).cpu(), target.squeeze(1).cpu()\n", + " cls_intersection += intersection.sum(0)\n", + " cls_union += union.sum(0)\n", + " cls_target += target.sum(0)\n", + "\n", + " base_count, novel_count, sum_base_IoU, sum_novel_IoU = 4 * [0]\n", + " for i, class_ in enumerate(val_loader.dataset.all_classes):\n", + " if cls_union[i] == 0:\n", + " continue\n", + " IoU = cls_intersection[i] / (cls_union[i] + 1e-10)\n", + " print(\"Class {}: \\t{:.4f}\".format(class_, IoU))\n", + " if class_ in val_loader.dataset.base_class_list:\n", + " sum_base_IoU += IoU\n", + " base_count += 1\n", + " elif class_ in val_loader.dataset.novel_class_list:\n", + " sum_novel_IoU += IoU\n", + " novel_count += 1\n", + "\n", + " avg_base_IoU, avg_novel_IoU = sum_base_IoU / base_count, sum_novel_IoU / novel_count\n", + " print('Mean base IoU: {:.4f}, Mean novel IoU: {:.4f}'.format(avg_base_IoU, avg_novel_IoU), flush=True)\n", + "\n", + " base_mIoU[run], novel_mIoU[run] = avg_base_IoU, avg_novel_IoU\n", + " runtimes[run] = runtime\n", + "\n", + " agg_mIoU = (base_mIoU.mean() + novel_mIoU.mean()) / 2\n", + " print('==>')\n", + " print('Average of base mIoU: {:.4f}\\tAverage of novel mIoU: {:.4f} \\t(over {} runs)'.format(\n", + " base_mIoU.mean(), novel_mIoU.mean(), args.n_runs))\n", + " print('Mean --- {:.4f}'.format(agg_mIoU), flush=True)\n", + " print('Average runtime / run --- {:.1f}\\n'.format(runtimes.mean()))\n", + "\n", + " resume_random_state(random_state)\n", + " return agg_mIoU\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17f2d2f8-b495-4df3-b20d-e24a944a9583", + "metadata": {}, + "outputs": [], + "source": [ + "if __name__ == \"__main__\":\n", + " print(parse_args())\n", + " args = parse_args()\n", + " os.environ[\"CUDA_VISIBLE_DEVICES\"] = ','.join(str(x) for x in args.gpus)\n", + " os.environ['OPENBLAS_NUM_THREADS'] = '1'\n", + "\n", + " if args.debug:\n", + " args.test_num = 64\n", + " args.n_runs = 2\n", + "\n", + " world_size = len(args.gpus)\n", + " distributed = world_size > 1\n", + " assert not distributed, 'Testing should not be done in a distributed way'\n", + " args.distributed = distributed\n", + " args.port = find_free_port()\n", + " try:\n", + " mp.set_start_method('spawn')\n", + " except RuntimeError:\n", + " pass\n", + " mp.spawn(main_worker,\n", + " args=(world_size, args),\n", + " nprocs=world_size,\n", + " join=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "62bd6e1f-b79f-4b49-877c-8887ec5ea9fc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/zeus/miniconda3/envs/cloudspace/bin/python: Error while finding module specification for 'src.test' (ModuleNotFoundError: No module named 'src')\n", + "/home/zeus/miniconda3/envs/cloudspace/bin/python: Error while finding module specification for 'src.test' (ModuleNotFoundError: No module named 'src')\n", + "/home/zeus/miniconda3/envs/cloudspace/bin/python: Error while finding module specification for 'src.test' (ModuleNotFoundError: No module named 'src')\n", + "/home/zeus/miniconda3/envs/cloudspace/bin/python: Error while finding module specification for 'src.test' (ModuleNotFoundError: No module named 'src')\n" + ] + } + ], + "source": [ + "!bash ../test.sh coco20i 5 upperbound [0] out.log" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f50abdef-5bd1-425e-863e-12e5467f573e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'/teamspace/studios/this_studio/DIaM/DIaM/src'" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%pwd" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "a5dee953-48ea-49a3-9011-a2610478d6e7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/teamspace/studios/this_studio/DIaM/DIaM\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/IPython/core/magics/osm.py:417: UserWarning: using dhist requires you to install the `pickleshare` library.\n", + " self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n" + ] + } + ], + "source": [ + "%cd ../" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "30f120dc-4abb-45ed-8b2a-202352b79d45", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "adapt_iter: 100\n", + "batch_size_val: 50\n", + "bins: [1, 2, 3, 6]\n", + "bottleneck_dim: 512\n", + "ckpt_path: model_ckpt/\n", + "ckpt_used: model\n", + "cls_lr: 0.00125\n", + "data_name: coco\n", + "data_root: data/coco/\n", + "debug: False\n", + "dropout: 0.1\n", + "fine_tune_base_classifier: True\n", + "generate_new_support_set_for_each_task: False\n", + "gpus: [0]\n", + "image_size: 417\n", + "layers: 50\n", + "load_model_id: 1\n", + "m_scale: False\n", + "manual_seed: 2023\n", + "mean: [0.485, 0.456, 0.406]\n", + "n_runs: 5\n", + "pi_estimation_strategy: upperbound\n", + "pi_update_at: [10]\n", + "pin_memory: True\n", + "pretrained: True\n", + "shot: 5\n", + "shuffle_test_data: True\n", + "split: 0\n", + "std: [0.229, 0.224, 0.225]\n", + "support_only_one_novel: True\n", + "test_num: 10000\n", + "train_list: lists/coco/train.txt\n", + "use_split_coco: True\n", + "use_training_images_for_supports: False\n", + "val_list: lists/coco/val.txt\n", + "weights: [100, 1, 1, 100]\n", + "workers: 3\n", + "==> Running evaluation script\n", + "Traceback (most recent call last):\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/runpy.py\", line 196, in _run_module_as_main\n", + " return _run_code(code, main_globals, None,\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/runpy.py\", line 86, in _run_code\n", + " exec(code, run_globals)\n", + " File \"/teamspace/studios/this_studio/DIaM/DIaM/src/test.py\", line 183, in \n", + " mp.spawn(main_worker,\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/multiprocessing/spawn.py\", line 241, in spawn\n", + " return start_processes(fn, args, nprocs, join, daemon, start_method=\"spawn\")\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/multiprocessing/spawn.py\", line 197, in start_processes\n", + " while not context.join():\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/multiprocessing/spawn.py\", line 158, in join\n", + " raise ProcessRaisedException(msg, error_index, failed_process.pid)\n", + "torch.multiprocessing.spawn.ProcessRaisedException: \n", + "\n", + "-- Process 0 terminated with the following error:\n", + "Traceback (most recent call last):\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/multiprocessing/spawn.py\", line 68, in _wrap\n", + " fn(i, *args)\n", + " File \"/teamspace/studios/this_studio/DIaM/DIaM/src/test.py\", line 29, in main_worker\n", + " setup(args, rank, world_size)\n", + " File \"/teamspace/studios/this_studio/DIaM/DIaM/src/util.py\", line 41, in setup\n", + " dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/c10d_logger.py\", line 86, in wrapper\n", + " func_return = func(*args, **kwargs)\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py\", line 1184, in init_process_group\n", + " default_pg, _ = _new_process_group_helper(\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py\", line 1339, in _new_process_group_helper\n", + " backend_class = ProcessGroupNCCL(\n", + "ValueError: ProcessGroupNCCL is only supported with GPUs, no GPUs found!\n", + "\n", + "adapt_iter: 100\n", + "batch_size_val: 50\n", + "bins: [1, 2, 3, 6]\n", + "bottleneck_dim: 512\n", + "ckpt_path: model_ckpt/\n", + "ckpt_used: model\n", + "cls_lr: 0.00125\n", + "data_name: coco\n", + "data_root: data/coco/\n", + "debug: False\n", + "dropout: 0.1\n", + "fine_tune_base_classifier: True\n", + "generate_new_support_set_for_each_task: False\n", + "gpus: [0]\n", + "image_size: 417\n", + "layers: 50\n", + "load_model_id: 1\n", + "m_scale: False\n", + "manual_seed: 2023\n", + "mean: [0.485, 0.456, 0.406]\n", + "n_runs: 5\n", + "pi_estimation_strategy: upperbound\n", + "pi_update_at: [10]\n", + "pin_memory: True\n", + "pretrained: True\n", + "shot: 5\n", + "shuffle_test_data: True\n", + "split: 1\n", + "std: [0.229, 0.224, 0.225]\n", + "support_only_one_novel: True\n", + "test_num: 10000\n", + "train_list: lists/coco/train.txt\n", + "use_split_coco: True\n", + "use_training_images_for_supports: False\n", + "val_list: lists/coco/val.txt\n", + "weights: [100, 1, 1, 100]\n", + "workers: 3\n", + "==> Running evaluation script\n", + "Traceback (most recent call last):\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/runpy.py\", line 196, in _run_module_as_main\n", + " return _run_code(code, main_globals, None,\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/runpy.py\", line 86, in _run_code\n", + " exec(code, run_globals)\n", + " File \"/teamspace/studios/this_studio/DIaM/DIaM/src/test.py\", line 183, in \n", + " mp.spawn(main_worker,\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/multiprocessing/spawn.py\", line 241, in spawn\n", + " return start_processes(fn, args, nprocs, join, daemon, start_method=\"spawn\")\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/multiprocessing/spawn.py\", line 197, in start_processes\n", + " while not context.join():\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/multiprocessing/spawn.py\", line 158, in join\n", + " raise ProcessRaisedException(msg, error_index, failed_process.pid)\n", + "torch.multiprocessing.spawn.ProcessRaisedException: \n", + "\n", + "-- Process 0 terminated with the following error:\n", + "Traceback (most recent call last):\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/multiprocessing/spawn.py\", line 68, in _wrap\n", + " fn(i, *args)\n", + " File \"/teamspace/studios/this_studio/DIaM/DIaM/src/test.py\", line 29, in main_worker\n", + " setup(args, rank, world_size)\n", + " File \"/teamspace/studios/this_studio/DIaM/DIaM/src/util.py\", line 41, in setup\n", + " dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/c10d_logger.py\", line 86, in wrapper\n", + " func_return = func(*args, **kwargs)\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py\", line 1184, in init_process_group\n", + " default_pg, _ = _new_process_group_helper(\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py\", line 1339, in _new_process_group_helper\n", + " backend_class = ProcessGroupNCCL(\n", + "ValueError: ProcessGroupNCCL is only supported with GPUs, no GPUs found!\n", + "\n", + "adapt_iter: 100\n", + "batch_size_val: 50\n", + "bins: [1, 2, 3, 6]\n", + "bottleneck_dim: 512\n", + "ckpt_path: model_ckpt/\n", + "ckpt_used: model\n", + "cls_lr: 0.00125\n", + "data_name: coco\n", + "data_root: data/coco/\n", + "debug: False\n", + "dropout: 0.1\n", + "fine_tune_base_classifier: True\n", + "generate_new_support_set_for_each_task: False\n", + "gpus: [0]\n", + "image_size: 417\n", + "layers: 50\n", + "load_model_id: 1\n", + "m_scale: False\n", + "manual_seed: 2023\n", + "mean: [0.485, 0.456, 0.406]\n", + "n_runs: 5\n", + "pi_estimation_strategy: upperbound\n", + "pi_update_at: [10]\n", + "pin_memory: True\n", + "pretrained: True\n", + "shot: 5\n", + "shuffle_test_data: True\n", + "split: 2\n", + "std: [0.229, 0.224, 0.225]\n", + "support_only_one_novel: True\n", + "test_num: 10000\n", + "train_list: lists/coco/train.txt\n", + "use_split_coco: True\n", + "use_training_images_for_supports: False\n", + "val_list: lists/coco/val.txt\n", + "weights: [100, 1, 1, 100]\n", + "workers: 3\n", + "==> Running evaluation script\n", + "Traceback (most recent call last):\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/runpy.py\", line 196, in _run_module_as_main\n", + " return _run_code(code, main_globals, None,\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/runpy.py\", line 86, in _run_code\n", + " exec(code, run_globals)\n", + " File \"/teamspace/studios/this_studio/DIaM/DIaM/src/test.py\", line 183, in \n", + " mp.spawn(main_worker,\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/multiprocessing/spawn.py\", line 241, in spawn\n", + " return start_processes(fn, args, nprocs, join, daemon, start_method=\"spawn\")\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/multiprocessing/spawn.py\", line 197, in start_processes\n", + " while not context.join():\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/multiprocessing/spawn.py\", line 158, in join\n", + " raise ProcessRaisedException(msg, error_index, failed_process.pid)\n", + "torch.multiprocessing.spawn.ProcessRaisedException: \n", + "\n", + "-- Process 0 terminated with the following error:\n", + "Traceback (most recent call last):\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/multiprocessing/spawn.py\", line 68, in _wrap\n", + " fn(i, *args)\n", + " File \"/teamspace/studios/this_studio/DIaM/DIaM/src/test.py\", line 29, in main_worker\n", + " setup(args, rank, world_size)\n", + " File \"/teamspace/studios/this_studio/DIaM/DIaM/src/util.py\", line 41, in setup\n", + " dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/c10d_logger.py\", line 86, in wrapper\n", + " func_return = func(*args, **kwargs)\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py\", line 1184, in init_process_group\n", + " default_pg, _ = _new_process_group_helper(\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py\", line 1339, in _new_process_group_helper\n", + " backend_class = ProcessGroupNCCL(\n", + "ValueError: ProcessGroupNCCL is only supported with GPUs, no GPUs found!\n", + "\n", + "adapt_iter: 100\n", + "batch_size_val: 50\n", + "bins: [1, 2, 3, 6]\n", + "bottleneck_dim: 512\n", + "ckpt_path: model_ckpt/\n", + "ckpt_used: model\n", + "cls_lr: 0.00125\n", + "data_name: coco\n", + "data_root: data/coco/\n", + "debug: False\n", + "dropout: 0.1\n", + "fine_tune_base_classifier: True\n", + "generate_new_support_set_for_each_task: False\n", + "gpus: [0]\n", + "image_size: 417\n", + "layers: 50\n", + "load_model_id: 1\n", + "m_scale: False\n", + "manual_seed: 2023\n", + "mean: [0.485, 0.456, 0.406]\n", + "n_runs: 5\n", + "pi_estimation_strategy: upperbound\n", + "pi_update_at: [10]\n", + "pin_memory: True\n", + "pretrained: True\n", + "shot: 5\n", + "shuffle_test_data: True\n", + "split: 3\n", + "std: [0.229, 0.224, 0.225]\n", + "support_only_one_novel: True\n", + "test_num: 10000\n", + "train_list: lists/coco/train.txt\n", + "use_split_coco: True\n", + "use_training_images_for_supports: False\n", + "val_list: lists/coco/val.txt\n", + "weights: [100, 1, 1, 100]\n", + "workers: 3\n", + "==> Running evaluation script\n", + "Traceback (most recent call last):\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/runpy.py\", line 196, in _run_module_as_main\n", + " return _run_code(code, main_globals, None,\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/runpy.py\", line 86, in _run_code\n", + " exec(code, run_globals)\n", + " File \"/teamspace/studios/this_studio/DIaM/DIaM/src/test.py\", line 183, in \n", + " mp.spawn(main_worker,\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/multiprocessing/spawn.py\", line 241, in spawn\n", + " return start_processes(fn, args, nprocs, join, daemon, start_method=\"spawn\")\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/multiprocessing/spawn.py\", line 197, in start_processes\n", + " while not context.join():\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/multiprocessing/spawn.py\", line 158, in join\n", + " raise ProcessRaisedException(msg, error_index, failed_process.pid)\n", + "torch.multiprocessing.spawn.ProcessRaisedException: \n", + "\n", + "-- Process 0 terminated with the following error:\n", + "Traceback (most recent call last):\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/multiprocessing/spawn.py\", line 68, in _wrap\n", + " fn(i, *args)\n", + " File \"/teamspace/studios/this_studio/DIaM/DIaM/src/test.py\", line 29, in main_worker\n", + " setup(args, rank, world_size)\n", + " File \"/teamspace/studios/this_studio/DIaM/DIaM/src/util.py\", line 41, in setup\n", + " dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/c10d_logger.py\", line 86, in wrapper\n", + " func_return = func(*args, **kwargs)\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py\", line 1184, in init_process_group\n", + " default_pg, _ = _new_process_group_helper(\n", + " File \"/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py\", line 1339, in _new_process_group_helper\n", + " backend_class = ProcessGroupNCCL(\n", + "ValueError: ProcessGroupNCCL is only supported with GPUs, no GPUs found!\n", + "\n" + ] + } + ], + "source": [ + "!bash test.sh coco20i 5 upperbound [0] out.log" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "409d12eb-3377-420d-ac56-ba41c2880eee", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LICENSE\t\tconfig\t lists models.zip\t src\t util\n", + "README.md\tdata\t model out.log\t\t test.sh\n", + "cocdownload.py\tinitmodel model_ckpt requirements.txt train.py\n" + ] + } + ], + "source": [ + "!ls" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "767edacb-b263-4c6e-966c-c2f42aafcf8c", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/train.py b/src/train.py new file mode 100644 index 0000000..d173f86 --- /dev/null +++ b/src/train.py @@ -0,0 +1,268 @@ +import argparse +import yaml +import os +import time +import torch +from tqdm import tqdm +import torch.nn as nn +from typing import Tuple +import torch.nn.functional as F +import torch.distributed as dist +from .model.pspnet import get_model +from torch.nn.parallel import DistributedDataParallel as DDP +from .dataset.datasetV1 import get_val_loader +from .util import get_model_dir, fast_intersection_and_union, setup_seed, resume_random_state, find_free_port, setup, \ + cleanup, get_cfg +from .model.pspnet import get_model +from .classifierV1 import Classifier +import torch +import matplotlib.pyplot as plt +import numpy as np +import torchvision +import sys + + + +def main(rank: int, world_size: int, args: argparse.Namespace) -> None: + # 1. Load Configuration + torch.cuda.empty_cache() + with open(args.config, 'r') as f: + cfg = yaml.safe_load(f) + setup(args, rank, world_size) + # 2. Setup Device (GPU or CPU) + device = torch.device('cuda:{}'.format(args.gpus) if torch.cuda.is_available() and args.gpus != -1 else 'cpu') + #setup the args + print(f"==> Running setup script") + # setup(args, rank, world_size) + setup_seed(cfg['EVALUATION']['manual_seed']) + # 3. Datasets and DataLoaders + # ========== Data ========== + val_loader = get_val_loader(cfg, args) + print(f'rank in raw....',{rank}) + + # ========== Model ========== + model = get_model(cfg, args).to(0) + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = DDP(model, device_ids=[0]) + + # cfg: dict, args: argparse.Namespace, run_id=None + root = get_model_dir(cfg, args) + + print("=> Creating the model") + if cfg['EVALUATION']['ckpt_used'] is not None: + filepath = os.path.join(root, f"{cfg['EVALUATION']['ckpt_used']}.pth") + assert os.path.isfile(filepath), filepath + checkpoint = torch.load(filepath) + model.load_state_dict(checkpoint['state_dict']) + print("=> Loaded weight '{}'".format(filepath)) + else: + print("=> Not loading anything") + + # ========== Test ========== + print('starting validation in ptrain file....') + + validateNow(args=args, val_loader=val_loader, model=model, cfg=cfg) + # cleanup() + +def validateNow(args: argparse.Namespace, val_loader: torch.utils.data.DataLoader, model: DDP,cfg: dict) -> Tuple[torch.tensor, torch.tensor]: + print('\n==> Start testing ({} runs)'.format(cfg['EVALUATION']['n_runs']), flush=True) + random_state = setup_seed(cfg['EVALUATION']['manual_seed'], return_old_state=True) + device = torch.device('cuda:{}'.format(dist.get_rank())) + model.eval() + + c = model.module.bottleneck_dim + h = model.module.feature_res[0] + w = model.module.feature_res[1] + print(f'channel .....{c}') + nb_episodes = len(val_loader) if cfg['EVALUATION']['test_num'] == -1 else int(cfg['EVALUATION']['test_num'] / cfg['EVALUATION']['batch_size_val']) + runtimes = torch.zeros(cfg['EVALUATION']['n_runs']) + base_mIoU, novel_mIoU = [torch.zeros(cfg['EVALUATION']['n_runs'], device=device) for _ in range(2)] + + # ========== Perform the runs ========== + for run in range(cfg['EVALUATION']['n_runs']): + print('Run', run + 1, 'of', cfg['EVALUATION']['n_runs']) + # The order of classes in the following tensors is the same as the order of classifier (novels at last) + cls_intersection = torch.zeros(args.num_classes_tr + args.num_classes_val) + # print(f'CLS_INTERSECTION.....{cls_intersection.shape}') + print(f'CLS_INTERSECTION.....{args.num_classes_tr}') + # print(f'CLS_INTERSECTION.....{args.num_classes_val}') + cls_union = torch.zeros(args.num_classes_tr + args.num_classes_val) + cls_target = torch.zeros(args.num_classes_tr + args.num_classes_val) + # print(f'CLS_INTERSECTION.....{args.num_classes_tr} + {args.num_classes_val}') + # print(f'CLS_INTERSECTION.....{args.num_classes_tr}') + runtime = 0 + features_s, gt_s = None, None + if not cfg['EVALUATION']['generate_new_support_set_for_each_task']: + with torch.no_grad(): + spprt_imgs, s_label = val_loader.dataset.generate_support([], remove_them_from_query_data_list=True) + nb_episodes = len(val_loader) if cfg['EVALUATION']['test_num'] == -1 else nb_episodes # Updates nb_episodes since some images were removed by generate_support + spprt_imgs = spprt_imgs.to(device, non_blocking=True) + s_label = s_label.to(device, non_blocking=True) + # print(f'support shape model extractor before........{spprt_imgs.shape}') + features_s = model.module.extract_features(spprt_imgs).detach().view((args.num_classes_val, cfg['EVALUATION']['shot'], c, h, w)) + # print(f'features extracted here is...{features_s.shape}') + # print(f'NUM_CLASSES_IN_VAL....{args.num_classes_val}') + gt_s = s_label.view((args.num_classes_val, cfg['EVALUATION']['shot'], cfg['DATA']['image_size'], cfg['DATA']['image_size'])) + print(f'features extracted here is...{gt_s.shape}') + # sys.exit(1) + for _ in tqdm(range(nb_episodes), leave=True): + t0 = time.time() + with torch.no_grad(): + try: + loader_output = next(iter_loader) + except (UnboundLocalError, StopIteration): + iter_loader = iter(val_loader) + loader_output = next(iter_loader) + qry_img, q_label, q_valid_pix, img_path = loader_output + # hape of query image in train... torch.Size([20, 3, 417, 417]) + # print(f'shape of query image in train...', qry_img.shape) + # # shape of label image in train... torch.Size([20, 3, 417, 417]) + # print(f'shape of label image in train...', qry_img.shape) + # # shape of VALID_PIXEL in train... torch.Size([20, 417, 417]) + # print(f'shape of VALID_PIXEL in train...', q_valid_pix.shape) + image_height = qry_img.shape[1] # Height + image_width = qry_img.shape[2] # Width + # print(f"Image Height: {image_height}, Width: {image_width}") + qry_img = qry_img.to(device, non_blocking=True) + q_label = q_label.to(device, non_blocking=True) + features_q = model.module.extract_features(qry_img).detach().unsqueeze(1) + # FEATURE AFTER EXTRACTION IS.....torch.Size([20, 1, 512, 53, 53]) + # print(f'FEATURE AFTER EXTRACTION IS.....{features_q.shape}') + valid_pixels_q = q_valid_pix.unsqueeze(1).to(device) + # VALID PIXEL SHAPE HERE.....torch.Size([20, 1, 417, 417]) + # print(f'VALID PIXEL SHAPE HERE.....{valid_pixels_q.shape}') + gt_q = q_label.unsqueeze(1) + # GROUND LABEL SHAPE.....torch.Size([20, 1, 417, 417]) + print(f'GROUND LABEL SHAPE.....{gt_q.shape}') + + query_image_path_list = list(img_path) + if cfg['EVALUATION']['generate_new_support_set_for_each_task']: + spprt_imgs, s_label = val_loader.dataset.generate_support(query_image_path_list) + spprt_imgs = spprt_imgs.to(device, non_blocking=True) + s_label = s_label.to(device, non_blocking=True) + features_s = model.module.extract_features(spprt_imgs).detach().view((args.num_classes_val, args.shot, c, h, w)) + gt_s = s_label.view((args.num_classes_val, cfg['EVALUATION']['shot'], cfg['DATA']['image_size'], cfg['DATA']['image_size'])) + print(f'GROUND TRUTH SHAPE INITIAlIZED.....{gt_s.shape}') + # sys.exit(1) + # =========== Initialize the classifier and run the method =============== + # print("SHAPE OF BASE WEIGHT BEFORE TRANSPOSE...",model.module.classifier.weight.shape ) + if len(model.module.classifier.weight.shape) == 2: + base_weight = model.module.classifier.weight.detach().clone().mT # or .T + else: + num_dims = model.module.classifier.weight.ndim # Get the number of dimensions + base_weight = model.module.classifier.weight.detach().clone().permute(*torch.arange(num_dims - 1, -1, -1)) + feature_dim = model.module.get_feature_dim() + # print(f' THE FEATURE DIMENSION IS ....{feature_dim}') + # Now you can use base_weight + # print("NEW SHAPE AFTER TRANSPOSE...", base_weight.shape) + # print(f'the feature dimension from the model is {num_dims}') + # base_weight = model.module.classifier.weight.detach().clone().T + base_bias = model.module.classifier.bias.detach().clone() + # (self, configer, backbone, feature_dim, num_prototypes, model): + # print("INITIAL SHAPE features_q...", features_q.shape) + # print("INITIAL SHAPE gt_s...", gt_s.shape) + print(f"number of novel classes in training before classifier passed to...{args.num_novel_classes}") + classifier = DIaMClassifier(args, base_weight, base_bias, n_tasks=features_q.size(0), cfg=cfg, backbone=model,features_s=features_s,gt_s=gt_s,num_novel_classes= args.num_novel_classes,feature_dim=feature_dim) + # print(f'After the training here..') + # classifier = Classifier(args, base_weight, base_bias, n_tasks=features_q.size(0)) + print(f'Before the init_prototype section feaures support {features_s.shape} and ground support {gt_s.shape}') + classifier.init_prototypes(features_s, gt_s) + + # INITIAL SHAPE features_s... torch.Size([20, 1, 512, 53, 53]) + # print("INITIAL SHAPE features_s...", features_s.shape) + # INITIAL SHAPE features_q... torch.Size([20, 1, 512, 53, 53]) + # print("INITIAL SHAPE features_q...", features_q.shape) + # INITIAL SHAPE gt_s... torch.Size([20, 1, 417, 417]) + print("INITIAL SHAPE gt_s...", gt_s.shape) + # INITIAL SHAPE valid_pixels_q... torch.Size([20, 1, 417, 417]) + # print("INITIAL SHAPE valid_pixels_q...", valid_pixels_q.shape) + # sys.exit(1) + classifier.optimize(features_s, features_q, gt_s, valid_pixels_q) + # print(f'After optimization here....') + runtime += time.time() - t0 + + # =========== Perform inference and compute metrics =============== + logits = classifier.get_logits(features_q).detach() + probas = classifier.get_probas(logits) + print(f'Probability ....{probas.shape} and ground truth shape ... {gt_q.shape}') + # Probability ....torch.Size([20, 1, 20, 81, 53]) and ground truth shape ... torch.Size([20, 1, 417, 417]) + intersection, union, target = fast_intersection_and_union(probas, gt_q) # [batch_size_val, 1, num_classes] + print(f"Intersection shape: {intersection.shape}") + print(f"Union shape: {union.shape}") + print(f"Target shape: {target.shape}") + intersection, union, target = intersection.squeeze(1).cpu(), union.squeeze(1).cpu(), target.squeeze(1).cpu() + + print(f"Squeezed Intersection shape: {intersection.shape}") + print(f"Squeezed Union shape: {union.shape}") + print(f"Squeezed Target shape: {target.shape}") + + try: + # cls_intersection += intersection.sum(0) + # cls_union += union.sum(0) + # cls_target += target.sum(0) + cls_intersection[args.num_classes_tr:] += intersection.sum(0) # Change 2: Update the relevant part + cls_union[args.num_classes_tr:] += union.sum(0) # Change 2: Update the relevant part + cls_target[args.num_classes_tr:] += target.sum(0) + except RuntimeError as e: + print(f"Error in cls_intersection update: {e}. Skipping this operation.") + print(f"Error in cls_intersection update: {e}. Skipping this operation.") + print(f"cls_intersection shape: {cls_intersection.shape}") + print(f"cls_union shape: {cls_union.shape}") + print(f"cls_target shape: {cls_target.shape}") + # cls_intersection += intersection.sum(0) + # cls_union += union.sum(0) + # cls_target += target.sum(0) + + base_count, novel_count, sum_base_IoU, sum_novel_IoU = 4 * [0] + for i, class_ in enumerate(val_loader.dataset.all_classes): + if cls_union[i] == 0: + continue + IoU = cls_intersection[i] / (cls_union[i] + 1e-10) + print("Class {}: \t{:.4f}".format(class_, IoU)) + if class_ in val_loader.dataset.base_class_list: + sum_base_IoU += IoU + base_count += 1 + elif class_ in val_loader.dataset.novel_class_list: + sum_novel_IoU += IoU + novel_count += 1 + + avg_base_IoU, avg_novel_IoU = sum_base_IoU / base_count, sum_novel_IoU / novel_count + print('Mean base IoU: {:.4f}, Mean novel IoU: {:.4f}'.format(avg_base_IoU, avg_novel_IoU), flush=True) + + base_mIoU[run], novel_mIoU[run] = avg_base_IoU, avg_novel_IoU + runtimes[run] = runtime + + agg_mIoU = (base_mIoU.mean() + novel_mIoU.mean()) / 2 + print('==>') + print('Average of base mIoU: {:.4f}\tAverage of novel mIoU: {:.4f} \t(over {} runs)'.format( + base_mIoU.mean(), novel_mIoU.mean(), args.n_runs)) + print('Mean --- {:.4f}'.format(agg_mIoU), flush=True) + print('Average runtime / run --- {:.1f}\n'.format(runtimes.mean())) + + resume_random_state(random_state) + return agg_mIoU + +def evaluate(model, data_loader, device, cfg): + return 0 + + +if __name__ == "__main__": + # Original tensor + + parser = argparse.ArgumentParser(description='DIaM Training and Testing Script') + parser.add_argument('--config', type=str, help='Path to YAML config file') + parser.add_argument('--split', type=int, help='Data split to use') + parser.add_argument('--shot', type=int, help='Number of shots') + parser.add_argument('--gpus', type=int, default=0, help='GPU ID to use (-1 for CPU)') + args = parser.parse_args() + # if args.debug: + # args.test_num = 64 + # args.n_runs = 2 + + world_size = len(str(args.gpus)) + distributed = world_size > 1 + assert not distributed, 'Testing should not be done in a distributed way' + args.distributed = distributed + args.port = find_free_port() + main(0, world_size, args) + diff --git a/src/util.py b/src/util.py index 3d3b9c1..f963dc9 100755 --- a/src/util.py +++ b/src/util.py @@ -71,17 +71,22 @@ def get_next_run_id(args) -> int: return max(map(int, os.listdir(args.ckpt_path))) + 1 -def get_model_dir(args: argparse.Namespace, run_id=None) -> str: +def get_model_dir(cfg:dict,args: argparse.Namespace, run_id=None) -> str: """ Obtain the directory to save/load the model """ if run_id is None: - run_id = args.load_model_id - path = os.path.join(args.ckpt_path, + run_id = cfg['EVALUATION']['load_model_id'] + path = os.path.join(cfg['EVALUATION']['ckpt_path'], str(run_id), - args.data_name, - f'split{args.split}', - f'pspnet_resnet{args.layers}') + cfg['DATA']['data_name'], + f'split{cfg["DATA"]["split"]}', + f'pspnet_resnet{cfg["MODEL"]["layers"]}') + # path = os.path.join(cfg['EVALUATION']['ckpt_path'], + # str(run_id), + # cfg['EVALUATION']['data_name'], + # f'split{cfg["DATA"]["split"]}', + # f'pspnet_resnet{cfg["MODEL"]["layers"]}') return path @@ -244,6 +249,7 @@ def __init__(self, init_dict=None, key_list=None, new_allowed=False): super(CfgNode, self).__init__(init_dict) def __getattr__(self, name): + print(f'ATTRIBUTE NAME....{name}') if name in self: return self[name] else: diff --git a/train.py b/train.py new file mode 100644 index 0000000..b9d5f5b --- /dev/null +++ b/train.py @@ -0,0 +1,416 @@ +import os +import datetime +import random +import time +import cv2 +import numpy as np +import logging +import argparse +import math +from visdom import Visdom +import os.path as osp +from shutil import copyfile + +import torch +import torch.backends.cudnn as cudnn +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.multiprocessing as mp +import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler + +from tensorboardX import SummaryWriter + +from model import PSPNet + +from util import dataset +from util import transform, transform_tri, config +from util.util import AverageMeter, poly_learning_rate, intersectionAndUnionGPU, get_model_para_number, setup_seed, get_logger, get_save_path, \ + is_same_model, fix_bn, sum_list, check_makedirs + +cv2.ocl.setUseOpenCL(False) +cv2.setNumThreads(0) +# os.environ["CUDA_VISIBLE_DEVICES"] = '8' + +def get_parser(): + parser = argparse.ArgumentParser(description='PyTorch Semantic Segmentation') + parser.add_argument('--arch', type=str, default='PSPNet') # + parser.add_argument('--viz', action='store_true', default=False) + parser.add_argument('--config', type=str, default='config/pascal/pascal_split0_vgg_base.yaml', help='config file') # coco/coco_split0_resnet50.yaml + parser.add_argument('--local_rank', type=int, default=-1, help='number of cpu threads to use during batch generation') + parser.add_argument('--opts', help='see config/ade20k/ade20k_pspnet50.yaml for all options', default=None, nargs=argparse.REMAINDER) + args = parser.parse_args() + assert args.config is not None + cfg = config.load_cfg_from_cfg_file(args.config) + cfg = config.merge_cfg_from_args(cfg, args) + if args.opts is not None: + cfg = config.merge_cfg_from_list(cfg, args.opts) + return cfg + + +def get_model(args): + + model = eval(args.arch).OneModel(args) + optimizer = model.get_optim(model, args, LR=args.base_lr) + + if hasattr(model,'freeze_modules'): + model.freeze_modules(model) + + if args.distributed: + # Initialize Process Group + dist.init_process_group(backend='nccl') + print('args.local_rank: ', args.local_rank) + torch.cuda.set_device(args.local_rank) + device = torch.device('cuda', args.local_rank) + model.to(device) + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) + else: + model = model.cuda() + + # Resume + get_save_path(args) + check_makedirs(args.snapshot_path) + check_makedirs(args.result_path) + + if args.resume: + resume_path = osp.join(args.snapshot_path, args.resume) + if os.path.isfile(resume_path): + if main_process(): + logger.info("=> loading checkpoint '{}'".format(resume_path)) + checkpoint = torch.load(resume_path, map_location=lambda storage, loc: storage.cuda()) + args.start_epoch = checkpoint['epoch'] + new_param = checkpoint['state_dict'] + try: + model.load_state_dict(new_param) + except RuntimeError: # 1GPU loads mGPU model + for key in list(new_param.keys()): + new_param[key[7:]] = new_param.pop(key) + model.load_state_dict(new_param) + optimizer.load_state_dict(checkpoint['optimizer']) + if main_process(): + logger.info("=> loaded checkpoint '{}' (epoch {})".format(resume_path, checkpoint['epoch'])) + else: + if main_process(): + logger.info("=> no checkpoint found at '{}'".format(resume_path)) + + # Get model para. + total_number, learnable_number = get_model_para_number(model) + if main_process(): + print('Number of Parameters: %d' % (total_number)) + print('Number of Learnable Parameters: %d' % (learnable_number)) + + time.sleep(5) + return model, optimizer + +def main_process(): + return not args.distributed or (args.distributed and (args.local_rank == 0)) + +def main(): + global args, logger, writer + args = get_parser() + logger = get_logger() + args.distributed = True if torch.cuda.device_count() > 1 else False + if main_process(): + print(args) + + if args.manual_seed is not None: + setup_seed(args.manual_seed, args.seed_deterministic) + + assert args.classes > 1 + assert args.zoom_factor in [1, 2, 4, 8] + assert (args.train_h - 1) % 8 == 0 and (args.train_w - 1) % 8 == 0 + + if main_process(): + logger.info("=> creating model ...") + model, optimizer = get_model(args) + if main_process(): + logger.info(model) + if main_process() and args.viz: + writer = SummaryWriter(args.result_path) + +# ---------------------- DATASET ---------------------- + value_scale = 255 + mean = [0.485, 0.456, 0.406] + mean = [item * value_scale for item in mean] + std = [0.229, 0.224, 0.225] + std = [item * value_scale for item in std] + # Train + train_transform = transform.Compose([ + transform.RandScale([args.scale_min, args.scale_max]), + transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.padding_label), + transform.RandomGaussianBlur(), + transform.RandomHorizontalFlip(), + transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.padding_label), + transform.ToTensor(), + transform.Normalize(mean=mean, std=std)]) + if args.data_set == 'pascal' or args.data_set == 'coco': + train_data = dataset.BaseData(split=args.split, mode='train', data_root=args.data_root, data_list=args.train_list, \ + data_set=args.data_set, use_split_coco=args.use_split_coco, \ + transform=train_transform, main_process=main_process(), batch_size=args.batch_size) + train_sampler = DistributedSampler(train_data) if args.distributed else None + train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, num_workers=args.workers, \ + pin_memory=True, sampler=train_sampler, drop_last=True, \ + shuffle=False if args.distributed else True) + # Val + if args.evaluate: + if args.resized_val: + val_transform = transform.Compose([ + transform.Resize(size=args.val_size), + transform.ToTensor(), + transform.Normalize(mean=mean, std=std)]) + else: + val_transform = transform.Compose([ + transform.test_Resize(size=args.val_size), + transform.ToTensor(), + transform.Normalize(mean=mean, std=std)]) + if args.data_set == 'pascal' or args.data_set == 'coco': + val_data = dataset.BaseData(split=args.split, mode='val', data_root=args.data_root, data_list=args.val_list, \ + data_set=args.data_set, use_split_coco=args.use_split_coco, \ + transform=val_transform, main_process=main_process(), batch_size=args.batch_size_val) + val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=None) + if args.ori_resize: + assert args.batch_size_val == 1 + +# ---------------------- TRAINVAL ---------------------- + global best_miou, best_epoch, keep_epoch, val_num + best_miou = 0. + best_epoch = 0 + keep_epoch = 0 + val_num = 0 + + start_time = time.time() + + for epoch in range(args.start_epoch, args.epochs): + if keep_epoch == args.stop_interval: + break + if args.fix_random_seed_val: + setup_seed(args.manual_seed + epoch, args.seed_deterministic) + + epoch_log = epoch + 1 + keep_epoch += 1 + if args.distributed: + train_sampler.set_epoch(epoch) + + # ---------------------- TRAIN ---------------------- + train(train_loader, model, optimizer, epoch) + + # save model for + if (epoch % args.save_freq == 0) and (epoch > 0) and main_process(): + filename = args.snapshot_path + '/epoch_{}.pth'.format(epoch) + logger.info('Saving checkpoint to: ' + filename) + if osp.exists(filename): + os.remove(filename) + torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, filename) + + # ----------------------- VAL ----------------------- + if args.evaluate and epoch%1==0: + mIoU = validate(val_loader, model) + val_num += 1 + if main_process() and args.viz: + writer.add_scalar('mIoU_val', mIoU, epoch_log) + + # save model for + if (mIoU > best_miou): + best_miou, best_epoch = mIoU, epoch + keep_epoch = 0 + if main_process(): + filename = args.snapshot_path + '/train_epoch_' + str(epoch) + '_{:.4f}'.format(best_miou) + '.pth' + logger.info('Saving checkpoint to: ' + filename) + torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, filename) + copyfile(filename, args.snapshot_path + '/best.pth') + + total_time = time.time() - start_time + t_m, t_s = divmod(total_time, 60) + t_h, t_m = divmod(t_m, 60) + total_time = '{:02d}h {:02d}m {:02d}s'.format(int(t_h), int(t_m), int(t_s)) + + if main_process(): + print('\nEpoch: {}/{} \t Total running time: {}'.format(epoch_log, args.epochs, total_time)) + print('The number of models validated: {}'.format(val_num)) + print('\n<<<<<<<<<<<<<<<<<<<<<<<<<<<<< Final Best Result <<<<<<<<<<<<<<<<<<<<<<<<<<<<<') + print(args.arch + '\t Group:{} \t Best_mIoU:{:.4f} \t Best_step:{}'.format(args.split, best_miou, best_epoch)) + print('>'*80) + print ('%s' % datetime.datetime.now()) + + +def train(train_loader, model, optimizer, epoch): + batch_time = AverageMeter() + data_time = AverageMeter() + loss_meter = AverageMeter() + intersection_meter = AverageMeter() + union_meter = AverageMeter() + target_meter = AverageMeter() + + model.train() + + end = time.time() + val_time = 0. + max_iter = args.epochs * len(train_loader) + if main_process(): + print('Warmup: {}'.format(args.warmup)) + print(train_loader) + for i, (input, target) in enumerate(train_loader): + + data_time.update(time.time() - end - val_time) + current_iter = epoch * len(train_loader) + i + 1 + + poly_learning_rate(optimizer, args.base_lr, current_iter, max_iter, power=args.power, index_split=args.index_split, warmup=args.warmup, warmup_step=len(train_loader)//2) + + input = input.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + output, main_loss = model(x=input, y=target) + + loss = main_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + n = input.size(0) # batch_size + + intersection, union, target = intersectionAndUnionGPU(output, target, args.classes, args.ignore_label) + intersection, union, target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy() + intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target) + + accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10) # allAcc + + loss_meter.update(loss.item(), n) + + batch_time.update(time.time() - end - val_time) + end = time.time() + + remain_iter = max_iter - current_iter + remain_time = remain_iter * batch_time.avg + t_m, t_s = divmod(remain_time, 60) + t_h, t_m = divmod(t_m, 60) + remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s)) + + if (i + 1) % args.print_freq == 0 and main_process(): + logger.info('Epoch: [{}/{}][{}/{}] ' + 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' + 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) ' + 'Remain {remain_time} ' + 'Loss {loss_meter.val:.4f} ' + 'Accuracy {accuracy:.4f}.'.format(epoch+1, args.epochs, i + 1, len(train_loader), + batch_time=batch_time, + data_time=data_time, + remain_time=remain_time, + loss_meter=loss_meter, + accuracy=accuracy)) + if args.viz: + writer.add_scalar('loss_train', loss_meter.val, current_iter) + + iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) + accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10) + mIoU = np.mean(iou_class) + mAcc = np.mean(accuracy_class) + allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10) + + if main_process(): + logger.info('Train result at epoch [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(epoch, args.epochs, mIoU, mAcc, allAcc)) + + +def validate(val_loader, model): + if main_process(): + logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>') + batch_time = AverageMeter() + model_time = AverageMeter() + data_time = AverageMeter() + loss_meter = AverageMeter() + + intersection_meter = AverageMeter() + union_meter = AverageMeter() + target_meter = AverageMeter() + + class_intersection_meter = [0]*(args.classes-1) + class_union_meter = [0]*(args.classes-1) + + if args.manual_seed is not None and args.fix_random_seed_val: + setup_seed(args.manual_seed, args.seed_deterministic) + + criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label) + + model.eval() + end = time.time() + val_start = end + + iter_num = 0 + + for i, logits in enumerate(val_loader): + iter_num += 1 + data_time.update(time.time() - end) + + if args.batch_size_val == 1: + input, target, ori_label = logits + ori_label = ori_label.cuda(non_blocking=True) + else: + input, target = logits + input = input.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + start_time = time.time() + output = model(x=input, y=target) + model_time.update(time.time() - start_time) + + if args.ori_resize: + longerside = max(ori_label.size(1), ori_label.size(2)) + backmask = torch.ones(ori_label.size(0), longerside, longerside, device='cuda')*255 + backmask[0, :ori_label.size(1), :ori_label.size(2)] = ori_label + target = backmask.clone().long() + + output = F.interpolate(output, size=target.size()[1:], mode='bilinear', align_corners=True) + + loss = criterion(output, target) + + output = output.max(1)[1] + + intersection, union, new_target = intersectionAndUnionGPU(output, target, args.classes, args.ignore_label) + intersection, union, new_target = intersection.cpu().numpy(), union.cpu().numpy(), new_target.cpu().numpy() + intersection_meter.update(intersection), union_meter.update(union), target_meter.update(new_target) + for idx in range(1,len(intersection)): + class_intersection_meter[idx-1] += intersection[idx] + class_union_meter[idx-1] += union[idx] + + accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10) + loss_meter.update(loss.item(), input.size(0)) + batch_time.update(time.time() - end) + end = time.time() + if ((iter_num % 100 == 0) or (iter_num == len(val_loader))) and main_process(): + logger.info('Test: [{}/{}] ' + 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' + 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) ' + 'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f}) ' + 'Accuracy {accuracy:.4f}.'.format(iter_num, len(val_loader), + data_time=data_time, + batch_time=batch_time, + loss_meter=loss_meter, + accuracy=accuracy)) + val_time = time.time()-val_start + + iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) + mIoU = np.mean(iou_class) + + class_iou_class = [] + class_miou = 0 + for i in range(len(class_intersection_meter)): + class_iou = class_intersection_meter[i]/(class_union_meter[i]+ 1e-10) + class_iou_class.append(class_iou) + class_miou += class_iou + class_miou = class_miou*1.0 / len(class_intersection_meter) + + if main_process(): + logger.info('meanIoU---Val result: mIoU {:.4f}.'.format(class_miou)) + for i in range(len(class_intersection_meter)): + logger.info('Class_{} Result: iou_b {:.4f}.'.format(i+1, class_iou_class[i])) + logger.info('<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<') + print('total time: {:.4f}, avg inference time: {:.4f}, count: {}'.format(val_time, model_time.avg, iter_num)) + + return class_miou + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/util/config.py b/util/config.py new file mode 100644 index 0000000..a282fd1 --- /dev/null +++ b/util/config.py @@ -0,0 +1,172 @@ +# ----------------------------------------------------------------------------- +# Functions for parsing args +# ----------------------------------------------------------------------------- +import yaml +import os +from ast import literal_eval +import copy + + +class CfgNode(dict): + """ + CfgNode represents an internal node in the configuration tree. It's a simple + dict-like container that allows for attribute-based access to keys. + """ + + def __init__(self, init_dict=None, key_list=None, new_allowed=False): + # Recursively convert nested dictionaries in init_dict into CfgNodes + init_dict = {} if init_dict is None else init_dict + key_list = [] if key_list is None else key_list + for k, v in init_dict.items(): + if type(v) is dict: + # Convert dict to CfgNode + init_dict[k] = CfgNode(v, key_list=key_list + [k]) + super(CfgNode, self).__init__(init_dict) + + def __getattr__(self, name): + if name in self: + return self[name] + else: + raise AttributeError(name) + + def __setattr__(self, name, value): + self[name] = value + + def __str__(self): + def _indent(s_, num_spaces): + s = s_.split("\n") + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s + + r = "" + s = [] + for k, v in sorted(self.items()): + seperator = "\n" if isinstance(v, CfgNode) else " " + attr_str = "{}:{}{}".format(str(k), seperator, str(v)) + attr_str = _indent(attr_str, 2) + s.append(attr_str) + r += "\n".join(s) + return r + + def __repr__(self): # print + return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__()) + + +def load_cfg_from_cfg_file(file): + cfg = {} + assert os.path.isfile(file) and file.endswith('.yaml'), \ + '{} is not a yaml file'.format(file) + + with open(file, 'r') as f: + cfg_from_file = yaml.safe_load(f) + + for key in cfg_from_file: + for k, v in cfg_from_file[key].items(): + cfg[k] = v + + cfg = CfgNode(cfg) + return cfg + +def merge_cfg_from_args(cfg, args): + args_dict = args.__dict__ + for k ,v in args_dict.items(): + if not k == 'config' or k == 'opts': + cfg[k] = v + + return cfg + +def merge_cfg_from_list(cfg, cfg_list): + new_cfg = copy.deepcopy(cfg) + assert len(cfg_list) % 2 == 0 + for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): + subkey = full_key.split('.')[-1] + assert subkey in cfg, 'Non-existent key: {}'.format(full_key) + value = _decode_cfg_value(v) + value = _check_and_coerce_cfg_value_type( + value, cfg[subkey], subkey, full_key + ) + setattr(new_cfg, subkey, value) + + return new_cfg + + +def _decode_cfg_value(v): + """Decodes a raw config value (e.g., from a yaml config files or command + line argument) into a Python object. + """ + # All remaining processing is only applied to strings + if not isinstance(v, str): + return v + # Try to interpret `v` as a: + # string, number, tuple, list, dict, boolean, or None + try: + v = literal_eval(v) + # The following two excepts allow v to pass through when it represents a + # string. + # + # Longer explanation: + # The type of v is always a string (before calling literal_eval), but + # sometimes it *represents* a string and other times a data structure, like + # a list. In the case that v represents a string, what we got back from the + # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is + # ok with '"foo"', but will raise a ValueError if given 'foo'. In other + # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval + # will raise a SyntaxError. + except ValueError: + pass + except SyntaxError: + pass + return v + + +def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): + """Checks that `replacement`, which is intended to replace `original` is of + the right type. The type is correct if it matches exactly or is one of a few + cases in which the type can be easily coerced. + """ + original_type = type(original) + replacement_type = type(replacement) + + # The types must match (with some exceptions) + if replacement_type == original_type: + return replacement + + # Cast replacement from from_type to to_type if the replacement and original + # types match from_type and to_type + def conditional_cast(from_type, to_type): + if replacement_type == from_type and original_type == to_type: + return True, to_type(replacement) + else: + return False, None + + # Conditionally casts + # list <-> tuple + casts = [(tuple, list), (list, tuple)] + # For py2: allow converting from str (bytes) to a unicode string + try: + casts.append((str, unicode)) # noqa: F821 + except Exception: + pass + + for (from_type, to_type) in casts: + converted, converted_value = conditional_cast(from_type, to_type) + if converted: + return converted_value + + raise ValueError( + "Type mismatch ({} vs. {}) with values ({} vs. {}) for config " + "key: {}".format( + original_type, replacement_type, original, replacement, full_key + ) + ) + + +def _assert_with_logging(cond, msg): + if not cond: + logger.debug(msg) + assert cond, msg \ No newline at end of file diff --git a/util/dataset.py b/util/dataset.py new file mode 100644 index 0000000..1fe1b12 --- /dev/null +++ b/util/dataset.py @@ -0,0 +1,678 @@ +import os +import os.path +import cv2 +import numpy as np +import copy + +from torch.utils.data import Dataset +import torch.nn.functional as F +import torch +import random +import time +from tqdm import tqdm + +from .get_weak_anns import transform_anns + +IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] + + + +def is_image_file(filename): + filename_lower = filename.lower() + return any(filename_lower.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(split=0, data_root=None, data_list=None, sub_list=None, filter_intersection=False): + assert split in [0, 1, 2, 3] + if not os.path.isfile(data_list): + raise (RuntimeError("Image list file do not exist: " + data_list + "\n")) + + # Shaban uses these lines to remove small objects: + # if util.change_coordinates(mask, 32.0, 0.0).sum() > 2: + # filtered_item.append(item) + # which means the mask will be downsampled to 1/32 of the original size and the valid area should be larger than 2, + # therefore the area in original size should be accordingly larger than 2 * 32 * 32 + image_label_list = [] + list_read = open(data_list).readlines() + print("Processing data...".format(sub_list)) + sub_class_file_list = {} + for sub_c in sub_list: + sub_class_file_list[sub_c] = [] + + for l_idx in tqdm(range(len(list_read))): + line = list_read[l_idx] + line = line.strip() + line_split = line.split(' ') + image_name = os.path.join(data_root, line_split[0]) + label_name = os.path.join(data_root, line_split[1]) + item = (image_name, label_name) + label = cv2.imread(label_name, cv2.IMREAD_GRAYSCALE) + label_class = np.unique(label).tolist() + + if 0 in label_class: + label_class.remove(0) + if 255 in label_class: + label_class.remove(255) + + new_label_class = [] + + if filter_intersection: # filter images containing objects of novel categories during meta-training + if set(label_class).issubset(set(sub_list)): + for c in label_class: + if c in sub_list: + tmp_label = np.zeros_like(label) + target_pix = np.where(label == c) + tmp_label[target_pix[0],target_pix[1]] = 1 + if tmp_label.sum() >= 2 * 32 * 32: + new_label_class.append(c) + else: + for c in label_class: + if c in sub_list: + tmp_label = np.zeros_like(label) + target_pix = np.where(label == c) + tmp_label[target_pix[0],target_pix[1]] = 1 + if tmp_label.sum() >= 2 * 32 * 32: + new_label_class.append(c) + + label_class = new_label_class + + if len(label_class) > 0: + image_label_list.append(item) + for c in label_class: + if c in sub_list: + sub_class_file_list[c].append(item) + + print("Checking image&label pair {} list done! ".format(split)) + return image_label_list, sub_class_file_list + + + +class SemData(Dataset): + def __init__(self, split=3, shot=1, data_root=None, base_data_root=None, data_list=None, data_set=None, use_split_coco=False, \ + transform=None, transform_tri=None, mode='train', ann_type='mask', \ + ft_transform=None, ft_aug_size=None, \ + ms_transform=None): + + assert mode in ['train', 'val', 'demo', 'finetune'] + assert data_set in ['pascal', 'coco'] + if mode == 'finetune': + assert ft_transform is not None + assert ft_aug_size is not None + + if data_set == 'pascal': + self.num_classes = 20 + elif data_set == 'coco': + self.num_classes = 80 + + self.mode = mode + self.split = split + self.shot = shot + self.data_root = data_root + self.base_data_root = base_data_root + self.ann_type = ann_type + + if data_set == 'pascal': + self.class_list = list(range(1, 21)) # [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20] + if self.split == 3: + self.sub_list = list(range(1, 16)) # [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15] + self.sub_val_list = list(range(16, 21)) # [16,17,18,19,20] + elif self.split == 2: + self.sub_list = list(range(1, 11)) + list(range(16, 21)) # [1,2,3,4,5,6,7,8,9,10,16,17,18,19,20] + self.sub_val_list = list(range(11, 16)) # [11,12,13,14,15] + elif self.split == 1: + self.sub_list = list(range(1, 6)) + list(range(11, 21)) # [1,2,3,4,5,11,12,13,14,15,16,17,18,19,20] + self.sub_val_list = list(range(6, 11)) # [6,7,8,9,10] + elif self.split == 0: + self.sub_list = list(range(6, 21)) # [6,7,8,9,10,11,12,13,14,15,16,17,18,19,20] + self.sub_val_list = list(range(1, 6)) # [1,2,3,4,5] + + elif data_set == 'coco': + if use_split_coco: + print('INFO: using SPLIT COCO (FWB)') + self.class_list = list(range(1, 81)) + if self.split == 3: + self.sub_val_list = list(range(4, 81, 4)) + self.sub_list = list(set(self.class_list) - set(self.sub_val_list)) + elif self.split == 2: + self.sub_val_list = list(range(3, 80, 4)) + self.sub_list = list(set(self.class_list) - set(self.sub_val_list)) + elif self.split == 1: + self.sub_val_list = list(range(2, 79, 4)) + self.sub_list = list(set(self.class_list) - set(self.sub_val_list)) + elif self.split == 0: + self.sub_val_list = list(range(1, 78, 4)) + self.sub_list = list(set(self.class_list) - set(self.sub_val_list)) + else: + print('INFO: using COCO (PANet)') + self.class_list = list(range(1, 81)) + if self.split == 3: + self.sub_list = list(range(1, 61)) + self.sub_val_list = list(range(61, 81)) + elif self.split == 2: + self.sub_list = list(range(1, 41)) + list(range(61, 81)) + self.sub_val_list = list(range(41, 61)) + elif self.split == 1: + self.sub_list = list(range(1, 21)) + list(range(41, 81)) + self.sub_val_list = list(range(21, 41)) + elif self.split == 0: + self.sub_list = list(range(21, 81)) + self.sub_val_list = list(range(1, 21)) + + print('sub_list: ', self.sub_list) + print('sub_val_list: ', self.sub_val_list) + + # @@@ For convenience, we skip the step of building datasets and instead use the pre-generated lists @@@ + # if self.mode == 'train': + # self.data_list, self.sub_class_file_list = make_dataset(split, data_root, data_list, self.sub_list, True) + # assert len(self.sub_class_file_list.keys()) == len(self.sub_list) + # elif self.mode == 'val' or self.mode == 'demo' or self.mode == 'finetune': + # self.data_list, self.sub_class_file_list = make_dataset(split, data_root, data_list, self.sub_val_list, False) + # assert len(self.sub_class_file_list.keys()) == len(self.sub_val_list) + + mode = 'train' if self.mode=='train' else 'val' + self.base_path = os.path.join(self.base_data_root, mode, str(self.split)) + + fss_list_root = './lists/{}/fss_list/{}/'.format(data_set, mode) + fss_data_list_path = fss_list_root + 'data_list_{}.txt'.format(split) + fss_sub_class_file_list_path = fss_list_root + 'sub_class_file_list_{}.txt'.format(split) + + # Write FSS Data + # with open(fss_data_list_path, 'w') as f: + # for item in self.data_list: + # img, label = item + # f.write(img + ' ') + # f.write(label + '\n') + # with open(fss_sub_class_file_list_path, 'w') as f: + # f.write(str(self.sub_class_file_list)) + + # Read FSS Data + with open(fss_data_list_path, 'r') as f: + f_str = f.readlines() + self.data_list = [] + for line in f_str: + img, mask = line.split(' ') + self.data_list.append((img, mask.strip())) + + with open(fss_sub_class_file_list_path, 'r') as f: + f_str = f.read() + self.sub_class_file_list = eval(f_str) + + self.transform = transform + self.transform_tri = transform_tri + self.ft_transform = ft_transform + self.ft_aug_size = ft_aug_size + self.ms_transform_list = ms_transform + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, index): + label_class = [] + image_path, label_path = self.data_list[index] + image = cv2.imread(image_path, cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = np.float32(image) + label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) + label_b = cv2.imread(os.path.join(self.base_path,label_path.split('/')[-1]), cv2.IMREAD_GRAYSCALE) + + if image.shape[0] != label.shape[0] or image.shape[1] != label.shape[1]: + raise (RuntimeError("Query Image & label shape mismatch: " + image_path + " " + label_path + "\n")) + label_class = np.unique(label).tolist() + if 0 in label_class: + label_class.remove(0) + if 255 in label_class: + label_class.remove(255) + new_label_class = [] + for c in label_class: + if c in self.sub_val_list: + if self.mode == 'val' or self.mode == 'demo' or self.mode == 'finetune': + new_label_class.append(c) + if c in self.sub_list: + if self.mode == 'train': + new_label_class.append(c) + label_class = new_label_class + assert len(label_class) > 0 + + class_chosen = label_class[random.randint(1,len(label_class))-1] + target_pix = np.where(label == class_chosen) + ignore_pix = np.where(label == 255) + label[:,:] = 0 + if target_pix[0].shape[0] > 0: + label[target_pix[0],target_pix[1]] = 1 + label[ignore_pix[0],ignore_pix[1]] = 255 + + # for cls in range(1,self.num_classes+1): + # select_pix = np.where(label_b_tmp == cls) + # if cls in self.sub_list: + # label_b[select_pix[0],select_pix[1]] = self.sub_list.index(cls) + 1 + # else: + # label_b[select_pix[0],select_pix[1]] = 0 + + file_class_chosen = self.sub_class_file_list[class_chosen] + num_file = len(file_class_chosen) + + support_image_path_list = [] + support_label_path_list = [] + support_idx_list = [] + for k in range(self.shot): + support_idx = random.randint(1,num_file)-1 + support_image_path = image_path + support_label_path = label_path + while((support_image_path == image_path and support_label_path == label_path) or support_idx in support_idx_list): + support_idx = random.randint(1,num_file)-1 + support_image_path, support_label_path = file_class_chosen[support_idx] + support_idx_list.append(support_idx) + support_image_path_list.append(support_image_path) + support_label_path_list.append(support_label_path) + + support_image_list_ori = [] + support_label_list_ori = [] + support_label_list_ori_mask = [] + subcls_list = [] + if self.mode == 'train': + subcls_list.append(self.sub_list.index(class_chosen)) + else: + subcls_list.append(self.sub_val_list.index(class_chosen)) + for k in range(self.shot): + support_image_path = support_image_path_list[k] + support_label_path = support_label_path_list[k] + support_image = cv2.imread(support_image_path, cv2.IMREAD_COLOR) + support_image = cv2.cvtColor(support_image, cv2.COLOR_BGR2RGB) + support_image = np.float32(support_image) + support_label = cv2.imread(support_label_path, cv2.IMREAD_GRAYSCALE) + target_pix = np.where(support_label == class_chosen) + ignore_pix = np.where(support_label == 255) + support_label[:,:] = 0 + support_label[target_pix[0],target_pix[1]] = 1 + + support_label, support_label_mask = transform_anns(support_label, self.ann_type) # mask/bbox + support_label[ignore_pix[0],ignore_pix[1]] = 255 + support_label_mask[ignore_pix[0],ignore_pix[1]] = 255 + if support_image.shape[0] != support_label.shape[0] or support_image.shape[1] != support_label.shape[1]: + raise (RuntimeError("Support Image & label shape mismatch: " + support_image_path + " " + support_label_path + "\n")) + support_image_list_ori.append(support_image) + support_label_list_ori.append(support_label) + support_label_list_ori_mask.append(support_label_mask) + assert len(support_label_list_ori) == self.shot and len(support_image_list_ori) == self.shot + + raw_image = image.copy() + raw_label = label.copy() + raw_label_b = label_b.copy() + support_image_list = [[] for _ in range(self.shot)] + support_label_list = [[] for _ in range(self.shot)] + if self.transform is not None: + image, label, label_b = self.transform_tri(image, label, label_b) # transform the triple + for k in range(self.shot): + support_image_list[k], support_label_list[k] = self.transform(support_image_list_ori[k], support_label_list_ori[k]) + + s_xs = support_image_list + s_ys = support_label_list + s_x = s_xs[0].unsqueeze(0) + for i in range(1, self.shot): + s_x = torch.cat([s_xs[i].unsqueeze(0), s_x], 0) + s_y = s_ys[0].unsqueeze(0) + for i in range(1, self.shot): + s_y = torch.cat([s_ys[i].unsqueeze(0), s_y], 0) + + # Return + if self.mode == 'train': + return image, label, label_b, s_x, s_y, subcls_list + elif self.mode == 'val': + return image, label, label_b, s_x, s_y, subcls_list, raw_label, raw_label_b + elif self.mode == 'demo': + total_image_list = support_image_list_ori.copy() + total_image_list.append(raw_image) + return image, label, label_b, s_x, s_y, subcls_list, total_image_list, support_label_list_ori, support_label_list_ori_mask, raw_label, raw_label_b + + + +# -------------------------- GFSS -------------------------- + +def make_GFSS_dataset(split=0, data_root=None, data_list=None, sub_list=None, sub_val_list=None): + assert split in [0, 1, 2, 3] + if not os.path.isfile(data_list): + raise (RuntimeError("Image list file do not exist: " + data_list + "\n")) + + image_label_list = [] + list_read = open(data_list).readlines() + print("Processing data...".format(sub_val_list)) + sub_class_list_sup = {} + for sub_c in sub_val_list: + sub_class_list_sup[sub_c] = [] + + for l_idx in tqdm(range(len(list_read))): + line = list_read[l_idx] + line = line.strip() + line_split = line.split(' ') + image_name = os.path.join(data_root, line_split[0]) + label_name = os.path.join(data_root, line_split[1]) + item = (image_name, label_name) + label = cv2.imread(label_name, cv2.IMREAD_GRAYSCALE) + label_class = np.unique(label).tolist() + + if 0 in label_class: + label_class.remove(0) + if 255 in label_class: + label_class.remove(255) + + for c in label_class: + if c in sub_val_list: + sub_class_list_sup[c].append(item) + + image_label_list.append(item) + + print("Checking image&label pair {} list done! ".format(split)) + return sub_class_list_sup, image_label_list + +class GSemData(Dataset): + # Generalized Few-Shot Segmentation + def __init__(self, split=3, shot=1, data_root=None, base_data_root=None, data_list=None, data_set=None, use_split_coco=False, \ + transform=None, transform_tri=None, mode='val', ann_type='mask'): + + assert mode in ['val', 'demo'] + assert data_set in ['pascal', 'coco'] + + if data_set == 'pascal': + self.num_classes = 20 + elif data_set == 'coco': + self.num_classes = 80 + + self.mode = mode + self.split = split + self.shot = shot + self.data_root = data_root + self.base_data_root = base_data_root + self.ann_type = ann_type + + if data_set == 'pascal': + self.class_list = list(range(1, 21)) # [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20] + if self.split == 3: + self.sub_list = list(range(1, 16)) # [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15] + self.sub_val_list = list(range(16, 21)) # [16,17,18,19,20] + elif self.split == 2: + self.sub_list = list(range(1, 11)) + list(range(16, 21)) # [1,2,3,4,5,6,7,8,9,10,16,17,18,19,20] + self.sub_val_list = list(range(11, 16)) # [11,12,13,14,15] + elif self.split == 1: + self.sub_list = list(range(1, 6)) + list(range(11, 21)) # [1,2,3,4,5,11,12,13,14,15,16,17,18,19,20] + self.sub_val_list = list(range(6, 11)) # [6,7,8,9,10] + elif self.split == 0: + self.sub_list = list(range(6, 21)) # [6,7,8,9,10,11,12,13,14,15,16,17,18,19,20] + self.sub_val_list = list(range(1, 6)) # [1,2,3,4,5] + + elif data_set == 'coco': + if use_split_coco: + print('INFO: using SPLIT COCO (FWB)') + self.class_list = list(range(1, 81)) + if self.split == 3: + self.sub_val_list = list(range(4, 81, 4)) + self.sub_list = list(set(self.class_list) - set(self.sub_val_list)) + elif self.split == 2: + self.sub_val_list = list(range(3, 80, 4)) + self.sub_list = list(set(self.class_list) - set(self.sub_val_list)) + elif self.split == 1: + self.sub_val_list = list(range(2, 79, 4)) + self.sub_list = list(set(self.class_list) - set(self.sub_val_list)) + elif self.split == 0: + self.sub_val_list = list(range(1, 78, 4)) + self.sub_list = list(set(self.class_list) - set(self.sub_val_list)) + else: + print('INFO: using COCO (PANet)') + self.class_list = list(range(1, 81)) + if self.split == 3: + self.sub_list = list(range(1, 61)) + self.sub_val_list = list(range(61, 81)) + elif self.split == 2: + self.sub_list = list(range(1, 41)) + list(range(61, 81)) + self.sub_val_list = list(range(41, 61)) + elif self.split == 1: + self.sub_list = list(range(1, 21)) + list(range(41, 81)) + self.sub_val_list = list(range(21, 41)) + elif self.split == 0: + self.sub_list = list(range(21, 81)) + self.sub_val_list = list(range(1, 21)) + + print('sub_list: ', self.sub_list) + print('sub_val_list: ', self.sub_val_list) + + self.sub_class_list_sup, self.data_list = make_GFSS_dataset(split, data_root, data_list, self.sub_list, self.sub_val_list) + assert len(self.sub_class_list_sup.keys()) == len(self.sub_val_list) + + self.transform = transform + self.transform_tri = transform_tri + + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, index): + + # Choose a query image + image_path, label_path = self.data_list[index] + image = cv2.imread(image_path, cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = np.float32(image) + label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) + label_t = label.copy() + label_t_tmp = label.copy() + + if image.shape[0] != label.shape[0] or image.shape[1] != label.shape[1]: + raise (RuntimeError("Query Image & label shape mismatch: " + image_path + " " + label_path + "\n")) + + # Get the category information of the query image + label_class = np.unique(label).tolist() + if 0 in label_class: + label_class.remove(0) + if 255 in label_class: + label_class.remove(255) + label_class_novel = [] + label_class_base = [] + for c in label_class: + if c in self.sub_val_list: + label_class_novel.append(c) + else: + label_class_base.append(c) + + # Choose the category of this episode + if len(label_class_base) == 0: + class_chosen = random.choice(label_class_novel) # rule out the possibility that the image contains only "background" + else: + class_chosen = random.choice(self.sub_val_list) + + # Generate new annotations + for cls in range(1,self.num_classes+1): + select_pix = np.where(label_t_tmp == cls) + if cls in self.sub_list: + label_t[select_pix[0],select_pix[1]] = self.sub_list.index(cls) + 1 + elif cls == class_chosen: + label_t[select_pix[0],select_pix[1]] = self.num_classes*3/4 + 1 + else: + label_t[select_pix[0],select_pix[1]] = 0 + + # Sample K-shot images + file_class_chosen = self.sub_class_list_sup[class_chosen] + num_file = len(file_class_chosen) + + support_image_path_list = [] + support_label_path_list = [] + support_idx_list = [] + for k in range(self.shot): + support_idx = random.randint(1,num_file)-1 + support_image_path = image_path + support_label_path = label_path + while((support_image_path == image_path and support_label_path == label_path) or support_idx in support_idx_list): + support_idx = random.randint(1,num_file)-1 + support_image_path, support_label_path = file_class_chosen[support_idx] + support_idx_list.append(support_idx) + support_image_path_list.append(support_image_path) + support_label_path_list.append(support_label_path) + + support_image_list_ori = [] + support_label_list_ori = [] + support_label_list_ori_mask = [] + subcls_list = [] + subcls_list.append(self.sub_val_list.index(class_chosen)) + for k in range(self.shot): + support_image_path = support_image_path_list[k] + support_label_path = support_label_path_list[k] + support_image = cv2.imread(support_image_path, cv2.IMREAD_COLOR) + support_image = cv2.cvtColor(support_image, cv2.COLOR_BGR2RGB) + support_image = np.float32(support_image) + support_label = cv2.imread(support_label_path, cv2.IMREAD_GRAYSCALE) + target_pix = np.where(support_label == class_chosen) + ignore_pix = np.where(support_label == 255) + support_label[:,:] = 0 + support_label[target_pix[0],target_pix[1]] = 1 + + support_label, support_label_mask = transform_anns(support_label, self.ann_type) + support_label[ignore_pix[0],ignore_pix[1]] = 255 + support_label_mask[ignore_pix[0],ignore_pix[1]] = 255 + if support_image.shape[0] != support_label.shape[0] or support_image.shape[1] != support_label.shape[1]: + raise (RuntimeError("Support Image & label shape mismatch: " + support_image_path + " " + support_label_path + "\n")) + support_image_list_ori.append(support_image) + support_label_list_ori.append(support_label) + support_label_list_ori_mask.append(support_label_mask) + assert len(support_label_list_ori) == self.shot and len(support_image_list_ori) == self.shot + + # Transform + raw_image = image.copy() + raw_label_t = label_t.copy() + support_image_list = [[] for _ in range(self.shot)] + support_label_list = [[] for _ in range(self.shot)] + if self.transform is not None: + image, label_t = self.transform(image, label_t) + for k in range(self.shot): + support_image_list[k], support_label_list[k] = self.transform(support_image_list_ori[k], support_label_list_ori[k]) + + s_xs = support_image_list + s_ys = support_label_list + s_x = s_xs[0].unsqueeze(0) + for i in range(1, self.shot): + s_x = torch.cat([s_xs[i].unsqueeze(0), s_x], 0) + s_y = s_ys[0].unsqueeze(0) + for i in range(1, self.shot): + s_y = torch.cat([s_ys[i].unsqueeze(0), s_y], 0) + + # Return + if self.mode == 'val': + return image, label_t, s_x, s_y, subcls_list, raw_label_t + elif self.mode == 'demo': + total_image_list = support_image_list_ori.copy() + total_image_list.append(raw_image) + return image, label_t, s_x, s_y, subcls_list, total_image_list, support_label_list_ori, support_label_list_ori_mask, raw_label_t + + + +# -------------------------- Pre-Training -------------------------- + +class BaseData(Dataset): + def __init__(self, split=3, mode=None, data_root=None, data_list=None, data_set=None, use_split_coco=False, transform=None, main_process=False, \ + batch_size=None): + + assert data_set in ['pascal', 'coco'] + assert mode in ['train', 'val'] + + if data_set == 'pascal': + self.num_classes = 20 + elif data_set == 'coco': + self.num_classes = 80 + + self.mode = mode + self.split = split + self.data_root = data_root + self.batch_size = batch_size + + if data_set == 'pascal': + self.class_list = list(range(1, 21)) # [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20] + if self.split == 3: + self.sub_list = list(range(1, 16)) # [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15] + self.sub_val_list = list(range(16, 21)) # [16,17,18,19,20] + elif self.split == 2: + self.sub_list = list(range(1, 11)) + list(range(16, 21)) # [1,2,3,4,5,6,7,8,9,10,16,17,18,19,20] + self.sub_val_list = list(range(11, 16)) # [11,12,13,14,15] + elif self.split == 1: + self.sub_list = list(range(1, 6)) + list(range(11, 21)) # [1,2,3,4,5,11,12,13,14,15,16,17,18,19,20] + self.sub_val_list = list(range(6, 11)) # [6,7,8,9,10] + elif self.split == 0: + self.sub_list = list(range(6, 21)) # [6,7,8,9,10,11,12,13,14,15,16,17,18,19,20] + self.sub_val_list = list(range(1, 6)) # [1,2,3,4,5] + + elif data_set == 'coco': + if use_split_coco: + print('INFO: using SPLIT COCO (FWB)') + self.class_list = list(range(1, 81)) + if self.split == 3: + self.sub_val_list = list(range(4, 81, 4)) + self.sub_list = list(set(self.class_list) - set(self.sub_val_list)) + elif self.split == 2: + self.sub_val_list = list(range(3, 80, 4)) + self.sub_list = list(set(self.class_list) - set(self.sub_val_list)) + elif self.split == 1: + self.sub_val_list = list(range(2, 79, 4)) + self.sub_list = list(set(self.class_list) - set(self.sub_val_list)) + elif self.split == 0: + self.sub_val_list = list(range(1, 78, 4)) + self.sub_list = list(set(self.class_list) - set(self.sub_val_list)) + else: + print('INFO: using COCO (PANet)') + self.class_list = list(range(1, 81)) + if self.split == 3: + self.sub_list = list(range(1, 61)) + self.sub_val_list = list(range(61, 81)) + elif self.split == 2: + self.sub_list = list(range(1, 41)) + list(range(61, 81)) + self.sub_val_list = list(range(41, 61)) + elif self.split == 1: + self.sub_list = list(range(1, 21)) + list(range(41, 81)) + self.sub_val_list = list(range(21, 41)) + elif self.split == 0: + self.sub_list = list(range(21, 81)) + self.sub_val_list = list(range(1, 21)) + + print('sub_list: ', self.sub_list) + print('sub_val_list: ', self.sub_val_list) + + self.data_list = [] + list_read = open(data_list).readlines() + print("Processing data...") + + for l_idx in tqdm(range(len(list_read))): + line = list_read[l_idx] + line = line.strip() + line_split = line.split(' ') + image_name = os.path.join(self.data_root, line_split[0]) + label_name = os.path.join(self.data_root, line_split[1]) + item = (image_name, label_name) + self.data_list.append(item) + + self.transform = transform + + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, index): + image_path, label_path = self.data_list[index] + image = cv2.imread(image_path, cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = np.float32(image) + label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) + label_tmp = label.copy() + + for cls in range(1, self.num_classes+1): + select_pix = np.where(label_tmp == cls) + if cls in self.sub_list: + label[select_pix[0],select_pix[1]] = self.sub_list.index(cls) + 1 + else: + label[select_pix[0],select_pix[1]] = 0 + + raw_label = label.copy() + + if self.transform is not None: + image, label = self.transform(image, label) + + # Return + if self.mode == 'val' and self.batch_size == 1: + return image, label, raw_label + else: + return image, label + \ No newline at end of file diff --git a/util/get_mulway_base_data.py b/util/get_mulway_base_data.py new file mode 100644 index 0000000..510f734 --- /dev/null +++ b/util/get_mulway_base_data.py @@ -0,0 +1,88 @@ +import cv2 +import numpy as np +import argparse +import os.path as osp +from tqdm import tqdm +from util import get_train_val_set, check_makedirs + +# Get the annotations of base categories + +# root_path +# ├── BAM/ +# │ ├── util/ +# │ ├── config/ +# │ ├── model/ +# │ ├── README.md +# │ ├── train.py +# │ ├── train_base.py +# │ └── test.py +# └── data/ +# ├── base_annotation/ # the scripts to create THIS folder +# │ ├── pascal/ +# │ │ ├── train/ +# │ │ │ ├── 0/ # annotations of PASCAL-5^0 +# │ │ │ ├── 1/ +# │ │ │ ├── 2/ +# │ │ │ └── 3/ +# │ │ └── val/ +# │ └── coco/ # the same file structure for COCO +# ├── VOCdevkit2012/ +# └── MSCOCO2014/ + +parser = argparse.ArgumentParser() +args = parser.parse_args() + +args.data_set = 'pascal' # pascal coco +args.use_split_coco = True +args.mode = 'train' # train val +args.split = 0 # 0 1 2 3 +if args.data_set == 'pascal': + num_classes = 20 +elif args.data_set == 'coco': + num_classes = 80 + +root_path = '/disk2/lcb/study/FS_Seg/' +data_path = osp.join(root_path, 'data/base_annotation/') +save_path = osp.join(data_path, args.data_set, args.mode, str(args.split)) +check_makedirs(save_path) + +# get class list +sub_list, sub_val_list = get_train_val_set(args) + +# get data_list +fss_list_root = root_path + '/BAM/lists/{}/fss_list/{}/'.format(args.data_set, args.mode) +fss_data_list_path = fss_list_root + 'data_list_{}.txt'.format(args.split) +with open(fss_data_list_path, 'r') as f: + f_str = f.readlines() +data_list = [] +for line in f_str: + img, mask = line.split(' ') + data_list.append((img, mask.strip())) + +# Start Processing +for index in tqdm(range(len(data_list))): + image_path, label_path = data_list[index] + image_path, label_path = root_path + image_path[3:], root_path+ label_path[3:] # + label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) + label_tmp = label.copy() + + for cls in range(1,num_classes+1): + select_pix = np.where(label_tmp == cls) + if cls in sub_list: + label[select_pix[0],select_pix[1]] = sub_list.index(cls) + 1 + else: + label[select_pix[0],select_pix[1]] = 0 + + # for pix in np.nditer(label, op_flags=['readwrite']): + # if pix == 255: + # pass + # elif pix not in sub_list: + # pix[...] = 0 + # else: + # pix[...] = sub_list.index(pix) + 1 + + save_item_path = osp.join(save_path, label_path.split('/')[-1]) + cv2.imwrite(save_item_path, label) + + +print('end') \ No newline at end of file diff --git a/util/get_weak_anns.py b/util/get_weak_anns.py new file mode 100644 index 0000000..7102cd7 --- /dev/null +++ b/util/get_weak_anns.py @@ -0,0 +1,41 @@ +from __future__ import absolute_import, division + +import networkx as nx +import numpy as np +from scipy.ndimage import binary_dilation, binary_erosion, maximum_filter +from scipy.special import comb +from skimage.filters import rank +from skimage.morphology import dilation, disk, erosion, medial_axis +from sklearn.neighbors import radius_neighbors_graph +import cv2 +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +from scipy import ndimage + +def find_bbox(mask): + _, labels, stats, centroids = cv2.connectedComponentsWithStats(mask.astype(np.uint8)) + return stats[1:] # remove bg stat + +def transform_anns(mask, ann_type): + mask_ori = mask.copy() + + if ann_type == 'bbox': + bboxs = find_bbox(mask) + for j in bboxs: + cv2.rectangle(mask, (j[0], j[1]), (j[0] + j[2], j[1] + j[3]), 1, -1) # -1->fill; 2->draw_rec + return mask, mask_ori + + elif ann_type == 'mask': + return mask, mask_ori + + +if __name__ == '__main__': + label_path = '2008_001227.png' + mask = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) + bboxs = find_bbox(mask) + mask_color = cv2.imread(label_path, cv2.IMREAD_COLOR) + for j in bboxs: + cv2.rectangle(mask_color, (j[0], j[1]), (j[0] + j[2], j[1] + j[3]), (0,255,0), -1) + cv2.imwrite('bbox.png', mask_color) + + print('done') \ No newline at end of file diff --git a/util/transform.py b/util/transform.py new file mode 100644 index 0000000..e5ff298 --- /dev/null +++ b/util/transform.py @@ -0,0 +1,424 @@ +import random +import math +import numpy as np +import numbers +import collections.abc +import collections + +import cv2 + +import torch + +manual_seed = 123 +torch.manual_seed(manual_seed) +np.random.seed(manual_seed) +torch.manual_seed(manual_seed) +torch.cuda.manual_seed_all(manual_seed) +random.seed(manual_seed) + +class Compose(object): + # Composes segtransforms: segtransform.Compose([segtransform.RandScale([0.5, 2.0]), segtransform.ToTensor()]) + def __init__(self, segtransform): + self.segtransform = segtransform + + def __call__(self, image, label): + for t in self.segtransform: + image, label = t(image, label) + return image, label + +import time +class ToTensor(object): + # Converts numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W). + def __call__(self, image, label): + if not isinstance(image, np.ndarray) or not isinstance(label, np.ndarray): + raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray" + "[eg: data readed by cv2.imread()].\n")) + if len(image.shape) > 3 or len(image.shape) < 2: + raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray with 3 dims or 2 dims.\n")) + if len(image.shape) == 2: + image = np.expand_dims(image, axis=2) + if not len(label.shape) == 2: + raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray labellabel with 2 dims.\n")) + + image = torch.from_numpy(image.transpose((2, 0, 1))) + if not isinstance(image, torch.FloatTensor): + image = image.float() + label = torch.from_numpy(label) + if not isinstance(label, torch.LongTensor): + label = label.long() + return image, label + +class ToNumpy(object): + # Converts torch.FloatTensor of shape (C x H x W) to a numpy.ndarray (H x W x C). + def __call__(self, image, label): + if not isinstance(image, torch.Tensor) or not isinstance(label, torch.Tensor): + raise (RuntimeError("segtransform.ToNumpy() only handle torch.tensor")) + + image = image.cpu().numpy().transpose((1, 2, 0)) + if not image.dtype == np.uint8: + image = image.astype(np.uint8) + label = label.cpu().numpy().transpose((1, 2, 0)) + if not label.dtype == np.uint8: + label = label.astype(np.uint8) + return image, label + +class Normalize(object): + # Normalize tensor with mean and standard deviation along channel: channel = (channel - mean) / std + def __init__(self, mean, std=None): + if std is None: + assert len(mean) > 0 + else: + assert len(mean) == len(std) + self.mean = mean + self.std = std + + def __call__(self, image, label): + if self.std is None: + for t, m in zip(image, self.mean): + t.sub_(m) + else: + for t, m, s in zip(image, self.mean, self.std): + t.sub_(m).div_(s) + return image, label + +class UnNormalize(object): + # UnNormalize tensor with mean and standard deviation along channel: channel = (channel * std) + mean + def __init__(self, mean, std=None): + if std is None: + assert len(mean) > 0 + else: + assert len(mean) == len(std) + self.mean = mean + self.std = std + + def __call__(self, image, label): + if self.std is None: + for t, m in zip(image, self.mean): + t.add_(m) + else: + for t, m, s in zip(image, self.mean, self.std): + t.mul_(s).add_(m) + return image, label + + +class Resize(object): + # Resize the input to the given size, 'size' is a 2-element tuple or list in the order of (h, w). + def __init__(self, size): + self.size = size + + def __call__(self, image, label): + + value_scale = 255 + mean = [0.485, 0.456, 0.406] + mean = [item * value_scale for item in mean] + std = [0.229, 0.224, 0.225] + std = [item * value_scale for item in std] + + def find_new_hw(ori_h, ori_w, test_size): + if ori_h >= ori_w: + ratio = test_size*1.0 / ori_h + new_h = test_size + new_w = int(ori_w * ratio) + elif ori_w > ori_h: + ratio = test_size*1.0 / ori_w + new_h = int(ori_h * ratio) + new_w = test_size + + if new_h % 8 != 0: + new_h = (int(new_h /8))*8 + else: + new_h = new_h + if new_w % 8 != 0: + new_w = (int(new_w /8))*8 + else: + new_w = new_w + return new_h, new_w + + test_size = self.size + new_h, new_w = find_new_hw(image.shape[0], image.shape[1], test_size) + #new_h, new_w = test_size, test_size + image_crop = cv2.resize(image, dsize=(int(new_w), int(new_h)), interpolation=cv2.INTER_LINEAR) + back_crop = np.zeros((test_size, test_size, 3)) + # back_crop[:,:,0] = mean[0] + # back_crop[:,:,1] = mean[1] + # back_crop[:,:,2] = mean[2] + back_crop[:new_h, :new_w, :] = image_crop + image = back_crop + + s_mask = label + new_h, new_w = find_new_hw(s_mask.shape[0], s_mask.shape[1], test_size) + #new_h, new_w = test_size, test_size + s_mask = cv2.resize(s_mask.astype(np.float32), dsize=(int(new_w), int(new_h)),interpolation=cv2.INTER_NEAREST) + back_crop_s_mask = np.ones((test_size, test_size)) * 255 + back_crop_s_mask[:new_h, :new_w] = s_mask + label = back_crop_s_mask + + return image, label + + +class test_Resize(object): + # Resize the input to the given size, 'size' is a 2-element tuple or list in the order of (h, w). + def __init__(self, size): + self.size = size + + def __call__(self, image, label): + + value_scale = 255 + mean = [0.485, 0.456, 0.406] + mean = [item * value_scale for item in mean] + std = [0.229, 0.224, 0.225] + std = [item * value_scale for item in std] + + def find_new_hw(ori_h, ori_w, test_size): + if max(ori_h, ori_w) > test_size: + if ori_h >= ori_w: + ratio = test_size*1.0 / ori_h + new_h = test_size + new_w = int(ori_w * ratio) + elif ori_w > ori_h: + ratio = test_size*1.0 / ori_w + new_h = int(ori_h * ratio) + new_w = test_size + + if new_h % 8 != 0: + new_h = (int(new_h /8))*8 + else: + new_h = new_h + if new_w % 8 != 0: + new_w = (int(new_w /8))*8 + else: + new_w = new_w + return new_h, new_w + else: + return ori_h, ori_w + + test_size = self.size + new_h, new_w = find_new_hw(image.shape[0], image.shape[1], test_size) + if new_w != image.shape[0] or new_h != image.shape[1]: + image_crop = cv2.resize(image, dsize=(int(new_w), int(new_h)), interpolation=cv2.INTER_LINEAR) + else: + image_crop = image.copy() + back_crop = np.zeros((test_size, test_size, 3)) + back_crop[:new_h, :new_w, :] = image_crop + image = back_crop + + s_mask = label + new_h, new_w = find_new_hw(s_mask.shape[0], s_mask.shape[1], test_size) + if new_w != s_mask.shape[0] or new_h != s_mask.shape[1]: + s_mask = cv2.resize(s_mask.astype(np.float32), dsize=(int(new_w), int(new_h)),interpolation=cv2.INTER_NEAREST) + back_crop_s_mask = np.ones((test_size, test_size)) * 255 + back_crop_s_mask[:new_h, :new_w] = s_mask + label = back_crop_s_mask + + return image, label + +class Direct_Resize(object): + # Resize the input to the given size, 'size' is a 2-element tuple or list in the order of (h, w). + def __init__(self, size): + self.size = size + + def __call__(self, image, label): + + test_size = self.size + + image = cv2.resize(image, dsize=(test_size, test_size), interpolation=cv2.INTER_LINEAR) + label = cv2.resize(label.astype(np.float32), dsize=(test_size, test_size),interpolation=cv2.INTER_NEAREST) + + return image, label + + +class RandScale(object): + # Randomly resize image & label with scale factor in [scale_min, scale_max] + def __init__(self, scale, aspect_ratio=None): + assert (isinstance(scale, collections.abc.Iterable) and len(scale) == 2) + if isinstance(scale, collections.abc.Iterable) and len(scale) == 2 \ + and isinstance(scale[0], numbers.Number) and isinstance(scale[1], numbers.Number) \ + and 0 < scale[0] < scale[1]: + self.scale = scale + else: + raise (RuntimeError("segtransform.RandScale() scale param error.\n")) + if aspect_ratio is None: + self.aspect_ratio = aspect_ratio + elif isinstance(aspect_ratio, collections.abc.Iterable) and len(aspect_ratio) == 2 \ + and isinstance(aspect_ratio[0], numbers.Number) and isinstance(aspect_ratio[1], numbers.Number) \ + and 0 < aspect_ratio[0] < aspect_ratio[1]: + self.aspect_ratio = aspect_ratio + else: + raise (RuntimeError("segtransform.RandScale() aspect_ratio param error.\n")) + + def __call__(self, image, label): + temp_scale = self.scale[0] + (self.scale[1] - self.scale[0]) * random.random() + temp_aspect_ratio = 1.0 + if self.aspect_ratio is not None: + temp_aspect_ratio = self.aspect_ratio[0] + (self.aspect_ratio[1] - self.aspect_ratio[0]) * random.random() + temp_aspect_ratio = math.sqrt(temp_aspect_ratio) + scale_factor_x = temp_scale * temp_aspect_ratio + scale_factor_y = temp_scale / temp_aspect_ratio + image = cv2.resize(image, None, fx=scale_factor_x, fy=scale_factor_y, interpolation=cv2.INTER_LINEAR) + label = cv2.resize(label, None, fx=scale_factor_x, fy=scale_factor_y, interpolation=cv2.INTER_NEAREST) + return image, label + + +class Crop(object): + """Crops the given ndarray image (H*W*C or H*W). + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is made. + """ + def __init__(self, size, crop_type='center', padding=None, ignore_label=255): + self.size = size + if isinstance(size, int): + self.crop_h = size + self.crop_w = size + elif isinstance(size, collections.abc.Iterable) and len(size) == 2 \ + and isinstance(size[0], int) and isinstance(size[1], int) \ + and size[0] > 0 and size[1] > 0: + self.crop_h = size[0] + self.crop_w = size[1] + else: + raise (RuntimeError("crop size error.\n")) + if crop_type == 'center' or crop_type == 'rand': + self.crop_type = crop_type + else: + raise (RuntimeError("crop type error: rand | center\n")) + if padding is None: + self.padding = padding + elif isinstance(padding, list): + if all(isinstance(i, numbers.Number) for i in padding): + self.padding = padding + else: + raise (RuntimeError("padding in Crop() should be a number list\n")) + if len(padding) != 3: + raise (RuntimeError("padding channel is not equal with 3\n")) + else: + raise (RuntimeError("padding in Crop() should be a number list\n")) + if isinstance(ignore_label, int): + self.ignore_label = ignore_label + else: + raise (RuntimeError("ignore_label should be an integer number\n")) + + def __call__(self, image, label): + h, w = label.shape + + + pad_h = max(self.crop_h - h, 0) + pad_w = max(self.crop_w - w, 0) + pad_h_half = int(pad_h / 2) + pad_w_half = int(pad_w / 2) + if pad_h > 0 or pad_w > 0: + if self.padding is None: + raise (RuntimeError("segtransform.Crop() need padding while padding argument is None\n")) + image = cv2.copyMakeBorder(image, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=self.padding) + label = cv2.copyMakeBorder(label, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=self.ignore_label) + h, w = label.shape + raw_label = label + raw_image = image + + if self.crop_type == 'rand': + h_off = random.randint(0, h - self.crop_h) + w_off = random.randint(0, w - self.crop_w) + else: + h_off = int((h - self.crop_h) / 2) + w_off = int((w - self.crop_w) / 2) + image = image[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w] + label = label[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w] + raw_pos_num = np.sum(raw_label == 1) + pos_num = np.sum(label == 1) + crop_cnt = 0 + while(pos_num < 0.85*raw_pos_num and crop_cnt<=30): + image = raw_image + label = raw_label + if self.crop_type == 'rand': + h_off = random.randint(0, h - self.crop_h) + w_off = random.randint(0, w - self.crop_w) + else: + h_off = int((h - self.crop_h) / 2) + w_off = int((w - self.crop_w) / 2) + image = image[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w] + label = label[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w] + raw_pos_num = np.sum(raw_label == 1) + pos_num = np.sum(label == 1) + crop_cnt += 1 + if crop_cnt >= 50: + image = cv2.resize(raw_image, (self.size[0], self.size[0]), interpolation=cv2.INTER_LINEAR) + label = cv2.resize(raw_label, (self.size[0], self.size[0]), interpolation=cv2.INTER_NEAREST) + + if image.shape != (self.size[0], self.size[0], 3): + image = cv2.resize(image, (self.size[0], self.size[0]), interpolation=cv2.INTER_LINEAR) + label = cv2.resize(label, (self.size[0], self.size[0]), interpolation=cv2.INTER_NEAREST) + + return image, label + + +class RandRotate(object): + # Randomly rotate image & label with rotate factor in [rotate_min, rotate_max] + def __init__(self, rotate, padding, ignore_label=255, p=0.5): + assert (isinstance(rotate, collections.abc.Iterable) and len(rotate) == 2) + if isinstance(rotate[0], numbers.Number) and isinstance(rotate[1], numbers.Number) and rotate[0] < rotate[1]: + self.rotate = rotate + else: + raise (RuntimeError("segtransform.RandRotate() scale param error.\n")) + assert padding is not None + assert isinstance(padding, list) and len(padding) == 3 + if all(isinstance(i, numbers.Number) for i in padding): + self.padding = padding + else: + raise (RuntimeError("padding in RandRotate() should be a number list\n")) + assert isinstance(ignore_label, int) + self.ignore_label = ignore_label + self.p = p + + def __call__(self, image, label): + if random.random() < self.p: + angle = self.rotate[0] + (self.rotate[1] - self.rotate[0]) * random.random() + h, w = label.shape + matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) + image = cv2.warpAffine(image, matrix, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=self.padding) + label = cv2.warpAffine(label, matrix, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=self.ignore_label) + return image, label + + +class RandomHorizontalFlip(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, image, label): + if random.random() < self.p: + image = cv2.flip(image, 1) + label = cv2.flip(label, 1) + return image, label + + +class RandomVerticalFlip(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, image, label): + if random.random() < self.p: + image = cv2.flip(image, 0) + label = cv2.flip(label, 0) + return image, label + + +class RandomGaussianBlur(object): + def __init__(self, radius=5): + self.radius = radius + + def __call__(self, image, label): + if random.random() < 0.5: + image = cv2.GaussianBlur(image, (self.radius, self.radius), 0) + return image, label + + +class RGB2BGR(object): + # Converts image from RGB order to BGR order, for model initialized from Caffe + def __call__(self, image, label): + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + return image, label + + +class BGR2RGB(object): + # Converts image from BGR order to RGB order, for model initialized from Pytorch + def __call__(self, image, label): + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + return image, label \ No newline at end of file diff --git a/util/transform_tri.py b/util/transform_tri.py new file mode 100644 index 0000000..f92afda --- /dev/null +++ b/util/transform_tri.py @@ -0,0 +1,452 @@ +import random +import math +import numpy as np +import numbers +import collections +import cv2 + +import torch + +manual_seed = 123 +torch.manual_seed(manual_seed) +np.random.seed(manual_seed) +torch.manual_seed(manual_seed) +torch.cuda.manual_seed_all(manual_seed) +random.seed(manual_seed) + +class Compose(object): + # Composes segtransforms: segtransform.Compose([segtransform.RandScale([0.5, 2.0]), segtransform.ToTensor()]) + def __init__(self, segtransform): + self.segtransform = segtransform + + def __call__(self, image, label, label2): + for t in self.segtransform: + image, label, label2 = t(image, label, label2) + return image, label, label2 + +import time +class ToTensor(object): + # Converts numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W). + def __call__(self, image, label, label2): + if not isinstance(image, np.ndarray) or not isinstance(label, np.ndarray): + raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray" + "[eg: data readed by cv2.imread()].\n")) + if len(image.shape) > 3 or len(image.shape) < 2: + raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray with 3 dims or 2 dims.\n")) + if len(image.shape) == 2: + image = np.expand_dims(image, axis=2) + if not len(label.shape) == 2: + raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray labellabel with 2 dims.\n")) + + image = torch.from_numpy(image.transpose((2, 0, 1))) + if not isinstance(image, torch.FloatTensor): + image = image.float() + label = torch.from_numpy(label) + label2 = torch.from_numpy(label2) + if not isinstance(label, torch.LongTensor): + label = label.long() + label2 = label2.long() + return image, label, label2 + +class ToNumpy(object): + # Converts torch.FloatTensor of shape (C x H x W) to a numpy.ndarray (H x W x C). + def __call__(self, image, label, label2): + if not isinstance(image, torch.Tensor) or not isinstance(label, torch.Tensor): + raise (RuntimeError("segtransform.ToNumpy() only handle torch.tensor")) + + image = image.cpu().numpy().transpose((1, 2, 0)) + if not image.dtype == np.uint8: + image = image.astype(np.uint8) + label = label.cpu().numpy().transpose((1, 2, 0)) + label2 = label2.cpu().numpy().transpose((1, 2, 0)) + if not label.dtype == np.uint8: + label = label.astype(np.uint8) + label2 = label2.astype(np.uint8) + return image, label, label2 + +class Normalize(object): + # Normalize tensor with mean and standard deviation along channel: channel = (channel - mean) / std + def __init__(self, mean, std=None): + if std is None: + assert len(mean) > 0 + else: + assert len(mean) == len(std) + self.mean = mean + self.std = std + + def __call__(self, image, label, label2): + if self.std is None: + for t, m in zip(image, self.mean): + t.sub_(m) + else: + for t, m, s in zip(image, self.mean, self.std): + t.sub_(m).div_(s) + return image, label, label2 + +class UnNormalize(object): + # UnNormalize tensor with mean and standard deviation along channel: channel = (channel * std) + mean + def __init__(self, mean, std=None): + if std is None: + assert len(mean) > 0 + else: + assert len(mean) == len(std) + self.mean = mean + self.std = std + + def __call__(self, image, label, label2): + if self.std is None: + for t, m in zip(image, self.mean): + t.add_(m) + else: + for t, m, s in zip(image, self.mean, self.std): + t.mul_(s).add_(m) + return image, label, label2 + + +class Resize(object): + # Resize the input to the given size, 'size' is a 2-element tuple or list in the order of (h, w). + def __init__(self, size): + self.size = size + + def __call__(self, image, label, label2): + + value_scale = 255 + mean = [0.485, 0.456, 0.406] + mean = [item * value_scale for item in mean] + std = [0.229, 0.224, 0.225] + std = [item * value_scale for item in std] + + def find_new_hw(ori_h, ori_w, test_size): + if ori_h >= ori_w: + ratio = test_size*1.0 / ori_h + new_h = test_size + new_w = int(ori_w * ratio) + elif ori_w > ori_h: + ratio = test_size*1.0 / ori_w + new_h = int(ori_h * ratio) + new_w = test_size + + if new_h % 8 != 0: + new_h = (int(new_h /8))*8 + else: + new_h = new_h + if new_w % 8 != 0: + new_w = (int(new_w /8))*8 + else: + new_w = new_w + return new_h, new_w + + test_size = self.size + new_h, new_w = find_new_hw(image.shape[0], image.shape[1], test_size) + #new_h, new_w = test_size, test_size + image_crop = cv2.resize(image, dsize=(int(new_w), int(new_h)), interpolation=cv2.INTER_LINEAR) + back_crop = np.zeros((test_size, test_size, 3)) + # back_crop[:,:,0] = mean[0] + # back_crop[:,:,1] = mean[1] + # back_crop[:,:,2] = mean[2] + back_crop[:new_h, :new_w, :] = image_crop + image = back_crop + + s_mask = label + new_h, new_w = find_new_hw(s_mask.shape[0], s_mask.shape[1], test_size) + s_mask = cv2.resize(s_mask.astype(np.float32), dsize=(int(new_w), int(new_h)),interpolation=cv2.INTER_NEAREST) + back_crop_s_mask = np.ones((test_size, test_size)) * 255 + back_crop_s_mask[:new_h, :new_w] = s_mask + label = back_crop_s_mask + + s_mask2 = label2 + new_h2, new_w2 = find_new_hw(s_mask2.shape[0], s_mask2.shape[1], test_size) + s_mask2 = cv2.resize(s_mask2.astype(np.float32), dsize=(int(new_w2), int(new_h2)),interpolation=cv2.INTER_NEAREST) + back_crop_s_mask2 = np.ones((test_size, test_size)) * 255 + back_crop_s_mask2[:new_h2, :new_w2] = s_mask2 + label2 = back_crop_s_mask2 + + return image, label, label2 + + +class test_Resize(object): + # Resize the input to the given size, 'size' is a 2-element tuple or list in the order of (h, w). + def __init__(self, size): + self.size = size + + def __call__(self, image, label, label2): + + value_scale = 255 + mean = [0.485, 0.456, 0.406] + mean = [item * value_scale for item in mean] + std = [0.229, 0.224, 0.225] + std = [item * value_scale for item in std] + + def find_new_hw(ori_h, ori_w, test_size): + if max(ori_h, ori_w) > test_size: + if ori_h >= ori_w: + ratio = test_size*1.0 / ori_h + new_h = test_size + new_w = int(ori_w * ratio) + elif ori_w > ori_h: + ratio = test_size*1.0 / ori_w + new_h = int(ori_h * ratio) + new_w = test_size + + if new_h % 8 != 0: + new_h = (int(new_h /8))*8 + else: + new_h = new_h + if new_w % 8 != 0: + new_w = (int(new_w /8))*8 + else: + new_w = new_w + return new_h, new_w + else: + return ori_h, ori_w + + test_size = self.size + new_h, new_w = find_new_hw(image.shape[0], image.shape[1], test_size) + if new_w != image.shape[0] or new_h != image.shape[1]: + image_crop = cv2.resize(image, dsize=(int(new_w), int(new_h)), interpolation=cv2.INTER_LINEAR) + else: + image_crop = image.copy() + back_crop = np.zeros((test_size, test_size, 3)) + back_crop[:new_h, :new_w, :] = image_crop + image = back_crop + + s_mask = label + new_h, new_w = find_new_hw(s_mask.shape[0], s_mask.shape[1], test_size) + if new_w != s_mask.shape[0] or new_h != s_mask.shape[1]: + s_mask = cv2.resize(s_mask.astype(np.float32), dsize=(int(new_w), int(new_h)),interpolation=cv2.INTER_NEAREST) + back_crop_s_mask = np.ones((test_size, test_size)) * 255 + back_crop_s_mask[:new_h, :new_w] = s_mask + label = back_crop_s_mask + + s_mask2 = label2 + new_h2, new_w2 = find_new_hw(s_mask2.shape[0], s_mask2.shape[1], test_size) + if new_w2 != s_mask.shape[0] or new_h2 != s_mask2.shape[1]: + s_mask2 = cv2.resize(s_mask2.astype(np.float32), dsize=(int(new_w2), int(new_h2)),interpolation=cv2.INTER_NEAREST) + back_crop_s_mask2 = np.ones((test_size, test_size)) * 255 + back_crop_s_mask2[:new_h2, :new_w2] = s_mask2 + label2 = back_crop_s_mask2 + + return image, label, label2 + +class Direct_Resize(object): + # Resize the input to the given size, 'size' is a 2-element tuple or list in the order of (h, w). + def __init__(self, size): + self.size = size + + def __call__(self, image, label, label2): + + test_size = self.size + + image = cv2.resize(image, dsize=(test_size, test_size), interpolation=cv2.INTER_LINEAR) + label = cv2.resize(label.astype(np.float32), dsize=(test_size, test_size),interpolation=cv2.INTER_NEAREST) + label2 = cv2.resize(label2.astype(np.float32), dsize=(test_size, test_size),interpolation=cv2.INTER_NEAREST) + + return image, label, label2 + + +class RandScale(object): + # Randomly resize image & label with scale factor in [scale_min, scale_max] + def __init__(self, scale, aspect_ratio=None): + assert (isinstance(scale, collections.Iterable) and len(scale) == 2) + if isinstance(scale, collections.Iterable) and len(scale) == 2 \ + and isinstance(scale[0], numbers.Number) and isinstance(scale[1], numbers.Number) \ + and 0 < scale[0] < scale[1]: + self.scale = scale + else: + raise (RuntimeError("segtransform.RandScale() scale param error.\n")) + if aspect_ratio is None: + self.aspect_ratio = aspect_ratio + elif isinstance(aspect_ratio, collections.Iterable) and len(aspect_ratio) == 2 \ + and isinstance(aspect_ratio[0], numbers.Number) and isinstance(aspect_ratio[1], numbers.Number) \ + and 0 < aspect_ratio[0] < aspect_ratio[1]: + self.aspect_ratio = aspect_ratio + else: + raise (RuntimeError("segtransform.RandScale() aspect_ratio param error.\n")) + + def __call__(self, image, label, label2): + temp_scale = self.scale[0] + (self.scale[1] - self.scale[0]) * random.random() + temp_aspect_ratio = 1.0 + if self.aspect_ratio is not None: + temp_aspect_ratio = self.aspect_ratio[0] + (self.aspect_ratio[1] - self.aspect_ratio[0]) * random.random() + temp_aspect_ratio = math.sqrt(temp_aspect_ratio) + scale_factor_x = temp_scale * temp_aspect_ratio + scale_factor_y = temp_scale / temp_aspect_ratio + image = cv2.resize(image, None, fx=scale_factor_x, fy=scale_factor_y, interpolation=cv2.INTER_LINEAR) + label = cv2.resize(label, None, fx=scale_factor_x, fy=scale_factor_y, interpolation=cv2.INTER_NEAREST) + label2 = cv2.resize(label2, None, fx=scale_factor_x, fy=scale_factor_y, interpolation=cv2.INTER_NEAREST) + return image, label, label2 + + +class Crop(object): + """Crops the given ndarray image (H*W*C or H*W). + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is made. + """ + def __init__(self, size, crop_type='center', padding=None, ignore_label=255): + self.size = size + if isinstance(size, int): + self.crop_h = size + self.crop_w = size + elif isinstance(size, collections.Iterable) and len(size) == 2 \ + and isinstance(size[0], int) and isinstance(size[1], int) \ + and size[0] > 0 and size[1] > 0: + self.crop_h = size[0] + self.crop_w = size[1] + else: + raise (RuntimeError("crop size error.\n")) + if crop_type == 'center' or crop_type == 'rand': + self.crop_type = crop_type + else: + raise (RuntimeError("crop type error: rand | center\n")) + if padding is None: + self.padding = padding + elif isinstance(padding, list): + if all(isinstance(i, numbers.Number) for i in padding): + self.padding = padding + else: + raise (RuntimeError("padding in Crop() should be a number list\n")) + if len(padding) != 3: + raise (RuntimeError("padding channel is not equal with 3\n")) + else: + raise (RuntimeError("padding in Crop() should be a number list\n")) + if isinstance(ignore_label, int): + self.ignore_label = ignore_label + else: + raise (RuntimeError("ignore_label should be an integer number\n")) + + def __call__(self, image, label, label2): + h, w = label.shape + + + pad_h = max(self.crop_h - h, 0) + pad_w = max(self.crop_w - w, 0) + pad_h_half = int(pad_h / 2) + pad_w_half = int(pad_w / 2) + if pad_h > 0 or pad_w > 0: + if self.padding is None: + raise (RuntimeError("segtransform.Crop() need padding while padding argument is None\n")) + image = cv2.copyMakeBorder(image, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=self.padding) + label = cv2.copyMakeBorder(label, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=self.ignore_label) + label2 = cv2.copyMakeBorder(label2, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=self.ignore_label) + h, w = label.shape + raw_label = label + raw_label2 = label2 + raw_image = image + + if self.crop_type == 'rand': + h_off = random.randint(0, h - self.crop_h) + w_off = random.randint(0, w - self.crop_w) + else: + h_off = int((h - self.crop_h) / 2) + w_off = int((w - self.crop_w) / 2) + image = image[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w] + label = label[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w] + label2 = label2[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w] + raw_pos_num = np.sum(raw_label == 1) + pos_num = np.sum(label == 1) + crop_cnt = 0 + while(pos_num < 0.85*raw_pos_num and crop_cnt<=30): + image = raw_image + label = raw_label + label2 = raw_label2 + if self.crop_type == 'rand': + h_off = random.randint(0, h - self.crop_h) + w_off = random.randint(0, w - self.crop_w) + else: + h_off = int((h - self.crop_h) / 2) + w_off = int((w - self.crop_w) / 2) + image = image[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w] + label = label[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w] + label2 = label2[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w] + raw_pos_num = np.sum(raw_label == 1) + pos_num = np.sum(label == 1) + crop_cnt += 1 + if crop_cnt >= 50: + image = cv2.resize(raw_image, (self.size[0], self.size[0]), interpolation=cv2.INTER_LINEAR) + label = cv2.resize(raw_label, (self.size[0], self.size[0]), interpolation=cv2.INTER_NEAREST) + label2 = cv2.resize(raw_label2, (self.size[0], self.size[0]), interpolation=cv2.INTER_NEAREST) + + if image.shape != (self.size[0], self.size[0], 3): + image = cv2.resize(image, (self.size[0], self.size[0]), interpolation=cv2.INTER_LINEAR) + label = cv2.resize(label, (self.size[0], self.size[0]), interpolation=cv2.INTER_NEAREST) + label2 = cv2.resize(label2, (self.size[0], self.size[0]), interpolation=cv2.INTER_NEAREST) + + return image, label, label2 + + +class RandRotate(object): + # Randomly rotate image & label with rotate factor in [rotate_min, rotate_max] + def __init__(self, rotate, padding, ignore_label=255, p=0.5): + assert (isinstance(rotate, collections.Iterable) and len(rotate) == 2) + if isinstance(rotate[0], numbers.Number) and isinstance(rotate[1], numbers.Number) and rotate[0] < rotate[1]: + self.rotate = rotate + else: + raise (RuntimeError("segtransform.RandRotate() scale param error.\n")) + assert padding is not None + assert isinstance(padding, list) and len(padding) == 3 + if all(isinstance(i, numbers.Number) for i in padding): + self.padding = padding + else: + raise (RuntimeError("padding in RandRotate() should be a number list\n")) + assert isinstance(ignore_label, int) + self.ignore_label = ignore_label + self.p = p + + def __call__(self, image, label, label2): + if random.random() < self.p: + angle = self.rotate[0] + (self.rotate[1] - self.rotate[0]) * random.random() + h, w = label.shape + matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) + image = cv2.warpAffine(image, matrix, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=self.padding) + label = cv2.warpAffine(label, matrix, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=self.ignore_label) + label2 = cv2.warpAffine(label2, matrix, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=self.ignore_label) + return image, label, label2 + + +class RandomHorizontalFlip(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, image, label, label2): + if random.random() < self.p: + image = cv2.flip(image, 1) + label = cv2.flip(label, 1) + label2 = cv2.flip(label2, 1) + return image, label, label2 + + +class RandomVerticalFlip(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, image, label, label2): + if random.random() < self.p: + image = cv2.flip(image, 0) + label = cv2.flip(label, 0) + label2 = cv2.flip(label2, 0) + return image, label, label2 + + +class RandomGaussianBlur(object): + def __init__(self, radius=5): + self.radius = radius + + def __call__(self, image, label, label2): + if random.random() < 0.5: + image = cv2.GaussianBlur(image, (self.radius, self.radius), 0) + return image, label, label2 + + +class RGB2BGR(object): + # Converts image from RGB order to BGR order, for model initialized from Caffe + def __call__(self, image, label, label2): + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + return image, label, label2 + + +class BGR2RGB(object): + # Converts image from BGR order to RGB order, for model initialized from Pytorch + def __call__(self, image, label, label2): + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + return image, label, label2 \ No newline at end of file diff --git a/util/util.py b/util/util.py new file mode 100644 index 0000000..dc9f107 --- /dev/null +++ b/util/util.py @@ -0,0 +1,281 @@ +import os +import numpy as np +from PIL import Image +import random +import logging +import cv2 +import matplotlib.pyplot as plt +from matplotlib.pyplot import MultipleLocator +from matplotlib.ticker import FuncFormatter, FormatStrFormatter +from matplotlib import font_manager +from matplotlib import rcParams +import seaborn as sns +import pandas as pd +import math +from seaborn.distributions import distplot +from tqdm import tqdm +from scipy import ndimage + +from util.get_weak_anns import find_bbox + +import torch +from torch import nn +import torch.backends.cudnn as cudnn +import torch.nn.init as initer + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def step_learning_rate(optimizer, base_lr, epoch, step_epoch, multiplier=0.1): + """Sets the learning rate to the base LR decayed by 10 every step epochs""" + lr = base_lr * (multiplier ** (epoch // step_epoch)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def poly_learning_rate(optimizer, base_lr, curr_iter, max_iter, power=0.9, index_split=-1, scale_lr=10., warmup=False, warmup_step=500): + """poly learning rate policy""" + if warmup and curr_iter < warmup_step: + lr = base_lr * (0.1 + 0.9 * (curr_iter/warmup_step)) + else: + lr = base_lr * (1 - float(curr_iter) / max_iter) ** power + + # if curr_iter % 50 == 0: + # print('Base LR: {:.4f}, Curr LR: {:.4f}, Warmup: {}.'.format(base_lr, lr, (warmup and curr_iter < warmup_step))) + + for index, param_group in enumerate(optimizer.param_groups): + if index <= index_split: + param_group['lr'] = lr + else: + param_group['lr'] = lr * scale_lr # 10x LR + + +def intersectionAndUnion(output, target, K, ignore_index=255): + # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. + assert (output.ndim in [1, 2, 3]) + assert output.shape == target.shape + output = output.reshape(output.size).copy() + target = target.reshape(target.size) + output[np.where(target == ignore_index)[0]] = ignore_index + intersection = output[np.where(output == target)[0]] + area_intersection, _ = np.histogram(intersection, bins=np.arange(K+1)) + area_output, _ = np.histogram(output, bins=np.arange(K+1)) + area_target, _ = np.histogram(target, bins=np.arange(K+1)) + area_union = area_output + area_target - area_intersection + return area_intersection, area_union, area_target + + +def intersectionAndUnionGPU(output, target, K, ignore_index=255): + # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. + assert (output.dim() in [1, 2, 3]) + assert output.shape == target.shape + output = output.view(-1) + target = target.view(-1) + output[target == ignore_index] = ignore_index + intersection = output[output == target] + area_intersection = torch.histc(intersection, bins=K, min=0, max=K-1) + area_output = torch.histc(output, bins=K, min=0, max=K-1) + area_target = torch.histc(target, bins=K, min=0, max=K-1) + area_union = area_output + area_target - area_intersection + return area_intersection, area_union, area_target + +def check_mkdir(dir_name): + if not os.path.exists(dir_name): + os.mkdir(dir_name) + +def check_makedirs(dir_name): + if not os.path.exists(dir_name): + os.makedirs(dir_name) + +def del_file(path): + for i in os.listdir(path): + path_file = os.path.join(path,i) + if os.path.isfile(path_file): + os.remove(path_file) + else: + del_file(path_file) + +def init_weights(model, conv='kaiming', batchnorm='normal', linear='kaiming', lstm='kaiming'): + """ + :param model: Pytorch Model which is nn.Module + :param conv: 'kaiming' or 'xavier' + :param batchnorm: 'normal' or 'constant' + :param linear: 'kaiming' or 'xavier' + :param lstm: 'kaiming' or 'xavier' + """ + for m in model.modules(): + if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + if conv == 'kaiming': + initer.kaiming_normal_(m.weight) + elif conv == 'xavier': + initer.xavier_normal_(m.weight) + else: + raise ValueError("init type of conv error.\n") + if m.bias is not None: + initer.constant_(m.bias, 0) + + elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):#, BatchNorm1d, BatchNorm2d, BatchNorm3d)): + if batchnorm == 'normal': + initer.normal_(m.weight, 1.0, 0.02) + elif batchnorm == 'constant': + initer.constant_(m.weight, 1.0) + else: + raise ValueError("init type of batchnorm error.\n") + initer.constant_(m.bias, 0.0) + + elif isinstance(m, nn.Linear): + if linear == 'kaiming': + initer.kaiming_normal_(m.weight) + elif linear == 'xavier': + initer.xavier_normal_(m.weight) + else: + raise ValueError("init type of linear error.\n") + if m.bias is not None: + initer.constant_(m.bias, 0) + + elif isinstance(m, nn.LSTM): + for name, param in m.named_parameters(): + if 'weight' in name: + if lstm == 'kaiming': + initer.kaiming_normal_(param) + elif lstm == 'xavier': + initer.xavier_normal_(param) + else: + raise ValueError("init type of lstm error.\n") + elif 'bias' in name: + initer.constant_(param, 0) + +def colorize(gray, palette): + # gray: numpy array of the label and 1*3N size list palette + color = Image.fromarray(gray.astype(np.uint8)).convert('P') + color.putpalette(palette) + return color + + +# ------------------------------------------------------ +def get_model_para_number(model): + total_number = 0 + learnable_number = 0 + for para in model.parameters(): + total_number += torch.numel(para) + if para.requires_grad == True: + learnable_number+= torch.numel(para) + return total_number, learnable_number + +def setup_seed(seed=2021, deterministic=False): + if deterministic: + cudnn.benchmark = False + cudnn.deterministic = True + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + +def get_logger(): + logger_name = "main-logger" + logger = logging.getLogger() + logger.setLevel(logging.INFO) + handler = logging.StreamHandler() + fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" + handler.setFormatter(logging.Formatter(fmt)) + logger.addHandler(handler) + return logger + +def get_save_path(args): + backbone_str = 'vgg' if args.vgg else 'resnet'+str(args.layers) + args.snapshot_path = 'exp/{}/{}/split{}/{}/snapshot'.format(args.data_set, args.arch, args.split, backbone_str) + args.result_path = 'exp/{}/{}/split{}/{}/result'.format(args.data_set, args.arch, args.split, backbone_str) + +def get_train_val_set(args): + if args.data_set == 'pascal': + class_list = list(range(1, 21)) #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20] + if args.split == 3: + sub_list = list(range(1, 16)) #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15] + sub_val_list = list(range(16, 21)) #[16,17,18,19,20] + elif args.split == 2: + sub_list = list(range(1, 11)) + list(range(16, 21)) #[1,2,3,4,5,6,7,8,9,10,16,17,18,19,20] + sub_val_list = list(range(11, 16)) #[11,12,13,14,15] + elif args.split == 1: + sub_list = list(range(1, 6)) + list(range(11, 21)) #[1,2,3,4,5,11,12,13,14,15,16,17,18,19,20] + sub_val_list = list(range(6, 11)) #[6,7,8,9,10] + elif args.split == 0: + sub_list = list(range(6, 21)) #[6,7,8,9,10,11,12,13,14,15,16,17,18,19,20] + sub_val_list = list(range(1, 6)) #[1,2,3,4,5] + + elif args.data_set == 'coco': + if args.use_split_coco: + print('INFO: using SPLIT COCO (FWB)') + class_list = list(range(1, 81)) + if args.split == 3: + sub_val_list = list(range(4, 81, 4)) + sub_list = list(set(class_list) - set(sub_val_list)) + elif args.split == 2: + sub_val_list = list(range(3, 80, 4)) + sub_list = list(set(class_list) - set(sub_val_list)) + elif args.split == 1: + sub_val_list = list(range(2, 79, 4)) + sub_list = list(set(class_list) - set(sub_val_list)) + elif args.split == 0: + sub_val_list = list(range(1, 78, 4)) + sub_list = list(set(class_list) - set(sub_val_list)) + else: + print('INFO: using COCO (PANet)') + class_list = list(range(1, 81)) + if args.split == 3: + sub_list = list(range(1, 61)) + sub_val_list = list(range(61, 81)) + elif args.split == 2: + sub_list = list(range(1, 41)) + list(range(61, 81)) + sub_val_list = list(range(41, 61)) + elif args.split == 1: + sub_list = list(range(1, 21)) + list(range(41, 81)) + sub_val_list = list(range(21, 41)) + elif args.split == 0: + sub_list = list(range(21, 81)) + sub_val_list = list(range(1, 21)) + + return sub_list, sub_val_list + +def is_same_model(model1, model2): + flag = 0 + count = 0 + for k, v in model1.state_dict().items(): + model1_val = v + model2_val = model2.state_dict()[k] + if (model1_val==model2_val).all(): + pass + else: + flag+=1 + print('value of key <{}> mismatch'.format(k)) + count+=1 + + return True if flag==0 else False + +def fix_bn(m): + classname = m.__class__.__name__ + if classname.find('BatchNorm') != -1: + m.eval() + +def sum_list(list): + sum = 0 + for item in list: + sum += item + return sum \ No newline at end of file