Skip to content

Commit

Permalink
Format and increment
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaspinder committed Mar 3, 2021
1 parent 24fbde8 commit acf7d2f
Show file tree
Hide file tree
Showing 21 changed files with 177 additions and 151 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
author = "Thomas Pinder"

# The full version, including alpha/beta/rc tags
release = "0.2.0"
release = "0.3"


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion gpjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
from .predict import mean, variance
from .sampling import random_variable, sample

__version__ = "0.3.0"
__version__ = "0.3.1"
2 changes: 1 addition & 1 deletion gpjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __call__(self, x: Array, y: Array) -> Array:

@dispatch(RBF)
def initialise(kernel: RBF):
return {"lengthscale": jnp.array([1.0]*kernel.ndims), "variance": jnp.array([1.0])}
return {"lengthscale": jnp.array([1.0] * kernel.ndims), "variance": jnp.array([1.0])}


def squared_distance(x: Array, y: Array):
Expand Down
2 changes: 1 addition & 1 deletion gpjax/mean_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Zero(MeanFunction):

def __call__(self, x: Array) -> Array:
out_shape = (x.shape[0], self.output_dim)
return jnp.zeros(shape = out_shape)
return jnp.zeros(shape=out_shape)


@dispatch(Zero)
Expand Down
5 changes: 4 additions & 1 deletion gpjax/objectives/mlls.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def mll(params: dict, x: Array, y: Array, priors: dict = None):
log_prior_density = evaluate_prior(params, priors)
constant = jnp.array(-1.0) if negative else jnp.array(1.0)
return constant * (random_variable.log_prob(y.squeeze()).mean() + log_prior_density)

return mll


Expand All @@ -49,7 +50,9 @@ def marginal_ll(
negative: bool = False,
jitter: float = 1e-6,
) -> Callable:
def mll(params: dict, x: Array, y: Array, priors: dict = {'latent': tfd.Normal(loc=0., scale=1.)}):
def mll(
params: dict, x: Array, y: Array, priors: dict = {"latent": tfd.Normal(loc=0.0, scale=1.0)}
):
params = untransform(params, transformation)
n = x.shape[0]
link = link_function(gp.likelihood)
Expand Down
23 changes: 13 additions & 10 deletions gpjax/parameters/priors.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import warnings

import jax.numpy as jnp
from jax.interpreters.ad import JVPTracer
from jax.interpreters.partial_eval import DynamicJaxprTracer
from multipledispatch import dispatch
from tensorflow_probability.substrates.jax import distributions as tfd
from ..gps import NonConjugatePosterior
import warnings

from ..gps import NonConjugatePosterior
from ..types import Array, NoneType


Expand All @@ -16,24 +17,26 @@ def log_density(param: jnp.DeviceArray, density: tfd.Distribution) -> Array:

@dispatch(dict, NoneType)
def evaluate_prior(params: dict, priors: dict) -> Array:
return jnp.array(0.)
return jnp.array(0.0)


@dispatch(dict, dict)
def evaluate_prior(params: dict, priors: dict) -> Array:
lpd = jnp.array(0)
for param, val in priors.items():
lpd+=jnp.sum(log_density(params[param], priors[param]))
lpd += jnp.sum(log_density(params[param], priors[param]))
return lpd


@dispatch(NonConjugatePosterior, dict)
def prior_checks(gp: NonConjugatePosterior, priors: dict) -> dict:
if 'latent' in priors.keys():
latent_prior = priors['latent']
if latent_prior.name != 'Normal':
warnings.warn(f'A {latent_prior.name} distribution prior has been placed on the latent function. It is strongly afvised that a unit-Gaussian prior is used.')
if "latent" in priors.keys():
latent_prior = priors["latent"]
if latent_prior.name != "Normal":
warnings.warn(
f"A {latent_prior.name} distribution prior has been placed on the latent function. It is strongly afvised that a unit-Gaussian prior is used."
)
return priors
else:
priors['latent'] = tfd.Normal(loc=0., scale=1.)
return priors
priors["latent"] = tfd.Normal(loc=0.0, scale=1.0)
return priors
2 changes: 1 addition & 1 deletion gpjax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@

# Array = Union[jnp.ndarray, ShardedDeviceArray, jnp.DeviceArray] # Cannot currently dispatch on a Union type
# Data = Tuple[Array, Array]
NoneType = type(None)
NoneType = type(None)
18 changes: 11 additions & 7 deletions gpjax/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax.numpy as jnp
from typing import Tuple

import jax.numpy as jnp
from multipledispatch import dispatch

