Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add davibarreira's sinkhorn_divergence with some modifications #145

Merged
merged 28 commits into from
Sep 22, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
ecb81e1
add davibarreira's sinkhorn_divergence with some modifications
zsteve Sep 13, 2021
5e981e7
added documentation entry for sinkhorn_divergence
zsteve Sep 13, 2021
9828028
add statsbase to test deps
zsteve Sep 13, 2021
5c308e9
add empirical measure example for sinkhorn divergence
zsteve Sep 13, 2021
43d7a1d
format
zsteve Sep 13, 2021
e3fa3d9
add literate
zsteve Sep 13, 2021
5c347b1
implement symmetric sinkhorn
zsteve Sep 14, 2021
7f737ee
implement sinkhorn_loss
zsteve Sep 14, 2021
44659ae
change formula for obj()
zsteve Sep 14, 2021
a0dc59a
make empirical example run faster
zsteve Sep 14, 2021
48baadc
update docstrings
zsteve Sep 14, 2021
2f7179e
Update src/entropic/sinkhorn.jl
zsteve Sep 14, 2021
94055c2
Update src/entropic/sinkhorn.jl
zsteve Sep 14, 2021
a43ef32
address review comments
zsteve Sep 18, 2021
fefc503
Merge branch 'sinkhorn_divergence' of https://github.com/JuliaOptimal…
zsteve Sep 18, 2021
c8c8e9c
fix naming of plan and sinkhorn_plan
zsteve Sep 18, 2021
5cb77be
address comments
zsteve Sep 19, 2021
1026a44
fix sinkhorn_divergence and docs
zsteve Sep 19, 2021
edabeee
update docs
zsteve Sep 19, 2021
5b19d3b
format
zsteve Sep 19, 2021
2df1f54
remove sinkhorn_loss
zsteve Sep 19, 2021
440c191
Update examples/empirical_sinkhorn_div/script.jl
zsteve Sep 20, 2021
e214a68
Update src/entropic/sinkhorn_divergence.jl
zsteve Sep 20, 2021
5e02240
Update src/entropic/symmetric.jl
zsteve Sep 20, 2021
80dc5fe
Update src/entropic/sinkhorn_gibbs.jl
zsteve Sep 20, 2021
fb0277b
Update src/entropic/sinkhorn_gibbs.jl
zsteve Sep 20, 2021
ed1b947
Update src/entropic/symmetric.jl
zsteve Sep 20, 2021
f4d6474
bump version
zsteve Sep 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/OptimalTransport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export SinkhornGibbs, SinkhornStabilized, SinkhornEpsilonScaling
export SinkhornBarycenterGibbs
export QuadraticOTNewton

