Skip to content

Commit

Permalink
Merge pull request #205 from tcarion/diskarrays
Browse files Browse the repository at this point in the history
DiskArrays for `Variable`'s
  • Loading branch information
Alexander-Barth authored Sep 29, 2023
2 parents 36d2d51 + fc4e099 commit 7576323
Show file tree
Hide file tree
Showing 13 changed files with 89 additions and 155 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ CFTime = "179af706-886a-5703-950a-314cd64e0468"
CommonDataModel = "1fbeeb36-5f17-413c-809b-666fb144f157"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DiskArrays = "3c3547ce-8d99-4f5e-a174-61eb10b00ae3"
NetCDF_jll = "7243133f-43d8-5620-bbf4-c2c921802cf3"
NetworkOptions = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand All @@ -24,10 +25,10 @@ julia = "1.3"

[extras]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"

[targets]
test = ["Dates", "Test", "Random", "Printf", "IntervalSets"]
5 changes: 5 additions & 0 deletions src/NCDatasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ import CommonDataModel: AbstractDataset, AbstractVariable,
groupnames, group, defGroup,
dimnames, dim, defDim,
attribnames, attrib, defAttrib
import DiskArrays
import DiskArrays: readblock!, writeblock!, eachchunk, haschunks
using DiskArrays: @implement_diskarray

function __init__()
NetCDF_jll.is_available() && init_certificate_authority()
Expand Down Expand Up @@ -65,6 +68,8 @@ include("ncgen.jl")
include("select.jl")
include("precompile.jl")

@implement_diskarray NCDatasets.Variable

export CatArrays
export CFTime
export daysinmonth, daysinyear, yearmonthday, yearmonth, monthday
Expand Down
2 changes: 1 addition & 1 deletion src/cfvariable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ function _range_indices_dest(of,v,rest...)
end
range_indices_dest(ri...) = _range_indices_dest((),ri...)

function Base.getindex(v::Union{CFVariable,Variable,MFVariable,SubVariable},indices::Union{Int,Colon,AbstractRange{<:Integer},Vector{Int}}...)
function Base.getindex(v::Union{MFVariable,SubVariable},indices::Union{Int,Colon,AbstractRange{<:Integer},Vector{Int}}...)
@debug "transform vector of indices to ranges"

