Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend gradlogpdf to MixtureModels #1827

Draft
wants to merge 15 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions docs/src/mixture.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ var(::UnivariateMixture)
length(::MultivariateMixture)
pdf(::AbstractMixtureModel, ::Any)
logpdf(::AbstractMixtureModel, ::Any)
gradlogpdf(::AbstractMixtureModel, ::Any)
rand(::AbstractMixtureModel)
rand!(::AbstractMixtureModel, ::AbstractArray)
```
Expand Down
1 change: 1 addition & 0 deletions docs/src/truncate.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ are defined for all truncated univariate distributions:
- [`insupport(::UnivariateDistribution, x::Any)`](@ref)
- [`pdf(::UnivariateDistribution, ::Real)`](@ref)
- [`logpdf(::UnivariateDistribution, ::Real)`](@ref)
- [`gradlogpdf(::UnivariateDistribution, ::Real)`](@ref)
- [`cdf(::UnivariateDistribution, ::Real)`](@ref)
- [`logcdf(::UnivariateDistribution, ::Real)`](@ref)
- [`logdiffcdf(::UnivariateDistribution, ::T, ::T) where {T <: Real}`](@ref)
Expand Down
1 change: 1 addition & 0 deletions docs/src/univariate.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ pdfsquaredL2norm
insupport(::UnivariateDistribution, x::Any)
pdf(::UnivariateDistribution, ::Real)
logpdf(::UnivariateDistribution, ::Real)
gradlogpdf(::UnivariateDistribution, ::Real)
loglikelihood(::UnivariateDistribution, ::AbstractArray)
cdf(::UnivariateDistribution, ::Real)
logcdf(::UnivariateDistribution, ::Real)
Expand Down
59 changes: 59 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,46 @@ end
# `_logpdf` should be implemented and has no default definition
# _logpdf(d::Distribution{ArrayLikeVariate{N}}, x::AbstractArray{<:Real,N}) where {N}

"""
gradlogpdf(d::Distribution{ArrayLikeVariate{N}}, x::AbstractArray{<:Real,N}) where {N}

Evaluate the gradient of the logarithm of the probability density function of `d` at `x`.

This function checks if the size of `x` is compatible with distribution `d`. This check can
be disabled by using `@inbounds`.

# Implementation

Instead of `gradlogpdf` one should implement `_gradlogpdf(d, x)` which does not have to check the
size of `x`.

See also: [`pdf`](@ref).
"""
@inline function gradlogpdf(
d::Distribution{ArrayLikeVariate{N}}, x::AbstractArray{<:Real,M}
) where {N,M}
if M == N
@boundscheck begin
size(x) == size(d) ||
throw(DimensionMismatch("inconsistent array dimensions"))
end
return _gradlogpdf(d, x)
else
@boundscheck begin
M > N ||
throw(DimensionMismatch(
"number of dimensions of the variates ($M) must be greater than or equal to the dimension of the distribution ($N)"
))
ntuple(i -> size(x, i), Val(N)) == size(d) ||
throw(DimensionMismatch("inconsistent array dimensions"))
end
return @inbounds map(Base.Fix1(gradlogpdf, d), eachvariate(x, variate_form(typeof(d))))
end
end

# `_gradlogpdf` should be implemented and has no default definition
# _gradlogpdf(d::Distribution{ArrayLikeVariate{N}}, x::AbstractArray{<:Real,N}) where {N}

# TODO: deprecate?
"""
pdf(d::Distribution{ArrayLikeVariate{N}}, x) where {N}
Expand Down Expand Up @@ -315,6 +355,25 @@ Base.@propagate_inbounds function logpdf(
return map(Base.Fix1(logpdf, d), x)
end

"""
gradlogpdf(d::Distribution{ArrayLikeVariate{N}}, x) where {N}

Evaluate the gradient of the logarithm of the probability density function of `d` at every
element in a collection `x`.

This function checks for every element of `x` if its size is compatible with distribution
`d`. This check can be disabled by using `@inbounds`.

Here, `x` can be
- an array of dimension `> N` with `size(x)[1:N] == size(d)`, or
- an array of arrays `xi` of dimension `N` with `size(xi) == size(d)`.
"""
Base.@propagate_inbounds function gradlogpdf(
d::Distribution{ArrayLikeVariate{N}}, x::AbstractArray{<:AbstractArray{<:Real,N}},
) where {N}
return map(Base.Fix1(gradlogpdf, d), x)
end

"""
pdf!(out, d::Distribution{ArrayLikeVariate{N}}, x) where {N}

Expand Down
15 changes: 15 additions & 0 deletions src/mixtures/mixturemodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ Here, `x` can be a single sample or an array of multiple samples.
"""
logpdf(d::AbstractMixtureModel, x::Any)

"""
gradlogpdf(d::Union{UnivariateMixture, MultivariateMixture}, x)

Evaluate the gradient of the logarithm of the (mixed) probability density function over `x`.
Here, `x` can be a single sample or an array of multiple samples.
"""
gradlogpdf(d::AbstractMixtureModel, x::Any)

"""
rand(d::Union{UnivariateMixture, MultivariateMixture})

Expand Down Expand Up @@ -359,15 +367,22 @@ function _mixlogpdf!(r::AbstractArray, d::AbstractMixtureModel, x)
return r
end

function _mixgradlogpdf1(d::AbstractMixtureModel, x)
glp = insupport(d, x) ? sum(pi * pdf(d.components[i], x) .* gradlogpdf(d.components[i], x) for (i, pi) in enumerate(probs(d)) if (!iszero(pi) && !iszero(pdf(d.components[i], x)))) / pdf(d, x) : zero(x)
return glp
end

pdf(d::UnivariateMixture, x::Real) = _mixpdf1(d, x)
logpdf(d::UnivariateMixture, x::Real) = _mixlogpdf1(d, x)
gradlogpdf(d::UnivariateMixture, x::Real) = _mixgradlogpdf1(d, x)

_pdf!(r::AbstractArray{<:Real}, d::UnivariateMixture{Discrete}, x::UnitRange) = _mixpdf!(r, d, x)
_pdf!(r::AbstractArray{<:Real}, d::UnivariateMixture, x::AbstractArray{<:Real}) = _mixpdf!(r, d, x)
_logpdf!(r::AbstractArray{<:Real}, d::UnivariateMixture, x::AbstractArray{<:Real}) = _mixlogpdf!(r, d, x)

_pdf(d::MultivariateMixture, x::AbstractVector{<:Real}) = _mixpdf1(d, x)
_logpdf(d::MultivariateMixture, x::AbstractVector{<:Real}) = _mixlogpdf1(d, x)
_gradlogpdf(d::MultivariateMixture, x::AbstractVector{<:Real}) = _mixgradlogpdf1(d, x)
_pdf!(r::AbstractArray{<:Real}, d::MultivariateMixture, x::AbstractMatrix{<:Real}) = _mixpdf!(r, d, x)
_logpdf!(r::AbstractArray{<:Real}, d::MultivariateMixture, x::AbstractMatrix{<:Real}) = _mixlogpdf!(r, d, x)

Expand Down
40 changes: 40 additions & 0 deletions test/gradlogpdf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,43 @@ using Test
[0.191919191919192, 1.080808080808081] ,atol=1.0e-8)
@test isapprox(gradlogpdf(MvTDist(5., [1., 2.], [1. 0.1; 0.1 1.]), [0.7, 0.9]),
[0.2150711513583442, 1.2111901681759383] ,atol=1.0e-8)

# Test for gradlogpdf on univariate mixture distributions against centered finite-difference on logpdf

x = [-0.2, 0.3, 0.8, 1.3, 10.5]
delta = 0.001

for d in (
MixtureModel([Normal(-4.5, 2.0)], [1.0]),
MixtureModel([Exponential(2.0)], [1.0]),
MixtureModel([Uniform(-1.0, 1.0)], [1.0]),
MixtureModel([Normal(-2.0, 3.5), Normal(-4.5, 2.0)], [0.0, 1.0]),
MixtureModel([Beta(1.5, 3.0), Chi(5.0), Chisq(7.0)], [0.4, 0.3, 0.3]),
MixtureModel([Exponential(2.0), Gamma(9.0, 0.5), Gumbel(3.5, 1.0), Laplace(7.0)], [0.3, 0.2, 0.4, 0.1]),
MixtureModel([Logistic(-6.0), LogNormal(5.5), TDist(8.0), Weibull(2.0)], [0.3, 0.2, 0.4, 0.1])
)
xs = filter(s -> insupport(d, s), x)
glp1 = gradlogpdf(d, xs)
glp2 = ( logpdf.(d, xs .+ delta) - logpdf.(d, xs .- delta) ) ./ 2delta
@test isapprox(glp1, glp2, atol = delta)
end

# Test for gradlogpdf on multivariate mixture distributions against centered finite-difference on logpdf

x = [[0.2, 0.3], [0.8, 1.3], [-1.0, 10.5]]
delta = 0.001

for d in (
MixtureModel([MvNormal([1., 2.], [1. 0.1; 0.1 1.])], [1.0]),
MixtureModel([MvNormal([1.0, 2.0], [0.4 0.2; 0.2 0.5]), MvNormal([2.0, 1.0], [0.3 0.1; 0.1 0.4])], [0.4, 0.6]),
MixtureModel([MvNormal([1.0, 2.0], [0.4 0.2; 0.2 0.5]), MvNormal([2.0, 1.0], [0.3 0.1; 0.1 0.4])], [1.0, 0.0]),
MixtureModel([MvTDist(5., [1., 2.], [1. 0.1; 0.1 1.])], [1.0]),
MixtureModel([MvNormal([1.0, 2.0], [0.4 0.2; 0.2 0.5]), MvTDist(5., [1., 2.], [1. 0.1; 0.1 1.])], [0.4, 0.6])
)
xs = filter(s -> insupport(d, s), x)
glp = gradlogpdf(d, xs)
glpx = ( logpdf(d, xs .+ [[delta, 0]]) - logpdf(d, xs .- [[delta, 0]]) ) ./ 2delta
glpy = ( logpdf(d, xs .+ [[0, delta]]) - logpdf(d, xs .- [[0, delta]]) ) ./ 2delta
@test isapprox(getindex.(glp, 1), glpx, atol=delta)
@test isapprox(getindex.(glp, 2), glpy, atol=delta)
end
Loading