Skip to content

Commit

Permalink
fix facornet validation loss aggregation (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
vitalwarley committed Apr 3, 2024
1 parent bd76555 commit 4a1a5c2
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions ours/tasks/facornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def predict(model, val_loader, device: int | str = 0) -> tuple[torch.Tensor, tor

current_index += batch_size_current

return loss, similarities, y_true, pred_kin_relations, y_true_kin_relations
return loss_values, similarities, y_true, pred_kin_relations, y_true_kin_relations


def validate(model, dataloader, device=0, threshold=None):
Expand Down Expand Up @@ -83,6 +83,7 @@ def validate(model, dataloader, device=0, threshold=None):
)
# kin_acc = tm.functional.accuracy(pred_kin_relations, y_true_kin_relations, task="multiclass", num_classes=12)
# return loss, auc, threshold, acc, acc_kin_relations, kin_acc
loss = loss.mean().item()
return loss, auc, threshold, acc, acc_kin_relations


Expand Down Expand Up @@ -127,12 +128,12 @@ def train(args):
out = acc_kr_to_str(out, acc_kv)
print(out)

ce_loss = torch.nn.CrossEntropyLoss()
# ce_loss = torch.nn.CrossEntropyLoss()

for epoch in range(args.num_epoch):
model.train()
contrastive_loss_epoch = 0.0
kin_loss_epoch = 0.0
# kin_loss_epoch = 0.0
for step, data in enumerate(train_loader):
global_step = step + epoch * args.steps_per_epoch

Expand Down

0 comments on commit 4a1a5c2

Please sign in to comment.