sz_source = size(v)
Expand Down
14 changes: 13 additions & 1 deletion src/dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -489,20 +489,32 @@ function Base.write(dest::NCDataset, src::AbstractDataset;
# end
end

function _destindex(ind, dimname, dimlength, unlimdims)
nind = _normalizeindex(dimlength, ind)
if dimname in unlimdims
nind[1]:dimlength
else
nind
end
end
_maxrange(dimname, idimensions, dimlength) = haskey(idimensions, dimname) ? idimensions[dimname][end] : dimlength

# loop over variables
for varname in include
(varname exclude) && continue
@debug "Writing variable $varname..."

cfvar = src[varname]
cfsz = size(cfvar)
dimension_names = dimnames(cfvar)
var = cfvar.var
# indices for subset
index = ntuple(i -> torange(get(idimensions,dimension_names[i],:)),length(dimension_names))
destindex = ntuple(i -> _destindex(index[i], dimension_names[i], _maxrange(dimension_names[i], idimensions, cfsz[i]), unlimited_dims), length(dimension_names))

destvar = defVar(dest, varname, eltype(var), dimension_names; attrib = attribs(cfvar))
# copy data
destvar.var[:] = cfvar.var[index...]
destvar.var[destindex...] = cfvar.var[index...]
end

# loop over all global attributes
Expand Down
2 changes: 1 addition & 1 deletion src/subvariable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ close(ds)
```
"""
Base.view(v::AbstractVariable,indices::Union{Int,Colon,AbstractVector{Int}}...) = SubVariable(v,indices...)
Base.view(v::Union{CFVariable, DeferVariable, MFCFVariable},indices::Union{Int,Colon,AbstractVector{Int}}...) = SubVariable(v,indices...)
Base.view(v::SubVariable,indices::CartesianIndex) = view(v,indices.I...)
Base.view(v::SubVariable,indices::CartesianIndices) = view(v,indices.indices...)

Expand Down
176 changes: 46 additions & 130 deletions src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,85 +345,68 @@ end
nomissing(a::AbstractArray,value) = a
export nomissing


function Base.getindex(v::Variable,indexes::Int...)
# This method needs to be duplicated instead of using an Union. Otherwise a DiskArrays fallback is called instead which impacts performances
# (see https://github.com/Alexander-Barth/NCDatasets.jl/pull/205#issuecomment-1589575041)
function readblock!(v::Variable, aout, indexes::TI...) where TI <: Union{AbstractUnitRange,StepRange}
datamode(v.ds)
return nc_get_var1(eltype(v),v.ds.ncid,v.varid,[i-1 for i in indexes[ndims(v):-1:1]])
_readblock!(v, aout, indexes...)
return aout
end

function Base.setindex!(v::Variable{T,N},data,indexes::Int...) where N where T
@debug "$(@__LINE__)"
datamode(v.ds)
# use zero-based indexes and reversed order
nc_put_var1(v.ds.ncid,v.varid,[i-1 for i in indexes[ndims(v):-1:1]],T(data))
return data
end
_readblock!(v::Variable, aout, indexes::AbstractUnitRange...) = _read_data_from_nc!(v, aout, indexes...)
_readblock!(v::Variable, aout, indexes::StepRange...) = _read_data_from_nc!(v, aout, indexes...)

function Base.getindex(v::Variable{T,N},indexes::Colon...) where {T,N}
datamode(v.ds)
data = Array{T,N}(undef,size(v))
nc_get_var!(v.ds.ncid,v.varid,data)
readblock!(v::Variable, aout) = _read_data_from_nc!(v::Variable, aout)

# special case for scalar NetCDF variable
if N == 0
return data[]
else
return data
end
function _read_data_from_nc!(v::Variable, aout, indexes::Int...)
aout .= nc_get_var1(eltype(v),v.ds.ncid,v.varid,[i-1 for i in reverse(indexes)])
end

function Base.setindex!(v::Variable{T,N},data::T,indexes::Colon...) where {T,N}
@debug "setindex! colon $data"
datamode(v.ds) # make sure that the file is in data mode
tmp = fill(data,size(v))
nc_put_var(v.ds.ncid,v.varid,tmp)
return data
function _read_data_from_nc!(v::Variable{T,N}, aout, indexes::TR...) where {T,N} where TR <: Union{StepRange{Int,Int},UnitRange{Int}}
start,count,stride,jlshape = ncsub(indexes)
nc_get_vars!(v.ds.ncid,v.varid,start,count,stride,aout)
end

# union types cannot be used to avoid ambiguity
for data_type = [Number, String, Char]
@eval begin
# call to v .= 123
function Base.setindex!(v::Variable{T,N},data::$data_type) where {T,N}
@debug "setindex! $data"
datamode(v.ds) # make sure that the file is in data mode
tmp = fill(convert(T,data),size(v))
nc_put_var(v.ds.ncid,v.varid,tmp)
return data
end

Base.setindex!(v::Variable,data::$data_type,indexes::Colon...) = setindex!(v::Variable,data)

function Base.setindex!(v::Variable{T,N},data::$data_type,indexes::StepRange{Int,Int}...) where {T,N}
datamode(v.ds) # make sure that the file is in data mode
start,count,stride,jlshape = ncsub(indexes[1:ndims(v)])
tmp = fill(convert(T,data),jlshape)
nc_put_vars(v.ds.ncid,v.varid,start,count,stride,tmp)
return data
end
end
function _read_data_from_nc!(v::Variable{T,N}, aout, indexes::Union{Int,Colon,AbstractRange{<:Integer}}...) where {T,N}
sz = size(v)
start,count,stride = ncsub2(sz,indexes...)
jlshape = _shape_after_slice(sz,indexes...)
nc_get_vars!(v.ds.ncid,v.varid,start,count,stride,aout)
end

function Base.setindex!(v::Variable{T,N},data::AbstractArray{T,N},indexes::Colon...) where {T,N}
datamode(v.ds) # make sure that the file is in data mode
_read_data_from_nc!(v::Variable, aout) = _read_data_from_nc!(v, aout, 1)

nc_put_var(v.ds.ncid,v.varid,data)
function writeblock!(v::Variable, data, indexes::TI...) where TI <: Union{AbstractUnitRange,StepRange}
datamode(v.ds)
_write_data_to_nc(v, data, indexes...)
return data
end

function Base.setindex!(v::Variable{T,N},data::AbstractArray{T2,N},indexes::Colon...) where {T,T2,N}
datamode(v.ds) # make sure that the file is in data mode
tmp =
if T <: Integer
round.(T,data)
else
convert(Array{T,N},data)
end
function _write_data_to_nc(v::Variable{T,N},data,indexes::Int...) where {T,N}
nc_put_var1(v.ds.ncid,v.varid,[i-1 for i in reverse(indexes)],T(data[1]))
end

nc_put_var(v.ds.ncid,v.varid,tmp)
return data
_write_data_to_nc(v::Variable, data) = _write_data_to_nc(v, data, 1)

function _write_data_to_nc(v::Variable{T, N}, data, indexes::StepRange{Int,Int}...) where {T, N}
start,count,stride,jlshape = ncsub(indexes)
nc_put_vars(v.ds.ncid,v.varid,start,count,stride,T.(data))
end

function _write_data_to_nc(v::Variable, data, indexes::Union{AbstractRange{<:Integer}}...)
ind = prod(length.(indexes)) == 1 ? first.(indexes) : normalizeindexes(size(v),indexes)
return _write_data_to_nc(v, data, ind...)
end

getchunksize(v::Variable) = getchunksize(haschunks(v),v)
getchunksize(::DiskArrays.Chunked, v::Variable) = chunking(v)[2]
# getchunksize(::DiskArrays.Unchunked, v::Variable) = DiskArrays.estimate_chunksize(v)
getchunksize(::DiskArrays.Unchunked, v::Variable) = size(v)
eachchunk(v::CFVariable) = eachchunk(v.var)
haschunks(v::CFVariable) = haschunks(v.var)
eachchunk(v::Variable) = DiskArrays.GridChunks(v, Tuple(getchunksize(v)))
haschunks(v::Variable) = (chunking(v)[1] == :contiguous ? DiskArrays.Unchunked() : DiskArrays.Chunked())

_normalizeindex(n,ind::Base.OneTo) = 1:1:ind.stop
_normalizeindex(n,ind::Colon) = 1:1:n
_normalizeindex(n,ind::Int) = ind:1:ind
Expand Down Expand Up @@ -477,72 +460,5 @@ end
return start,count,stride
end

function Base.getindex(v::Variable{T,N},indexes::TR...) where {T,N} where TR <: Union{StepRange{Int,Int},UnitRange{Int}}
start,count,stride,jlshape = ncsub(indexes[1:N])
data = Array{T,N}(undef,jlshape)

datamode(v.ds)
nc_get_vars!(v.ds.ncid,v.varid,start,count,stride,data)
return data
end

function Base.setindex!(v::Variable{T,N},data::T,indexes::StepRange{Int,Int}...) where {T,N}
datamode(v.ds) # make sure that the file is in data mode
start,count,stride,jlshape = ncsub(indexes[1:ndims(v)])
tmp = fill(data,jlshape)
nc_put_vars(v.ds.ncid,v.varid,start,count,stride,tmp)
return data
end

function Base.setindex!(v::Variable{T,N},data::Array{T,N},indexes::StepRange{Int,Int}...) where {T,N}
datamode(v.ds) # make sure that the file is in data mode
start,count,stride,jlshape = ncsub(indexes[1:ndims(v)])
nc_put_vars(v.ds.ncid,v.varid,start,count,stride,data)
return data
end

# data can be Array{T2,N} or BitArray{N}
function Base.setindex!(v::Variable{T,N},data::AbstractArray,indexes::StepRange{Int,Int}...) where {T,N}
datamode(v.ds) # make sure that the file is in data mode
start,count,stride,jlshape = ncsub(indexes[1:ndims(v)])

tmp = convert(Array{T,ndims(data)},data)
nc_put_vars(v.ds.ncid,v.varid,start,count,stride,tmp)

return data
end




function Base.getindex(v::Variable{T,N},indexes::Union{Int,Colon,AbstractRange{<:Integer}}...) where {T,N}
sz = size(v)
start,count,stride = ncsub2(sz,indexes...)
jlshape = _shape_after_slice(sz,indexes...)
data = Array{T}(undef,jlshape)

datamode(v.ds)
nc_get_vars!(v.ds.ncid,v.varid,start,count,stride,data)

return data
end

# NetCDF scalars indexed as []
Base.getindex(v::Variable{T, 0}) where T = v[1]



function Base.setindex!(v::Variable,data,indexes::Union{Int,Colon,AbstractRange{<:Integer}}...)
ind = normalizeindexes(size(v),indexes)

# make arrays out of scalars (arrays can have zero dimensions)
if (ndims(data) == 0) && !(data isa AbstractArray)
data = fill(data,length.(ind))
end

return v[ind...] = data
end


Base.getindex(v::Union{MFVariable,DeferVariable,Variable},ci::CartesianIndices) = v[ci.indices...]
Base.setindex!(v::Union{MFVariable,DeferVariable,Variable},data,ci::CartesianIndices) = setindex!(v,data,ci.indices...)
Base.getindex(v::Union{MFVariable,DeferVariable},ci::CartesianIndices) = v[ci.indices...]
Base.setindex!(v::Union{MFVariable,DeferVariable},data,ci::CartesianIndices) = setindex!(v,data,ci.indices...)
2 changes: 1 addition & 1 deletion test/perf/generate_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ ncv1 = defVar(ds,"v1", UInt8, ("longitude", "latitude", "time"), fillvalue = UIn
for n = 1:sz[3]
@show n
ncv1[:,:,n] = rand(1:100,sz[1],sz[2])
ncv1[:,1,n] = missing
ncv1[:,1,n] .= missing
end

close(ds)
12 changes: 6 additions & 6 deletions test/test_check_size.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ defVar(ds, "w", Float64, ("x", "Time"))

for i in 1:10
ds["Time"][i] = i
ds["a"][:,i] = 1
@test_throws NCDatasets.NetCDFError ds["u"][:,i] = collect(1:9)
@test_throws NCDatasets.NetCDFError ds["v"][:,i] = collect(1:11)
@test_throws NCDatasets.NetCDFError ds["w"][:,i] = reshape(collect(1:20), 10, 2)
ds["a"][:,i] .= 1
@test_throws DimensionMismatch ds["u"][:,i] = collect(1:9)
@test_throws DimensionMismatch ds["v"][:,i] = collect(1:11)
@test_throws DimensionMismatch ds["w"][:,i] = reshape(collect(1:20), 10, 2)

# ignore singleton dimension
ds["w"][:,i] = reshape(collect(1:10), 1, 1, 10, 1)
Expand All @@ -29,11 +29,11 @@ end
ds["w"][:,:] = ones(10,10)

# w should grow along the unlimited dimension
ds["w"][:,:] = ones(10,15)
ds["w"][:,1:15] = ones(10,15)
@test size(ds["w"]) == (10,15)

# w cannot grow along a fixed dimension
@test_throws NCDatasets.NetCDFError ds["w"][:,:] = ones(11,15)
@test_throws DimensionMismatch ds["w"][:,:] = ones(11,15)

# NetCDF: Index exceeds dimension bound
@test_throws NCDatasets.NetCDFError ds["u"][100,100]
Expand Down
2 changes: 1 addition & 1 deletion test/test_corner_cases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ b = dropdims([1.], dims=(1,))
NCDataset(fname,"c") do ds
time = defDim(ds,"time",Inf)
v = defVar(ds,"temp",Float32,("time",))
ds["temp"][1:1] = b
ds["temp"][1] = b
@test ds["temp"][1] == 1
end

Expand Down
2 changes: 1 addition & 1 deletion test/test_scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ for (T,data) in ((Float32,123.f0),
end

NCDataset(filename,"r") do ds
v2 = ds["scalar"][:]
v2 = ds["scalar"][1]
@test v2 == data
end
rm(filename)
Expand Down
8 changes: 4 additions & 4 deletions test/test_variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ NCDatasets.NCDataset(filename,"c") do ds

v = NCDatasets.defVar(ds,"small",Float64,("lon","lat"))
# @test_throws Union{NCDatasets.NetCDFError,DimensionMismatch} v[:] = zeros(sz[1]+1,sz[2])
@test_throws NCDatasets.NetCDFError v[1:sz[1],1:sz[2]] = zeros(sz[1]+1,sz[2])
@test_throws DimensionMismatch v[1:sz[1],1:sz[2]] = zeros(sz[1]+1,sz[2])
@test_throws NCDatasets.NetCDFError v[sz[1]+1,1] = 1
@test_throws NCDatasets.NetCDFError v[-1,1] = 1

Expand Down Expand Up @@ -79,7 +79,7 @@ NCDataset(filename,"c") do ds
"units" => "degree_Celsius",
"long_name" => "Temperature"
])
@test ds["temp"][:] == data
@test ds["temp"][:] == data[:]
@test eltype(ds["temp"].var) == Int32
@test ds.dim["lon"] == sz[1]
@test ds.dim["lat"] == sz[2]
Expand Down Expand Up @@ -150,10 +150,10 @@ NCDataset(filename,"c") do ds
end

defVar(ds,"scalar",123.)
@test ds["scalar"][:] == 123.
@test ds["scalar"][1] == 123.

# test indexing with symbols #101
@test ds[:scalar][:] == 123.
@test ds[:scalar][1] == 123.
end
rm(filename)

Expand Down
2 changes: 1 addition & 1 deletion test/test_variable_unlim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ NCDatasets.NCDataset(filename,"c") do ds

for j = 1:sz[2]
data[:,j] .= T(j)
v[:,j] = T(j)
v[:,j] = fill(T(j), sz[1])
end

@test all(v[:,:] == data)
Expand Down
Loading

0 comments on commit 7576323

Please sign in to comment.