Skip to content

Commit

Permalink
deal with type stability better
Browse files Browse the repository at this point in the history
  • Loading branch information
tiemvanderdeure committed Oct 15, 2024
1 parent e18a67d commit 44b0fc4
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 43 deletions.
2 changes: 1 addition & 1 deletion ext/RastersStatsBaseExt/RastersStatsBaseExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using StatsBase.Random

const RA = Rasters

import Rasters: _True, _False, _booltype
import Rasters: _True, _False, _booltype, istrue
import Rasters.DimensionalData as DD

include("sample.jl")
Expand Down
72 changes: 37 additions & 35 deletions ext/RastersStatsBaseExt/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,60 +13,62 @@ Rasters.sample(x::RA.RasterStackOrArray, n::Integer; kw...) = Rasters.sample(Ran
)
na = DD._astuple(name)
geometry, geometrytype, dims = _geometrytype(x, geometry)
# x = x isa RA.AbstractRasterStack ? x[na] : x
_sample(rng, x, n;
dims=DD.dims(x, RA.DEFAULT_POINT_ORDER),
names=NamedTuple{na}(na),

_sample(rng, x, n,
dims,
NamedTuple{na}(na),
geometry,
geometrytype,
# These keywords are converted to _True/_False for type stability later on
# The @inline above helps constant propagation of the Bools
geometry=geometry,
geometrytype=geometrytype,
index=_booltype(index),
skipmissing=_booltype(skipmissing),
_booltype(index),
_booltype(skipmissing),
weights,
weightstype,
replace,
ordered,
weightstype
ordered
)

end
function _sample(
rng, x, n;
dims, names::NamedTuple{K}, geometry, geometrytype, index, skipmissing, weights, replace, ordered, weightstype
) where K
rng, x, n,
dims, names::NamedTuple{K}, geometry, ::Type{G}, index, skipmissing, weights, weightstype, replace, ordered,
) where {K, G}
indices = sample_indices(rng, x, n, skipmissing, weights, replace, ordered, weightstype)
points = DimPoints(dims)
T = RA._rowtype(x, geometrytype; geometry, index, skipmissing, skipinvalid = _True(), names)
T = RA._rowtype(x, G; geometry, index, skipmissing, skipinvalid = _True(), names)
x2 = x isa AbstractRasterStack ? x[K] : RasterStack(NamedTuple{K}((x,)))
_getindices(T, x2, points, indices)
return _getindices(T, x2, dims, indices)
end

_getindices(::Type{T}, x, points, indices) where T =
broadcast(I -> _getindex(T, x, points, I), indices)

_getindex(::Type{T}, x::AbstractRasterStack{<:Any, NT}, points, idx) where {T, NT} =
RA._maybe_add_fields(T, NT(x[RA.commondims(idx, x)]), points[RA.commondims(idx, points)], val(idx))
_getindices(::Type{T}, x, dims, indices) where {T} =
broadcast(I -> _getindex(T, x, dims, I), indices)

function sample_indices(rng, x, n, skipmissing::_False, weights::Nothing, replace, ordered, weightstype)
StatsBase.sample(rng, RA.DimIndices(x), n; replace, ordered)
function _getindex(::Type{T}, x::AbstractRasterStack{<:Any, NT}, dims, idx) where {T, NT}
RA._maybe_add_fields(
T,
NT(x[RA.commondims(idx, x)]),
DimPoints(dims)[RA.commondims(idx, dims)],
val(idx)
)
end
function sample_indices(rng, x, n, skipmissing::_True, weights::Nothing, replace, ordered, weightstype)
wts = weightstype(vec(boolmask(x)))
StatsBase.sample(rng, RA.DimIndices(x), wts, n; replace, ordered)

function sample_indices(rng, x, n, skipmissing, weights::Nothing, replace, ordered, _)
if istrue(skipmissing)
wts = StatsBase.Weights(vec(boolmask(x)))
StatsBase.sample(rng, RA.DimIndices(x), wts, n; replace, ordered)
else
StatsBase.sample(rng, RA.DimIndices(x), n; replace, ordered)
end
end
function sample_indices(rng, x, n, skipmissing::_False, weights::AbstractDimArray, replace, ordered, weightstype)
wts = if dims(weights) == dims(x)
function sample_indices(rng, x, n, skipmissing, weights::AbstractDimArray, replace, ordered, ::Type{W}) where W
wts = if istrue(skipmissing)
@d boolmask(x) .* weights
elseif dims(weights) == dims(x)
weights
else
@d ones(eltype(weights), dims(x)) .* weights
end |> vec |> weightstype
end |> vec |> W
StatsBase.sample(rng, RA.DimIndices(x), wts, n; replace, ordered)
end
function sample_indices(rng, x, n, skipmissing::_True, weights::AbstractDimArray, replace, ordered, weightstype)
wts = weightstype(vec(@d boolmask(x) .* weights))
StatsBase.sample(rng, RA.DimIndices(x), wts, n; replace, ordered)
end

function _geometrytype(x, geometry::Bool)
if geometry
error("Specify a geometry type by setting `geometry` to a Tuple or NamedTuple of Dimensions. E.g. `geometry = (X, Y)`")
Expand Down
10 changes: 3 additions & 7 deletions src/methods/extract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,25 +91,21 @@ function _extract(A::RasterStackOrArray, ::Nothing, data;
# We need to split out points from other geoms
# TODO this will fail with mixed point/geom vectors
if trait1 isa GI.PointTrait
rows = Vector{T}(undef, length(geoms))
if istrue(skipmissing)
T2 = _rowtype(A, eltype(geoms);
names, skipinvalid = _True(), skipmissing = _False(), kw...)
rows = Vector{T2}(undef, length(geoms))
j = 1
for i in eachindex(geoms)
g = geoms[i]
ismissing(g) && continue
e = _extract_point(T2, A, g, skipmissing; names, kw...)
e = _extract_point(T, A, g, skipmissing; names, kw...)
if !ismissing(e)
rows[j] = e
j += 1
end
nothing
end
deleteat!(rows, j:length(rows))
rows = T === T2 ? rows : T.(rows)
else
rows = Vector{T}(undef, length(geoms))
for i in eachindex(geoms)
g = geoms[i]
rows[i] = _extract_point(T, A, g, skipmissing; names, kw...)::T
Expand Down Expand Up @@ -281,7 +277,7 @@ Base.@assume_effects :total function _maybe_add_fields(::Type{T}, props::NamedTu
:index in K ? merge((; geometry=point, index=I), props) : merge((; geometry=point), props)
else
:index in K ? merge((; index=I), props) : props
end
end |> T
end

@inline _skip_missing_rows(rows, ::Missing, names) =
Expand Down

0 comments on commit 44b0fc4

Please sign in to comment.