Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Managing initilizations via a role scheme dictionary #1030

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 95 additions & 36 deletions blocks/bricks/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ..config import config
from .base import _Brick, Brick, lazy
from blocks.roles import WEIGHT, BIAS, FILTER, INITIAL_STATE


class ActivationDocumentation(_Brick):
Expand Down Expand Up @@ -127,45 +128,109 @@ class Initializable(RNGMixin, Brick):
``True``.
use_bias : :obj:`bool`, optional
Whether to use a bias. Defaults to `True`. Required by
:meth:`~.Brick.initialize`. Only supported by bricks for which
:attr:`has_biases` is ``True``.
:meth:`~.Brick.initialize`.
rng : :class:`numpy.random.RandomState`

Attributes
----------
has_biases : bool
``False`` if the brick does not support biases, and only has
:attr:`weights_init`. For an example of this, see
:class:`.Bidirectional`. If this is ``False``, the brick does not
support the arguments ``biases_init`` or ``use_bias``.

"""
has_biases = True

@lazy()
def __init__(self, weights_init=None, biases_init=None, use_bias=None,
seed=None, **kwargs):
super(Initializable, self).__init__(**kwargs)
self.weights_init = weights_init
if self.has_biases:
self.biases_init = biases_init
elif biases_init is not None or not use_bias:
raise ValueError("This brick does not support biases config")
if use_bias is not None:
self.use_bias = use_bias
def __init__(self, initialization_schemes=None,
use_bias=True, seed=None, **kwargs):
self.use_bias = use_bias
self.seed = seed
self.initialization_schemes = initialization_schemes
if self.initialization_schemes is None:
self.initialization_schemes = {}

initialization_to_role = {"weights_init": WEIGHT, 'biases_init': BIAS,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record, it looks like you have chosen to only handle a handful of <role>_init keyword arguments, not all of them. This is probably good, because this clearly makes initialization_schemes the primary way of defining the initialization schemes.

But the code clearly lacks a check if key not in initialization_to_role. Since now we have a clearly defined white list, I would probably raise a warning, that calling the arguments smth_init is recommended, but still pass this argument to the super().__init__ call.

'initial_state_init': INITIAL_STATE}
for key in list(kwargs.keys()):
if key[-5:] == "_init":
if key not in initialization_to_role:
raise ValueError("The initlization scheme: {}".format(key),
"is not defined by default, pass it"
"via initialization_schemes")
if initialization_to_role[key] in \
self.initialization_schemes.keys():
raise ValueError("All initializations are accepted either"
"through initialization schemes or "
"corresponding attribute but not both")
else:
self.initialization_schemes[initialization_to_role[
key]] = kwargs[key]
kwargs.pop(key)

super(Initializable, self).__init__(**kwargs)

def _validate_roles(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess here you propagate initialization schemes does the hierarchy of roles. But it seems to me that you do not choose the most specific role out of those for which an initialization scheme is available. E.g. if we have an init. scheme for PARAMETER and for WEIGHT, and now we are looking for an init. scheme for RECURRENT_WEIGHT, then we should always choose the scheme given for WEIGHT. In the code
below it seems like the choice will depend on the order of the dictionary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I find it a bit weird that you change initialization_schemes. This might make debugging quite a bit hard. I would rather put the search for the most specific init. scheme directly in initialize.

high_level_roles = []
for role in self.parameter_roles:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This validation checks that initialization schemes are given for all parameters. It is very restrictive to run it in _push_initialization_config as you do it, because this method can be called in the middle of specifying initialization schemes for bricks.

Also, we get this kind of validation for free when we actually try to initialize parameters. Instead, I would rather check that there no initialization schemes are provided that are clearly useless. This is often a sign of a bug, when something was useless was given as an argument.

if role not in self.initialization_schemes.keys():
for key in list(self.initialization_schemes.keys()):
if isinstance(role, type(key)):
self.initialization_schemes[role] = \
self.initialization_schemes[key]
high_level_roles.append(key)

for key in high_level_roles:
if key not in self.parameter_roles:
self.initialization_schemes.pop(key)

for key in self.initialization_schemes:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be wrong, because an exception can be raised in perfectly legitimate cases when e.g. self.parameter_roles has the only role RECURRENT_WEIGHTS and key is WEIGHTS.

I think the only when we can (and should) raise an exception is when an init. scheme is provided for a roles that is not a parent of any of parameter_roles.

if key not in self.parameter_roles:
raise ValueError("{} is not member of ".format(key) +
"parameter_roles")

def _push_initialization_config(self):
self._collect_roles()
self._validate_roles()
for child in self.children:
if isinstance(child, Initializable):
if (isinstance(child, Initializable) and
hasattr(child, 'initialization_schemes')):
child.rng = self.rng
if self.weights_init:
child.weights_init = self.weights_init
if hasattr(self, 'biases_init') and self.biases_init:
for child in self.children:
if (isinstance(child, Initializable) and
hasattr(child, 'biases_init')):
child.biases_init = self.biases_init
for role, scheme in self.initialization_schemes.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not propagate all the roles to the children, I think. It would not be nice, if even the bottom-most brick has initialization schemes for all the roles. We should only push those initialization schemes that can be useful for the child brick. Given that push_initialization_config can be called before allocate, we can not use self.parameter_roles, at least given the way you collect them.

if role in child.parameter_roles:
child.initialization_schemes[role] = scheme

def _collect_roles(self):
def get_param_roles(obj):
all_roles = []
for param in obj.parameters:
roles = param.tag.roles
# TODO do something smarter
if len(roles) > 0:
all_roles.append(roles[0])
return all_roles

self.parameter_roles = set(get_param_roles(self))
for child in self.children:
if isinstance(child, Initializable):
child._collect_roles()
self.parameter_roles.update(child.parameter_roles)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this not recurse?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's not necessary if called from the constructor (though I think that might be suboptimal)


def _initialize(self):
for param in self.parameters:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a dangerous code. First, if a parameter has more than one role, it is not clear at all how it should be initialized. I would rather raise an exception upon encountering such a situation. Also, we should choose the most specific initialization scheme (e.g. the one for RECURRENT_WEIGHT and not for just WEIGHT if the parameter is RECURRENT_WEIGHT).

for role in param.tag.roles:
if role in self.parameter_roles:
self.initialization_schemes[role].initialize(param,
self.rng)

def __getattr__(self, name):
if name == "weights_init":
if WEIGHT in self.initialization_schemes:
return self.initialization_schemes[WEIGHT]
elif name == "biases_init":
if BIAS in self.initialization_schemes:
return self.initialization_schemes[BIAS]
super(Initializable, self).__getattr__(name)

def __setattr__(self, name, value):
if name == 'weights_init':
self.initialization_schemes[WEIGHT] = value
elif name == 'biases_init':
self.initialization_schemes[BIAS] = value
else:
super(Initializable, self).__setattr__(name, value)


class LinearLike(Initializable):
Expand All @@ -182,6 +247,7 @@ class LinearLike(Initializable):
first and biases (if ``use_bias`` is True) coming second.

"""

