Skip to content

Commit

Permalink
Implement nnabla.experimental.distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jan 9, 2020
1 parent fea0942 commit d3eed32
Show file tree
Hide file tree
Showing 10 changed files with 699 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def extopts(library_name, library_dir):
'nnabla.experimental.graph_converters',
'nnabla.experimental.parametric_function_class',
'nnabla.experimental.trainers',
'nnabla.experimental.distributions',
'nnabla.models',
'nnabla.models.imagenet',
'nnabla.models.object_detection',
Expand Down
3 changes: 3 additions & 0 deletions python/src/nnabla/experimental/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from nnabla.experimental.distributions.uniform import Uniform
from nnabla.experimental.distributions.normal import Normal
from nnabla.experimental.distributions.multivariate_normal import MultivariateNormal
94 changes: 94 additions & 0 deletions python/src/nnabla/experimental/distributions/distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) 2017 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import nnabla.functions as F


class Distribution(object):
"""Distribution base class for distribution classes.
"""

def entropy(self):
"""Get entropy of distribution.
Returns:
:class:`~nnabla.Variable`: N-D array.
"""
raise NotImplementedError

def mean(self):
"""Get mean of distribution.
Returns:
:class:`~nnabla.Variable`: N-D array.
"""
raise NotImplementedError

def stddev(self):
"""Get standard deviation of distribution.
Returns:
:class:`~nnabla.Variable`: N-D array.
"""
raise NotImplementedError

def variance(self):
"""Get variance of distribution.
Returns:
:class:`~nnabla.Variable`: N-D array.
"""
raise NotImplementedError

def prob(self, x):
"""Get probability of sampled `x` from distribution.
Args:
x (~nnabla.Variable): N-D array.
Returns:
:class:`~nnabla.Variable`: N-D array.
"""
raise NotImplementedError

def sample(self, shape):
"""Sample points from distribution.
Args:
shape (:obj:`tuple`): Shape of sampled points.
Returns:
:class:`~nnabla.Variable`: N-D array.
"""
raise NotImplementedError

def sample_n(self, n):
"""Sample points from distribution :math:`n` times.
Args:
n (int): The number of sampling points.
Returns:
:class:`~nnabla.Variable`: N-D array.
"""
samples = [self.sample() for _ in range(n)]
return F.stack(*samples, axis=1)
138 changes: 138 additions & 0 deletions python/src/nnabla/experimental/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) 2017 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np
import nnabla.functions as F

from .distribution import Distribution


class MultivariateNormal(Distribution):
"""Multivariate normal distribution.
Multivariate normal distribution defined as follows:
.. math::
p(x | \mu, \Sigma) = \frac{1}{\sqrt{(2 \pi)^k \det(\Sigma)}}
\exp(-\frac{1}{2}(x - \mu)^T \Sigma^(-1) (x - \mu))
where :math:`k` is a rank of `\Sigma`.
Args:
loc (~nnabla.Variable or numpy.ndarray): N-D array of :math:`\mu` in
definition.
scale (~nnabla.Variable or numpy.ndarray): N-D array of diagonal
entries of :math:`L` such that covariance matrix
:math:`\Sigma = L L^T`.
"""

def __init__(self, loc, scale):
assert loc.shape == scale.shape,\
'For now, loc and scale must have same shape.'
if isinstance(loc, np.ndarray):
loc = nn.Variable.from_numpy_array(loc)
loc.persistent = True
if isinstance(scale, np.ndarray):
scale = nn.Variable.from_numpy_array(scale)
scale.persistent = True
self.loc = loc
self.scale = scale

def mean(self):
"""Get mean of multivariate normal distribution.
Returns:
:class:`~nnabla.Variable`: N-D array identical to :math:`\mu`.
"""
# to avoid no parent error
return F.identity(self.loc)

def variance(self):
"""Get covariance matrix of multivariate normal distribution.
.. math::
\Sigma = L L^T
Returns:
:class:`~nnabla.Variable`: N-D array.
"""
diag = self._diag_scale()
return F.batch_matmul(diag, diag, False, True)

def prob(self, x):
"""Get probability of `x` in multivariate normal distribution.
.. math::
p(x | \mu, \Sigma) = \frac{1}{\sqrt{(2 \pi)^k \det(\Sigma)}}
\exp(-\frac{1}{2}(x - \mu)^T \Sigma^(-1) (x - \mu))
Args:
x (~nn.Variable): N-D array.
Returns:
:class:`~nnabla.Variable`: N-D array.
"""
k = self.loc.shape[1]
z = 1.0 / ((2 * np.pi) ** k * F.batch_det(self._diag_scale())) ** 0.5

