diff --git a/blocks/bricks/interfaces.py b/blocks/bricks/interfaces.py index 531fb2aa..8edf918b 100644 --- a/blocks/bricks/interfaces.py +++ b/blocks/bricks/interfaces.py @@ -1,10 +1,12 @@ """Bricks that are interfaces and/or mixins.""" import numpy +import inspect from six import add_metaclass from theano.sandbox.rng_mrg import MRG_RandomStreams from ..config import config from .base import _Brick, Brick, lazy +from blocks.roles import WEIGHT, BIAS, FILTER, INITIAL_STATE class ActivationDocumentation(_Brick): @@ -127,45 +129,121 @@ class Initializable(RNGMixin, Brick): ``True``. use_bias : :obj:`bool`, optional Whether to use a bias. Defaults to `True`. Required by - :meth:`~.Brick.initialize`. Only supported by bricks for which - :attr:`has_biases` is ``True``. + :meth:`~.Brick.initialize`. rng : :class:`numpy.random.RandomState` - Attributes - ---------- - has_biases : bool - ``False`` if the brick does not support biases, and only has - :attr:`weights_init`. For an example of this, see - :class:`.Bidirectional`. If this is ``False``, the brick does not - support the arguments ``biases_init`` or ``use_bias``. - """ - has_biases = True @lazy() - def __init__(self, weights_init=None, biases_init=None, use_bias=None, - seed=None, **kwargs): - super(Initializable, self).__init__(**kwargs) - self.weights_init = weights_init - if self.has_biases: - self.biases_init = biases_init - elif biases_init is not None or not use_bias: - raise ValueError("This brick does not support biases config") - if use_bias is not None: - self.use_bias = use_bias + def __init__(self, initialization_schemes=None, + use_bias=True, seed=None, **kwargs): + self.use_bias = use_bias self.seed = seed + self.initialization_schemes = initialization_schemes + if self.initialization_schemes is None: + self.initialization_schemes = {} + + initialization_to_role = {"weights_init": WEIGHT, 'biases_init': BIAS, + 'initial_state_init': INITIAL_STATE} + for key in list(kwargs.keys()): + if key[-5:] == "_init": + if key not in initialization_to_role: + raise ValueError("The initlization scheme: {}".format(key), + "is not defined by default, pass it" + "via initialization_schemes") + if initialization_to_role[key] in \ + self.initialization_schemes.keys(): + raise ValueError("All initializations are accepted either" + "through initialization schemes or " + "corresponding attribute but not both") + else: + self.initialization_schemes[initialization_to_role[ + key]] = kwargs[key] + kwargs.pop(key) + + super(Initializable, self).__init__(**kwargs) + + def get_scheme(role, schemes): + for key in schemes: + if role == type(key): + return key + for key in schemes: + if isinstance(role, type(key)): + return key + + def _validate_roles(self): + all_parent_roles = [] + for role in self.parameter_roles: + all_parent_roles += list(inspect.getmro(type(role))) + + for key in self.initialization_schemes: + if type(key) not in all_parent_roles: + raise ValueError("There is no parameter role" + "for initlization sheme {}".format(key)) def _push_initialization_config(self): + self._collect_roles() + self._validate_roles() for child in self.children: - if isinstance(child, Initializable): + if (isinstance(child, Initializable) and + hasattr(child, 'initialization_schemes')): child.rng = self.rng - if self.weights_init: - child.weights_init = self.weights_init - if hasattr(self, 'biases_init') and self.biases_init: - for child in self.children: - if (isinstance(child, Initializable) and - hasattr(child, 'biases_init')): - child.biases_init = self.biases_init + for role, scheme in self.initialization_schemes.items(): + if role in child.parameter_roles: + child.initialization_schemes[role] = scheme + + def _collect_roles(self): + def get_param_roles(obj): + all_roles = [] + for param in obj.parameters: + roles = param.tag.roles + # TODO do something smarter + if len(roles) > 0: + all_roles.append(roles[0]) + return all_roles + + self.parameter_roles = set(get_param_roles(self)) + for child in self.children: + if isinstance(child, Initializable): + child._collect_roles() + self.parameter_roles.update(child.parameter_roles) + + def _initialize(self): + def get_scheme(role, schemes): + if role in schemes: + return role + for key in schemes: + if role == type(key): + return key + for key in schemes: + if isinstance(role, type(key)): + return key + + for param in self.parameters: + for role in param.tag.roles: + if role in self.parameter_roles: + key = get_scheme(role, self.initialization_schemes.keys()) + if key is not None: + self.initialization_schemes[key].initialize(param, + self.rng) + continue + + def __getattr__(self, name): + if name == "weights_init": + if WEIGHT in self.initialization_schemes: + return self.initialization_schemes[WEIGHT] + elif name == "biases_init": + if BIAS in self.initialization_schemes: + return self.initialization_schemes[BIAS] + super(Initializable, self).__getattr__(name) + + def __setattr__(self, name, value): + if name == 'weights_init': + self.initialization_schemes[WEIGHT] = value + elif name == 'biases_init': + self.initialization_schemes[BIAS] = value + else: + super(Initializable, self).__setattr__(name, value) class LinearLike(Initializable): @@ -182,6 +260,7 @@ class LinearLike(Initializable): first and biases (if ``use_bias`` is True) coming second. """ + @property def W(self): return self.parameters[0] @@ -193,13 +272,6 @@ def b(self): else: raise AttributeError('use_bias is False') - def _initialize(self): - # Use self.parameters[] references in case W and b are overridden - # to return non-shared-variables. - if getattr(self, 'use_bias', True): - self.biases_init.initialize(self.parameters[1], self.rng) - self.weights_init.initialize(self.parameters[0], self.rng) - class Random(Brick): """A mixin class for Bricks which need Theano RNGs. diff --git a/blocks/bricks/lookup.py b/blocks/bricks/lookup.py index 2fd20ba4..01b0b648 100644 --- a/blocks/bricks/lookup.py +++ b/blocks/bricks/lookup.py @@ -41,9 +41,6 @@ def _allocate(self): name='W')) add_role(self.parameters[-1], WEIGHT) - def _initialize(self): - self.weights_init.initialize(self.W, self.rng) - @application(inputs=['indices'], outputs=['output']) def apply(self, indices): """Perform lookup. diff --git a/blocks/bricks/recurrent/architectures.py b/blocks/bricks/recurrent/architectures.py index 1367a4df..1ebc259c 100644 --- a/blocks/bricks/recurrent/architectures.py +++ b/blocks/bricks/recurrent/architectures.py @@ -6,6 +6,7 @@ from ..simple import Initializable, Logistic, Tanh from ...roles import add_role, WEIGHT, INITIAL_STATE from ...utils import shared_floatx_nans, shared_floatx_zeros +from ...initialization import Constant from .base import BaseRecurrent, recurrent @@ -32,6 +33,7 @@ def __init__(self, dim, activation, **kwargs): self.dim = dim children = [activation] kwargs.setdefault('children', []).extend(children) + kwargs.setdefault('initial_state_init', Constant(0.)) super(SimpleRecurrent, self).__init__(**kwargs) @property diff --git a/blocks/bricks/simple.py b/blocks/bricks/simple.py index 1243d3ea..0fe6a0bc 100644 --- a/blocks/bricks/simple.py +++ b/blocks/bricks/simple.py @@ -95,10 +95,6 @@ def _allocate(self): add_role(b, BIAS) self.parameters.append(b) - def _initialize(self): - b, = self.parameters - self.biases_init.initialize(b, self.rng) - @application(inputs=['input_'], outputs=['output']) def apply(self, input_): """Apply the linear transformation. diff --git a/blocks/roles.py b/blocks/roles.py index d672189c..6fd3e4bb 100644 --- a/blocks/roles.py +++ b/blocks/roles.py @@ -71,6 +71,9 @@ def __repr__(self): return re.sub(r'(?!^)([A-Z]+)', r'_\1', self.__class__.__name__[:-4]).upper() + def __hash__(self): + return hash(str(self)) + class InputRole(VariableRole): pass diff --git a/tests/bricks/test_recurrent.py b/tests/bricks/test_recurrent.py index f7fc45cb..78c22ed1 100644 --- a/tests/bricks/test_recurrent.py +++ b/tests/bricks/test_recurrent.py @@ -146,8 +146,7 @@ def test_many_steps(self): class TestLSTM(unittest.TestCase): def setUp(self): - self.lstm = LSTM(dim=3, weights_init=Constant(2), - biases_init=Constant(0)) + self.lstm = LSTM(dim=3, weights_init=Constant(2)) self.lstm.initialize() def test_one_step(self): @@ -244,7 +243,6 @@ def setUp(self): self.stack2 = RecurrentStack(transitions, weights_init=Constant(2), - biases_init=Constant(0), skip_connections=True) self.stack2.initialize() @@ -502,7 +500,7 @@ def setUp(self): dim=3, activation=Tanh())) self.simple = SimpleRecurrent(dim=3, weights_init=Orthogonal(), activation=Tanh(), seed=1) - self.bidir.allocate() + self.bidir.initialize() self.simple.initialize() self.bidir.children[0].parameters[0].set_value( self.simple.parameters[0].get_value())