-
Notifications
You must be signed in to change notification settings - Fork 0
/
rde_birkhoff.jl
106 lines (90 loc) · 3.13 KB
/
rde_birkhoff.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
using Pkg
Pkg.activate(@__DIR__)
# parse command line arguments if given
if length(ARGS) > 0
subdir = ARGS[1]
# otherwise prompt user to specify
else
print("Please enter sub directory to run RDE in: ")
subdir = readline()
end
# input validation
while !isdir(joinpath(@__DIR__, subdir))
print("Invalid directory $subdir. Please enter sub directory to run RDE in: ")
global subdir = readline()
end
using PyCall
pushfirst!(PyVector(pyimport("sys")["path"]), joinpath(@__DIR__, subdir))
import FrankWolfe
include("custom_oralces.jl")
include(joinpath(@__DIR__, subdir, "config_birkhoff.jl")) # load indices, rates, max_iter
cd(subdir)
# Get the Python side of RDE
rde = pyimport("rde")
for idx in indices
# Load data sample and distortion functional
x, fname = rde.get_data_sample(idx)
f, df, node, pred = rde.get_distortion(x)
# Setup LMO
lmo = FrankWolfe.BirkhoffPolytopeLMO()
# helper functions for prototype ordering vectors
lin_p = convert(Vector{eltype(x)}, LinRange(0.0, 1.0, length(x)))
function k_to_p(k)
pk = zeros(eltype(x), length(x))
pk[end-k+1:end] .= 1.0
return pk
end
# Wrap objective and gradient functions
function func(S)
if (S isa Matrix) && !(S isa Matrix{eltype(x)})
S = convert(Matrix{eltype(x)}, S)
end
f_sum = convert(eltype(x), 0.0)
for (ridx, rate) in enumerate(all_rates)
f_sum += f(S*k_to_p(rate))
end
return f_sum
end
function grad!(storage, S)
if (S isa Matrix) && !(S isa Matrix{eltype(x)})
S = convert(Matrix{eltype(x)}, S)
end
df_sum = zeros(eltype(x), length(x), length(x))
for (ridx, rate) in enumerate(all_rates)
pk = k_to_p(rate)
g = df(S*pk)
if any(isnan, g) || any(isnan, pk)
@info "Warning: Numerical instabilities, skipped rate k=$rate"
else
# df_sum = df_sum + g*transpose(pk)
BLAS.ger!(convert(eltype(x), 1.0), g, pk, df_sum)
end
end
return @. storage = df_sum
end
# Run FrankWolfe
println("Running sample $idx")
S0 = Matrix{eltype(x)}(I, length(x), length(x))
@time S, V, primal, dual_gap = FrankWolfe.frank_wolfe(
#@time S, V, primal, dual_gap = FrankWolfe.away_frank_wolfe(
#@time S, V, primal, dual_gap = FrankWolfe.blended_conditional_gradient(
#@time S, V, primal, dual_gap = FrankWolfe.lazified_conditional_gradient(
S -> func(S),
(storage, S) -> grad!(storage, S),
lmo,
S0,
;fw_arguments...
)
# reset adaptive step size if necessary
if fw_arguments.line_search isa FrankWolfe.MonotonousNonConvexStepSize
fw_arguments.line_search.factor = 0
end
# Store results
all_s = zeros(eltype(x), (length(rates), length(x)))
for (ridx, rate) in enumerate(rates)
all_s[ridx, :] = S*k_to_p(rate)
rde.store_single_result(all_s[ridx,:], idx, fname, rate)
end
# Store multiple rate results
rde.store_collected_results(all_s, idx, node, pred, fname, rates, nothing, S, S*lin_p)
end