Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] A way to add kwarg support to #164 #165

Draft
wants to merge 7 commits into
base: subtape
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions src/tapedfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,25 @@ const LOGGING = Ref(false)
abstract type AbstractInstruction end
const RawTape = Vector{AbstractInstruction}

separate_kwargs(args...; kwargs...) = (args, values(kwargs))

function _infer(f, args_type)
# `code_typed` returns a vector: [Pair{Core.CodeInfo, DataType}]
ir0 = code_typed(f, Tuple{args_type...}, optimize=false)[1][1]
return ir0
end

resolve_globalref(var) = var
resolve_globalref(var::Core.GlobalRef) = getproperty(var.mod, var.name)

function mark_kwarg_func_as_nonprimitive(ir::Core.CodeInfo)
line = ir.code[end - 1]
Meta.isexpr(line, :call) || error("Expected a call expression")
f = resolve_globalref(line.args[1])
@debug "Marking $f as non-primitive"
@eval is_primitive(::typeof($f)) = false
end

const Bindings = Vector{Any}

mutable struct TapedFunction{F, TapeType}
Expand All @@ -59,26 +72,31 @@ mutable struct TapedFunction{F, TapeType}
retval_binding_slot::Int # 0 indicates the function has not returned
deepcopy_types::Type # use a Union type for multiple types

function TapedFunction{F, T}(f::F, args...; cache=false, deepcopy_types=Union{}) where {F, T}
function TapedFunction{F, T}(_f::F, _args...; cache=false, deepcopy_types=Union{}, _kwargs...) where {F, T}
f, args = make_kwcall_maybe(_f, _args...; _kwargs...)
args_type = _accurate_typeof.(args)
cache_key = (f, deepcopy_types, args_type...)

cache_key = (f, deepcopy_types, args_type...)
if cache && haskey(TRCache, cache_key) # use cache
cached_tf = TRCache[cache_key]::TapedFunction{F, T}
cached_tf = TRCache[cache_key]::TapedFunction{typeof(f), T}
tf = copy(cached_tf)
tf.counter = 1
return tf
end
ir = _infer(f, args_type)
if iskwcall(f)
mark_kwarg_func_as_nonprimitive(ir)
end
binding_values, slots, tape = translate!(RawTape(), ir)

tf = new{F, T}(f, length(args), ir, tape, 1, binding_values, slots, 0, deepcopy_types)
# TODO: Make this use `kwcall` instead.
tf = new{typeof(f), T}(f, length(args), ir, tape, 1, binding_values, slots, 0, deepcopy_types)
TRCache[cache_key] = tf # set cache
return tf
end

TapedFunction(f, args...; cache=false, deepcopy_types=Union{}) =
TapedFunction{typeof(f), RawTape}(f, args...; cache=cache, deepcopy_types=deepcopy_types)
TapedFunction(f, args...; cache=false, deepcopy_types=Union{}, kwargs...) =
TapedFunction{typeof(f), RawTape}(f, args...; cache=cache, deepcopy_types=deepcopy_types, kwargs...)

function TapedFunction{F, T0}(tf::TapedFunction{F, T1}) where {F, T0, T1}
new{F, T0}(tf.func, tf.arity, tf.ir, tf.tape,
Expand All @@ -91,6 +109,19 @@ end
const TRCache = LRU{Tuple, TapedFunction}(maxsize=10)
const CompiledTape = Vector{FunctionWrapper{Nothing, Tuple{TapedFunction}}}

# TODO: Make this work on pre-1.9 Julia.
iskwcall(f) = false
iskwcall(f::typeof(Core.kwcall)) = true
iskwcall(tf::TapedFunction) = tf.func === Core.kwcall
function make_kwcall_maybe(f, args...; kwargs...)
return if length(kwargs) > 0
args, kwargs = separate_kwargs(args...; kwargs...)
Core.kwcall, (kwargs, f, args...)
else
f, args
end
end

function Base.convert(::Type{CompiledTape}, tape::RawTape)
ctape = CompiledTape(undef, length(tape))
for idx in 1:length(tape)
Expand Down
9 changes: 6 additions & 3 deletions src/tapedtask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,14 @@ BASE_COPY_TYPES = Union{Array, Ref}

# NOTE: evaluating model without a trace, see
# https://github.com/TuringLang/Turing.jl/pull/1757#diff-8d16dd13c316055e55f300cd24294bb2f73f46cbcb5a481f8936ff56939da7ceR329
function TapedTask(f, args...; deepcopy_types=nothing) # deepcoy Array and Ref by default.
function TapedTask(f, args...; deepcopy_types=nothing, kwargs...) # deepcoy Array and Ref by default.
if isnothing(deepcopy_types)
deepcopy = BASE_COPY_TYPES
else
deepcopy = Union{BASE_COPY_TYPES, deepcopy_types}
end
tf = TapedFunction(f, args...; cache=true, deepcopy_types=deepcopy)
tf = TapedFunction(f, args...; cache=true, deepcopy_types=deepcopy, kwargs...)
args = last(make_kwcall_maybe(f, args...; kwargs...))
TapedTask(tf, args...)
end

Expand Down Expand Up @@ -169,7 +170,9 @@ Base.IteratorEltype(::Type{<:TapedTask}) = Base.EltypeUnknown()

# copy the task

function Base.copy(t::TapedTask; args=())
function Base.copy(t::TapedTask; args=(), kwargs=())
args = last(make_kwcall_maybe(func(t), args...; kwargs...))

length(args) > 0 && t.tf.counter >1 &&
error("can't copy started task with new arguments")
tf = copy(t.tf)
Expand Down