From 44b0fc4e9e4cfb70d8c9c24f578a4ee03f961d71 Mon Sep 17 00:00:00 2001 From: tiemvanderdeure Date: Tue, 15 Oct 2024 09:56:45 +0200 Subject: [PATCH] deal with type stability better --- .../RastersStatsBaseExt.jl | 2 +- ext/RastersStatsBaseExt/sample.jl | 72 ++++++++++--------- src/methods/extract.jl | 10 +-- 3 files changed, 41 insertions(+), 43 deletions(-) diff --git a/ext/RastersStatsBaseExt/RastersStatsBaseExt.jl b/ext/RastersStatsBaseExt/RastersStatsBaseExt.jl index d1f770dc..fd6f1e91 100644 --- a/ext/RastersStatsBaseExt/RastersStatsBaseExt.jl +++ b/ext/RastersStatsBaseExt/RastersStatsBaseExt.jl @@ -5,7 +5,7 @@ using StatsBase.Random const RA = Rasters -import Rasters: _True, _False, _booltype +import Rasters: _True, _False, _booltype, istrue import Rasters.DimensionalData as DD include("sample.jl") diff --git a/ext/RastersStatsBaseExt/sample.jl b/ext/RastersStatsBaseExt/sample.jl index 13b42172..bd71edf9 100644 --- a/ext/RastersStatsBaseExt/sample.jl +++ b/ext/RastersStatsBaseExt/sample.jl @@ -13,60 +13,62 @@ Rasters.sample(x::RA.RasterStackOrArray, n::Integer; kw...) = Rasters.sample(Ran ) na = DD._astuple(name) geometry, geometrytype, dims = _geometrytype(x, geometry) - # x = x isa RA.AbstractRasterStack ? x[na] : x - _sample(rng, x, n; - dims=DD.dims(x, RA.DEFAULT_POINT_ORDER), - names=NamedTuple{na}(na), + + _sample(rng, x, n, + dims, + NamedTuple{na}(na), + geometry, + geometrytype, # These keywords are converted to _True/_False for type stability later on - # The @inline above helps constant propagation of the Bools - geometry=geometry, - geometrytype=geometrytype, - index=_booltype(index), - skipmissing=_booltype(skipmissing), + _booltype(index), + _booltype(skipmissing), weights, + weightstype, replace, - ordered, - weightstype + ordered ) end function _sample( - rng, x, n; - dims, names::NamedTuple{K}, geometry, geometrytype, index, skipmissing, weights, replace, ordered, weightstype -) where K + rng, x, n, + dims, names::NamedTuple{K}, geometry, ::Type{G}, index, skipmissing, weights, weightstype, replace, ordered, +) where {K, G} indices = sample_indices(rng, x, n, skipmissing, weights, replace, ordered, weightstype) - points = DimPoints(dims) - T = RA._rowtype(x, geometrytype; geometry, index, skipmissing, skipinvalid = _True(), names) + T = RA._rowtype(x, G; geometry, index, skipmissing, skipinvalid = _True(), names) x2 = x isa AbstractRasterStack ? x[K] : RasterStack(NamedTuple{K}((x,))) - _getindices(T, x2, points, indices) + return _getindices(T, x2, dims, indices) end -_getindices(::Type{T}, x, points, indices) where T = - broadcast(I -> _getindex(T, x, points, I), indices) - -_getindex(::Type{T}, x::AbstractRasterStack{<:Any, NT}, points, idx) where {T, NT} = - RA._maybe_add_fields(T, NT(x[RA.commondims(idx, x)]), points[RA.commondims(idx, points)], val(idx)) +_getindices(::Type{T}, x, dims, indices) where {T} = + broadcast(I -> _getindex(T, x, dims, I), indices) -function sample_indices(rng, x, n, skipmissing::_False, weights::Nothing, replace, ordered, weightstype) - StatsBase.sample(rng, RA.DimIndices(x), n; replace, ordered) +function _getindex(::Type{T}, x::AbstractRasterStack{<:Any, NT}, dims, idx) where {T, NT} + RA._maybe_add_fields( + T, + NT(x[RA.commondims(idx, x)]), + DimPoints(dims)[RA.commondims(idx, dims)], + val(idx) + ) end -function sample_indices(rng, x, n, skipmissing::_True, weights::Nothing, replace, ordered, weightstype) - wts = weightstype(vec(boolmask(x))) - StatsBase.sample(rng, RA.DimIndices(x), wts, n; replace, ordered) + +function sample_indices(rng, x, n, skipmissing, weights::Nothing, replace, ordered, _) + if istrue(skipmissing) + wts = StatsBase.Weights(vec(boolmask(x))) + StatsBase.sample(rng, RA.DimIndices(x), wts, n; replace, ordered) + else + StatsBase.sample(rng, RA.DimIndices(x), n; replace, ordered) + end end -function sample_indices(rng, x, n, skipmissing::_False, weights::AbstractDimArray, replace, ordered, weightstype) - wts = if dims(weights) == dims(x) +function sample_indices(rng, x, n, skipmissing, weights::AbstractDimArray, replace, ordered, ::Type{W}) where W + wts = if istrue(skipmissing) + @d boolmask(x) .* weights + elseif dims(weights) == dims(x) weights else @d ones(eltype(weights), dims(x)) .* weights - end |> vec |> weightstype + end |> vec |> W StatsBase.sample(rng, RA.DimIndices(x), wts, n; replace, ordered) end -function sample_indices(rng, x, n, skipmissing::_True, weights::AbstractDimArray, replace, ordered, weightstype) - wts = weightstype(vec(@d boolmask(x) .* weights)) - StatsBase.sample(rng, RA.DimIndices(x), wts, n; replace, ordered) -end - function _geometrytype(x, geometry::Bool) if geometry error("Specify a geometry type by setting `geometry` to a Tuple or NamedTuple of Dimensions. E.g. `geometry = (X, Y)`") diff --git a/src/methods/extract.jl b/src/methods/extract.jl index b9b95470..ef41b04e 100644 --- a/src/methods/extract.jl +++ b/src/methods/extract.jl @@ -91,15 +91,13 @@ function _extract(A::RasterStackOrArray, ::Nothing, data; # We need to split out points from other geoms # TODO this will fail with mixed point/geom vectors if trait1 isa GI.PointTrait + rows = Vector{T}(undef, length(geoms)) if istrue(skipmissing) - T2 = _rowtype(A, eltype(geoms); - names, skipinvalid = _True(), skipmissing = _False(), kw...) - rows = Vector{T2}(undef, length(geoms)) j = 1 for i in eachindex(geoms) g = geoms[i] ismissing(g) && continue - e = _extract_point(T2, A, g, skipmissing; names, kw...) + e = _extract_point(T, A, g, skipmissing; names, kw...) if !ismissing(e) rows[j] = e j += 1 @@ -107,9 +105,7 @@ function _extract(A::RasterStackOrArray, ::Nothing, data; nothing end deleteat!(rows, j:length(rows)) - rows = T === T2 ? rows : T.(rows) else - rows = Vector{T}(undef, length(geoms)) for i in eachindex(geoms) g = geoms[i] rows[i] = _extract_point(T, A, g, skipmissing; names, kw...)::T @@ -281,7 +277,7 @@ Base.@assume_effects :total function _maybe_add_fields(::Type{T}, props::NamedTu :index in K ? merge((; geometry=point, index=I), props) : merge((; geometry=point), props) else :index in K ? merge((; index=I), props) : props - end + end |> T end @inline _skip_missing_rows(rows, ::Missing, names) =