Skip to content

Commit

Permalink
managing initilizations via a a role scheme dictionary
Browse files Browse the repository at this point in the history
  • Loading branch information
memimo committed Apr 8, 2016
1 parent 9f8189e commit 21ce4cf
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 30 deletions.
72 changes: 42 additions & 30 deletions blocks/bricks/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,45 +127,55 @@ class Initializable(RNGMixin, Brick):
``True``.
use_bias : :obj:`bool`, optional
Whether to use a bias. Defaults to `True`. Required by
:meth:`~.Brick.initialize`. Only supported by bricks for which
:attr:`has_biases` is ``True``.
:meth:`~.Brick.initialize`.
rng : :class:`numpy.random.RandomState`
Attributes
----------
has_biases : bool
``False`` if the brick does not support biases, and only has
:attr:`weights_init`. For an example of this, see
:class:`.Bidirectional`. If this is ``False``, the brick does not
support the arguments ``biases_init`` or ``use_bias``.
"""
has_biases = True

@lazy()
def __init__(self, weights_init=None, biases_init=None, use_bias=None,
def __init__(self, initialization_schemes=None, use_bias=True,
seed=None, **kwargs):
super(Initializable, self).__init__(**kwargs)
self.weights_init = weights_init
if self.has_biases:
self.biases_init = biases_init
elif biases_init is not None or not use_bias:
raise ValueError("This brick does not support biases config")
if use_bias is not None:
self.use_bias = use_bias
self.use_bias = use_bias
self.seed = seed
self.initialization_schemes = initialization_schemes
self.parameter_roles = set([])
if self.initialization_schemes is None:
self.initialization_schemes = {}

kwargs_ = {}
for key in kwargs:
if key[-5:] == "_init":
if key in self.initialization_schemes:
raise ValueError("All initializations are accepted either"
"through initialization_schemes or "
"correspodong attribute but not both")
else:
self.initialization_schemes[key] = kwargs[key]
else:
kwargs_[key] = kwargs[key]

super(Initializable, self).__init__(**kwargs_)
self._collect_roles()

def _push_initialization_config(self):
for child in self.children:
if (isinstance(child, Initializable) and
hasattr(child, 'initialization_schemes')):
for role in child.initialization_schemes:
if role not in self.parameter_roles:
raise ValueError("The parameter role: " +
"{} is not defined in".format(role) +
"in the class parameter_roles")

for child in self.children:
if isinstance(child, Initializable):
child.rng = self.rng
if self.weights_init:
child.weights_init = self.weights_init
if hasattr(self, 'biases_init') and self.biases_init:
for child in self.children:
if (isinstance(child, Initializable) and
hasattr(child, 'biases_init')):
child.biases_init = self.biases_init
child.initialization_schemes = self.initialization_schemes

def _collect_roles(self):
for child in self.children:
if isinstance(child, Initializable):
self.parameter_roles.update(child.parameter_roles)


class LinearLike(Initializable):
Expand Down Expand Up @@ -196,9 +206,11 @@ def b(self):
def _initialize(self):
# Use self.parameters[] references in case W and b are overridden
# to return non-shared-variables.
if getattr(self, 'use_bias', True):
self.biases_init.initialize(self.parameters[1], self.rng)
self.weights_init.initialize(self.parameters[0], self.rng)
if self.use_bias:
self.initialization_schemes['biases_init'].initialize(
self.parameters[1], self.rng)
self.initialization_schemes['weights_init'].initialize(
self.parameters[0], self.rng)


class Random(Brick):
Expand Down
1 change: 1 addition & 0 deletions blocks/bricks/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self, input_dim, output_dim, **kwargs):
super(Linear, self).__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
self.parameter_roles = set(['weights_init', 'biases_init'])

def _allocate(self):
W = shared_floatx_nans((self.input_dim, self.output_dim), name='W')
Expand Down

0 comments on commit 21ce4cf

Please sign in to comment.