Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sinkhorn Divergence #92

Closed
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
30de49c
Initianning sikhorn divergence
davibarreira Jun 2, 2021
1a03325
Merge branch 'master' of https://github.com/JuliaOptimalTransport/Opt…
davibarreira Jun 2, 2021
4a1f380
Sinkhorn divergence implemented
davibarreira Jun 2, 2021
bdc1b5b
Added PyCall to test dependencies
davibarreira Jun 2, 2021
416dcb4
Added tests for sinkhorn divergence
davibarreira Jun 2, 2021
f593377
Added Sinkhorn Divergence to docs
davibarreira Jun 2, 2021
21d38a8
Creating FiniteDiscreteMeasure struct
davibarreira Jun 3, 2021
e17bba5
Modifications:
davibarreira Jun 3, 2021
10e8849
FixedDiscreteMeasure normalizes the weights to sum 1
davibarreira Jun 3, 2021
52b3c7a
FixedDiscreteMeasure checks if probabilities are positive
davibarreira Jun 3, 2021
7d2924d
Created tests for FiniteDiscreteMeasure
davibarreira Jun 3, 2021
7cf44a6
Added tests for sinkhorn divergence and finite discrete measure
davibarreira Jun 3, 2021
4764b00
Fixed the code for creating cost matrices in the sinkhorn_divergence
davibarreira Jun 3, 2021
98784c5
Added costmatrix.jl to tests
davibarreira Jun 3, 2021
1fb0fc1
Fixed docstring for costmatrix
davibarreira Jun 3, 2021
808d6ac
Fixed errors in the tests
davibarreira Jun 3, 2021
3415386
Minor fixes in the tests
davibarreira Jun 3, 2021
a7361ff
Fixed tests and merged
davibarreira Jun 13, 2021
5e448f9
Update Project.toml
davibarreira Jun 16, 2021
a3a6e9c
Update src/OptimalTransport.jl
davibarreira Jun 16, 2021
f10938a
Update test/runtests.jl
davibarreira Jun 16, 2021
957e7ed
Update src/entropic/sinkhorn.jl
davibarreira Jun 16, 2021
725bee7
Update src/entropic/sinkhorn.jl
davibarreira Jun 16, 2021
1768af8
removed incorrect file
davibarreira Jun 17, 2021
cc9bd7a
Simplified loop format in the test
davibarreira Jun 17, 2021
ad5ce9e
Adjusted function to the new Sinkhorn format
davibarreira Jun 30, 2021
e261025
Fixed format
davibarreira Jul 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tulip = "6dd1b50a-3aae-11e9-10b5-ef983d2400fa"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
davibarreira marked this conversation as resolved.
Show resolved Hide resolved

