Skip to content

Commit

Permalink
reorganize initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
memimo committed Apr 12, 2016
1 parent 21ce4cf commit 22b5a09
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 65 deletions.
76 changes: 46 additions & 30 deletions blocks/bricks/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ..config import config
from .base import _Brick, Brick, lazy
from blocks.roles import WEIGHT, BIAS, FILTER, INITIAL_STATE


class ActivationDocumentation(_Brick):
Expand Down Expand Up @@ -132,6 +133,8 @@ class Initializable(RNGMixin, Brick):
"""

initializable_roles = ['WEIGHT', 'BIAS', 'FILTER', 'INITIAL_STATE']

@lazy()
def __init__(self, initialization_schemes=None, use_bias=True,
seed=None, **kwargs):
Expand All @@ -142,41 +145,62 @@ def __init__(self, initialization_schemes=None, use_bias=True,
if self.initialization_schemes is None:
self.initialization_schemes = {}

kwargs_ = {}
for key in kwargs:

initialization_to_role = {"weights_init": 'WEIGHT', 'biases_init': 'BIAS',
'initial_state_init': 'INITIAL_STATE'}
for key in list(kwargs.keys()):
if key[-5:] == "_init":
if key in self.initialization_schemes:
if initialization_to_role[key] in self.initialization_schemes.keys():
raise ValueError("All initializations are accepted either"
"through initialization_schemes or "
"correspodong attribute but not both")
"through initialization schemes or "
"corresponding attribute but not both")
else:
self.initialization_schemes[key] = kwargs[key]
else:
kwargs_[key] = kwargs[key]
self.initialization_schemes[initialization_to_role[key]] = kwargs[key]
kwargs.pop(key)

for key in self.initialization_schemes:
if key not in self.initializable_roles:
raise ValueError("{} is not member of ".format(str(key)) +
"initializable_roles")

super(Initializable, self).__init__(**kwargs)


def _validate_roles_schmes(self):
for role in self.parameter_roles:
if role not in self.initialization_schemes.keys():
found = False
for init_role in list(self.initialization_schemes.keys()):
if isinstance(eval(role), type(eval(init_role))):
self.initialization_schemes[role] = self.initialization_schemes[init_role]
found = True
if not found:
raise ValueError("There is no initialization_schemes"
" defined for {}".format(role))

super(Initializable, self).__init__(**kwargs_)
self._collect_roles()

def _push_initialization_config(self):
self._collect_roles()
self._validate_roles_schmes()
for child in self.children:
if (isinstance(child, Initializable) and
hasattr(child, 'initialization_schemes')):
for role in child.initialization_schemes:
if role not in self.parameter_roles:
raise ValueError("The parameter role: " +
"{} is not defined in".format(role) +
"in the class parameter_roles")

for child in self.children:
if isinstance(child, Initializable):
child.rng = self.rng
child.initialization_schemes = self.initialization_schemes
for role, scheme in self.initialization_schemes.items():
child.initialization_schemes[role] = scheme


def _collect_roles(self):
for child in self.children:
if isinstance(child, Initializable):
self.parameter_roles.update(child.parameter_roles)
for param in self.parameters:
for role in param.tag.roles:
if str(role) in self.initializable_roles:
self.parameter_roles.update(set([str(role)]))

def _initialize(self):
for param in self.parameters:
for role in param.tag.roles:
if str(role) in self.initializable_roles:
self.initialization_schemes[str(role)].initialize(param, self.rng)

class LinearLike(Initializable):
"""Initializable subclass with logic for :class:`Linear`-like classes.
Expand All @@ -203,14 +227,6 @@ def b(self):
else:
raise AttributeError('use_bias is False')

def _initialize(self):
# Use self.parameters[] references in case W and b are overridden
# to return non-shared-variables.
if self.use_bias:
self.initialization_schemes['biases_init'].initialize(
self.parameters[1], self.rng)
self.initialization_schemes['weights_init'].initialize(
self.parameters[0], self.rng)


