Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add loglogistic, logitexp, log1mlogistic and logit1mexp #82

Merged
merged 29 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
8fa3801
Add `loglogistic`, `logitexp`, `log1mlogistic` and `logit1mexp`
andrewjradcliffe May 3, 2024
218d508
Register inverses; why not?
andrewjradcliffe May 3, 2024
f89c9c5
Update documentation
andrewjradcliffe May 3, 2024
33e6677
Fix failing test: be slightly less stringent
andrewjradcliffe May 3, 2024
51b9a5b
Re-style documentation
andrewjradcliffe May 3, 2024
9f24823
Use alternate design for `loglogistic`
andrewjradcliffe May 3, 2024
a7f9cdd
Simplify dispatch cases
andrewjradcliffe May 4, 2024
223327d
Add `ChangesOfVariables` support
andrewjradcliffe May 4, 2024
27322fc
Fix duplication in docstring
andrewjradcliffe May 4, 2024
716b886
Be as careful as possible in the evaluation; we get ulp error of `log…
andrewjradcliffe May 4, 2024
3da3f8a
Apply the same conversion technique as `loglogistic`
andrewjradcliffe May 4, 2024
8ee3db1
Eliminate 3 lines of code
andrewjradcliffe May 4, 2024
c9ab3f3
Hide the proper handling of integers and rationals.
andrewjradcliffe May 4, 2024
ac17f3c
Add tests of inverses
andrewjradcliffe May 4, 2024
16d6c9c
Add ChainRules extension and respective tests
andrewjradcliffe May 4, 2024
0d65559
Add ChangesOfVariables tests
andrewjradcliffe May 4, 2024
5d84266
Fix typo
andrewjradcliffe May 4, 2024
86ac5f5
Update documentation
andrewjradcliffe May 4, 2024
eb1633a
Add tests to ensure correctness for subtypes of `Real` in `Base`
andrewjradcliffe May 4, 2024
46c0e60
Simplify definitions
andrewjradcliffe May 5, 2024
18a58ef
Remove ChainRules definitions
andrewjradcliffe May 5, 2024
519a6a4
Also remove ChainRules tests
andrewjradcliffe May 5, 2024
716b671
Match expected form of `ChangesOfVariables` tests
andrewjradcliffe May 5, 2024
58133ad
Fix oversight in function type dispatches
andrewjradcliffe May 5, 2024
1b0f914
Add `ChainRulesCore` support back
andrewjradcliffe May 6, 2024
2aa9f44
Revert "Add `ChainRulesCore` support back"
andrewjradcliffe May 17, 2024
177cebb
Add extensive return type tests
andrewjradcliffe May 24, 2024
0d60abd
Fix minor oversight from running tests locally
andrewjradcliffe May 24, 2024
15cb4ef
Add tests for `NaN`
andrewjradcliffe May 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Various special functions based on `log` and `exp` moved from [StatsFuns.jl](htt

The original authors of these functions are the StatsFuns.jl contributors.

LogExpFunctions supports [`InverseFunctions.inverse`](https://github.com/JuliaMath/InverseFunctions.jl) and [`ChangesOfVariables.test_with_logabsdet_jacobian`](https://github.com/JuliaMath/ChangesOfVariables.jl) for `log1mexp`, `log1pexp`, `log2mexp`, `logexpm1`, `logistic`, `logit`, and `logcosh` (no inverse).
LogExpFunctions supports [`InverseFunctions.inverse`](https://github.com/JuliaMath/InverseFunctions.jl) and [`ChangesOfVariables.test_with_logabsdet_jacobian`](https://github.com/JuliaMath/ChangesOfVariables.jl) for `log1mexp`, `log1pexp`, `log2mexp`, `logexpm1`, `logistic`, `logit`, `loglogistic`, `logitexp`, `log1mlogistic`, `logit1mexp`, and `logcosh` (no inverse).

```@docs
xlogx
Expand All @@ -31,4 +31,8 @@ softmax!
softmax
cloglog
cexpexp
loglogistic
logitexp
log1mlogistic
logit1mexp
```
5 changes: 5 additions & 0 deletions ext/LogExpFunctionsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,4 +187,9 @@ end
ChainRulesCore.@scalar_rule(cloglog(x), (-inv((1 - x) * log1p(-x)),))
ChainRulesCore.@scalar_rule(cexpexp(x), (-xexpx(-exp(x)),))

ChainRulesCore.@scalar_rule(loglogistic(x::Real), (logistic(-x),))
ChainRulesCore.@scalar_rule(log1mlogistic(x::Real), (-logistic(x),))
ChainRulesCore.@scalar_rule(logitexp(x::Real), (inv(1 - exp(x)),))
ChainRulesCore.@scalar_rule(logit1mexp(x::Real), (-inv(1 - exp(x)),))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO these are not needed - AD handles the functions already fine since rules for log1pexp etc. are defined.

Suggested change
ChainRulesCore.@scalar_rule(loglogistic(x::Real), (logistic(-x),))
ChainRulesCore.@scalar_rule(log1mlogistic(x::Real), (-logistic(x),))
ChainRulesCore.@scalar_rule(logitexp(x::Real), (inv(1 - exp(x)),))
ChainRulesCore.@scalar_rule(logit1mexp(x::Real), (-inv(1 - exp(x)),))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As directed, these definitions have been removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking through the issues, #36 seems to indicate that these functions should support ChainRulesCore forward and reverse rules.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that's just not correct. These new functions are supported by ChainRulesCore-based AD systems just fine without any additional definitions. See https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/which_functions_need_rules.html for some additional information on this matter.

end # module
20 changes: 20 additions & 0 deletions ext/LogExpFunctionsChangesOfVariablesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,24 @@ function ChangesOfVariables.with_logabsdet_jacobian(::typeof(logcosh), x::Real)
return y, log1mexp(a) - z
end

function ChangesOfVariables.with_logabsdet_jacobian(::typeof(loglogistic), x::Real)
y = loglogistic(x)
return y, y - x
end

function ChangesOfVariables.with_logabsdet_jacobian(::typeof(log1mlogistic), x::Real)
y = log1mlogistic(x)
return y, x + y
end

function ChangesOfVariables.with_logabsdet_jacobian(::typeof(logitexp), x::Real)
y = logitexp(x)
return y, y - x
end

function ChangesOfVariables.with_logabsdet_jacobian(::typeof(logit1mexp), x::Real)
y = logit1mexp(x)
return y, -y - x
end

end # module
6 changes: 6 additions & 0 deletions ext/LogExpFunctionsInverseFunctionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,10 @@ InverseFunctions.inverse(::typeof(logistic)) = logit
InverseFunctions.inverse(::typeof(cloglog)) = cexpexp
InverseFunctions.inverse(::typeof(cexpexp)) = cloglog

InverseFunctions.inverse(::typeof(loglogistic)) = logitexp
InverseFunctions.inverse(::typeof(logitexp)) = loglogistic

InverseFunctions.inverse(::typeof(log1mlogistic)) = logit1mexp
InverseFunctions.inverse(::typeof(logit1mexp)) = log1mlogistic

end # module
3 changes: 2 additions & 1 deletion src/LogExpFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import LinearAlgebra

export xlogx, xlogy, xlog1py, xexpx, xexpy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1,
softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, logsumexp!, softmax,
softmax!, logcosh, logabssinh, cloglog, cexpexp
softmax!, logcosh, logabssinh, cloglog, cexpexp,
loglogistic, logitexp, log1mlogistic, logit1mexp

# expm1(::Float16) is not defined in older Julia versions,
# hence for better Float16 support we use an internal function instead
Expand Down
65 changes: 65 additions & 0 deletions src/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,68 @@ $(SIGNATURES)
Compute the complementary double exponential, `1 - exp(-exp(x))`.
"""
cexpexp(x) = -_expm1(-exp(x))

#=
this uses the identity:

log(logistic(x)) = -log(1 + exp(-x))
=#
"""
$(SIGNATURES)

Return `log(logistic(x))`, computed more carefully and with fewer calls
than the naive composition of functions.

Its inverse is the [`logitexp`](@ref) function.
"""
loglogistic(x::Real) = -log1pexp(-float(x))

#=
this uses the identity:

logit(exp(x)) = log(exp(x) / (1 + exp(x))) = -log(exp(-x) - 1)
=#
"""
$(SIGNATURES)

Return `logit(exp(x))`, computed more carefully and with fewer calls than
the naive composition of functions.

Its inverse is the [`loglogistic`](@ref) function.
"""
logitexp(x::Real) = -logexpm1(-float(x))

#=
this uses the identity:

log(logistic(-x)) = -log(1 + exp(x))

that is, negation in the log-odds domain.
=#

"""
$(SIGNATURES)

Return `log(1 - logistic(x))`, computed more carefully and with fewer calls than
the naive composition of functions.

Its inverse is the [`logit1mexp`](@ref) function.
"""
log1mlogistic(x::Real) = -log1pexp(x)

#=

this uses the same identity:

-logit(exp(x)) = logit(1 - exp(x)) = log((1 - exp(x)) / exp(x)) = log(exp(-x) - 1)
=#

"""
$(SIGNATURES)

Return `logit(1 - exp(x))`, computed more carefully and with fewer calls than
the naive composition of functions.

Its inverse is the [`log1mlogistic`](@ref) function.
"""
logit1mexp(x::Real) = logexpm1(-float(x))
90 changes: 90 additions & 0 deletions test/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,93 @@ end
@test cexpexp(-Inf) == 0.0
@test cexpexp(0) == (ℯ - 1) / ℯ
end

@testset "loglogistic: $T" for T in (Float16, Float32, Float64)
lim1 = T === Float16 ? -14.0 : -50.0
lim2 = T === Float16 ? -10.0 : -37.0
xs = T[Inf, -Inf, 0.0, lim1, lim2]
for x in xs
@test loglogistic(x) == log(logistic(x))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we test multiple types T, I think it would be good to also check that - as desired - the results are of type T.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Comprehensive return type tests have been added for the 4 functions included in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anything else, @devmotion?

end

ϵ = eps(T)
xs = T[ϵ, 1.0, 18.0, 33.3, 50.0]
for x in xs
lhs = loglogistic(x)
rhs = log(logistic(x))
@test abs(lhs - rhs) < ϵ
end

# misc
@test loglogistic(T(Inf)) == -zero(T)
@test loglogistic(-T(Inf)) == -T(Inf)
@test loglogistic(-T(103.0)) == -T(103.0)
@test abs(loglogistic(T(35.0))) < 3eps(T)
@test abs(loglogistic(T(103.0))) < eps(T)
@test isfinite(loglogistic(-T(745.0)))
@test isfinite(loglogistic(T(50.0)))
@test isfinite(loglogistic(T(745.0)))
end


@testset "logitexp: $T" for T in (Float16, Float32, Float64)
ϵ = eps(T)
xs = T[ϵ, √ϵ, 0.2, 0.4, 0.8, 1.0 - √ϵ, 1.0 - ϵ]
neg_xs = -xs
for x in xs
@test abs(logitexp(loglogistic(x)) - x) <= ϵ
end
for x in neg_xs
@test abs(logitexp(loglogistic(x)) - x) <= 2ϵ
end
xs = T[-Inf, 0.0, Inf]
for x in xs
@test logitexp(loglogistic(x)) == x
end
end

@testset "log1mlogistic: $T" for T in (Float16, Float32, Float64)
@test log1mlogistic(T(Inf)) == -T(Inf)
@test log1mlogistic(-T(Inf)) == -zero(T)
@test log1mlogistic(-T(103.0)) < eps(T)
@test abs(log1mlogistic(T(35.0))) == T(35.0)
@test abs(log1mlogistic(T(103.0))) == T(103.0)
@test isfinite(log1mlogistic(-T(745.0)))
@test isfinite(log1mlogistic(T(50.0)))
@test isfinite(log1mlogistic(T(745.0)))
end


@testset "logit1mexp: $T" for T in (Float16, Float32, Float64)
ϵ = eps(T)
xs = T[ϵ, √ϵ, 0.2, 0.4, 0.8, 1.0 - √ϵ, 1.0 - ϵ]
neg_xs = -xs
for x in xs
@test abs(logit1mexp(log1mlogistic(x)) - x) <= 2ϵ
end
for x in neg_xs
@test abs(logit1mexp(log1mlogistic(x)) - x) <= ϵ
end
xs = T[-Inf, 0.0, Inf]
for x in xs
@test logit1mexp(log1mlogistic(x)) == x
end
end

@testset "correctness wrt Unsigned, Rational" begin
@test loglogistic(UInt64(5)) == loglogistic(5.0)
@test log1mlogistic(UInt64(5)) == log1mlogistic(5.0)

@test loglogistic(0x01//0x02) == loglogistic(0.5)
@test log1mlogistic(0x01//0x02) == log1mlogistic(0.5)
end

@testset "correctness wrt Integer edge case" begin
# If not handled, these will be zero
@test loglogistic(typemin(Int)) == loglogistic(float(typemin(Int)))
@test log1mlogistic(typemin(Int)) == log1mlogistic(float(typemin(Int)))

# If not handled, these would throw since negation at typemin is a round-trip
@test logitexp(typemin(Int)) == logitexp(float(typemin(Int)))
@test logit1mexp(typemin(Int)) == logit1mexp(float(typemin(Int)))
end
26 changes: 26 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,30 @@
test_scalar(cloglog, rand())

test_scalar(cexpexp, rand())

@testset "loglogistic, log1mlogistic" begin
# relationship to logistic suggests an analogous test regime
for x in (-821.4, -23.5, 12.3, 41.2)
test_frule(loglogistic, x)
test_rrule(loglogistic, x)
end
for x in (-123.2f0, -21.4f0, 8.3f0, 21.5f0)
test_frule(loglogistic, x; rtol=1f-3, atol=1f-3)
test_rrule(loglogistic, x; rtol=1f-3, atol=1f-3)
end

for x in (-821.4, -23.5, 12.3, 41.2)
test_frule(log1mlogistic, x)
test_rrule(log1mlogistic, x)
end
for x in (-123.2f0, -21.4f0, 8.3f0, 21.5f0)
test_frule(log1mlogistic, x; rtol=1f-3, atol=1f-3)
test_rrule(log1mlogistic, x; rtol=1f-3, atol=1f-3)
end
end

test_frule(logitexp, -x)
test_frule(logit1mexp, -x)
test_rrule(logitexp, -x)
test_rrule(logit1mexp, -x)
end
6 changes: 6 additions & 0 deletions test/inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,10 @@

InverseFunctions.test_inverse(cloglog, rand())
InverseFunctions.test_inverse(cexpexp, rand())

InverseFunctions.test_inverse(loglogistic, randexp())
InverseFunctions.test_inverse(logitexp, -randexp())

InverseFunctions.test_inverse(log1mlogistic, randexp())
InverseFunctions.test_inverse(logit1mexp, -randexp())
end
6 changes: 6 additions & 0 deletions test/with_logabsdet_jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,10 @@

ChangesOfVariables.test_with_logabsdet_jacobian(logcosh, x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(logcosh, -x, derivative)

ChangesOfVariables.test_with_logabsdet_jacobian(loglogistic, x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(logitexp, -x, derivative)

ChangesOfVariables.test_with_logabsdet_jacobian(log1mlogistic, x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(logit1mexp, -x, derivative)
end
Loading