diff --git a/Project.toml b/Project.toml index 0793db8..2e535e5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,16 +1,14 @@ name = "Yota" uuid = "cd998857-8626-517d-b929-70ad188a48f0" authors = ["Andrei Zhabinski "] -version = "0.4.3" +version = "0.4.4" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -Cassette = "7057c7e9-c182-5462-911a-8362d720325c" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Espresso = "6912e4f1-e036-58b0-9138-08d1e6358ea9" IRTools = "7869d1d1-7146-5819-86e3-90919afe41df" -JuliaInterpreter = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -18,10 +16,8 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] CUDA = "1.2, 2.3" -Cassette = "0.2.6, 0.3" ChainRulesCore = "0.9.5" Distributions = "0.23.2" Espresso = "0.6.0" IRTools = "0.4.0" -JuliaInterpreter = "0.7.2" julia = "1.4" diff --git a/README.md b/README.md index 190d2cf..d5f5c8f 100644 --- a/README.md +++ b/README.md @@ -145,13 +145,6 @@ compile!(tape) # 492.063 ns (2 allocations: 144 bytes) ``` -Note that `trace()` is an alias to `irtrace()` - IRTools-based tracer. As of Yota 0.4.0, two other tracers are available: - - * `ctrace()`, based on [Cassette.jl](https://github.com/jrevels/Cassette.jl) - * `itrace()`, based on [JuliaInterpreter.jl](https://github.com/JuliaDebug/JuliaInterpreter.jl) - -These tracers can be used for experimental purposes, but **their reliability or even existence is not guaranteed in future**. For any long-term code please use alias `trace()` which always points to the most recent and well-tested implementation. - ## CUDA support `CuArray` is fully supported. If you encounter an issue with CUDA arrays which you don't have with ordinary arrays, please file a bug. diff --git a/src/core.jl b/src/core.jl index 0945836..2246c5f 100644 --- a/src/core.jl +++ b/src/core.jl @@ -12,7 +12,7 @@ include("helpers.jl") include("devices.jl") include("tape.jl") include("tapeutils.jl") -include("trace/trace.jl") +include("trace.jl") include("diffrules/diffrules.jl") include("grad.jl") include("compile.jl") diff --git a/src/devices.jl b/src/devices.jl index 4d1a8de..876d62e 100644 --- a/src/devices.jl +++ b/src/devices.jl @@ -9,28 +9,14 @@ end GPU() = GPU(1) - -""" -Check if the argument is of type CuArray. Doesn't require CuArrays.jl to be loaded -""" -is_cuarray(x) = startswith(string(typeof(x)), "CuArray") - -# function has_cuda_inputs(tape::Tape) -# res = false -# for op in tape -# if op isa Input && op.val isa CuArray -# res = true -# break -# end -# end -# return res -# end +is_cuarray(x) = x isa CuArray # currently GPU's ID is just a placeholder guess_device(args) = any(is_cuarray, args) ? GPU(1) : CPU() device_of(A) = A isa CuArray ? GPU(1) : CPU() + """ Retrieve function compatible with specified device @@ -56,3 +42,6 @@ to_device(device::CPU, f::Function, args) = f (device::CPU)(x) = to_device(device, x) (device::GPU)(x) = to_device(device, x) + + +to_same_device(A, example) = device_of(example)(A) diff --git a/src/helpers.jl b/src/helpers.jl index 973a392..49b482d 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -93,11 +93,13 @@ function unbroadcast_prod_x(x::AbstractArray, y::AbstractArray, Δ) end unbroadcast_prod_y(x::AbstractArray, y::AbstractArray, Δ) = unbroadcast_prod_x(y, x, Δ) -device_like(example, a) = (device = guess_device([example]); device(a)) -unbroadcast_prod_x(x::Number, y::AbstractArray, Δ) = unbroadcast_prod_x(device_like(y, [x]), y, Δ)[1] -unbroadcast_prod_x(x::AbstractArray, y::Number, Δ) = unbroadcast_prod_x(x, device_like(x, [y]), Δ) -unbroadcast_prod_y(x::AbstractArray, y::Number, Δ) = unbroadcast_prod_y(x, device_like(x, [y]), Δ)[1] -unbroadcast_prod_y(x::Number, y::AbstractArray, Δ) = unbroadcast_prod_y(device_like(y, [x]), y, Δ) +# device_like(example, a) = (device = guess_device([example]); device(a)) + +# unbroadcast_prod_x(x::Number, y::AbstractArray, Δ) = unbroadcast_prod_x(device_like(y, [x]), y, Δ)[1] +unbroadcast_prod_x(x::Number, y::AbstractArray, Δ) = unbroadcast_prod_x(to_same_device([x], y), y, Δ)[1] +unbroadcast_prod_x(x::AbstractArray, y::Number, Δ) = unbroadcast_prod_x(x, to_same_device([y], x), Δ) +unbroadcast_prod_y(x::AbstractArray, y::Number, Δ) = unbroadcast_prod_y(x, to_same_device([y], x), Δ)[1] +unbroadcast_prod_y(x::Number, y::AbstractArray, Δ) = unbroadcast_prod_y(to_same_device([x], y), y, Δ) untranspose_vec(ds::Transpose{T, <:AbstractVector{T}}) where T = transpose(ds) diff --git a/src/trace/irtools.jl b/src/trace.jl similarity index 83% rename from src/trace/irtools.jl rename to src/trace.jl index a7170bf..014b9e2 100644 --- a/src/trace/irtools.jl +++ b/src/trace.jl @@ -1,3 +1,45 @@ +function __new__(T, args...) + # note: we also add __new__() to the list of primitives so it's not overdubbed recursively + if T <: NamedTuple + return T(args) + else + return T(args...) + end +end + + +__tuple__(args...) = tuple(args...) +__getfield__(args...) = getfield(args...) + + +function module_functions(modl) + res = Vector{Function}() + for s in Base.names(modl; all=true) + isdefined(modl, s) || continue + fn = getfield(modl, s) + if fn isa Function # && match(r"^[a-z#]+$", string(s)) != nothing + push!(res, fn) + end + end + return res +end + +const PRIMITIVES = Set{Any}(vcat( + module_functions(Base), + module_functions(Core), + module_functions(Core.Intrinsics), + [Broadcast.materialize, Broadcast.broadcasted, Colon(), (:), + Base.not_int, + # our own special functions + __new__, __tuple__, __getfield__, namedtuple, guess_device])); + + +################################################################################ +################################################################################ +# IRTools-based Tracer # +################################################################################ +################################################################################ + import IRTools import IRTools: IR, @dynamo, self, insertafter! @@ -204,3 +246,6 @@ function irtrace(f, args...; primitives=PRIMITIVES, optimize=true) end return val, tape end + + +trace = irtrace diff --git a/src/trace/cassette.jl b/src/trace/cassette.jl deleted file mode 100644 index a700a17..0000000 --- a/src/trace/cassette.jl +++ /dev/null @@ -1,118 +0,0 @@ -using Cassette -using Cassette: Tagged, tag, untag, istagged, metadata, hasmetadata, - enabletagging, @overdub, overdub, canrecurse, similarcontext, fallback - - -######################################################################## -# SETTING THE SCENE # -######################################################################## - -Cassette.@context TraceCtx - -# allow assiciation of Int values with TraceCtx -Cassette.metadatatype(::Type{<:TraceCtx}, ::DataType) = Int -Cassette.hastagging(::Type{<:TraceCtx}) = true - - -######################################################################## -# CUSTOM PASS # -######################################################################## - - -is_gref_call(a, fn_name) = a isa GlobalRef && a.name == fn_name - - -function prepare_ir(::Type{<:TraceCtx}, reflection::Cassette.Reflection) - ir = reflection.code_info - Cassette.replace_match!(x -> Base.Meta.isexpr(x, :new), ir.code) do x - return Expr(:call, __new__, x.args...) - end - Cassette.replace_match!(x -> Base.Meta.isexpr(x, :call) && is_gref_call(x.args[1], :tuple), ir.code) do x - return Expr(:call, __tuple__, x.args[2:end]...) - end - Cassette.replace_match!(x -> Base.Meta.isexpr(x, :call) && is_gref_call(x.args[1], :getfield), ir.code) do x - return Expr(:call, __getfield__, x.args[2:end]...) - end - return ir -end - -@runonce const prepare_pass = Cassette.@pass prepare_ir - - -######################################################################## -# TRACE # -######################################################################## - -struct TapeBox - tape::Tape - primitives::Set{Any} -end - - -""" -Trace function execution using provided arguments. -Returns calculated value and a tape. - -``` -foo(x) = 2.0x + 1.0 -val, tape = trace(foo, 4.0) -``` -""" -function ctrace(f, args...; primitives=PRIMITIVES, optimize=true) - # create tape - tape = Tape(guess_device(args)) - box = TapeBox(tape, primitives) - ctx = enabletagging(TraceCtx(metadata=box, pass=prepare_pass), f) - tagged_args = Vector(undef, length(args)) - for (i, x) in enumerate(args) - id = record!(tape, Input, x) - tagged_args[i] = tag(x, ctx, i) - end - # trace f with tagged arguments - tagged_val = overdub(ctx, f, tagged_args...) - val = untag(tagged_val, ctx) - tape.resultid = metadata(tagged_val, ctx) - if optimize - tape = simplify(tape) - end - return val, tape -end - - -function with_free_args_as_constants(ctx::TraceCtx, tape::Tape, args) - new_args = [] - for x in args - if istagged(x, ctx) - push!(new_args, x) - else - # x = x isa Function ? device_function(ctx.metadata.tape.device, x) : x - id = record!(tape, Constant, x) - x = tag(x, ctx, id) - push!(new_args, x) - end - end - return new_args -end - - -function Cassette.overdub(ctx::TraceCtx, fargs...) - f, args = fargs[1], fargs[2:end] - fv = istagged(f, ctx) ? untag(f, ctx) : f - tape = ctx.metadata.tape - primitives = ctx.metadata.primitives - if fv in primitives - fargs = with_free_args_as_constants(ctx, tape, fargs) - farg_ids = [metadata(x, ctx) for x in fargs] - farg_ids = Int[id isa Cassette.NoMetaData ? -1 : id for id in farg_ids] - # execute call - retval = fallback(ctx, [untag(x, ctx) for x in fargs]...) - # record to the tape and tag with a newly created ID - ret_id = record!(tape, Call, retval, fv, farg_ids[2:end]) - retval = tag(retval, ctx, ret_id) - elseif canrecurse(ctx, fv, args...) - retval = Cassette.recurse(ctx, fargs...) - else - retval = fallback(ctx, fargs...) - end - return retval -end diff --git a/src/trace/interp.jl b/src/trace/interp.jl deleted file mode 100644 index b09bd6d..0000000 --- a/src/trace/interp.jl +++ /dev/null @@ -1,141 +0,0 @@ -import JuliaInterpreter -import JuliaInterpreter: enter_call, step_expr!, @lookup, Frame, SSAValue, SlotNumber - - -getexpr(frame::Frame, pc::Int) = frame.framecode.src.code[pc] -current_expr(frame::Frame) = getexpr(frame, frame.pc) - - -""" -Split JuliaInterpreter call expression into a tuple of 2 elements: - - * callable (e.g. function or callable struct) & its arguments - * vars on the tape corresponding to these function & args - -If arguments include free parameters (not SlotNumber or SSAValue), these are recorded -to the tape as constants -""" -function split_int_call!(tape::Tape, frame::Frame, frame_vars::Dict, ex) - arr = Meta.isexpr(ex, :(=)) ? ex.args[2].args : ex.args - # for whatever reason JuliaInterpreter wraps some nodes in the original code into QuoteNode - arr = [isa(x, QuoteNode) ? x.value : x for x in arr] - cfargs = [x isa Symbol ? x : @lookup(frame, x) for x in arr] - cvars = Vector{Int}(undef, length(cfargs)) - for (i, x) in enumerate(arr) - # if isa(x, JuliaInterpreter.SlotNumber) || isa(x, JuliaInterpreter.SSAValue) - if haskey(frame_vars, x) - cvars[i] = frame_vars[x] - else - val = x isa Symbol ? x : @lookup(frame, x) - id = record!(tape, Constant, val) - cvars[i] = id - if val != x - # if constant appeared to be a SlotNumber or SSAValue - # store its mapping into frame_vars - frame_vars[x] = id - end - end - end - return cfargs, cvars -end - - -""" -Given a Frame and current expression, extract LHS location (SlotNumber or SSAValue) -""" -get_location(frame::Frame, ex) = Meta.isexpr(ex, :(=)) ? ex.args[1] : JuliaInterpreter.SSAValue(frame.pc) - -is_int_call_expr(ex) = Meta.isexpr(ex, :call) || (Meta.isexpr(ex, :(=)) && Meta.isexpr(ex.args[2], :call)) -is_int_assign_expr(ex) = Meta.isexpr(ex, :(=)) && (isa(ex.args[2], SlotNumber) || isa(ex.args[2], SSAValue)) - -is_interesting_expr(ex) = is_int_call_expr(ex) || is_int_assign_expr(ex) || Meta.isexpr(ex, :return) - - -function itrace!(tape::Tape, fargs, vars; primitives) - frame = enter_call(fargs...) - frame_vars = Dict{Any, Int}(JuliaInterpreter.SlotNumber(i) => v for (i, v) in enumerate(vars)) - # f might be a callable struct, so we need to record it and add to frame_vars - # f_id = record!(tape, Constant, f) # should not be constant! should caller record it? - # frame_vars[SlotNumber(1)] = f - is_interesting_expr(current_expr(frame)) || step_expr!(frame) # skip non-call expressions - ex = current_expr(frame) - while !Meta.isexpr(ex, :return) - if is_int_assign_expr(ex) - lhs, rhs = ex.args - frame_vars[lhs] = frame_vars[rhs] - step_expr!(frame) - elseif is_int_call_expr(ex) - # read as "current function & arguments", "current variables" - cfargs, cvars = split_int_call!(tape, frame, frame_vars, ex) - cf = cfargs[1] - loc = get_location(frame, ex) - # there are several special cases such as NamedTuples and constructors - # we replace these with calls to special helper functions - if cf isa UnionAll && cf <: NamedTuple - # replace cf with namedtuple function, adjust arguments - names = collect(cf.body.parameters)[1] - cf = namedtuple - insert!(cfargs, 2, names) - names_var_id = record!(tape, Constant, names) - cvars = [names_var_id; cvars] - elseif cf isa DataType - # constructor, replace with a call to __new__ which we know how to differentiate - T = cf - cf = __new__ - insert!(cfargs, 2, T) - T_var_id = record!(tape, Constant, T) - cvars = [T_var_id; cvars] - elseif cf == Base.tuple - cf = __tuple__ - elseif cf == Base.getfield || (cf == Base.getindex && isa(tape[cvars[2]].val, NamedTuple)) - # similar to constuctors, there's a special case for __getfield__ in backprop - cf = __getfield__ - end - # if current function is a primitive of a built-in, write it to the tape - # otherwise recurse into the current function - if cf in primitives || isa(cf, Core.Builtin) || isa(cf, Core.IntrinsicFunction) - step_expr!(frame) - retval = @lookup(frame, loc) - ret_id = record!(tape, Call, retval, cf, cvars[2:end]) - frame_vars[loc] = ret_id # for slots it may overwrite old mapping - else - try - retval, ret_id = itrace!(tape, cfargs, cvars; primitives=primitives) - catch - println("Failed to trace through function $cf") - rethrow() - end - frame_vars[loc] = ret_id # for slots it may overwrite old mapping - step_expr!(frame) # can we avoid this double execution? - end - else - step_expr!(frame) - end - ex = current_expr(frame) - end - retval = @lookup(frame, ex.args[1]) - ret_id = frame_vars[ex.args[1]] - return retval, ret_id # return var ID of a result variable -end - - -""" -Trace function f with arguments args using JuliaInterpreter -""" -function itrace(f, args...; primitives=PRIMITIVES, optimize=true) - tape = Tape(guess_device(args)) - # record arguments as input variables - fargs = Vector(undef, length(args) + 1) - vars = Vector{Int}(undef, length(args) + 1) - for (i, arg) in enumerate([f, args...]) - id = record!(tape, Input, arg) - fargs[i] = arg - vars[i] = i - end - val, resultid = itrace!(tape, fargs, vars; primitives=primitives) - tape.resultid = resultid - if optimize - tape = simplify(tape) - end - return val, tape -end diff --git a/src/trace/trace.jl b/src/trace/trace.jl deleted file mode 100644 index e7d3474..0000000 --- a/src/trace/trace.jl +++ /dev/null @@ -1,44 +0,0 @@ -function __new__(T, args...) - # @show T - # @show args - # note: we also add __new__() to the list of primitives so it's not overdubbed recursively - if T <: NamedTuple - return T(args) - else - return T(args...) - end -end - - -__tuple__(args...) = tuple(args...) -__getfield__(args...) = getfield(args...) - - -function module_functions(modl) - res = Vector{Function}() - for s in Base.names(modl; all=true) - isdefined(modl, s) || continue - fn = getfield(modl, s) - if fn isa Function # && match(r"^[a-z#]+$", string(s)) != nothing - push!(res, fn) - end - end - return res -end - -const PRIMITIVES = Set{Any}(vcat( - module_functions(Base), - module_functions(Core), - module_functions(Core.Intrinsics), - [Broadcast.materialize, Broadcast.broadcasted, Colon(), (:), - Base.not_int, - # our own special functions - __new__, __tuple__, __getfield__, namedtuple, guess_device])); - - -include("cassette.jl") -# include("interp.jl") -include("irtools.jl") - - -trace = irtrace diff --git a/test/runtests.jl b/test/runtests.jl index ca3206e..96aba04 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ using Test using Yota -using Yota: Tape, Input, Call, Constant, trace, ctrace, irtrace, play!, transform, binarize_ops +using Yota: Tape, Input, Call, Constant, trace, play!, transform, binarize_ops using Yota: ∇mean, setfield_nested!, copy_with, simplegrad, remove_unused using Yota: find_field_source_var, unwind_iterate, eliminate_common using Yota: unvcat, unhcat, uncat, ungetindex!, ungetindex @@ -10,9 +10,7 @@ using CUDA CUDA.allowscalar(false) -include("test_trace_cassette.jl") -include("test_trace_irtools.jl") -# include("test_trace_interp.jl") # doesn't work for Julia 1.4, maybe will drop support at all +include("test_trace.jl") include("gradcheck.jl") include("test_helpers.jl") include("test_simple.jl") diff --git a/test/test_trace_cassette.jl b/test/test_trace.jl similarity index 70% rename from test/test_trace_cassette.jl rename to test/test_trace.jl index 52b1d2f..01ddd44 100644 --- a/test/test_trace_cassette.jl +++ b/test/test_trace.jl @@ -6,17 +6,17 @@ non_primitive(x) = 2x + 1 non_primitive_caller(x) = sin(non_primitive(x)) -@testset "ctracer: calls" begin - val, tape = ctrace(inc_mul, 2.0, 3.0) +@testset "tracer: calls" begin + val, tape = trace(inc_mul, 2.0, 3.0) @test val == inc_mul(2.0, 3.0) @test length(tape) == 5 @test tape[3] isa Constant end -@testset "ctracer: bcast" begin +@testset "tracer: bcast" begin A = rand(3) B = rand(3) - val, tape = ctrace(inc_mul, A, B) + val, tape = trace(inc_mul, A, B) @test val == inc_mul(A, B) # broadcasting may be lowered to different forms, # so making no assumptions regarding the tape @@ -25,10 +25,10 @@ end @test val == inc_mul2(A, B) end -@testset "ctracer: primitives" begin +@testset "tracer: primitives" begin x = 3.0 - val1, tape1 = ctrace(non_primitive_caller, x) - val2, tape2 = ctrace(non_primitive_caller, x; primitives=Set([non_primitive, sin])) + val1, tape1 = trace(non_primitive_caller, x) + val2, tape2 = trace(non_primitive_caller, x; primitives=Set([non_primitive, sin])) @test val1 == val2 @test any(op isa Call && op.fn == (*) for op in tape1) diff --git a/test/test_trace_interp.jl b/test/test_trace_interp.jl deleted file mode 100644 index 7810ff7..0000000 --- a/test/test_trace_interp.jl +++ /dev/null @@ -1,30 +0,0 @@ -@testset "itracer: calls" begin - val, tape = itrace(inc_mul, 2.0, 3.0) - @test val == inc_mul(2.0, 3.0) - @test length(tape) == 5 - @test tape[3] isa Constant -end - -@testset "itracer: bcast" begin - A = rand(3) - B = rand(3) - val, tape = itrace(inc_mul, A, B) - @test val == inc_mul(A, B) - # broadcasting may be lowered to different forms, - # so making no assumptions regarding the tape - - val, tape = itrace(inc_mul2, A, B) - @test val == inc_mul2(A, B) -end - -@testset "itracer: primitives" begin - x = 3.0 - val1, tape1 = itrace(non_primitive_caller, x) - val2, tape2 = itrace(non_primitive_caller, x; primitives=Set([non_primitive, sin])) - - @test val1 == val2 - @test any(op isa Call && op.fn == (*) for op in tape1) - @test tape2[2].fn == non_primitive - @test tape2[3].fn == sin - -end diff --git a/test/test_trace_irtools.jl b/test/test_trace_irtools.jl index 2c2f2c5..01ddd44 100644 --- a/test/test_trace_irtools.jl +++ b/test/test_trace_irtools.jl @@ -1,26 +1,34 @@ -@testset "irtracer: calls" begin - val, tape = irtrace(inc_mul, 2.0, 3.0) +inc_mul(a::Real, b::Real) = a * (b + 1.0) +inc_mul(A::AbstractArray, B::AbstractArray) = inc_mul.(A, B) +inc_mul2(A::AbstractArray, B::AbstractArray) = A .* (B .+ 1) + +non_primitive(x) = 2x + 1 +non_primitive_caller(x) = sin(non_primitive(x)) + + +@testset "tracer: calls" begin + val, tape = trace(inc_mul, 2.0, 3.0) @test val == inc_mul(2.0, 3.0) @test length(tape) == 5 @test tape[3] isa Constant end -@testset "irtracer: bcast" begin +@testset "tracer: bcast" begin A = rand(3) B = rand(3) - val, tape = irtrace(inc_mul, A, B) + val, tape = trace(inc_mul, A, B) @test val == inc_mul(A, B) # broadcasting may be lowered to different forms, # so making no assumptions regarding the tape - val, tape = irtrace(inc_mul2, A, B) + val, tape = trace(inc_mul2, A, B) @test val == inc_mul2(A, B) end -@testset "irtracer: primitives" begin +@testset "tracer: primitives" begin x = 3.0 - val1, tape1 = irtrace(non_primitive_caller, x) - val2, tape2 = irtrace(non_primitive_caller, x; primitives=Set([non_primitive, sin])) + val1, tape1 = trace(non_primitive_caller, x) + val2, tape2 = trace(non_primitive_caller, x; primitives=Set([non_primitive, sin])) @test val1 == val2 @test any(op isa Call && op.fn == (*) for op in tape1)