[targets]
test = ["ForwardDiff", "Pkg", "PythonOT", "Random", "SafeTestsets", "Test", "Tulip", "HCubature"]
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ sinkhorn2
sinkhorn_stabilized_epsscaling
sinkhorn_stabilized
sinkhorn_barycenter
sinkhorn_divergence
```

## Unbalanced optimal transport
Expand Down
3 changes: 2 additions & 1 deletion src/OptimalTransport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

module OptimalTransport

using LinearAlgebra: AbstractMatrix
davibarreira marked this conversation as resolved.
Show resolved Hide resolved
using Distances
using LinearAlgebra
using IterativeSolvers, SparseArrays
Expand All @@ -14,7 +15,7 @@ using PDMats
using QuadGK
using StatsBase: StatsBase

export sinkhorn, sinkhorn2
export sinkhorn, sinkhorn2, sinkhorn_divergence
export emd, emd2
export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter
export sinkhorn_unbalanced, sinkhorn_unbalanced2
Expand Down
82 changes: 82 additions & 0 deletions src/entropic/sinkhorn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -504,3 +504,85 @@ function sinkhorn_barycenter(μ, C, ε, w; tol=1e-9, check_marginal_step=10, max
end
return u[:, 1] .* (K * v[:, 1])
end

"""
sinkhorn_divergence(
c,
μ::Union{FiniteDiscreteMeasure, DiscreteNonParametric},
ν::Union{FiniteDiscreteMeasure, DiscreteNonParametric},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not really what we support, maybe just omit the type or add two docstrings (one basically referring to the other)?

ε; regularization=false, plan=nothing, kwargs...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be useful if we would restructure sinkhorn and introduce the different algorithms, as proposed in the other PR. Then users could also choose the stabilized algorithm or epsilon scaling here.

davibarreira marked this conversation as resolved.
Show resolved Hide resolved
)

Compute the Sinkhorn Divergence between finite discrete
measures `μ` and `ν` with respect to a cost function `c`
and entropic regularization parameter `ε`.

A pre-computed optimal transport `plan` between `μ` and `ν` may be provided.

The Sinkhorn Divergence is computed as:
```math
\\operatorname{S}_{c,ε}(μ,ν) := \\operatorname{OT}_{c,ε}(μ,ν)
- \\frac{1}{2}(\\operatorname{OT}_{c,ε}(μ,μ) + \\operatorname{OT}_{c,ε}(ν,ν)),
```
where ``\\operatorname{OT}_{c,ε}(μ,ν)``, ``\\operatorname{OT}_{c,ε}(μ,μ)`` and
``\\operatorname{OT}_{c,ε}(ν,ν)`` are the entropically regularized optimal transport cost
between `(μ,ν)`, `(μ,μ)` and `(ν,ν)`, respectively.

The formulation for the Sinkhorn Divergence may have slight variations depending on the paper consulted.
The Sinkhorn Divergence was initially proposed by [^GPC18], although, this package uses the formulation given by
[^FeydyP19], which is also the one used on the Python Optimal Transport package.

[^GPC18]: Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences,
Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018

[^FeydyP19]: Jean Feydy, Thibault Séjourné, François-Xavier Vialard, Shun-ichi
Amari, Alain Trouvé, and Gabriel Peyré. Interpolating between op-
timal transport and mmd using sinkhorn divergences. In The 22nd In-
ternational Conference on Artificial Intelligence and Statistics, pages
2681–2690. PMLR, 2019.

See also: [`sinkhorn2`](@ref)
"""
function sinkhorn_divergence(
c, μ::T, ν::T, ε; regularization=nothing, plan=nothing, kwargs...
) where {T<:Union{FiniteDiscreteMeasure,DiscreteNonParametric}}
return sinkhorn_divergence(
pairwise(c, μ.support, ν.support),
pairwise(c, μ.support),
pairwise(c, ν.support),
μ,
ν,
ε;
regularization=regularization,
kwargs...,
)
end

"""
sinkhorn_divergence(
Cμν, Cμμ, Cνν,
μ::Union{FiniteDiscreteMeasure, DiscreteNonParametric},
ν::Union{FiniteDiscreteMeasure, DiscreteNonParametric},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

ε; regularization=false, plan=nothing, kwargs...
davibarreira marked this conversation as resolved.
Show resolved Hide resolved
)

Compute the Sinkhorn Divergence between finite discrete
measures `μ` and `ν` with respect to the precomputed cost matrices `Cμν`,
`Cμμ` and `Cνν`, and entropic regularization parameter `ε`.

A pre-computed optimal transport `plan` between `μ` and `ν` may be provided.

See also: [`sinkhorn2`](@ref)
"""
function sinkhorn_divergence(
Cμν, Cμ, Cν, μ::T, ν::T, ε; regularization=nothing, plan=nothing, kwargs...
) where {T<:Union{FiniteDiscreteMeasure,DiscreteNonParametric}}
if regularization !== nothing
@warn "`sinkhorn_divergence` does not support the `regularization` keyword argument"
end

OTμν = sinkhorn2(μ.p, ν.p, Cμν, ε; plan=plan, regularization=false, kwargs...)
OTμ = sinkhorn2(μ.p, μ.p, Cμ, ε; plan=nothing, regularization=false, kwargs...)
OTν = sinkhorn2(ν.p, ν.p, Cν, ε; plan=nothing, regularization=false, kwargs...)
return max(0, OTμν - (OTμ + OTν) / 2)
end
90 changes: 90 additions & 0 deletions test/entropic/sinkhorn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Distances
using ForwardDiff
using LogExpFunctions
using PythonOT: PythonOT
using Distributions

using LinearAlgebra
using Random
Expand Down Expand Up @@ -195,4 +196,93 @@ Random.seed!(100)
@test μ_interp ≈ μ_interp_pot rtol = 1e-6
end
end

@testset "sinkhorn divergence" begin
@testset "univariate exmaples" begin
# create distributions
n = 20
m = 10
μsupp = [rand(1) for i in 1:n]
νsupp = [rand(1) for i in 1:m]
μprobs = normalize!(rand(n), 1)
μ = OptimalTransport.discretemeasure(μsupp, μprobs)
ν = OptimalTransport.discretemeasure(νsupp)

for (ε, metrics) in Iterators.product(
[0.1, 1.0, 10.0],
[
(sqeuclidean, SqEuclidean()),
(euclidean, Euclidean()),
(totalvariation, TotalVariation()),
],
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit uncommon, usually one would just use two for loops here. One can even use the short notation

for x in xs, y in ys

for metric in metrics
@test sinkhorn_divergence(metric, μ, μ, ε) ≈ 0.0
@test sinkhorn_divergence(metric, ν, ν, ε) ≈ 0.0

sd_c = sinkhorn_divergence(metric, μ, ν, ε)

# calculating cost matrices to use in POT.sinkhorn2
Cμν = pairwise(metric, μ.support, ν.support)
Cμ = pairwise(metric, μ.support)
Cν = pairwise(metric, ν.support)

sd_C = sinkhorn_divergence(Cμν, Cμ, Cν, μ, ν, ε)

# the empirical_sinkhorn_divergence returns an error if the weights are not all equal
# so instead, it's more realiable to calculate using sinkhorn2
sd_pot =
POT.sinkhorn2(μ.p, ν.p, Cμν, ε) -
(POT.sinkhorn2(μ.p, μ.p, Cμ, ε) + POT.sinkhorn2(ν.p, ν.p, Cν, ε)) /
2

@test sd_c ≈ sd_pot[1]
@test sd_C ≈ sd_pot[1]
end
end
end
@testset "multivariate exmaples" begin
# create distributions
n = 20
m = 10
μsupp = [rand(3) for i in 1:n]
νsupp = [rand(3) for i in 1:m]
μprobs = normalize!(rand(n), 1)
μ = OptimalTransport.discretemeasure(μsupp, μprobs)
ν = OptimalTransport.discretemeasure(νsupp)

for (ε, metrics) in Iterators.product(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

[0.1, 1.0, 10.0],
[
(sqeuclidean, SqEuclidean()),
(euclidean, Euclidean()),
(totalvariation, TotalVariation()),
],
)
for metric in metrics
@test sinkhorn_divergence(metric, μ, μ, ε) ≈ 0.0
@test sinkhorn_divergence(metric, ν, ν, ε) ≈ 0.0

sd_c = sinkhorn_divergence(metric, μ, ν, ε)

# calculating cost matrices to use in POT.sinkhorn2
Cμν = pairwise(metric, μ.support, ν.support)
Cμ = pairwise(metric, μ.support)
Cν = pairwise(metric, ν.support)

sd_C = sinkhorn_divergence(Cμν, Cμ, Cν, μ, ν, ε)

# the empirical_sinkhorn_divergence returns an error if the weights are not all equal
# so instead, it's more realiable to calculate using sinkhorn2
sd_pot =
POT.sinkhorn2(μ.p, ν.p, Cμν, ε) -
(POT.sinkhorn2(μ.p, μ.p, Cμ, ε) + POT.sinkhorn2(ν.p, ν.p, Cν, ε)) /
2

@test sd_c ≈ sd_pot[1]
@test sd_C ≈ sd_pot[1]
end
end
end
end
end
55 changes: 55 additions & 0 deletions test/finitediscretemeasure.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using Distributions: DiscreteNonParametric
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file seems unrelated to the PR?

using OptimalTransport
using Distributions
using Random

Random.seed!(100)

@testset "finitediscretemeasure.jl" begin
@testset "Univariate Finite Discrete Measure" begin
n = 100
μsupp = rand(n)
νsupp = rand(n, 1)
μ = FiniteDiscreteMeasure(μsupp)
ν = FiniteDiscreteMeasure(νsupp, rand(n))
# check if it assigns equal probabilities to all entries
@test μ.p ≈ ones(n) ./ n
@test probs(μ) ≈ ones(n) ./ n
# check if it probabilities sum to 1
@test sum(ν.p) ≈ 1
@test sum(probs(ν)) ≈ 1
# check if probabilities are all positive (non-negative)
@test all(ν.p .>= 0)
@test all(probs(ν) .>= 0)
# check if it assigns to DiscreteNonParametric when Vector/Matrix is 1D
@test typeof(μ) <: DiscreteNonParametric
@test typeof(ν) <: DiscreteNonParametric
# check if support is correctly assinged
@test sort(μsupp) == μ.support
@test sort(μsupp) == support(μ)
@test sort(vec(νsupp)) == ν.support
@test sort(vec(νsupp)) == support(ν)
end
@testset "Multivariate Finite Discrete Measure" begin
n = 10
m = 3
μsupp = rand(n, m)
νsupp = rand(n, m)
μ = FiniteDiscreteMeasure(μsupp)
ν = FiniteDiscreteMeasure(νsupp, rand(n))
# check if it assigns equal probabilities to all entries
@test μ.p ≈ ones(n) ./ n
@test probs(μ) ≈ ones(n) ./ n
# check if it probabilities sum to 1
@test sum(ν.p) ≈ 1
@test sum(probs(ν)) ≈ 1
# check if probabilities are all positive (non-negative)
@test all(ν.p .>= 0)
@test all(probs(ν) .>= 0)
# check if support is correctly assinged
@test μsupp == μ.support
@test μsupp == support(μ)
@test νsupp == ν.support
@test νsupp == support(ν)
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using LinearAlgebra: symmetric
davibarreira marked this conversation as resolved.
Show resolved Hide resolved
using OptimalTransport
using Pkg: Pkg
using SafeTestsets
Expand Down