Skip to content

Commit

Permalink
Update mask_generator_with_label.py
Browse files Browse the repository at this point in the history
  • Loading branch information
CarlHuangNuc authored Jul 6, 2023
1 parent 20247b7 commit 3aa8a70
Showing 1 changed file with 23 additions and 3 deletions.
26 changes: 23 additions & 3 deletions configs/common/models/mask_generator_with_label.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

# ------------------------------------------------------------------------------
# Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
Expand Down Expand Up @@ -25,10 +26,15 @@
from mask2former.modeling.matcher import HungarianMatcher
from mask2former.modeling.pixel_decoder.msdeformattn import MSDeformAttnPixelDecoder


model = L(CategoryODISE)(
sem_seg_head=L(MaskFormerHead)(
ignore_value=255,
num_classes=133,
#ignore_value=2555,
#num_classes=133,
#num_classes=1203,
num_classes=150,
#num_classes=847,
pixel_decoder=L(MSDeformAttnPixelDecoder)(
conv_dim=256,
mask_dim=256,
Expand Down Expand Up @@ -83,14 +89,28 @@
),
category_head=L(CategoryEmbed)(
clip_model_name="ViT-L-14-336",
labels=L(get_openseg_labels)(dataset="coco_panoptic", prompt_engineered=True),
#labels=L(get_openseg_labels)(dataset="coco_panoptic", prompt_engineered=True),
labels=L(get_openseg_labels)(dataset="ade20k_150", prompt_engineered=True),
#labels=L(get_openseg_labels)(dataset="ade20k_847", prompt_engineered=True),


# ade20k_847
#labels=L(get_openseg_labels)(dataset="lvis_1203", prompt_engineered=True),

projection_dim="${..sem_seg_head.transformer_predictor.post_mask_embed.projection_dim}",
),
clip_head=L(PoolingCLIPHead)(),
num_queries=100,
object_mask_threshold=0.0,
overlap_threshold=0.8,
metadata=L(MetadataCatalog.get)(name="coco_2017_train_panoptic_with_sem_seg"),

#metadata=L(MetadataCatalog.get)(name="coco_2017_train_panoptic_with_sem_seg"),

#### carl ...change ....
metadata=L(MetadataCatalog.get)(name="ade20k_panoptic_train"),
#metadata=L(MetadataCatalog.get)(name="ade20k_full_panoptic_val"),

#metadata=L(MetadataCatalog.get)(name="lvis_v1_train_with_sem_seg"),
size_divisibility=64,
sem_seg_postprocess_before_inference=True,
# normalize to [0, 1]
Expand Down

0 comments on commit 3aa8a70

Please sign in to comment.