diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 5bc38f69ea..003ec2cf0b 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -418,6 +418,10 @@ class PydicomReader(ImageReader): If provided, only the matched files will be included. For example, to include the file name "image_0001.dcm", the regular expression could be `".*image_(\\d+).dcm"`. Default to `""`. Set it to `None` to use `pydicom.misc.is_dicom` to match valid files. + to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading. + Default is False. CuPy and Kvikio are required for this option. + In practical use, it's recommended to add a warm up call before the actual loading. + A related tutorial will be prepared in the future, and the document will be updated accordingly. kwargs: additional args for `pydicom.dcmread` API. more details about available args: https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html If the `get_data` function will be called @@ -434,6 +438,7 @@ def __init__( prune_metadata: bool = True, label_dict: dict | None = None, fname_regex: str = "", + to_gpu: bool = False, **kwargs, ): super().__init__() @@ -444,6 +449,33 @@ def __init__( self.prune_metadata = prune_metadata self.label_dict = label_dict self.fname_regex = fname_regex + if to_gpu and (not has_cp or not has_kvikio): + warnings.warn( + "PydicomReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading." + ) + to_gpu = False + + if to_gpu: + self.warmup_kvikio() + + self.to_gpu = to_gpu + + def warmup_kvikio(self): + """ + Warm up the Kvikio library to initialize the internal buffers, cuFile, GDS, etc. + This can accelerate the data loading process when `to_gpu` is set to True. + """ + if has_cp and has_kvikio: + a = cp.arange(100) + with tempfile.NamedTemporaryFile() as tmp_file: + tmp_file_name = tmp_file.name + f = kvikio.CuFile(tmp_file_name, "w") + f.write(a) + f.close() + + b = cp.empty_like(a) + f = kvikio.CuFile(tmp_file_name, "r") + f.read(b) def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ @@ -475,12 +507,15 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_ = [] filenames: Sequence[PathLike] = ensure_tuple(data) + self.filenames = list(filenames) kwargs_ = self.kwargs.copy() + if self.to_gpu: + kwargs["defer_size"] = "100 KB" kwargs_.update(kwargs) self.has_series = False - for name in filenames: + for i, name in enumerate(filenames): name = f"{name}" if Path(name).is_dir(): # read DICOM series @@ -489,20 +524,28 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): else: series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if pydicom.misc.is_dicom(slc)] slices = [] + loaded_slc_names = [] for slc in series_slcs: try: slices.append(pydicom.dcmread(fp=slc, **kwargs_)) + loaded_slc_names.append(slc) except pydicom.errors.InvalidDicomError as e: warnings.warn(f"Failed to read {slc} with exception: \n{e}.", stacklevel=2) - img_.append(slices if len(slices) > 1 else slices[0]) if len(slices) > 1: self.has_series = True + img_.append(slices) + self.filenames[i] = loaded_slc_names # type: ignore + else: + img_.append(slices[0]) # type: ignore + self.filenames[i] = loaded_slc_names[0] # type: ignore else: ds = pydicom.dcmread(fp=name, **kwargs_) - img_.append(ds) - return img_ if len(filenames) > 1 else img_[0] + img_.append(ds) # type: ignore + if len(filenames) == 1: + return img_[0] + return img_ - def _combine_dicom_series(self, data: Iterable): + def _combine_dicom_series(self, data: Iterable, filenames: Sequence[PathLike]): """ Combine dicom series (a list of pydicom dataset objects). Their data arrays will be stacked together at a new dimension as the last dimension. @@ -522,28 +565,27 @@ def _combine_dicom_series(self, data: Iterable): """ slices: list = [] # for a dicom series - for slc_ds in data: + for slc_ds, filename in zip(data, filenames): if hasattr(slc_ds, "InstanceNumber"): - slices.append(slc_ds) + slices.append((slc_ds, filename)) else: - warnings.warn(f"slice: {slc_ds.filename} does not have InstanceNumber tag, skip it.") - slices = sorted(slices, key=lambda s: s.InstanceNumber) - + warnings.warn(f"slice: {filename} does not have InstanceNumber tag, skip it.") + slices = sorted(slices, key=lambda s: s[0].InstanceNumber) if len(slices) == 0: raise ValueError("the input does not have valid slices.") - first_slice = slices[0] + first_slice, first_filename = slices[0] average_distance = 0.0 - first_array = self._get_array_data(first_slice) + first_array = self._get_array_data(first_slice, first_filename) shape = first_array.shape - spacing = getattr(first_slice, "PixelSpacing", [1.0, 1.0, 1.0]) + spacing = getattr(first_slice, "PixelSpacing", [1.0] * len(shape)) prev_pos = getattr(first_slice, "ImagePositionPatient", (0.0, 0.0, 0.0))[2] stack_array = [first_array] for idx in range(1, len(slices)): - slc_array = self._get_array_data(slices[idx]) + slc_array = self._get_array_data(slices[idx][0], slices[idx][1]) slc_shape = slc_array.shape - slc_spacing = getattr(slices[idx], "PixelSpacing", (1.0, 1.0, 1.0)) - slc_pos = getattr(slices[idx], "ImagePositionPatient", (0.0, 0.0, float(idx)))[2] + slc_spacing = getattr(slices[idx][0], "PixelSpacing", [1.0] * len(shape)) + slc_pos = getattr(slices[idx][0], "ImagePositionPatient", (0.0, 0.0, float(idx)))[2] if not np.allclose(slc_spacing, spacing): warnings.warn(f"the list contains slices that have different spacings {spacing} and {slc_spacing}.") if shape != slc_shape: @@ -555,11 +597,14 @@ def _combine_dicom_series(self, data: Iterable): if len(slices) > 1: average_distance /= len(slices) - 1 spacing.append(average_distance) - stack_array = np.stack(stack_array, axis=-1) + if self.to_gpu: + stack_array = cp.stack(stack_array, axis=-1) + else: + stack_array = np.stack(stack_array, axis=-1) stack_metadata = self._get_meta_dict(first_slice) stack_metadata["spacing"] = np.asarray(spacing) - if hasattr(slices[-1], "ImagePositionPatient"): - stack_metadata["lastImagePositionPatient"] = np.asarray(slices[-1].ImagePositionPatient) + if hasattr(slices[-1][0], "ImagePositionPatient"): + stack_metadata["lastImagePositionPatient"] = np.asarray(slices[-1][0].ImagePositionPatient) stack_metadata[MetaKeys.SPATIAL_SHAPE] = shape + (len(slices),) else: stack_array = stack_array[0] @@ -597,29 +642,35 @@ def get_data(self, data) -> tuple[np.ndarray, dict]: if self.has_series is True: # a list, all objects within a list belong to one dicom series if not isinstance(data[0], list): - dicom_data.append(self._combine_dicom_series(data)) + # input is a dir, self.filenames is a list of list of filenames + dicom_data.append(self._combine_dicom_series(data, self.filenames[0])) # type: ignore # a list of list, each inner list represents a dicom series else: - for series in data: - dicom_data.append(self._combine_dicom_series(series)) + for i, series in enumerate(data): + dicom_data.append(self._combine_dicom_series(series, self.filenames[i])) # type: ignore else: # a single pydicom dataset object if not isinstance(data, list): data = [data] - for d in data: + for i, d in enumerate(data): if hasattr(d, "SegmentSequence"): - data_array, metadata = self._get_seg_data(d) + data_array, metadata = self._get_seg_data(d, self.filenames[i]) else: - data_array = self._get_array_data(d) + data_array = self._get_array_data(d, self.filenames[i]) metadata = self._get_meta_dict(d) metadata[MetaKeys.SPATIAL_SHAPE] = data_array.shape dicom_data.append((data_array, metadata)) + # TODO: the actual type is list[np.ndarray | cp.ndarray] + # should figure out how to define correct types without having cupy not found error + # https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918 img_array: list[np.ndarray] = [] compatible_meta: dict = {} for data_array, metadata in ensure_tuple(dicom_data): - img_array.append(np.ascontiguousarray(np.swapaxes(data_array, 0, 1) if self.swap_ij else data_array)) + if self.swap_ij: + data_array = cp.swapaxes(data_array, 0, 1) if self.to_gpu else np.swapaxes(data_array, 0, 1) + img_array.append(cp.ascontiguousarray(data_array) if self.to_gpu else np.ascontiguousarray(data_array)) affine = self._get_affine(metadata, self.affine_lps_to_ras) metadata[MetaKeys.SPACE] = SpaceKeys.RAS if self.affine_lps_to_ras else SpaceKeys.LPS if self.swap_ij: @@ -641,7 +692,7 @@ def get_data(self, data) -> tuple[np.ndarray, dict]: _copy_compatible_dict(metadata, compatible_meta) - return _stack_images(img_array, compatible_meta), compatible_meta + return _stack_images(img_array, compatible_meta, to_cupy=self.to_gpu), compatible_meta def _get_meta_dict(self, img) -> dict: """ @@ -713,7 +764,7 @@ def _get_affine(self, metadata: dict, lps_to_ras: bool = True): affine = orientation_ras_lps(affine) return affine - def _get_frame_data(self, img) -> Iterator: + def _get_frame_data(self, img, filename, array_data) -> Iterator: """ yield frames and description from the segmentation image. This function is adapted from Highdicom: @@ -751,48 +802,54 @@ def _get_frame_data(self, img) -> Iterator: """ if not hasattr(img, "PerFrameFunctionalGroupsSequence"): - raise NotImplementedError( - f"To read dicom seg: {img.filename}, 'PerFrameFunctionalGroupsSequence' is required." - ) + raise NotImplementedError(f"To read dicom seg: {filename}, 'PerFrameFunctionalGroupsSequence' is required.") frame_seg_nums = [] for f in img.PerFrameFunctionalGroupsSequence: if not hasattr(f, "SegmentIdentificationSequence"): raise NotImplementedError( - f"To read dicom seg: {img.filename}, 'SegmentIdentificationSequence' is required for each frame." + f"To read dicom seg: {filename}, 'SegmentIdentificationSequence' is required for each frame." ) frame_seg_nums.append(int(f.SegmentIdentificationSequence[0].ReferencedSegmentNumber)) - frame_seg_nums_arr = np.array(frame_seg_nums) + frame_seg_nums_arr = cp.array(frame_seg_nums) if self.to_gpu else np.array(frame_seg_nums) seg_descriptions = {int(f.SegmentNumber): f for f in img.SegmentSequence} - for i in np.unique(frame_seg_nums_arr): - indices = np.where(frame_seg_nums_arr == i)[0] - yield (img.pixel_array[indices, ...], seg_descriptions[i]) + for i in np.unique(frame_seg_nums_arr) if not self.to_gpu else cp.unique(frame_seg_nums_arr): + indices = np.where(frame_seg_nums_arr == i)[0] if not self.to_gpu else cp.where(frame_seg_nums_arr == i)[0] + yield (array_data[indices, ...], seg_descriptions[i]) - def _get_seg_data(self, img): + def _get_seg_data(self, img, filename): """ Get the array data and metadata of the segmentation image. Aegs: img: a Pydicom dataset object that has attribute "SegmentSequence". + filename: the file path of the image. """ metadata = self._get_meta_dict(img) n_classes = len(img.SegmentSequence) - spatial_shape = list(img.pixel_array.shape) + array_data = self._get_array_data(img, filename) + spatial_shape = list(array_data.shape) spatial_shape[0] = spatial_shape[0] // n_classes if self.label_dict is not None: metadata["labels"] = self.label_dict - all_segs = np.zeros([*spatial_shape, len(self.label_dict)]) + if self.to_gpu: + all_segs = cp.zeros([*spatial_shape, len(self.label_dict)], dtype=array_data.dtype) + else: + all_segs = np.zeros([*spatial_shape, len(self.label_dict)], dtype=array_data.dtype) else: metadata["labels"] = {} - all_segs = np.zeros([*spatial_shape, n_classes]) + if self.to_gpu: + all_segs = cp.zeros([*spatial_shape, n_classes], dtype=array_data.dtype) + else: + all_segs = np.zeros([*spatial_shape, n_classes], dtype=array_data.dtype) - for i, (frames, description) in enumerate(self._get_frame_data(img)): + for i, (frames, description) in enumerate(self._get_frame_data(img, filename, array_data)): segment_label = getattr(description, "SegmentLabel", f"label_{i}") class_name = getattr(description, "SegmentDescription", segment_label) if class_name not in metadata["labels"].keys(): @@ -840,19 +897,79 @@ def _get_seg_data(self, img): return all_segs, metadata - def _get_array_data(self, img): + def _get_array_data_from_gpu(self, img, filename): + """ + Get the raw array data of the image. This function is used when `to_gpu` is set to True. + + Args: + img: a Pydicom dataset object. + filename: the file path of the image. + + """ + rows = getattr(img, "Rows", None) + columns = getattr(img, "Columns", None) + bits_allocated = getattr(img, "BitsAllocated", None) + samples_per_pixel = getattr(img, "SamplesPerPixel", 1) + number_of_frames = getattr(img, "NumberOfFrames", 1) + pixel_representation = getattr(img, "PixelRepresentation", 1) + + if rows is None or columns is None or bits_allocated is None: + warnings.warn( + f"dicom data: {filename} does not have Rows, Columns or BitsAllocated, falling back to CPU loading." + ) + + if not hasattr(img, "pixel_array"): + raise ValueError(f"dicom data: {filename} does not have pixel_array.") + data = img.pixel_array + + return data + + if bits_allocated == 8: + dtype = cp.int8 if pixel_representation == 1 else cp.uint8 + elif bits_allocated == 16: + dtype = cp.int16 if pixel_representation == 1 else cp.uint16 + elif bits_allocated == 32: + dtype = cp.int32 if pixel_representation == 1 else cp.uint32 + else: + raise ValueError("Unsupported BitsAllocated value") + + bytes_per_pixel = bits_allocated // 8 + total_pixels = rows * columns * samples_per_pixel * number_of_frames + expected_pixel_data_length = total_pixels * bytes_per_pixel + + pixel_data_tag = pydicom.tag.Tag(0x7FE0, 0x0010) + if pixel_data_tag not in img: + raise ValueError(f"dicom data: {filename} does not have pixel data.") + + offset = img.get_item(pixel_data_tag, keep_deferred=True).value_tell + + with kvikio.CuFile(filename, "r") as f: + buffer = cp.empty(expected_pixel_data_length, dtype=cp.int8) + f.read(buffer, expected_pixel_data_length, offset) + + new_shape = (number_of_frames, rows, columns) if number_of_frames > 1 else (rows, columns) + data = buffer.view(dtype).reshape(new_shape) + + return data + + def _get_array_data(self, img, filename): """ Get the array data of the image. If `RescaleSlope` and `RescaleIntercept` are available, the raw array data - will be rescaled. The output data has the dtype np.float32 if the rescaling is applied. + will be rescaled. The output data has the dtype float32 if the rescaling is applied. Args: img: a Pydicom dataset object. + filename: the file path of the image. """ # process Dicom series - if not hasattr(img, "pixel_array"): - raise ValueError(f"dicom data: {img.filename} does not have pixel_array.") - data = img.pixel_array + + if self.to_gpu: + data = self._get_array_data_from_gpu(img, filename) + else: + if not hasattr(img, "pixel_array"): + raise ValueError(f"dicom data: {filename} does not have pixel_array.") + data = img.pixel_array slope, offset = 1.0, 0.0 rescale_flag = False @@ -862,8 +979,14 @@ def _get_array_data(self, img): if hasattr(img, "RescaleIntercept"): offset = img.RescaleIntercept rescale_flag = True + if rescale_flag: - data = data.astype(np.float32) * slope + offset + if self.to_gpu: + slope = cp.asarray(slope, dtype=cp.float32) + offset = cp.asarray(offset, dtype=cp.float32) + data = data.astype(cp.float32) * slope + offset + else: + data = data.astype(np.float32) * slope + offset return data @@ -884,8 +1007,6 @@ class NibabelReader(ImageReader): Default is False. CuPy and Kvikio are required for this option. Note: For compressed NIfTI files, some operations may still be performed on CPU memory, and the acceleration may not be significant. In some cases, it may be slower than loading on CPU. - In practical use, it's recommended to add a warm up call before the actual loading. - A related tutorial will be prepared in the future, and the document will be updated accordingly. kwargs: additional args for `nibabel.load` API. more details about available args: https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 498b9972b4..07acf7c179 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -168,6 +168,16 @@ def get_data(self, _obj): # test reader consistency between PydicomReader and ITKReader on dicom data TEST_CASE_22 = ["tests/testing_data/CT_DICOM"] +# test pydicom gpu reader +TEST_CASE_GPU_5 = [{"reader": "PydicomReader", "to_gpu": True}, "tests/testing_data/CT_DICOM", (16, 16, 4), (16, 16, 4)] + +TEST_CASE_GPU_6 = [ + {"reader": "PydicomReader", "ensure_channel_first": True, "force": True, "to_gpu": True}, + "tests/testing_data/CT_DICOM", + (16, 16, 4), + (1, 16, 16, 4), +] + TESTS_META = [] for track_meta in (False, True): TESTS_META.append([{}, (128, 128, 128), track_meta]) @@ -242,16 +252,17 @@ def test_nibabel_reader_gpu(self, input_param, filenames, expected_shape): @parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_8_1, TEST_CASE_9]) def test_itk_reader(self, input_param, filenames, expected_shape): - test_image = np.random.rand(128, 128, 128) + test_image = torch.randint(0, 256, (128, 128, 128), dtype=torch.uint8).numpy() + print("Test image value range:", test_image.min(), test_image.max()) with tempfile.TemporaryDirectory() as tempdir: for i, name in enumerate(filenames): filenames[i] = os.path.join(tempdir, name) - itk_np_view = itk.image_view_from_array(test_image) - itk.imwrite(itk_np_view, filenames[i]) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) result = LoadImage(image_only=True, **input_param)(filenames) - self.assertEqual(result.meta["filename_or_obj"], os.path.join(tempdir, "test_image.nii.gz")) - diag = torch.as_tensor(np.diag([-1, -1, 1, 1])) - np.testing.assert_allclose(result.affine, diag) + ext = "".join(Path(name).suffixes) + self.assertEqual(result.meta["filename_or_obj"], os.path.join(tempdir, "test_image" + ext)) + self.assertEqual(result.meta["space"], "RAS") + assert_allclose(result.affine, torch.eye(4)) self.assertTupleEqual(result.shape, expected_shape) @parameterized.expand([TEST_CASE_10, TEST_CASE_11, TEST_CASE_12, TEST_CASE_19, TEST_CASE_20, TEST_CASE_21]) @@ -271,6 +282,26 @@ def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape, e ) self.assertTupleEqual(result.shape, expected_np_shape) + @SkipIfNoModule("pydicom") + @SkipIfNoModule("cupy") + @SkipIfNoModule("kvikio") + @parameterized.expand([TEST_CASE_GPU_5, TEST_CASE_GPU_6]) + def test_pydicom_gpu_reader(self, input_param, filenames, expected_shape, expected_np_shape): + result = LoadImage(image_only=True, **input_param)(filenames) + self.assertEqual(result.meta["filename_or_obj"], f"{Path(filenames)}") + assert_allclose( + result.affine, + torch.tensor( + [ + [-0.488281, 0.0, 0.0, 125.0], + [0.0, -0.488281, 0.0, 128.100006], + [0.0, 0.0, 68.33333333, -99.480003], + [0.0, 0.0, 0.0, 1.0], + ] + ), + ) + self.assertTupleEqual(result.shape, expected_np_shape) + def test_no_files(self): with self.assertRaisesRegex(RuntimeError, "list index out of range"): # fname_regex excludes everything LoadImage(image_only=True, reader="PydicomReader", fname_regex=r"^(?!.*).*")("tests/testing_data/CT_DICOM") @@ -317,6 +348,21 @@ def test_dicom_reader_consistency(self, filenames): np.testing.assert_allclose(pydicom_result, itk_result) np.testing.assert_allclose(pydicom_result.affine, itk_result.affine) + @SkipIfNoModule("pydicom") + @SkipIfNoModule("cupy") + @SkipIfNoModule("kvikio") + @parameterized.expand([TEST_CASE_22]) + def test_pydicom_reader_gpu_cpu_consistency(self, filenames): + gpu_param = {"reader": "PydicomReader", "to_gpu": True} + cpu_param = {"reader": "PydicomReader", "to_gpu": False} + for affine_flag in [True, False]: + gpu_param["affine_lps_to_ras"] = affine_flag + cpu_param["affine_lps_to_ras"] = affine_flag + gpu_result = LoadImage(image_only=True, **gpu_param)(filenames) + cpu_result = LoadImage(image_only=True, **cpu_param)(filenames) + np.testing.assert_allclose(gpu_result.cpu(), cpu_result) + np.testing.assert_allclose(gpu_result.affine.cpu(), cpu_result.affine) + def test_dicom_reader_consistency_single(self): itk_param = {"reader": "ITKReader"} pydicom_param = {"reader": "PydicomReader"}