From ff28c354af4692f6b8cc2b7662f5d010479bd6f1 Mon Sep 17 00:00:00 2001 From: hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Wed, 20 Mar 2024 17:09:39 -0700 Subject: [PATCH] Add `JaxLayer` and `FlaxLayer` to wrap JAX/Flax modules as layers. (#19342) - `JaxLayer` can wrap any JAX model defined by a function. - `FlaxLayer` is a subclass of `JaxLayer` that can wrap a Flax module. --- keras/utils/jax_layer.py | 667 ++++++++++++++++++++++++++++++++++ keras/utils/jax_layer_test.py | 657 +++++++++++++++++++++++++++++++++ requirements.txt | 1 + 3 files changed, 1325 insertions(+) create mode 100644 keras/utils/jax_layer.py create mode 100644 keras/utils/jax_layer_test.py diff --git a/keras/utils/jax_layer.py b/keras/utils/jax_layer.py new file mode 100644 index 00000000000..871eb6106ed --- /dev/null +++ b/keras/utils/jax_layer.py @@ -0,0 +1,667 @@ +import inspect + +import jax +import numpy as np + +from keras import backend +from keras.api_export import keras_export +from keras.layers.layer import Layer +from keras.saving import serialization_lib +from keras.utils import shape_utils +from keras.utils import tracking + + +@keras_export("keras.layers.JaxLayer") +class JaxLayer(Layer): + """Keras Layer that wraps a JAX model. + + This layer enables the use of JAX components within Keras when using JAX as + the backend for Keras. + + ## Model function + + This layer accepts JAX models in the form of a function, `call_fn`, which + must take the following arguments with these exact names: + + - `params`: trainable parameters of the model. + - `state` (*optional*): non-trainable state of the model. Can be omitted if + the model has no non-trainable state. + - `rng` (*optional*): a `jax.random.PRNGKey` instance. Can be omitted if the + model does not need RNGs, neither during training nor during inference. + - `inputs`: inputs to the model, a JAX array or a `PyTree` of arrays. + - `training` (*optional*): an argument specifying if we're in training mode + or inference mode, `True` is passed in training mode. Can be omitted if + the model behaves the same in training mode and inference mode. + + The `inputs` argument is mandatory. Inputs to the model must be provided via + a single argument. If the JAX model takes multiple inputs as separate + arguments, they must be combined into a single structure, for instance in a + `tuple` or a `dict`. + + ## Model weights initialization + + The initialization of the `params` and `state` of the model can be handled + by this layer, in which case the `init_fn` argument must be provided. This + allows the model to be initialized dynamically with the right shape. + Alternatively, and if the shape is known, the `params` argument and + optionally the `state` argument can be used to create an already initialized + model. + + The `init_fn` function, if provided, must take the following arguments with + these exact names: + + - `rng`: a `jax.random.PRNGKey` instance. + - `inputs`: a JAX array or a `PyTree` of arrays with placeholder values to + provide the shape of the inputs. + - `training` (*optional*): an argument specifying if we're in training mode + or inference mode. `True` is always passed to `init_fn`. Can be omitted + regardless of whether `call_fn` has a `training` argument. + + ## Models with non-trainable state + + For JAX models that have non-trainable state: + + - `call_fn` must have a `state` argument + - `call_fn` must return a `tuple` containing the outputs of the model and + the new non-trainable state of the model + - `init_fn` must return a `tuple` containing the initial trainable params of + the model and the initial non-trainable state of the model. + + This code shows a possible combination of `call_fn` and `init_fn` signatures + for a model with non-trainable state. In this example, the model has a + `training` argument and an `rng` argument in `call_fn`. + + ```python + def stateful_call(params, state, rng, inputs, training): + outputs = ... + new_state = ... + return outputs, new_state + + def stateful_init(rng, inputs): + initial_params = ... + initial_state = ... + return initial_params, initial_state + ``` + + ## Models without non-trainable state + + For JAX models with no non-trainable state: + + - `call_fn` must not have a `state` argument + - `call_fn` must return only the outputs of the model + - `init_fn` must return only the initial trainable params of the model. + + This code shows a possible combination of `call_fn` and `init_fn` signatures + for a model without non-trainable state. In this example, the model does not + have a `training` argument and does not have an `rng` argument in `call_fn`. + + ```python + def stateless_call(params, inputs): + outputs = ... + return outputs + + def stateless_init(rng, inputs): + initial_params = ... + return initial_params + ``` + + ## Conforming to the required signature + + If a model has a different signature than the one required by `JaxLayer`, + one can easily write a wrapper method to adapt the arguments. This example + shows a model that has multiple inputs as separate arguments, expects + multiple RNGs in a `dict`, and has a `deterministic` argument with the + opposite meaning of `training`. To conform, the inputs are combined in a + single structure using a `tuple`, the RNG is split and used the populate the + expected `dict`, and the Boolean flag is negated: + + ```python + def my_model_fn(params, rngs, input1, input2, deterministic): + ... + if not deterministic: + dropout_rng = rngs["dropout"] + keep = jax.random.bernoulli(dropout_rng, dropout_rate, x.shape) + x = jax.numpy.where(keep, x / dropout_rate, 0) + ... + ... + return outputs + + def my_model_wrapper_fn(params, rng, inputs, training): + input1, input2 = inputs + rng1, rng2 = jax.random.split(rng) + rngs = {"dropout": rng1, "preprocessing": rng2} + deterministic = not training + return my_model_fn(params, rngs, input1, input2, deterministic) + + keras_layer = JaxLayer(my_model_wrapper_fn, params=initial_params) + ``` + + ## Usage with Haiku modules + + `JaxLayer` enables the use of [Haiku](https://dm-haiku.readthedocs.io) + components in the form of + [`haiku.Module`](https://dm-haiku.readthedocs.io/en/latest/api.html#module). + This is achieved by transforming the module per the Haiku pattern and then + passing `module.apply` in the `call_fn` parameter and `module.init` in the + `init_fn` parameter if needed. + + If the model has non-trainable state, it should be transformed with + [`haiku.transform_with_state`]( + https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.transform_with_state). + If the model has no non-trainable state, it should be transformed with + [`haiku.transform`]( + https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.transform). + Additionally, and optionally, if the module does not use RNGs in "apply", it + can be transformed with + [`haiku.without_apply_rng`]( + https://dm-haiku.readthedocs.io/en/latest/api.html#without-apply-rng). + + The following example shows how to create a `JaxLayer` from a Haiku module + that uses random number generators via `hk.next_rng_key()` and takes a + training positional argument: + + ```python + class MyHaikuModule(hk.Module): + def __call__(self, x, training): + x = hk.Conv2D(32, (3, 3))(x) + x = jax.nn.relu(x) + x = hk.AvgPool((1, 2, 2, 1), (1, 2, 2, 1), "VALID")(x) + x = hk.Flatten()(x) + x = hk.Linear(200)(x) + if training: + x = hk.dropout(rng=hk.next_rng_key(), rate=0.3, x=x) + x = jax.nn.relu(x) + x = hk.Linear(10)(x) + x = jax.nn.softmax(x) + return x + + def my_haiku_module_fn(inputs, training): + module = MyHaikuModule() + return module(inputs, training) + + transformed_module = hk.transform(my_haiku_module_fn) + + keras_layer = JaxLayer( + call_fn=transformed_module.apply, + init_fn=transformed_module.init, + ) + ``` + + Args: + call_fn: The function to call the model. See description above for the + list of arguments it takes and the outputs it returns. + init_fn: the function to call to initialize the model. See description + above for the list of arguments it takes and the ouputs it returns. + If `None`, then `params` and/or `state` must be provided. + params: A `PyTree` containing all the model trainable parameters. This + allows passing trained parameters or controlling the initialization. + If both `params` and `state` are `None`, `init_fn` is called at + build time to initialize the trainable parameters of the model. + state: A `PyTree` containing all the model non-trainable state. This + allows passing learned state or controlling the initialization. If + both `params` and `state` are `None`, and `call_fn` takes a `state` + argument, then `init_fn` is called at build time to initialize the + non-trainable state of the model. + seed: Seed for random number generator. Optional. + """ + + def __init__( + self, + call_fn, + init_fn=None, + params=None, + state=None, + seed=None, + **kwargs, + ): + if backend.backend() != "jax": + raise ValueError( + "JaxLayer is only supported with the JAX backend. Current " + f"backend: {backend.backend()}" + ) + + if init_fn is None and params is None and state is None: + raise ValueError( + "`init_fn`, `params` and `state` cannot all be `None`." + ) + + super().__init__(**kwargs) + self.call_fn = call_fn + self.init_fn = init_fn + self.tracked_params = self._create_variables(params, trainable=True) + self.tracked_state = self._create_variables(state, trainable=False) + self.seed_generator = backend.random.SeedGenerator(seed) + + self.call_fn_arguments = self._validate_signature( + call_fn, + "call_fn", + {"params", "state", "rng", "inputs", "training"}, + {"inputs"}, + ) + self.has_state = "state" in self.call_fn_arguments + + if init_fn: + self.init_fn_arguments = self._validate_signature( + init_fn, "init_fn", {"rng", "inputs", "training"}, {"inputs"} + ) + + def _validate_signature(self, fn, fn_name, allowed, required): + fn_parameters = inspect.signature(fn).parameters + for parameter_name in required: + if parameter_name not in fn_parameters: + raise ValueError( + f"Missing required argument in `{fn_name}`: " + f"`{parameter_name}`" + ) + + parameter_names = [] + for parameter in fn_parameters.values(): + if parameter.name not in allowed: + raise ValueError( + f"Unsupported argument in `{fn_name}`: `{parameter.name}`, " + f"supported arguments are `{'`, `'.join(allowed)}`" + ) + parameter_names.append(parameter.name) + + return parameter_names + + @tracking.no_automatic_dependency_tracking + def _create_variables(self, values, trainable): + """Create a structure of variables from a structure of JAX arrays. + + `values` is traversed via JAX's `tree_map`. When a leaf is a JAX array + or a tensor-like object, a corresponding variable is created with it as + the initial value. The resulting structure of variables is assigned to + `self.params` or `self.state` depending on `trainable`. Then, a + flattened version of the variables is returned for tracking. + `self.params` or `self.state` are intentionally not tracked because + structures like `TrackedList` interfere with `jax.tree_utils`. + Note that leaf objects that are not JAX arrays and not tensor-like are + left intact as they are assumed to be configuration used by the model. + + Args: + values: the structure of values to traverse. + trainable: whether to create trainable variables. + + Returns: + flat list of variables initialized with `values` for tracking. + """ + + def create_variable(value): + if backend.is_tensor(value) or isinstance(value, np.ndarray): + variable = self.add_weight( + value.shape, initializer="zeros", trainable=trainable + ) + variable.assign(value) + return variable + elif isinstance(value, (np.generic, int, float)): + variable = self.add_weight( + (), initializer="zeros", trainable=trainable + ) + variable.assign(value) + return variable + else: + return value + + # Use JAX's tree_map as it understands registered classes. + variables = jax.tree_util.tree_map(create_variable, values) + + if trainable: + self.params = variables + else: + self.state = variables + + return jax.tree_util.tree_flatten(variables) + + def _get_init_rng(self): + """ + Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `init_fn`. + + By default, this returns a single `PRNGKey` retrieved by calling + `self.seed_generator.next()`. Override this to return a different + structure. + + Returns: + a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as + the `rng` argument of `init_fn`. + """ + return self.seed_generator.next() + + def _get_call_rng(self, training): + """ + Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `call_fn`. + + By default, this returns a single `PRNGKey` retrieved by calling + `self.seed_generator.next()` when `training` is `True`, and `None` when + `training` is `False`. Override this to return a different structure or + to pass RNGs in inference mode too. + + Returns: + a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as + the `rng` argument of `call_fn`. + """ + if training: + return self.seed_generator.next() + else: + return None + + def build(self, input_shape): + if self.params is not None or self.state is not None: + return + + # Initialize `params` and `state` if needed by calling `init_fn`. + def create_input(shape): + shape = [d if d is not None else 1 for d in shape] + return jax.numpy.ones(shape) + + init_inputs = shape_utils.map_shape_structure(create_input, input_shape) + init_args = [] + for argument_name in self.init_fn_arguments: + if argument_name == "rng": + init_args.append(self._get_init_rng()) + elif argument_name == "inputs": + init_args.append(init_inputs) + elif argument_name == "training": + init_args.append(True) + + init_result = self.init_fn(*init_args) + if self.has_state: + init_params, init_state = init_result + else: + init_params, init_state = init_result, None + + self.tracked_params = self._create_variables( + init_params, trainable=True + ) + self.tracked_state = self._create_variables(init_state, trainable=False) + + def call(self, inputs, training=False): + def unwrap_variable(variable): + return None if variable is None else variable.value + + call_args = [] + for argument_name in self.call_fn_arguments: + if argument_name == "params": + call_args.append( + jax.tree_util.tree_map(unwrap_variable, self.params) + ) + elif argument_name == "state": + call_args.append( + jax.tree_util.tree_map(unwrap_variable, self.state) + ) + elif argument_name == "rng": + call_args.append(self._get_call_rng(training)) + elif argument_name == "inputs": + call_args.append(inputs) + elif argument_name == "training": + call_args.append(training) + + def assign_state_to_variable(value, variable): + # This exists only to make debugging this error case easier. + if not hasattr(variable, "assign"): + raise ValueError( + "Structure mismatch: the structure of the state returned " + "by `call` does not match the structure of the state at " + "initialization time." + ) + variable.assign(value) + + if self.has_state: + predictions, new_state = self.call_fn(*call_args) + jax.tree_util.tree_map( + assign_state_to_variable, new_state, self.state + ) + return predictions + else: + return self.call_fn(*call_args) + + def get_config(self): + config = { + "call_fn": serialization_lib.serialize_keras_object(self.call_fn), + "init_fn": serialization_lib.serialize_keras_object(self.init_fn), + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config): + call_fn = serialization_lib.deserialize_keras_object(config["call_fn"]) + init_fn = serialization_lib.deserialize_keras_object(config["init_fn"]) + config["call_fn"] = call_fn + config["init_fn"] = init_fn + return super().from_config(config) + + +@keras_export("keras.layers.FlaxLayer") +class FlaxLayer(JaxLayer): + """Keras Layer that wraps a [Flax](https://flax.readthedocs.io) module. + + This layer enables the use of Flax components in the form of + [`flax.linen.Module`]( + https://flax.readthedocs.io/en/latest/flax.linen.html#module) + instances within Keras when using JAX as the backend for Keras. + + The module method to use for the forward pass can be specified via the + `method` argument and is `__call__` by default. This method must take the + following arguments with these exact names: + + - `self` if the method is bound to the module, which is the case for the + default of `__call__`, and `module` otherwise to pass the module. + - `inputs`: the inputs to the model, a JAX array or a `PyTree` of arrays. + - `training` *(optional)*: an argument specifying if we're in training mode + or inference mode, `True` is passed in training mode. + + `FlaxLayer` handles the non-trainable state of your model and required RNGs + automatically. Note that the `mutable` parameter of + [`flax.linen.Module.apply()`]( + https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply) + is set to `DenyList(["params"])`, therefore making the assumption that all + the variables outside of the "params" collection are non-trainable weights. + + This example shows how to create a `FlaxLayer` from a Flax `Module` with + the default `__call__` method and no training argument: + + ```python + class MyFlaxModule(flax.linen.Module): + @flax.linen.compact + def __call__(self, inputs): + x = inputs + x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x) + x = flax.linen.relu(x) + x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + x = flax.linen.Dense(features=200)(x) + x = flax.linen.relu(x) + x = flax.linen.Dense(features=10)(x) + x = flax.linen.softmax(x) + return x + + flax_module = MyFlaxModule() + keras_layer = FlaxLayer(flax_module) + ``` + + This example shows how to wrap the module method to conform to the required + signature. This allows having multiple input arguments and a training + argument that has a different name and values. This additionally shows how + to use a function that is not bound to the module. + + ```python + class MyFlaxModule(flax.linen.Module): + @flax.linen.compact + def forward(self, input1, input1, deterministic): + ... + return outputs + + def my_flax_module_wrapper(module, inputs, training): + input1, input2 = inputs + return module.forward(input1, input2, not training) + + flax_module = MyFlaxModule() + keras_layer = FlaxLayer(flax_module) + module=flax_module, + method=my_flax_module_wrapper, + ) + ``` + + Args: + module: An instance of `flax.linen.Module` or subclass. + method: The method to call the model. This is generally a method in the + `Module`. If not provided, the `__call__` method is used. `method` + can also be a function not defined in the `Module`, in which case it + must take the `Module` as the first argument. It is used for both + `Module.init` and `Module.apply`. Details are documented in the + `method` argument of [`flax.linen.Module.apply()`]( + https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply). + variables: A `dict` containing all the variables of the module in the + same format as what is returned by [`flax.linen.Module.init()`]( + https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.init). + It should contain a "params" key and, if applicable, other keys for + collections of variables for non-trainable state. This allows + passing trained parameters and learned non-trainable state or + controlling the initialization. If `None` is passed, the module's + `init` function is called at build time to initialize the variables + of the model. + """ + + def __init__( + self, + module, + method=None, + variables=None, + **kwargs, + ): + # Late import to only require Flax when this is used. + from flax.core import scope as flax_scope + + if backend.backend() != "jax": + raise ValueError( + "FlaxLayer is only supported with the JAX backend. Current " + f"backend: {backend.backend()}" + ) + + self.module = module + self.method = method + + apply_mutable = flax_scope.DenyList(["params"]) + + def apply_with_training(params, state, rng, inputs, training): + return self.module.apply( + self._params_and_state_to_variables(params, state), + inputs, + rngs=rng, + method=self.method, + mutable=apply_mutable, + training=training, + ) + + def apply_without_training(params, state, rng, inputs): + return self.module.apply( + self._params_and_state_to_variables(params, state), + inputs, + rngs=rng, + method=self.method, + mutable=apply_mutable, + ) + + def init_with_training(rng, inputs, training): + return self._variables_to_params_and_state( + self.module.init( + rng, + inputs, + method=self.method, + training=training, + ) + ) + + def init_without_training(rng, inputs): + return self._variables_to_params_and_state( + self.module.init( + rng, + inputs, + method=self.method, + ) + ) + + if ( + "training" + in inspect.signature(method or module.__call__).parameters + ): + call_fn, init_fn = apply_with_training, init_with_training + else: + call_fn, init_fn = apply_without_training, init_without_training + + params, state = self._variables_to_params_and_state(variables) + + super().__init__( + call_fn=call_fn, + init_fn=init_fn, + params=params, + state=state, + **kwargs, + ) + + def _params_and_state_to_variables(self, params, state): + if params: + if state: + return {**params, **state} + else: + return params + elif state: + return state + return {} + + def _variables_to_params_and_state(self, variables): + # neither params nor state + if variables is None: + return None, None + # state only + if "params" not in variables: + return {}, variables + # params only + if len(variables) == 1: + return variables, {} + # both, we need to split + params = {"params": variables["params"]} + state = {k: v for k, v in variables.items() if k != "params"} + return params, state + + def _get_init_rng(self): + return { + "params": self.seed_generator.next(), + "dropout": self.seed_generator.next(), + } + + def _get_call_rng(self, training): + if training: + return {"dropout": self.seed_generator.next()} + else: + return {} + + def get_config(self): + config_method = self.method + if ( + hasattr(self.method, "__self__") + and self.method.__self__ == self.module + ): + # A method bound to the module is serialized by name. + config_method = self.method.__name__ + config = { + "module": serialization_lib.serialize_keras_object(self.module), + "method": serialization_lib.serialize_keras_object(config_method), + } + base_config = super().get_config() + # call_fn and init_fn come from module, do not save them. + base_config.pop("call_fn") + base_config.pop("init_fn") + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config): + module = serialization_lib.deserialize_keras_object(config["module"]) + method = serialization_lib.deserialize_keras_object(config["method"]) + if isinstance(config["method"], str): + # Deserialize bound method from the module. + method = getattr(module, method) + config["module"] = module + config["method"] = method + return cls(**config) diff --git a/keras/utils/jax_layer_test.py b/keras/utils/jax_layer_test.py new file mode 100644 index 00000000000..cf89c0658f5 --- /dev/null +++ b/keras/utils/jax_layer_test.py @@ -0,0 +1,657 @@ +import os + +import flax +import jax +import jax.numpy as jnp +import numpy as np +import pytest +import tensorflow as tf +from absl.testing import parameterized + +from keras import backend +from keras import layers +from keras import metrics +from keras import models +from keras import saving +from keras import testing +from keras import utils +from keras.export import export_lib +from keras.saving import object_registration +from keras.utils import tree +from keras.utils.jax_layer import FlaxLayer +from keras.utils.jax_layer import JaxLayer + +num_classes = 10 +input_shape = (28, 28, 1) # Excluding batch_size + + +@object_registration.register_keras_serializable() +def jax_stateless_init(rng, inputs): + layer_sizes = [784, 300, 100, 10] + params = [] + w_init = jax.nn.initializers.glorot_normal() + b_init = jax.nn.initializers.normal(0.1) + for m, n in zip(layer_sizes[:-1], layer_sizes[1:]): + rng, w_rng = jax.random.split(rng) + rng, b_rng = jax.random.split(rng) + params.append([w_init(w_rng, (m, n)), b_init(b_rng, (n,))]) + return params + + +@object_registration.register_keras_serializable() +def jax_stateless_apply(params, inputs): + activations = inputs.reshape((inputs.shape[0], -1)) # flatten + for w, b in params[:-1]: + outputs = jnp.dot(activations, w) + b + activations = jnp.tanh(outputs) + + final_w, final_b = params[-1] + logits = jnp.dot(activations, final_w) + final_b + return jax.nn.softmax(logits, axis=-1) + + +@object_registration.register_keras_serializable() +def jax_stateful_init(rng, inputs, training): + params = jax_stateless_init(rng, inputs) + state = jnp.zeros([], jnp.int32) + return params, state + + +@object_registration.register_keras_serializable() +def jax_stateful_apply(params, state, inputs, training): + outputs = jax_stateless_apply(params, inputs) + if training: + state = state + 1 + return outputs, state + + +@object_registration.register_keras_serializable() +class FlaxTrainingIndependentModel(flax.linen.Module): + @flax.linen.compact + def forward(self, inputs): + x = inputs + x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x) + x = flax.linen.relu(x) + x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = flax.linen.Conv(features=64, kernel_size=(3, 3))(x) + x = flax.linen.relu(x) + x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + x = flax.linen.Dense(features=200)(x) + x = flax.linen.relu(x) + x = flax.linen.Dense(features=10)(x) + x = flax.linen.softmax(x) + return x + + def get_config(self): + return {} + + @classmethod + def from_config(cls, config): + return cls(**config) + + +@object_registration.register_keras_serializable() +class FlaxDropoutModel(flax.linen.Module): + @flax.linen.compact + def my_apply(self, inputs, training): + x = inputs + x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x) + x = flax.linen.relu(x) + x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = flax.linen.Conv(features=64, kernel_size=(3, 3))(x) + x = flax.linen.relu(x) + x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + x = flax.linen.Dense(features=200)(x) + x = flax.linen.Dropout(rate=0.3, deterministic=not training)(x) + x = flax.linen.relu(x) + x = flax.linen.Dense(features=10)(x) + x = flax.linen.softmax(x) + return x + + def get_config(self): + return {} + + @classmethod + def from_config(cls, config): + return cls(**config) + + +@object_registration.register_keras_serializable() +def flax_dropout_wrapper(module, x, training): + return module.my_apply(x, training) + + +@object_registration.register_keras_serializable() +class FlaxBatchNormModel(flax.linen.Module): + @flax.linen.compact + def __call__(self, inputs, training=False): + ura = not training + x = inputs + x = flax.linen.Conv(features=12, kernel_size=(3, 3), use_bias=False)(x) + x = flax.linen.BatchNorm(use_running_average=ura, use_scale=False)(x) + x = flax.linen.relu(x) + x = flax.linen.Conv(features=24, kernel_size=(6, 6), strides=(2, 2))(x) + x = flax.linen.BatchNorm(use_running_average=ura, use_scale=False)(x) + x = flax.linen.relu(x) + x = flax.linen.Conv(features=32, kernel_size=(6, 6), strides=(2, 2))(x) + x = flax.linen.BatchNorm(use_running_average=ura, use_scale=False)(x) + x = x.reshape((x.shape[0], -1)) # flatten + x = flax.linen.Dense(features=200, use_bias=True)(x) + x = flax.linen.BatchNorm(use_running_average=ura, use_scale=False)(x) + x = flax.linen.Dropout(rate=0.3, deterministic=not training)(x) + x = flax.linen.relu(x) + x = flax.linen.Dense(features=10)(x) + x = flax.linen.softmax(x) + return x + + def get_config(self): + return {} + + @classmethod + def from_config(cls, config): + return cls(**config) + + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="JaxLayer and FlaxLayer are only supported with JAX backend", +) +class TestJaxLayer(testing.TestCase, parameterized.TestCase): + def _test_layer( + self, + model_name, + layer_class, + layer_init_kwargs, + trainable_weights, + trainable_params, + non_trainable_weights, + non_trainable_params, + ): + # Fake MNIST data + x_train = np.random.uniform(size=(320, 28, 28, 1)) + y_train = np.eye(num_classes, dtype="int32")[ + (np.random.uniform(size=(320,)) * num_classes).astype("int32") + ] + x_test = np.random.uniform(size=(32, 28, 28, 1)) + + def _count_params(weights): + count = 0 + for weight in weights: + count = count + np.prod(weight.shape) + return count + + def verify_weights_and_params(layer): + + self.assertEqual(trainable_weights, len(layer.trainable_weights)) + self.assertEqual( + trainable_params, + _count_params(layer.trainable_weights), + ) + self.assertEqual( + non_trainable_weights, len(layer.non_trainable_weights) + ) + self.assertEqual( + non_trainable_params, + _count_params(layer.non_trainable_weights), + ) + + # functional model + layer1 = layer_class(**layer_init_kwargs) + inputs1 = layers.Input(shape=input_shape) + outputs1 = layer1(inputs1) + model1 = models.Model( + inputs=inputs1, outputs=outputs1, name=model_name + "1" + ) + model1.summary() + + verify_weights_and_params(layer1) + + model1.compile( + loss="categorical_crossentropy", + optimizer="adam", + metrics=[metrics.CategoricalAccuracy()], + ) + + tw1_before_fit = tree.map_structure( + backend.convert_to_numpy, layer1.trainable_weights + ) + ntw1_before_fit = tree.map_structure( + backend.convert_to_numpy, layer1.non_trainable_weights + ) + model1.fit(x_train, y_train, epochs=1, steps_per_epoch=10) + tw1_after_fit = tree.map_structure( + backend.convert_to_numpy, layer1.trainable_weights + ) + ntw1_after_fit = tree.map_structure( + backend.convert_to_numpy, layer1.non_trainable_weights + ) + + # verify both trainable and non-trainable weights did change after fit + for before, after in zip(tw1_before_fit, tw1_after_fit): + self.assertNotAllClose(before, after) + for before, after in zip(ntw1_before_fit, ntw1_after_fit): + self.assertNotAllClose(before, after) + + expected_ouput_shape = (x_test.shape[0], num_classes) + output1 = model1(x_test) + self.assertEqual(output1.shape, expected_ouput_shape) + predict1 = model1.predict(x_test, steps=1) + self.assertEqual(predict1.shape, expected_ouput_shape) + + # verify both trainable and non-trainable weights did not change + tw1_after_call = tree.map_structure( + backend.convert_to_numpy, layer1.trainable_weights + ) + ntw1_after_call = tree.map_structure( + backend.convert_to_numpy, layer1.non_trainable_weights + ) + for after_fit, after_call in zip(tw1_after_fit, tw1_after_call): + self.assertAllClose(after_fit, after_call) + for after_fit, after_call in zip(ntw1_after_fit, ntw1_after_call): + self.assertAllClose(after_fit, after_call) + + exported_params = jax.tree_util.tree_map( + backend.convert_to_numpy, layer1.params + ) + if layer1.state is not None: + exported_state = jax.tree_util.tree_map( + backend.convert_to_numpy, layer1.state + ) + else: + exported_state = None + + def verify_identical_model(model): + output = model(x_test) + self.assertAllClose(output1, output) + + predict = model.predict(x_test, steps=1) + self.assertAllClose(predict1, predict) + + # sequential model to compare results + layer2 = layer_class( + params=exported_params, + state=exported_state, + input_shape=input_shape, + **layer_init_kwargs, + ) + model2 = models.Sequential([layer2], name=model_name + "2") + model2.summary() + verify_weights_and_params(layer2) + model2.compile( + loss="categorical_crossentropy", + optimizer="adam", + metrics=[metrics.CategoricalAccuracy()], + ) + verify_identical_model(model2) + + # save, load back and compare results + path = os.path.join(self.get_temp_dir(), "jax_layer_model.keras") + model2.save(path) + + model3 = saving.load_model(path) + layer3 = model3.layers[0] + model3.summary() + verify_weights_and_params(layer3) + verify_identical_model(model3) + + # export, load back and compare results + path = os.path.join(self.get_temp_dir(), "jax_layer_export") + # export_archive = export_lib.ExportArchive() + # export_archive.track(model2) + # export_archive.add_endpoint( + # "call", + # model2.call, + # input_signature=[ + # tf.TensorSpec( + # shape=(None,) + input_shape, + # dtype=tf.float32, + # ) + # ], + # ) + # export_archive.write_out( + # path, + # tf.saved_model.SaveOptions(experimental_custom_gradients=False), + # ) + export_lib.export_model(model2, path) + model4 = tf.saved_model.load(path) + output4 = model4.serve(x_test) + self.assertAllClose(output1, output4) + + @parameterized.named_parameters( + { + "testcase_name": "training_independent", + "init_kwargs": { + "call_fn": jax_stateless_apply, + "init_fn": jax_stateless_init, + }, + "trainable_weights": 6, + "trainable_params": 266610, + "non_trainable_weights": 0, + "non_trainable_params": 0, + }, + { + "testcase_name": "training_state", + "init_kwargs": { + "call_fn": jax_stateful_apply, + "init_fn": jax_stateful_init, + }, + "trainable_weights": 6, + "trainable_params": 266610, + "non_trainable_weights": 1, + "non_trainable_params": 1, + }, + ) + def test_jax_layer( + self, + init_kwargs, + trainable_weights, + trainable_params, + non_trainable_weights, + non_trainable_params, + ): + self._test_layer( + init_kwargs["call_fn"].__name__, + JaxLayer, + init_kwargs, + trainable_weights, + trainable_params, + non_trainable_weights, + non_trainable_params, + ) + + @parameterized.named_parameters( + { + "testcase_name": "training_independent_bound_method", + "flax_model_class": FlaxTrainingIndependentModel, + "flax_model_method": "forward", + "init_kwargs": {}, + "trainable_weights": 8, + "trainable_params": 648226, + "non_trainable_weights": 0, + "non_trainable_params": 0, + }, + { + "testcase_name": "training_rng_unbound_method", + "flax_model_class": FlaxDropoutModel, + "flax_model_method": None, + "init_kwargs": { + "method": flax_dropout_wrapper, + }, + "trainable_weights": 8, + "trainable_params": 648226, + "non_trainable_weights": 0, + "non_trainable_params": 0, + }, + { + "testcase_name": "training_rng_state_no_method", + "flax_model_class": FlaxBatchNormModel, + "flax_model_method": None, + "init_kwargs": {}, + "trainable_weights": 13, + "trainable_params": 354258, + "non_trainable_weights": 8, + "non_trainable_params": 536, + }, + ) + def test_flax_layer( + self, + flax_model_class, + flax_model_method, + init_kwargs, + trainable_weights, + trainable_params, + non_trainable_weights, + non_trainable_params, + ): + def create_wrapper(**kwargs): + params = kwargs.pop("params") if "params" in kwargs else None + state = kwargs.pop("state") if "state" in kwargs else None + if params and state: + variables = {**params, **state} + elif params: + variables = params + elif state: + variables = state + else: + variables = None + kwargs["variables"] = variables + flax_model = flax_model_class() + if flax_model_method: + kwargs["method"] = getattr(flax_model, flax_model_method) + return FlaxLayer(flax_model_class(), **kwargs) + + self._test_layer( + flax_model_class.__name__, + create_wrapper, + init_kwargs, + trainable_weights, + trainable_params, + non_trainable_weights, + non_trainable_params, + ) + + def test_with_no_init_fn_and_no_params(self): + def jax_fn(params, inputs): + return inputs + + with self.assertRaises(ValueError): + JaxLayer(jax_fn) + + def test_with_training_in_call_fn_but_not_init_fn(self): + def jax_call_fn(params, state, rng, inputs, training): + return inputs, {} + + def jax_init_fn(rng, inputs): + return {}, {} + + layer = JaxLayer(jax_call_fn, jax_init_fn) + layer(np.ones((1,))) + + def test_with_different_argument_order(self): + def jax_call_fn(training, inputs, rng, state, params): + return inputs, {} + + def jax_init_fn(training, inputs, rng): + return {}, {} + + layer = JaxLayer(jax_call_fn, jax_init_fn) + layer(np.ones((1,))) + + def test_with_minimal_arguments(self): + def jax_call_fn(inputs): + return inputs + + def jax_init_fn(inputs): + return {} + + layer = JaxLayer(jax_call_fn, jax_init_fn) + layer(np.ones((1,))) + + def test_with_missing_inputs_in_call_fn(self): + def jax_call_fn(params, rng, training): + return jnp.ones((1,)) + + def jax_init_fn(rng, inputs): + return {} + + with self.assertRaisesRegex(ValueError, "`call_fn`.*`inputs`"): + JaxLayer(jax_call_fn, jax_init_fn) + + def test_with_missing_inputs_in_init_fn(self): + def jax_call_fn(params, rng, inputs, training): + return jnp.ones((1,)) + + def jax_init_fn(rng, training): + return {} + + with self.assertRaisesRegex(ValueError, "`init_fn`.*`inputs`"): + JaxLayer(jax_call_fn, jax_init_fn) + + def test_with_unsupported_argument_in_call_fn(self): + def jax_call_fn(params, rng, inputs, mode): + return jnp.ones((1,)) + + def jax_init_fn(rng, inputs): + return {} + + with self.assertRaisesRegex(ValueError, "`call_fn`.*`mode`"): + JaxLayer(jax_call_fn, jax_init_fn) + + def test_with_unsupported_argument_in_init_fn(self): + def jax_call_fn(params, rng, inputs, training): + return inputs + + def jax_init_fn(rng, inputs, mode): + return {} + + with self.assertRaisesRegex(ValueError, "`init_fn`.*`mode`"): + JaxLayer(jax_call_fn, jax_init_fn) + + def test_with_structures_as_inputs_and_outputs(self): + def jax_fn(params, inputs): + a = inputs["a"] + b = inputs["b"] + output1 = jnp.concatenate([a, b], axis=1) + output2 = jnp.concatenate([b, a], axis=1) + return output1, output2 + + layer = JaxLayer(jax_fn, params={}) + inputs = { + "a": layers.Input((None, 3)), + "b": layers.Input((None, 3)), + } + outputs = layer(inputs) + model = models.Model(inputs, outputs) + + test_inputs = { + "a": np.ones((2, 6, 3)), + "b": np.ones((2, 7, 3)), + } + test_outputs = model(test_inputs) + self.assertAllClose(test_outputs[0], np.ones((2, 13, 3))) + self.assertAllClose(test_outputs[1], np.ones((2, 13, 3))) + + def test_with_polymorphic_shape_more_than_26_dimension_names(self): + def jax_fn(params, inputs): + return jnp.concatenate(inputs, axis=1) + + layer = JaxLayer(jax_fn, params=()) + inputs = [layers.Input((None, 3)) for _ in range(60)] + output = layer(inputs) + model = models.Model(inputs, output) + + test_inputs = [np.ones((2, 1, 3))] * 60 + test_output = model(test_inputs) + self.assertAllClose(test_output, np.ones((2, 60, 3))) + + def test_with_flax_state_no_params(self): + class MyFlaxLayer(flax.linen.Module): + @flax.linen.compact + def __call__(self, x): + def zeros_init(shape): + return jnp.zeros(shape, jnp.int32) + + count = self.variable("a", "b", zeros_init, []) + count.value = count.value + 1 + return x + + layer = FlaxLayer(MyFlaxLayer(), variables={"a": {"b": 0}}) + layer(np.ones((1,))) + self.assertLen(layer.params, 0) + self.assertEqual(layer.state["a"]["b"].value, 1) + + def test_with_state_none_leaves(self): + def jax_fn(params, state, inputs): + return inputs, state + + layer = JaxLayer(jax_fn, state={"foo": None}) + self.assertIsNone(layer.state["foo"]) + layer(np.ones((1,))) + + def test_with_state_non_tensor_leaves(self): + def jax_fn(params, state, inputs): + return inputs, state + + layer = JaxLayer(jax_fn, state={"foo": "bar"}) + self.assertEqual(layer.state["foo"], "bar") + # layer cannot be invoked as jax2tf will fail on strings + + def test_with_state_jax_registered_node_class(self): + @jax.tree_util.register_pytree_node_class + class NamedPoint: + def __init__(self, x, y, name): + self.x = x + self.y = y + self.name = name + + def tree_flatten(self): + return ((self.x, self.y), self.name) + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls(*children, aux_data) + + def jax_fn(params, state, inputs): + return inputs, state + + layer = JaxLayer(jax_fn, state=[NamedPoint(1.0, 2.0, "foo")]) + layer(np.ones((1,))) + + @parameterized.named_parameters( + { + "testcase_name": "sequence_instead_of_mapping", + "init_state": [0.0], + "error_regex": "Expected dict, got ", + }, + { + "testcase_name": "mapping_instead_of_sequence", + "init_state": {"state": {"foo": 0.0}}, + "error_regex": "Expected list, got ", + }, + { + "testcase_name": "sequence_instead_of_variable", + "init_state": {"state": [[0.0]]}, + "error_regex": "Structure mismatch", + }, + { + "testcase_name": "no_initial_state", + "init_state": None, + "error_regex": "Expected dict, got None", + }, + { + "testcase_name": "missing_dict_key", + "init_state": {"state": {}}, + "error_regex": "Expected list, got ", + }, + { + "testcase_name": "missing_variable_in_list", + "init_state": {"state": {"foo": [2.0]}}, + "error_regex": "Expected list, got ", + }, + ) + def test_state_mismatch_during_update(self, init_state, error_regex): + def jax_fn(params, state, inputs): + return inputs, {"state": [jnp.ones([])]} + + layer = JaxLayer(jax_fn, params={}, state=init_state) + with self.assertRaisesRegex(ValueError, error_regex): + layer(np.ones((1,))) + + def test_rng_seeding(self): + def jax_init(rng, inputs): + return [jax.nn.initializers.normal(1.0)(rng, inputs.shape)] + + def jax_apply(params, inputs): + return jnp.dot(inputs, params[0]) + + shape = (2, 2) + + utils.set_random_seed(0) + layer1 = JaxLayer(jax_apply, jax_init) + layer1.build(shape) + utils.set_random_seed(0) + layer2 = JaxLayer(jax_apply, jax_init) + layer2.build(shape) + self.assertAllClose(layer1.params[0], layer2.params[0]) diff --git a/requirements.txt b/requirements.txt index f2cf68a6729..bd34860fe0c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,7 @@ torchvision>=0.16.0 # Jax. jax[cpu] +flax # Common deps. -r requirements-common.txt