Skip to content

Commit

Permalink
- Updating README and docstrings to include more discussion of the "m…
Browse files Browse the repository at this point in the history
…aximal slope function" etc.

- Minor fix to subnet maximizing function in tests (which should be a no-op for the relevant depths).
- Bumping version number to 0.1.1.

PiperOrigin-RevId: 432772935
  • Loading branch information
james-martens authored and DKSdev committed Mar 6, 2022
1 parent 00dc283 commit 080759f
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 56 deletions.
85 changes: 47 additions & 38 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
# Official Python package for Deep Kernel Shaping (DKS) and Tailored Activation Transformations (TAT)

This Python package implements the activation function transformations and
weight initializations used Deep Kernel Shaping (DKS) and Tailored Activation
weight initializations used in Deep Kernel Shaping (DKS) and Tailored Activation
Transformations (TAT). DKS and TAT, which were introduced in the [DKS paper] and
[TAT paper], are methods constructing/transforming neural networks to make them
much easier to train. For example, these methods can be used in conjunction with
K-FAC to train deep vanilla deep convnets (without skip connections or
[TAT paper], are methods for constructing/transforming neural networks to make
them much easier to train. For example, these methods can be used in conjunction
with K-FAC to train deep vanilla deep convnets (without skip connections or
normalization layers) as fast as standard ResNets of the same depth.

The package supports the JAX, PyTorch, and TensorFlow tensor programming
Expand All @@ -23,43 +23,51 @@ from Github will be rejected. Instead, please email us if you find a bug.
## Usage

For each of the supported tensor programming frameworks, there is a
corresponding directory/subpackage which handles the activation function
transformations and weight initializations. It's up to the user to import these
and use them appropriately within their model code. Activation functions are
transformed by the function `get_transformed_activations()` in module
`activation_transform` of the appropriate subpackage. Weight sampling is done
using functions inside of the module `parameter_sampling_functions` of said
subpackage.

In addition to using these functions, the user is responsble for ensuring that
their model meets the architectural requirements of DKS/TAT, and for converting
any weighted sums in their model to "normalized sums" (which are weighted sums
whoses non-trainable weights have a sum of squares equal to 1). This package
doesn't currently include an implementation of Per-Location Normalization (PLN)
data pre-processing. While not required for CIFAR or ImageNet, PLN could
potentially be important for other datasets. See the section titled "Summary of
our method" in the [DKS paper] for more details about the requirements and
execution steps of DKS. To read about the additional requirements of TAT, such
as the subset maximizing function, refer to Appendix B of the [TAT paper].

Note that ReLUs are only partially supported by DKS, and unsupported by TAT, and
their use is *highly* discouraged. Instead, one should use Leaky ReLUs, which
corresponding subpackage which handles the activation function transformations
and weight initializations. (These are `dks.jax`, `dks.pytorch`, and
`dks.tensorflow`.) It's up to the user to import these and use them
appropriately within their model code. Activation functions are transformed by
the function `get_transformed_activations()` in the module
`activation_transform` of the appropriate subpackage. Sampling initial
parameters is done using functions inside of the module
`parameter_sampling_functions` of said subpackage. Note that in order to avoid
having to import all of the tensor programming frameworks, the user is required
to individually import whatever framework subpackage they want. e.g. `import
dks.jax`. Meanwhile, `import dks` won't actually do anything.

`get_transformed_activations()` requires the user to pass either the "maximal
slope function" for DKS, the "subnet maximizing function" for TAT with Leaky
ReLUs, or the "maximal curvature function" for TAT with smooth activation
functions. (The subnet maximizing function also handles DKS and TAT with smooth
activations.) These are special functions that encode information about the
particular model architecture. See the section titled "Summary of our method" of
the [DKS paper] for a procedure to construct the maximal slope function for a
given model, or the appendix section titled "Additional details and pseudocode
for activation function transformations" of the [TAT paper] for procedures to
construct the other two functions.

In addition to these things, the user is responsible for ensuring that their
model meets the architectural requirements of DKS/TAT, and for converting any
weighted sums into "normalized sums" (which are weighted sums whose
non-trainable weights have a sum of squares equal to 1). See the section titled
"Summary of our method" of the [DKS paper] for more details.

Note that this package doesn't currently include an implementation of
Per-Location Normalization (PLN) data pre-processing. While not required for
CIFAR or ImageNet, PLN could potentially be important for other datasets. Also
note that ReLUs are only partially supported by DKS, and unsupported by TAT, and
so their use is *highly* discouraged. Instead, one should use Leaky ReLUs, which
are fully supported by DKS, and work especially well with TAT.

Note that in order to avoid having to import all of the tensor programming
frameworks, the user is required to individually import whatever framework
subpackage they want. e.g. `import dks.jax`. Meanwhile, `import dks` won't
actually do anything.

## Example

`dks.examples.haiku.modified_resnet` is a Haiku ResNet model which has been
`dks.examples.haiku.modified_resnet` is a [Haiku] ResNet model which has been
modified as described in the DKS/TAT papers, and includes support for both DKS
and TAT. By default, it removes the normalization layers and skip connections
found in standard ResNets, making it a "vanilla network". It can be used as an
instructive example for how to build DKS/TAT models using this package. See the
section titled "Application to various modified ResNets" from the [DKS paper]
for more details.
and TAT. When constructed with its default arguments, it removes the
normalization layers and skip connections found in standard ResNets, making it a
"vanilla network". It can be used as an instructive example for how to build
DKS/TAT models using this package. See the section titled "Application to
various modified ResNets" from the [DKS paper] for more details.

## Installation

Expand All @@ -75,7 +83,7 @@ or
pip install -e git+https://github.com/deepmind/dks.git#egg=dks[<extras>]
```

or from PyPI with
Or from PyPI with

```bash
pip install dks
Expand All @@ -87,7 +95,7 @@ or
pip install dks[<extras>]
```

Here `<extras>` is a common-separated list (with no spaces) of strings that can
Here `<extras>` is a common-separated list of strings (with no spaces) that can
be passed to install extra dependencies for different tensor programming
frameworks. Valid strings are `jax`, `tf`, and `pytorch`. So for example, to
install `dks` with the extra requirements for JAX and PyTorch, one does
Expand Down Expand Up @@ -126,3 +134,4 @@ This is not an official Google product.

[DKS paper]: https://arxiv.org/abs/2110.01765
[TAT paper]: https://openreview.net/forum?id=U0k7XNTiFEq
[Haiku]: https://github.com/deepmind/dm-haiku
2 changes: 1 addition & 1 deletion dks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
# Do not directly import this package; it won't do anything. Instead, import one
# of the framework-specific subpackages.

__version__ = "0.1.0"
__version__ = "0.1.1"
38 changes: 22 additions & 16 deletions dks/base/activation_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,12 @@ def get_transformed_activations(
See the DKS paper (https://arxiv.org/abs/2110.01765) and the TAT paper
(https://openreview.net/forum?id=U0k7XNTiFEq) for details about what these
are, how they are computed, and what their parameters mean.
are, how they are computed, and what their parameters mean. A procedure to
compute the "maximal slope function" is given in the section titled "Summary
of our method" of the DKS paper. Procedures to compute the "maximal curvature
function", and the "subnet maximizing function", are given in the appendix
section titled "Additional details and pseudocode for activation function
transformations" of the TAT paper.
Note that if you are using the JAX, PyTorch, or TensorFlow frameworks, you
probably want to be using the version of get_transformed_activations() in the
Expand Down Expand Up @@ -604,21 +609,22 @@ def get_transformed_activations(
corresponds to "tau" from the paper), and defaults to 0.3. If
``tat_params`` is passed as None, it defaults to the empty dictionary (so
that the parameters will use their default values). (Default: None)
max_slope_func: A callable which computes the maximal slope function, as
defined in the DKS paper. It should take a single argument representing
the slope of each local C map at ``c=1``. If this is required (i.e. when
using DKS) but is passed as None, it will be generated using
``subnet_max_func`` if possible. (Default: None)
max_curv_func: A callable which computes the maximal curvature function. It
should take a single parameter representing the second derivative of each
local C map at c=1. If this is required (i.e. when using TAT with smooth
activation functions) but is passed as None, it will be generated using
``subnet_max_func`` if possible. (Default: None)
subnet_max_func: A callable which computes the subnetwork maximizing
function of the network (denoted ``M_{f,r}(x)`` in the TAT paper). It
should take two arguments: the input value ``x``, and a callable
``r_func`` which maps a float to a float. This is required when using TAT
with Leaky ReLUs. (Default: None)
max_slope_func: A callable which computes the "maximal slope function" of
the network, as defined in the DKS paper. It should take a single argument
representing the slope of each local C map at ``c=1``. If this is required
(i.e. when using DKS) but passed as None, it will be generated using
``subnet_max_func`` when possible. (Default: None)
max_curv_func: A callable which computes the "maximal curvature function" of
the network, as defined in the TAT paper. It should take a single
parameter representing the second derivative of each local C map at c=1.
If this is required (i.e. when using TAT with smooth activation functions)
but is passed as None, it will be generated using ``subnet_max_func`` when
possible. (Default: None)
subnet_max_func: A callable which computes the "subnetwork maximizing
function" of the network, as defined in the TAT paper (and denoted
``M_{f,r}(x)``). It should take two arguments: the input value ``x``, and
a callable ``r_func`` which maps a float to a float. This is required when
using TAT with Leaky ReLUs. (Default: None)
activation_getter: A callable which takes a string name for an activation
function and returns the (untransformed) activation function corresponding
to this name. Defaults to one returning activation functions in NumPy
Expand Down
4 changes: 3 additions & 1 deletion tests/test_activation_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def _subnet_max_func(x, r_fn, shortcut_weight=0.6):

blocks_per_group = (3, 4, 23, 3)

res_branch_subnetwork_x = r_fn(r_fn(r_fn(x)))

for i in range(4):
for j in range(blocks_per_group[i]):

Expand All @@ -46,7 +48,7 @@ def _subnet_max_func(x, r_fn, shortcut_weight=0.6):

x = r_fn(x)

return x
return max(x, res_branch_subnetwork_x)


class ActivationTransformTest(absltest.TestCase):
Expand Down

0 comments on commit 080759f

Please sign in to comment.