Skip to content

Commit

Permalink
rm Weighted
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Abbott committed Jun 15, 2019
1 parent ce17dd5 commit 662ce33
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 15 deletions.
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
TensorCast = "02d47bb6-7ce6-556a-be16-bb1710789e2b"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
WeightedArrays = "379a43df-f81c-573e-83a6-069eb6c11a71"

[compat]
julia = "1"
Expand Down
19 changes: 5 additions & 14 deletions src/SliceMap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module SliceMap

export mapcols, MapCols, maprows, slicemap, tmapcols, ThreadMapCols

using MacroTools, Requires, WeightedArrays, TensorCast, JuliennedArrays
using MacroTools, Requires, TensorCast, JuliennedArrays

using Tracker
using Tracker: TrackedMatrix, track, @grad, data
Expand All @@ -25,9 +25,6 @@ They do not get sliced/iterated (unlike `map`), nor are their gradients tracked.
mapcols(f::Function, M, args...) = _mapcols(map, f, M, args...)
tmapcols(f::Function, M, args...) = _mapcols(threadmap, f, M, args...)

_mapcols(map::Function, f::Function, M::WeightedMatrix, args...) =
Weighted(_mapcols(map, f, M.array, args...), M.weights, M.opt)

_mapcols(map::Function, f::Function, M::AbstractMatrix, args...) =
reduce(hcat, map(col -> surevec(f(col, args...)), eachcol(M)))

Expand Down Expand Up @@ -83,7 +80,7 @@ end

#========== Forward, Static ==========#

using StaticArrays, ForwardDiff, WeightedArrays
using StaticArrays, ForwardDiff

struct MapCols{d} end

Expand All @@ -95,12 +92,9 @@ Their length `d = size(M,1)` should ideally be provided for type-stability, but
The gradient for Tracker and Zygote uses `ForwardDiff` on each slice.
"""
MapCols(f::Function, M::AT, args...) where {AT<:WeightedArrays.MaybeWeightedMatrix} =
MapCols(f::Function, M::AbstractMatrix, args...) =
MapCols{size(M,1)}(f, M, args...)

MapCols{d}(f::Function, M::WeightedMatrix, args...) where {d} =
Weighted(MapCols{d}(f, M.array, args...), M.weights, M.opt)

MapCols{d}(f::Function, M::AbstractMatrix, args...) where {d} =
_MapCols(map, f, M, Val(d), args...)

Expand Down Expand Up @@ -220,7 +214,7 @@ end
# What KissThreading does is much more complicated, perhaps worth investigating:
# https://github.com/mohamed82008/KissThreading.jl/blob/master/src/KissThreading.jl

# BTW I do the first one because some diffeq maps are infer to ::Any
# BTW I do the first one because some diffeq maps infer to ::Any,
# else you could use Core.Compiler.return_type(f, Tuple{eltype(x)})

"""
Expand Down Expand Up @@ -260,12 +254,9 @@ struct ThreadMapCols{d} end
Like `MapCols` but with multi-threading!
"""
ThreadMapCols(f::Function, M::AT, args...) where {AT<:WeightedArrays.MaybeWeightedMatrix} =
ThreadMapCols(f::Function, M::AbstractMatrix, args...) =
ThreadMapCols{size(M,1)}(f, M, args...)

ThreadMapCols{d}(f::Function, M::WeightedMatrix, args...) where {d} =
Weighted(ThreadMapCols{d}(f, M.array, args...), M.weights, M.opt)

ThreadMapCols{d}(f::Function, M::AbstractMatrix, args...) where {d} =
_MapCols(threadmap, f, M, Val(d), args...)

Expand Down

0 comments on commit 662ce33

Please sign in to comment.