Skip to content

Commit

Permalink
Various fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Feb 11, 2025
1 parent 69dbf02 commit bfdd4a0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
21 changes: 13 additions & 8 deletions src/concatenate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ export concatenate, concatenate!
@compat public Concatenated, cat_offset!, cat_offset1!, copy_or_fill!

using Base: promote_eltypeof
using .DerivableInterfaces: AbstractInterface, interface
using ..DerivableInterfaces: DerivableInterfaces, AbstractInterface, interface

"""
Concatenated{Interface,Dims,Args<:Tuple}
Expand Down Expand Up @@ -66,7 +66,11 @@ end
dims(::Concatenated{A,D}) where {A,D} = D
DerivableInterfaces.interface(cat::Concatenated) = cat.interface

concatenated(args...; dims) = Concatenated(args, Val(dims))
concatenated(args...; dims) = Concatenated(Val(dims), args)

function Base.convert(::Type{Concatenated{NewInterface}}, cat::Concatenated{<:Any,Dims,Args}) where {NewInterface,Dims,Args}
return Concatenated{NewInterface}(cat.dims, cat.args)::Concatenated{NewInterface,Dims,Args}
end

# allocating the destination container
# ------------------------------------
Expand All @@ -93,7 +97,7 @@ Concatenate the supplied `args` along dimensions `dims`.
See also [`concatenate!`](@ref).
"""
concatenate(args...; dims) = Base.materialize(concatenated(dims, args...))
concatenate(args...; dims) = Base.materialize(concatenated(args...; dims))
Base.materialize(cat::Concatenated) = copy(cat)

"""
Expand All @@ -111,15 +115,16 @@ Base.copy(cat::Concatenated) = copyto!(similar(cat), cat)

# default falls back to replacing interface with Nothing
# this permits specializing on typeof(dest) without ambiguities
@inline Base.copyto!(dest, cat::Concatenated) =
# Note: this needs to be defined for AbstractArray specifically to avoid ambiguities with Base.
@inline Base.copyto!(dest::AbstractArray, cat::Concatenated) =
copyto!(dest, convert(Concatenated{Nothing}, cat))

function Base.copyto!(dest::AbstractArray, cat::Concatenated{Nothing})
# if concatenation along multiple directions, holes need to be zero.
catdims = Base.dims2cat(dims(cat))
count(!iszero, catdims)::Int > 1 && zero!(dest)

shape = cat_size_shape(catdims, cat.args...)
shape = Base.cat_size_shape(catdims, cat.args...)
offsets = ntuple(zero, ndims(dest))
return cat_offset!(dest, shape, catdims, offsets, cat.args...)
end
Expand All @@ -130,16 +135,16 @@ end
# at a time via cat_offset1! to avoid having to write too many specializations
function cat_offset!(dest, shape, catdims, offsets, x, X...)
dest, newoffsets = cat_offset1!(dest, shape, catdims, offsets, x)
return cat_offset!(dest, shape, newoffsets, X...)
return cat_offset!(dest, shape, catdims, newoffsets, X...)
end
cat_offset!(dest, shape, catdims, offsets) = dest

# this is the typical specialization point, which is no longer vararg.
# it simply computes indices and calls out to copy_or_fill!, so if that
# pattern works you can also overload that function
function cat_offset1!(dest, shape, catdims, offsets, x)
inds = ntuple(length(offests)) do i
(i length(catdims) && catdims[i]) ? offsets[i] + axes(x, i) : 1:shape[i]
inds = ntuple(length(offsets)) do i
(i length(catdims) && catdims[i]) ? offsets[i] .+ axes(x, i) : 1:shape[i]
end
copy_or_fill!(dest, inds, x)
newoffsets = ntuple(length(offsets)) do i
Expand Down
3 changes: 1 addition & 2 deletions test/SparseArrayDOKs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ struct SparseArrayDOK{T,N} <: AbstractArray{T,N}
end
storage(a::SparseArrayDOK) = a.storage
Base.size(a::SparseArrayDOK) = a.size
Base.similar(::Type{SparseArrayDOK{T}}, axes) = SparseArrayDOK{T}(undef, axes)
function SparseArrayDOK{T}(size::Int...) where {T}
N = length(size)
return SparseArrayDOK{T,N}(Dict{CartesianIndex{N},T}(), size)
Expand Down Expand Up @@ -269,6 +268,6 @@ DerivableInterfaces.interface(::Type{<:SparseArrayDOK}) = SparseArrayInterface()
# DerivableInterfaces the interface for the type.
@derive AnySparseArrayDOK AbstractArrayOps

Base._cat(dims, args::SparseArrayDOK...) = DerivableInterfaces.concatenate(args...; dims)
Base._cat(dims, args::SparseArrayDOK...) = DerivableInterfaces.Concatenate.concatenate(args...; dims)

end

0 comments on commit bfdd4a0

Please sign in to comment.