Skip to content

Commit

Permalink
simplify list_data_collate and collate_meta_tensor (#7165)
Browse files Browse the repository at this point in the history
Fixes #5917


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: KumoLiu <[email protected]>
Signed-off-by: YunLiu <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
  • Loading branch information
KumoLiu and ericspod authored Oct 26, 2023
1 parent 85243f5 commit fa15eec
Showing 1 changed file with 31 additions and 11 deletions.
42 changes: 31 additions & 11 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,12 @@
issequenceiterable,
look_up_option,
optional_import,
pytorch_after,
)

if pytorch_after(1, 13):
# import private code for reuse purposes, comment in case things break in the future
from torch.utils.data._utils.collate import collate_tensor_fn, default_collate_fn_map
pd, _ = optional_import("pandas")
DataFrame, _ = optional_import("pandas", name="DataFrame")
nib, _ = optional_import("nibabel")
Expand Down Expand Up @@ -444,22 +448,31 @@ def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True):
return data


def collate_meta_tensor_fn(batch, *, collate_fn_map=None):
"""
Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor`
and so should not be used as a collate function directly in dataloaders.
"""
collate_fn = collate_tensor_fn if pytorch_after(1, 13) else default_collate
collated = collate_fn(batch) # type: ignore
meta_dicts = [i.meta or TraceKeys.NONE for i in batch]
common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)])
if common_:
meta_dicts = [{k: d[k] for k in common_} if isinstance(d, dict) else TraceKeys.NONE for d in meta_dicts]
collated.meta = default_collate(meta_dicts)
collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch]
collated.is_batch = True
return collated


def collate_meta_tensor(batch):
"""collate a sequence of meta tensor sequences/dictionaries into
a single batched metatensor or a dictionary of batched metatensor"""
if not isinstance(batch, Sequence):
raise NotImplementedError()
elem_0 = first(batch)
if isinstance(elem_0, MetaObj):
collated = default_collate(batch)
meta_dicts = [i.meta or TraceKeys.NONE for i in batch]
common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)])
if common_:
meta_dicts = [{k: d[k] for k in common_} if isinstance(d, dict) else TraceKeys.NONE for d in meta_dicts]
collated.meta = default_collate(meta_dicts)
collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch]
collated.is_batch = True
return collated
return collate_meta_tensor_fn(batch)
if isinstance(elem_0, Mapping):
return {k: collate_meta_tensor([d[k] for d in batch]) for k in elem_0}
if isinstance(elem_0, (tuple, list)):
Expand All @@ -479,9 +492,16 @@ def list_data_collate(batch: Sequence):
Need to use this collate if apply some transforms that can generate batch data.
"""

if pytorch_after(1, 13):
# needs to go here to avoid circular import
from monai.data.meta_tensor import MetaTensor

default_collate_fn_map.update({MetaTensor: collate_meta_tensor_fn})
elem = batch[0]
data = [i for k in batch for i in k] if isinstance(elem, list) else batch
key = None
collate_fn = default_collate if pytorch_after(1, 13) else collate_meta_tensor
try:
if config.USE_META_DICT:
data = pickle_operations(data) # bc 0.9.0
Expand All @@ -490,9 +510,9 @@ def list_data_collate(batch: Sequence):
for k in elem:
key = k
data_for_batch = [d[key] for d in data]
ret[key] = collate_meta_tensor(data_for_batch)
ret[key] = collate_fn(data_for_batch)
else:
ret = collate_meta_tensor(data)
ret = collate_fn(data)
return ret
except RuntimeError as re:
re_str = str(re)
Expand Down

0 comments on commit fa15eec

Please sign in to comment.