Skip to content

Commit

Permalink
add rasters.sample
Browse files Browse the repository at this point in the history
  • Loading branch information
tiemvanderdeure committed Oct 11, 2024
1 parent 184cda2 commit 04a5d2e
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 2 deletions.
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
Proj = "c94c279d-25a6-4763-9509-64d165bea63e"
RasterDataSources = "3cb90ccd-e1b6-4867-9617-4276c8b2ca36"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
ZarrDatasets = "519a4cdf-1362-424a-9ea1-b1d782dbb24b"

[extensions]
Expand All @@ -44,6 +45,7 @@ RastersMakieExt = "Makie"
RastersNCDatasetsExt = "NCDatasets"
RastersProjExt = "Proj"
RastersRasterDataSourcesExt = "RasterDataSources"
RastersStatsBaseExt = "StatsBase"
RastersZarrDatasetsExt = "ZarrDatasets"

[compat]
Expand All @@ -70,13 +72,14 @@ NCDatasets = "0.13, 0.14"
OffsetArrays = "1"
ProgressMeter = "1"
Proj = "1.7.2"
RasterDataSources = "0.5.7, 0.6"
RasterDataSources = "0.7"
RecipesBase = "0.7, 0.8, 1.0"
Reexport = "0.2, 1.0"
SafeTestsets = "0.1"
Setfield = "0.6, 0.7, 0.8, 1"
Shapefile = "0.10, 0.11"
Statistics = "1"
StatsBase = "0.34"
Test = "1"
ZarrDatasets = "0.1"
julia = "1.10"
Expand All @@ -97,8 +100,10 @@ Proj = "c94c279d-25a6-4763-9509-64d165bea63e"
RasterDataSources = "3cb90ccd-e1b6-4867-9617-4276c8b2ca36"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Shapefile = "8e980c4a-a4fe-5da2-b3a7-4b4b0353a2f4"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "ArchGDAL", "CFTime", "CoordinateTransformations", "DataFrames", "GeoDataFrames", "GeometryBasics", "GRIBDatasets", "NCDatasets", "Plots", "Proj", "RasterDataSources", "SafeTestsets", "Shapefile", "Statistics", "Test", "ZarrDatasets"]
test = ["Aqua", "ArchGDAL", "CFTime", "CoordinateTransformations", "DataFrames", "GeoDataFrames", "GeometryBasics", "GRIBDatasets", "NCDatasets", "Plots", "Proj", "RasterDataSources", "SafeTestsets", "Shapefile", "StableRNGs", "Statistics", "StatsBase", "Test", "ZarrDatasets"]
17 changes: 17 additions & 0 deletions ext/RastersStatsBaseExt/RastersStatsBaseExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module RastersStatsBaseExt

@static if isdefined(Base, :get_extension) # julia < 1.9
using Rasters, StatsBase
else
using ..Rasters, ..StatsBase
end
using StatsBase.Random

const RA = Rasters

import Rasters: _True, _False, _booltype
import Rasters.DimensionalData as DD

include("sample.jl")

end # Module
92 changes: 92 additions & 0 deletions ext/RastersStatsBaseExt/sample.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
Rasters.sample(x::RA.RasterStackOrArray, n::Integer; kw...) = Rasters.sample(Random.GLOBAL_RNG, x, n; kw...)
@inline function Rasters.sample(
rng::Random.AbstractRNG, x::RA.RasterStackOrArray, n::Integer;
geometry = true, index = false, names=RA._names(x), name=names, skipmissing = false,
replace = true, ordered = false, weights = nothing, weightstype::Type{<:StatsBase.AbstractWeights} = StatsBase.Weights
)
na = DD._astuple(name)
_sample(rng, x, n;
dims=DD.dims(x, RA.DEFAULT_POINT_ORDER),
names=NamedTuple{na}(na),
# These keywords are converted to _True/_False for type stability later on
# The @inline above helps constant propagation of the Bools
geometry=_booltype(geometry),
index=_booltype(index),
skipmissing=_booltype(skipmissing),
weights,
replace,
ordered,
weightstype
)

end
function _sample(rng, x, n; dims, names::NamedTuple{K}, geometry, index, skipmissing, weights, replace, ordered, weightstype) where K
indices = sample_indices(rng, x, n, skipmissing, weights, replace, ordered, weightstype)
tuplepoint = map(first, dims)
T = _srowtype(x, tuplepoint; geometry, index, skipmissing, names)
rows = Vector{T}(undef, n)
points = DimPoints(dims)

for i in 1:n
idx = indices[i]
props = if x isa Raster
NamedTuple{K,Tuple{eltype(x)}}((x[idx],))
else
NamedTuple(x[idx])[K]
end
point = geometry isa _True ? points[idx] : nothing
rows[i] = RA._maybe_add_fields(T, props, point, idx)
end

return rows
end

