Skip to content

Commit

Permalink
Added SkipMissing wrapper metric.
Browse files Browse the repository at this point in the history
  • Loading branch information
rofinn committed Dec 29, 2023
1 parent 886ad02 commit 0ce9761
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 2 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.10.11"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
Expand All @@ -19,6 +20,7 @@ DistancesSparseArraysExt = "SparseArrays"
[compat]
ChainRulesCore = "1"
LinearAlgebra = "<0.0.1, 1"
Missings = "1"
SparseArrays = "<0.0.1, 1"
Statistics = "<0.0.1, 1"
StatsAPI = "1"
Expand Down
2 changes: 2 additions & 0 deletions src/Distances.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Distances

using LinearAlgebra
using Missings
using Statistics: mean
import StatsAPI: pairwise, pairwise!

Expand Down Expand Up @@ -116,6 +117,7 @@ include("haversine.jl")
include("mahalanobis.jl")
include("bhattacharyya.jl")
include("bregman.jl")
include("missing.jl")

include("deprecated.jl")

Expand Down
47 changes: 47 additions & 0 deletions src/missing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
Exclude any missing indices from being included the wrappped distance metric.
"""
struct SkipMissing{D<:PreMetric} <: PreMetric
d::D
end

result_type(dist::SkipMissing, a::Type, b::Type) = result_type(dist.d, a, b)

# Always fallback to the internal metric behaviour
(dist::SkipMissing)(a, b) = dist.d(a, b)

# Special case vector arguments where we can mask out incomplete cases
function (dist::SkipMissing)(a::AbstractVector, b::AbstractVector)
require_one_based_indexing(a)
require_one_based_indexing(b)
n = length(a)
length(b) == n || throw(DimensionMismatch("a and b have different lengths"))

mask = BitVector(undef, n)
@inbounds for i in 1:n
mask[i] = !(ismissing(a[i]) || ismissing(b[i]))
end

# Calling `_evaluate` allows us to also mask metric parameters like weights or periods
# I don't think this can be generalized to user defined metric types though without knowing
# what the parameters mean.
# NTOE: We call disallowmissings to avoid downstream type promotion issues.
if dist.d isa UnionMetrics
params = parameters(dist.d)

return _evaluate(
dist.d,
disallowmissing(view(a, mask)),
disallowmissing(view(b, mask)),
isnothing(params) ? params : view(params, mask),
)
else
return dist.d(
disallowmissing(view(a, mask)),
disallowmissing(view(b, mask)),
)
end
end

# Convenience function
skipmissing(dist::PreMetric, args...) = SkipMissing(dist)(args...)
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Distances

using Test
using LinearAlgebra
using Missings
using OffsetArrays
using Random
using Statistics
Expand All @@ -15,7 +16,7 @@ include("test_dists.jl")
# Test ChainRules definitions on Julia versions that support weak dependencies
# Support for extensions was added in
# https://github.com/JuliaLang/julia/commit/93587d7c1015efcd4c5184e9c42684382f1f9ab2
# https://github.com/JuliaLang/julia/pull/47695
# https://github.com/JuliaLang/julia/pull/47695
if VERSION >= v"1.9.0-alpha1.18"
include("chainrules.jl")
end
120 changes: 119 additions & 1 deletion test/test_dists.jl
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ function test_colwise(dist, x, y, T)
# ≈ and all( .≈ ) seem to behave slightly differently for F64
@test all(colwise(dist, x, y) .≈ r1)
@test all(colwise(dist, (x[:,i] for i in axes(x, 2)), (y[:,i] for i in axes(y, 2))) .≈ r1)

@test colwise!(dist, r4, x, y) @test_deprecated(colwise!(r5, dist, x, y))
@test r4 r5

Expand Down Expand Up @@ -1051,6 +1051,124 @@ end
end
end

@testset "skip missing" begin
x = Float64[]
a = [1, missing, 3, missing, 5]
b = [6, missing, missing, 9, 10]
A = allowmissing(reshape(1:20, 5, 4))
B = allowmissing(reshape(21:40, 5, 4))
A[3, 1] = missing
B[4, 2] = missing
w = collect(0.2:0.2:1.0)

# Sampling of different types of distance calculations to check against
dists = [
SqEuclidean(),
Euclidean(),
Cityblock(),
TotalVariation(),
Chebyshev(),
Minkowski(2.5),
Hamming(),
Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x),
CosineDist(),
CorrDist(),
ChiSqDist(),
KLDivergence(),
GenKLDivergence(),
RenyiDivergence(0.0),
JSDivergence(),
SpanNormDist(),
BhattacharyyaDist(),
HellingerDist(),
BrayCurtis(),
Jaccard(),
WeightedEuclidean(w),
WeightedSqEuclidean(w),
WeightedCityblock(w),
WeightedMinkowski(w, 2.5),
WeightedHamming(w),
MeanAbsDeviation(),
MeanSqDeviation(),
RMSDeviation(),
NormRMSDeviation(),
]

# NOTE: Most of this is special casing the failure conditions
# for using `missing` in the base metrics
@testset "Distance $dist" for dist in dists
D = Distances.SkipMissing(dist)

# Baseline that our wrapped metric has the same empty case
# behaviour
if dist isa NormRMSDeviation
@test_throws ArgumentError dist(x, x)
@test_throws ArgumentError D(x, x)
elseif nameof(typeof(dist)) in nameof.(Distances.weightedmetrics)
@test_throws DimensionMismatch dist(x, x)
@test_throws BoundsError D(x, x) # Could choose to special case this?
elseif dist isa Union{BhattacharyyaDist, HellingerDist, MeanAbsDeviation, MeanSqDeviation, RMSDeviation}
@test isnan(dist(x, x))
@test isnan(D(x, x))
else
@test D(x, x) == dist(x, x)
end

# Cover existing failure cases with missings
# TODO: Simplify this with an error variable
if dist isa Union{Hamming, WeightedHamming, ChiSqDist, KLDivergence, GenKLDivergence, SpanNormDist}
@test_throws TypeError dist(a, b)
@test_throws TypeError colwise(dist, A, B)
@test_throws TypeError pairwise(dist, A, B, dims=2)
@test_throws TypeError pairwise(dist, A, dims=2)
elseif dist isa JSDivergence
@test_throws TypeError dist(a, b)
@test_throws MethodError colwise(dist, A, B)
@test_throws MethodError pairwise(dist, A, B, dims=2)
@test_throws MethodError pairwise(dist, A, dims=2)
elseif dist isa Bregman
@test_throws ArgumentError dist(a, b)
@test_throws ArgumentError colwise(dist, A, B)
@test_throws ArgumentError pairwise(dist, A, B, dims=2)
@test_throws ArgumentError pairwise(dist, A, dims=2)
elseif dist isa CorrDist
@test_throws UndefVarError dist(a, b)
@test_throws UndefVarError colwise(dist, A, B)
@test_throws MethodError pairwise(dist, A, B, dims=2)
@test_throws MethodError pairwise(dist, A, dims=2)
elseif dist isa RenyiDivergence
@test_throws MethodError dist(a, b)
@test_throws MethodError colwise(dist, A, B)
@test_throws MethodError pairwise(dist, A, B, dims=2)
@test_throws MethodError pairwise(dist, A, dims=2)

# Doesn't handle eltype Union{T, Missing}
@test_throws MethodError colwise(dist, A[:, 3:4], B[:, 3:4])
else
@test ismissing(dist(a, b))
@test_throws MethodError colwise(dist, A, B)
@test_throws MethodError pairwise(dist, A, B, dims=2)
@test_throws MethodError pairwise(dist, A, dims=2)
end

# Handle weights
if dist isa Distances.UnionMetrics && Distances.parameters(dist) isa Vector
@test D(a, b) == Distances._evaluate(dist, [1, 5], [6, 10], Distances.parameters(dist)[[1, 5]])
else
@test D(a, b) == dist([1, 5], [6, 10])
end

@test colwise(D, A, B)[3:4] colwise(dist, disallowmissing(A[:, 3:4]), disallowmissing(B[:, 3:4]))

M = pairwise(D, A, B, dims=2)
@test size(M) == (4, 4)
@test !any(ismissing, M)

M = pairwise(D, A, dims=2)
@test size(M) == (4, 4)
@test !any(ismissing, M)
end
end
#=
@testset "zero allocation colwise!" begin
d = Euclidean()
Expand Down

0 comments on commit 0ce9761

Please sign in to comment.