diff --git a/equinox/nn/composed.py b/equinox/nn/composed.py index 192117c3..2c4ad497 100644 --- a/equinox/nn/composed.py +++ b/equinox/nn/composed.py @@ -13,7 +13,10 @@ import jax.nn as jnn import jax.random as jrandom +import jax.numpy as jnp from jaxtyping import Array +from jax import lax +from jax.tree_util import tree_flatten from ..custom_types import PRNGKey from ..module import Module, static_field @@ -101,6 +104,35 @@ def __init__( self.activation = activation self.final_activation = final_activation + def _scan_hidden_layers(self, x): + def step(inp, layer_weights): + weight = layer_weights[:, :-1] + bias = layer_weights[:, -1].T + layer = self.layers[1]._tree_unflatten( + [ + ["weight", "bias"], + ["in_features", "out_features", "use_bias"], + [ + self.layers[1].in_features, + self.layers[1].out_features, + self.layers[1].use_bias, + ], + ], + [weight, bias], + ) + inp = self.activation(layer(inp)) + return inp, None + + flattened_layers, _ = tree_flatten(self.layers[1:-1]) + concatenated_weight_bias = [ + jnp.concatenate([weight, bias.reshape(-1, 1)], axis=1) + for weight, bias in zip(flattened_layers[::2], flattened_layers[1::2]) + ] + stacked_weights = jnp.stack(concatenated_weight_bias) + + x, _ = lax.scan(step, x, stacked_weights) + return x + def __call__(self, x: Array, *, key: Optional[PRNGKey] = None) -> Array: """**Arguments:** @@ -113,9 +145,14 @@ def __call__(self, x: Array, *, key: Optional[PRNGKey] = None) -> Array: A JAX array with shape `(out_size,)`. (Or shape `()` if `out_size="scalar"`.) """ - for layer in self.layers[:-1]: - x = layer(x) + + if len(self.layers) > 1: + x = self.layers[0](x) x = self.activation(x) + + if len(self.layers) > 2: + x = self._scan_hidden_layers(x) + x = self.layers[-1](x) x = self.final_activation(x) return x