Skip to content

Commit

Permalink
add HeadKin to facornet model (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
vitalwarley committed Mar 21, 2024
1 parent def26ca commit b08dae7
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 33 deletions.
48 changes: 31 additions & 17 deletions ours/models/facornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,34 +51,45 @@ def to_input(pil_rgb_image):
return tensor


class FaCoR(torch.nn.Module):
def __init__(self):
super(FaCoR, self).__init__()
self.backbone = load_pretrained_model("ir_101")
self.projection = nn.Sequential(
torch.nn.Linear(512 * 6, 256),
torch.nn.BatchNorm1d(256),
class HeadKin(nn.Module): # couldn't be HeadFamily because there is no family label
def __init__(self, in_features=512, out_features=4, ratio=8):
super().__init__()
self.projection_head = nn.Sequential(
torch.nn.Linear(2 * in_features, in_features // ratio),
torch.nn.BatchNorm1d(in_features // ratio),
torch.nn.ReLU(),
torch.nn.Linear(256, 1),
torch.nn.Linear(in_features // ratio, out_features),
)
self.channel = 64
self.spatial_ca = SpatialCrossAttention(self.channel * 8, CA=True)
self.channel_ca = ChannelCrossAttention(self.channel * 8)
self.CCA = ChannelInteraction(1024)
self.avg_pool = nn.AdaptiveAvgPool2d(1)

self._initialize_weights()
self.initialize_weights(self.projection_head)

def _initialize_weights(self):
for m in self.projection.modules():
def initialize_weights(self, proj_head):
for m in proj_head.modules():
if isinstance(m, nn.Linear):
nn.init.uniform_(m.weight - 0.05, 0.05)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)

nn.init.constant_(m.bias, 0)

def forward(self, em):
return self.projection_head(em)


class FaCoR(torch.nn.Module):
def __init__(self):
super(FaCoR, self).__init__()
self.backbone = load_pretrained_model("ir_101")
self.channel = 64
self.spatial_ca = SpatialCrossAttention(self.channel * 8, CA=True)
self.channel_ca = ChannelCrossAttention(self.channel * 8)
self.CCA = ChannelInteraction(1024)
self.avg_pool = nn.AdaptiveAvgPool2d(1)

self.task_kin = HeadKin(512, 12, 8)

def forward(self, imgs, aug=False):
img1, img2 = imgs
idx = [2, 1, 0]
Expand Down Expand Up @@ -110,7 +121,10 @@ def forward(self, imgs, aug=False):
f1s = torch.flatten(f1s, 1)
f2s = torch.flatten(f2s, 1)

return f1s, f2s, att_map0
fc = torch.cat([f1s, f2s], dim=1)
kin = self.task_kin(fc)

return kin, f1s, f2s, att_map0


class SpatialCrossAttention(nn.Module):
Expand Down
41 changes: 25 additions & 16 deletions ours/tasks/facornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def predict(model, val_loader, device: int | str = 0) -> tuple[torch.Tensor, tor
similarities = torch.zeros(dataset_size, device=device)
y_true = torch.zeros(dataset_size, dtype=torch.uint8, device=device)
y_true_kin_relations = torch.zeros(dataset_size, dtype=torch.uint8, device=device)
pred_kin_relations = torch.zeros(dataset_size, dtype=torch.uint8, device=device)

current_index = 0
for img1, img2, labels in tqdm(val_loader, total=len(val_loader), bar_format=TQDM_BAR_FORMAT):
Expand All @@ -38,23 +39,24 @@ def predict(model, val_loader, device: int | str = 0) -> tuple[torch.Tensor, tor
(kin_relation, is_kin) = labels
kin_relation, is_kin = kin_relation.to(device), is_kin.to(device)

f1, f2, _ = model([img1, img2])
kin, f1, f2, _ = model([img1, img2])
sim = torch.cosine_similarity(f1, f2)

# Fill preallocated tensors
similarities[current_index : current_index + batch_size_current] = sim
y_true[current_index : current_index + batch_size_current] = is_kin
y_true_kin_relations[current_index : current_index + batch_size_current] = kin_relation
pred_kin_relations[current_index : current_index + batch_size_current] = kin.argmax(dim=1)

current_index += batch_size_current

return similarities, y_true, y_true_kin_relations
return similarities, y_true, pred_kin_relations, y_true_kin_relations


def validate(model, dataloader, device=0, threshold=None):
model.eval()
# Compute similarities
similarities, y_true, y_true_kin_relations = predict(model, dataloader)
similarities, y_true, pred_kin_relations, y_true_kin_relations = predict(model, dataloader)
# Compute metrics
auc = tm.functional.auroc(similarities, y_true, task="binary")
fpr, tpr, thresholds = tm.functional.roc(similarities, y_true, task="binary")
Expand All @@ -75,7 +77,8 @@ def validate(model, dataloader, device=0, threshold=None):
acc_kin_relations[kin_relation] = tm.functional.accuracy(
similarities[mask], y_true[mask], task="binary", threshold=threshold
)
return auc, threshold, acc, acc_kin_relations
kin_acc = tm.functional.accuracy(pred_kin_relations, y_true_kin_relations, task="multiclass", num_classes=12)
return auc, threshold, acc, acc_kin_relations, kin_acc


def train(args):
Expand Down Expand Up @@ -114,16 +117,17 @@ def train(args):
total_steps = len(train_loader)
print(f"Total steps: {total_steps}")
global_step = 0
best_model_auc, _, val_acc, acc_kr = validate(model, val_model_sel_loader)
out = f"epoch: 0 | auc: {best_model_auc:.6f} | acc: {val_acc:.6f}"
# Add acc_kr to out
for kin_relation, acc in acc_kr.items():
out += f" | acc_{kin_relation}: {acc:.6f}"
best_model_auc, _, val_acc, acc_kv, acc_clf_kr = validate(model, val_model_sel_loader)
out = f"epoch: 0 | auc: {best_model_auc:.6f} | acc_kv: {val_acc:.6f} | acc_clf_kr: {acc_clf_kr:.6f}"
out = acc_kr_to_str(out, acc_kv)
print(out)

ce_loss = torch.nn.CrossEntropyLoss()

for epoch in range(args.num_epoch):
model.train()
loss_epoch = 0.0
contrastive_loss_epoch = 0.0
kin_loss_epoch = 0.0
for step, data in enumerate(train_loader):
global_step = step + epoch * args.steps_per_epoch

Expand All @@ -135,10 +139,13 @@ def train(args):
kin_relation = kin_relation.to(args.device)
is_kin = is_kin.to(args.device)

x1, x2, att = model([image1, image2])
loss = facornet_contrastive_loss(x1, x2, beta=att)
kin, x1, x2, att = model([image1, image2])
contrastive_loss = facornet_contrastive_loss(x1, x2, beta=att)
kin_loss = ce_loss(kin, kin_relation)

loss_epoch += loss.item()
contrastive_loss_epoch += contrastive_loss.item()
kin_loss_epoch += kin_loss.item()
loss = contrastive_loss + kin_loss

optimizer.zero_grad()
loss.backward()
Expand All @@ -151,17 +158,19 @@ def train(args):
train_dataset.set_bias(use_sample)

# Save model checkpoints
auc, _, val_acc, acc_kr = validate(model, val_model_sel_loader)
auc, _, val_acc, acc_kv, acc_clf_kr = validate(model, val_model_sel_loader)

if auc > best_model_auc:
best_model_auc = auc
torch.save(model.state_dict(), args.output_dir / "best.pth")

out = (
f"epoch: {epoch + 1:>2} | step: {global_step} "
+ f"| loss: {loss_epoch / args.steps_per_epoch:.3f} | auc: {auc:.6f} | acc: {val_acc:.6f}"
+ f"| loss: {contrastive_loss_epoch / args.steps_per_epoch:.3f} "
+ f"| kin_loss: {kin_loss_epoch / args.steps_per_epoch:.3f} "
+ f"| auc: {auc:.6f} | acc_kv: {val_acc:.6f} | acc_clf_kr: {acc_clf_kr:.6f}"
)
out = acc_kr_to_str(out, acc_kr)
out = acc_kr_to_str(out, acc_kv)
print(out)


Expand Down

0 comments on commit b08dae7

Please sign in to comment.