diff --git a/src/GPLikelihoods.jl b/src/GPLikelihoods.jl index c385e71..ba60d0d 100644 --- a/src/GPLikelihoods.jl +++ b/src/GPLikelihoods.jl @@ -29,6 +29,7 @@ export Link, ProbitLink, NormalCDFLink, SoftMaxLink +export nlatent export expected_loglikelihood # Links @@ -37,6 +38,16 @@ include("links.jl") # Likelihoods abstract type AbstractLikelihood end +""" + nlatent(::AbstractLikelihood)::Int + +Returns the number of latent Gaussian processes needed to build the likelihood. +In other terms the input dimensionality passed to the likelihood from the GP perspective. +It is typically 1, but for some likelihoods like [`CategoricalLikelihood`](@ref) or +[`HeteroscedasticGaussianLikelihood`](@ref) multiple latent GPs are necessary. +""" +nlatent(::AbstractLikelihood) = 1 # Default number of latent GPs required is 1 + include("expectations.jl") include("likelihoods/bernoulli.jl") include("likelihoods/categorical.jl") diff --git a/src/TestInterface.jl b/src/TestInterface.jl index 4556ec3..d2b4acc 100644 --- a/src/TestInterface.jl +++ b/src/TestInterface.jl @@ -1,10 +1,11 @@ module TestInterface using Functors +using ..GPLikelihoods using Random using Test -function test_interface(rng::AbstractRNG, lik, out_dist, D_in=1; functor_args=()) +function test_interface(rng::AbstractRNG, lik, out_dist, D_in=nlatent(lik); functor_args=()) N = 10 T = Float64 # TODO test Float32 as well f, fs = if D_in == 1 @@ -53,7 +54,7 @@ samples is correct and if the functor works as intended. - `functor_args=()`: a collection of symbols of arguments to match functor parameters with. ... """ -function test_interface(lik, out_dist, D_in=1; kwargs...) +function test_interface(lik, out_dist, D_in=nlatent(lik); kwargs...) return test_interface(Random.GLOBAL_RNG, lik, out_dist, D_in; kwargs...) end diff --git a/src/likelihoods/categorical.jl b/src/likelihoods/categorical.jl index 6caedbf..b998812 100644 --- a/src/likelihoods/categorical.jl +++ b/src/likelihoods/categorical.jl @@ -1,7 +1,7 @@ """ - CategoricalLikelihood(l=BijectiveSimplexLink(softmax)) + CategoricalLikelihood(n::Int, l=BijectiveSimplexLink(softmax)) -Categorical likelihood is to be used if we assume that the +Categorical likelihood with `n` categories is to be used if we assume that the uncertainty associated with the data follows a [Categorical distribution](https://en.wikipedia.org/wiki/Categorical_distribution). Assuming a distribution with `n` categories: @@ -27,10 +27,16 @@ For more details, see the end of the section of this [Wikipedia link](https://en where it corresponds to Variant 1 and 2. """ struct CategoricalLikelihood{Tl<:AbstractLink} <: AbstractLikelihood + n::Int # Number of categories invlink::Tl end -CategoricalLikelihood(l=BijectiveSimplexLink(softmax)) = CategoricalLikelihood(link(l)) +function CategoricalLikelihood(n, l=BijectiveSimplexLink(softmax)) + return CategoricalLikelihood(n, link(l)) +end + +nlatent(l::CategoricalLikelihood) = l.n +nlatent(l::CategoricalLikelihood{<:BijectiveSimplexLink}) = l.n - 1 function (l::CategoricalLikelihood)(f::AbstractVector{<:Real}) return Categorical(l.invlink(f)) diff --git a/src/likelihoods/gaussian.jl b/src/likelihoods/gaussian.jl index ca5101b..70f9e06 100644 --- a/src/likelihoods/gaussian.jl +++ b/src/likelihoods/gaussian.jl @@ -55,6 +55,8 @@ end HeteroscedasticGaussianLikelihood(l=exp) = HeteroscedasticGaussianLikelihood(link(l)) +nlatent(::HeteroscedasticGaussianLikelihood) = 2 + function (l::HeteroscedasticGaussianLikelihood)(f::AbstractVector{<:Real}) return Normal(f[1], sqrt(l.invlink(f[2]))) end diff --git a/test/likelihoods/categorical.jl b/test/likelihoods/categorical.jl index 8ed48be..f736579 100644 --- a/test/likelihoods/categorical.jl +++ b/test/likelihoods/categorical.jl @@ -1,13 +1,15 @@ @testset "CategoricalLikelihood" begin - @test CategoricalLikelihood() isa + nclass = 4 + @test CategoricalLikelihood(nclass) isa CategoricalLikelihood{<:GPLikelihoods.BijectiveSimplexLink} + @test CategoricalLikelihood(nclass, softmax) isa CategoricalLikelihood{SoftMaxLink} + @test CategoricalLikelihood(nclass, SoftMaxLink()) isa + CategoricalLikelihood{SoftMaxLink} - @test CategoricalLikelihood(softmax) isa CategoricalLikelihood{SoftMaxLink} - @test CategoricalLikelihood(SoftMaxLink()) isa CategoricalLikelihood{SoftMaxLink} - - OUT_DIM = 4 - lik_bijective = CategoricalLikelihood() - test_interface(lik_bijective, Categorical, OUT_DIM) - lik_nonbijective = CategoricalLikelihood(softmax) - test_interface(lik_nonbijective, Categorical, OUT_DIM) + lik_bijective = CategoricalLikelihood(nclass) + test_interface(lik_bijective, Categorical) + @test nlatent(lik_bijective) == nclass - 1 + lik_nonbijective = CategoricalLikelihood(nclass, softmax) + test_interface(lik_nonbijective, Categorical) + @test nlatent(lik_nonbijective) == nclass end diff --git a/test/likelihoods/gaussian.jl b/test/likelihoods/gaussian.jl index fe9f210..d184dcb 100644 --- a/test/likelihoods/gaussian.jl +++ b/test/likelihoods/gaussian.jl @@ -10,8 +10,7 @@ end end lik = HeteroscedasticGaussianLikelihood() - IN_DIM = 3 - OUT_DIM = 2 # one for the mean the other for the log-standard deviation N = 10 - test_interface(lik, Normal, 2) + test_interface(lik, Normal) + @test nlatent(lik) == 2 end