Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timesteps to quadratic regularizers #55

Merged
merged 2 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
/examples/**/plots/
/examples/**/trajectories/
pardiso.lic
/.CondaPkg/
/.CondaPkg/
*.code-workspace
67 changes: 58 additions & 9 deletions src/objectives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ function QuadraticRegularizer(;
times::AbstractVector{Int}=1:traj.T,
dim::Int=nothing,
R::AbstractVector{<:Real}=ones(traj.dims[name]),
eval_hessian=true
eval_hessian=true,
timestep_symbol=:Δt
)

@assert !isnothing(name) "name must be specified"
Expand All @@ -316,18 +317,34 @@ function QuadraticRegularizer(;
@views function L(Z⃗::AbstractVector{<:Real}, Z::NamedTrajectory)
J = 0.0
for t ∈ times
if Z.timestep isa Symbol
Δt = Z⃗[slice(t, Z.components[timestep_symbol], Z.dim)]
else
Δt = Z.timestep
end

vₜ = Z⃗[slice(t, Z.components[name], Z.dim)]
J += 0.5 * vₜ' * (R .* vₜ)
rₜ = Δt .* vₜ
J += 0.5 * rₜ' * (R .* rₜ)
end
return J
end

@views function ∇L(Z⃗::AbstractVector{<:Real}, Z::NamedTrajectory)
∇ = zeros(Z.dim * Z.T)
∇ = zeros(Z.dim * Z.T)
Threads.@threads for t ∈ times
vₜ_slice = slice(t, Z.components[name], Z.dim)
vₜ_slice = slice(t, Z.components[name], Z.dim)
vₜ = Z⃗[vₜ_slice]
∇[vₜ_slice] = R .* vₜ

if Z.timestep isa Symbol
Δt_slice = slice(t, Z.components[timestep_symbol], Z.dim)
Δt = Z⃗[Δt_slice]
∇[Δt_slice] .= vₜ' * (R .* (Δt .* vₜ))
else
Δt = Z.timestep
end

∇[vₜ_slice] .= R .* (Δt.^2 .* vₜ)
end
return ∇
end
Expand All @@ -339,16 +356,48 @@ function QuadraticRegularizer(;

∂²L_structure = Z -> begin
structure = []
# vₜ Hessian structure (eq. 17)
# Hessian structure (eq. 17)
for t ∈ times
vₜ_slice = slice(t, Z.components[name], Z.dim)
diag_inds = collect(zip(vₜ_slice, vₜ_slice))
append!(structure, diag_inds)
vₜ_vₜ_inds = collect(zip(vₜ_slice, vₜ_slice))
append!(structure, vₜ_vₜ_inds)

if Z.timestep isa Symbol
Δt_slice = slice(t, Z.components[timestep_symbol], Z.dim)
# ∂²_vₜ_Δt
vₜ_Δt_inds = [(i, j) for i ∈ vₜ_slice for j ∈ Δt_slice]
append!(structure, vₜ_Δt_inds)
# ∂²_Δt_vₜ
Δt_vₜ_inds = [(i, j) for i ∈ Δt_slice for j ∈ vₜ_slice]
append!(structure, Δt_vₜ_inds)
# ∂²_Δt_Δt
Δt_Δt_inds = collect(zip(Δt_slice, Δt_slice))
append!(structure, Δt_Δt_inds)
end
end
return structure
end

∂²L = (Z⃗, Z) -> vcat(fill(R, length(times))...)
∂²L = (Z⃗, Z) -> begin
values = []
# Match Hessian structure indices
for t ∈ times
if Z.timestep isa Symbol
Δt = Z⃗[slice(t, Z.components[timestep_symbol], Z.dim)]
append!(values, R .* Δt.^2)
# ∂²_vₜ_Δt, ∂²_Δt_vₜ
vₜ = Z⃗[slice(t, Z.components[name], Z.dim)]
append!(values, 2 * (R .* (Δt .* vₜ)))
append!(values, 2 * (R .* (Δt .* vₜ)))
# ∂²_Δt_Δt
append!(values, vₜ' * (R .* vₜ))
else
Δt = Z.timestep
append!(values, R .* Δt.^2)
end
end
return values
end
end

return Objective(L, ∇L, ∂²L, ∂²L_structure, Dict[params])
Expand Down