class Random(Brick):
Expand Down
3 changes: 0 additions & 3 deletions blocks/bricks/lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ def _allocate(self):
name='W'))
add_role(self.parameters[-1], WEIGHT)

def _initialize(self):
self.weights_init.initialize(self.W, self.rng)

@application(inputs=['indices'], outputs=['output'])
def apply(self, indices):
"""Perform lookup.
Expand Down
34 changes: 13 additions & 21 deletions blocks/bricks/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from blocks.bricks import Initializable, Logistic, Tanh, Linear
from blocks.bricks.base import Application, application, Brick, lazy
from blocks.initialization import NdarrayInitialization
from blocks.initialization import NdarrayInitialization, Constant
from blocks.roles import add_role, WEIGHT, INITIAL_STATE
from blocks.utils import (pack, shared_floatx_nans, shared_floatx_zeros,
dict_union, dict_subset, is_shared_variable)
Expand Down Expand Up @@ -279,6 +279,8 @@ class SimpleRecurrent(BaseRecurrent, Initializable):
def __init__(self, dim, activation, **kwargs):
self.dim = dim
children = [activation] + kwargs.get('children', [])
if not 'initial_state_init' in kwargs:
kwargs['initial_state_init'] = Constant(0.)
super(SimpleRecurrent, self).__init__(children=children, **kwargs)

@property
Expand All @@ -297,13 +299,10 @@ def _allocate(self):
self.parameters.append(shared_floatx_nans((self.dim, self.dim),
name="W"))
add_role(self.parameters[0], WEIGHT)
self.parameters.append(shared_floatx_zeros((self.dim,),
self.parameters.append(shared_floatx_nans((self.dim,),
name="initial_state"))
add_role(self.parameters[1], INITIAL_STATE)

def _initialize(self):
self.weights_init.initialize(self.W, self.rng)

@recurrent(sequences=['inputs', 'mask'], states=['states'],
outputs=['states'], contexts=[])
def apply(self, inputs, states, mask=None):
Expand Down Expand Up @@ -386,6 +385,9 @@ def __init__(self, dim, activation=None, gate_activation=None, **kwargs):

children = ([self.activation, self.gate_activation] +
kwargs.get('children', []))

if not 'initial_state_init' in kwargs:
kwargs['initial_state_init'] = Constant(0.)
super(LSTM, self).__init__(children=children, **kwargs)

def get_dim(self, name):
Expand All @@ -408,9 +410,9 @@ def _allocate(self):
name='W_cell_to_out')
# The underscore is required to prevent collision with
# the `initial_state` application method
self.initial_state_ = shared_floatx_zeros((self.dim,),
self.initial_state_ = shared_floatx_nans((self.dim,),
name="initial_state")
self.initial_cells = shared_floatx_zeros((self.dim,),
self.initial_cells = shared_floatx_nans((self.dim,),
name="initial_cells")
add_role(self.W_state, WEIGHT)
add_role(self.W_cell_to_in, WEIGHT)
Expand All @@ -423,10 +425,6 @@ def _allocate(self):
self.W_state, self.W_cell_to_in, self.W_cell_to_forget,
self.W_cell_to_out, self.initial_state_, self.initial_cells]

def _initialize(self):
for weights in self.parameters[:4]:
self.weights_init.initialize(weights, self.rng)

@recurrent(sequences=['inputs', 'mask'], states=['states', 'cells'],
contexts=[], outputs=['states', 'cells'])
def apply(self, inputs, states, cells, mask=None):
Expand Down Expand Up @@ -533,6 +531,9 @@ def __init__(self, dim, activation=None, gate_activation=None,
self.gate_activation = gate_activation

children = [activation, gate_activation] + kwargs.get('children', [])

if not 'initial_state_init' in kwargs:
kwargs['initial_state_init'] = Constant(0.)
super(GatedRecurrent, self).__init__(children=children, **kwargs)

@property
Expand All @@ -557,22 +558,13 @@ def _allocate(self):
name='state_to_state'))
self.parameters.append(shared_floatx_nans((self.dim, 2 * self.dim),
name='state_to_gates'))
self.parameters.append(shared_floatx_zeros((self.dim,),
self.parameters.append(shared_floatx_nans((self.dim,),
name="initial_state"))
for i in range(2):
if self.parameters[i]:
add_role(self.parameters[i], WEIGHT)
add_role(self.parameters[2], INITIAL_STATE)

def _initialize(self):
self.weights_init.initialize(self.state_to_state, self.rng)
state_to_update = self.weights_init.generate(
self.rng, (self.dim, self.dim))
state_to_reset = self.weights_init.generate(
self.rng, (self.dim, self.dim))
self.state_to_gates.set_value(
numpy.hstack([state_to_update, state_to_reset]))

@recurrent(sequences=['mask', 'inputs', 'gate_inputs'],
states=['states'], outputs=['states'], contexts=[])
def apply(self, inputs, gate_inputs, states, mask=None):
Expand Down
5 changes: 0 additions & 5 deletions blocks/bricks/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def __init__(self, input_dim, output_dim, **kwargs):
super(Linear, self).__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
self.parameter_roles = set(['weights_init', 'biases_init'])

def _allocate(self):
W = shared_floatx_nans((self.input_dim, self.output_dim), name='W')
Expand Down Expand Up @@ -96,10 +95,6 @@ def _allocate(self):
add_role(b, BIAS)
self.parameters.append(b)

def _initialize(self):
b, = self.parameters
self.biases_init.initialize(b, self.rng)

@application(inputs=['input_'], outputs=['output'])
def apply(self, input_):
"""Apply the linear transformation.
Expand Down
4 changes: 2 additions & 2 deletions tests/bricks/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def test_attention_recurrent():
state_names=wrapped.apply.states,
attended_dim=attended_dim, match_dim=attended_dim)
recurrent = AttentionRecurrent(wrapped, attention, seed=1234)
recurrent.weights_init = IsotropicGaussian(0.5)
recurrent.biases_init = Constant(0)
recurrent.initialization_schemes['WEIGHT'] = IsotropicGaussian(0.5)
recurrent.initialization_schemes['BIAS'] = Constant(0)
recurrent.initialize()

