diff --git a/lightly/data/_helpers.py b/lightly/data/_helpers.py index 00ab4fe09..b73bd4283 100644 --- a/lightly/data/_helpers.py +++ b/lightly/data/_helpers.py @@ -4,7 +4,7 @@ # All Rights Reserved import os -from typing import Any, Callable, Dict, List, Optional, Set +from typing import Any, Callable, Dict, Optional, Tuple from torchvision import datasets @@ -34,32 +34,29 @@ VIDEO_EXTENSIONS = (".mp4", ".mov", ".avi", ".mpg", ".hevc", ".m4v", ".webm", ".mpeg") -def _dir_contains_videos(root: str, extensions: tuple): - """Checks whether directory contains video files. +def _dir_contains_videos(root: str, extensions: Tuple[str, ...]) -> bool: + """Checks whether the directory contains video files. Args: root: Root directory path. + extensions: Tuple of valid video file extensions. Returns: - True if root contains video files. - + True if the root directory contains video files, False otherwise. """ with os.scandir(root) as scan_dir: return any(f.name.lower().endswith(extensions) for f in scan_dir) -def _contains_videos(root: str, extensions: tuple): - """Checks whether directory or any subdirectory contains video files. - - Iterates over all subdirectories of "root" recursively and returns True - if any of the subdirectories contains a file with a VIDEO_EXTENSION. +def _contains_videos(root: str, extensions: Tuple[str, ...]) -> bool: + """Checks whether the directory or any subdirectory contains video files. Args: root: Root directory path. + extensions: Tuple of valid video file extensions. Returns: - True if "root" or any subdir contains video files. - + True if the root directory or any subdirectory contains video files, False otherwise. """ for subdir, _, _ in os.walk(root): if _dir_contains_videos(subdir, extensions): @@ -67,28 +64,26 @@ def _contains_videos(root: str, extensions: tuple): return False -def _is_lightly_output_dir(dirname: str): +def _is_lightly_output_dir(dirname: str) -> bool: """Checks whether the directory is a lightly_output directory. Args: - dirname: Directory to check. + dirname: Directory name to check. Returns: - True if dirname is "lightly_outputs" else false. - + True if the directory name is "lightly_outputs", False otherwise. """ return "lightly_outputs" in dirname -def _contains_subdirs(root: str): - """Checks whether directory contains subdirectories. +def _contains_subdirs(root: str) -> bool: + """Checks whether the directory contains subdirectories. Args: root: Root directory path. Returns: - True if root contains subdirectories else false. - + True if the root directory contains subdirectories (excluding "lightly_outputs"), False otherwise. """ with os.scandir(root) as scan_dir: return any(not _is_lightly_output_dir(f.name) for f in scan_dir if f.is_dir()) @@ -96,42 +91,45 @@ def _contains_subdirs(root: str): def _load_dataset_from_folder( root: str, - transform, + transform: Callable[[Any], Any], is_valid_file: Optional[Callable[[str], bool]] = None, - tqdm_args: Dict[str, Any] = None, + tqdm_args: Optional[Dict[str, Any]] = None, num_workers_video_frame_counting: int = 0, -): - """Initializes dataset from folder. +) -> datasets.VisionDataset: + """Initializes a dataset from a folder. + + This function determines the appropriate dataset type based on the contents of the root directory + and returns the corresponding dataset object. Args: - root: (str) Root directory path - transform: (torchvision.transforms.Compose) image transformations + root: Root directory path. + transform: Composed image transformations to be applied to the dataset. + is_valid_file: Optional function to determine valid files. + tqdm_args: Optional dictionary of arguments for tqdm progress bar. + num_workers_video_frame_counting: Number of workers for video frame counting. Returns: - Dataset consisting of images/videos in the root directory. + A dataset object (VideoDataset, ImageFolder, or DatasetFolder) based on the directory contents. Raises: - ValueError: If the specified dataset doesn't exist - + ValueError: If the specified dataset directory doesn't exist or if videos are present + but VideoDataset is not available. """ if not os.path.exists(root): raise ValueError(f"The input directory {root} does not exist!") - # if there is a video in the input directory but we do not have - # the right dependencies, raise a ValueError contains_videos = _contains_videos(root, VIDEO_EXTENSIONS) if contains_videos and not VIDEO_DATASET_AVAILABLE: raise ValueError( f"The input directory {root} contains videos " - "but the VideoDataset is not available. \n" + "but the VideoDataset is not available. " "Make sure you have installed the right " "dependencies. The error from the imported " f"module was: {VIDEO_DATASET_ERRORMSG}" ) if contains_videos: - # root contains videos -> create a video dataset - dataset = VideoDataset( + return VideoDataset( root, extensions=VIDEO_EXTENSIONS, transform=transform, @@ -140,17 +138,13 @@ def _load_dataset_from_folder( num_workers=num_workers_video_frame_counting, ) elif _contains_subdirs(root): - # root contains subdirectories -> create an image folder dataset - dataset = datasets.ImageFolder( + return datasets.ImageFolder( root, transform=transform, is_valid_file=is_valid_file ) else: - # root contains plain images -> create a folder dataset - dataset = DatasetFolder( + return DatasetFolder( root, extensions=IMG_EXTENSIONS, transform=transform, is_valid_file=is_valid_file, ) - - return dataset