From a2b92eb2215635b9aa6abc28cc7a94a014e29617 Mon Sep 17 00:00:00 2001 From: Ian Butterworth Date: Sat, 13 Jan 2024 12:35:52 -0500 Subject: [PATCH] move plans into Algorith.FFT --- demo.jl | 12 ++++-------- src/ImageFiltering.jl | 8 +++++++- src/imfilter.jl | 33 ++++++++++++++++++++------------- 3 files changed, 31 insertions(+), 22 deletions(-) diff --git a/demo.jl b/demo.jl index 3c128b5..985a362 100644 --- a/demo.jl +++ b/demo.jl @@ -13,16 +13,12 @@ function do_work(mat) indices = Int.(round.((LinRange(frame_start, n_frames - frame_dist, n_pairs)))) factored_kernel = ImageFiltering.factorkernel(Kernel.LoG(1)) frame_filt = deepcopy(@view mat[:, :, frame_start]) - r = CPU1(Algorithm.FFT()) - p1 = plan_rfft(frame_filt) - p2, krn = ImageFiltering.kernel_plan_rfft(frame_filt, factored_kernel) - B = p1 * frame_filt - B .*= conj!(p2 * krn) - p3 = plan_irfft(B, length(axes(frame_filt, 1))) - plan = (p1, p2, p3) + r = CPU1(ImageFiltering.planned_fft(frame_filt, factored_kernel)) for i in indices frame = @view mat[:, :, i] - imfilter!(r, frame_filt, frame, factored_kernel, NoPad(Pad(:replicate)); plan) + imfilter!(r, frame_filt, frame, factored_kernel, NoPad()) + # TODO: make it work for cases with padding.. so planned_fft needs to do the padding transforms on a copy of A + # imfilter!(r, frame_filt, frame, factored_kernel) end return end diff --git a/src/ImageFiltering.jl b/src/ImageFiltering.jl index bbb164e..dd80eb4 100644 --- a/src/ImageFiltering.jl +++ b/src/ImageFiltering.jl @@ -50,10 +50,16 @@ function Base.transpose(A::StaticOffsetArray{T,2}) where T end module Algorithm + import FFTW # deliberately don't export these, but it's expected that they # will be used as Algorithm.FFT(), etc. abstract type Alg end - "Filter using the Fast Fourier Transform" struct FFT <: Alg end + "Filter using the Fast Fourier Transform" struct FFT <: Alg + plan1::Union{FFTW.rFFTWPlan,Nothing} + plan2::Union{FFTW.rFFTWPlan,Nothing} + plan3::Union{FFTW.AbstractFFTs.ScaledPlan,Nothing} + end + FFT() = FFT(nothing, nothing, nothing) "Filter using a direct algorithm" struct FIR <: Alg end "Cache-efficient filtering using tiles" struct FIRTiled{N} <: Alg tilesize::Dims{N} diff --git a/src/imfilter.jl b/src/imfilter.jl index 46299b4..9229f43 100644 --- a/src/imfilter.jl +++ b/src/imfilter.jl @@ -797,43 +797,39 @@ function imfilter!(r::AbstractCPU{FFT}, out::AbstractArray{S,N}, img::AbstractArray{T,N}, kernel::AbstractArray{K,N}, - border::NoPad; - plan::Union{Tuple{Any,Any,Any},Nothing}=nothing) where {S,T,K,N} - imfilter!(r, out, img, (kernel,), border; plan) + border::NoPad) where {S,T,K,N} + imfilter!(r, out, img, (kernel,), border) end function imfilter!(r::AbstractCPU{FFT}, out::AbstractArray{S,N}, A::AbstractArray{T,N}, kernel::Tuple{AbstractArray}, - border::NoPad; - plan::Union{Tuple{Any,Any,Any},Nothing}=nothing) where {S,T,N} - _imfilter_fft!(r, out, A, kernel, border; plan) # ambiguity resolution + border::NoPad) where {S,T,N} + _imfilter_fft!(r, out, A, kernel, border) # ambiguity resolution end function imfilter!(r::AbstractCPU{FFT}, out::AbstractArray{S,N}, A::AbstractArray{T,N}, kernel::Tuple{AbstractArray,Vararg{AbstractArray}}, - border::NoPad; - plan::Union{Tuple{Any,Any,Any},Nothing}=nothing) where {S,T,N} - _imfilter_fft!(r, out, A, kernel, border; plan) + border::NoPad) where {S,T,N} + _imfilter_fft!(r, out, A, kernel, border) end function _imfilter_fft!(r::AbstractCPU{FFT}, out::AbstractArray{S,N}, A::AbstractArray{T,N}, kernel::Tuple{AbstractArray,Vararg{AbstractArray}}, - border::NoPad; - plan::Union{Tuple{Any,Any,Any},Nothing}=nothing) where {S,T,N} + border::NoPad) where {S,T,N} kern = samedims(A, kernelconv(kernel...)) krn = FFTView(zeros(eltype(kern), map(length, axes(A)))) for I in CartesianIndices(axes(kern)) krn[I] = kern[I] end - Af = if plan === nothing + Af = if any(isnothing, (r.settings.plan1, r.settings.plan2, r.settings.plan3)) filtfft(A, krn) else - filtfft(plan[1], A, plan[2], krn, plan[3]) + filtfft(r.settings.plan1, A, r.settings.plan2, krn, r.settings.plan3) end Af = if map(first, axes(out)) == map(first, axes(Af)) @@ -849,6 +845,17 @@ function _imfilter_fft!(r::AbstractCPU{FFT}, out end +function planned_fft(A::AbstractArray{T,N}, kernel::Tuple{AbstractArray,Vararg{AbstractArray}}) where {T,N} + p1 = plan_rfft(A) + kern = samedims(A, kernelconv(kernel...)) + krn = FFTView(zeros(eltype(kern), map(length, axes(A)))) + p2 = plan_rfft(krn) + B = p1 * A + B .*= conj!(p2 * krn) + p3 = plan_irfft(B, length(axes(A, 1))) + return Algorithm.FFT(p1, p2, p3) +end + function kernel_plan_rfft(A::AbstractArray{T,N}, kernel::Tuple{AbstractArray,Vararg{AbstractArray}}) where {T,N} kern = samedims(A, kernelconv(kernel...)) krn = FFTView(zeros(eltype(kern), map(length, axes(A))))