diff --git a/README.md b/README.md index 1e1ff45..0e97353 100644 --- a/README.md +++ b/README.md @@ -4,11 +4,11 @@ # Official Python package for Deep Kernel Shaping (DKS) and Tailored Activation Transformations (TAT) This Python package implements the activation function transformations and -weight initializations used Deep Kernel Shaping (DKS) and Tailored Activation +weight initializations used in Deep Kernel Shaping (DKS) and Tailored Activation Transformations (TAT). DKS and TAT, which were introduced in the [DKS paper] and -[TAT paper], are methods constructing/transforming neural networks to make them -much easier to train. For example, these methods can be used in conjunction with -K-FAC to train deep vanilla deep convnets (without skip connections or +[TAT paper], are methods for constructing/transforming neural networks to make +them much easier to train. For example, these methods can be used in conjunction +with K-FAC to train deep vanilla deep convnets (without skip connections or normalization layers) as fast as standard ResNets of the same depth. The package supports the JAX, PyTorch, and TensorFlow tensor programming @@ -23,43 +23,51 @@ from Github will be rejected. Instead, please email us if you find a bug. ## Usage For each of the supported tensor programming frameworks, there is a -corresponding directory/subpackage which handles the activation function -transformations and weight initializations. It's up to the user to import these -and use them appropriately within their model code. Activation functions are -transformed by the function `get_transformed_activations()` in module -`activation_transform` of the appropriate subpackage. Weight sampling is done -using functions inside of the module `parameter_sampling_functions` of said -subpackage. - -In addition to using these functions, the user is responsble for ensuring that -their model meets the architectural requirements of DKS/TAT, and for converting -any weighted sums in their model to "normalized sums" (which are weighted sums -whoses non-trainable weights have a sum of squares equal to 1). This package -doesn't currently include an implementation of Per-Location Normalization (PLN) -data pre-processing. While not required for CIFAR or ImageNet, PLN could -potentially be important for other datasets. See the section titled "Summary of -our method" in the [DKS paper] for more details about the requirements and -execution steps of DKS. To read about the additional requirements of TAT, such -as the subset maximizing function, refer to Appendix B of the [TAT paper]. - -Note that ReLUs are only partially supported by DKS, and unsupported by TAT, and -their use is *highly* discouraged. Instead, one should use Leaky ReLUs, which +corresponding subpackage which handles the activation function transformations +and weight initializations. (These are `dks.jax`, `dks.pytorch`, and +`dks.tensorflow`.) It's up to the user to import these and use them +appropriately within their model code. Activation functions are transformed by +the function `get_transformed_activations()` in the module +`activation_transform` of the appropriate subpackage. Sampling initial +parameters is done using functions inside of the module +`parameter_sampling_functions` of said subpackage. Note that in order to avoid +having to import all of the tensor programming frameworks, the user is required +to individually import whatever framework subpackage they want. e.g. `import +dks.jax`. Meanwhile, `import dks` won't actually do anything. + +`get_transformed_activations()` requires the user to pass either the "maximal +slope function" for DKS, the "subnet maximizing function" for TAT with Leaky +ReLUs, or the "maximal curvature function" for TAT with smooth activation +functions. (The subnet maximizing function also handles DKS and TAT with smooth +activations.) These are special functions that encode information about the +particular model architecture. See the section titled "Summary of our method" of +the [DKS paper] for a procedure to construct the maximal slope function for a +given model, or the appendix section titled "Additional details and pseudocode +for activation function transformations" of the [TAT paper] for procedures to +construct the other two functions. + +In addition to these things, the user is responsible for ensuring that their +model meets the architectural requirements of DKS/TAT, and for converting any +weighted sums into "normalized sums" (which are weighted sums whose +non-trainable weights have a sum of squares equal to 1). See the section titled +"Summary of our method" of the [DKS paper] for more details. + +Note that this package doesn't currently include an implementation of +Per-Location Normalization (PLN) data pre-processing. While not required for +CIFAR or ImageNet, PLN could potentially be important for other datasets. Also +note that ReLUs are only partially supported by DKS, and unsupported by TAT, and +so their use is *highly* discouraged. Instead, one should use Leaky ReLUs, which are fully supported by DKS, and work especially well with TAT. -Note that in order to avoid having to import all of the tensor programming -frameworks, the user is required to individually import whatever framework -subpackage they want. e.g. `import dks.jax`. Meanwhile, `import dks` won't -actually do anything. - ## Example -`dks.examples.haiku.modified_resnet` is a Haiku ResNet model which has been +`dks.examples.haiku.modified_resnet` is a [Haiku] ResNet model which has been modified as described in the DKS/TAT papers, and includes support for both DKS -and TAT. By default, it removes the normalization layers and skip connections -found in standard ResNets, making it a "vanilla network". It can be used as an -instructive example for how to build DKS/TAT models using this package. See the -section titled "Application to various modified ResNets" from the [DKS paper] -for more details. +and TAT. When constructed with its default arguments, it removes the +normalization layers and skip connections found in standard ResNets, making it a +"vanilla network". It can be used as an instructive example for how to build +DKS/TAT models using this package. See the section titled "Application to +various modified ResNets" from the [DKS paper] for more details. ## Installation @@ -75,7 +83,7 @@ or pip install -e git+https://github.com/deepmind/dks.git#egg=dks[] ``` -or from PyPI with +Or from PyPI with ```bash pip install dks @@ -87,7 +95,7 @@ or pip install dks[] ``` -Here `` is a common-separated list (with no spaces) of strings that can +Here `` is a common-separated list of strings (with no spaces) that can be passed to install extra dependencies for different tensor programming frameworks. Valid strings are `jax`, `tf`, and `pytorch`. So for example, to install `dks` with the extra requirements for JAX and PyTorch, one does @@ -126,3 +134,4 @@ This is not an official Google product. [DKS paper]: https://arxiv.org/abs/2110.01765 [TAT paper]: https://openreview.net/forum?id=U0k7XNTiFEq +[Haiku]: https://github.com/deepmind/dm-haiku diff --git a/dks/__init__.py b/dks/__init__.py index 2d94cdd..0c4782c 100644 --- a/dks/__init__.py +++ b/dks/__init__.py @@ -16,4 +16,4 @@ # Do not directly import this package; it won't do anything. Instead, import one # of the framework-specific subpackages. -__version__ = "0.1.0" +__version__ = "0.1.1" diff --git a/dks/base/activation_transform.py b/dks/base/activation_transform.py index 143d7f2..3f55cca 100644 --- a/dks/base/activation_transform.py +++ b/dks/base/activation_transform.py @@ -564,7 +564,12 @@ def get_transformed_activations( See the DKS paper (https://arxiv.org/abs/2110.01765) and the TAT paper (https://openreview.net/forum?id=U0k7XNTiFEq) for details about what these - are, how they are computed, and what their parameters mean. + are, how they are computed, and what their parameters mean. A procedure to + compute the "maximal slope function" is given in the section titled "Summary + of our method" of the DKS paper. Procedures to compute the "maximal curvature + function", and the "subnet maximizing function", are given in the appendix + section titled "Additional details and pseudocode for activation function + transformations" of the TAT paper. Note that if you are using the JAX, PyTorch, or TensorFlow frameworks, you probably want to be using the version of get_transformed_activations() in the @@ -604,21 +609,22 @@ def get_transformed_activations( corresponds to "tau" from the paper), and defaults to 0.3. If ``tat_params`` is passed as None, it defaults to the empty dictionary (so that the parameters will use their default values). (Default: None) - max_slope_func: A callable which computes the maximal slope function, as - defined in the DKS paper. It should take a single argument representing - the slope of each local C map at ``c=1``. If this is required (i.e. when - using DKS) but is passed as None, it will be generated using - ``subnet_max_func`` if possible. (Default: None) - max_curv_func: A callable which computes the maximal curvature function. It - should take a single parameter representing the second derivative of each - local C map at c=1. If this is required (i.e. when using TAT with smooth - activation functions) but is passed as None, it will be generated using - ``subnet_max_func`` if possible. (Default: None) - subnet_max_func: A callable which computes the subnetwork maximizing - function of the network (denoted ``M_{f,r}(x)`` in the TAT paper). It - should take two arguments: the input value ``x``, and a callable - ``r_func`` which maps a float to a float. This is required when using TAT - with Leaky ReLUs. (Default: None) + max_slope_func: A callable which computes the "maximal slope function" of + the network, as defined in the DKS paper. It should take a single argument + representing the slope of each local C map at ``c=1``. If this is required + (i.e. when using DKS) but passed as None, it will be generated using + ``subnet_max_func`` when possible. (Default: None) + max_curv_func: A callable which computes the "maximal curvature function" of + the network, as defined in the TAT paper. It should take a single + parameter representing the second derivative of each local C map at c=1. + If this is required (i.e. when using TAT with smooth activation functions) + but is passed as None, it will be generated using ``subnet_max_func`` when + possible. (Default: None) + subnet_max_func: A callable which computes the "subnetwork maximizing + function" of the network, as defined in the TAT paper (and denoted + ``M_{f,r}(x)``). It should take two arguments: the input value ``x``, and + a callable ``r_func`` which maps a float to a float. This is required when + using TAT with Leaky ReLUs. (Default: None) activation_getter: A callable which takes a string name for an activation function and returns the (untransformed) activation function corresponding to this name. Defaults to one returning activation functions in NumPy diff --git a/tests/test_activation_transform.py b/tests/test_activation_transform.py index e5c8e8f..42d6996 100644 --- a/tests/test_activation_transform.py +++ b/tests/test_activation_transform.py @@ -35,6 +35,8 @@ def _subnet_max_func(x, r_fn, shortcut_weight=0.6): blocks_per_group = (3, 4, 23, 3) + res_branch_subnetwork_x = r_fn(r_fn(r_fn(x))) + for i in range(4): for j in range(blocks_per_group[i]): @@ -46,7 +48,7 @@ def _subnet_max_func(x, r_fn, shortcut_weight=0.6): x = r_fn(x) - return x + return max(x, res_branch_subnetwork_x) class ActivationTransformTest(absltest.TestCase):