export sinkhorn, sinkhorn2
export sinkhorn, sinkhorn2, sinkhorn_loss
export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter
export sinkhorn_unbalanced, sinkhorn_unbalanced2
export sinkhorn_divergence
Expand Down
36 changes: 25 additions & 11 deletions src/entropic/sinkhorn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,12 @@ function sinkhorn2(μ, ν, C, ε, alg::Sinkhorn; regularization=false, plan=noth
return cost
end

function sinkhorn_loss(μ, ν, C, ε, alg::Sinkhorn; kwargs...)
return error(
"sinkhorn_loss is only implemented for alg::SinkhornGibbs. For other algorithms please use sinkhorn2",
)
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just integrate this in sinkhorn2? Or rename sinkhorn2 to sinkhorn_loss?

"""
sinkhorn_divergence(μ::AbstractVecOrMat, ν::AbstractVecOrMat, C, ε, alg::Sinkhorn = SinkhornGibbs(); regularization = nothing, plan = nothing, kwargs...)
Compute the Sinkhorn Divergence between finite discrete
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Compute the Sinkhorn Divergence between finite discrete
Compute the Sinkhorn Divergence between finite discrete

Expand Down Expand Up @@ -251,13 +257,19 @@ function sinkhorn_divergence(
alg::Sinkhorn=SinkhornGibbs(),
algμ::Sinkhorn=SymmetricSinkhornGibbs(),
algν::Sinkhorn=SymmetricSinkhornGibbs();
regularization=nothing,
regularization=true,
plan=nothing,
kwargs...,
)
OTμν = sinkhorn2(μ, ν, C, ε, alg; plan=plan, regularization=false, kwargs...)
OTμ = sinkhorn2(μ, C, ε, algμ; plan=nothing, regularization=false, kwargs...)
OTν = sinkhorn2(ν, C, ε, algν; plan=nothing, regularization=false, kwargs...)
OTμν, OTμ, OTν = if (regularization == true) && (plan === nothing)
sinkhorn_loss(μ, ν, C, ε, alg; kwargs...),
sinkhorn_loss(μ, C, ε, algμ; kwargs...),
sinkhorn_loss(ν, C, ε, algν; kwargs...)
else
sinkhorn2(μ, ν, C, ε, alg; plan=plan, regularization=false, kwargs...),
sinkhorn2(μ, C, ε, algμ; plan=nothing, regularization=false, kwargs...),
sinkhorn2(ν, C, ε, algν; plan=nothing, regularization=false, kwargs...)
end
zsteve marked this conversation as resolved.
Show resolved Hide resolved
return max.(0, OTμν .- (OTμ .+ OTν) / 2)
zsteve marked this conversation as resolved.
Show resolved Hide resolved
end
"""
Expand Down Expand Up @@ -290,16 +302,18 @@ function sinkhorn_divergence(
alg::Sinkhorn=SinkhornGibbs(),
algμ::Sinkhorn=SymmetricSinkhornGibbs(),
algν::Sinkhorn=SymmetricSinkhornGibbs();
regularization=nothing,
regularization=true,
zsteve marked this conversation as resolved.
Show resolved Hide resolved
plan=nothing,
kwargs...,
)
if regularization !== nothing
@warn "`sinkhorn_divergence` does not support the `regularization` keyword argument"
OTμν, OTμ, OTν = if (regularization == true) && (plan === nothing)
sinkhorn_loss(μ, ν, Cμν, ε, alg; kwargs...),
sinkhorn_loss(μ, Cμ, ε, algμ; kwargs...),
sinkhorn_loss(ν, Cν, ε, algν; kwargs...)
else
sinkhorn2(μ, ν, Cμν, ε, alg; plan=plan, regularization=false, kwargs...),
sinkhorn2(μ, Cμ, ε, algμ; plan=nothing, regularization=false, kwargs...),
sinkhorn2(ν, Cν, ε, algν; plan=nothing, regularization=false, kwargs...)
end

OTμν = sinkhorn2(μ, ν, Cμν, ε, alg; plan=plan, regularization=false, kwargs...)
OTμ = sinkhorn2(μ, Cμ, ε, algμ; plan=nothing, regularization=false, kwargs...)
OTν = sinkhorn2(ν, Cν, ε, algν; plan=nothing, regularization=false, kwargs...)
return max.(0, OTμν - (OTμ + OTν) / 2)
end
10 changes: 10 additions & 0 deletions src/entropic/sinkhorn_gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ function sinkhorn2(
)
end

function sinkhorn_loss(μ, ν, C, ε, alg::SinkhornGibbs; kwargs...)
# build solver
solver = build_solver(μ, ν, C, ε, alg; kwargs...)
# perform Sinkhorn algorithm
solve!(solver)
# return loss
cache = solver.cache
return obj(cache.u, cache.v, solver.source, solver.target, solver.eps)
end

# interface

prestep!(::SinkhornSolver{SinkhornGibbs}, ::Int) = nothing
Expand Down
10 changes: 10 additions & 0 deletions src/entropic/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,13 @@ function sinkhorn2(

return cost
end

function sinkhorn_loss(μ, C, ε, alg::SymmetricSinkhornGibbs; kwargs...)
# build solver
solver = build_solver(μ, C, ε, alg; kwargs...)
# perform Sinkhorn algorithm
solve!(solver)
# return loss
cache = solver.cache
return obj(cache.u, cache.u, solver.source, solver.source, solver.eps)
end
101 changes: 61 additions & 40 deletions test/entropic/sinkhorn_divergence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,51 +23,67 @@ Random.seed!(100)
C = pairwise(SqEuclidean(), x)
f(x; μ, σ) = exp(-((x - μ) / σ)^2)
# regularization parameter
ε = 0.01
ε = 0.05
@testset "basic" begin
μ = normalize!(f.(x; μ=0, σ=0.5), 1)
M = 100

ν_all = [normalize!(f.(x; μ=y, σ=0.5), 1) for y in range(-1, 1; length=M)]

loss = map(ν -> sinkhorn_divergence(μ, ν, C, ε), ν_all)
loss_ = map(
ν ->
sinkhorn2(μ, ν, C, ε) -
(sinkhorn2(μ, μ, C, ε) + sinkhorn2(ν, ν, C, ε)) / 2,
ν_all,
)
for reg in (true, false)
loss = map(ν -> sinkhorn_divergence(μ, ν, C, ε; regularization=reg), ν_all)
loss_ = map(
ν ->
sinkhorn2(μ, ν, C, ε; regularization=reg) -
(
sinkhorn2(μ, μ, C, ε; regularization=reg) +
sinkhorn2(ν, ν, C, ε; regularization=reg)
) / 2,
ν_all,
)

@test loss ≈ loss_
@test all(loss .≥ 0)
@test sinkhorn_divergence(μ, μ, C, ε) ≈ 0 atol = 1e-9
@test loss ≈ loss_ rtol = 1e-6
@test all(loss .≥ 0)
@test sinkhorn_divergence(μ, μ, C, ε) ≈ 0 atol = 1e-9
end
end
@testset "batch" begin
M = 10
μ = hcat([normalize!(f.(x; μ=randn(), σ=0.5), 1) for _ in 1:M]...)
ν = hcat([normalize!(f.(x; μ=randn(), σ=0.5), 1) for _ in 1:M]...)
loss_batch = sinkhorn_divergence(μ, ν, C, ε)
@test loss_batch ≈ [
sinkhorn_divergence(x, y, C, ε) for (x, y) in zip(eachcol(μ), eachcol(ν))
]
loss_batch_μ = sinkhorn_divergence(μ, ν[:, 1], C, ε)
@test loss_batch_μ ≈ [sinkhorn_divergence(x, ν[:, 1], C, ε) for x in eachcol(μ)]
loss_batch_ν = sinkhorn_divergence(μ[:, 1], ν, C, ε)
@test loss_batch_ν ≈ [sinkhorn_divergence(μ[:, 1], y, C, ε) for y in eachcol(ν)]
for reg in (true, false)
loss_batch = sinkhorn_divergence(μ, ν, C, ε; regularization=reg)
@test loss_batch ≈ [
sinkhorn_divergence(x, y, C, ε; regularization=reg) for
(x, y) in zip(eachcol(μ), eachcol(ν))
]
loss_batch_μ = sinkhorn_divergence(μ, ν[:, 1], C, ε; regularization=reg)
@test loss_batch_μ ≈ [
sinkhorn_divergence(x, ν[:, 1], C, ε; regularization=reg) for
x in eachcol(μ)
]
loss_batch_ν = sinkhorn_divergence(μ[:, 1], ν, C, ε; regularization=reg)
@test loss_batch_ν ≈ [
sinkhorn_divergence(μ[:, 1], y, C, ε; regularization=reg) for
y in eachcol(ν)
]
end
end
@testset "AD" begin
ε = 0.05
μ = normalize!(f.(x; μ=-0.5, σ=0.5), 1)
ν = normalize!(f.(x; μ=0.5, σ=0.5), 1)
for Diff in [ForwardDiff, ReverseDiff]
∇ = Diff.gradient(log.(ν)) do xs
sinkhorn_divergence(μ, softmax(xs), C, ε)
end
@test size(∇) == size(ν)
∇ = Diff.gradient(log.(μ)) do xs
sinkhorn_divergence(μ, softmax(xs), C, ε)
for reg in (true, false)
∇ = Diff.gradient(log.(ν)) do xs
sinkhorn_divergence(μ, softmax(xs), C, ε; regularization=reg)
end
@test size(∇) == size(ν)
∇ = Diff.gradient(log.(μ)) do xs
sinkhorn_divergence(μ, softmax(xs), C, ε; regularization=reg)
end
@test norm(∇, Inf) ≈ 0 atol = 1e-9 # Sinkhorn divergence has minimum at SD(μ, μ)
end
@test norm(∇, Inf) ≈ 0 rtol = 1e-9 # Sinkhorn divergence has minimum at SD(μ, μ)
end
end
end
Expand All @@ -82,27 +98,32 @@ Random.seed!(100)
Cμν = pairwise(SqEuclidean(), μ_spt', ν_spt'; dims=2)
Cμ = pairwise(SqEuclidean(), μ_spt'; dims=2)
Cν = pairwise(SqEuclidean(), ν_spt'; dims=2)
ε = 0.05 * max(mean(Cμν), mean(Cμ), mean(Cν))
ε = 0.1 * max(mean(Cμν), mean(Cμ), mean(Cν))

@testset "basic" begin
@test sinkhorn_divergence(μ, ν, Cμν, Cμ, Cν, ε) ≥ 0
@test sinkhorn_divergence(μ, μ, Cμ, Cμ, Cμ, ε) ≈ 0
for reg in (true, false)
@test sinkhorn_divergence(μ, ν, Cμν, Cμ, Cν, ε; regularization=reg) ≥ 0
@test sinkhorn_divergence(μ, μ, Cμ, Cμ, Cμ, ε; regularization=reg) ≈ 0 rtol =
1e-6
end
end

@testset "AD" begin
for Diff in [ForwardDiff, ReverseDiff]
∇ = Diff.gradient(ν_spt) do xs
Cμν = pairwise(SqEuclidean(), μ_spt', xs'; dims=2)
Cν = pairwise(SqEuclidean(), xs'; dims=2)
sinkhorn_divergence(μ, ν, Cμν, Cμ, Cν, ε)
end
@test size(∇) == size(ν_spt)
∇ = Diff.gradient(μ_spt) do xs
Cμν = pairwise(SqEuclidean(), μ_spt', xs'; dims=2)
Cν = pairwise(SqEuclidean(), xs'; dims=2)
sinkhorn_divergence(μ, μ, Cμν, Cμ, Cν, ε)
for reg in (true, false)
∇ = Diff.gradient(ν_spt) do xs
Cμν = pairwise(SqEuclidean(), μ_spt', xs'; dims=2)
Cν = pairwise(SqEuclidean(), xs'; dims=2)
sinkhorn_divergence(μ, ν, Cμν, Cμ, Cν, ε)
end
@test size(∇) == size(ν_spt)
∇ = Diff.gradient(μ_spt) do xs
Cμν = pairwise(SqEuclidean(), μ_spt', xs'; dims=2)
Cν = pairwise(SqEuclidean(), xs'; dims=2)
sinkhorn_divergence(μ, μ, Cμν, Cμ, Cν, ε)
end
@test norm(∇, Inf) ≈ 0 rtol = 1e-6
end
@test norm(∇, Inf) ≈ 0
end
end
end
Expand Down