-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Problem with mask predictions #22
Comments
You will have to define the classes with a 1:1 correspondence in pascal and set them as active classes:
To evaluate on Pascal VOCyou will have to implement a corresponding dataloader and evaluation function. A note: Interestingly the pascal classes seem to be easier for the model as one-shot performance on the pascal classes in coco is better than on the 4 symmetric splits we use in the paper. |
Hi @michaelisc ! Thanks for your reply! Did you also report the AP50 for the Pascal VOC which trained on Non-VOC COCO dataset in your paper? If not, I am wondering what's the result for your experiment? Because I am trying to implement this experiment and I wish to know what's the reasonable AP50 should be. Thanks a lot! |
We did not evaluate performance on Pascal VOC. We did measure performance on COCO using the VOC classes as a split but in the end decided to go with the four balanced splits reported in the paper because we think that this is the "cleaner" solution. |
Hi @michaelisc! Did you provide the learning curve of your model? I have tried to train on Non-VOC COCO dataset and the performance is still low after 100 epochs. I used the large config as same as this to train the model. Thanks! |
@NCTU-VRDL the learning rate decay step after 120 epochs is very important for performance. I'd guess this is the problem but let me know if the performance is still low a few epochs after that. |
The performance is still poor after 140 epochs training, you can see the training loss still high in the learning curve below And you can find me train.py here, but it almost as same as the notebook you provide. The only difference is that I use two V100 and 6 images_per_gpu to train this model. Please have some checks with the code, many thanks! import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO)
sess_config = tf.ConfigProto()
import sys
import os
COCO_DATA = 'data/coco'
MASK_RCNN_MODEL_PATH = 'lib/Mask_RCNN/'
if MASK_RCNN_MODEL_PATH not in sys.path:
sys.path.append(MASK_RCNN_MODEL_PATH)
from samples.coco import coco
from mrcnn import utils
from mrcnn import model as modellib
from mrcnn import visualize
from lib import utils as siamese_utils
from lib import model as siamese_model
from lib import config as siamese_config
import time
import datetime
import random
import numpy as np
import skimage.io
import imgaug
import pickle
import matplotlib.pyplot as plt
from collections import OrderedDict
# Root directory of the project
ROOT_DIR = os.getcwd()
# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs")
# train_classes = coco_nopascal_classes
pascal_classes = [1,2,3,4,5,6,7,9,15,16,17,18,19,20,40,57,58,59,61,63]
no_pascal_classes = [i for i in range(1,81) if i not in pascal_classes]
train_classes = np.array(no_pascal_classes)
# Load COCO/train dataset
coco_train = siamese_utils.IndexedCocoDataset()
coco_train.load_coco(COCO_DATA, subset="train", subsubset="train", year="2017")
coco_train.prepare()
coco_train.build_indices()
coco_train.ACTIVE_CLASSES = train_classes
# Load COCO/val dataset
coco_val = siamese_utils.IndexedCocoDataset()
coco_val.load_coco(COCO_DATA, subset="train", subsubset="val", year="2017")
coco_val.prepare()
coco_val.build_indices()
coco_val.ACTIVE_CLASSES = train_classes
# ### Model
class LargeTrainConfig(siamese_config.Config):
# Batch size = GPU_COUNT * IMAGES_PER_GPU
GPU_COUNT = 2
IMAGES_PER_GPU = 6 # 4 16GB GPUs are required for a batch_size of 12
NUM_CLASSES = 1 + 1
NAME = 'large_coco_novoc'
EXPERIMENT = 'example'
CHECKPOINT_DIR = 'checkpoints/'
# Reduced image sizes
TARGET_MAX_DIM = 192
TARGET_MIN_DIM = 150
IMAGE_MIN_DIM = 800
IMAGE_MAX_DIM = 1024
# Reduce model size
FPN_CLASSIF_FC_LAYERS_SIZE = 1024
FPN_FEATUREMAPS = 256
# Reduce number of rois at all stages
RPN_ANCHOR_STRIDE = 1
RPN_TRAIN_ANCHORS_PER_IMAGE = 256
POST_NMS_ROIS_TRAINING = 2000
POST_NMS_ROIS_INFERENCE = 1000
TRAIN_ROIS_PER_IMAGE = 200
DETECTION_MAX_INSTANCES = 100
MAX_GT_INSTANCES = 100
# Adapt NMS Threshold
DETECTION_NMS_THRESHOLD = 0.5
# Adapt loss weights
LOSS_WEIGHTS = {'rpn_class_loss': 2.0,
'rpn_bbox_loss': 0.1,
'mrcnn_class_loss': 2.0,
'mrcnn_bbox_loss': 0.5,
'mrcnn_mask_loss': 1.0}
config = LargeTrainConfig()
config.display()
# Create model object in inference mode.
model = siamese_model.SiameseMaskRCNN(mode="training", model_dir=MODEL_DIR, config=config)
train_schedule = OrderedDict()
train_schedule[1] = {"learning_rate": config.LEARNING_RATE, "layers": "heads"}
train_schedule[120] = {"learning_rate": config.LEARNING_RATE, "layers": "all"}
train_schedule[160] = {"learning_rate": config.LEARNING_RATE/10, "layers": "all"}
model.load_imagenet_weights(pretraining='imagenet_687')
for epochs, parameters in train_schedule.items():
print("")
print("training layers {} until epoch {} with learning_rate {}".format(parameters["layers"],
epochs,
parameters["learning_rate"]))
model.train(coco_train, coco_val,
learning_rate=parameters["learning_rate"],
epochs=epochs,
layers=parameters["layers"]) |
What is the performance when you evaluate? Loss and AP are not always that correlated because the model uses a mixture of 5 loss functions. In my experience the classifier loss would for example barely drop but still classify somewhat recently in the end (it is still the weak spot though). |
When I used the pre-train "large_siamese_mrcnn_coco_i4_0160.h5" to inference on VOC dataset, I can get 0.58 at AP50. It makes sense since "i4" may contain some same classes as VOC. But I get only 0.04 at AP50 when using the model trained from scripts above. |
Is this detection or instance segmentation performance? |
instance segmentation performance |
Can you test object detection? You are now the second person to report this and it seems as if something broke in the segmentation training. And can you visualize some of the predictions? Maybe this gives some hints at what is going on. |
A quick check did not turn up any obvious bugs in the model. I'll have to check some of the older commits to figure out, where the problem originated from. That may take a bit. |
Maybe you can try using the pre-train weights as initialization to train a new model. I also trained on the LVIS dataset, and use the "coco_full_0320.h5" as initialization. But the performance is getting much worser after 100 epochs training. From 0.24 to 0.16 at AP50 (segm), which is the "coco_full_0320.h5" model can achieve 0.24 without any fine-tuning. |
I first want to check if training with an older commit of the repo works. In the end I did run the experiments which led to the pre-trained checkpoints at some time, so something must have changed in between. But as I said, this may take a while. |
That's what I suspected. Won't be easy to fix as it is something that goes wrong during training and only for the mask branch (object detection and the pre-trained model work). |
Hi @NCTU-VRDL |
Hi!
I am wondering how to inference with the unseen classes? In the evaluate.ipynb, I notice that you comment the coco_nopascal_classes.
To more specifically, how do I train a model with COCO no pascal classess and evaluate the model with pascal dataset?
Thank you!
The text was updated successfully, but these errors were encountered: