-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sinkhorn Divergence #92
Sinkhorn Divergence #92
Conversation
…imalTransport.jl into sinkhorndivergence
Pull Request Test Coverage Report for Build 899735377Warning: This coverage report may be inaccurate.This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.
Details
💛 - Coveralls |
It would be good to add it to PythonOT. |
I'll see how this can be done, |
Submitted a PR to the |
src/OptimalTransport.jl
Outdated
|
||
The Sinkhorn Divergence is computed as: | ||
```math | ||
S_{c,ε}(μ,ν) := OT_{c,ε}(μ,ν) - \\frac{1}{2}(OT_{c,ε}(μ,μ) + OT_{c,ε}(ν,ν)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we use \\operatorname{OT}
here instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for \\operatorname{S}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, better to add the \\operatorname
. Will do.
src/OptimalTransport.jl
Outdated
regularization=regularization, | ||
kwargs..., | ||
) | ||
return max(0, OTμν - 0.5 * (OTμμ + OTνν)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason we take max(0, ...)
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To guarantee that the value is larger or equal to 0. The same is also present in POT.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. It doesn't come from the 'math' though, I assume any negative values would be due to numerical issues (i.e. any negative values would be very small?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. In theory, this is a "proper" metric (symmetric, 0 <=> x=y , trig inequality), so it would never be negative.
src/OptimalTransport.jl
Outdated
See also: [`sinkhorn2`](@ref) | ||
""" | ||
function sinkhorn_divergence( | ||
c, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we stick to the calling convention of sinkhorn
here, i.e. (\mu, \nu, C, \varepsilon)
in that order.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we were going to revert all functions to (c,\mu,\nu ...)
, like the implementations of ot_cost
and ot_plan
. Also, I think we should use lowercase c
for the cost function, and uppercase C
for cost matrix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see. Makes sense, which PR is this from? Sorry I've not kept up with all of them :P
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not your fault at all! I thought we had settle on it in the issue #63 , but now I see it's not there. So probably me and @devmotion talked about in one of the Multivariate Normals or 1D implementations (which have almost hundreds of comments). Things are moving so fast in this package now, that I actually don't know where it is.
I'll comment in issue #63, so we can make a decision. If I remember correctly, it was actually a suggestion of @devmotion, based on the fact that one usually declares arguments that are functions in the beginning (I'm guessing this is a Julia standard or something).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is a standard and mentioned in the Julia documentation about ordering of function arguments since it allows one to use the do
syntax
map(rand(20)) do x
return x^2
end
instead of
map(x -> x^2, rand(20))
which is very convenient for longer and more complicated functions.
src/OptimalTransport.jl
Outdated
See also: [`sinkhorn2`](@ref) | ||
""" | ||
function sinkhorn_divergence( | ||
c, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we implement this with the cost matrices as inputs instead? I think this would have multiple advantages:
- it would allow to reuse precomputed cost matrices, here and in other functions
- it would be guaranteed to work on the GPU (BTW GPU tests should be added if possible) whereas the default implementation of
pairwise
probably doesn't for most distances and custom functions (IIRC it uses incorrect containers and indexing which should be avoided on GPUs and is often disabled by users)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additionally, this would still allow to define the function implemented by just forwarding it to the one with cost matrices that are computed from the DiscreteNonParametric
and the cost function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW it's also quite restrictive to only allow univariate messures here it seems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, the univariate was a mistake. I forgot that DiscreteNonParametric
was only for the univariate case. I could implement it with a matrix cost, but it would actually require 3 matrices, so the it would be quite "ugly". It would be something like:
sinkhorn_divergence(mu,nu,C1,C2,C3,eps)
. So I think we should stick with the cost function... I don't know much about programming for GPU, but, don't you think we could adapt this somehow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could implement another sinkhorn_divergence
that would take the cost matrices as argument. And would could perhaps indicate that such version should be used in case the user is interested in using GPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just define both and forward the call from the one with the cost function to the other:
Additionally, this would still allow to define the function implemented by just forwarding it to the one with cost matrices that are computed from the DiscreteNonParametric and the cost function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
such version should be used in case the user is interested in using GPU.
It could be useful without GPUs as well, e.g., if you evaluate sinkhorn
, sinkhorn2
or other other regularized OT distances with the same cost matrices.
In fact, it could also be useful to allow to specify pre-computed plans for the different sinkhorn2
calls.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Convinced.
src/OptimalTransport.jl
Outdated
μ::DiscreteNonParametric, | ||
ν::DiscreteNonParametric, | ||
ε; | ||
regularization=false, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be removed since it's the default in sinkhorn2
:
regularization=false, |
src/OptimalTransport.jl
Outdated
See also: [`sinkhorn2`](@ref) | ||
""" | ||
function sinkhorn_divergence( | ||
c, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is a standard and mentioned in the Julia documentation about ordering of function arguments since it allows one to use the do
syntax
map(rand(20)) do x
return x^2
end
instead of
map(x -> x^2, rand(20))
which is very convenient for longer and more complicated functions.
src/OptimalTransport.jl
Outdated
ν.p, | ||
pairwise(c, μ.support, ν.support), | ||
ε; | ||
regularization=regularization, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
regularization=regularization, |
src/OptimalTransport.jl
Outdated
μ.p, | ||
pairwise(c, μ.support; symmetric=true), | ||
ε; | ||
regularization=regularization, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
regularization=regularization, |
src/OptimalTransport.jl
Outdated
ν.p, | ||
pairwise(c, ν.support; symmetric=true), | ||
ε; | ||
regularization=regularization, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
regularization=regularization, |
src/OptimalTransport.jl
Outdated
regularization=regularization, | ||
kwargs..., | ||
) | ||
return max(0, OTμν - 0.5 * (OTμμ + OTνν)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should avoid unwanted promotions (e.g. when working with Float32
on GPUs):
return max(0, OTμν - 0.5 * (OTμμ + OTνν)) | |
return max(0, OTμν - (OTμμ + OTνν) / 2) |
Intuitively I would have thought that the divergence can be negative if the regularization terms are included (BTW does one actually want to include them)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, if the regularization terms are added, it could be negative, so you are correct that we should not allow the regularization
keyword. For the Sinkhorn Divergence, one should probably never include the regularization terms, because the goal here is to turn the entropic ot cost into a "proper" metric in the space of probability distributions. Besides, Feydy (in the paper I cited), proved some neat properties of this metric, such as the fact that it also metrizes weak convergence.
TL;DR: I don't think one would want to ever include the regularization terms in the Sinkhorn Divergence.
For sinkhorn2
, I saw one paper where the author suggested that the use of the cost plus regularization to behave better than when one removes it. Although, it can be negative, so it loses some of the interpretability. I usually refer to as "Sinkhorn loss" when I add the regularization, and "Sinkhorn cost" or "Sinkhorn distance" when I remove it. But this is personal, and not adopted in the literature generally (but it probably should, cause people use the terms interchangeably, and it becomes very confusing).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the detailed explanation! This confirms my intuition, so let's just remove the regularization
keyword argument. Maybe then we should even check that kwargs...
does not contain a keyword argument regularization
and otherwise either throw an error or (maybe better) drop the regularization
keyword argument and display a @warn
ing. Maybe something like
....; regularization=nothing, kwargs...)
if regularization !== nothing
@warn "`sinkhorn_divergence` does not support the `regularization` keyword argument"
end
test/entropic.jl
Outdated
@testset "example" begin | ||
# create distributions | ||
N = 100 | ||
μ = DiscreteNonParametric(rand(N), ones(N) / N) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use random histograms?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. I'll update.
- Created the struct FiniteDiscreteMeasure, - Implemented two versions of sinkhorn_divergence, - Disabled the use of regularization on sinkhorn_divergence, - Fixed docstring with suggestions.
So, this current push in the PR is actually too much code for a single PR. I intend to break it down in other PR's, but, I thought it would be good to submit here the whole thing so you guys can take a look on how these new functions tie themselves to the rest of the package. The The @devmotion and @zsteve , what do you think of these two new additions? Should I submit a PR for each one? I already wrote the test for all these functions, so if you guys prefer, we can actually just review everything in this PR. |
I think it's a good idea to move these two additions to separate PRs, then it's easier to discuss and polish them. I also think that probably a I am a bit less convinced by |
The error in the check seems to be unrelated to the PR. |
src/entropic/sinkhorn.jl
Outdated
μ::Union{FiniteDiscreteMeasure, DiscreteNonParametric}, | ||
ν::Union{FiniteDiscreteMeasure, DiscreteNonParametric}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not really what we support, maybe just omit the type or add two docstrings (one basically referring to the other)?
src/entropic/sinkhorn.jl
Outdated
c, | ||
μ::Union{FiniteDiscreteMeasure, DiscreteNonParametric}, | ||
ν::Union{FiniteDiscreteMeasure, DiscreteNonParametric}, | ||
ε; regularization=false, plan=nothing, kwargs... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be useful if we would restructure sinkhorn
and introduce the different algorithms, as proposed in the other PR. Then users could also choose the stabilized algorithm or epsilon scaling here.
src/entropic/sinkhorn.jl
Outdated
μ::Union{FiniteDiscreteMeasure, DiscreteNonParametric}, | ||
ν::Union{FiniteDiscreteMeasure, DiscreteNonParametric}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
test/entropic/sinkhorn.jl
Outdated
for (ε, metrics) in Iterators.product( | ||
[0.1, 1.0, 10.0], | ||
[ | ||
(sqeuclidean, SqEuclidean()), | ||
(euclidean, Euclidean()), | ||
(totalvariation, TotalVariation()), | ||
], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit uncommon, usually one would just use two for
loops here. One can even use the short notation
for x in xs, y in ys
test/entropic/sinkhorn.jl
Outdated
μ = OptimalTransport.discretemeasure(μsupp, μprobs) | ||
ν = OptimalTransport.discretemeasure(νsupp) | ||
|
||
for (ε, metrics) in Iterators.product( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
test/finitediscretemeasure.jl
Outdated
@@ -0,0 +1,55 @@ | |||
using Distributions: DiscreteNonParametric |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file seems unrelated to the PR?
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Took a break from the computer a couple of days. Made the requested changes. |
Codecov Report
@@ Coverage Diff @@
## master #92 +/- ##
==========================================
- Coverage 96.25% 96.14% -0.12%
==========================================
Files 11 12 +1
Lines 561 570 +9
==========================================
+ Hits 540 548 +8
- Misses 21 22 +1
Continue to review full report at Codecov.
|
I think this one is finally ready to go. |
@devmotion or @zsteve, whenever you guys can, please review this PR. |
Implemented the function
sinkhorn_divergence
which is an unbiased version ofsinkhorn
and that is also a metric in the space of probability spaces. This function is similar to theot.bregman.empirical_sinkhorn_divergence
fromPOT.py
.The tests required the use of
PyCall
, becauseot.bregman.empirical_sinkhorn_divergence
is not supported onPythonOT.jl
.