Skip to content

Commit

Permalink
fix Stacked by fusing two loops into one in a Zygote-friendly way
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Jun 6, 2024
1 parent ce27322 commit 81da609
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions src/bijectors/stacked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,11 @@ end
end

function _with_logabsdet_jacobian(sb::Stacked, x::AbstractVector)
ys = map(zip(sb.bs, sb.ranges_in)) do tup
b, r = tup[1], tup[2]
b(x[r])
end
y = reduce(vcat, ys)
logjac = sum(zip(sb.bs, sb.ranges_in)) do tup
b, r = tup[1], tup[2]
logabsdetjac(b, x[r])
ys_and_logjacs = map(zip(sb.bs, sb.ranges_in)) do (b, r)
with_logabsdet_jacobian(b, x[r])
end
y = reduce(vcat, map(first, ys_and_logjacs))
logjac = sum(map(last, ys_and_logjacs))
return (y, logjac)
end

Expand Down

0 comments on commit 81da609

Please sign in to comment.