Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the InverseFunctions.inverse function in Functional transform #227

Merged
merged 8 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ CoDa = "5900dafe-f573-5c72-b367-76665857777b"
ColumnSelectors = "9cc86067-7e36-4c61-b350-1ac9833d277f"
DataScienceTraits = "6cb2f572-2d2b-4ba6-bdb3-e710fa044d6c"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NelderMead = "2f6b4ddb-b4ff-44c0-b59b-2ab99302f970"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Expand All @@ -28,6 +29,7 @@ CoDa = "1.2"
ColumnSelectors = "0.1"
DataScienceTraits = "0.1"
Distributions = "0.25"
InverseFunctions = "0.1"
LinearAlgebra = "1.9"
NelderMead = "0.4"
PrettyTables = "2"
Expand All @@ -36,6 +38,6 @@ Statistics = "1.9"
StatsBase = "0.33, 0.34"
Tables = "1.6"
Transducers = "0.4"
TransformsBase = "1.2"
TransformsBase = "1.3"
Unitful = "1.17"
julia = "1.9"
5 changes: 3 additions & 2 deletions src/TableTransforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ using ColumnSelectors: AllSelector, Column, selector, selectsingle
using DataScienceTraits: SciType, Continuous, Categorical, coerce
using Unitful: AbstractQuantity, AffineQuantity, AffineUnits, Units
using Distributions: ContinuousUnivariateDistribution, Normal
using InverseFunctions: NoInverse, inverse as invfun
using StatsBase: AbstractWeights, Weights, sample
using Transducers: tcollect
using NelderMead: optimise

import Distributions: quantile, cdf
import TransformsBase: assertions, isrevertible, preprocess
import TransformsBase: apply, revert, reapply
import TransformsBase: assertions, isrevertible, isinvertible
import TransformsBase: apply, revert, reapply, preprocess, inverse

include("assertions.jl")
include("tabletraits.jl")
Expand Down
87 changes: 30 additions & 57 deletions src/transforms/functional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,98 +3,71 @@
# ------------------------------------------------------------------

"""
Functional(func)
Functional(fun)

The transform that applies a `func` elementwise.
The transform that applies a `fun` elementwise.

Functional(col₁ => func₁, col₂ => func₂, ..., colₙ => funcₙ)
Functional(col₁ => fun₁, col₂ => fun₂, ..., colₙ => funₙ)

Apply the corresponding `funcᵢ` function to each `colᵢ` column.
Apply the corresponding `funᵢ` function to each `colᵢ` column.

# Examples

```julia
Functional(cos)
Functional(sin)
Functional(1 => cos, 2 => sin)
Functional(:a => cos, :b => sin)
Functional("a" => cos, "b" => sin)
Functional(exp)
Functional(log)
Functional(1 => exp, 2 => log)
Functional(:a => exp, :b => log)
Functional("a" => exp, "b" => log)
```
"""
struct Functional{S<:ColumnSelector,F} <: StatelessFeatureTransform
selector::S
func::F
fun::F
end

Functional(func) = Functional(AllSelector(), func)
Functional(fun) = Functional(AllSelector(), fun)

Functional(pairs::Pair{C}...) where {C<:Column} = Functional(selector(first.(pairs)), last.(pairs))

Functional() = throw(ArgumentError("cannot create Functional transform without arguments"))

# known invertible functions
inverse(::typeof(log)) = exp
inverse(::typeof(exp)) = log
inverse(::typeof(cos)) = acos
inverse(::typeof(acos)) = cos
inverse(::typeof(sin)) = asin
inverse(::typeof(asin)) = sin
inverse(::typeof(cosd)) = acosd
inverse(::typeof(acosd)) = cosd
inverse(::typeof(sind)) = asind
inverse(::typeof(asind)) = sind
inverse(::typeof(identity)) = identity
isrevertible(transform::Functional) = isinvertible(transform)

# fallback to nothing
inverse(::Any) = nothing
_hasinverse(f) = !(invfun(f) isa NoInverse)

isrevertible(transform::Functional{AllSelector}) = !isnothing(inverse(transform.func))
isinvertible(transform::Functional{AllSelector}) = _hasinverse(transform.fun)
isinvertible(transform::Functional) = all(_hasinverse, transform.fun)

isrevertible(transform::Functional) = all(!isnothing, inverse.(transform.func))
inverse(transform::Functional{AllSelector}) = Functional(transform.selector, invfun(transform.fun))
inverse(transform::Functional) = Functional(transform.selector, invfun.(transform.fun))

_funcdict(func, names) = Dict(nm => func for nm in names)
_funcdict(func::Tuple, names) = Dict(names .=> func)
_fundict(transform::Functional{AllSelector}, names) = Dict(nm => transform.fun for nm in names)
_fundict(transform::Functional, names) = Dict(zip(names, transform.fun))

function applyfeat(transform::Functional, feat, prep)
cols = Tables.columns(feat)
names = Tables.columnnames(cols)
snames = transform.selector(names)
funcs = _funcdict(transform.func, snames)
fundict = _fundict(transform, snames)

columns = map(names) do nm
x = Tables.getcolumn(cols, nm)
if nm ∈ snames
func = funcs[nm]
y = func.(x)
columns = map(names) do name
x = Tables.getcolumn(cols, name)
if name ∈ snames
fun = fundict[name]
map(fun, x)
else
y = x
x
end
y
end

