diff --git a/docs/intro.md b/docs/intro.md index 80c40d2..0292503 100644 --- a/docs/intro.md +++ b/docs/intro.md @@ -30,20 +30,22 @@ which is what `mapcols` does, has some overhead: using BenchmarkTools mat1k = rand(3,1000); -@btime mapreduce(fun, hcat, eachcol($mat1k)) # 1.522 ms -@btime mapslices(fun, $mat1k, dims=1) # 1.017 ms - -@btime mapcols(fun, $mat1k) # 399.016 μs -@btime MapCols{3}(fun, $mat1k) # 15.564 μs -@btime MapCols(fun, $mat1k) # 16.774 μs without size - -@btime ForwardDiff.gradient(m -> sum(sin, mapslices(fun, m, dims=1)), $mat1k); # 372.705 ms -@btime Tracker.gradient(m -> sum(sin, mapcols(fun, m)), $mat1k); # 70.203 ms -@btime Tracker.gradient(m -> sum(sin, MapCols{3}(fun, m)), $mat1k); # 146.561 μs, 330.51 KiB -@btime Zygote.gradient(m -> sum(sin, mapcols(fun, m)), $mat1k); # 20.018 ms, 3.82 MiB -@btime Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), $mat1k); # 245.550 μs +@btime mapreduce(fun, hcat, eachcol($mat1k)) # 1.522 ms, 11.80 MiB +@btime mapslices(fun, $mat1k, dims=1) # 1.017 ms, 329.92 KiB + +@btime mapcols(fun, $mat1k) # 399.016 μs, 219.02 KiB +@btime MapCols{3}(fun, $mat1k) # 15.564 μs, 47.16 KiB +@btime MapCols(fun, $mat1k) # 16.774 μs (without slice size) + +@btime ForwardDiff.gradient(m -> sum(mapslices(fun, m, dims=1)), $mat1k); # 329.305 ms +@btime Tracker.gradient(m -> sum(mapcols(fun, m)), $mat1k); # 70.203 ms +@btime Tracker.gradient(m -> sum(MapCols{3}(fun, m)), $mat1k); # 51.129 μs, 282.92 KiB +@btime Zygote.gradient(m -> sum(mapcols(fun, m)), $mat1k); # 20.454 ms, 3.52 MiB +@btime Zygote.gradient(m -> sum(MapCols{3}(fun, m)), $mat1k); # 28.229 μs, 164.63 KiB ``` +For such a simple function, timing `sum(sin, MapCols{3}(fun, m))` takes 3 to 10 times longer! + ## Other packages This package also provides Zygote gradients for the slice/glue functions in @@ -53,13 +55,13 @@ which can be used to write many mapslices-like operations. ```julia using TensorCast -@cast [i,j] := fun(mat[:,j])[i] # same as mapcols +@cast [i,j] := fun(mat[:,j])[i] # same as mapcols tcm(mat) = @cast out[i,j] := fun(mat[:,j])[i] Zygote.gradient(m -> sum(sin, tcm(m)), mat)[1] -@btime tcm($mat1k) # 407.176 μs -@btime Zygote.gradient(m -> sum(sin, tcm(m)), $mat1k); # 19.086 ms +@btime tcm($mat1k) # 427.907 μs +@btime Zygote.gradient(m -> sum(tcm(m)), $mat1k); # 18.358 ms ``` Similar gradients work for the Slice/Align functions in @@ -69,11 +71,11 @@ so it defines these too: ```julia using JuliennedArrays jumap(f,m) = Align(map(f, Slices(m, True(), False())), True(), False()) -jumap(fun, mat) # same as mapcols +jumap(fun, mat) # same as mapcols Zygote.gradient(m -> sum(sin, jumap(fun, m)), mat)[1] -@btime jumap(fun, $mat1k); # 408.259 μs -@btime Zygote.gradient(m -> sum(sin, jumap(fun, m)), $mat1k); # 18.638 ms +@btime jumap(fun, $mat1k); # 421.061 μs +@btime Zygote.gradient(m -> sum(jumap(fun, m)), $mat1k); # 18.383 ms ``` That's a 2-line gradient definition, so borrowing it may be easier than depending on this package. @@ -102,11 +104,11 @@ Tracker.gradient(k -> sum(sin, MapCols{2}(g, k, 1:5)), kay)[1] This is quite efficient, and seems to go well with multi-threading: ```julia -@btime MapCols{2}(g, $kay, 1:5) # 1.423 ms -@btime ThreadMapCols{2}(g, $kay, 1:5) # 713.748 μs +@btime MapCols{2}(g, $kay, 1:5) # 1.394 ms +@btime ThreadMapCols{2}(g, $kay, 1:5) # 697.333 μs -@btime Tracker.gradient(k -> sum(sin, MapCols{2}(g, k, 1:5)), $kay)[1] # 2.535 ms -@btime Tracker.gradient(k -> sum(sin, ThreadMapCols{2}(g, k, 1:5)), $kay)[1] # 1.333 ms +@btime Tracker.gradient(k -> sum(MapCols{2}(g, k, 1:5)), $kay)[1] # 2.561 ms +@btime Tracker.gradient(k -> sum(ThreadMapCols{2}(g, k, 1:5)), $kay)[1] # 1.344 ms Threads.nthreads() == 4 # on my 2/4-core laptop ``` diff --git a/src/SliceMap.jl b/src/SliceMap.jl index eda926f..3e54c43 100644 --- a/src/SliceMap.jl +++ b/src/SliceMap.jl @@ -113,6 +113,8 @@ _MapCols(map::Function, f::Function, M::TrackedMatrix, dval, args...) = function ∇MapCols(bigmap::Function, f::Function, M::AbstractMatrix{T}, dval::Val{d}, args...) where {T,d} d == size(M,1) || error("expected M with $d columns") + k = size(M,2) + A = reinterpret(SArray{Tuple{d}, T, 1, d}, vec(data(M))) dualcol = SVector(ntuple(j->ForwardDiff.Dual(0, ntuple(i->i==j ? 1 : 0, dval)...), dval)) @@ -121,8 +123,9 @@ function ∇MapCols(bigmap::Function, f::Function, M::AbstractMatrix{T}, dval::V Z = reduce(hcat, map(col -> ForwardDiff.value.(col), C)) function back(ΔZ) - ∇M = zeros(eltype(data(ΔZ)), size(M)) - @inbounds for c=1:size(M,2) + S = promote_type(T, eltype(data(ΔZ))) + ∇M = zeros(S, size(M)) + @inbounds for c=1:k part = ForwardDiff.partials.(C[c]) for r=1:d for i=1:size(ΔZ,1)