Skip to content

Commit

Permalink
Add wasserstein and squared2wasserstein (#91)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Jun 2, 2021
1 parent 3572ba3 commit 7e8cbe2
Show file tree
Hide file tree
Showing 10 changed files with 307 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimalTransport"
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
authors = ["zsteve <[email protected]>"]
version = "0.3.6"
version = "0.3.7"

[deps]
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Expand Down
137 changes: 129 additions & 8 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,22 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[BenchmarkTools]]
deps = ["JSON", "Logging", "Printf", "Statistics", "UUIDs"]
git-tree-sha1 = "068fda9b756e41e6c75da7b771e6f89fa8a43d15"
git-tree-sha1 = "01ca3823217f474243cc2c8e6e1d1f45956fe872"
uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
version = "0.7.0"
version = "1.0.0"

[[Bzip2_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "19a35467a82e236ff51bc17a3a44b69ef35185a2"
uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0"
version = "1.0.8+0"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "4b28f88cecf5d9a07c85b9ce5209a361ecaff34a"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.45"

[[CodecBzip2]]
deps = ["Bzip2_jll", "Libdl", "TranscodingStreams"]
git-tree-sha1 = "2e62a725210ce3c3c2e1a3080190e7ca491f18d7"
Expand All @@ -33,16 +39,51 @@ git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.7.0"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.30.0"

[[CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"

[[DataAPI]]
git-tree-sha1 = "dfb3b7e89e395be1e25c2ad6d7690dc29cc53b1d"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.6.0"

[[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "4437b64df1e0adccc3e5d1adbc3ac741095e4677"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.18.9"

[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[DelimitedFiles]]
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[Distances]]
deps = ["LinearAlgebra", "Statistics", "StatsAPI"]
git-tree-sha1 = "abe4ad222b26af3337262b8afb28fab8d215e9f8"
uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
version = "0.10.3"

[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[Distributions]]
deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"]
git-tree-sha1 = "64a3e756c44dcf33bd33e7f500113d9992a02e92"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.25.2"

[[DocStringExtensions]]
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32"
Expand All @@ -59,11 +100,17 @@ version = "0.26.3"
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "31939159aeb8ffad1d4d8ee44d07f8558273120a"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.11.7"

[[HTTP]]
deps = ["Base64", "Dates", "IniFile", "MbedTLS", "NetworkOptions", "Sockets", "URIs"]
git-tree-sha1 = "b855bf8247d6e946c75bb30f593bfe7fe591058d"
git-tree-sha1 = "86ed84701fbfd1142c9786f8e53c595ff5a4def9"
uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3"
version = "0.9.8"
version = "0.9.10"

[[IOCapture]]
deps = ["Logging"]
Expand Down Expand Up @@ -134,6 +181,12 @@ git-tree-sha1 = "32b517d4d8219d3bbab199de3416ace45010bdb3"
uuid = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
version = "2.8.0"

[[LogExpFunctions]]
deps = ["DocStringExtensions", "LinearAlgebra"]
git-tree-sha1 = "1ba664552f1ef15325e68dc4c05c3ef8c2d5d885"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.2.4"

[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

Expand All @@ -143,9 +196,9 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[MathOptInterface]]
deps = ["BenchmarkTools", "CodecBzip2", "CodecZlib", "JSON", "JSONSchema", "LinearAlgebra", "MutableArithmetics", "OrderedCollections", "SparseArrays", "Test", "Unicode"]
git-tree-sha1 = "cd3057ca89a9ab83ce37ec42324523b8db0c60dc"
git-tree-sha1 = "575644e3c05b258250bb599e57cf73bbf1062901"
uuid = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
version = "0.9.21"
version = "0.9.22"

[[MbedTLS]]
deps = ["Dates", "MbedTLS_jll", "Random", "Sockets"]
Expand All @@ -157,6 +210,12 @@ version = "1.0.3"
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"

[[Missings]]
deps = ["DataAPI"]
git-tree-sha1 = "4ea90bd5d3985ae1f9a908bd4500ae88921c5ce7"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "1.0.0"

[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

Expand All @@ -172,17 +231,29 @@ version = "0.2.19"
[[NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"

[[OpenSpecFun_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "b9b8b8ed236998f91143938a760c2112dceeb2b4"
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.4+0"

[[OptimalTransport]]
deps = ["Distances", "IterativeSolvers", "LinearAlgebra", "MathOptInterface", "SparseArrays"]
deps = ["Distances", "Distributions", "IterativeSolvers", "LinearAlgebra", "LogExpFunctions", "MathOptInterface", "QuadGK", "SparseArrays", "StatsBase"]
path = ".."
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
version = "0.3.0"
version = "0.3.6"

[[OrderedCollections]]
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.4.1"

[[PDMats]]
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
git-tree-sha1 = "f82a0e71f222199de8e9eb9a09977bd0767d52a0"
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.11.0"

[[Parsers]]
deps = ["Dates"]
git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc"
Expand All @@ -203,6 +274,12 @@ version = "1.2.2"
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[QuadGK]]
deps = ["DataStructures", "LinearAlgebra"]
git-tree-sha1 = "12fbe86da16df6679be7521dfb39fbc861e1dc7b"
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
version = "2.4.1"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
Expand All @@ -216,19 +293,47 @@ git-tree-sha1 = "b3fb709f3c97bfc6e948be68beeecb55a0b340ae"
uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
version = "1.1.1"

[[Rmath]]
deps = ["Random", "Rmath_jll"]
git-tree-sha1 = "bf3188feca147ce108c76ad82c2792c57abe7b1f"
uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
version = "0.7.0"

[[Rmath_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "68db32dff12bb6127bac73c209881191bf0efbb7"
uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f"
version = "0.3.0+0"

[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[[SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"

[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

[[SortingAlgorithms]]
deps = ["DataStructures"]
git-tree-sha1 = "2ec1962eba973f383239da22e75218565c390a96"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "1.0.0"

[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[SpecialFunctions]]
deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"]
git-tree-sha1 = "371204984184315ed7228bcc604d08e1bbc18f31"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "1.4.2"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -238,6 +343,22 @@ git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510"
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
version = "1.0.0"

[[StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
git-tree-sha1 = "2f6792d523d7448bbe2fec99eca9218f06cc746d"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.33.8"

[[StatsFuns]]
deps = ["LogExpFunctions", "Rmath", "SpecialFunctions"]
git-tree-sha1 = "30cd8c360c54081f806b1ee14d2eecbef3c04c49"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "0.9.8"

[[SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[[TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Expand Down
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
OptimalTransport = "7e02d93a-ae51-4f58-b602-d97af76e3b33"

[compat]
Distributions = "0.25"
Documenter = "0.26"
Literate = "2.8"
OptimalTransport = "0.3"
Expand Down
2 changes: 2 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using OptimalTransport
using Distributions

using Literate: Literate
using Pkg: Pkg

Expand Down
6 changes: 6 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
emd
emd2
ot_plan
ot_plan(::Any, ::ContinuousUnivariateDistribution, ::UnivariateDistribution)
ot_plan(::Any, ::DiscreteNonParametric, ::DiscreteNonParametric)
ot_cost
ot_cost(::Any, ::ContinuousUnivariateDistribution, ::UnivariateDistribution)
ot_cost(::Any, ::DiscreteNonParametric, ::DiscreteNonParametric)
wasserstein
squared2wasserstein
```

## Entropically regularised optimal transport
Expand Down
3 changes: 2 additions & 1 deletion src/OptimalTransport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ export emd, emd2
export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter
export sinkhorn_unbalanced, sinkhorn_unbalanced2
export quadreg
export ot_cost, ot_plan
export ot_cost, ot_plan, wasserstein, squared2wasserstein

const MOI = MathOptInterface

include("exact.jl")
include("wasserstein.jl")

dot_matwise(x::AbstractMatrix, y::AbstractMatrix) = dot(x, y)
function dot_matwise(x::AbstractArray, y::AbstractMatrix)
Expand Down
32 changes: 32 additions & 0 deletions src/exact.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,35 @@
"""
ot_plan(c, μ, ν; kwargs...)
Compute the optimal transport plan for the Monge-Kantorovich problem with source and target
marginals `μ` and `ν` and cost `c`.
The optimal transport plan solves
```math
\\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\int c(x, y) \\, \\mathrm{d}\\gamma(x, y)
```
where ``\\Pi(\\mu, \\nu)`` denotes the couplings of ``\\mu`` and ``\\nu``.
See also: [`ot_cost`](@ref)
"""
function ot_plan end

"""
ot_cost(c, μ, ν; kwargs...)
Compute the optimal transport cost for the Monge-Kantorovich problem with source and target
marginals `μ` and `ν` and cost `c`.
The optimal transport cost is the scalar value
```math
\\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\int c(x, y) \\, \\mathrm{d}\\gamma(x, y)
```
where ``\\Pi(\\mu, \\nu)`` denotes the couplings of ``\\mu`` and ``\\nu``.
See also: [`ot_plan`](@ref)
"""
function ot_cost end

#############
# Discrete OT
#############
Expand Down
44 changes: 44 additions & 0 deletions src/wasserstein.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
wasserstein(μ, ν; metric=Euclidean(), p=Val(1), kwargs...)
Compute the `p`-Wasserstein distance with respect to the `metric` between measures `μ` and
`ν`.
Order `p` can be provided as a scalar of type `Real` or as a parameter of a value type
`Val(p)`. For certain combinations of `metric` and `p`, such as `metric=Euclidean()` and
`p=Val(2)`, the computations are more efficient if `p` is specified as a value type. The
remaining keyword arguments are forwarded to [`ot_cost`](@ref).
See also: [`squared2wasserstein`](@ref), [`ot_cost`](@ref)
"""
function wasserstein(μ, ν; metric=Euclidean(), p::Union{Real,Val}=Val(1), kwargs...)
cost = ot_cost(p2distance(metric, p), μ, ν; kwargs...)
return prt(cost, p)
end

# compute the cost function corresponding to a metric and exponent `p`
p2distance(metric, ::Val{1}) = metric
p2distance(metric, ::Val{P}) where {P} = (x, y) -> metric(x, y)^P
p2distance(d::Euclidean, ::Val{2}) = SqEuclidean(d.thresh)
p2distance(metric, p) = (x, y) -> metric(x, y)^p

# compute the `p` root
prt(x, ::Val{1}) = x
prt(x, ::Val{2}) = sqrt(x)
prt(x, ::Val{3}) = cbrt(x)
prt(x, ::Val{P}) where {P} = x^(1 / P)
prt(x, p) = x^(1 / p)

"""
squared2wasserstein(μ, ν; metric=Euclidean(), kwargs...)
Compute the squared 2-Wasserstein distance with respect to the `metric` between measures `μ`
and `ν`.
The remaining keyword arguments are forwarded to [`ot_cost`](@ref).
See also: [`wasserstein`](@ref), [`ot_cost`](@ref)
"""
function squared2wasserstein(μ, ν; metric=Euclidean(), kwargs...)
return ot_cost(p2distance(metric, Val(2)), μ, ν; kwargs...)
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ const GROUP = get(ENV, "GROUP", "All")
@safetestset "Unbalanced OT" begin
include("unbalanced.jl")
end
@safetestset "Wasserstein distance" begin
include("wasserstein.jl")
end
end

# CUDA requires Julia >= 1.6
Expand Down
Loading

2 comments on commit 7e8cbe2

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/37976

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.7 -m "<description of version>" 7e8cbe21f45149e2abef277d2e6ecf971e15b3d7
git push origin v0.3.7

Please sign in to comment.