diff --git a/slowfast/datasets/multigrid_helper.py b/slowfast/datasets/multigrid_helper.py index f596e98d..c6d3fb31 100644 --- a/slowfast/datasets/multigrid_helper.py +++ b/slowfast/datasets/multigrid_helper.py @@ -10,7 +10,7 @@ TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) -if TORCH_MAJOR >= 1 and TORCH_MINOR >= 8: +if TORCH_MAJOR >= 2 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 8): _int_classes = int else: from torch._six import int_classes as _int_classes