-
Notifications
You must be signed in to change notification settings - Fork 349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Managing initilizations via a role scheme dictionary #1030
base: master
Are you sure you want to change the base?
Changes from 6 commits
e8b1def
f1ef6ab
2f09a35
84dff78
ecb758f
e4bbab9
f3fca53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
|
||
from ..config import config | ||
from .base import _Brick, Brick, lazy | ||
from blocks.roles import WEIGHT, BIAS, FILTER, INITIAL_STATE | ||
|
||
|
||
class ActivationDocumentation(_Brick): | ||
|
@@ -127,45 +128,109 @@ class Initializable(RNGMixin, Brick): | |
``True``. | ||
use_bias : :obj:`bool`, optional | ||
Whether to use a bias. Defaults to `True`. Required by | ||
:meth:`~.Brick.initialize`. Only supported by bricks for which | ||
:attr:`has_biases` is ``True``. | ||
:meth:`~.Brick.initialize`. | ||
rng : :class:`numpy.random.RandomState` | ||
|
||
Attributes | ||
---------- | ||
has_biases : bool | ||
``False`` if the brick does not support biases, and only has | ||
:attr:`weights_init`. For an example of this, see | ||
:class:`.Bidirectional`. If this is ``False``, the brick does not | ||
support the arguments ``biases_init`` or ``use_bias``. | ||
|
||
""" | ||
has_biases = True | ||
|
||
@lazy() | ||
def __init__(self, weights_init=None, biases_init=None, use_bias=None, | ||
seed=None, **kwargs): | ||
super(Initializable, self).__init__(**kwargs) | ||
self.weights_init = weights_init | ||
if self.has_biases: | ||
self.biases_init = biases_init | ||
elif biases_init is not None or not use_bias: | ||
raise ValueError("This brick does not support biases config") | ||
if use_bias is not None: | ||
self.use_bias = use_bias | ||
def __init__(self, initialization_schemes=None, | ||
use_bias=True, seed=None, **kwargs): | ||
self.use_bias = use_bias | ||
self.seed = seed | ||
self.initialization_schemes = initialization_schemes | ||
if self.initialization_schemes is None: | ||
self.initialization_schemes = {} | ||
|
||
initialization_to_role = {"weights_init": WEIGHT, 'biases_init': BIAS, | ||
'initial_state_init': INITIAL_STATE} | ||
for key in list(kwargs.keys()): | ||
if key[-5:] == "_init": | ||
if key not in initialization_to_role: | ||
raise ValueError("The initlization scheme: {}".format(key), | ||
"is not defined by default, pass it" | ||
"via initialization_schemes") | ||
if initialization_to_role[key] in \ | ||
self.initialization_schemes.keys(): | ||
raise ValueError("All initializations are accepted either" | ||
"through initialization schemes or " | ||
"corresponding attribute but not both") | ||
else: | ||
self.initialization_schemes[initialization_to_role[ | ||
key]] = kwargs[key] | ||
kwargs.pop(key) | ||
|
||
super(Initializable, self).__init__(**kwargs) | ||
|
||
def _validate_roles(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess here you propagate initialization schemes does the hierarchy of roles. But it seems to me that you do not choose the most specific role out of those for which an initialization scheme is available. E.g. if we have an init. scheme for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, I find it a bit weird that you change |
||
high_level_roles = [] | ||
for role in self.parameter_roles: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This validation checks that initialization schemes are given for all parameters. It is very restrictive to run it in Also, we get this kind of validation for free when we actually try to initialize parameters. Instead, I would rather check that there no initialization schemes are provided that are clearly useless. This is often a sign of a bug, when something was useless was given as an argument. |
||
if role not in self.initialization_schemes.keys(): | ||
for key in list(self.initialization_schemes.keys()): | ||
if isinstance(role, type(key)): | ||
self.initialization_schemes[role] = \ | ||
self.initialization_schemes[key] | ||
high_level_roles.append(key) | ||
|
||
for key in high_level_roles: | ||
if key not in self.parameter_roles: | ||
self.initialization_schemes.pop(key) | ||
|
||
for key in self.initialization_schemes: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to be wrong, because an exception can be raised in perfectly legitimate cases when e.g. I think the only when we can (and should) raise an exception is when an init. scheme is provided for a roles that is not a parent of any of |
||
if key not in self.parameter_roles: | ||
raise ValueError("{} is not member of ".format(key) + | ||
"parameter_roles") | ||
|
||
def _push_initialization_config(self): | ||
self._collect_roles() | ||
self._validate_roles() | ||
for child in self.children: | ||
if isinstance(child, Initializable): | ||
if (isinstance(child, Initializable) and | ||
hasattr(child, 'initialization_schemes')): | ||
child.rng = self.rng | ||
if self.weights_init: | ||
child.weights_init = self.weights_init | ||
if hasattr(self, 'biases_init') and self.biases_init: | ||
for child in self.children: | ||
if (isinstance(child, Initializable) and | ||
hasattr(child, 'biases_init')): | ||
child.biases_init = self.biases_init | ||
for role, scheme in self.initialization_schemes.items(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should not propagate all the roles to the children, I think. It would not be nice, if even the bottom-most brick has initialization schemes for all the roles. We should only push those initialization schemes that can be useful for the child brick. Given that |
||
if role in child.parameter_roles: | ||
child.initialization_schemes[role] = scheme | ||
|
||
def _collect_roles(self): | ||
def get_param_roles(obj): | ||
all_roles = [] | ||
for param in obj.parameters: | ||
roles = param.tag.roles | ||
# TODO do something smarter | ||
if len(roles) > 0: | ||
all_roles.append(roles[0]) | ||
return all_roles | ||
|
||
self.parameter_roles = set(get_param_roles(self)) | ||
for child in self.children: | ||
if isinstance(child, Initializable): | ||
child._collect_roles() | ||
self.parameter_roles.update(child.parameter_roles) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this not recurse? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess it's not necessary if called from the constructor (though I think that might be suboptimal) |
||
|
||
def _initialize(self): | ||
for param in self.parameters: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a dangerous code. First, if a parameter has more than one role, it is not clear at all how it should be initialized. I would rather raise an exception upon encountering such a situation. Also, we should choose the most specific initialization scheme (e.g. the one for |
||
for role in param.tag.roles: | ||
if role in self.parameter_roles: | ||
self.initialization_schemes[role].initialize(param, | ||
self.rng) | ||
|
||
def __getattr__(self, name): | ||
if name == "weights_init": | ||
if WEIGHT in self.initialization_schemes: | ||
return self.initialization_schemes[WEIGHT] | ||
elif name == "biases_init": | ||
if BIAS in self.initialization_schemes: | ||
return self.initialization_schemes[BIAS] | ||
super(Initializable, self).__getattr__(name) | ||
|
||
def __setattr__(self, name, value): | ||
if name == 'weights_init': | ||
self.initialization_schemes[WEIGHT] = value | ||
elif name == 'biases_init': | ||
self.initialization_schemes[BIAS] = value | ||
else: | ||
super(Initializable, self).__setattr__(name, value) | ||
|
||
|
||
class LinearLike(Initializable): | ||
|
@@ -182,6 +247,7 @@ class LinearLike(Initializable): | |
first and biases (if ``use_bias`` is True) coming second. | ||
|
||
""" | ||
|
||
@property | ||
def W(self): | ||
return self.parameters[0] | ||
|
@@ -193,13 +259,6 @@ def b(self): | |
else: | ||
raise AttributeError('use_bias is False') | ||
|
||
def _initialize(self): | ||
# Use self.parameters[] references in case W and b are overridden | ||
# to return non-shared-variables. | ||
if getattr(self, 'use_bias', True): | ||
self.biases_init.initialize(self.parameters[1], self.rng) | ||
self.weights_init.initialize(self.parameters[0], self.rng) | ||
|
||
|
||
class Random(Brick): | ||
"""A mixin class for Bricks which need Theano RNGs. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the record, it looks like you have chosen to only handle a handful of
<role>_init
keyword arguments, not all of them. This is probably good, because this clearly makesinitialization_schemes
the primary way of defining the initialization schemes.But the code clearly lacks a check
if key not in initialization_to_role
. Since now we have a clearly defined white list, I would probably raise a warning, that calling the argumentssmth_init
is recommended, but still pass this argument to thesuper().__init__
call.