diff = F.reshape(x - self.mean(), self.loc.shape + (1,), False)
inv = F.batch_inv(self._diag_scale())
y = F.batch_matmul(diff, inv, True, False)
norm = F.reshape(F.batch_matmul(y, diff, False, False), (-1,), False)
return z * F.exp(-0.5 * norm)

def entropy(self):
"""Get entropy of multivariate normal distribution.
.. math::
S = \frac{1}{2} \ln \det(2 \pi e \Sigma)
Returns:
:class:`~nnabla.Variable`: N-D array.
"""
det = F.batch_det(2.0 * np.pi * np.e * self._diag_scale())
return 0.5 * F.log(det)

def _diag_scale(self):
return F.matrix_diag(self.scale)

def sample(self, shape=None):
"""Sample points from multivariate normal distribution.
.. math::
x \sim N(\mu, \Sigma)
Args:
shape (:obj:`tuple`): Shape of sampled points. If this is omitted,
the returned shape is identical to :math:`\mu`.
Returns:
:class:`~nnabla.Variable`: N-D array.
"""
if shape is None:
shape = self.loc.shape
eps = F.randn(mu=0.0, sigma=1.0, shape=shape)
return self.mean() + self.scale * eps
131 changes: 131 additions & 0 deletions python/src/nnabla/experimental/distributions/normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright (c) 2017 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np
import nnabla.functions as F

from .distribution import Distribution


class Normal(Distribution):
"""Normal distribution.
Normal distribution defined as follows:
.. math::
p(x | \mu, \sigma^2) = \frac{1}{\sqrt{2 \pi \sigma^2}}
\exp(-\frac{(x - \mu)^2}{2\sigma^2})
Args:
loc (~nnabla.Variable or numpy.ndarray): N-D array of :math:`\mu` in
definition.
scale (~nnabla.Variable or numpy.ndarray): N-D array of :math:`\sigma`
in definition.
"""

def __init__(self, loc, scale):
assert loc.shape == scale.shape,\
'For now, loc and scale must have same shape.'
if isinstance(loc, np.ndarray):
loc = nn.Variable.from_numpy_array(loc)
loc.persistent = True
if isinstance(scale, np.ndarray):
scale = nn.Variable.from_numpy_array(scale)
scale.persistent = True
self.loc = loc
self.scale = scale

def mean(self):
"""Get mean of normal distribution.
Returns:
:class:`~nnabla.Variable`: N-D array identical to :math:`\mu`.
"""
# to avoid no parent error
return F.identity(self.loc)

def stddev(self):
"""Get standard deviation of normal distribution.
Returns:
:class:`~nnabla.Variable`: N-D array identical to :math:`\sigma`.
"""
# to avoid no parent error
return F.identity(self.scale)

def variance(self):
"""Get variance of normal distribution.
Returns:
:class:`~nnabla.Variable`: N-D array defined as :math:`\sigma^2`.
"""
return self.stddev() ** 2

def prob(self, x):
"""Get probability of :math:`x` in normal distribution.
.. math::
p(x | \mu, \sigma^2) = \frac{1}{\sqrt{2 \pi \sigma^2}}
\exp(-\frac{(x - \mu)^2}{2\sigma^2})
Args:
x (~nnabla.Variable): N-D array.
Returns:
:class:`~nnabla.Variable`: N-D array.
"""
z = 1.0 / (2 * np.pi * self.variance()) ** 0.5
return z * F.exp(-0.5 * ((x - self.mean()) ** 2) / self.variance())

def entropy(self):
"""Get entropy of normal distribution.
.. math::
S = \frac{1}{2}\log(2 \pi e \sigma^2)
Returns:
:class:`~nnabla.Variable`: N-D array.
"""
return F.log(self.stddev()) + 0.5 * np.log(2.0 * np.pi * np.e)

def sample(self, shape=None):
"""Sample points from normal distribution.
.. math::
x \sim N(\mu, \sigma^2)
Args:
shape (:obj:`tuple`): Shape of sampled points. If this is omitted,
the returned shape is identical to
:math:`\mu` and :math:`\sigma`.
Returns:
:class:`~nnabla.Variable`: N-D array.
"""
if shape is None:
shape = self.loc.shape
eps = F.randn(mu=0.0, sigma=1.0, shape=shape)
return self.mean() + self.stddev() * eps
Loading

0 comments on commit d3eed32

Please sign in to comment.