Skip to content

Commit

Permalink
Update example
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jan 9, 2025
1 parent 4921cf7 commit 2dd6822
Showing 1 changed file with 20 additions and 96 deletions.
116 changes: 20 additions & 96 deletions examples/imagenet_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,27 @@
.. include:: ../plots/imagenet_classification_chart.txt
A file list can be created, for example, by:
.. code-block:: bash
cd /data/users/moto/imagenet/
find val -name '*.JPEG' > ~/imagenet.val.flist
To run the benchmark, pass it to the script like the following.
.. code-block::
python imagenet_classification.py
--input-flist ~/imagenet.val.flist
--prefix /data/users/moto/imagenet/
--root-dir ~/imagenet/
--split val
"""

# pyre-ignore-all-errors

import contextlib
import logging
import os.path
import re
import time
from collections.abc import Awaitable, Callable, Iterator
from pathlib import Path

import spdl.io
import spdl.utils
import torch
from spdl.dataloader import DataLoader
from spdl.dataloader import DataLoader, ImageNet
from torch import Tensor
from torch.profiler import profile

Expand All @@ -53,15 +44,12 @@
__all__ = [
"entrypoint",
"benchmark",
"source",
"get_decode_func",
"get_dataloader",
"get_model",
"ModelBundle",
"Classification",
"Preprocessing",
"get_mappings",
"parse_wnid",
]


Expand All @@ -73,20 +61,20 @@ def _parse_args(args):
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--input-flist", type=Path, required=True)
parser.add_argument("--max-samples", type=int, default=float("inf"))
parser.add_argument("--prefix", default="")
parser.add_argument("--root-dir", type=Path, required=True)
parser.add_argument("--max-batches", type=int, default=float("inf"))
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--split", default="val", choices=["train", "val"])
parser.add_argument("--trace", type=Path)
parser.add_argument("--queue-size", type=int, default=16)
parser.add_argument("--buffer-size", type=int, default=16)
parser.add_argument("--num-threads", type=int, default=16)
parser.add_argument("--no-compile", action="store_false", dest="compile")
parser.add_argument("--no-bf16", action="store_false", dest="use_bf16")
parser.add_argument("--use-nvdec", action="store_true")
parser.add_argument("--use-nvjpeg", action="store_true")
args = parser.parse_args(args)
if args.trace:
args.max_samples = args.batch_size * 60
args.max_batches = 60
return args


Expand Down Expand Up @@ -221,34 +209,6 @@ def get_model(
return ModelBundle(model, preprocessing, classification, use_bf16)


def source(
path: Path,
prefix: str = "",
max_samples: int = float("inf"),
) -> Iterator[tuple[str, int]]:
"""Iterate a file containing a list of paths.
Args:
path: Path to the file containing list of file paths.
prefix: Prepended to the paths in the list.
max_samples: Maximum number of samples to yield.
Yields:
The path of the image and its class label.
"""
class_mapping = get_mappings()

with open(path) as f:
i = 0
for line in f:
if line := line.strip():
path_ = prefix + line
label = class_mapping[parse_wnid(path_)]
yield path_, label
if (i := i + 1) >= max_samples:
return


def get_decode_func(
device_index: int,
width: int = 224,
Expand Down Expand Up @@ -396,12 +356,17 @@ def get_dataloader(
)


def benchmark(dataloader: Iterator[tuple[Tensor, Tensor]], model: ModelBundle) -> None:
def benchmark(
dataloader: Iterator[tuple[Tensor, Tensor]],
model: ModelBundle,
max_batches: int = float("nan"),
) -> None:
"""The main loop that measures the performance of dataloading and model inference.
Args:
loader: The dataloader to benchmark.
model: The model to benchmark.
max_batches: The number of batch before stopping.
"""

_LG.info("Running inference.")
Expand All @@ -422,6 +387,9 @@ def benchmark(dataloader: Iterator[tuple[Tensor, Tensor]], model: ModelBundle) -
num_frames += batch.shape[0]
num_correct_top1 += top1
num_correct_top5 += top5

if i + 1 >= max_batches:
break
finally:
elapsed = time.monotonic() - t0
if num_frames != 0:
Expand All @@ -436,7 +404,7 @@ def benchmark(dataloader: Iterator[tuple[Tensor, Tensor]], model: ModelBundle) -


def _get_dataloader(args, device_index) -> DataLoader:
src = source(args.input_flist, args.prefix, args.max_samples)
src = ImageNet(args.root_dir, split=args.split)

if args.use_nvjpeg:
decode_func = _get_experimental_nvjpeg_decode_function(device_index)
Expand All @@ -449,7 +417,7 @@ def _get_dataloader(args, device_index) -> DataLoader:
src,
args.batch_size,
decode_func,
args.queue_size,
args.buffer_size,
args.num_threads,
)

Expand All @@ -476,7 +444,7 @@ def entrypoint(args: list[int] | None = None):
profile() if args.trace else contextlib.nullcontext() as prof,
spdl.utils.tracing(f"{trace_path}.pftrace", enable=args.trace is not None),
):
benchmark(dataloader, model)
benchmark(dataloader, model, args.max_batches)

if args.trace:
prof.export_chrome_trace(f"{trace_path}.json")
Expand All @@ -488,49 +456,5 @@ def _init_logging(debug=False):
logging.basicConfig(format=fmt, level=level)


def get_mappings() -> dict[str, int]:
"""Get the mapping from WordNet ID to class and label.
1000 IDs from ILSVRC2012 is used. The class indices are the index of
sorted WordNet ID, which corresponds to most models publicly available.
Returns:
Mapping from WordNet ID to class index.
Example:
.. code-block::
>>> class_mapping = get_mappings()
>>> print(class_mapping["n03709823"])
636
"""
class_mapping = {}

path = os.path.join(os.path.dirname(__file__), "imagenet_class.tsv")
with open(path, mode="r", encoding="utf-8") as f:
for line in f:
if line := line.strip():
class_, wnid = line.split("\t")[:2]
class_mapping[wnid] = int(class_)
return class_mapping


def parse_wnid(s: str):
"""Parse a WordNet ID (nXXXXXXXX) from string.
Args:
s (str): String to parse
Returns:
(str): Wordnet ID if found otherwise an exception is raised.
If the string contain multiple WordNet IDs, the first one is returned.
"""
if match := re.search(r"n\d{8}", s):
return match.group(0)
raise ValueError(f"The given string does not contain WNID: {s}")


if __name__ == "__main__":
entrypoint()

0 comments on commit 2dd6822

Please sign in to comment.