From e522155c06565673d8b95c330f2a99db13a6d3b1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 15 Apr 2024 13:06:29 +0100 Subject: [PATCH] Bugfix for `SimplexBijector` on `Matrix` (#302) * fix for `SimplexBijector` with matrix inputs * added tests for filldist with `Dirichlet` * bump patch version --- Project.toml | 2 +- src/bijectors/simplex.jl | 6 ++++++ test/interface.jl | 1 + 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9df2d5d5..2a7ccaf3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.13.8" +version = "0.13.9" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/bijectors/simplex.jl b/src/bijectors/simplex.jl index c5bb4ef6..eacda64e 100644 --- a/src/bijectors/simplex.jl +++ b/src/bijectors/simplex.jl @@ -136,6 +136,12 @@ function logabsdetjac(b::SimplexBijector, x::AbstractVector{T}) where {T} return -lp end + +# Needed to avoid falling back to `with_logabsdet_jacobian` for matrix inputs. +function logabsdetjac(b::SimplexBijector, x::AbstractMatrix{<:Real}) + return sum(Base.Fix1(logabsdetjac, b), eachcol(x)) +end + function simplex_logabsdetjac_gradient(x::AbstractVector) T = eltype(x) ϵ = _eps(T) diff --git a/test/interface.jl b/test/interface.jl index 3272e0f1..c3221307 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -230,6 +230,7 @@ end filldist(Exponential(), 2), filldist(Exponential(), 2, 3), filldist(filldist(Exponential(), 2), 3), + filldist(Dirichlet(ones(2)), 3), ] x = rand(dist) b = bijector(dist)