Skip to content

Commit

Permalink
Fixes and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ararslan committed Jun 28, 2024
1 parent f74d096 commit 8653c2d
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 23 deletions.
20 changes: 10 additions & 10 deletions src/truncated/lognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,28 @@
# Given `truncate(LogNormal(μ, σ), a, b)`, return `truncate(Normal(μ, σ), log(a), log(b))`
function _truncnorm(d::Truncated{<:LogNormal})
μ, σ = params(d.untruncated)
a = d.lower === nothing ? nothing : log(minimum(d))
b = d.upper === nothing ? nothing : log(maximum(d))
T = partype(d)
a = d.lower === nothing ? nothing : log(T(minimum(d)))
b = d.upper === nothing ? nothing : log(T(maximum(d)))
return truncated(Normal(μ, σ), a, b)
end

mean(d::Truncated{<:LogNormal}) = mgf(_truncnorm(d), 1)

function var(d::Truncated{<:LogNormal})
tn = _truncnorm(d)
m1 = mgf(tn, 1)
m2 = sqrt(mgf(tn, 2))
return (m2 - m1) * (m2 + m1)
# Ensure the variance doesn't end up negative, which can occur due to numerical issues
return max(mgf(tn, 2) - mgf(tn, 1)^2, 0)
end

function skewness(d::Truncated{<:LogNormal})
tn = _truncnorm(d)
m1 = mgf(tn, 1)
m2 = sqrt(mgf(tn, 2))
m2 = mgf(tn, 2)
m3 = mgf(tn, 3)
v = (m2 - m1) * (m2 + m1)
return (m3 - 3 * m1 * v - m1^3) / (v * sqrt(v))
sqm1 = m1^2
v = m2 - sqm1
return (m3 + m1 * (-3 * m2 + 2 * sqm1)) / (v * sqrt(v))
end

function kurtosis(d::Truncated{<:LogNormal})
Expand All @@ -35,8 +36,7 @@ function kurtosis(d::Truncated{<:LogNormal})
m2 = mgf(tn, 2)
m3 = mgf(tn, 3)
m4 = mgf(tn, 4)
sm2 = sqrt(m2)
v = (sm2 - m1) * (sm2 + m1)
v = m2 - m1^2
return evalpoly(m1, (m4, -4m3, 6m2, 0, -3)) / v^2 - 3
end

Expand Down
5 changes: 3 additions & 2 deletions src/truncated/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ function entropy(d::Truncated{<:Normal{<:Real},Continuous})
end

function mgf(d::Truncated{<:Normal{<:Real},Continuous}, t::Real)
a, b = extrema(d)
T = promote_type(partype(d), typeof(t), typeof(a))
T = promote_type(partype(d), typeof(t))
a = T(minimum(d))
b = T(maximum(d))
if isnan(a) || isnan(b) # TODO: Disallow constructing `Truncated` with a `NaN` bound?
return T(NaN)
elseif isinf(a) && isinf(b) && a != b
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const tests = [
"truncated/exponential",
"truncated/uniform",
"truncated/discrete_uniform",
"truncated/lognormal",
"censored",
"univariate/continuous/normal",
"univariate/continuous/laplace",
Expand Down
11 changes: 11 additions & 0 deletions test/testutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ function _linspace(a::Float64, b::Float64, n::Int)
return r
end

# Enables testing against values computed at high precision by transforming an expression
# that uses numeric literals and constants to wrap those in `big()`, similar to how the
# high-precision values for irrational constants are defined with `Base.@irrational` and
# in IrrationalConstants.jl. See e.g. `test/truncated/normal.jl` for example use.
bigly(x) = x
bigly(x::Symbol) = x in (, :ℯ, :Inf, :NaN) ? Expr(:call, :big, x) : x
bigly(x::Real) = Expr(:call, :big, x)
bigly(x::Expr) = (map!(bigly, x.args, x.args); x)
macro bigly(ex)
return esc(bigly(ex))
end

#################################################
#
Expand Down
36 changes: 36 additions & 0 deletions test/truncated/lognormal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using Distributions, Test
using Distributions: expectation

naive_moment(d, n, μ, σ²) == sqrt(σ²); expectation(x -> ((x - μ) / σ)^n, d))

@testset "Truncated log normal" begin
@testset "truncated(LogNormal{$T}(0, 1), ℯ⁻², ℯ²)" for T in (Float32, Float64, BigFloat)
d = truncated(LogNormal{T}(zero(T), one(T)), exp(T(-2)), exp(T(2)))
tn = truncated(Normal{BigFloat}(big(0.0), big(1.0)), -2, 2)
bigmean = mgf(tn, 1)
bigvar = mgf(tn, 2) - bigmean^2
@test @inferred(mean(d)) bigmean
@test @inferred(var(d)) bigvar
@test @inferred(median(d)) one(T)
@test @inferred(skewness(d)) naive_moment(d, 3, bigmean, bigvar)
@test @inferred(kurtosis(d)) naive_moment(d, 4, bigmean, bigvar) - big(3)
@test mean(d) isa T
end
@testset "Bound with no effect" begin
# Uses the example distribution from issue #709, though what's tested here is
# mostly unrelated to that issue (aside from `mean` not erroring).
# The specified left truncation at 0 has no effect for `LogNormal`
d1 = truncated(LogNormal(1, 5), 0, 1e5)
@test mean(d1) 0 atol=eps()
v1 = var(d1)
@test v1 0 atol=eps()
# Without a `max(_, 0)`, this would be within machine precision of 0 (as above) but
# numerically negative, which could cause downstream issues that assume a nonnegative
# variance
@test v1 > 0
# Compare results with not specifying a lower bound at all
d2 = truncated(LogNormal(1, 5); upper=1e5)
@test mean(d1) == mean(d2)
@test var(d1) == var(d2)
end
end
15 changes: 4 additions & 11 deletions test/truncated/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,12 @@ end
end
end

bigly(x) = x
bigly(x::Symbol) = x === || x === :ℯ ? Expr(:call, :big, x) : x
bigly(x::Real) = Expr(:call, :big, x)
bigly(x::Expr) = (map!(bigly, x.args, x.args); x)
macro bigly(ex)
return esc(bigly(ex))
end

@testset "Truncated normal MGF" begin
sqrt2 = sqrt(big(2))
two = big(2)
sqrt2 = sqrt(two)
invsqrt2 = inv(sqrt2)
inv2sqrt2 = inv(big(2) * sqrt2)
twoerfsqrt2 = big(2) * erf(sqrt2)
inv2sqrt2 = inv(two * sqrt2)
twoerfsqrt2 = two * erf(sqrt2)

for T in (Float32, Float64, BigFloat)
d = truncated(Normal{T}(zero(T), one(T)), -2, 2)
Expand Down

0 comments on commit 8653c2d

Please sign in to comment.