Skip to content

Commit

Permalink
refactor finding consistent value (#252)
Browse files Browse the repository at this point in the history
* refactor finding consistent value

* add internal docs

* add doc compat

* remove outdated docstring

* fix inferrability on older julia
  • Loading branch information
piever authored Oct 14, 2022
1 parent 0933432 commit 4056c71
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 17 deletions.
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
Documenter = "0.27"
PooledArrays = "1"
2 changes: 2 additions & 0 deletions docs/src/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,6 @@ StructArrays.map_params
StructArrays.buildfromschema
StructArrays.bypass_constructor
StructArrays.iscompatible
StructArrays.maybe_convert_elt
StructArrays.findconsistentvalue
```
17 changes: 6 additions & 11 deletions src/structarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ struct StructArray{T, N, C<:Tup, I} <: AbstractArray{T, N}
components::C

function StructArray{T, N, C}(c) where {T, N, C<:Tup}
isempty(c) && error("only eltypes with fields are supported")
ax = axes(first(c))
length(ax) == N || error("wrong number of dimensions")
map(tail(c)) do ci
axes(ci) == ax || error("all field arrays must have same shape")
end
isempty(c) && throw(ArgumentError("only eltypes with fields are supported"))
ax = findconsistentvalue(axes, c)
(ax === nothing) && throw(ArgumentError("all component arrays must have the same shape"))
length(ax) == N || throw(ArgumentError("wrong number of dimensions"))
new{T, N, C, index_type(c)}(c)
end
end
Expand Down Expand Up @@ -119,9 +117,6 @@ Construct a `StructArray` from slices of `A` along `dims`.
The `unwrap` keyword argument is a function that determines whether to
recursively convert fields of type `FT` to `StructArray`s.
!!! compat "Julia 1.1"
This function requires at least Julia 1.1.
```julia-repl
julia> X = [1.0 2.0; 3.0 4.0]
2×2 Array{Float64,2}:
Expand Down Expand Up @@ -369,8 +364,8 @@ end
end

function Base.parentindices(s::StructArray)
res = parentindices(component(s, 1))
all(c -> parentindices(c) == res, components(s)) || throw(ArgumentError("inconsistent parentindices of components"))
res = findconsistentvalue(parentindices, components(s))
(res === nothing) && throw(ArgumentError("inconsistent parentindices of components"))
return res
end

Expand Down
12 changes: 12 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,15 @@ By default, this calls `convert(T, x)`; however, you can specialize it for other
maybe_convert_elt(::Type{T}, vals) where T = convert(T, vals)
maybe_convert_elt(::Type{T}, vals::Tuple) where T = T <: Tuple ? convert(T, vals) : vals # assignment of fields by position
maybe_convert_elt(::Type{T}, vals::NamedTuple) where T = T<:NamedTuple ? convert(T, vals) : vals # assignment of fields by name

"""
findconsistentvalue(f, componenents::Union{Tuple, NamedTuple})
Compute the unique value that `f` takes on each `component ∈ componenents`.
If not all values are equal, return `nothing`. Otherwise, return the unique value.
"""
function findconsistentvalue(f::F, (col, cols...)::Tup) where F
val = f(col)
isconsistent = mapfoldl(isequal(val)f, &, cols; init=true)
return ifelse(isconsistent, val, nothing)
end
20 changes: 14 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ end
@test StructArrays.strip_params(Tuple{Int}) == Tuple
@test StructArrays.astuple(NamedTuple{(:a,), Tuple{Float64}}) == Tuple{Float64}
@test StructArrays.strip_params(NamedTuple{(:a,), Tuple{Float64}}) == NamedTuple{(:a,)}

cols = (a=rand(2), b=rand(2), c=rand(2))
@test StructArrays.findconsistentvalue(length, cols) == 2
@test StructArrays.findconsistentvalue(length, Tuple(cols)) == 2

cols = (a=rand(2), b=rand(2), c=rand(3))
@test isnothing(StructArrays.findconsistentvalue(length, cols))
@test isnothing(StructArrays.findconsistentvalue(length, Tuple(cols)))
end

@testset "indexstyle" begin
Expand Down Expand Up @@ -439,8 +447,8 @@ end
@test isequal(t.a, [1, missing])
@test eltype(t) <: NamedTuple{(:a,)}

@test_throws ErrorException StructArray([nothing])
@test_throws ErrorException StructArray([1, 2, 3])
@test_throws ArgumentError StructArray([nothing])
@test_throws ArgumentError StructArray([1, 2, 3])
end

@testset "tuple case" begin
Expand All @@ -460,10 +468,10 @@ end
@test getproperty(t, 1) == [2]
@test getproperty(t, 2) == [3.0]

@test_throws ErrorException StructArray(([1, 2], [3]))
@test_throws ArgumentError StructArray(([1, 2], [3]))

@test_throws ErrorException StructArray{Tuple{}}(())
@test_throws ErrorException StructArray{Tuple{}, 1, Tuple{}}(())
@test_throws ArgumentError StructArray{Tuple{}}(())
@test_throws ArgumentError StructArray{Tuple{}, 1, Tuple{}}(())
end

@testset "constructor from slices" begin
Expand Down Expand Up @@ -503,7 +511,7 @@ end
@test t1 == StructArray((a=[1.2], b=["test"]))
@test t2 == StructArray{Pair{Float64, String}}(([1.2], ["test"]))

@test_throws ErrorException StructArray(a=[1, 2], b=[3])
@test_throws ArgumentError StructArray(a=[1, 2], b=[3])
end

@testset "complex" begin
Expand Down

2 comments on commit 4056c71

@piever
Copy link
Collaborator Author

@piever piever commented on 4056c71 Oct 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/70486

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.13 -m "<description of version>" 4056c71627be247a5165b3bdb11040690b090893
git push origin v0.6.13

Please sign in to comment.