diff --git a/src/bijectors/truncated.jl b/src/bijectors/truncated.jl index d468bbe9..9517807e 100644 --- a/src/bijectors/truncated.jl +++ b/src/bijectors/truncated.jl @@ -68,6 +68,25 @@ end with_logabsdet_jacobian(b::TruncatedBijector, x) = transform(b, x), logabsdetjac(b, x) +function truncated_inv_logabsdetjac(y, a, b) + lowerbounded, upperbounded = isfinite(a), isfinite(b) + if lowerbounded && upperbounded + abs_y = abs(y) + return log(b - a) - abs_y + 2 * LogExpFunctions.log1pexp(-abs_y) + elseif lowerbounded || upperbounded + return convert(promote_type(typeof(y), typeof(a), typeof(b)), y) + else + return zero(y) + end +end + +function logabsdetjac(ib::Inverse{<:TruncatedBijector}, y) + a, b = ib.orig.lb, ib.orig.ub + return truncated_inv_logabsdetjac.(y, a, b) +end + +with_logabsdet_jacobian(ib::Inverse{<:TruncatedBijector}, y) = transform(ib, y), logabsdetjac(ib, y) + # It's only monotonically decreasing if it's only upper-bounded. # In the multivariate case, we can only say something reasonable if entries are monotonic. function is_monotonically_increasing(b::TruncatedBijector)