Skip to content

Commit

Permalink
Add value_flatten as described in JuliaGaussianProcesses#37.
Browse files Browse the repository at this point in the history
  • Loading branch information
jonniedie committed Aug 27, 2021
1 parent 4390386 commit 8cff65d
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ParameterHandling"
uuid = "2412ca09-6db7-441c-8e3a-88d5709968c5"
authors = ["Invenia Technical Computing Corporation"]
version = "0.3.6"
version = "0.3.7"

[deps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Expand Down
2 changes: 1 addition & 1 deletion src/ParameterHandling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using ChainRulesCore
using LinearAlgebra
using SparseArrays

export flatten, positive, bounded, fixed, deferred, orthogonal, positive_definite
export flatten, value_flatten, positive, bounded, fixed, deferred, orthogonal, positive_definite

include("flatten.jl")
include("parameters.jl")
Expand Down
21 changes: 21 additions & 0 deletions src/flatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,24 @@ _cumsum(x) = cumsum(x)
if VERSION < v"1.5"
_cumsum(x::Tuple) = (_cumsum(collect(x))..., )
end

"""
value_flatten([eltype=Float64], x)
Operates similarly to `flatten`, but the returned `unflatten` function returns an object
like `x`, but with unwrapped values.
Doing
```julia
v, unflatten = value_flatten(x)
```
is the same as doing
```julia
v, _unflatten = flatten(x)
unflatten = ParameterHandling.value ∘ _unflatten
```
"""
function value_flatten(args...)
v, unflatten = flatten(args...)
return v, value unflatten
end
46 changes: 46 additions & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,52 @@ function test_flatten_interface(x::T; check_inferred::Bool=true) where T
return nothing
end

function test_value_flatten_interface(x::T; check_inferred::Bool=true) where T
@testset "value_flatten($T)" begin
# Checks default eltype still works and ensure that
# basic functionality is implemented.
v, unflatten = value_flatten(x)
@test typeof(v) === Vector{Float64}
@test default_equality(value(x), unflatten(v))
@test unflatten(v) isa T

# Check that everything infers properly.
check_inferred && @inferred value_flatten(x)

# Test with different precisions
@testset "Float64" begin
_v, _unflatten = value_flatten(Float64, x)
@test typeof(_v) === Vector{Float64}
@test _v == v
@test default_equality(value(x), unflatten(_v))
@test _unflatten(_v) isa T

# Check that everything infers properly.
check_inferred && @inferred value_flatten(Float64, x)
end
@testset "Float32" begin
_v, _unflatten = value_flatten(Float32, x)
@test typeof(_v) === Vector{Float32}
@test default_equality(value(x), _unflatten(_v); atol=1e-5)
@test _unflatten(_v) isa T

# Check that everything infers properly.
check_inferred && @inferred value_flatten(Float32, x)
end
@testset "Float16" begin
_v, _unflatten = value_flatten(Float16, x)
@test typeof(_v) === Vector{Float16}
@test default_equality(value(x), _unflatten(_v); atol=1e-2)
@test _unflatten(_v) isa T

# Check that everything infers properly.
check_inferred && @inferred value_flatten(Float16, x)
end
end

return nothing
end

function test_parameter_interface(x; check_inferred::Bool=true)

# Parameters need to be flatten-able.
Expand Down
40 changes: 40 additions & 0 deletions test/flatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,43 @@
test_flatten_interface(Dict(:a => 4.0, :b => 5.0); check_inferred=false)
end
end

@testset "value_flatten" begin

@testset "Reals" begin
test_value_flatten_interface(1.0)

@testset "Integers" begin
test_value_flatten_interface(1)
@test isempty(first(value_flatten(1)))
end
end

@testset "AbstractArrays" begin
test_value_flatten_interface(randn(10))
test_value_flatten_interface(randn(5, 4))
test_value_flatten_interface([randn(5) for _ in 1:3])
end

@testset "SparseMatrixCSC" begin
test_value_flatten_interface(sprand(10, 10, 0.5))
end

@testset "Tuple" begin
test_value_flatten_interface((1.0, 2.0); check_inferred=tuple_infers)

test_value_flatten_interface(
(1.0, (2.0, 3.0), randn(5)); check_inferred=tuple_infers,
)
end

@testset "NamedTuple" begin
test_value_flatten_interface(
(a=1.0, b=(2.0, 3.0), c=(e=5.0,)); check_inferred=tuple_infers,
)
end

@testset "Dict" begin
test_value_flatten_interface(Dict(:a => 4.0, :b => 5.0); check_inferred=false)
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using Zygote
using SparseArrays

using ParameterHandling: value
using ParameterHandling.TestUtils: test_flatten_interface, test_parameter_interface
using ParameterHandling.TestUtils: test_flatten_interface, test_value_flatten_interface, test_parameter_interface

const tuple_infers = VERSION < v"1.5" ? false : true

Expand Down

0 comments on commit 8cff65d

Please sign in to comment.