function sample_indices(rng, x, n, skipmissing::_False, weights::Nothing, replace, ordered, weightstype)
StatsBase.sample(rng, CartesianIndices(x), n; replace, ordered)
end
function sample_indices(rng, x, n, skipmissing::_True, weights::Nothing, replace, ordered, weightstype)
wts = weightstype(vec(boolmask(x)))
StatsBase.sample(rng, CartesianIndices(x), wts, n; replace, ordered)
end
function sample_indices(rng, x, n, skipmissing::_False, weights::AbstractDimArray, replace, ordered, weightstype)
wts = if dims(weights) == dims(x)
weights
else
@d ones(eltype(weights), dims(x)) .* weights
end |> vec |> weightstype
StatsBase.sample(rng, CartesianIndices(x), wts, n; replace, ordered)
end
function sample_indices(rng, x, n, skipmissing::_True, weights::AbstractDimArray, replace, ordered, weightstype)
wts = weightstype(vec(@d boolmask(x) .* weights))
StatsBase.sample(rng, CartesianIndices(x), wts, n; replace, ordered)
end

# Determine the row type, making use of some of extract machinery
_srowtype(x, g; kw...) = _srowtype(x, typeof(g); kw...)
function _srowtype(x, g::Type; geometry, index, skipmissing, names, kw...)
keys = RA._rowkeys(geometry, index, names)
types = _srowtypes(x, g, geometry, index, skipmissing, names)
NamedTuple{keys,types}
end
function _srowtypes(x, ::Type{G}, geometry::_True, index::_True, skipmissing::_False, names::NamedTuple{Names}) where {G,Names}
Tuple{G,Tuple{Int,Int},_nametypes(x, names)...}
end
function _srowtypes(x, ::Type{G}, geometry::_False, index::_True, skipmissing::_False, names::NamedTuple{Names}) where {G,Names}
Tuple{Tuple{Int,Int},_nametypes(x, names)...}
end
function _srowtypes(x, ::Type{G}, geometry::_True, index::_False, skipmissing::_False, names::NamedTuple{Names}) where {G,Names}
Tuple{G,_nametypes(x, names)...}
end
function _srowtypes(x, ::Type{G}, geometry::_False, index::_False, skipmissing::_False, names::NamedTuple{Names}) where {G,Names}
Tuple{_nametypes(x, names)...}
end
# fallback
_srowtypes(x, T, geometry, index, skipmissing::_True, names) = RA._rowtypes(x, T, geometry, index, skipmissing, names)
# adapted from extract code
@inline _nametypes(::Raster{T}, ::NamedTuple{Names}) where {T,Names} = (T,)
function _nametypes(::RasterStack{<:Any,T}, ::NamedTuple{PropNames}) where
{T<:NamedTuple{StackNames,Types},PropNames} where {StackNames,Types}
nt = NamedTuple{StackNames}(Types.parameters)
return values(nt[PropNames])
end

42 changes: 42 additions & 0 deletions src/extensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,48 @@ Note that cellarea returns the area in square m, while cellsize still uses squar
return cellarea(args...; kw..., radius = 6371.0088)
end

"""
Rasters.sample([rng], x, n; [geometry, index, name, skipmissing, replace, ordered, weights])
Sample from a `Raster` or `RasterStack` with additional options that match those provided by [`extract`](@ref)
Run `using StatsBase` to make this method available.
Note that this function is not exported to avoid confusion with StatsBase.sample
# Keywords
- `geometry`: include `:geometry` in returned `NamedTuple`, `true` by default.
- `index`: include `:index` of the `CartesianIndex` in returned `NamedTuple`, `false` by default.
- `name`: a `Symbol` or `Tuple` of `Symbol` corresponding to layer/s of a `RasterStack` to extract. All layers by default.
- `skipmissing`: skip missing points automatically.
- `replace`: sample with replacement, `true` by default.
- `ordered`: sample in order, `false` by default.
- `weights`: A DimArray that matches one or more of the dimensions of `x` with weights for sampling.
- `weightstype`: a `StatsBase.AbstractWeights`. Defaults to `StatsBase.Weights`.
# Example
This code draws 10 samples from the raster `myraster`, with weights adjusted by cell size.
```julia
using Rasters, Rasters.Lookups, Proj, StatsBase
xdim = X(Projected(90.0:10.0:120; sampling=Intervals(Start()), crs=EPSG(4326)))
ydim = Y(Projected(0.0:10.0:50; sampling=Intervals(Start()), crs=EPSG(4326)))
myraster = rand(xdim, ydim)
Rasters.sample(myraster, 5; weights = cellarea(myraster))
# output
5-element Vector{@NamedTuple{geometry::Tuple{Float64, Float64}, ::Union{Missing, Float64}}}:
@NamedTuple{geometry::Tuple{Float64, Float64}, ::Union{Missing, Float64}}(((90.0, 10.0), 0.7360504790189618))
@NamedTuple{geometry::Tuple{Float64, Float64}, ::Union{Missing, Float64}}(((90.0, 30.0), 0.5447657183842469))
@NamedTuple{geometry::Tuple{Float64, Float64}, ::Union{Missing, Float64}}(((90.0, 30.0), 0.5447657183842469))
@NamedTuple{geometry::Tuple{Float64, Float64}, ::Union{Missing, Float64}}(((90.0, 10.0), 0.7360504790189618))
@NamedTuple{geometry::Tuple{Float64, Float64}, ::Union{Missing, Float64}}(((110.0, 10.0), 0.5291143028176258))
```
"""
sample(args...; kw...) = throw_extension_error(sample, "StatsBase", :RastersStatsBaseExt, args)



