Skip to content

Commit

Permalink
Fixed aliasing bug for size field
Browse files Browse the repository at this point in the history
  • Loading branch information
Dominic Perno committed Jan 17, 2025
1 parent 9f743e3 commit f5f2268
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions src/interface/lazy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ Base.eltype(::Type{<:LazyTensor{T}}) where {T} = T
Base.eltype(tns::LazyTensor) = eltype(typeof(tns))
fill_value(tns::LazyTensor) = tns.fill_value

Base.size(::LazyTensor) =
throw(ErrorException("Base.size is not supported for LazyTensor. Call `compute()` first."))
Base.size(tns::LazyTensor) = tns.size

Base.getindex(::LazyTensor, i...) = throw(ErrorException("Lazy indexing with named indices is not supported. Call `compute()` first."))

Expand Down Expand Up @@ -62,19 +61,19 @@ function LazyTensor{T}(arr::Base.AbstractArrayOrBroadcasted) where {T}
name = alias(gensym(:A))
idxs = [field(gensym(:i)) for _ in 1:ndims(arr)]
extrude = ntuple(n -> size(arr, n) == 1, ndims(arr))
size = ntuple(n -> size(arr, n), ndims(arr))
shape = ntuple(n -> size(arr, n), ndims(arr))
tns = subquery(name, table(immediate(arr), idxs...))
LazyTensor{eltype(arr), ndims(arr)}(tns, extrude, size, fill_value(arr))
LazyTensor{eltype(arr), ndims(arr)}(tns, extrude, shape, fill_value(arr))
end
LazyTensor(arr::AbstractTensor) = LazyTensor{eltype(arr)}(arr)
LazyTensor(swizzle_arr::SwizzleArray{dims, <:Tensor}) where {dims} = permutedims(LazyTensor(swizzle_arr.body), dims)
function LazyTensor{T}(arr::AbstractTensor) where {T}
name = alias(gensym(:A))
idxs = [field(gensym(:i)) for _ in 1:ndims(arr)]
extrude = ntuple(n -> size(arr)[n] == 1, ndims(arr))
size = ntuple(n -> size(arr, n), ndims(arr))
shape = ntuple(n -> size(arr, n), ndims(arr))
tns = subquery(name, table(immediate(arr), idxs...))
LazyTensor{eltype(arr), ndims(arr)}(tns, extrude, size, fill_value(arr))
LazyTensor{eltype(arr), ndims(arr)}(tns, extrude, shape, fill_value(arr))
end
LazyTensor{T}(swizzle_arr::SwizzleArray{dims, <:Tensor}) where {T, dims} = permutedims(LazyTensor{T}(swizzle_arr.body), dims)
LazyTensor(data::LazyTensor) = data
Expand Down Expand Up @@ -140,11 +139,11 @@ end
function Base.reduce(op, arg::LazyTensor{T, N}; dims=:, init = initial_value(op, T)) where {T, N}
dims = dims == Colon() ? (1:N) : collect(dims)
extrude = ((arg.extrude[n] for n in 1:N if !(n in dims))...,)
size = ((arg.size[n] for n in 1:N if !(n in dims))...,)
shape = ((arg.size[n] for n in 1:N if !(n in dims))...,)
fields = [field(gensym(:i)) for _ in 1:N]
S = fixpoint_type(op, init, eltype(arg))
data = aggregate(immediate(op), immediate(init), relabel(arg.data, fields), fields[dims]...)
LazyTensor{S}(identify(data), extrude, size, init)
LazyTensor{S}(identify(data), extrude, shape, init)
end

tensordot(A::LazyTensor, B::Union{AbstractTensor, AbstractArray}, idxs; kwargs...) = tensordot(A, LazyTensor(B), idxs; kwargs...)
Expand All @@ -167,7 +166,7 @@ function tensordot(A::LazyTensor{T1, N1}, B::LazyTensor{T2, N2}, idxs; mult_op=*
extrude = ((A.extrude[n] for n in 1:N1 if !(n in A_idxs))...,
(B.extrude[n] for n in 1:N2 if !(n in B_idxs))...,)

size = ((A.size[n] for n in 1:N1 if !(n in A_idxs))...,
shape = ((A.size[n] for n in 1:N1 if !(n in A_idxs))...,
(B.size[n] for n in 1:N2 if !(n in B_idxs))...,)
A_fields = [field(gensym(:i)) for _ in 1:N1]
B_fields = [field(gensym(:i)) for _ in 1:N2]
Expand All @@ -180,7 +179,7 @@ function tensordot(A::LazyTensor{T1, N1}, B::LazyTensor{T2, N2}, idxs; mult_op=*
AB_reduce = aggregate(immediate(add_op), immediate(init), AB, reduce_fields...)
T = return_type(DefaultAlgebra(), mult_op, T1, T2)
S = fixpoint_type(add_op, init, T)
return LazyTensor{S}(identify(AB_reduce), extrude, size, init)
return LazyTensor{S}(identify(AB_reduce), extrude, shape, init)
end

struct LazyStyle{N} <: BroadcastStyle end
Expand Down Expand Up @@ -252,9 +251,9 @@ function Base.copy(bc::Broadcasted{LazyStyle{N}}) where {N}
idxs = [field(gensym(:i)) for _ in 1:N]
data = reorder(broadcast_to_query(bc_lgc, idxs), idxs)
extrude = ntuple(n -> broadcast_to_extrude(bc_lgc, n), N)
size = ntuple(n -> broadcast_to_size(bc_lgc, n), N)
shape = ntuple(n -> broadcast_to_size(bc_lgc, n), N)
def = broadcast_to_default(bc_lgc)
return LazyTensor{broadcast_to_eltype(bc)}(identify(data), extrude, size, def)
return LazyTensor{broadcast_to_eltype(bc)}(identify(data), extrude, shape, def)
end

function Base.copyto!(::LazyTensor, ::Any)
Expand Down

0 comments on commit f5f2268

Please sign in to comment.