diff --git a/src/bijectors/simplex.jl b/src/bijectors/simplex.jl index c5bb4ef6..eacda64e 100644 --- a/src/bijectors/simplex.jl +++ b/src/bijectors/simplex.jl @@ -136,6 +136,12 @@ function logabsdetjac(b::SimplexBijector, x::AbstractVector{T}) where {T} return -lp end + +# Needed to avoid falling back to `with_logabsdet_jacobian` for matrix inputs. +function logabsdetjac(b::SimplexBijector, x::AbstractMatrix{<:Real}) + return sum(Base.Fix1(logabsdetjac, b), eachcol(x)) +end + function simplex_logabsdetjac_gradient(x::AbstractVector) T = eltype(x) ϵ = _eps(T) diff --git a/test/interface.jl b/test/interface.jl index 3272e0f1..c3221307 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -230,6 +230,7 @@ end filldist(Exponential(), 2), filldist(Exponential(), 2, 3), filldist(filldist(Exponential(), 2), 3), + filldist(Dirichlet(ones(2)), 3), ] x = rand(dist) b = bijector(dist)