diff --git a/ours/tasks/facornet.py b/ours/tasks/facornet.py index 078f6ec..c1dceba 100644 --- a/ours/tasks/facornet.py +++ b/ours/tasks/facornet.py @@ -54,7 +54,7 @@ def predict(model, val_loader, device: int | str = 0) -> tuple[torch.Tensor, tor current_index += batch_size_current - return loss, similarities, y_true, pred_kin_relations, y_true_kin_relations + return loss_values, similarities, y_true, pred_kin_relations, y_true_kin_relations def validate(model, dataloader, device=0, threshold=None): @@ -83,6 +83,7 @@ def validate(model, dataloader, device=0, threshold=None): ) # kin_acc = tm.functional.accuracy(pred_kin_relations, y_true_kin_relations, task="multiclass", num_classes=12) # return loss, auc, threshold, acc, acc_kin_relations, kin_acc + loss = loss.mean().item() return loss, auc, threshold, acc, acc_kin_relations @@ -127,12 +128,12 @@ def train(args): out = acc_kr_to_str(out, acc_kv) print(out) - ce_loss = torch.nn.CrossEntropyLoss() + # ce_loss = torch.nn.CrossEntropyLoss() for epoch in range(args.num_epoch): model.train() contrastive_loss_epoch = 0.0 - kin_loss_epoch = 0.0 + # kin_loss_epoch = 0.0 for step, data in enumerate(train_loader): global_step = step + epoch * args.steps_per_epoch