Skip to content

Commit

Permalink
move plans into Algorith.FFT
Browse files Browse the repository at this point in the history
  • Loading branch information
IanButterworth committed Jan 13, 2024
1 parent 08e34f2 commit a2b92eb
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 22 deletions.
12 changes: 4 additions & 8 deletions demo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,12 @@ function do_work(mat)
indices = Int.(round.((LinRange(frame_start, n_frames - frame_dist, n_pairs))))
factored_kernel = ImageFiltering.factorkernel(Kernel.LoG(1))
frame_filt = deepcopy(@view mat[:, :, frame_start])
r = CPU1(Algorithm.FFT())
p1 = plan_rfft(frame_filt)
p2, krn = ImageFiltering.kernel_plan_rfft(frame_filt, factored_kernel)
B = p1 * frame_filt
B .*= conj!(p2 * krn)
p3 = plan_irfft(B, length(axes(frame_filt, 1)))
plan = (p1, p2, p3)
r = CPU1(ImageFiltering.planned_fft(frame_filt, factored_kernel))
for i in indices
frame = @view mat[:, :, i]
imfilter!(r, frame_filt, frame, factored_kernel, NoPad(Pad(:replicate)); plan)
imfilter!(r, frame_filt, frame, factored_kernel, NoPad())
# TODO: make it work for cases with padding.. so planned_fft needs to do the padding transforms on a copy of A
# imfilter!(r, frame_filt, frame, factored_kernel)
end
return
end
Expand Down
8 changes: 7 additions & 1 deletion src/ImageFiltering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,16 @@ function Base.transpose(A::StaticOffsetArray{T,2}) where T
end

module Algorithm
import FFTW
# deliberately don't export these, but it's expected that they
# will be used as Algorithm.FFT(), etc.
abstract type Alg end
"Filter using the Fast Fourier Transform" struct FFT <: Alg end
"Filter using the Fast Fourier Transform" struct FFT <: Alg
plan1::Union{FFTW.rFFTWPlan,Nothing}
plan2::Union{FFTW.rFFTWPlan,Nothing}
plan3::Union{FFTW.AbstractFFTs.ScaledPlan,Nothing}
end
FFT() = FFT(nothing, nothing, nothing)
"Filter using a direct algorithm" struct FIR <: Alg end
"Cache-efficient filtering using tiles" struct FIRTiled{N} <: Alg
tilesize::Dims{N}
Expand Down
33 changes: 20 additions & 13 deletions src/imfilter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -797,43 +797,39 @@ function imfilter!(r::AbstractCPU{FFT},
out::AbstractArray{S,N},
img::AbstractArray{T,N},
kernel::AbstractArray{K,N},
border::NoPad;
plan::Union{Tuple{Any,Any,Any},Nothing}=nothing) where {S,T,K,N}
imfilter!(r, out, img, (kernel,), border; plan)
border::NoPad) where {S,T,K,N}
imfilter!(r, out, img, (kernel,), border)
end

function imfilter!(r::AbstractCPU{FFT},
out::AbstractArray{S,N},
A::AbstractArray{T,N},
kernel::Tuple{AbstractArray},
border::NoPad;
plan::Union{Tuple{Any,Any,Any},Nothing}=nothing) where {S,T,N}
_imfilter_fft!(r, out, A, kernel, border; plan) # ambiguity resolution
border::NoPad) where {S,T,N}
_imfilter_fft!(r, out, A, kernel, border) # ambiguity resolution
end
function imfilter!(r::AbstractCPU{FFT},
out::AbstractArray{S,N},
A::AbstractArray{T,N},
kernel::Tuple{AbstractArray,Vararg{AbstractArray}},
border::NoPad;
plan::Union{Tuple{Any,Any,Any},Nothing}=nothing) where {S,T,N}
_imfilter_fft!(r, out, A, kernel, border; plan)
border::NoPad) where {S,T,N}
_imfilter_fft!(r, out, A, kernel, border)
end

function _imfilter_fft!(r::AbstractCPU{FFT},
out::AbstractArray{S,N},
A::AbstractArray{T,N},
kernel::Tuple{AbstractArray,Vararg{AbstractArray}},
border::NoPad;
plan::Union{Tuple{Any,Any,Any},Nothing}=nothing) where {S,T,N}
border::NoPad) where {S,T,N}
kern = samedims(A, kernelconv(kernel...))
krn = FFTView(zeros(eltype(kern), map(length, axes(A))))
for I in CartesianIndices(axes(kern))
krn[I] = kern[I]
end
Af = if plan === nothing
Af = if any(isnothing, (r.settings.plan1, r.settings.plan2, r.settings.plan3))
filtfft(A, krn)
else
filtfft(plan[1], A, plan[2], krn, plan[3])
filtfft(r.settings.plan1, A, r.settings.plan2, krn, r.settings.plan3)
end
Af =
if map(first, axes(out)) == map(first, axes(Af))
Expand All @@ -849,6 +845,17 @@ function _imfilter_fft!(r::AbstractCPU{FFT},
out
end

function planned_fft(A::AbstractArray{T,N}, kernel::Tuple{AbstractArray,Vararg{AbstractArray}}) where {T,N}
p1 = plan_rfft(A)
kern = samedims(A, kernelconv(kernel...))
krn = FFTView(zeros(eltype(kern), map(length, axes(A))))
p2 = plan_rfft(krn)
B = p1 * A
B .*= conj!(p2 * krn)
p3 = plan_irfft(B, length(axes(A, 1)))
return Algorithm.FFT(p1, p2, p3)
end

function kernel_plan_rfft(A::AbstractArray{T,N}, kernel::Tuple{AbstractArray,Vararg{AbstractArray}}) where {T,N}
kern = samedims(A, kernelconv(kernel...))
krn = FFTView(zeros(eltype(kern), map(length, axes(A))))
Expand Down

0 comments on commit a2b92eb

Please sign in to comment.