𝒯 = (; zip(names, columns)...)
newfeat = 𝒯 |> Tables.materializer(feat)
return newfeat, (snames, funcs)

newfeat, nothing
end

function revertfeat(transform::Functional, newfeat, fcache)
cols = Tables.columns(newfeat)
names = Tables.columnnames(cols)

snames, funcs = fcache

columns = map(names) do nm
y = Tables.getcolumn(cols, nm)
if nm ∈ snames
func = funcs[nm]
invfunc = inverse(func)
x = invfunc.(y)
else
x = y
end
x
end

𝒯 = (; zip(names, columns)...)
𝒯 |> Tables.materializer(newfeat)
ofeat, _ = applyfeat(inverse(transform), newfeat, nothing)
ofeat
end
10 changes: 5 additions & 5 deletions test/metadata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
mtₒ = revert(T, mn, mc)
@test mtₒ == mt

T = Functional(sin)
T = Functional(exp)
mn, mc = apply(T, mt)
tn, tc = apply(T, t)
@test mn.meta == m
Expand All @@ -33,7 +33,7 @@
@test mtₒ.meta == mt.meta
@test Tables.matrix(mtₒ.table) ≈ Tables.matrix(mt.table)

T = (Functional(sin) → MinMax()) ⊔ Center()
T = (Functional(exp) → MinMax()) ⊔ Center()
mn, mc = apply(T, mt)
tn, tc = apply(T, t)
@test mn.meta == m
Expand Down Expand Up @@ -68,7 +68,7 @@
mtₒ = revert(T, mn, mc)
@test mtₒ == mt

T = Functional(cos)
T = Functional(exp)
mn, mc = apply(T, mt)
tn, tc = apply(T, t)
@test mn.meta == VarMeta(m.data .+ 2)
Expand All @@ -79,7 +79,7 @@

# first revertible branch has two transforms,
# so metadata is increased by 2 + 2 = 4
T = (Functional(sin) → MinMax()) ⊔ Center()
T = (Functional(exp) → MinMax()) ⊔ Center()
mn, mc = apply(T, mt)
tn, tc = apply(T, t)
@test mn.meta == VarMeta(m.data .+ 4)
Expand All @@ -90,7 +90,7 @@

# first revertible branch has one transform,
# so metadata is increased by 2
T = Center() ⊔ (Functional(sin) → MinMax())
T = Center() ⊔ (Functional(exp) → MinMax())
mn, mc = apply(T, mt)
tn, tc = apply(T, t)
@test mn.meta == VarMeta(m.data .+ 2)
Expand Down
18 changes: 9 additions & 9 deletions test/shows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,18 +317,18 @@
end

@testset "Functional" begin
T = Functional(sin)
T = Functional(log)

# compact mode
iostr = sprint(show, T)
@test iostr == "Functional(all, sin)"
@test iostr == "Functional(all, log)"

# full mode
iostr = sprint(show, MIME("text/plain"), T)
@test iostr == """
Functional transform
├─ selector = all
└─ func = sin"""
└─ fun = log"""
end

@testset "EigenAnalysis" begin
Expand Down Expand Up @@ -419,31 +419,31 @@
@testset "ParallelTableTransform" begin
t1 = Scale(low=0.3, high=0.6)
t2 = EigenAnalysis(:VDV)
t3 = Functional(cos)
t3 = Functional(exp)
pipeline = t1 ⊔ t2 ⊔ t3

# compact mode
iostr = sprint(show, pipeline)
@test iostr == "Scale(all, 0.3, 0.6) ⊔ EigenAnalysis(:VDV, nothing, 1.0) ⊔ Functional(all, cos)"
@test iostr == "Scale(all, 0.3, 0.6) ⊔ EigenAnalysis(:VDV, nothing, 1.0) ⊔ Functional(all, exp)"

# full mode
iostr = sprint(show, MIME("text/plain"), pipeline)
@test iostr == """
ParallelTableTransform
├─ Scale(all, 0.3, 0.6)
├─ EigenAnalysis(:VDV, nothing, 1.0)
└─ Functional(all, cos)"""
└─ Functional(all, exp)"""

# parallel and sequential
f1 = ZScore()
f2 = Scale()
f3 = Functional(cos)
f3 = Functional(exp)
f4 = Interquartile()
pipeline = (f1 → f2) ⊔ (f3 → f4)

# compact mode
iostr = sprint(show, pipeline)
@test iostr == "ZScore(all) → Scale(all, 0.25, 0.75) ⊔ Functional(all, cos) → Scale(all, 0.25, 0.75)"
@test iostr == "ZScore(all) → Scale(all, 0.25, 0.75) ⊔ Functional(all, exp) → Scale(all, 0.25, 0.75)"

# full mode
iostr = sprint(show, MIME("text/plain"), pipeline)
Expand All @@ -453,7 +453,7 @@
│ ├─ ZScore(all)
│ └─ Scale(all, 0.25, 0.75)
└─ SequentialTransform
├─ Functional(all, cos)
├─ Functional(all, exp)
└─ Scale(all, 0.25, 0.75)"""
end
end
Loading
Loading