attended = tensor.tensor3("attended")
Expand Down
6 changes: 3 additions & 3 deletions tests/bricks/test_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def setUp(self):
dim=3, activation=Tanh()))
self.simple = SimpleRecurrent(dim=3, weights_init=Orthogonal(),
activation=Tanh(), seed=1)
self.bidir.allocate()
self.bidir.initialize()
self.simple.initialize()
self.bidir.children[0].parameters[0].set_value(
self.simple.parameters[0].get_value())
Expand Down Expand Up @@ -542,8 +542,8 @@ def setUp(self):
for _ in range(3)]
self.stack = RecurrentStack(self.layers)
for fork in self.stack.forks:
fork.weights_init = Identity(1)
fork.biases_init = Constant(0)
fork.initialization_schemes['WEIGHT'] = Identity(1)
fork.initialization_schemes['BIAS'] = Constant(0)
self.stack.initialize()

self.x_val = 0.1 * numpy.asarray(
Expand Down
2 changes: 1 addition & 1 deletion tests/bricks/test_sequence_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_integer_sequence_generator():
assert outputs_val.shape == (n_steps, batch_size)
assert outputs_val.dtype == 'int64'
assert costs_val.shape == (n_steps, batch_size)
assert_allclose(states_val.sum(), -17.854, rtol=1e-5)
assert_allclose(states_val.sum(), -17.889, rtol=1e-5)
assert_allclose(costs_val.sum(), 482.868, rtol=1e-5)
assert outputs_val.sum() == 629

Expand Down

0 comments on commit 22b5a09

Please sign in to comment.