Skip to content

Commit

Permalink
implement suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
acertain committed Aug 19, 2024
1 parent 9d36d88 commit 35d78f7
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions src/bijectors/truncated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,26 +68,23 @@ end

with_logabsdet_jacobian(b::TruncatedBijector, x) = transform(b, x), logabsdetjac(b, x)

# from stan https://github.com/stan-dev/math/blob/develop/stan/math/prim/constraint/lub_constrain.hpp etc
function truncated_inv_logabsdetjac(x, a, b)
function truncated_inv_logabsdetjac(y, a, b)
lowerbounded, upperbounded = isfinite(a), isfinite(b)
if lowerbounded && upperbounded
return log(b - a) - abs(x) + 2.0 * LogExpFunctions.log1pexp(-abs(x))
elseif lowerbounded
return x
elseif upperbounded
return x
return log(b - a) - abs(y) + 2 * LogExpFunctions.log1pexp(-abs(y))
elseif lowerbounded || upperbounded
return convert(promote_type(typeof(y), typeof(a), typeof(b)), y)
else
return zero(x)
return zero(y)
end
end

function logabsdetjac(ib::Inverse{<:TruncatedBijector}, x)
function logabsdetjac(ib::Inverse{<:TruncatedBijector}, y)
a, b = ib.orig.lb, ib.orig.ub
return truncated_inv_logabsdetjac.(x, a, b)
return truncated_inv_logabsdetjac.(y, a, b)
end

with_logabsdet_jacobian(ib::Inverse{<:TruncatedBijector}, x) = transform(ib, x), logabsdetjac(ib, x)
with_logabsdet_jacobian(ib::Inverse{<:TruncatedBijector}, y) = transform(ib, y), logabsdetjac(ib, y)

# It's only monotonically decreasing if it's only upper-bounded.
# In the multivariate case, we can only say something reasonable if entries are monotonic.
Expand Down

0 comments on commit 35d78f7

Please sign in to comment.