From 4056c71627be247a5165b3bdb11040690b090893 Mon Sep 17 00:00:00 2001 From: Pietro Vertechi Date: Fri, 14 Oct 2022 17:03:45 +0200 Subject: [PATCH] refactor finding consistent value (#252) * refactor finding consistent value * add internal docs * add doc compat * remove outdated docstring * fix inferrability on older julia --- docs/Project.toml | 2 ++ docs/src/reference.md | 2 ++ src/structarray.jl | 17 ++++++----------- src/utils.jl | 12 ++++++++++++ test/runtests.jl | 20 ++++++++++++++------ 5 files changed, 36 insertions(+), 17 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 78f2a7bf..b741cfe1 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,6 +1,8 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] +Documenter = "0.27" PooledArrays = "1" diff --git a/docs/src/reference.md b/docs/src/reference.md index bf7a456f..d0195e94 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -57,4 +57,6 @@ StructArrays.map_params StructArrays.buildfromschema StructArrays.bypass_constructor StructArrays.iscompatible +StructArrays.maybe_convert_elt +StructArrays.findconsistentvalue ``` \ No newline at end of file diff --git a/src/structarray.jl b/src/structarray.jl index 3fa2029e..d4bf529f 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -14,12 +14,10 @@ struct StructArray{T, N, C<:Tup, I} <: AbstractArray{T, N} components::C function StructArray{T, N, C}(c) where {T, N, C<:Tup} - isempty(c) && error("only eltypes with fields are supported") - ax = axes(first(c)) - length(ax) == N || error("wrong number of dimensions") - map(tail(c)) do ci - axes(ci) == ax || error("all field arrays must have same shape") - end + isempty(c) && throw(ArgumentError("only eltypes with fields are supported")) + ax = findconsistentvalue(axes, c) + (ax === nothing) && throw(ArgumentError("all component arrays must have the same shape")) + length(ax) == N || throw(ArgumentError("wrong number of dimensions")) new{T, N, C, index_type(c)}(c) end end @@ -119,9 +117,6 @@ Construct a `StructArray` from slices of `A` along `dims`. The `unwrap` keyword argument is a function that determines whether to recursively convert fields of type `FT` to `StructArray`s. -!!! compat "Julia 1.1" - This function requires at least Julia 1.1. - ```julia-repl julia> X = [1.0 2.0; 3.0 4.0] 2×2 Array{Float64,2}: @@ -369,8 +364,8 @@ end end function Base.parentindices(s::StructArray) - res = parentindices(component(s, 1)) - all(c -> parentindices(c) == res, components(s)) || throw(ArgumentError("inconsistent parentindices of components")) + res = findconsistentvalue(parentindices, components(s)) + (res === nothing) && throw(ArgumentError("inconsistent parentindices of components")) return res end diff --git a/src/utils.jl b/src/utils.jl index 5f4e9936..f1ade86a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -196,3 +196,15 @@ By default, this calls `convert(T, x)`; however, you can specialize it for other maybe_convert_elt(::Type{T}, vals) where T = convert(T, vals) maybe_convert_elt(::Type{T}, vals::Tuple) where T = T <: Tuple ? convert(T, vals) : vals # assignment of fields by position maybe_convert_elt(::Type{T}, vals::NamedTuple) where T = T<:NamedTuple ? convert(T, vals) : vals # assignment of fields by name + +""" + findconsistentvalue(f, componenents::Union{Tuple, NamedTuple}) + +Compute the unique value that `f` takes on each `component ∈ componenents`. +If not all values are equal, return `nothing`. Otherwise, return the unique value. +""" +function findconsistentvalue(f::F, (col, cols...)::Tup) where F + val = f(col) + isconsistent = mapfoldl(isequal(val)∘f, &, cols; init=true) + return ifelse(isconsistent, val, nothing) +end diff --git a/test/runtests.jl b/test/runtests.jl index 2c11b34f..4693ca1b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -105,6 +105,14 @@ end @test StructArrays.strip_params(Tuple{Int}) == Tuple @test StructArrays.astuple(NamedTuple{(:a,), Tuple{Float64}}) == Tuple{Float64} @test StructArrays.strip_params(NamedTuple{(:a,), Tuple{Float64}}) == NamedTuple{(:a,)} + + cols = (a=rand(2), b=rand(2), c=rand(2)) + @test StructArrays.findconsistentvalue(length, cols) == 2 + @test StructArrays.findconsistentvalue(length, Tuple(cols)) == 2 + + cols = (a=rand(2), b=rand(2), c=rand(3)) + @test isnothing(StructArrays.findconsistentvalue(length, cols)) + @test isnothing(StructArrays.findconsistentvalue(length, Tuple(cols))) end @testset "indexstyle" begin @@ -439,8 +447,8 @@ end @test isequal(t.a, [1, missing]) @test eltype(t) <: NamedTuple{(:a,)} - @test_throws ErrorException StructArray([nothing]) - @test_throws ErrorException StructArray([1, 2, 3]) + @test_throws ArgumentError StructArray([nothing]) + @test_throws ArgumentError StructArray([1, 2, 3]) end @testset "tuple case" begin @@ -460,10 +468,10 @@ end @test getproperty(t, 1) == [2] @test getproperty(t, 2) == [3.0] - @test_throws ErrorException StructArray(([1, 2], [3])) + @test_throws ArgumentError StructArray(([1, 2], [3])) - @test_throws ErrorException StructArray{Tuple{}}(()) - @test_throws ErrorException StructArray{Tuple{}, 1, Tuple{}}(()) + @test_throws ArgumentError StructArray{Tuple{}}(()) + @test_throws ArgumentError StructArray{Tuple{}, 1, Tuple{}}(()) end @testset "constructor from slices" begin @@ -503,7 +511,7 @@ end @test t1 == StructArray((a=[1.2], b=["test"])) @test t2 == StructArray{Pair{Float64, String}}(([1.2], ["test"])) - @test_throws ErrorException StructArray(a=[1, 2], b=[3]) + @test_throws ArgumentError StructArray(a=[1, 2], b=[3]) end @testset "complex" begin