diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 854d497d..5c1cf8de 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -230,15 +230,11 @@ end end function _with_logabsdet_jacobian(sb::Stacked, x::AbstractVector) - ys = map(zip(sb.bs, sb.ranges_in)) do tup - b, r = tup[1], tup[2] - b(x[r]) - end - y = reduce(vcat, ys) - logjac = sum(zip(sb.bs, sb.ranges_in)) do tup - b, r = tup[1], tup[2] - logabsdetjac(b, x[r]) + ys_and_logjacs = map(zip(sb.bs, sb.ranges_in)) do (b, r) + with_logabsdet_jacobian(b, x[r]) end + y = reduce(vcat, map(first, ys_and_logjacs)) + logjac = sum(map(last, ys_and_logjacs)) return (y, logjac) end