-
Notifications
You must be signed in to change notification settings - Fork 98
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add weak dependency on ChainRulesCore (#246)
* Add weak dependency on ChainRulesCore * Adjust tolerance * Rename extension * Update ext/DistancesChainRulesCoreExt.jl * Add tests with matrices of repeated columns * Use StableRNGs to fix spurious test failures * Update test/runtests.jl --------- Co-authored-by: David Widmann <[email protected]>
- Loading branch information
Showing
4 changed files
with
212 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
module DistancesChainRulesCoreExt | ||
|
||
using Distances | ||
|
||
import ChainRulesCore | ||
|
||
const CRC = ChainRulesCore | ||
|
||
## SqEuclidean | ||
|
||
function CRC.rrule( | ||
::CRC.RuleConfig{>:CRC.HasReverseMode}, | ||
dist::SqEuclidean, | ||
x::AbstractVector{<:Real}, | ||
y::AbstractVector{<:Real} | ||
) | ||
Ω = dist(x, y) | ||
|
||
function SqEuclidean_pullback(ΔΩ) | ||
x̄ = (2 * CRC.unthunk(ΔΩ)) .* (x .- y) | ||
return CRC.NoTangent(), x̄, -x̄ | ||
end | ||
|
||
return Ω, SqEuclidean_pullback | ||
end | ||
|
||
function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(colwise), dist::SqEuclidean, X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real}) | ||
Ω = colwise(dist, X, Y) | ||
|
||
function colwise_SqEuclidean_pullback(ΔΩ) | ||
X̄ = 2 .* CRC.unthunk(ΔΩ)' .* (X .- Y) | ||
return CRC.NoTangent(), CRC.NoTangent(), X̄, -X̄ | ||
end | ||
|
||
return Ω, colwise_SqEuclidean_pullback | ||
end | ||
|
||
function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(pairwise), dist::SqEuclidean, X::AbstractMatrix{<:Real}; dims::Union{Nothing,Integer}=nothing) | ||
dims = Distances.deprecated_dims(dims) | ||
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)")) | ||
Ω = pairwise(dist, X; dims=dims) | ||
|
||
function pairwise_SqEuclidean_X_pullback(ΔΩ) | ||
Δ = CRC.unthunk(ΔΩ) | ||
A = Δ .+ transpose(Δ) | ||
X̄ = if dims == 1 | ||
2 .* (sum(A; dims=2) .* X .- A * X) | ||
else | ||
2 .* (X .* sum(A; dims=1) .- X * A) | ||
end | ||
return CRC.NoTangent(), CRC.NoTangent(), X̄ | ||
end | ||
|
||
return Ω, pairwise_SqEuclidean_X_pullback | ||
end | ||
|
||
function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(pairwise), dist::SqEuclidean, X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real}; dims::Union{Nothing,Integer}=nothing) | ||
dims = Distances.deprecated_dims(dims) | ||
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)")) | ||
Ω = pairwise(dist, X, Y; dims=dims) | ||
|
||
function pairwise_SqEuclidean_X_Y_pullback(ΔΩ) | ||
Δ = CRC.unthunk(ΔΩ) | ||
Δt = transpose(Δ) | ||
X̄ = if dims == 1 | ||
2 .* (sum(Δ; dims=2) .* X .- Δ * Y) | ||
else | ||
2 .* (X .* sum(Δt; dims=1) .- Y * Δt) | ||
end | ||
Ȳ = if dims == 1 | ||
2 .* (sum(Δt; dims=2) .* Y .- Δt * X) | ||
else | ||
2 .* (Y .* sum(Δ; dims=1) .- X * Δ) | ||
end | ||
return CRC.NoTangent(), CRC.NoTangent(), X̄, Ȳ | ||
end | ||
|
||
return Ω, pairwise_SqEuclidean_X_Y_pullback | ||
end | ||
|
||
## Euclidean | ||
|
||
_normalize(x::Real, nrm::Real) = iszero(nrm) && !isnan(x) ? one(x / nrm) : x / nrm | ||
|
||
function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, dist::Euclidean, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) | ||
Ω = dist(x, y) | ||
|
||
function Euclidean_pullback(ΔΩ) | ||
x̄ = _normalize(CRC.unthunk(ΔΩ), Ω) .* (x .- y) | ||
return CRC.NoTangent(), x̄, -x̄ | ||
end | ||
|
||
return Ω, Euclidean_pullback | ||
end | ||
|
||
function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(colwise), dist::Euclidean, X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real}) | ||
Ω = colwise(dist, X, Y) | ||
|
||
function colwise_Euclidean_pullback(ΔΩ) | ||
X̄ = _normalize.(CRC.unthunk(ΔΩ)', Ω') .* (X .- Y) | ||
return CRC.NoTangent(), CRC.NoTangent(), X̄, -X̄ | ||
end | ||
|
||
return Ω, colwise_Euclidean_pullback | ||
end | ||
|
||
function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(pairwise), dist::Euclidean, X::AbstractMatrix{<:Real}; dims::Union{Nothing,Integer}=nothing) | ||
dims = Distances.deprecated_dims(dims) | ||
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)")) | ||
Ω = pairwise(dist, X; dims=dims) | ||
|
||
function pairwise_Euclidean_X_pullback(ΔΩ) | ||
Δ = CRC.unthunk(ΔΩ) | ||
A = _normalize.(Δ .+ transpose(Δ), Ω) | ||
X̄ = if dims == 1 | ||
sum(A; dims=2) .* X .- A * X | ||
else | ||
X .* sum(A; dims=1) .- X * A | ||
end | ||
return CRC.NoTangent(), CRC.NoTangent(), X̄ | ||
end | ||
|
||
return Ω, pairwise_Euclidean_X_pullback | ||
end | ||
|
||
function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(pairwise), dist::Euclidean, X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real}; dims::Union{Nothing,Integer}=nothing) | ||
dims = Distances.deprecated_dims(dims) | ||
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)")) | ||
Ω = pairwise(dist, X, Y; dims=dims) | ||
|
||
function pairwise_Euclidean_X_Y_pullback(ΔΩ) | ||
Δ = _normalize.(CRC.unthunk(ΔΩ), Ω) | ||
Δt = transpose(Δ) | ||
X̄ = if dims == 1 | ||
sum(Δ; dims=2) .* X .- Δ * Y | ||
else | ||
X .* sum(Δt; dims=1) .- Y * Δt | ||
end | ||
Ȳ = if dims == 1 | ||
sum(Δt; dims=2) .* Y .- Δt * X | ||
else | ||
Y .* sum(Δ; dims=1) .- X * Δ | ||
end | ||
return CRC.NoTangent(), CRC.NoTangent(), X̄, Ȳ | ||
end | ||
|
||
return Ω, pairwise_Euclidean_X_Y_pullback | ||
end | ||
|
||
end # module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
using ChainRulesCore | ||
using ChainRulesTestUtils | ||
using StableRNGs | ||
|
||
@testset "ChainRulesCore extension" begin | ||
n = 4 | ||
rng = StableRNG(100) | ||
x = randn(rng, n) | ||
y = randn(rng, n) | ||
X = randn(rng, n, 3) | ||
Y = randn(rng, n, 3) | ||
Xrep = repeat(x, 1, 3) | ||
Yrep = repeat(y, 1, 3) | ||
|
||
@testset for metric in (SqEuclidean(), Euclidean()) | ||
# Single evaluation | ||
test_rrule(metric ⊢ NoTangent(), x, y) | ||
test_rrule(metric ⊢ NoTangent(), x, x) | ||
|
||
for A in (X, Xrep) | ||
# Column-wise distance | ||
test_rrule(colwise, metric ⊢ NoTangent(), A, A) | ||
|
||
# Pairwise distances | ||
# Finite differencing yields impressively inaccurate derivatives for `Euclidean`, | ||
# see https://github.com/FluxML/Zygote.jl/blob/45bf883491d2b52580d716d577e2fa8577a07230/test/gradcheck.jl#L1206 | ||
kwargs = metric isa Euclidean ? (rtol=1e-3, atol=1e-3) : () | ||
test_rrule(pairwise, metric ⊢ NoTangent(), A; kwargs...) | ||
test_rrule(pairwise, metric ⊢ NoTangent(), A; fkwargs=(dims=1,), kwargs...) | ||
test_rrule(pairwise, metric ⊢ NoTangent(), A; fkwargs=(dims=2,), kwargs...) | ||
test_rrule(pairwise, metric ⊢ NoTangent(), A, A; kwargs...) | ||
test_rrule(pairwise, metric ⊢ NoTangent(), A, A; fkwargs=(dims=1,), kwargs...) | ||
test_rrule(pairwise, metric ⊢ NoTangent(), A, A; fkwargs=(dims=2,), kwargs...) | ||
|
||
for B in (Y, Yrep) | ||
# Column-wise distance | ||
test_rrule(colwise, metric ⊢ NoTangent(), A, B) | ||
|
||
# Pairwise distances | ||
test_rrule(pairwise, metric ⊢ NoTangent(), A, B) | ||
test_rrule(pairwise, metric ⊢ NoTangent(), A, B; fkwargs=(dims=1,)) | ||
test_rrule(pairwise, metric ⊢ NoTangent(), A, B; fkwargs=(dims=2,)) | ||
end | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
6d0110d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register
6d0110d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/92860
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: