Skip to content

Commit

Permalink
Add more PyGrain transforms
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 733319106
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Mar 4, 2025
1 parent 939b2f6 commit 818c38f
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 24 deletions.
12 changes: 11 additions & 1 deletion kauldron/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,20 @@
# DO NOT ADD preprocessing ops here. Instead, add them to `kd.contrib.data`
# ****************************************************************************

# TODO(epot): Should migrate all users to use explicitly `kd.data.tf`
# Transforms here supports both PyGrain (`kd.data.py`) and `tf.data`
# (`kd.data.tf`) pipelines.
# For extra PyGrain-only or TF-only transforms, see `kd.data.py` or
# `kd.data.tf`

# ====== Structure transforms ======
from kauldron.data.transforms.base import AddConstants
from kauldron.data.transforms.base import Elements
from kauldron.data.transforms.base import ElementWiseTransform
from kauldron.data.transforms.base import TreeFlattenWithPath
# ====== Random transforms ======
# ====== Map transforms ======
from kauldron.data.transforms.map_transforms import Cast
from kauldron.data.transforms.map_transforms import Gather
from kauldron.data.transforms.map_transforms import Rearrange
from kauldron.data.transforms.map_transforms import Resize
from kauldron.data.transforms.map_transforms import ValueRange
6 changes: 5 additions & 1 deletion kauldron/data/py/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@
# *****************************************************************************

# ====== Structure transforms ======
from kauldron.data.transforms.base import AddConstants
from kauldron.data.transforms.base import Elements
from kauldron.data.transforms.base import ElementWiseTransform
from kauldron.data.transforms.base import TreeFlattenWithPath
from kauldron.data.py.transform_utils import Slice
from kauldron.data.py.transform_utils import SliceDataset
# ====== Random transforms ======
# ====== Map transforms ======
from kauldron.data.transforms.map_transforms import Cast
from kauldron.data.transforms.map_transforms import Gather
from kauldron.data.transforms.map_transforms import Rearrange
from kauldron.data.transforms.map_transforms import Resize
from kauldron.data.transforms.map_transforms import ValueRange
17 changes: 9 additions & 8 deletions kauldron/data/py/transform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,31 @@
_MISSING: Any = object()


class Slice:
"""Slice transform.
class SliceDataset:
"""Transform which select a subset of the dataset.
Transforms that select a subset of the dataset (e.g. to debug or train on
a subset of the data).
Can be useful to debug or train on a subset of the data.
```python
ds = kd.data.py.Tfds(
name='mnist',
split='train',
transforms=[
kd.data.py.Slice(10), # Select ds[:10]
kd.data.py.SliceDataset(10), # Select ds[:10]
],
)
```
"""

# Internally, this is converted to `ds = ds.slice` in `_apply_transform`

def __init__(
self,
start: int | None = _MISSING,
stop: int | None = _MISSING,
step: int | None = _MISSING,
):
# Called as `Slice(stop)`
# Called as `SliceDataset(stop)`
if start is not _MISSING and stop is _MISSING and step is _MISSING:
stop = start
start = _MISSING
Expand Down Expand Up @@ -92,7 +93,7 @@ def map(self, element: Any) -> Any:
def _adapt_for_pygrain(
transform: tr_normalize.Transformation,
) -> grain.Transformation:
if isinstance(transform, (grain.Transformation, Slice)):
if isinstance(transform, (grain.Transformation, SliceDataset)):
return transform
return tr_normalize.adapt_transform(transform, _KD_TO_PYGRAIN_ADAPTERS)

Expand Down Expand Up @@ -122,7 +123,7 @@ def _apply_transform(
ds = ds.filter(tr)
case grain.Batch():
ds = ds.batch(tr.batch_size, drop_remainder=tr.drop_remainder)
case Slice():
case SliceDataset():
ds = ds.slice(tr.slice)
case _:
raise ValueError(f"Unexpected transform type: {tr}")
Expand Down
6 changes: 4 additions & 2 deletions kauldron/data/tf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
# *****************************************************************************

# ====== Structure transforms ======
from kauldron.data.transforms.base import AddConstants
from kauldron.data.transforms.base import Elements
from kauldron.data.transforms.base import ElementWiseTransform
from kauldron.data.transforms.base import TreeFlattenWithPath
Expand All @@ -42,12 +43,13 @@
from kauldron.data.tf.random_transforms import RandomCrop
from kauldron.data.tf.random_transforms import RandomFlipLeftRight
# ====== Map transforms ======
from kauldron.data.transforms.map_transforms import Cast
from kauldron.data.transforms.map_transforms import Gather
from kauldron.data.transforms.map_transforms import Rearrange
from kauldron.data.transforms.map_transforms import ValueRange
from kauldron.data.tf.map_transforms import Cast
from kauldron.data.tf.map_transforms import CenterCrop
from kauldron.data.tf.map_transforms import OneHot
# TODO(epot): Unify Resize & ResizeSmall and have better API.
# TODO(epot): Unify Resize & ResizeSmall and have better API. Should replace
# by the generic `kd.data.py.Resize`
from kauldron.data.tf.map_transforms import Resize
from kauldron.data.tf.map_transforms import ResizeSmall
13 changes: 1 addition & 12 deletions kauldron/data/tf/map_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

import dataclasses
from typing import Any, Optional
from typing import Optional

import einops
from kauldron.data.tf import transform_utils
Expand All @@ -26,17 +26,6 @@
import tensorflow as tf


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class Cast(base.ElementWiseTransform):
"""Cast an element to the specified dtype."""

dtype: Any

@typechecked
def map_element(self, element: TfArray["*any"]) -> TfArray["*any"]:
return tf.cast(element, self.dtype)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class CenterCrop(base.ElementWiseTransform):
"""Crop the input data to the specified shape from the center.
Expand Down
27 changes: 27 additions & 0 deletions kauldron/data/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,30 @@ def map(self, features):
else:
output[key] = element
return output


