Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIR with duration #89

Merged
merged 16 commits into from
Dec 19, 2024
Merged
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Effects = "8f03c58b-bd97-4933-a826-f71b64d2cca2"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
Expand Down Expand Up @@ -47,10 +49,10 @@ RobustModels = "d6ea1423-9682-4bbd-952f-b1577cbf8c98"

[extensions]
UnfoldBSplineKitExt = "BSplineKit"
UnfoldCUDAExt = "CUDA"
UnfoldKrylovExt = ["Krylov", "CUDA"]
UnfoldMixedModelsExt = "MixedModels"
UnfoldRobustModelsExt = "RobustModels"
UnfoldCUDAExt = "CUDA"

[compat]
BSplineKit = "0.16,0.17"
Expand All @@ -62,6 +64,8 @@ DocStringExtensions = "0.9"
Effects = "0.1,1"
FileIO = "1"
GLM = "1"
ImageTransformations = "0.10.1"
Interpolations = "0.15.1"
IterativeSolvers = "0.9"
JLD2 = "0.5"
Krylov = "0.9"
Expand Down
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
MakieThemes = "e296ed71-da82-5faf-88ab-0034a9761098"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
Expand Down
58 changes: 58 additions & 0 deletions docs/literate/HowTo/FIRduration.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
using Unfold
using Interpolations

using UnfoldSim
using UnfoldMakie, CairoMakie
using DataFrames
using DisplayAs # hide

data, evts = UnfoldSim.predef_eeg(sfreq = 10, n_repeats = 1)

evts.duration = 5:24


# putting `scale_duration = Interpolation.Linear()` will introduce a Cameron-Hassall 2022 PNAS- Style basisfunction, that scales with the `:duration` column
basisfunction = firbasis(τ = (-1, 2), sfreq = 5, scale_duration = Interpolations.Linear())

# Two examples with `duration = 10`
Unfold.kernel(basisfunction, [0, 10])
# and `duration = 20`
Unfold.kernel(basisfunction, [0, 20])

# let's fit a model
f = @formula 0 ~ 1 + condition
bf_vec = [Any => (f, basisfunction)]
m = fit(UnfoldModel, bf_vec, evts, data; eventfields = [:latency, :duration]);


