Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General regularisation #66

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ sinkhorn_unbalanced
sinkhorn_unbalanced2
```

## Quadratically regularised optimal transport
## Optimal transport with general regularisation
```@docs
quadreg
ot_reg_plan
```
4 changes: 2 additions & 2 deletions examples/basic/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ sinkhorn2(μ, ν, C, ε)
# resulting transport plan $\gamma$ is *sparse*. We take advantage of this and represent it as
# a sparse matrix.

quadreg(μ, ν, C, ε; maxiter=500);
ot_reg_plan(μ, ν, C, ε; reg_func = "L2", method = "lorenz", maxiter=500);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
ot_reg_plan(μ, ν, C, ε; reg_func = "L2", method = "lorenz", maxiter=500);
ot_reg_plan(μ, ν, C, ε; reg_func="L2", method="lorenz", maxiter=500);


# ## Stabilized Sinkhorn algorithm
#
Expand Down Expand Up @@ -190,7 +190,7 @@ heatmap(
# Notice how the "edges" of the transport plan are sharper if we use quadratic regularisation
# instead of entropic regularisation:

γquad = Matrix(quadreg(μ, ν, C, 5; maxiter=500))
γquad = Matrix(ot_reg_plan(μ, ν, C, 5; reg_func = "L2", method = "lorenz", maxiter=500))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
γquad = Matrix(ot_reg_plan(μ, ν, C, 5; reg_func = "L2", method = "lorenz", maxiter=500))
γquad = Matrix(ot_reg_plan(μ, ν, C, 5; reg_func="L2", method="lorenz", maxiter=500))

heatmap(
μsupport,
νsupport,
Expand Down
49 changes: 48 additions & 1 deletion src/OptimalTransport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export sinkhorn, sinkhorn2
export emd, emd2
export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter
export sinkhorn_unbalanced, sinkhorn_unbalanced2
export quadreg
export ot_reg_plan, ot_reg_cost

const MOI = MathOptInterface

Expand Down Expand Up @@ -506,6 +506,53 @@ function sinkhorn_barycenter(
return u_all[1, :] .* (K_all[1] * v_all[1, :])
end

"""
ot_reg_plan(mu, nu, C, eps; reg_func = "L2", method = "lorenz", kwargs...)

Compute the optimal transport plan between `mu` and `nu` for optimal transport with a
general choice of regulariser `math Ω(γ)`. Solves for `gamma` that minimises

```math
\\inf_{γ ∈ Π(μ, ν)} \\langle γ, C \\rangle + ε Ω(γ)
```

Supported choices of `math Ω` are:
- L2: ``Ω(γ) = \\frac{1}{2} \\| γ \\|_2^2``, `reg_func = "L2"`

Supported solution methods are:
- L2: `method = "lorenz"` for the semi-smooth Newton method of Lorenz et al.

References

Lorenz, D.A., Manns, P. and Meyer, C., 2019. Quadratically regularized optimal transport. Applied Mathematics & Optimization, pp.1-31.
"""
function ot_reg_plan(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...)
if (reg_func == "L2") && (method == "lorenz")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach is problematic in my opinion: it is not possible to add support for other methods or regularization types for users or downstream packages, one always has to modify this function here.

I guess this could be avoided with the suggestion in #63 (comment) - everyone could just add other regularizations and/or algorithms.

return quadreg(mu, nu, C, eps; kwargs...)
else
@warn "Unimplemented"
end
end

"""
ot_reg_cost(mu, nu, C, eps; reg_func = "L2", method = "lorenz", kwargs...)

Compute the optimal transport cost between `mu` and `nu` for optimal transport with a
general choice of regulariser `math Ω(γ)`.

See also: [`ot_reg_plan`](@ref)

"""
function ot_reg_cost(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function ot_reg_cost(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...)
nothing

γ = if (reg_func == "L2") && (method == "lorenz")
quadreg(mu, nu, C, eps; kwargs...)
else
@warn "Unimplemented"
nothing
end
return dot(γ, C)
end

"""
quadreg(mu, nu, C, ϵ; θ = 0.1, tol = 1e-5,maxiter = 50,κ = 0.5,δ = 1e-5)

Expand Down
10 changes: 9 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,18 @@ end

# compute optimal transport map (Julia implementation + POT)
eps = 0.25
γ = quadreg(μ, ν, C, eps)
<<<<<<< HEAD
γ = ot_reg_plan(μ, ν, C, eps)
γ_pot = POT.Smooth.smooth_ot_dual(μ, ν, C, eps; stopThr=1e-9)
=======
γ = ot_reg_plan(μ, ν, C, eps; reg_func="L2", method="lorenz")
γ_pot = sparse(POT.smooth_ot_dual(μ, ν, C, eps; max_iter=5000))
>>>>>>> d6c9ee3 (updated tests and docstrings)
# need to use a larger tolerance here because of a quirk with the POT solver
@test norm(γ - γ_pot, Inf) < 1e-4
c = ot_reg_cost(μ, ν, C, eps; reg_func="L2", method="lorenz")
c_pot = dot(γ_pot, C)
@test c ≈ c_pot atol = 1e-4
end
end

Expand Down