diff --git a/docs/src/mixture.md b/docs/src/mixture.md index e6b24b103c..5549de039a 100644 --- a/docs/src/mixture.md +++ b/docs/src/mixture.md @@ -99,6 +99,7 @@ var(::UnivariateMixture) length(::MultivariateMixture) pdf(::AbstractMixtureModel, ::Any) logpdf(::AbstractMixtureModel, ::Any) +gradlogpdf(::AbstractMixtureModel, ::Any) rand(::AbstractMixtureModel) rand!(::AbstractMixtureModel, ::AbstractArray) ``` diff --git a/docs/src/truncate.md b/docs/src/truncate.md index 4d4f63f9f3..2362d166b3 100644 --- a/docs/src/truncate.md +++ b/docs/src/truncate.md @@ -26,6 +26,7 @@ are defined for all truncated univariate distributions: - [`insupport(::UnivariateDistribution, x::Any)`](@ref) - [`pdf(::UnivariateDistribution, ::Real)`](@ref) - [`logpdf(::UnivariateDistribution, ::Real)`](@ref) +- [`gradlogpdf(::UnivariateDistribution, ::Real)`](@ref) - [`cdf(::UnivariateDistribution, ::Real)`](@ref) - [`logcdf(::UnivariateDistribution, ::Real)`](@ref) - [`logdiffcdf(::UnivariateDistribution, ::T, ::T) where {T <: Real}`](@ref) diff --git a/docs/src/univariate.md b/docs/src/univariate.md index 0b2c48c6ea..c04da39136 100644 --- a/docs/src/univariate.md +++ b/docs/src/univariate.md @@ -73,6 +73,7 @@ pdfsquaredL2norm insupport(::UnivariateDistribution, x::Any) pdf(::UnivariateDistribution, ::Real) logpdf(::UnivariateDistribution, ::Real) +gradlogpdf(::UnivariateDistribution, ::Real) loglikelihood(::UnivariateDistribution, ::AbstractArray) cdf(::UnivariateDistribution, ::Real) logcdf(::UnivariateDistribution, ::Real) diff --git a/src/mixtures/mixturemodel.jl b/src/mixtures/mixturemodel.jl index c3d3b1f919..ff573d822c 100644 --- a/src/mixtures/mixturemodel.jl +++ b/src/mixtures/mixturemodel.jl @@ -87,6 +87,13 @@ Here, `x` can be a single sample or an array of multiple samples. """ logpdf(d::AbstractMixtureModel, x::Any) +""" + gradlogpdf(d::Union{UnivariateMixture, MultivariateMixture}, x) + +Evaluate the gradient of the logarithm of the (mixed) probability density function over a single sample `x`. +""" +gradlogpdf(d::AbstractMixtureModel, x::Any) + """ rand(d::Union{UnivariateMixture, MultivariateMixture}) @@ -362,6 +369,38 @@ end pdf(d::UnivariateMixture, x::Real) = _mixpdf1(d, x) logpdf(d::UnivariateMixture, x::Real) = _mixlogpdf1(d, x) +function gradlogpdf(d::UnivariateMixture, x::Real) + ps = probs(d) + cs = components(d) + + # `d` is expected to have at least one distribution, otherwise this will just error + psi, idxps = iterate(ps) + csi, idxcs = iterate(cs) + pdfx1 = pdf(csi, x) + pdfx = psi * pdfx1 + glp = pdfx * gradlogpdf(csi, x) + if iszero(psi) || iszero(pdfx) + glp = zero(glp) + end + + while (iterps = iterate(ps, idxps)) !== nothing && (itercs = iterate(cs, idxcs)) !== nothing + psi, idxps = iterps + csi, idxcs = itercs + if !iszero(psi) + pdfxi = pdf(csi, x) + if !iszero(pdfxi) + pipdfxi = psi * pdfxi + pdfx += pipdfxi + glp += pipdfxi * gradlogpdf(csi, x) + end + end + end + if !iszero(pdfx) # else glp is already zero + glp /= pdfx + end + return glp +end + _pdf!(r::AbstractArray{<:Real}, d::UnivariateMixture{Discrete}, x::UnitRange) = _mixpdf!(r, d, x) _pdf!(r::AbstractArray{<:Real}, d::UnivariateMixture, x::AbstractArray{<:Real}) = _mixpdf!(r, d, x) _logpdf!(r::AbstractArray{<:Real}, d::UnivariateMixture, x::AbstractArray{<:Real}) = _mixlogpdf!(r, d, x) @@ -371,6 +410,37 @@ _logpdf(d::MultivariateMixture, x::AbstractVector{<:Real}) = _mixlogpdf1(d, x) _pdf!(r::AbstractArray{<:Real}, d::MultivariateMixture, x::AbstractMatrix{<:Real}) = _mixpdf!(r, d, x) _logpdf!(r::AbstractArray{<:Real}, d::MultivariateMixture, x::AbstractMatrix{<:Real}) = _mixlogpdf!(r, d, x) +function gradlogpdf(d::MultivariateMixture, x::AbstractVector{<:Real}) + ps = probs(d) + cs = components(d) + + # `d` is expected to have at least one distribution, otherwise this will just error + psi, idxps = iterate(ps) + csi, idxcs = iterate(cs) + pdfx1 = pdf(csi, x) + pdfx = psi * pdfx1 + glp = pdfx * gradlogpdf(csi, x) + if iszero(psi) || iszero(pdfx) + fill!(glp, zero(eltype(glp))) + end + + while (iterps = iterate(ps, idxps)) !== nothing && (itercs = iterate(cs, idxcs)) !== nothing + psi, idxps = iterps + csi, idxcs = itercs + if !iszero(psi) + pdfxi = pdf(csi, x) + if !iszero(pdfxi) + pipdfxi = psi * pdfxi + pdfx += pipdfxi + glp .+= pipdfxi .* gradlogpdf(csi, x) + end + end + end + if !iszero(pdfx) # else glp is already zero + glp ./= pdfx + end + return glp +end ## component-wise pdf and logpdf diff --git a/test/gradlogpdf.jl b/test/gradlogpdf.jl index f4216a67e6..faa1ca0f92 100644 --- a/test/gradlogpdf.jl +++ b/test/gradlogpdf.jl @@ -25,3 +25,61 @@ using Test [0.191919191919192, 1.080808080808081] ,atol=1.0e-8) @test isapprox(gradlogpdf(MvTDist(5., [1., 2.], [1. 0.1; 0.1 1.]), [0.7, 0.9]), [0.2150711513583442, 1.2111901681759383] ,atol=1.0e-8) + +# Test for gradlogpdf on univariate mixture distributions + +x = [-0.2, 0.3, 0.8, 1.0, 1.3, 10.5] +delta = 0.0001 + +for di in ( + Normal(-4.5, 2.0), + Exponential(2.0), + Uniform(0.0, 1.0), + Beta(2.0, 3.0), + Beta(0.5, 0.5) +) + d = MixtureModel([di], [1.0]) + glp1 = gradlogpdf.(d, x) + glp2 = gradlogpdf.(di, x) + @info "Testing `gradlogpdf` on $d" + @test isapprox(glp1, glp2, atol = 0.01) +end + +for d in ( + MixtureModel([Normal(1//1, 2//1), Beta(2//1, 3//1), Exponential(3//2)], [3//10, 4//10, 3//10]), + MixtureModel([Normal(-2.0, 3.5), Normal(-4.5, 2.0)], [0.0, 1.0]), + MixtureModel([Beta(1.5, 3.0), Chi(5.0), Chisq(7.0)], [0.4, 0.3, 0.3]), + MixtureModel([Exponential(2.0), Gamma(9.0, 0.5), Gumbel(3.5, 1.0), Laplace(7.0)], [0.3, 0.2, 0.4, 0.1]), + MixtureModel([Logistic(-6.0), LogNormal(5.5), TDist(8.0), Weibull(2.0)], [0.3, 0.2, 0.4, 0.1]) +) + + # finite differences don't handle when not in the interior of the support + xs = filter(s -> all(insupport.(d, [s - delta, s, s + delta])), x) + + glp1 = gradlogpdf.(d, xs) + glp2 = ( logpdf.(d, xs .+ delta) - logpdf.(d, xs .- delta) ) ./ 2delta + @info "Testing `gradlogpdf` on $d" + @test isapprox(glp1, glp2, atol = 0.01) +end + +# Test for gradlogpdf on multivariate mixture distributions against centered finite-difference on logpdf + +x = [[0.2, 0.3], [0.8, 1.3], [-1.0, 10.5]] +delta = 0.001 + +for d in ( + MixtureModel([MvNormal([1., 2.], [1. 0.1; 0.1 1.])], [1.0]), + MixtureModel([MvNormal([1.0, 2.0], [0.4 0.2; 0.2 0.5]), MvNormal([2.0, 1.0], [0.3 0.1; 0.1 0.4])], [0.4, 0.6]), + MixtureModel([MvNormal([3.0, 2.0], [0.2 0.3; 0.3 0.5]), MvNormal([1.0, 2.0], [0.4 0.2; 0.2 0.5]), MvNormal([2.0, 1.0], [0.3 0.1; 0.1 0.4])], [0.0, 1.0, 0.0]), + MixtureModel([MvTDist(5., [1., 2.], [1. 0.1; 0.1 1.])], [1.0]), + MixtureModel([MvNormal([1.0, 2.0], [0.4 0.2; 0.2 0.5]), MvTDist(5., [1., 2.], [1. 0.1; 0.1 1.])], [0.4, 0.6]) +) + xs = filter(s -> insupport(d, s), x) + for xi in xs + glp = gradlogpdf(d, xi) + glpx = ( logpdf(d, xi .+ [delta, 0]) - logpdf(d, xi .- [delta, 0]) ) ./ 2delta + glpy = ( logpdf(d, xi .+ [0, delta]) - logpdf(d, xi .- [0, delta]) ) ./ 2delta + @test isapprox(glp[1], glpx, atol=delta) + @test isapprox(glp[2], glpy, atol=delta) + end +end