@property
def W(self):
return self.parameters[0]
Expand All @@ -193,13 +259,6 @@ def b(self):
else:
raise AttributeError('use_bias is False')

def _initialize(self):
# Use self.parameters[] references in case W and b are overridden
# to return non-shared-variables.
if getattr(self, 'use_bias', True):
self.biases_init.initialize(self.parameters[1], self.rng)
self.weights_init.initialize(self.parameters[0], self.rng)


class Random(Brick):
"""A mixin class for Bricks which need Theano RNGs.
Expand Down
3 changes: 0 additions & 3 deletions blocks/bricks/lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ def _allocate(self):
name='W'))
add_role(self.parameters[-1], WEIGHT)

def _initialize(self):
self.weights_init.initialize(self.W, self.rng)

@application(inputs=['indices'], outputs=['output'])
def apply(self, indices):
"""Perform lookup.
Expand Down
2 changes: 2 additions & 0 deletions blocks/bricks/recurrent/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ..simple import Initializable, Logistic, Tanh
from ...roles import add_role, WEIGHT, INITIAL_STATE
from ...utils import shared_floatx_nans, shared_floatx_zeros
from ...initialization import Constant
from .base import BaseRecurrent, recurrent


Expand All @@ -32,6 +33,7 @@ def __init__(self, dim, activation, **kwargs):
self.dim = dim
children = [activation]
kwargs.setdefault('children', []).extend(children)
kwargs.setdefault('initial_state_init', Constant(0.))
super(SimpleRecurrent, self).__init__(**kwargs)

@property
Expand Down
4 changes: 0 additions & 4 deletions blocks/bricks/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,6 @@ def _allocate(self):
add_role(b, BIAS)
self.parameters.append(b)

def _initialize(self):
b, = self.parameters
self.biases_init.initialize(b, self.rng)

@application(inputs=['input_'], outputs=['output'])
def apply(self, input_):
"""Apply the linear transformation.
Expand Down
3 changes: 3 additions & 0 deletions blocks/roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def __repr__(self):
return re.sub(r'(?!^)([A-Z]+)', r'_\1',
self.__class__.__name__[:-4]).upper()

def __hash__(self):
return hash(str(self))


class InputRole(VariableRole):
pass
Expand Down
6 changes: 2 additions & 4 deletions tests/bricks/test_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ def test_many_steps(self):

class TestLSTM(unittest.TestCase):
def setUp(self):
self.lstm = LSTM(dim=3, weights_init=Constant(2),
biases_init=Constant(0))
self.lstm = LSTM(dim=3, weights_init=Constant(2))
self.lstm.initialize()

def test_one_step(self):
Expand Down Expand Up @@ -244,7 +243,6 @@ def setUp(self):

self.stack2 = RecurrentStack(transitions,
weights_init=Constant(2),
biases_init=Constant(0),
skip_connections=True)
self.stack2.initialize()

Expand Down Expand Up @@ -502,7 +500,7 @@ def setUp(self):
dim=3, activation=Tanh()))
self.simple = SimpleRecurrent(dim=3, weights_init=Orthogonal(),
activation=Tanh(), seed=1)
self.bidir.allocate()
self.bidir.initialize()
self.simple.initialize()
self.bidir.children[0].parameters[0].set_value(
self.simple.parameters[0].get_value())
Expand Down