Skip to content

Commit

Permalink
slowly cleaning up
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jul 13, 2024
1 parent 5faaa07 commit a1ab54a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 95 deletions.
1 change: 0 additions & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,6 @@ end
return prev
end
tup = (subs...,)
@show TT, subs, tup
return NamedTuple{TT.parameters[1],typeof(tup)}(tup)
end

Expand Down
90 changes: 9 additions & 81 deletions src/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,8 @@ for (jlop, hloop, RT) in (
(:(Base.max), :maximum, :ElType),
(:(Base.:+), :add, :ElType),
(:(Base.:-), :subtract, :ElType),
(:(Base.:*), :multiply, :ElType),
(:(Base.:/), :divide, :ElType),
)
@eval begin
function $jlop(
Expand Down Expand Up @@ -546,55 +548,6 @@ function elem_apply(f, args::Vararg{Any, Nargs}) where Nargs
return traced2_result
end

for (jlop, hloop, RT) in (
(:(Base.min), :minimum, :ElType),
(:(Base.max), :maximum, :ElType),
(:(Base.:+), :add, :ElType),
(:(Base.add_sum), :add, :ElType),
(:(Base.:-), :subtract, :ElType),
(:(Base.:*), :multiply, :ElType),
(:(Base.:/), :divide, :ElType),
)
@eval begin
function elem_apply(
::typeof($jlop),
lhs::TracedRArray{ElType,Shape,N},
rhs::TracedRArray{ElType,Shape,N},
) where {ElType,Shape,N}
return TracedRArray{$RT,Shape,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
)
end

function elem_apply(
::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}, rhs
) where {ElType,Shape,N}
rhs = promote_to(lhs, rhs)
return TracedRArray{$RT,Shape,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
)
end

function elem_apply(
::typeof($jlop), lhs, rhs::TracedRArray{ElType,Shape,N}
) where {ElType,Shape,N}
lhs = promote_to(rhs, lhs)
return TracedRArray{$RT,Shape,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
)
end
end
end

for (jlop, hloop, hlocomp, RT) in (
(:(Base.:(==)), :compare, "EQ", :ElType),
(:(Base.:(!=)), :compare, "NE", :ElType),
Expand Down Expand Up @@ -664,38 +617,13 @@ for (jlop, hloop, hlocomp, RT) in (
end
end

function elem_apply(::typeof(identity), lhs)
return lhs
end
for (jlop, hloop) in (
(:(Base.:-), :negate),
(:(Base.sin), :sine),
(:(Base.cos), :cosine),
(:(Base.tanh), :tanh),
(:(Base.FastMath.tanh_fast), :tanh),
(:(Base.exp), :exponential),
(:(Base.FastMath.exp_fast), :exponential),
(:(Base.log), :log),
(:(Base.sqrt), :sqrt),
)
@eval begin
function elem_apply(
::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}
) where {ElType,Shape,N}
return TracedRArray{ElType,Shape,N}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1)
)
end
end
end

function elem_apply(::Type{T}, lhs::TracedRArray{ElType,Shape,N}) where {T,ElType,Shape,N}
inTy = MLIR.IR.type(lhs.mlir_data)
outTy = MLIR.IR.TensorType(Base.size(inTy), MLIR.IR.Type(T))
return TracedRArray{T,Shape,N}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.convert(lhs.mlir_data; result=outTy), 1)
)
end
# function elem_apply(::Type{T}, lhs::TracedRArray{ElType,Shape,N}) where {T,ElType,Shape,N}
# inTy = MLIR.IR.type(lhs.mlir_data)
# outTy = MLIR.IR.TensorType(Base.size(inTy), MLIR.IR.Type(T))
# return TracedRArray{T,Shape,N}(
# (), MLIR.IR.result(MLIR.Dialects.stablehlo.convert(lhs.mlir_data; result=outTy), 1)
# )
# end

Cassette.overdub(context::TraceCtx, f::typeof(elem_apply), args...) = f(args...)

Expand Down
28 changes: 15 additions & 13 deletions test/struct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,19 @@ function list(x::T...) where {T}
return l
end

function Reactant.make_tracer(
seen::IdDict, prev::RT, path, mode
) where {T,RT<:MockLinkedList{T}}
TT = Reactant.traced_type(T, (), Val(mode))
return MockLinkedList{TT}(
Reactant.make_tracer(seen, prev.head, Reactant.append_path(path, :head), mode),
if !isnothing(prev.tail)
Reactant.make_tracer(seen, prev.tail, Reactant.append_path(path, :tail), mode)
else
nothing
end,
)
end
# function Reactant.make_tracer(
# seen::IdDict, prev::RT, path, mode
# ) where {T,RT<:MockLinkedList{T}}
# TT = Reactant.traced_type(T, (), Val(mode))
# return MockLinkedList{TT}(
# Reactant.make_tracer(seen, prev.head, Reactant.append_path(path, :head), mode),
# if !isnothing(prev.tail)
# Reactant.make_tracer(seen, prev.tail, Reactant.append_path(path, :tail), mode)
# else
# nothing
# end,
# )
# end

function Base.sum(x::MockLinkedList{T}) where {T}
if isnothing(x.tail)
Expand Down Expand Up @@ -93,6 +93,8 @@ end
x = [rand(2, 2) for _ in 1:2]
x2 = list(x...)
x3 = Reactant.make_tracer(IdDict(), x2, (), Reactant.ArrayToConcrete)
@show x2
@show x3

# TODO this should be able to run without problems, but crashes
@test_broken begin
Expand Down

0 comments on commit a1ab54a

Please sign in to comment.