Skip to content

Commit

Permalink
Inline and use recursion
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Mar 2, 2024
1 parent 83fc6c6 commit a55877e
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 27 deletions.
8 changes: 4 additions & 4 deletions src/Adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ If we want this to work with custom structures, we need to extend `adapt_structu
julia> adapt(IntegerLessAdaptor(), MyStructure(42))
MyStructure(42.0)
"""
adapt(to, x) = adapt_structure(to, x)
@inline adapt(to, x) = adapt_structure(to, x)

"""
adapt(to)
Create a function that adapts its argument according to `to`.
If no specific adaptions have been registered for `to`, the returned function will be equivalent to `identity`.
"""
adapt(to) = Base.Fix1(adapt, to)
@inline adapt(to) = Base.Fix1(adapt, to)
if VERSION < v"1.9.0-DEV.857"
@eval function adapt(to::Type{T}) where {T}
(@isdefined T) || return Base.Fix1(adapt, to)
Expand All @@ -54,8 +54,8 @@ if VERSION < v"1.9.0-DEV.857"
end
end

adapt_structure(to, x) = adapt_storage(to, x)
adapt_storage(to, x) = x
@inline adapt_structure(to, x) = adapt_storage(to, x)
@inline adapt_storage(to, x) = x

# structure rules
include("base.jl")
Expand Down
4 changes: 2 additions & 2 deletions src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

## arrays

Adapt.adapt_storage(::Type{Array}, xs::AT) where {AT<:AbstractArray} =
@inline Adapt.adapt_storage(::Type{Array}, xs::AT) where {AT<:AbstractArray} =
convert(Array, xs)

# if an element type is specified, convert to it
Adapt.adapt_storage(::Type{<:Array{T}}, xs::AT) where {T, AT<:AbstractArray} =
@inline Adapt.adapt_storage(::Type{<:Array{T}}, xs::AT) where {T, AT<:AbstractArray} =
convert(Array{T}, xs)

# NOTE: this flattens all <:AbstractArray leaves, e.g., Base.Slice(1:2) -> [1,2].
16 changes: 12 additions & 4 deletions src/base.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
# predefined adaptors for working with types from the Julia standard library

adapt_structure(to, xs::Union{Tuple,NamedTuple}) = map(adapt(to), xs)
@inline adapt_structure(to, xs::Union{Tuple,NamedTuple}) = map(x->adapt(to,x), xs)

@inline _radapt_structure(to, xs::Tuple) =
(adapt(to, first(xs)), _radapt_structure(to, Base.tail(xs))...)
@inline _radapt_structure(to, xs::Tuple{}) = ()
@inline _radapt_structure(to, xs::Tuple{<:Any}) =
(adapt(to, first(xs)), )
@inline adapt_structure(to, xs::Tuple) = _radapt_structure(to, xs)
@inline adapt_structure(to, xs::NamedTuple) = map(x->adapt(to,x), xs)


## Closures

# two things can be captured: static parameters, and actual values (fields)

@eval function adapt_structure(to, f::F) where {F<:Function}
@eval @inline function adapt_structure(to, f::F) where {F<:Function}
# how many type parameters does this function have?
# each captured value will have one (with the exception of boxed values)
num_type_params = length(F.parameters)
Expand Down Expand Up @@ -45,8 +53,8 @@ end

import Base.Broadcast: Broadcasted, Extruded

adapt_structure(to, bc::Broadcasted{Style}) where Style =
@inline adapt_structure(to, bc::Broadcasted{Style}) where Style =
Broadcasted{Style}(adapt(to, bc.f), adapt(to, bc.args), bc.axes)

adapt_structure(to, ex::Extruded) =
@inline adapt_structure(to, ex::Extruded) =
Extruded(adapt(to, ex.x), ex.keeps, ex.defaults)
2 changes: 1 addition & 1 deletion src/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ of `obj` and constructs a new instance of `T` using the default constuctor `T(..
macro adapt_structure(T)
names = fieldnames(Core.eval(__module__, T))
quote
function Adapt.adapt_structure(to, obj::$(esc(T)))
@inline function Adapt.adapt_structure(to, obj::$(esc(T)))
$(esc(T))($([:(Adapt.adapt_structure(to, obj.$name)) for name in names]...))
end
end
Expand Down
32 changes: 16 additions & 16 deletions src/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,48 +7,48 @@ permutation(::PermutedDimsArray{T,N,perm}) where {T,N,perm} = perm

export WrappedArray, parent_type, unwrap_type

adapt_structure(to, A::SubArray) =
@inline adapt_structure(to, A::SubArray) =
SubArray(adapt(to, parent(A)), adapt(to, parentindices(A)))
function adapt_structure(to, A::PermutedDimsArray)
@inline function adapt_structure(to, A::PermutedDimsArray)
perm = permutation(A)
iperm = invperm(perm)
A′ = adapt(to, parent(A))
PermutedDimsArray{eltype(A′),ndims(A′),perm,iperm,typeof(A′)}(A′)
end
adapt_structure(to, A::Base.ReshapedArray) =
@inline adapt_structure(to, A::Base.ReshapedArray) =
Base.reshape(adapt(to, parent(A)), size(A))
@static if isdefined(Base, :NonReshapedReinterpretArray)
adapt_structure(to, A::Base.NonReshapedReinterpretArray) =
@inline adapt_structure(to, A::Base.NonReshapedReinterpretArray) =
Base.reinterpret(eltype(A), adapt(to, parent(A)))
adapt_structure(to, A::Base.ReshapedReinterpretArray) =
@inline adapt_structure(to, A::Base.ReshapedReinterpretArray) =
Base.reinterpret(reshape, eltype(A), adapt(to, parent(A)))
else
adapt_structure(to, A::Base.ReinterpretArray) =
@inline adapt_structure(to, A::Base.ReinterpretArray) =
Base.reinterpret(eltype(A), adapt(to, parent(A)))
end
@eval function adapt_structure(to, A::Base.LogicalIndex{T}) where T
@eval @inline function adapt_structure(to, A::Base.LogicalIndex{T}) where T
# prevent re-calculating the count of booleans during LogicalIndex construction
mask = adapt(to, A.mask)
$(Expr(:new, :(Base.LogicalIndex{T, typeof(mask)}), :mask, :(A.sum)))
end

adapt_structure(to, A::LinearAlgebra.Adjoint) =
@inline adapt_structure(to, A::LinearAlgebra.Adjoint) =
LinearAlgebra.adjoint(adapt(to, parent(A)))
adapt_structure(to, A::LinearAlgebra.Transpose) =
@inline adapt_structure(to, A::LinearAlgebra.Transpose) =
LinearAlgebra.transpose(adapt(to, parent(A)))
adapt_structure(to, A::LinearAlgebra.LowerTriangular) =
@inline adapt_structure(to, A::LinearAlgebra.LowerTriangular) =
LinearAlgebra.LowerTriangular(adapt(to, parent(A)))
adapt_structure(to, A::LinearAlgebra.UnitLowerTriangular) =
@inline adapt_structure(to, A::LinearAlgebra.UnitLowerTriangular) =
LinearAlgebra.UnitLowerTriangular(adapt(to, parent(A)))
adapt_structure(to, A::LinearAlgebra.UpperTriangular) =
@inline adapt_structure(to, A::LinearAlgebra.UpperTriangular) =
LinearAlgebra.UpperTriangular(adapt(to, parent(A)))
adapt_structure(to, A::LinearAlgebra.UnitUpperTriangular) =
@inline adapt_structure(to, A::LinearAlgebra.UnitUpperTriangular) =
LinearAlgebra.UnitUpperTriangular(adapt(to, parent(A)))
adapt_structure(to, A::LinearAlgebra.Diagonal) =
@inline adapt_structure(to, A::LinearAlgebra.Diagonal) =
LinearAlgebra.Diagonal(adapt(to, parent(A)))
adapt_structure(to, A::LinearAlgebra.Tridiagonal) =
@inline adapt_structure(to, A::LinearAlgebra.Tridiagonal) =
LinearAlgebra.Tridiagonal(adapt(to, A.dl), adapt(to, A.d), adapt(to, A.du))
adapt_structure(to, A::LinearAlgebra.Symmetric) =
@inline adapt_structure(to, A::LinearAlgebra.Symmetric) =
LinearAlgebra.Symmetric(adapt(to, parent(A)))


Expand Down

0 comments on commit a55877e

Please sign in to comment.