@dataclasses.dataclass(frozen=True, eq=True)
class AddConstants(tr_abc.MapTransform):
"""Adds constant elements.
```python
kd.data.AddConstants({
'my_field': 1.0,
})
```
Can be used with mixtures when some datasets have missing fields.
"""

values: Mapping[str, Any] = flax.core.FrozenDict()

def map(self, features):
overwrites = set(self.values.keys()) & set(features.keys())
if overwrites:
raise KeyError(
f"Tried adding key(s) {sorted(overwrites)!r} but"
" target names already exist. Implicit overwriting is not supported."
" Please explicitly drop target keys that should be overwritten."
)
features.update(self.values)
return features
81 changes: 81 additions & 0 deletions kauldron/data/transforms/map_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,33 @@

import einops
from etils import enp
from etils import epy
import flax.core
import jax
from kauldron.data.transforms import base
from kauldron.typing import XArray, typechecked # pylint: disable=g-multiple-import,g-importing-member
import numpy as np

with epy.lazy_imports():
import tensorflow as tf # pylint: disable=g-import-not-at-top

_FrozenDict = dict if typing.TYPE_CHECKING else flax.core.FrozenDict


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class Cast(base.ElementWiseTransform):
"""Cast an element to the specified dtype."""

dtype: Any

@typechecked
def map_element(self, element: XArray["*any"]) -> XArray["*any"]:
if enp.lazy.is_tf(element):
return tf.cast(element, self.dtype)
else:
return element.astype(self.dtype)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class Rearrange(base.ElementWiseTransform):
"""Einops rearrange on a single element.
Expand Down Expand Up @@ -102,3 +121,65 @@ class Gather(base.ElementWiseTransform):
def map_element(self, element: XArray) -> XArray:
xnp = enp.lazy.get_xnp(element)
return xnp.take(element, self.indices, axis=self.axis)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class Resize(base.ElementWiseTransform):
"""Resizes an image.
Attributes:
size: The new size of the image.
method: The resizing method. If `None`, uses `area` for float inputs and
`nearest` for int inputs, and `area` for float inputs.
antialias: Whether to use an anti-aliasing filter.
"""

size: tuple[int, int]
method: str | jax.image.ResizeMethod | tf.image.ResizeMethod | None = None
antialias: bool = True

@typechecked
def map_element(self, element: XArray["*b h w c"]) -> XArray["*b h2 w2 c"]:
if self.method is None:
method = "nearest" if _is_integer(element.dtype) else "area"
else:
method = self.method

if enp.lazy.is_tf(element):
# Flatten the batch dimensions
batch = tf.shape(element)[:-3]
imgs = einops.rearrange(element, "... h w c -> (...) h w c")

imgs = tf.image.resize(
imgs,
self.size,
method=method,
antialias=self.antialias,
)

# Unflatten the batch dimensions
return tf.reshape(imgs, tf.concat([batch, tf.shape(imgs)[-3:]], axis=0))
elif enp.lazy.is_np(element) or enp.lazy.is_jax(element):
if method == "area":
raise ValueError(
"Area resizing is not supported in JAX for float inputs"
" (Upvote: https://github.com/jax-ml/jax/issues/20098).\n"
"Please explicitly provide a resizing method."
)

*batch, _, _, c = element.shape
size = (*batch, *self.size, c)
# Explicitly set device to avoid `Disallowed host-to-device transfer`
element = jax.device_put(element, jax.devices("cpu")[0])
return jax.image.resize(
element,
size,
method=method,
antialias=self.antialias,
)
else:
raise ValueError(f"Unsupported type: {type(element)}")


def _is_integer(dtype: Any) -> bool:
return np.issubdtype(enp.lazy.as_dtype(dtype), np.integer)
11 changes: 11 additions & 0 deletions kauldron/data/transforms/map_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,14 @@ def test_transforms(

def _as_shape(shape: str) -> tuple[int, ...]:
return tuple(int(d) for d in shape.split())


@enp.testing.parametrize_xnp(skip=["torch"])
def test_resize(xnp: enp.NpModule):
vr = kd.data.py.Resize(
key="img",
size=(12, 12),
)
before = {"img": xnp.zeros((5, 5, 3), dtype=xnp.uint8)}
after = vr.map(before)
assert after["img"].shape == (12, 12, 3)
3 changes: 3 additions & 0 deletions kauldron/typing/type_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def typechecked(fn):
@jaxtyping.jaxtyped(typechecker=None)
@functools.wraps(fn)
def _reraise_with_shape_info(*args, _typecheck: bool = True, **kwargs):
# Hide the function from the traceback. Supported by Pytest and IPython
__tracebackhide__ = True # pylint: disable=unused-variable,invalid-name

if not (TYPECHECKING_ENABLED and _typecheck):
# typchecking disabled globally or locally -> just return fn(...)
return fn(*args, **kwargs)
Expand Down
3 changes: 3 additions & 0 deletions kauldron/utils/train_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ def _add_is_training_kwargs(fn: _FnT) -> _FnT:

@_internal.wraps_with_reload(fn)
def decorated(*args, is_training_property: bool | None = None, **kwargs): # pylint: disable=redefined-outer-name
# Hide the function from the traceback. Supported by Pytest and IPython
__tracebackhide__ = True # pylint: disable=unused-variable,invalid-name

if is_training_property is not None:
cm = _set_train_property(is_training_property)
else:
Expand Down

0 comments on commit 818c38f

Please sign in to comment.