Skip to content

Commit

Permalink
eval metrics: benchmarks/imagenet
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLiclair committed Oct 23, 2024
1 parent 6278112 commit e6b3b5a
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 73 deletions.
46 changes: 41 additions & 5 deletions benchmarks/imagenet/resnet50/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from argparse import ArgumentParser
from datetime import datetime
from pathlib import Path
from typing import Sequence, Union
from typing import Sequence, Union, Dict

import barlowtwins
import byol
Expand Down Expand Up @@ -121,11 +121,11 @@ def main(
precision=precision,
ckpt_path=ckpt_path,
)

eval_metrics: Dict[str, Dict[str, float]] = Dict()
if skip_knn_eval:
print_rank_zero("Skipping KNN eval.")
else:
knn_eval.knn_eval(
eval_metrics["knn"] = knn_eval.knn_eval(
model=model,
num_classes=num_classes,
train_dir=train_dir,
Expand All @@ -140,7 +140,7 @@ def main(
if skip_linear_eval:
print_rank_zero("Skipping linear eval.")
else:
linear_eval.linear_eval(
eval_metrics["linear"] = linear_eval.linear_eval(
model=model,
num_classes=num_classes,
train_dir=train_dir,
Expand All @@ -156,7 +156,7 @@ def main(
if skip_finetune_eval:
print_rank_zero("Skipping fine-tune eval.")
else:
finetune_eval.finetune_eval(
eval_metrics["finetune"] = finetune_eval.finetune_eval(
model=model,
num_classes=num_classes,
train_dir=train_dir,
Expand All @@ -169,6 +169,10 @@ def main(
precision=precision,
)

if eval_metrics:
print(f"Results for {method}:")
print(eval_metrics_to_markdown(eval_metrics))


def pretrain(
model: LightningModule,
Expand Down Expand Up @@ -246,6 +250,38 @@ def pretrain(
print_rank_zero(f"max {metric}: {max(metric_callback.val_metrics[metric])}")


def eval_metrics_to_markdown(metrics: Dict[str, Dict[str, float]]) -> str:
EVAL_NAME_COLUMN_NAME = "Eval Name"
METRIC_COLUMN_NAME = "Metric Name"
VALUE_COLUMN_NAME = "Value"

eval_name_max_len = max(
len(eval_name) for eval_name in list(metrics.keys()) + [EVAL_NAME_COLUMN_NAME]
)
metric_name_max_len = max(
len(metric_name)
for metric_dict in metrics.values()
for metric_name in list(metric_dict.keys()) + [METRIC_COLUMN_NAME]
)
value_max_len = max(
len(metric_value)
for metric_dict in metrics.values()
for metric_value in list(f"{value:.2f}" for value in metric_dict.values())
+ [VALUE_COLUMN_NAME]
)

header = f"| {EVAL_NAME_COLUMN_NAME.ljust(eval_name_max_len)} | {METRIC_COLUMN_NAME.ljust(metric_name_max_len)} | {VALUE_COLUMN_NAME.ljust(value_max_len)} |"
separator = f"|:{'-' * (eval_name_max_len)}:|:{'-' * (metric_name_max_len)}:|:{'-' * (value_max_len)}:|"

lines = [header, separator] + [
f"| {eval_name.ljust(eval_name_max_len)} | {metric_name.ljust(metric_name_max_len)} | {f'{metric_value:.2f}'.ljust(value_max_len)} |"
for eval_name, metric_dict in metrics.items()
for metric_name, metric_value in metric_dict.items()
]

return "\n".join(lines)


if __name__ == "__main__":
args = parser.parse_args()
main(**vars(args))
59 changes: 40 additions & 19 deletions benchmarks/imagenet/vitb16/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from argparse import ArgumentParser
from datetime import datetime
from pathlib import Path
from typing import Sequence, Union
from typing import Sequence, Union, Dict

import aim
import finetune_eval
Expand All @@ -20,14 +20,6 @@
from lightly.transforms.utils import IMAGENET_NORMALIZE
from lightly.utils.benchmarking import MetricCallback

from ...metrics import (
EvalMetrics,
FinetuneEvalMetric,
KNNEvalMetric,
LinearEvalMetric,
eval_metrics_to_markdown,
)

parser = ArgumentParser("ImageNet ViT-B/16 Benchmarks")
parser.add_argument("--train-dir", type=Path, default="/datasets/imagenet/train")
parser.add_argument("--val-dir", type=Path, default="/datasets/imagenet/val")
Expand Down Expand Up @@ -108,11 +100,11 @@ def main(
strategy=strategy,
)

eval_metrics: EvalMetrics = []
eval_metrics: Dict[str, Dict[str, float]] = Dict()
if skip_knn_eval:
print("Skipping KNN eval.")
else:
for k, v in knn_eval.knn_eval(
eval_metrics["knn"] = knn_eval.knn_eval(
model=model,
num_classes=num_classes,
train_dir=train_dir,
Expand All @@ -122,13 +114,12 @@ def main(
num_workers=num_workers,
accelerator=accelerator,
devices=devices,
).items():
eval_metrics.append(KNNEvalMetric(k, v))
)

if skip_linear_eval:
print("Skipping linear eval.")
else:
for k, v in linear_eval.linear_eval(
eval_metrics["linear"] = linear_eval.linear_eval(
model=model,
num_classes=num_classes,
train_dir=train_dir,
Expand All @@ -139,13 +130,12 @@ def main(
accelerator=accelerator,
devices=devices,
precision=precision,
).items():
eval_metrics.append(LinearEvalMetric(k, v))
)

if skip_finetune_eval:
print("Skipping fine-tune eval.")
else:
for k, v in finetune_eval.finetune_eval(
eval_metrics["finetune"] = finetune_eval.finetune_eval(
model=model,
num_classes=num_classes,
train_dir=train_dir,
Expand All @@ -156,8 +146,7 @@ def main(
accelerator=accelerator,
devices=devices,
precision=precision,
).items():
eval_metrics.append(FinetuneEvalMetric(k, v))
)

if eval_metrics:
print(f"Results for {method}:")
Expand Down Expand Up @@ -234,6 +223,38 @@ def pretrain(
print(f"max {metric}: {max(metric_callback.val_metrics[metric])}")


def eval_metrics_to_markdown(metrics: Dict[str, Dict[str, float]]) -> str:
EVAL_NAME_COLUMN_NAME = "Eval Name"
METRIC_COLUMN_NAME = "Metric Name"
VALUE_COLUMN_NAME = "Value"

eval_name_max_len = max(
len(eval_name) for eval_name in list(metrics.keys()) + [EVAL_NAME_COLUMN_NAME]
)
metric_name_max_len = max(
len(metric_name)
for metric_dict in metrics.values()
for metric_name in list(metric_dict.keys()) + [METRIC_COLUMN_NAME]
)
value_max_len = max(
len(metric_value)
for metric_dict in metrics.values()
for metric_value in list(f"{value:.2f}" for value in metric_dict.values())
+ [VALUE_COLUMN_NAME]
)

header = f"| {EVAL_NAME_COLUMN_NAME.ljust(eval_name_max_len)} | {METRIC_COLUMN_NAME.ljust(metric_name_max_len)} | {VALUE_COLUMN_NAME.ljust(value_max_len)} |"
separator = f"|:{'-' * (eval_name_max_len)}:|:{'-' * (metric_name_max_len)}:|:{'-' * (value_max_len)}:|"

lines = [header, separator] + [
f"| {eval_name.ljust(eval_name_max_len)} | {metric_name.ljust(metric_name_max_len)} | {f'{metric_value:.2f}'.ljust(value_max_len)} |"
for eval_name, metric_dict in metrics.items()
for metric_name, metric_value in metric_dict.items()
]

return "\n".join(lines)


if __name__ == "__main__":
args = parser.parse_args()
main(**vars(args))
49 changes: 0 additions & 49 deletions benchmarks/metrics.py

This file was deleted.

0 comments on commit e6b3b5a

Please sign in to comment.