diff --git a/src/Chain.jl b/src/Chain.jl index 74f43cb..8942bfa 100644 --- a/src/Chain.jl +++ b/src/Chain.jl @@ -56,6 +56,10 @@ function insert_first_arg(e::Expr, firstarg; assignment = false) elseif head == :call && length(args) > 0 if length(args) ≥ 2 && Meta.isexpr(args[2], :parameters) Expr(head, args[1:2]..., firstarg, args[3:end]...) + elseif args[1] == :splat + Expr(:call, args[2], Expr(:..., firstarg)) + elseif Meta.isexpr(args[1], :call) && args[1].args[1] == :splat + Expr(:call, args[1].args[2], Expr(:..., firstarg), args[2:end]...) else Expr(head, args[1], firstarg, args[2:end]...) end @@ -278,8 +282,16 @@ function replace_underscores(expr::Expr, replacement) # for all other expressions, their arguments are checked for underscores recursively # and replaced if any are found else + is_splat = false + if !isempty(expr.args) && expr.args[1] isa Expr && expr.args[1].args[1] == :splat + is_splat = true + replacement = Expr(:..., replacement) + end newargs = map(x -> replace_underscores(x, replacement), expr.args) found_underscore = any(first.(newargs)) + if is_splat && found_underscore + newargs[1] = (newargs[1][1], expr.args[1].args[2]) + end newexpr = Expr(expr.head, last.(newargs)...) end return found_underscore, newexpr