Skip to content

Commit

Permalink
tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Abbott committed Jun 15, 2019
1 parent 662ce33 commit b634651
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 24 deletions.
46 changes: 24 additions & 22 deletions docs/intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,22 @@ which is what `mapcols` does, has some overhead:
using BenchmarkTools
mat1k = rand(3,1000);

@btime mapreduce(fun, hcat, eachcol($mat1k)) # 1.522 ms
@btime mapslices(fun, $mat1k, dims=1) # 1.017 ms

@btime mapcols(fun, $mat1k) # 399.016 μs
@btime MapCols{3}(fun, $mat1k) # 15.564 μs
@btime MapCols(fun, $mat1k) # 16.774 μs without size

@btime ForwardDiff.gradient(m -> sum(sin, mapslices(fun, m, dims=1)), $mat1k); # 372.705 ms
@btime Tracker.gradient(m -> sum(sin, mapcols(fun, m)), $mat1k); # 70.203 ms
@btime Tracker.gradient(m -> sum(sin, MapCols{3}(fun, m)), $mat1k); # 146.561 μs, 330.51 KiB
@btime Zygote.gradient(m -> sum(sin, mapcols(fun, m)), $mat1k); # 20.018 ms, 3.82 MiB
@btime Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), $mat1k); # 245.550 μs
@btime mapreduce(fun, hcat, eachcol($mat1k)) # 1.522 ms, 11.80 MiB
@btime mapslices(fun, $mat1k, dims=1) # 1.017 ms, 329.92 KiB

@btime mapcols(fun, $mat1k) # 399.016 μs, 219.02 KiB
@btime MapCols{3}(fun, $mat1k) # 15.564 μs, 47.16 KiB
@btime MapCols(fun, $mat1k) # 16.774 μs (without slice size)

@btime ForwardDiff.gradient(m -> sum(mapslices(fun, m, dims=1)), $mat1k); # 329.305 ms
@btime Tracker.gradient(m -> sum(mapcols(fun, m)), $mat1k); # 70.203 ms
@btime Tracker.gradient(m -> sum(MapCols{3}(fun, m)), $mat1k); # 51.129 μs, 282.92 KiB
@btime Zygote.gradient(m -> sum(mapcols(fun, m)), $mat1k); # 20.454 ms, 3.52 MiB
@btime Zygote.gradient(m -> sum(MapCols{3}(fun, m)), $mat1k); # 28.229 μs, 164.63 KiB
```

For such a simple function, timing `sum(sin, MapCols{3}(fun, m))` takes 3 to 10 times longer!

## Other packages

This package also provides Zygote gradients for the slice/glue functions in
Expand All @@ -53,13 +55,13 @@ which can be used to write many mapslices-like operations.

```julia
using TensorCast
@cast [i,j] := fun(mat[:,j])[i] # same as mapcols
@cast [i,j] := fun(mat[:,j])[i] # same as mapcols

tcm(mat) = @cast out[i,j] := fun(mat[:,j])[i]
Zygote.gradient(m -> sum(sin, tcm(m)), mat)[1]

@btime tcm($mat1k) # 407.176 μs
@btime Zygote.gradient(m -> sum(sin, tcm(m)), $mat1k); # 19.086 ms
@btime tcm($mat1k) # 427.907 μs
@btime Zygote.gradient(m -> sum(tcm(m)), $mat1k); # 18.358 ms
```

Similar gradients work for the Slice/Align functions in
Expand All @@ -69,11 +71,11 @@ so it defines these too:
```julia
using JuliennedArrays
jumap(f,m) = Align(map(f, Slices(m, True(), False())), True(), False())
jumap(fun, mat) # same as mapcols
jumap(fun, mat) # same as mapcols
Zygote.gradient(m -> sum(sin, jumap(fun, m)), mat)[1]

@btime jumap(fun, $mat1k); # 408.259 μs
@btime Zygote.gradient(m -> sum(sin, jumap(fun, m)), $mat1k); # 18.638 ms
@btime jumap(fun, $mat1k); # 421.061 μs
@btime Zygote.gradient(m -> sum(jumap(fun, m)), $mat1k); # 18.383 ms
```

That's a 2-line gradient definition, so borrowing it may be easier than depending on this package.
Expand Down Expand Up @@ -102,11 +104,11 @@ Tracker.gradient(k -> sum(sin, MapCols{2}(g, k, 1:5)), kay)[1]
This is quite efficient, and seems to go well with multi-threading:

```julia
@btime MapCols{2}(g, $kay, 1:5) # 1.423 ms
@btime ThreadMapCols{2}(g, $kay, 1:5) # 713.748 μs
@btime MapCols{2}(g, $kay, 1:5) # 1.394 ms
@btime ThreadMapCols{2}(g, $kay, 1:5) # 697.333 μs

@btime Tracker.gradient(k -> sum(sin, MapCols{2}(g, k, 1:5)), $kay)[1] # 2.535 ms
@btime Tracker.gradient(k -> sum(sin, ThreadMapCols{2}(g, k, 1:5)), $kay)[1] # 1.333 ms
@btime Tracker.gradient(k -> sum(MapCols{2}(g, k, 1:5)), $kay)[1] # 2.561 ms
@btime Tracker.gradient(k -> sum(ThreadMapCols{2}(g, k, 1:5)), $kay)[1] # 1.344 ms

Threads.nthreads() == 4 # on my 2/4-core laptop
```
Expand Down
7 changes: 5 additions & 2 deletions src/SliceMap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ _MapCols(map::Function, f::Function, M::TrackedMatrix, dval, args...) =

function ∇MapCols(bigmap::Function, f::Function, M::AbstractMatrix{T}, dval::Val{d}, args...) where {T,d}
d == size(M,1) || error("expected M with $d columns")
k = size(M,2)

A = reinterpret(SArray{Tuple{d}, T, 1, d}, vec(data(M)))

dualcol = SVector(ntuple(j->ForwardDiff.Dual(0, ntuple(i->i==j ? 1 : 0, dval)...), dval))
Expand All @@ -121,8 +123,9 @@ function ∇MapCols(bigmap::Function, f::Function, M::AbstractMatrix{T}, dval::V
Z = reduce(hcat, map(col -> ForwardDiff.value.(col), C))

function back(ΔZ)
∇M = zeros(eltype(data(ΔZ)), size(M))
@inbounds for c=1:size(M,2)
S = promote_type(T, eltype(data(ΔZ)))
∇M = zeros(S, size(M))
@inbounds for c=1:k
part = ForwardDiff.partials.(C[c])
for r=1:d
for i=1:size(ΔZ,1)
Expand Down

0 comments on commit b634651

Please sign in to comment.