Skip to content

Commit

Permalink
Refactor Filter implementation: avoid Tables.rowtable (#240)
Browse files Browse the repository at this point in the history
* Refactor 'Filter' implementation: avoid 'Tables.rowtable'

* Update tests
  • Loading branch information
eliascarv authored Nov 28, 2023
1 parent 169ad55 commit 3ea33a5
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 17 deletions.
31 changes: 20 additions & 11 deletions src/transforms/filter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,29 +47,38 @@ function preprocess(transform::Filter, feat)
end

function applyfeat(::Filter, feat, prep)
# collect all rows
rows = Tables.rowtable(feat)

# preprocessed indices
sinds, rinds = prep

# select/reject rows
srows = view(rows, sinds)
rrows = view(rows, rinds)
# selected/rejected rows
srows = Tables.subset(feat, sinds, viewhint=true)
rrows = Tables.subset(feat, rinds, viewhint=true)

newfeat = srows |> Tables.materializer(feat)

newfeat, (rinds, rrows)
end

function revertfeat(::Filter, newfeat, fcache)
# collect all rows
rows = Tables.rowtable(newfeat)
cols = Tables.columns(newfeat)
names = Tables.columnnames(cols)

rinds, rrows = fcache
for (i, row) in zip(rinds, rrows)
insert!(rows, i, row)

# columns with selected rows
columns = map(names) do name
collect(Tables.getcolumn(cols, name))
end

# insert rejected rows into columns
rrcols = Tables.columns(rrows)
for (name, x) in zip(names, columns)
r = Tables.getcolumn(rrcols, name)
for (i, v) in zip(rinds, r)
insert!(x, i, v)
end
end

rows |> Tables.materializer(newfeat)
𝒯 = (; zip(names, columns)...)
𝒯 |> Tables.materializer(newfeat)
end
7 changes: 5 additions & 2 deletions src/transforms/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ function preprocess(transform::Sample, feat)
else
sample(rng, inds, weights, size; replace, ordered)
end

# rejected indices
rinds = setdiff(inds, sinds)

sinds, rinds
Expand All @@ -67,7 +69,7 @@ function applyfeat(::Sample, feat, prep)
# preprocessed indices
sinds, rinds = prep

# selected and removed rows
# selected/rejected rows
srows = Tables.subset(feat, sinds, viewhint=true)
rrows = Tables.subset(feat, rinds, viewhint=true)

Expand All @@ -78,6 +80,7 @@ end
function revertfeat(::Sample, newfeat, fcache)
cols = Tables.columns(newfeat)
names = Tables.columnnames(cols)

sinds, rinds, rrows = fcache

# columns with selected rows in original order
Expand All @@ -87,7 +90,7 @@ function revertfeat(::Sample, newfeat, fcache)
[y[i] for i in uinds]
end

# insert removed rows into columns
# insert rejected rows into columns
rrcols = Tables.columns(rrows)
for (name, x) in zip(names, columns)
r = Tables.getcolumn(rrcols, name)
Expand Down
10 changes: 10 additions & 0 deletions test/transforms/filter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,14 @@
@test Tables.isrowtable(n)
rtₒ = revert(T, n, c)
@test rt == rtₒ

# performance tests
trng = MersenneTwister(2) # test rng
x = rand(trng, 100_000)
y = rand(trng, 100_000)
c = CoDaArray((a=rand(trng, 100_000), b=rand(trng, 100_000), c=rand(trng, 100_000)))
t = (; x, y, c)

T = Filter(row -> row.x > 0.5)
@test @elapsed(apply(T, t)) < 0.5
end
9 changes: 5 additions & 4 deletions test/transforms/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,11 @@
@test isapprox(count(==(trows[6]), nrows) / 10_000, 6 / 21, atol=0.01)

# performance tests
x = rand(100_000)
y = rand(100_000)
c = CoDaArray((a=rand(100_000), b=rand(100_000), c=rand(100_000)))
t = Table(; x, y, c)
trng = MersenneTwister(2) # test rng
x = rand(trng, 100_000)
y = rand(trng, 100_000)
c = CoDaArray((a=rand(trng, 100_000), b=rand(trng, 100_000), c=rand(trng, 100_000)))
t = (; x, y, c)

T = Sample(10_000)
@test @elapsed(apply(T, t)) < 0.5
Expand Down

0 comments on commit 3ea33a5

Please sign in to comment.