# Other shared stubs
function layerkeys end
function smapseries end
Expand Down
81 changes: 81 additions & 0 deletions test/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -671,3 +671,84 @@ end
0.0 0.0 missing
missing missing missing], dims=3))
end

using StableRNGs, StatsBase
test = rebuild(ga; name = :test)
@testset "sample" begin
# test that all keywords work and return the same thing as extract
@test all(Rasters.sample(StableRNG(123), test, 2) .=== extract(test, [(2.0,2.0), (1.0,2.0)]))
@test all(Rasters.sample(StableRNG(123), st2, 2) .=== extract(st2, [(2,2), (1,2)]))
@test all(Rasters.sample(StableRNG(123), test, 2; geometry = false) .=== extract(test, [(2.0,2.0), (1.0,2.0)]; geometry = false))
@test all(Rasters.sample(StableRNG(123), st2, 2; geometry = false) .=== extract(st2, [(2,2), (1,2)]; geometry = false))

@test all(Rasters.sample(StableRNG(123), test, 2, skipmissing = true) .=== extract(test, [(2.0,1.0), (2.0,1.0)], skipmissing = true))
@test all(Rasters.sample(StableRNG(123), st2, 2, skipmissing = true) .=== extract(st2, [(2,1), (2,1)], skipmissing = true))

@test all(
Rasters.sample(StableRNG(123), test, 2, skipmissing = true, index = true) .===
extract(test, [(2.0,1.0), (2.0,1.0)], skipmissing = true, index = true))
@test all(
Rasters.sample(StableRNG(123), st2, 2, skipmissing = true, index = true) .===
extract(st2, [(2,1), (2,1)], skipmissing = true, index = true))

kws = [(:geometry => false,), (:index => true,), ()]
for kw in kws
@test all(
Rasters.sample(StableRNG(123), test, 2; skipmissing = true, kw...) .===
extract(test, [(2.0,1.0), (2.0,1.0)]; skipmissing = true, kw...)
)

@test all(
Rasters.sample(StableRNG(123), st2, 2; skipmissing = true, kw...) .==
extract(st2, [(2,1), (2,1)]; skipmissing = true, kw...)
)
end
@test all(Rasters.sample(StableRNG(123), st2, 2, name = (:a,)) .=== extract(st2, [(2,2), (1,2)], name = (:a,)))

# in this case extract and sample always return different types
@test eltype(Rasters.sample(StableRNG(123), test, 2, index = true)) != eltype(extract(test, [(2.0,1.0), (2.0,1.0)], index = true))
@test eltype(Rasters.sample(StableRNG(123), st2, 2, index = true)) != eltype(extract(st2, [(2,1), (2,1)], index = true))

@test all(
Rasters.sample(StableRNG(123), test, 2, weights = DimArray([1,1000], X(1:2)), skipmissing = true) .===
[
(geometry = (2.0,1.0), test = 2.0f0)
(geometry = (2.0,1.0), test = 2.0f0)
]
)

@test all(
Rasters.sample(StableRNG(123), test, 2, weights = DimArray([1,1000], X(1:2)), skipmissing = true, weightstype = StatsBase.FrequencyWeights) .===
[
(geometry = (2.0,1.0), test = 2.0f0)
(geometry = (2.0,1.0), test = 2.0f0)
]
)

@test all(
Rasters.sample(StableRNG(123), test, 2, skipmissing = true, replace = false) .===
[
(geometry = (2.0,1.0), test = 2.0f0)
(geometry = (1.0,2.0), test = 7.0f0)
]
)
@test all(
Rasters.sample(StableRNG(123), test, 2, skipmissing = true, replace = false, ordered = true) .===
[
(geometry = (2.0,1.0), test = 2.0f0)
(geometry = (1.0,2.0), test = 7.0f0)
]
)
@test_throws "strictly positive" Rasters.sample(StableRNG(123), test, 3, skipmissing = true, replace = false)
@test_throws "Cannot draw" Rasters.sample(StableRNG(123), test, 5, replace = false)
end

eltype(test)
test2 = replace_missing(test, 0.0)



Rasters._nametypes(test2, (test = :test,), Rasters._False())
typeof((a = 1,))
concretent((a = 1,))
concretent(nt::T) where T <: NamedTuple{StackNames,Types} where {StackNames,Types} = NamedTuple{StackNames}(Types.parameters)

0 comments on commit 04a5d2e

Please sign in to comment.