Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update pydicom reader to enable gpu load #8283

Merged
merged 16 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 170 additions & 49 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,10 @@ class PydicomReader(ImageReader):
If provided, only the matched files will be included. For example, to include the file name
"image_0001.dcm", the regular expression could be `".*image_(\\d+).dcm"`. Default to `""`.
Set it to `None` to use `pydicom.misc.is_dicom` to match valid files.
to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading.
Default is False. CuPy and Kvikio are required for this option.
In practical use, it's recommended to add a warm up call before the actual loading.
A related tutorial will be prepared in the future, and the document will be updated accordingly.
kwargs: additional args for `pydicom.dcmread` API. more details about available args:
https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html
If the `get_data` function will be called
Expand All @@ -434,6 +438,7 @@ def __init__(
prune_metadata: bool = True,
label_dict: dict | None = None,
fname_regex: str = "",
to_gpu: bool = False,
**kwargs,
):
super().__init__()
Expand All @@ -444,6 +449,33 @@ def __init__(
self.prune_metadata = prune_metadata
self.label_dict = label_dict
self.fname_regex = fname_regex
if to_gpu and (not has_cp or not has_kvikio):
warnings.warn(
"PydicomReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading."
)
to_gpu = False

if to_gpu:
self.warmup_kvikio()

self.to_gpu = to_gpu

def warmup_kvikio(self):
"""
Warm up the Kvikio library to initialize the internal buffers, cuFile, GDS, etc.
This can accelerate the data loading process when `to_gpu` is set to True.
"""
if has_cp and has_kvikio:
a = cp.arange(100)
with tempfile.NamedTemporaryFile() as tmp_file:
tmp_file_name = tmp_file.name
f = kvikio.CuFile(tmp_file_name, "w")
f.write(a)
f.close()

b = cp.empty_like(a)
f = kvikio.CuFile(tmp_file_name, "r")
f.read(b)

def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
"""
Expand Down Expand Up @@ -475,12 +507,15 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
img_ = []

filenames: Sequence[PathLike] = ensure_tuple(data)
self.filenames = list(filenames)
kwargs_ = self.kwargs.copy()
if self.to_gpu:
kwargs["defer_size"] = "100 KB"
kwargs_.update(kwargs)

self.has_series = False

for name in filenames:
for i, name in enumerate(filenames):
name = f"{name}"
if Path(name).is_dir():
# read DICOM series
Expand All @@ -489,20 +524,28 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
else:
series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if pydicom.misc.is_dicom(slc)]
slices = []
loaded_slc_names = []
for slc in series_slcs:
try:
slices.append(pydicom.dcmread(fp=slc, **kwargs_))
loaded_slc_names.append(slc)
except pydicom.errors.InvalidDicomError as e:
warnings.warn(f"Failed to read {slc} with exception: \n{e}.", stacklevel=2)
img_.append(slices if len(slices) > 1 else slices[0])
if len(slices) > 1:
self.has_series = True
img_.append(slices)
self.filenames[i] = loaded_slc_names # type: ignore
else:
img_.append(slices[0]) # type: ignore
self.filenames[i] = loaded_slc_names[0] # type: ignore
else:
ds = pydicom.dcmread(fp=name, **kwargs_)
img_.append(ds)
return img_ if len(filenames) > 1 else img_[0]
img_.append(ds) # type: ignore
if len(filenames) == 1:
return img_[0]
return img_

def _combine_dicom_series(self, data: Iterable):
def _combine_dicom_series(self, data: Iterable, filenames: Sequence[PathLike]):
"""
Combine dicom series (a list of pydicom dataset objects). Their data arrays will be stacked together at a new
dimension as the last dimension.
Expand All @@ -522,28 +565,27 @@ def _combine_dicom_series(self, data: Iterable):
"""
slices: list = []
# for a dicom series
for slc_ds in data:
for slc_ds, filename in zip(data, filenames):
if hasattr(slc_ds, "InstanceNumber"):
slices.append(slc_ds)
slices.append((slc_ds, filename))
else:
warnings.warn(f"slice: {slc_ds.filename} does not have InstanceNumber tag, skip it.")
slices = sorted(slices, key=lambda s: s.InstanceNumber)

warnings.warn(f"slice: {filename} does not have InstanceNumber tag, skip it.")
slices = sorted(slices, key=lambda s: s[0].InstanceNumber)
if len(slices) == 0:
raise ValueError("the input does not have valid slices.")

first_slice = slices[0]
first_slice, first_filename = slices[0]
average_distance = 0.0
first_array = self._get_array_data(first_slice)
first_array = self._get_array_data(first_slice, first_filename)
shape = first_array.shape
spacing = getattr(first_slice, "PixelSpacing", [1.0, 1.0, 1.0])
spacing = getattr(first_slice, "PixelSpacing", [1.0] * len(shape))
prev_pos = getattr(first_slice, "ImagePositionPatient", (0.0, 0.0, 0.0))[2]
stack_array = [first_array]
for idx in range(1, len(slices)):
slc_array = self._get_array_data(slices[idx])
slc_array = self._get_array_data(slices[idx][0], slices[idx][1])
slc_shape = slc_array.shape
slc_spacing = getattr(slices[idx], "PixelSpacing", (1.0, 1.0, 1.0))
slc_pos = getattr(slices[idx], "ImagePositionPatient", (0.0, 0.0, float(idx)))[2]
slc_spacing = getattr(slices[idx][0], "PixelSpacing", [1.0] * len(shape))
slc_pos = getattr(slices[idx][0], "ImagePositionPatient", (0.0, 0.0, float(idx)))[2]
if not np.allclose(slc_spacing, spacing):
warnings.warn(f"the list contains slices that have different spacings {spacing} and {slc_spacing}.")
if shape != slc_shape:
Expand All @@ -555,11 +597,14 @@ def _combine_dicom_series(self, data: Iterable):
if len(slices) > 1:
average_distance /= len(slices) - 1
spacing.append(average_distance)
stack_array = np.stack(stack_array, axis=-1)
if self.to_gpu:
stack_array = cp.stack(stack_array, axis=-1)
else:
stack_array = np.stack(stack_array, axis=-1)
stack_metadata = self._get_meta_dict(first_slice)
stack_metadata["spacing"] = np.asarray(spacing)
if hasattr(slices[-1], "ImagePositionPatient"):
stack_metadata["lastImagePositionPatient"] = np.asarray(slices[-1].ImagePositionPatient)
if hasattr(slices[-1][0], "ImagePositionPatient"):
stack_metadata["lastImagePositionPatient"] = np.asarray(slices[-1][0].ImagePositionPatient)
stack_metadata[MetaKeys.SPATIAL_SHAPE] = shape + (len(slices),)
else:
stack_array = stack_array[0]
Expand Down Expand Up @@ -597,29 +642,35 @@ def get_data(self, data) -> tuple[np.ndarray, dict]:
if self.has_series is True:
# a list, all objects within a list belong to one dicom series
if not isinstance(data[0], list):
dicom_data.append(self._combine_dicom_series(data))
# input is a dir, self.filenames is a list of list of filenames
dicom_data.append(self._combine_dicom_series(data, self.filenames[0])) # type: ignore
# a list of list, each inner list represents a dicom series
else:
for series in data:
dicom_data.append(self._combine_dicom_series(series))
for i, series in enumerate(data):
dicom_data.append(self._combine_dicom_series(series, self.filenames[i])) # type: ignore
else:
# a single pydicom dataset object
if not isinstance(data, list):
data = [data]
for d in data:
for i, d in enumerate(data):
if hasattr(d, "SegmentSequence"):
data_array, metadata = self._get_seg_data(d)
data_array, metadata = self._get_seg_data(d, self.filenames[i])
else:
data_array = self._get_array_data(d)
data_array = self._get_array_data(d, self.filenames[i])
metadata = self._get_meta_dict(d)
metadata[MetaKeys.SPATIAL_SHAPE] = data_array.shape
dicom_data.append((data_array, metadata))

# TODO: the actual type is list[np.ndarray | cp.ndarray]
# should figure out how to define correct types without having cupy not found error
# https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918
img_array: list[np.ndarray] = []
compatible_meta: dict = {}

for data_array, metadata in ensure_tuple(dicom_data):
img_array.append(np.ascontiguousarray(np.swapaxes(data_array, 0, 1) if self.swap_ij else data_array))
if self.swap_ij:
data_array = cp.swapaxes(data_array, 0, 1) if self.to_gpu else np.swapaxes(data_array, 0, 1)
img_array.append(cp.ascontiguousarray(data_array) if self.to_gpu else np.ascontiguousarray(data_array))
affine = self._get_affine(metadata, self.affine_lps_to_ras)
metadata[MetaKeys.SPACE] = SpaceKeys.RAS if self.affine_lps_to_ras else SpaceKeys.LPS
if self.swap_ij:
Expand All @@ -641,7 +692,7 @@ def get_data(self, data) -> tuple[np.ndarray, dict]:

_copy_compatible_dict(metadata, compatible_meta)

return _stack_images(img_array, compatible_meta), compatible_meta
return _stack_images(img_array, compatible_meta, to_cupy=self.to_gpu), compatible_meta

def _get_meta_dict(self, img) -> dict:
"""
Expand Down Expand Up @@ -713,7 +764,7 @@ def _get_affine(self, metadata: dict, lps_to_ras: bool = True):
affine = orientation_ras_lps(affine)
return affine

def _get_frame_data(self, img) -> Iterator:
def _get_frame_data(self, img, filename, array_data) -> Iterator:
"""
yield frames and description from the segmentation image.
This function is adapted from Highdicom:
Expand Down Expand Up @@ -751,48 +802,54 @@ def _get_frame_data(self, img) -> Iterator:
"""

if not hasattr(img, "PerFrameFunctionalGroupsSequence"):
raise NotImplementedError(
f"To read dicom seg: {img.filename}, 'PerFrameFunctionalGroupsSequence' is required."
)
raise NotImplementedError(f"To read dicom seg: {filename}, 'PerFrameFunctionalGroupsSequence' is required.")

frame_seg_nums = []
for f in img.PerFrameFunctionalGroupsSequence:
if not hasattr(f, "SegmentIdentificationSequence"):
raise NotImplementedError(
f"To read dicom seg: {img.filename}, 'SegmentIdentificationSequence' is required for each frame."
f"To read dicom seg: {filename}, 'SegmentIdentificationSequence' is required for each frame."
)
frame_seg_nums.append(int(f.SegmentIdentificationSequence[0].ReferencedSegmentNumber))

frame_seg_nums_arr = np.array(frame_seg_nums)
frame_seg_nums_arr = cp.array(frame_seg_nums) if self.to_gpu else np.array(frame_seg_nums)

seg_descriptions = {int(f.SegmentNumber): f for f in img.SegmentSequence}

for i in np.unique(frame_seg_nums_arr):
indices = np.where(frame_seg_nums_arr == i)[0]
yield (img.pixel_array[indices, ...], seg_descriptions[i])
for i in np.unique(frame_seg_nums_arr) if not self.to_gpu else cp.unique(frame_seg_nums_arr):
indices = np.where(frame_seg_nums_arr == i)[0] if not self.to_gpu else cp.where(frame_seg_nums_arr == i)[0]
yield (array_data[indices, ...], seg_descriptions[i])

def _get_seg_data(self, img):
def _get_seg_data(self, img, filename):
"""
Get the array data and metadata of the segmentation image.

Aegs:
img: a Pydicom dataset object that has attribute "SegmentSequence".
filename: the file path of the image.

"""

metadata = self._get_meta_dict(img)
n_classes = len(img.SegmentSequence)
spatial_shape = list(img.pixel_array.shape)
array_data = self._get_array_data(img, filename)
spatial_shape = list(array_data.shape)
spatial_shape[0] = spatial_shape[0] // n_classes

if self.label_dict is not None:
metadata["labels"] = self.label_dict
all_segs = np.zeros([*spatial_shape, len(self.label_dict)])
if self.to_gpu:
all_segs = cp.zeros([*spatial_shape, len(self.label_dict)], dtype=array_data.dtype)
else:
all_segs = np.zeros([*spatial_shape, len(self.label_dict)], dtype=array_data.dtype)
else:
metadata["labels"] = {}
all_segs = np.zeros([*spatial_shape, n_classes])
if self.to_gpu:
all_segs = cp.zeros([*spatial_shape, n_classes], dtype=array_data.dtype)
else:
all_segs = np.zeros([*spatial_shape, n_classes], dtype=array_data.dtype)

for i, (frames, description) in enumerate(self._get_frame_data(img)):
for i, (frames, description) in enumerate(self._get_frame_data(img, filename, array_data)):
segment_label = getattr(description, "SegmentLabel", f"label_{i}")
class_name = getattr(description, "SegmentDescription", segment_label)
if class_name not in metadata["labels"].keys():
Expand Down Expand Up @@ -840,19 +897,79 @@ def _get_seg_data(self, img):

return all_segs, metadata

def _get_array_data(self, img):
def _get_array_data_from_gpu(self, img, filename):
"""
Get the raw array data of the image. This function is used when `to_gpu` is set to True.

Args:
img: a Pydicom dataset object.
filename: the file path of the image.

"""
rows = getattr(img, "Rows", None)
columns = getattr(img, "Columns", None)
bits_allocated = getattr(img, "BitsAllocated", None)
samples_per_pixel = getattr(img, "SamplesPerPixel", 1)
number_of_frames = getattr(img, "NumberOfFrames", 1)
pixel_representation = getattr(img, "PixelRepresentation", 1)

if rows is None or columns is None or bits_allocated is None:
warnings.warn(
f"dicom data: {filename} does not have Rows, Columns or BitsAllocated, falling back to CPU loading."
)

if not hasattr(img, "pixel_array"):
raise ValueError(f"dicom data: {filename} does not have pixel_array.")
data = img.pixel_array

return data

if bits_allocated == 8:
dtype = cp.int8 if pixel_representation == 1 else cp.uint8
elif bits_allocated == 16:
dtype = cp.int16 if pixel_representation == 1 else cp.uint16
elif bits_allocated == 32:
dtype = cp.int32 if pixel_representation == 1 else cp.uint32
else:
raise ValueError("Unsupported BitsAllocated value")

bytes_per_pixel = bits_allocated // 8
total_pixels = rows * columns * samples_per_pixel * number_of_frames
expected_pixel_data_length = total_pixels * bytes_per_pixel

pixel_data_tag = pydicom.tag.Tag(0x7FE0, 0x0010)
if pixel_data_tag not in img:
raise ValueError(f"dicom data: {filename} does not have pixel data.")

offset = img.get_item(pixel_data_tag, keep_deferred=True).value_tell

with kvikio.CuFile(filename, "r") as f:
buffer = cp.empty(expected_pixel_data_length, dtype=cp.int8)
f.read(buffer, expected_pixel_data_length, offset)

new_shape = (number_of_frames, rows, columns) if number_of_frames > 1 else (rows, columns)
data = buffer.view(dtype).reshape(new_shape)

return data

def _get_array_data(self, img, filename):
"""
Get the array data of the image. If `RescaleSlope` and `RescaleIntercept` are available, the raw array data
will be rescaled. The output data has the dtype np.float32 if the rescaling is applied.
will be rescaled. The output data has the dtype float32 if the rescaling is applied.

Args:
img: a Pydicom dataset object.
filename: the file path of the image.

"""
# process Dicom series
if not hasattr(img, "pixel_array"):
raise ValueError(f"dicom data: {img.filename} does not have pixel_array.")
data = img.pixel_array

if self.to_gpu:
data = self._get_array_data_from_gpu(img, filename)
else:
if not hasattr(img, "pixel_array"):
raise ValueError(f"dicom data: {filename} does not have pixel_array.")
data = img.pixel_array

slope, offset = 1.0, 0.0
rescale_flag = False
Expand All @@ -862,8 +979,14 @@ def _get_array_data(self, img):
if hasattr(img, "RescaleIntercept"):
offset = img.RescaleIntercept
rescale_flag = True

if rescale_flag:
data = data.astype(np.float32) * slope + offset
if self.to_gpu:
slope = cp.asarray(slope, dtype=cp.float32)
offset = cp.asarray(offset, dtype=cp.float32)
data = data.astype(cp.float32) * slope + offset
else:
data = data.astype(np.float32) * slope + offset

return data

Expand All @@ -884,8 +1007,6 @@ class NibabelReader(ImageReader):
Default is False. CuPy and Kvikio are required for this option.
Note: For compressed NIfTI files, some operations may still be performed on CPU memory,
and the acceleration may not be significant. In some cases, it may be slower than loading on CPU.
In practical use, it's recommended to add a warm up call before the actual loading.
A related tutorial will be prepared in the future, and the document will be updated accordingly.
kwargs: additional args for `nibabel.load` API. more details about available args:
https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py

Expand Down
Loading
Loading