Skip to content

Commit

Permalink
Preserve more FieldArrays with parametric eltype.
Browse files Browse the repository at this point in the history
And return a `MArray` for mutable `FieldArray`
  • Loading branch information
N5N3 committed Aug 1, 2022
1 parent def8fc2 commit c6e2676
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 23 deletions.
38 changes: 22 additions & 16 deletions src/FieldArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ consider defining `similar_type` as in the `FieldVector` example.
yyyy::Float64
end
"""
abstract type FieldArray{N, T, D} <: StaticArray{N, T, D} end
abstract type FieldArray{N,T,D} <: StaticArray{N,T,D} end

"""
abstract FieldMatrix{N1, N2, T} <: FieldArray{Tuple{N1, N2}, 2}
Expand Down Expand Up @@ -84,7 +84,7 @@ you may consider using the alternative
4.0 5.0 6.0;
7.0 8.0 9.0])
"""
abstract type FieldMatrix{N1, N2, T} <: FieldArray{Tuple{N1, N2}, T, 2} end
abstract type FieldMatrix{N1,N2,T} <: FieldArray{Tuple{N1,N2},T,2} end

"""
abstract FieldVector{N, T} <: FieldArray{Tuple{N}, 1}
Expand All @@ -108,11 +108,11 @@ array operations as in the example below.
StaticArrays.similar_type(::Type{<:Vec3D}, ::Type{T}, s::Size{(3,)}) where {T} = Vec3D{T}
"""
abstract type FieldVector{N, T} <: FieldArray{Tuple{N}, T, 1} end
abstract type FieldVector{N,T} <: FieldArray{Tuple{N},T,1} end

@inline (::Type{FA})(x::Tuple) where {FA <: FieldArray} = construct_type(FA, x)(x...)
@inline (::Type{FA})(x::Tuple) where {FA<:FieldArray} = construct_type(FA, x)(x...)

function construct_type(::Type{FA}, x) where {FA <: FieldArray}
function construct_type(::Type{FA}, x) where {FA<:FieldArray}
has_size(FA) || error("$FA has no static size!")
length_match_size(FA, x)
return adapt_eltype(FA, x)
Expand All @@ -125,14 +125,20 @@ Base.cconvert(::Type{<:Ptr}, a::FieldArray) = Base.RefValue(a)
Base.unsafe_convert(::Type{Ptr{T}}, m::Base.RefValue{FA}) where {N,T,D,FA<:FieldArray{N,T,D}} =
Ptr{T}(Base.unsafe_convert(Ptr{FA}, m))

# We can automatically preserve FieldArrays in array operations which do not
# change their eltype or Size. This should cover all non-parametric FieldArray,
# but for those which are parametric on the eltype the user will still need to
# overload similar_type themselves.
similar_type(::Type{A}, ::Type{T}, S::Size) where {N, T, A<:FieldArray{N, T}} =
_fieldarray_similar_type(A, T, S, Size(A))

# Extra layer of dispatch to match NewSize and OldSize
_fieldarray_similar_type(A, T, NewSize::S, OldSize::S) where {S} = A
_fieldarray_similar_type(A, T, NewSize, OldSize) =
default_similar_type(T, NewSize, length_val(NewSize))
# We can preserve FieldArrays in array operations which do not change their `Size` and `eltype`.
# FieldArrays with parametric `eltype` would be adapted to the new `eltype` automatically.
# Otherwise, we fallback to `S/MArray` based on it's mutability.
function similar_type(::Type{A}, ::Type{T}, S::Size) where {T,A<:FieldArray}
A′ = Base.typeintersect(base_type(A), StaticArray{Tuple{Tuple(S)...},T,length(S)})
isabstracttype(A′) || A′ === Union{} || return A′
if ismutabletype(A)
return mutable_similar_type(T, S, length_val(S))
else
return default_similar_type(T, S, length_val(S))
end
end

@pure base_type(@nospecialize(T::Type)) = Base.unwrap_unionall(T).name.wrapper
if VERSION < v"1.7"
@pure ismutabletype(@nospecialize(T::Type)) = Base.unwrap_unionall(T).mutable
end
5 changes: 2 additions & 3 deletions test/FieldMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
yy::T
end

StaticArrays.similar_type(::Type{<:Tensor2x2}, ::Type{T}, s::Size{(2,2)}) where {T} = Tensor2x2{T}
end)

p = Tensor2x2(0.0, 0.0, 0.0, 0.0)
Expand All @@ -83,8 +82,8 @@

@test @inferred(similar_type(Tensor2x2{Float64})) == Tensor2x2{Float64}
@test @inferred(similar_type(Tensor2x2{Float64}, Float32)) == Tensor2x2{Float32}
@test @inferred(similar_type(Tensor2x2{Float64}, Size(3,3))) == SMatrix{3,3,Float64,9}
@test @inferred(similar_type(Tensor2x2{Float64}, Float32, Size(4,4))) == SMatrix{4,4,Float32,16}
@test @inferred(similar_type(Tensor2x2{Float64}, Size(3, 3))) == MMatrix{3,3,Float64,9}
@test @inferred(similar_type(Tensor2x2{Float64}, Float32, Size(4, 4))) == MMatrix{4,4,Float32,16}

# eltype promotion
@test Tuple(@inferred(Tensor2x2(1., 2, 3, 4f0))) === (1.,2.,3.,4.)
Expand Down
7 changes: 3 additions & 4 deletions test/FieldVector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
y::T
end

StaticArrays.similar_type(::Type{<:Point2D}, ::Type{T}, s::Size{(2,)}) where {T} = Point2D{T}
end)

p = Point2D(0.0, 0.0)
Expand All @@ -86,8 +85,8 @@

@test @inferred(similar_type(Point2D{Float64})) == Point2D{Float64}
@test @inferred(similar_type(Point2D{Float64}, Float32)) == Point2D{Float32}
@test @inferred(similar_type(Point2D{Float64}, Size(4))) == SVector{4,Float64}
@test @inferred(similar_type(Point2D{Float64}, Float32, Size(4))) == SVector{4,Float32}
@test @inferred(similar_type(Point2D{Float64}, Size(4))) == MVector{4,Float64}
@test @inferred(similar_type(Point2D{Float64}, Float32, Size(4))) == MVector{4,Float32}

# eltype promotion
@test Point2D(1f0, 2) isa Point2D{Float32}
Expand Down Expand Up @@ -122,7 +121,7 @@
# No similar_type defined - test fallback codepath
end)

@test @inferred(similar_type(FVT{Float64}, Float32)) == SVector{2,Float32} # Fallback code path
@test @inferred(similar_type(FVT{Float64}, Float32)) == FVT{Float32}
@test @inferred(similar_type(FVT{Float64}, Size(2))) == FVT{Float64}
@test @inferred(similar_type(FVT{Float64}, Size(3))) == SVector{3,Float64}
@test @inferred(similar_type(FVT{Float64}, Float32, Size(3))) == SVector{3,Float32}
Expand Down

0 comments on commit c6e2676

Please sign in to comment.