Skip to content

Commit

Permalink
do not unnecessarily return a union in extract (#790)
Browse files Browse the repository at this point in the history
* extract tweaks

* just write kw... in docs

Co-authored-by: Rafael Schouten <[email protected]>

* use Type{G}

Co-authored-by: Rafael Schouten <[email protected]>

* add a missing where G

* merge _rowtype changes from other branch

* optimization for extract points with skipmissing

---------

Co-authored-by: Rafael Schouten <[email protected]>
  • Loading branch information
tiemvanderdeure and rafaqz authored Oct 12, 2024
1 parent b677b05 commit 8cf4942
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 80 deletions.
106 changes: 29 additions & 77 deletions src/methods/extract.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
using DimensionalData.Lookups: _True, _False

_booltype(x) = x ? _True() : _False()
istrue(::_True) = true
istrue(::_False) = false

"""
extract(x, data; atol)
extract(x, data; kw...)
Extracts the value of `Raster` or `RasterStack` at given points, returning
an iterable of `NamedTuple` with properties for `:geometry` and raster or
Expand Down Expand Up @@ -78,7 +72,7 @@ function extract end
end

function _extract(A::RasterStackOrArray, geom::Missing, names, kw...)
T = _rowtype(A, geom; names, kw...)
T = _extractrowtype(A, geom; names, kw...)
[_maybe_add_fields(T, map(_ -> missing, names), missing, missing)]
end
function _extract(A::RasterStackOrArray, geom; names, kw...)
Expand All @@ -89,9 +83,9 @@ function _extract(A::RasterStackOrArray, ::Nothing, data;
)
geoms = _get_geometries(data, geometrycolumn)
T = if istrue(skipmissing)
_rowtype(A, nonmissingtype(eltype(geoms)); names, skipmissing, kw...)
_extractrowtype(A, nonmissingtype(eltype(geoms)); names, skipmissing, kw...)
else
_rowtype(A, eltype(geoms); names, skipmissing, kw...)
_extractrowtype(A, eltype(geoms); names, skipmissing, kw...)
end
# Handle empty / all missing cases
(length(geoms) > 0 && any(!ismissing, geoms)) || return T[]
Expand All @@ -101,21 +95,25 @@ function _extract(A::RasterStackOrArray, ::Nothing, data;
# We need to split out points from other geoms
# TODO this will fail with mixed point/geom vectors
if trait1 isa GI.PointTrait
rows = Vector{T}(undef, length(geoms))
if istrue(skipmissing)
T2 = _extractrowtype(A, eltype(geoms), Tuple{Int64, Int64};
names, skipmissing = _False(), kw...)
rows = Vector{T2}(undef, length(geoms))
j = 1
for i in eachindex(geoms)
g = geoms[i]
ismissing(g) && continue
e = _extract_point(T, A, g, skipmissing; names, kw...)
e = _extract_point(T2, A, g, skipmissing; names, kw...)
if !ismissing(e)
rows[j] = e
j += 1
end
nothing
end
deleteat!(rows, j:length(rows))
rows = T === T2 ? rows : T.(rows)
else
rows = Vector{T}(undef, length(geoms))
for i in eachindex(geoms)
g = geoms[i]
rows[i] = _extract_point(T, A, g, skipmissing; names, kw...)::T
Expand Down Expand Up @@ -143,14 +141,14 @@ end
function _extract(A::RasterStackOrArray, ::GI.AbstractMultiPointTrait, geom;
skipmissing, kw...
)
T = _rowtype(A, GI.getpoint(geom, 1); names, skipmissing, kw...)
T = _extractrowtype(A, GI.getpoint(geom, 1); names, skipmissing, kw...)
rows = (_extract_point(T, A, p, skipmissing; kw...) for p in GI.getpoint(geom))
return skipmissing isa _True ? collect(_skip_missing_rows(rows, _missingval_or_missing(A), names)) : collect(rows)
end
function _extract(A::RasterStackOrArray, ::GI.PointTrait, geom;
skipmissing, kw...
)
T = _rowtype(A, geom; names, skipmissing, kw...)
T = _extractrowtype(A, geom; names, skipmissing, kw...)
_extract_point(T, A, geom, skipmissing; kw...)
end
function _extract(A::RasterStackOrArray, t::GI.AbstractGeometryTrait, geom;
Expand All @@ -168,7 +166,7 @@ function _extract(A::RasterStackOrArray, t::GI.AbstractGeometryTrait, geom;
else
GI.x(p), GI.y(p)
end
T = _rowtype(A, tuplepoint; names, skipmissing, kw...)
T = _extractrowtype(A, tuplepoint; names, skipmissing, kw...)
B = boolmask(geom; to=template, kw...)
offset = CartesianIndex(map(x -> first(x) - 1, parentindices(parent(template))))
# Add a row for each pixel that is `true` in the mask
Expand Down Expand Up @@ -290,70 +288,24 @@ Base.@assume_effects :total function _maybe_add_fields(::Type{T}, props::NamedTu
end
end

_names(A::AbstractRaster) = (Symbol(name(A)),)
_names(A::AbstractRasterStack) = keys(A)

@inline _nametypes(::Raster{T}, ::NamedTuple{Names}, skipmissing::_True) where {T,Names} = (nonmissingtype(T),)
@inline _nametypes(::Raster{T}, ::NamedTuple{Names}, skipmissing::_False) where {T,Names} = (Union{Missing,T},)
# This only compiles away when generated
@generated function _nametypes(
::RasterStack{<:Any,T}, ::NamedTuple{PropNames}, skipmissing::_True
) where {T<:NamedTuple{StackNames,Types},PropNames} where {StackNames,Types}
nt = NamedTuple{StackNames}(map(nonmissingtype, Types.parameters))
return values(nt[PropNames])
end
@generated function _nametypes(
::RasterStack{<:Any,T}, ::NamedTuple{PropNames}, skipmissing::_False
) where {T<:NamedTuple{StackNames,Types},PropNames} where {StackNames,Types}
nt = NamedTuple{StackNames}(map(T -> Union{Missing,T}, Types.parameters))
return values(nt[PropNames])
end

# _rowtype returns the complete NamedTuple type for a point row
# This code is entirely for types stability and performance.
_rowtype(x, g; kw...) = _rowtype(x, typeof(g); kw...)
_rowtype(x, g::Type; geometry, index, skipmissing, names, kw...) =
_rowtype(x, g, geometry, index, skipmissing, names)
function _rowtype(x, ::Type{G}, geometry::_True, index::_True, skipmissing::_True, names::NamedTuple{Names}) where {G,Names}
keys = (:geometry, :index, Names...,)
types = Tuple{G,Tuple{Int,Int},_nametypes(x, names, skipmissing)...}
NamedTuple{keys,types}
end
function _rowtype(x, ::Type{G}, geometry::_True, index::_False, skipmissing::_True, names::NamedTuple{Names}) where {G,Names}
keys = (:geometry, Names...,)
types = Tuple{G,_nametypes(x, names, skipmissing)...}
NamedTuple{keys,types}
end
function _rowtype(x, ::Type{G}, geometry::_False, index::_True, skipmissing::_True, names::NamedTuple{Names}) where {G,Names}
keys = (:index, Names...,)
types = Tuple{Tuple{Int,Int},_nametypes(x, names, skipmissing)...}
NamedTuple{keys,types}
end
function _rowtype(x, ::Type{G}, geometry::_False, index::_False, skipmissing::_True, names::NamedTuple{Names}) where {G,Names}
keys = Names
types = Tuple{_nametypes(x, names, skipmissing)...}
NamedTuple{keys,types}
end
function _rowtype(x, ::Type{G}, geometry::_True, index::_True, skipmissing::_False, names::NamedTuple{Names}) where {G,Names}
keys = (:geometry, :index, names...,)
types = Tuple{Union{Missing,G},Union{Missing,Tuple{Int,Int}},_nametypes(x, names, skipmissing)...}
NamedTuple{keys,types}
end
function _rowtype(x, ::Type{G}, geometry::_True, index::_False, skipmissing::_False, names::NamedTuple{Names}) where {G,Names}
keys = (:geometry, Names...,)
types = Tuple{Union{Missing,G},_nametypes(x, names, skipmissing)...}
NamedTuple{keys,types}
end
function _rowtype(x, ::Type{G}, geometry::_False, index::_True, skipmissing::_False, names::NamedTuple{Names}) where {G,Names}
keys = (:index, Names...,)
types = Tuple{Union{Missing,Tuple{Int,Int}},_nametypes(x, names, skipmissing)...}
NamedTuple{keys,types}
function _extractrowtype(x, g; geometry, index, skipmissing, names, kw...)
I = if istrue(skipmissing)
Tuple{Int, Int}
else
Union{Missing, Tuple{Int, Int}}
end
_extractrowtype(x, g, I; geometry, index, skipmissing, names, kw...)
end
function _rowtype(x, ::Type{G}, geometry::_False, index::_False, skipmissing::_False, names::NamedTuple{Names}) where {G,Names}
keys = Names
types = Tuple{_nametypes(x, names, skipmissing)...}
NamedTuple{keys,types}
function _extractrowtype(x, g, ::Type{I}; geometry, index, skipmissing, names, kw...) where I
G = if istrue(skipmissing)
nonmissingtype(typeof(g))
else
typeof(g)
end
_extractrowtype(x, G, I; geometry, index, skipmissing, names, kw...)
end
_extractrowtype(x, ::Type{G}, ::Type{I}; geometry, index, skipmissing, names, kw...) where {G, I} =
_rowtype(x, G, I; geometry, index, skipmissing, names)

@inline _skip_missing_rows(rows, ::Missing, names) =
Iterators.filter(row -> !any(ismissing, row), rows)
Expand Down
60 changes: 60 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,63 @@ function _no_memory_error(f, bytes)
"""
return error(msg)
end


# _rowtype returns the complete NamedTuple type for a point row
# This code is entirely for types stability and performance.
# It is used in extract and Rasters.sample
_names(A::AbstractRaster) = (Symbol(name(A)),)
_names(A::AbstractRasterStack) = keys(A)

using DimensionalData.Lookups: _True, _False
_booltype(x) = x ? _True() : _False()
istrue(::_True) = true
istrue(::_False) = false

function _rowtype(x, ::Type{G}, ::Type{I}; geometry, index, skipmissing, names) where {G, I}
keys = _rowkeys(geometry, index, names)
types = _rowtypes(x, G, I, geometry, index, skipmissing, names)
NamedTuple{keys,types}
end

function _rowtypes(
x, ::Type{G}, ::Type{I}, geometry::_True, index::_True, skipmissing, names::NamedTuple{Names}
) where {G,I,Names}
Tuple{G,I,_nametypes(x, names, skipmissing)...}
end
function _rowtypes(
x, ::Type{G}, ::Type{I}, geometry::_True, index::_False, skipmissing, names::NamedTuple{Names}
) where {G,I,Names}
Tuple{G,_nametypes(x, names, skipmissing)...}
end
function _rowtypes(
x, ::Type{G}, ::Type{I}, geometry::_False, index::_True, skipmissing, names::NamedTuple{Names}
) where {G,I,Names}
Tuple{I,_nametypes(x, names, skipmissing)...}
end
function _rowtypes(
x, ::Type{G}, ::Type{I}, geometry::_False, index::_False, skipmissing, names::NamedTuple{Names}
) where {G,I,Names}
Tuple{_nametypes(x, names, skipmissing)...}
end

@inline _nametypes(::Raster{T}, ::NamedTuple{Names}, skipmissing::_True) where {T,Names} = (nonmissingtype(T),)
@inline _nametypes(::Raster{T}, ::NamedTuple{Names}, skipmissing::_False) where {T,Names} = (Union{Missing,T},)
# This only compiles away when generated
@generated function _nametypes(
::RasterStack{<:Any,T}, ::NamedTuple{PropNames}, skipmissing::_True
) where {T<:NamedTuple{StackNames,Types},PropNames} where {StackNames,Types}
nt = NamedTuple{StackNames}(map(nonmissingtype, Types.parameters))
return values(nt[PropNames])
end
@generated function _nametypes(
::RasterStack{<:Any,T}, ::NamedTuple{PropNames}, skipmissing::_False
) where {T<:NamedTuple{StackNames,Types},PropNames} where {StackNames,Types}
nt = NamedTuple{StackNames}(map(T -> Union{Missing,T}, Types.parameters))
return values(nt[PropNames])
end

_rowkeys(geometry::_False, index::_False, names::NamedTuple{Names}) where Names = Names
_rowkeys(geometry::_True, index::_False, names::NamedTuple{Names}) where Names = (:geometry, Names...)
_rowkeys(geometry::_True, index::_True, names::NamedTuple{Names}) where Names = (:geometry, :index, Names...)
_rowkeys(geometry::_False, index::_True, names::NamedTuple{Names}) where Names = (:index, Names...)
6 changes: 3 additions & 3 deletions test/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ createpoint(args...) = ArchGDAL.createpoint(args...)
(index = (1, 2), test = 2,)
])
# NamedTuple (reversed) points - tests a Table that iterates over points
T = @NamedTuple{geometry::Union{Missing,@NamedTuple{Y::Float64,X::Float64}},test::Union{Missing,Int64}}
T = @NamedTuple{geometry::Union{@NamedTuple{Y::Float64,X::Float64}},test::Union{Missing,Int64}}
@test all(extract(rast, [(Y=0.1, X=9.0), (Y=0.2, X=10.0), (Y=0.3, X=10.0)]) .=== T[
(geometry = (Y = 0.1, X = 9.0), test = 1)
(geometry = (Y = 0.2, X = 10.0), test = 4)
Expand All @@ -412,7 +412,7 @@ createpoint(args...) = ArchGDAL.createpoint(args...)
])
# Extract a polygon
p = ArchGDAL.createpolygon([[[8.0, 0.0], [11.0, 0.0], [11.0, 0.4], [8.0, 0.0]]])
T = @NamedTuple{geometry::Union{Missing,Tuple{Float64,Float64}},test::Union{Missing,Int64}}
T = @NamedTuple{geometry::Union{Tuple{Float64,Float64}},test::Union{Missing,Int64}}
@test all(extract(rast_m, p) .=== T[
(geometry = (9.0, 0.1), test = 1)
(geometry = (10.0, 0.1), test = 3)
Expand Down Expand Up @@ -447,7 +447,7 @@ createpoint(args...) = ArchGDAL.createpoint(args...)
(index = (2, 1), test = 3)
(index = (2, 2), test = missing)
])
T = @NamedTuple{geometry::Union{Missing,Tuple{Float64,Float64}},index::Union{Missing,Tuple{Int,Int}},test::Union{Missing,Int64}}
T = @NamedTuple{geometry::Union{Tuple{Float64,Float64}},index::Union{Missing,Tuple{Int,Int}},test::Union{Missing,Int64}}
@test all(extract(rast_m, p; index=true) .=== T[
(geometry = (9.0, 0.1), index = (1, 1), test = 1)
(geometry = (10.0, 0.1), index = (2, 1), test = 3)
Expand Down

0 comments on commit 8cf4942

Please sign in to comment.