## currently bugged for small matrices
## plot_designmatrix(designmatrix(m))
## thus using
heatmap(Matrix(modelmatrix(m))')
# As one can see, the designmatrix is nicely scaled

# We can predict overlap-corrected results
p = predict(m; overlap = false)[1]
heatmap(p[1, :, :])
# note the `missings` which are displayed as white pixels.

# ## Block-design predictors
# In contrast, it is also possible to put `scale_duration = true` - which wil not scale the matrix as before, but introduce a step-function.

# putting `scale_duration = Interpolation.Linear()` will introduce a Cameron-Hassall 2022 PNAS- Style basisfunction, that scales with the `:duration` column
basisfunction = firbasis(τ = (-1, 2), sfreq = 5, scale_duration = true)
# Two examples with `duration = 10`
Unfold.kernel(basisfunction, [0, 10])
# and `duration = 20`
Unfold.kernel(basisfunction, [0, 20])

# let's fit a model
f = @formula 0 ~ 1 + condition
bf_vec = [Any => (f, basisfunction)]
m = fit(UnfoldModel, bf_vec, evts, data; eventfields = [:latency, :duration]);


heatmap(Matrix(modelmatrix(m))')
# as one can see, now the designmatrix is not stretched - but rather "block"-ed
p = predict(m; overlap = false)[1]
heatmap(p[1, :, :])
9 changes: 5 additions & 4 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ makedocs(
"index.md",
"Installing Julia + Unfold.jl" => "installation.md",
"Tutorials" => [
"Mass univariate LM" => "tutorials/lm_mu.md",
"LM overlap correction" => "tutorials/lm_overlap.md",
"Mass univariate Mixed Model" => "tutorials/lmm_mu.md",
"LMM + overlap correction" => "tutorials/lmm_overlap.md",
"rERP (mass univariate)" => "tutorials/lm_mu.md",
"rERP (overlap correction)" => "tutorials/lm_overlap.md",
"lmmERP (mass univariate)" => "tutorials/lmm_mu.md",
"lmmERP (overlap correction)" => "tutorials/lmm_overlap.md",
],
"HowTo" => [
"Multiple events" => "HowTo/multiple_events.md",
Expand All @@ -39,6 +39,7 @@ makedocs(
#"Time domain basis functions"=>"generated/HowTo/timesplines.md",
"P-values for mixedModels" => "HowTo/lmm_pvalues.md",
"Save and load Unfold models" => "generated/HowTo/unfold_io.md",
"Duration-scaled basisfunctions (Hassall-style)" => "generated/HowTo/FIRduration.md",
"🐍 Import EEG with PyMNE.jl" => "HowTo/pymne.md",
"🐍 Calling Unfold.jl directly from Python" => "generated/HowTo/juliacall_unfold.md",
],
Expand Down
2 changes: 2 additions & 0 deletions ext/UnfoldBSplineKitExt/basisfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@ function splinekernel(e, times, nsplines)
basis = BSplineKit.BSplineBasis(BSplineOrder(4), breakpoints) # 4= cubic
return sparse(splFunction(times, basis))
end

Unfold.width(b::SplineBasis) = length(b.colnames)
2 changes: 2 additions & 0 deletions src/Unfold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import Term: Tree # to display Dicts
using PooledArrays
using TypedTables # DataFrames loose the pooled array, so we have to do it differently for now...

using Interpolations # for FIR duration scaling
using ImageTransformations # for FIR duration scaling
#using Tullio
#using BSplineKit # for spline predictors

Expand Down
123 changes: 94 additions & 29 deletions src/basisfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ See FIRBasis for an examples

a BasisFunction should implement:
- kernel() # kernel(b::BasisFunction,sample) => returns the designmatrix for that event
- height() # number of samples in continuous time
- height() # number of samples in continuous time, NaN if not defined
- width() # number of coefficient columns (e.g. HRF 1 to 3, FIR=height(),except if interpolate=true )

- colnames() # unique names of expanded columns
Expand Down Expand Up @@ -39,9 +39,20 @@ mutable struct FIRBasis <: BasisFunction
shift_onset::Int64
"should we linearly interpolate events not on full samples?"
interpolate::Bool
"should we scale kernel to the duration? If yes, with which method"
scale_duration::Any
#FIRBasis(times, name, shift_onset, interpolate, scale_duration::Bool) = new(
# times,
# name,
# shift_onset,
# interpolate,
# scale_duration ? Interpolations.Linear() : false,
#)
end
# default dont interpolate
FIRBasis(times, name, shift_onset) = FIRBasis(times, name, shift_onset, false)
FIRBasis(times, name, shift_onset, interpolate) =
FIRBasis(times, name, shift_onset, interpolate, false)
@deprecate FIRBasis(kernel::Function, times, name, shift_onset) FIRBasis(
times,
name,
Expand Down Expand Up @@ -83,8 +94,15 @@ end
$(SIGNATURES)
Generate a sparse FIR basis around the *τ* timevector at sampling rate *sfreq*. This is useful if you cannot make any assumptions on the shape of the event responses. If unrounded events are supplied, they are split between samples. E.g. event-latency = 1.2 will result in a "0.8" and a "0.2" entry.


Advanced: second input can be duration in samples - careful: `times(firbasis)` always assumes duration = 1. Therefore,
issues with LMM and predict will appear!

# keyword arguments
`interpolate` (Bool, default false): if true, interpolates events between samples linearly. This results in `predict` functions to return a trailling 0
`interpolate` (Bool, default false): if true, interpolates events between samples linearly. This results in `predict` functions to return a trailling 0`
`scale_duration` (Union{Bool,Interpolations-Interpolator}, default false):
if true, scales the response by the fit-kwargs `eventfields` second entry. That is, the FIR becomes a stepfunction instead of a impulse response.
if Interpolations.interpolator, e.g. `Interpolations.Linear()` - uses the fit-kwargs `eventfields` second entry to stretch the FIR kernel based on `imresize`. This implements Hassall

# Examples
Generate a FIR basis function from -0.1s to 0.3s at 100Hz
Expand All @@ -97,17 +115,29 @@ julia> f(103.3)
```

"""
function firbasis(τ, sfreq, name = ""; interpolate = false)
function firbasis(
τ,
sfreq,
name = "";
interpolate = false,
#max_duration = nothing,
scale_duration = false,
)
τ = round_times(τ, sfreq)
if interpolate
# stop + 1 step, because we support fractional event-timings
τ = (τ[1], τ[2] + 1 ./ sfreq)
end
#if !isnothing(max_duration)
# τ = (τ[1], max_duration)
#τ = (τ[1], τ[2] + max_height./sfreq)

#end
times = range(τ[1], stop = τ[2], step = 1 ./ sfreq)

shift_onset = Int64(floor(τ[1] * sfreq))

return FIRBasis(times, name, shift_onset, interpolate)
return FIRBasis(times, name, shift_onset, interpolate, scale_duration)

end
# cant multiple dispatch on optional arguments
Expand All @@ -119,39 +149,58 @@ firbasis(; τ, sfreq, name = "", kwargs...) = firbasis(τ, sfreq, name; kwargs..
"""
$(SIGNATURES)
Calculate a sparse firbasis

second input can be duration in samples - careful: `times(firbasis)` always assumes duration = 1. Therefore,
issues with LMM and predict will appear!

# Examples

```julia-repl
julia> f = firkernel(103.3,range(-0.1,step=0.01,stop=0.31))
julia> f_dur = firkernel([103.3 4],range(-0.1,step=0.01,stop=0.31))
```
"""
function firkernel(e, times; interpolate = false)
@assert ndims(e) <= 1 #either single onset or a row vector where we will take the first one :)
if size(e, 1) > 1
# XXX we will soon assume that the second entry would be the duration
e = Float64(e[1])
end
ksize = length(times) # kernelsize
if interpolate
eboth = [1 .- (e .% 1) e .% 1]
eboth[isapprox.(eboth, 0, atol = 1e-15)] .= 0
return spdiagm(
ksize + 1,
ksize,
0 => repeat([eboth[1]], ksize),
-1 => repeat([eboth[2]], ksize),
)
function firkernel(ev, times; interpolate = false, scale_duration = false)
@assert ndims(ev) <= 1 #either single onset or a row vector where we will take the first one :)

# duration is 1 for FIR and duration is duration (in samples!) if e is vector
e = interpolate ? ev[1] : 1
dur = (scale_duration == true && size(ev, 1) > 1) ? Int.(ceil(ev[2])) : 1

#scale_duration == false => 1
#scale_duration => true
#dur = (scale_duration == true || (scale_duration != false || size(ev, 1) > 1) ? Int.(ceil(ev[2])) : 1


ksize = interpolate ? length(times) - 1 : length(times) #isnothing(maxsize) ? length(times) : max_height # kernelsize

# if interpolatethe first and last entry is split, depending on whether we have "half" events
eboth = interpolate ? [1 .- (e .% 1) e .% 1] : [1]
#@show eboth dur
values = [eboth[1]; repeat([e], dur - 1); eboth[2:end]]
values[isapprox.(values, 0, atol = 1e-15)] .= 0 # keep sparsity pattern

# without interpolation, remove last entry again, which should be 0 anyway
# values = interpolate ? values : values[1:end-1] # commented out if the eboth[2:end] trick works

# build the matrix, we define it by diagonals which go from 0, -1, -2 ...
pairs = [x => repeat([y], ksize) for (x, y) in zip(.-range(0, dur), values)]
#return pairs, ksize
#@debug pairs
x_single = spdiagm(ksize + length(pairs) - 1, ksize, pairs...)
if scale_duration == false || scale_duration == true
return x_single
else
#eboth = Int(round(e))
return spdiagm(ksize, ksize, 0 => repeat([1], ksize))
#@show "imresize"
return imresize(x_single, ratio = (ev[2] / ksize, 1), method = scale_duration)
end


end





"""
$(SIGNATURES)
Generate a Hemodynamic-Response-Functio (HRF) basis with inverse-samplingrate "TR" (=1/FS)
Expand Down Expand Up @@ -222,19 +271,35 @@ basisname(fs::Vector{<:FormulaTerm}) = [name(f.rhs.basisfunction) for f in fs]
basisname(uf::UnfoldModel) = basisname(formulas(uf))
basisname(uf::UnfoldLinearModel) = first.(design(uf)) # for linear models we dont save it in the formula

kernel(basis::FIRBasis, e) =
basis.interpolate ? firkernel(e, basis.times[1:end-1]; interpolate = true) :
firkernel(e, basis.times; interpolate = false)
kernel(basis::FIRBasis, e) = firkernel(
e,
basis.times[1:end];
interpolate = basis.interpolate,
scale_duration = basis.scale_duration,
)


times(basis::BasisFunction) = basis.times
name(basis::BasisFunction) = basis.name

StatsModels.width(basis::BasisFunction) = height(basis)
StatsModels.width(basis::FIRBasis) = basis.interpolate ? height(basis) - 1 : height(basis)
height(basis::FIRBasis) = length(times(basis))
function StatsModels.width(basis::FIRBasis)
if basis.scale_duration == false#isa(basis.scale_duration, Bool)
if basis.interpolate
return height(basis) - 1
else
return height(basis)
end
else
return length(times(basis))
end
end
height(basis::BasisFunction) = length(times(basis))

height(basis::FIRBasis) = isa(basis.scale_duration, Bool) ? length(times(basis)) : NaN

StatsModels.width(basis::HRFBasis) = 1
times(basis::HRFBasis) = NaN
times(basis::HRFBasis) = NaN # I guess this could also return 1:32 or something?

"""
$(SIGNATURES)
Expand Down
Loading
Loading