from .types import Array
Expand Down Expand Up @@ -50,11 +51,13 @@ def standardise(x: jnp.DeviceArray) -> Tuple[jnp.DeviceArray, jnp.DeviceArray, j
"""
xmean = jnp.mean(x, axis=0)
xstd = jnp.std(x, axis=0)
return (x-xmean)/xstd, xmean, xstd
return (x - xmean) / xstd, xmean, xstd


@dispatch(jnp.DeviceArray, jnp.DeviceArray, jnp.DeviceArray)
def standardise(x: jnp.DeviceArray, xmean: jnp.DeviceArray, xstd: jnp.DeviceArray) -> jnp.DeviceArray:
def standardise(
x: jnp.DeviceArray, xmean: jnp.DeviceArray, xstd: jnp.DeviceArray
) -> jnp.DeviceArray:
"""
Standardise a given matrix with respect to a given mean and standard deviation. This is primarily designed for
standardising a test set of data with respect to the training data.
Expand All @@ -64,11 +67,12 @@ def standardise(x: jnp.DeviceArray, xmean: jnp.DeviceArray, xstd: jnp.DeviceArra
:param xstd: A precomputed standard deviation vector
:return: A matrix of standardised values
"""
return (x-xmean)/xstd

return (x - xmean) / xstd


def unstandardise(x: jnp.DeviceArray, xmean: jnp.DeviceArray, xstd: jnp.DeviceArray) -> jnp.DeviceArray:
def unstandardise(
x: jnp.DeviceArray, xmean: jnp.DeviceArray, xstd: jnp.DeviceArray
) -> jnp.DeviceArray:
"""
Unstandardise a given matrix with respect to a previously computed mean and standard deviation. This is designed
for remapping a matrix back onto its original scale.
Expand All @@ -78,4 +82,4 @@ def unstandardise(x: jnp.DeviceArray, xmean: jnp.DeviceArray, xstd: jnp.DeviceAr
:param xstd: A standard deviation vector.
:return: A matrix of unstandardised values.
"""
return (x*xstd) + xmean
return (x * xstd) + xmean
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def parse_requirements_file(filename):

setup(
name="GPJax",
version="0.3.0",
version="0.3.1",
author="Thomas Pinder",
author_email="[email protected]",
packages=find_packages(".", exclude=["tests"]),
Expand Down
28 changes: 15 additions & 13 deletions tests/objectives/test_mlls.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,38 @@
from gpjax.objectives import marginal_ll
from gpjax import Prior
from gpjax.likelihoods import Bernoulli, Gaussian
from gpjax.kernels import RBF
from gpjax.parameters import transform, SoftplusTransformation, initialise
from typing import Callable

import jax.numpy as jnp
import jax.random as jr
import pytest
from typing import Callable
from tensorflow_probability.substrates.jax import distributions as tfd

from gpjax import Prior
from gpjax.kernels import RBF
from gpjax.likelihoods import Bernoulli, Gaussian
from gpjax.objectives import marginal_ll
from gpjax.parameters import SoftplusTransformation, initialise, transform


def test_conjugate():
posterior = Prior(kernel = RBF()) * Gaussian()
posterior = Prior(kernel=RBF()) * Gaussian()
mll = marginal_ll(posterior)
assert isinstance(mll, Callable)
neg_mll = marginal_ll(posterior, negative=True)
x = jnp.linspace(-1., 1., 20).reshape(-1, 1)
x = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1)
y = jnp.sin(x)
params = transform(params=initialise(posterior), transformation=SoftplusTransformation)
assert neg_mll(params, x, y) == jnp.array(-1.)*mll(params, x, y)
assert neg_mll(params, x, y) == jnp.array(-1.0) * mll(params, x, y)


def test_non_conjugate():
posterior = Prior(kernel = RBF()) * Bernoulli()
posterior = Prior(kernel=RBF()) * Bernoulli()
mll = marginal_ll(posterior)
assert isinstance(mll, Callable)
neg_mll = marginal_ll(posterior, negative=True)
n = 20
x = jnp.linspace(-1., 1., n).reshape(-1, 1)
x = jnp.linspace(-1.0, 1.0, n).reshape(-1, 1)
y = jnp.sin(x)
params = transform(params=initialise(posterior, n), transformation=SoftplusTransformation)
assert neg_mll(params, x, y) == jnp.array(-1.)*mll(params, x, y)
assert neg_mll(params, x, y) == jnp.array(-1.0) * mll(params, x, y)


def test_prior_mll():
Expand All @@ -54,4 +56,4 @@ def test_prior_mll():
mll_eval_priors = mll(params, x, y, priors)

assert pytest.approx(mll_eval) == jnp.array(-115.72332969)
assert pytest.approx(mll_eval_priors) == jnp.array(-118.97202259)
assert pytest.approx(mll_eval_priors) == jnp.array(-118.97202259)
27 changes: 14 additions & 13 deletions tests/parameters/test_base.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,34 @@
from gpjax.parameters.base import initialise, _initialise_hyperparams, complete
import jax.numpy as jnp
import pytest

from gpjax import Prior
from gpjax.kernels import RBF
from gpjax.likelihoods import Bernoulli, Gaussian
from gpjax.mean_functions import Zero
from gpjax.likelihoods import Gaussian, Bernoulli
import jax.numpy as jnp
import pytest
from gpjax.parameters.base import _initialise_hyperparams, complete, initialise


def test_complete():
posterior = Prior(kernel = RBF()) * Gaussian()
partial_params = {'lengthscale': jnp.array(1.0)}
posterior = Prior(kernel=RBF()) * Gaussian()
partial_params = {"lengthscale": jnp.array(1.0)}
full_params = complete(partial_params, posterior)
assert list(full_params.keys()) == ['lengthscale', 'variance', 'obs_noise']
assert list(full_params.keys()) == ["lengthscale", "variance", "obs_noise"]


def test_initialise():
posterior = Prior(kernel = RBF()) * Gaussian()
posterior = Prior(kernel=RBF()) * Gaussian()
params = initialise(posterior)
assert list(params.keys()) == ['lengthscale', 'variance', 'obs_noise']
assert list(params.keys()) == ["lengthscale", "variance", "obs_noise"]


@pytest.mark.parametrize('n', [1, 10])
@pytest.mark.parametrize("n", [1, 10])
def test_non_conjugate_initialise(n):
posterior = Prior(kernel=RBF()) * Bernoulli()
params = initialise(posterior, n)
assert list(params.keys()) == ['lengthscale', 'variance', 'latent']
assert params['latent'].shape == (n, 1)
assert list(params.keys()) == ["lengthscale", "variance", "latent"]
assert params["latent"].shape == (n, 1)


def test_hyperparametr_initialise():
params = _initialise_hyperparams(RBF(), Zero())
assert list(params.keys()) == ['lengthscale', 'variance']
assert list(params.keys()) == ["lengthscale", "variance"]
45 changes: 23 additions & 22 deletions tests/parameters/test_priors.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from gpjax.parameters import log_density
from gpjax.parameters.priors import evaluate_prior, prior_checks
import jax.numpy as jnp
import pytest
from tensorflow_probability.substrates.jax import distributions as tfd

from gpjax.gps import Prior
from gpjax.kernels import RBF
from gpjax.likelihoods import Bernoulli
from tensorflow_probability.substrates.jax import distributions as tfd
import pytest
import jax.numpy as jnp
from gpjax.parameters import log_density
from gpjax.parameters.priors import evaluate_prior, prior_checks


@pytest.mark.parametrize('x', [-1., 0., 1.])
@pytest.mark.parametrize("x", [-1.0, 0.0, 1.0])
def test_lpd(x):
val = jnp.array(x)
dist = tfd.Normal(loc=0., scale=1.)
dist = tfd.Normal(loc=0.0, scale=1.0)
lpd = log_density(val, dist)
assert lpd is not None

Expand All @@ -22,9 +23,9 @@ def test_prior_evaluation():
value.
"""
params = {
"lengthscale": jnp.array([1.]),
"variance": jnp.array([1.]),
"obs_noise": jnp.array([1.]),
"lengthscale": jnp.array([1.0]),
"variance": jnp.array([1.0]),
"obs_noise": jnp.array([1.0]),
}
priors = {
"lengthscale": tfd.Gamma(1.0, 1.0),
Expand All @@ -40,22 +41,22 @@ def test_none_prior():
Test that multiple dispatch is working in the case of no priors.
"""
params = {
"lengthscale": jnp.array([1.]),
"variance": jnp.array([1.]),
"obs_noise": jnp.array([1.]),
"lengthscale": jnp.array([1.0]),
"variance": jnp.array([1.0]),
"obs_noise": jnp.array([1.0]),
}
lpd = evaluate_prior(params, None)
assert lpd == 0.
assert lpd == 0.0


def test_incomplete_priors():
"""
Test the case where a user specifies priors for some, but not all, parameters.
"""
params = {
"lengthscale": jnp.array([1.]),
"variance": jnp.array([1.]),
"obs_noise": jnp.array([1.]),
"lengthscale": jnp.array([1.0]),
"variance": jnp.array([1.0]),
"obs_noise": jnp.array([1.0]),
}
priors = {
"lengthscale": tfd.Gamma(1.0, 1.0),
Expand All @@ -66,20 +67,20 @@ def test_incomplete_priors():


def test_checks():
incomplete_priors = {'lengthscale': jnp.array([1.])}
incomplete_priors = {"lengthscale": jnp.array([1.0])}
posterior = Prior(kernel=RBF()) * Bernoulli()
priors = prior_checks(posterior, incomplete_priors)
assert 'latent' in priors.keys()
assert 'variance' not in priors.keys()
assert "latent" in priors.keys()
assert "variance" not in priors.keys()


def test_check_needless():
complete_prior = {
"lengthscale": tfd.Gamma(1.0, 1.0),
"variance": tfd.Gamma(2.0, 2.0),
"obs_noise": tfd.Gamma(3.0, 3.0),
"latent": tfd.Normal(loc=0., scale=1.)
"latent": tfd.Normal(loc=0.0, scale=1.0),
}
posterior = Prior(kernel=RBF()) * Bernoulli()
priors = prior_checks(posterior, complete_prior)
assert priors == complete_prior
assert priors == complete_prior
Loading

0 comments on commit acf7d2f

Please sign in to comment.