Skip to content

Commit

Permalink
Update odise.py
Browse files Browse the repository at this point in the history
  • Loading branch information
CarlHuangNuc authored Jul 6, 2023
1 parent 915ed04 commit 198bd55
Showing 1 changed file with 68 additions and 3 deletions.
71 changes: 68 additions & 3 deletions odise/modeling/meta_arch/odise.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)

self.category_head = category_head
self.clip_head = clip_head

Expand All @@ -184,11 +185,12 @@ def cal_pred_logits(self, outputs):
# [K, C]
text_embed = outputs["text_embed"]
# [1, C]
text_embed = outputs["text_embed"]
#text_embed = outputs["text_embed"]
null_embed = outputs["null_embed"]

labels = outputs["labels"]


mask_embed = F.normalize(mask_embed, dim=-1)
text_embed = F.normalize(text_embed, dim=-1)
logit_scale = outputs["logit_scale"]
Expand All @@ -203,6 +205,9 @@ def cal_pred_logits(self, outputs):

# [B, Q, K+1]
pred = torch.cat([pred, null_pred], dim=-1)

#print(pred.shape)
#print("rrrrrrrrrrrrrrrrrrrr")

return pred

Expand Down Expand Up @@ -233,15 +238,22 @@ def forward(self, batched_inputs):
segments_info (list[dict]): Describe each segment in `panoptic_seg`.
Each dict contains keys "id", "category_id", "isthing".
"""

#print("1111111111111111111111111")
# print(batched_inputs[0].keys())

images = [x["image"].to(self.device) for x in batched_inputs]
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
images = ImageList.from_tensors(images, self.size_divisibility)

denormalized_images = ImageList.from_tensors(
[x["image"].to(self.device) / 255.0 for x in batched_inputs]
)


features = self.backbone(images.tensor)
#print(features.keys())
#print("3333333333333333444444444444444444444")
outputs = self.sem_seg_head(features)
outputs["images"] = denormalized_images.tensor

Expand All @@ -254,8 +266,12 @@ def forward(self, batched_inputs):
targets = None

if self.category_head is not None:


category_head_outputs = self.category_head(outputs, targets)
outputs.update(category_head_outputs)


# inplace change pred_logits
outputs["pred_logits"] = self.cal_pred_logits(outputs)
if "aux_outputs" in outputs:
Expand All @@ -269,6 +285,8 @@ def forward(self, batched_inputs):
# targets = self.clip_head.prepare_targets(outputs, targets)

# bipartite matching-based loss


losses = self.criterion(outputs, targets)

for k in list(losses.keys()):
Expand All @@ -284,24 +302,49 @@ def forward(self, batched_inputs):
# get text_embeddings
outputs.update(self.category_head(outputs))

#print("666666666666666666666666666")
#print(outputs.keys())

#print(outputs["pred_logits"].shape)
outputs["pred_logits"] = self.cal_pred_logits(outputs)
#print("777777777777777777")
#print(outputs.keys())
#print(outputs["pred_logits"].shape)


mask_pred_results = outputs["pred_masks"]
mask_cls_results = outputs["pred_logits"]

#print("8888888888888888")
#print(mask_pred_results.shape)
#print(mask_cls_results.shape)

if self.clip_head is not None:
if self.clip_head.with_bg:
# [B, Q, K+1]
outputs["pred_open_logits"] = outputs["pred_logits"]
outputs.update(self.clip_head(outputs))
mask_cls_results = outputs["pred_open_logits"]
else:
#print("bbbbbbbbbbbbbbbbbb")
# [B, Q, K]
#print(outputs["pred_logits"].shape)

outputs["pred_open_logits"] = outputs["pred_logits"][..., :-1]
#print(outputs["pred_open_logits"].shape)
#print("ccccccccccccccccccccccccccc")
#print(outputs.keys())
#exit()

outputs.update(self.clip_head(outputs))
#print("ddddddddddddddddddd")
#print(outputs.keys())

# merge with bg scores
open_logits = outputs["pred_open_logits"]
#print(open_logits.shape)

#print("9999999999999999999999999")
#print(mask_cls_results.shape)

# in case the prediction is not binary
binary_probs = torch.zeros(
Expand Down Expand Up @@ -336,18 +379,23 @@ def forward(self, batched_inputs):
for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes
):

#print("fffffffffffffffffffff")
height = input_per_image.get("height", image_size[0])
width = input_per_image.get("width", image_size[1])
processed_results.append({})

if self.sem_seg_postprocess_before_inference:
# print("1111111111111111111111111111111111--------------------------")
mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
mask_pred_result, image_size, height, width
)
mask_cls_result = mask_cls_result.to(mask_pred_result)

# semantic segmentation inference
#if False:
if self.semantic_on:
#print("2222222222222222222222222------------------------------------")
r = retry_if_cuda_oom(self.semantic_inference)(
mask_cls_result, mask_pred_result
)
Expand All @@ -356,14 +404,18 @@ def forward(self, batched_inputs):
processed_results[-1]["sem_seg"] = r

# panoptic segmentation inference
if self.panoptic_on:
if False:

#if self.panoptic_on:
#print("3333333333333333333333333----------------------------------------")
panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(
mask_cls_result, mask_pred_result
)
processed_results[-1]["panoptic_seg"] = panoptic_r

# instance segmentation inference
if self.instance_on:
#print("44444444444444444444444----------------------------------------")
instance_r = retry_if_cuda_oom(self.instance_inference)(
mask_cls_result, mask_pred_result
)
Expand Down Expand Up @@ -1225,6 +1277,10 @@ def __init__(
prompt=None,
):
super().__init__()
print("4444444444444444444")
print(labels)
print(len(labels))

self.labels = labels

self.clip_model_name = clip_model_name
Expand Down Expand Up @@ -1298,12 +1354,19 @@ def forward(self, outputs, targets=None):
else:
assert targets is None
assert self.test_labels is not None

labels = self.test_labels
text_embed = self.get_and_cache_test_text_embed(prompt_labels(labels, self.prompt))

#print("rrrrrrrr")
#print(text_embed.shape)
text_embed = self.text_proj(text_embed)
null_embed = self.text_proj(self.null_embed)

#print(text_embed.shape)
#print(null_embed.shape)


return {"text_embed": text_embed, "null_embed": null_embed, "labels": labels}


Expand Down Expand Up @@ -1445,6 +1508,8 @@ def __init__(
self.prompt = prompt
if train_labels is None:
self.train_labels = get_openseg_labels("coco_panoptic", prompt_engineered=True)
#self.train_labels = get_openseg_labels("ade20k_150", prompt_engineered=True)

else:
self.train_labels = train_labels
self.bg_labels = bg_labels
Expand Down

0 comments on commit 198bd55

Please sign in to comment.