diff --git a/odise/modeling/meta_arch/odise.py b/odise/modeling/meta_arch/odise.py index a49d5ad..3b96604 100644 --- a/odise/modeling/meta_arch/odise.py +++ b/odise/modeling/meta_arch/odise.py @@ -175,6 +175,7 @@ def __init__( **kwargs, ): super().__init__(**kwargs) + self.category_head = category_head self.clip_head = clip_head @@ -184,11 +185,12 @@ def cal_pred_logits(self, outputs): # [K, C] text_embed = outputs["text_embed"] # [1, C] - text_embed = outputs["text_embed"] + #text_embed = outputs["text_embed"] null_embed = outputs["null_embed"] labels = outputs["labels"] + mask_embed = F.normalize(mask_embed, dim=-1) text_embed = F.normalize(text_embed, dim=-1) logit_scale = outputs["logit_scale"] @@ -203,6 +205,9 @@ def cal_pred_logits(self, outputs): # [B, Q, K+1] pred = torch.cat([pred, null_pred], dim=-1) + + #print(pred.shape) + #print("rrrrrrrrrrrrrrrrrrrr") return pred @@ -233,6 +238,10 @@ def forward(self, batched_inputs): segments_info (list[dict]): Describe each segment in `panoptic_seg`. Each dict contains keys "id", "category_id", "isthing". """ + + #print("1111111111111111111111111") + # print(batched_inputs[0].keys()) + images = [x["image"].to(self.device) for x in batched_inputs] images = [(x - self.pixel_mean) / self.pixel_std for x in images] images = ImageList.from_tensors(images, self.size_divisibility) @@ -240,8 +249,11 @@ def forward(self, batched_inputs): denormalized_images = ImageList.from_tensors( [x["image"].to(self.device) / 255.0 for x in batched_inputs] ) + features = self.backbone(images.tensor) + #print(features.keys()) + #print("3333333333333333444444444444444444444") outputs = self.sem_seg_head(features) outputs["images"] = denormalized_images.tensor @@ -254,8 +266,12 @@ def forward(self, batched_inputs): targets = None if self.category_head is not None: + + category_head_outputs = self.category_head(outputs, targets) outputs.update(category_head_outputs) + + # inplace change pred_logits outputs["pred_logits"] = self.cal_pred_logits(outputs) if "aux_outputs" in outputs: @@ -269,6 +285,8 @@ def forward(self, batched_inputs): # targets = self.clip_head.prepare_targets(outputs, targets) # bipartite matching-based loss + + losses = self.criterion(outputs, targets) for k in list(losses.keys()): @@ -284,11 +302,22 @@ def forward(self, batched_inputs): # get text_embeddings outputs.update(self.category_head(outputs)) + #print("666666666666666666666666666") + #print(outputs.keys()) + + #print(outputs["pred_logits"].shape) outputs["pred_logits"] = self.cal_pred_logits(outputs) + #print("777777777777777777") + #print(outputs.keys()) + #print(outputs["pred_logits"].shape) + mask_pred_results = outputs["pred_masks"] mask_cls_results = outputs["pred_logits"] - + #print("8888888888888888") + #print(mask_pred_results.shape) + #print(mask_cls_results.shape) + if self.clip_head is not None: if self.clip_head.with_bg: # [B, Q, K+1] @@ -296,12 +325,26 @@ def forward(self, batched_inputs): outputs.update(self.clip_head(outputs)) mask_cls_results = outputs["pred_open_logits"] else: + #print("bbbbbbbbbbbbbbbbbb") # [B, Q, K] + #print(outputs["pred_logits"].shape) + outputs["pred_open_logits"] = outputs["pred_logits"][..., :-1] + #print(outputs["pred_open_logits"].shape) + #print("ccccccccccccccccccccccccccc") + #print(outputs.keys()) + #exit() + outputs.update(self.clip_head(outputs)) + #print("ddddddddddddddddddd") + #print(outputs.keys()) # merge with bg scores open_logits = outputs["pred_open_logits"] + #print(open_logits.shape) + + #print("9999999999999999999999999") + #print(mask_cls_results.shape) # in case the prediction is not binary binary_probs = torch.zeros( @@ -336,18 +379,23 @@ def forward(self, batched_inputs): for mask_cls_result, mask_pred_result, input_per_image, image_size in zip( mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes ): + + #print("fffffffffffffffffffff") height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) processed_results.append({}) if self.sem_seg_postprocess_before_inference: + # print("1111111111111111111111111111111111--------------------------") mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( mask_pred_result, image_size, height, width ) mask_cls_result = mask_cls_result.to(mask_pred_result) # semantic segmentation inference + #if False: if self.semantic_on: + #print("2222222222222222222222222------------------------------------") r = retry_if_cuda_oom(self.semantic_inference)( mask_cls_result, mask_pred_result ) @@ -356,7 +404,10 @@ def forward(self, batched_inputs): processed_results[-1]["sem_seg"] = r # panoptic segmentation inference - if self.panoptic_on: + if False: + + #if self.panoptic_on: + #print("3333333333333333333333333----------------------------------------") panoptic_r = retry_if_cuda_oom(self.panoptic_inference)( mask_cls_result, mask_pred_result ) @@ -364,6 +415,7 @@ def forward(self, batched_inputs): # instance segmentation inference if self.instance_on: + #print("44444444444444444444444----------------------------------------") instance_r = retry_if_cuda_oom(self.instance_inference)( mask_cls_result, mask_pred_result ) @@ -1225,6 +1277,10 @@ def __init__( prompt=None, ): super().__init__() + print("4444444444444444444") + print(labels) + print(len(labels)) + self.labels = labels self.clip_model_name = clip_model_name @@ -1298,12 +1354,19 @@ def forward(self, outputs, targets=None): else: assert targets is None assert self.test_labels is not None + labels = self.test_labels text_embed = self.get_and_cache_test_text_embed(prompt_labels(labels, self.prompt)) + #print("rrrrrrrr") + #print(text_embed.shape) text_embed = self.text_proj(text_embed) null_embed = self.text_proj(self.null_embed) + #print(text_embed.shape) + #print(null_embed.shape) + + return {"text_embed": text_embed, "null_embed": null_embed, "labels": labels} @@ -1445,6 +1508,8 @@ def __init__( self.prompt = prompt if train_labels is None: self.train_labels = get_openseg_labels("coco_panoptic", prompt_engineered=True) + #self.train_labels = get_openseg_labels("ade20k_150", prompt_engineered=True) + else: self.train_labels = train_labels self.bg_labels = bg_labels