Skip to content

Commit

Permalink
Refactor: Accept all types of categorical columns in OneHot and Levels (
Browse files Browse the repository at this point in the history
  • Loading branch information
eliascarv authored Oct 23, 2023
1 parent 482ca2e commit 4256d53
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 102 deletions.
2 changes: 1 addition & 1 deletion src/TableTransforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ using Random
using CoDa

using TransformsBase: Transform, Identity,
using DataScienceTraits: SciType, Continuous, coerce
using ColumnSelectors: ColumnSelector, SingleColumnSelector
using ColumnSelectors: AllSelector, Column, selector, selectsingle
using DataScienceTraits: SciType, Continuous, Categorical, coerce
using Unitful: AbstractQuantity, AffineQuantity, AffineUnits, Units
using Distributions: ContinuousUnivariateDistribution, Normal
using StatsBase: AbstractWeights, Weights, sample
Expand Down
24 changes: 0 additions & 24 deletions src/assertions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,3 @@ function (assertion::SciTypeAssertion{T})(table) where {T}
@assert elscitype(x) <: T "the elements of the column '$nm' are not of scientific type $T"
end
end

"""
ColumnTypeAssertion{T}(selector = AllSelector())
Asserts that the columns in the `selector` have a type `T`.
"""
struct ColumnTypeAssertion{T,S<:ColumnSelector}
selector::S
end

ColumnTypeAssertion{T}(selector::S) where {T,S<:ColumnSelector} = ColumnTypeAssertion{T,S}(selector)

ColumnTypeAssertion{T}() where {T} = ColumnTypeAssertion{T}(AllSelector())

function (assertion::ColumnTypeAssertion{T})(table) where {T}
cols = Tables.columns(table)
names = Tables.columnnames(cols)
snames = assertion.selector(names)

for nm in snames
x = Tables.getcolumn(cols, nm)
@assert typeof(x) <: T "the column '$nm' is not of type $T"
end
end
35 changes: 19 additions & 16 deletions src/transforms/levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,32 +27,35 @@ Levels(pairs::Pair{C}...; ordered=nothing) where {C<:Column} =

Levels(; kwargs...) = throw(ArgumentError("cannot create Levels transform without arguments"))

assertions(transform::Levels) = [ColumnTypeAssertion{CategoricalArray}(transform.selector)]
assertions(transform::Levels) = [SciTypeAssertion{Categorical}(transform.selector)]

isrevertible(::Type{<:Levels}) = true

_revfun(x) = y -> Array(y)
function _revfun(x::CategoricalArray)
l, o = levels(x), isordered(x)
y -> categorical(y, levels=l, ordered=o)
end

function applyfeat(transform::Levels, feat, prep)
cols = Tables.columns(feat)
names = Tables.columnnames(cols)
snames = transform.selector(names)
ordered = transform.ordered(snames)
tlevels = transform.levels
leveldict = Dict(zip(snames, transform.levels))

results = map(names) do nm
x = Tables.getcolumn(cols, nm)
results = map(names) do name
x = Tables.getcolumn(cols, name)

if nm snames
o = nm ordered
l = tlevels[findfirst(==(nm), snames)]
if name snames
o = name ordered
l = leveldict[name]
y = categorical(x, levels=l, ordered=o)

xl, xo = levels(x), isordered(x)
revfunc = y -> categorical(y, levels=xl, ordered=xo)
revfun = _revfun(x)
y, revfun
else
y, revfunc = x, identity
x, identity
end

y, revfunc
end

columns, fcache = first.(results), last.(results)
Expand All @@ -67,9 +70,9 @@ function revertfeat(::Levels, newfeat, fcache)
cols = Tables.columns(newfeat)
names = Tables.columnnames(cols)

columns = map(names, fcache) do nm, revfunc
x = Tables.getcolumn(cols, nm)
revfunc(x)
columns = map(names, fcache) do name, revfun
y = Tables.getcolumn(cols, name)
revfun(y)
end

𝒯 = (; zip(names, columns)...)
Expand Down
34 changes: 21 additions & 13 deletions src/transforms/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,54 +26,62 @@ end

OneHot(col::Column; categ=false) = OneHot(selector(col), categ)

assertions(transform::OneHot) = [ColumnTypeAssertion{CategoricalArray}(transform.selector)]
assertions(transform::OneHot) = [SciTypeAssertion{Categorical}(transform.selector)]

isrevertible(::Type{<:OneHot}) = true

_categ(x) = categorical(x), identity
function _categ(x::CategoricalArray)
l, o = levels(x), isordered(x)
revfun = y -> categorical(y, levels=l, ordered=o)
x, revfun
end

function applyfeat(transform::OneHot, feat, prep)
cols = Tables.columns(feat)
names = Tables.columnnames(cols) |> collect
columns = Any[Tables.getcolumn(cols, nm) for nm in names]

name = selectsingle(transform.selector, names)
ind = findfirst(==(name), names)
x = columns[ind]
x, revfun = _categ(columns[ind])

xl = levels(x)
onehot = map(xl) do l
xlevels = levels(x)
onehot = map(xlevels) do l
nm = Symbol("$(name)_$l")
while nm names
nm = Symbol("$(nm)_")
end
nm, x .== l
end

newnms, newcols = first.(onehot), last.(onehot)
newnames = first.(onehot)
newcolumns = last.(onehot)

# convert to categorical arrays if necessary
newcols = transform.categ ? categorical.(newcols, levels=[false, true]) : newcols
newcolumns = transform.categ ? categorical.(newcolumns, levels=[false, true]) : newcolumns

splice!(names, ind, newnms)
splice!(columns, ind, newcols)
splice!(names, ind, newnames)
splice!(columns, ind, newcolumns)

inds = ind:(ind + length(newnms) - 1)
inds = ind:(ind + length(newnames) - 1)

𝒯 = (; zip(names, columns)...)
newfeat = 𝒯 |> Tables.materializer(feat)
newfeat, (name, inds, xl, isordered(x))
newfeat, (name, inds, xlevels, revfun)
end

function revertfeat(::OneHot, newfeat, fcache)
cols = Tables.columns(newfeat)
names = Tables.columnnames(cols) |> collect
columns = Any[Tables.getcolumn(cols, nm) for nm in names]

oname, inds, levels, ordered = fcache
x = map(zip(columns[inds]...)) do row
oname, inds, levels, revfun = fcache
y = map(zip(columns[inds]...)) do row
levels[findfirst(==(true), row)]
end

ocolumn = categorical(x; levels, ordered)
ocolumn = revfun(y)

splice!(names, inds, [oname])
splice!(columns, inds, [ocolumn])
Expand Down
7 changes: 0 additions & 7 deletions test/assertions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,4 @@
selector = CS.selector([:b, :e, :f])
assertion = TT.SciTypeAssertion{DST.Categorical}(selector)
@test_throws AssertionError assertion(table)

selector = CS.selector([:e, :f])
assertion = TT.ColumnTypeAssertion{CategoricalArray}(selector)
@test isnothing(assertion(table))
selector = CS.selector([:b, :e, :f])
assertion = TT.ColumnTypeAssertion{CategoricalArray}(selector)
@test_throws AssertionError assertion(table)
end
17 changes: 10 additions & 7 deletions test/transforms/levels.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
@testset "Levels" begin
a = categorical(rand([true, false], 50))
b = categorical(rand(["y", "n"], 50))
c = categorical(rand(1:3, 50))
a = Bool[1, 0, 1, 0, 1, 1]
b = ["n", "y", "n", "y", "y", "y"]
c = [2, 3, 1, 2, 1, 3]
t = Table(; a, b, c)

T = Levels(2 => ["n", "y", "m"])
n, c = apply(T, t)
@test levels(n.b) == ["n", "y", "m"]
@test isordered(n.b) == false
tₒ = revert(T, n, c)
@test Tables.schema(tₒ) == Tables.schema(t)
@test tₒ == t

T = Levels(:b => ["n", "y", "m"], :c => 1:4, ordered=[:c])
Expand All @@ -18,6 +19,7 @@
@test levels(n.c) == [1, 2, 3, 4]
@test isordered(n.c) == true
tₒ = revert(T, n, c)
@test Tables.schema(tₒ) == Tables.schema(t)
@test tₒ == t

T = Levels("b" => ["n", "y", "m"], "c" => 1:4, ordered=["b"])
Expand All @@ -27,6 +29,7 @@
@test levels(n.c) == [1, 2, 3, 4]
@test isordered(n.c) == false
tₒ = revert(T, n, c)
@test Tables.schema(tₒ) == Tables.schema(t)
@test tₒ == t

a = categorical(["yes", "no", "no", "no", "yes"])
Expand Down Expand Up @@ -87,9 +90,9 @@
tₒ = revert(T, n, c)
@test isordered(tₒ.a) == false

a = rand([true, false], 50)
b = categorical(rand(["y", "n"], 50))
c = categorical(rand(1:3, 50))
a = [0.1, 0.1, 0.2, 0.2, 0.1, 0.2]
b = ["n", "y", "n", "y", "y", "y"]
c = [2, 3, 1, 2, 1, 3]
t = Table(; a, b, c)

# throws: Levels without arguments
Expand All @@ -102,7 +105,7 @@
@test_throws AssertionError apply(T, t)

# throws: non categorical column
T = Levels(:a => [true, false], ordered=[:a])
T = Levels(:a => [0.1, 0.2, 0.3], ordered=[:a])
@test_throws AssertionError apply(T, t)

# throws: invalid ordered column selection
Expand Down
77 changes: 43 additions & 34 deletions test/transforms/onehot.jl
Original file line number Diff line number Diff line change
@@ -1,70 +1,79 @@
@testset "OneHot" begin
a = categorical(Bool[0, 1, 1, 0, 1, 1])
b = categorical(["m", "f", "m", "m", "m", "f"])
c = categorical([3, 2, 2, 1, 1, 3])
t = Table(; a, b, c)
a = Bool[0, 1, 1, 0, 1, 1]
b = ["m", "f", "m", "m", "m", "f"]
c = [3, 2, 2, 1, 1, 3]
d = categorical(a)
e = categorical(b)
f = categorical(c)
t = Table(; a, b, c, d, e, f)

T = OneHot(1; categ=true)
n, c = apply(T, t)
@test Tables.columnnames(n) == (:a_false, :a_true, :b, :c)
@test Tables.columnnames(n) == (:a_false, :a_true, :b, :c, :d, :e, :f)
@test n.a_false == categorical(Bool[1, 0, 0, 1, 0, 0])
@test n.a_true == categorical(Bool[0, 1, 1, 0, 1, 1])
@test n.a_false isa CategoricalVector{Bool}
@test n.a_true isa CategoricalVector{Bool}
tₒ = revert(T, n, c)
@test t == tₒ
@test Tables.schema(tₒ) == Tables.schema(t)
@test tₒ == t

T = OneHot(:b; categ=true)
n, c = apply(T, t)
@test Tables.columnnames(n) == (:a, :b_f, :b_m, :c)
@test Tables.columnnames(n) == (:a, :b_f, :b_m, :c, :d, :e, :f)
@test n.b_f == categorical(Bool[0, 1, 0, 0, 0, 1])
@test n.b_m == categorical(Bool[1, 0, 1, 1, 1, 0])
@test n.b_f isa CategoricalVector{Bool}
@test n.b_m isa CategoricalVector{Bool}
tₒ = revert(T, n, c)
@test t == tₒ
@test Tables.schema(tₒ) == Tables.schema(t)
@test tₒ == t

T = OneHot("c"; categ=true)
n, c = apply(T, t)
@test Tables.columnnames(n) == (:a, :b, :c_1, :c_2, :c_3)
@test Tables.columnnames(n) == (:a, :b, :c_1, :c_2, :c_3, :d, :e, :f)
@test n.c_1 == categorical(Bool[0, 0, 0, 1, 1, 0])
@test n.c_2 == categorical(Bool[0, 1, 1, 0, 0, 0])
@test n.c_3 == categorical(Bool[1, 0, 0, 0, 0, 1])
@test n.c_1 isa CategoricalVector{Bool}
@test n.c_2 isa CategoricalVector{Bool}
@test n.c_3 isa CategoricalVector{Bool}
tₒ = revert(T, n, c)
@test t == tₒ
@test Tables.schema(tₒ) == Tables.schema(t)
@test tₒ == t

T = OneHot(1; categ=false)
T = OneHot(4; categ=false)
n, c = apply(T, t)
@test Tables.columnnames(n) == (:a_false, :a_true, :b, :c)
@test n.a_false == Bool[1, 0, 0, 1, 0, 0]
@test n.a_true == Bool[0, 1, 1, 0, 1, 1]
@test Tables.columnnames(n) == (:a, :b, :c, :d_false, :d_true, :e, :f)
@test n.d_false == Bool[1, 0, 0, 1, 0, 0]
@test n.d_true == Bool[0, 1, 1, 0, 1, 1]
tₒ = revert(T, n, c)
@test t == tₒ
@test Tables.schema(tₒ) == Tables.schema(t)
@test tₒ == t

T = OneHot(:b; categ=false)
T = OneHot(:e; categ=false)
n, c = apply(T, t)
@test Tables.columnnames(n) == (:a, :b_f, :b_m, :c)
@test n.b_f == Bool[0, 1, 0, 0, 0, 1]
@test n.b_m == Bool[1, 0, 1, 1, 1, 0]
@test Tables.columnnames(n) == (:a, :b, :c, :d, :e_f, :e_m, :f)
@test n.e_f == Bool[0, 1, 0, 0, 0, 1]
@test n.e_m == Bool[1, 0, 1, 1, 1, 0]
tₒ = revert(T, n, c)
@test t == tₒ
@test Tables.schema(tₒ) == Tables.schema(t)
@test tₒ == t

T = OneHot("c"; categ=false)
T = OneHot("f"; categ=false)
n, c = apply(T, t)
@test Tables.columnnames(n) == (:a, :b, :c_1, :c_2, :c_3)
@test n.c_1 == Bool[0, 0, 0, 1, 1, 0]
@test n.c_2 == Bool[0, 1, 1, 0, 0, 0]
@test n.c_3 == Bool[1, 0, 0, 0, 0, 1]
@test Tables.columnnames(n) == (:a, :b, :c, :d, :e, :f_1, :f_2, :f_3)
@test n.f_1 == Bool[0, 0, 0, 1, 1, 0]
@test n.f_2 == Bool[0, 1, 1, 0, 0, 0]
@test n.f_3 == Bool[1, 0, 0, 0, 0, 1]
tₒ = revert(T, n, c)
@test t == tₒ
@test Tables.schema(tₒ) == Tables.schema(t)
@test tₒ == t

# name formatting
b = categorical(["m", "f", "m", "m", "m", "f"])
b_f = rand(10)
b_m = rand(10)
b_f = rand(6)
b_m = rand(6)
t = Table(; b, b_f, b_m)

T = OneHot(:b; categ=false)
Expand All @@ -76,10 +85,10 @@
@test t == tₒ

b = categorical(["m", "f", "m", "m", "m", "f"])
b_f = rand(10)
b_m = rand(10)
b_f_ = rand(10)
b_m_ = rand(10)
b_f = rand(6)
b_m = rand(6)
b_f_ = rand(6)
b_m_ = rand(6)
t = Table(; b, b_f, b_m, b_f_, b_m_)

T = OneHot(:b; categ=false)
Expand All @@ -91,8 +100,8 @@
@test t == tₒ

# throws
a = categorical(Bool[0, 1, 1, 0, 1, 1])
b = ["m", "f", "m", "m", "m", "f"]
a = Bool[0, 1, 1, 0, 1, 1]
b = rand(6)
t = Table(; a, b)

# non categorical column
Expand Down

0 comments on commit 4256d53

Please sign in to comment.