diff --git a/src/dspbase.jl b/src/dspbase.jl index fbcb5a7d7..4f6067006 100644 --- a/src/dspbase.jl +++ b/src/dspbase.jl @@ -171,6 +171,7 @@ function deconv(b::StridedVector{T}, a::StridedVector{T}) where T end + """ _zeropad!(padded::AbstractVector, u::AbstractVector, @@ -216,6 +217,7 @@ end padded end + """ _zeropad(u, padded_size, [data_dest, data_region]) @@ -320,12 +322,30 @@ end Transform the smaller convolution input to frequency domain, and return it in a new array. However, the contents of `buff` may be modified. """ -@inline function os_filter_transform!(buff::AbstractArray{<:Real}, p) - p * buff + +@inline function os_filter_transform!(A::NTuple{<:Any, AbstractArray{<:Real}}, p) + fA = p * A[1] + for a in A[2:end] + fA .*= (p * a) + end + return fA +end + +@inline function os_filter_transform!(buff::Tuple{AbstractArray{<:Real}}, p) + p * buff[1] end -@inline function os_filter_transform!(buff::AbstractArray{<:Complex}, p!) - copy(p! * buff) # p operates in place on buff +@inline function os_filter_transform!(A::NTuple{<:Any, AbstractArray{<:Complex}}, p!) + fA = p! * A[1] + for a in A[2:end] + fA .*= (p! * a) + end + return copy(fA) +end + +@inline function os_filter_transform!(buff::Tuple{AbstractArray{<:Complex}}, p!) + copy(p! * buff[1]) + end """ @@ -352,6 +372,7 @@ end p! * buff # p! operates in place on buff buff .*= filter_fd ip! * buff # ip! operates in place on buff + end # Used by `unsafe_conv_kern_os!` to handle blocks of input data that need to be padded. @@ -489,6 +510,7 @@ end function unsafe_conv_kern_os!(out, u::AbstractArray{<:Any, N}, v, + A, su, sv, sout, @@ -511,7 +533,9 @@ function unsafe_conv_kern_os!(out, # Transform the smaller filter _zeropad!(tdbuff, v) - filter_fd = os_filter_transform!(tdbuff, p) + filter_fd = os_filter_transform!((tdbuff, + (_zeropad(a, size(tdbuff)) + for a in A)...), p) filter_fd .*= 1 / prod(nffts) # Normalize once for brfft # block indices for center blocks, which need no padding @@ -607,148 +631,121 @@ function unsafe_conv_kern_os!(out, end function _conv_kern_fft!(out, - u::AbstractArray{T, N}, - v::AbstractArray{T, N}, - su, - sv, + A::NTuple{<:Any, AbstractArray{T, N}}, outsize, nffts) where {T<:Real, N} - padded = _zeropad(u, nffts) + padded = _zeropad(A[1], nffts) p = plan_rfft(padded) - uf = p * padded - _zeropad!(padded, v) - vf = p * padded - uf .*= vf - raw_out = irfft(uf, nffts[1]) + ftA = p * padded + for a in A[2:end] + _zeropad!(padded, a) + ftA .*= p * padded + end + raw_out = irfft(ftA, nffts[1]) copyto!(out, CartesianIndices(out), raw_out, CartesianIndices(UnitRange.(1, outsize))) end -function _conv_kern_fft!(out, u, v, su, sv, outsize, nffts) - upad = _zeropad(u, nffts) - vpad = _zeropad(v, nffts) - p! = plan_fft!(upad) - p! * upad # Operates in place on upad - p! * vpad - upad .*= vpad - ifft!(upad) + +function _conv_kern_fft!(out, A::NTuple{<:Any, AbstractArray{T, N}}, + outsize, nffts) where {T<:Complex{<:Real}, N} + Apad = [_zeropad(a, nffts) for a in A] + p! = plan_fft!(Apad[1]) + for a in Apad + p! * a + end + Apad[1] .*= .*(Apad[2:end]...) + ifft!(Apad[1]) copyto!(out, CartesianIndices(out), - upad, + Apad[1], CartesianIndices(UnitRange.(1, outsize))) end - -# v should be smaller than u for good performance -function _conv_fft!(out, u, v, su, sv, outsize) - os_nffts = map(optimalfftfiltlength, sv, su) +function _conv_fft!(out, A::Tuple{<:AbstractArray, <:AbstractArray}, S, outsize) + os_nffts = map(optimalfftfiltlength, S[2], S[1]) if any(os_nffts .< outsize) - unsafe_conv_kern_os!(out, u, v, su, sv, outsize, os_nffts) + unsafe_conv_kern_os!(out, A[1], A[2], (), S[1], S[2], outsize, os_nffts) else nffts = nextfastfft(outsize) - _conv_kern_fft!(out, u, v, su, sv, outsize, nffts) + _conv_kern_fft!(out, A, outsize, nffts) end end - +# A should be in ascending order of size for best performance +function _conv_fft!(out, A, S, outsize) + sv = outsize .- S[1] .+ 1 + os_nffts = map(optimalfftfiltlength, sv, S[1]) + if any(os_nffts .< outsize) + unsafe_conv_kern_os!(out, + A[1], A[2], A[3:end], S[1], sv, + outsize, os_nffts) + else + nffts = nextfastfft(outsize) + _conv_kern_fft!(out, A, outsize, nffts) + end +end # For arrays with weird offsets -function _conv_similar(u, outsize, axesu, axesv) - out_offsets = first.(axesu) .+ first.(axesv) +function _conv_similar(u, outsize, axes...) + out_offsets = .+([first.(ax) for ax in axes]...) out_axes = UnitRange.(out_offsets, out_offsets .+ outsize .- 1) similar(u, out_axes) end + function _conv_similar( - u, outsize, ::NTuple{<:Any, Base.OneTo{Int}}, ::NTuple{<:Any, Base.OneTo{Int}} -) + u, outsize, ::NTuple{<:Any, Base.OneTo{Int}}...) similar(u, outsize) end -_conv_similar(u, v, outsize) = _conv_similar(u, outsize, axes(u), axes(v)) +_conv_similar(A, outsize) = _conv_similar(A[1], outsize, + [axes(u) for u in A]...) # Does convolution, will not switch argument order -function _conv!(out, u, v, su, sv, outsize) +function _conv!(out, A, S, outsize) # TODO: Add spatial / time domain algorithm - _conv_fft!(out, u, v, su, sv, outsize) + _conv_fft!(out, A, S, outsize) end # Does convolution, will not switch argument order -function _conv(u, v, su, sv) - outsize = su .+ sv .- 1 - out = _conv_similar(u, v, outsize) - _conv!(out, u, v, su, sv, outsize) +function _conv_sz(A, S) + outsize = .+(S...) .- (length(S) - 1) + out = _conv_similar(A, outsize) + _conv!(out, A, S, outsize) end + # May switch argument order """ - conv(u,v) + conv(u, v, ...) Convolution of two arrays. Uses either FFT convolution or overlap-save, -depending on the size of the input. `u` and `v` can be N-dimensional arrays, +depending on the size of the input. Accepts any number of arrays to convolve +together can be N-dimensional arrays, with arbitrary indexing offsets, but their axes must be a `UnitRange`. """ -function conv(u::AbstractArray{T, N}, - v::AbstractArray{T, N}) where {T<:BLAS.BlasFloat, N} - su = size(u) - sv = size(v) - if prod(su) >= prod(sv) - _conv(u, v, su, sv) - else - _conv(v, u, sv, su) - end -end - -function conv(u::AbstractArray{<:BLAS.BlasFloat, N}, - v::AbstractArray{<:BLAS.BlasFloat, N}) where N - fu, fv = promote(u, v) - conv(fu, fv) +function _conv(A::AbstractArray...) + maxnd = max([ndims(a) for a in A]...) + return conv([cat(a, dims=maxnd) for a in A]...) end -conv(u::AbstractArray{<:Integer, N}, v::AbstractArray{<:Integer, N}) where {N} = - round.(Int, conv(float(u), float(v))) - -conv(u::AbstractArray{<:Number, N}, v::AbstractArray{<:Number, N}) where {N} = - conv(float(u), float(v)) - -function conv(u::AbstractArray{<:Number, N}, - v::AbstractArray{<:BLAS.BlasFloat, N}) where N - conv(float(u), v) +_conv(A::AbstractArray{<:Number, N}...) where {N} = conv(promote(A...)...) +_conv(A::AbstractArray{<:Integer}...) = round.(Int, conv([float(a) for a in A]...)) +function _conv(A::AbstractArray{<:BLAS.BlasFloat, N}...) where N + sizes = size.(A) + _conv_sz(A, sizes) end -function conv(u::AbstractArray{<:BLAS.BlasFloat, N}, - v::AbstractArray{<:Number, N}) where N - conv(u, float(v)) -end +# conv must have at least 2 inputs +conv(A::AbstractArray, B::AbstractArray, C::AbstractArray...) = _conv(A, B, C...) -function conv(A::AbstractArray{<:Number, M}, - B::AbstractArray{<:Number, N}) where {M, N} - if (M < N) - conv(cat(A, dims=N)::AbstractArray{eltype(A), N}, B) - else - @assert M > N - conv(A, cat(B, dims=M)::AbstractArray{eltype(B), M}) - end -end -""" - conv(u,v,A) -2-D convolution of the matrix `A` with the 2-D separable kernel generated by -the vectors `u` and `v`. -Uses 2-D FFT algorithm. -""" +# warn about old conv(u, v, A) function conv(u::AbstractVector{T}, v::AbstractVector{T}, A::AbstractMatrix{T}) where T - # Arbitrary indexing offsets not implemented - @assert !Base.has_offset_axes(u, v, A) - m = length(u)+size(A,1)-1 - n = length(v)+size(A,2)-1 - B = zeros(T, m, n) - B[1:size(A,1),1:size(A,2)] = A - u = fft([u;zeros(T,m-length(u))]) - v = fft([v;zeros(T,n-length(v))]) - C = ifft(fft(B) .* (u * transpose(v))) - if T <: Real - return real(C) - end - return C + # TODO this is inefficient + @warn "seperable convolution as conv(u::Vector, v::Vector, A::Matrix) is "\ + "no longer supported, use conv(u, transpose(v), A ) if that is what"\ + "you intend" + conv(cat(u, dims=2), cat(v, dims=2), A) end diff --git a/test/dsp.jl b/test/dsp.jl index 91dc942d7..3ead71a7b 100644 --- a/test/dsp.jl +++ b/test/dsp.jl @@ -68,6 +68,13 @@ end @test_throws MethodError conv([sin], [cos]) end + @testset "old-style-seperable-warns" begin + a = [1, 2] + b = [3, 4] + c = [1 2; 3 4] + @test conv(a, b, c) == conv(conv(a, b), c) + @test_warn "seperable" conv(a, b, c) + end @testset "conv-2D" begin a =[1 2 1; @@ -111,31 +118,6 @@ end @test conv(offset_arr, b) == OffsetArray(expectation, 0:3, 0:3) end - @testset "seperable conv" begin - u = [1, 2, 3, 2, 1] - v = [6, 7, 3, 2] - A = [1 2 3 4 5 6 7; - 8 9 10 11 12 13 14; - 15 16 17 18 19 20 21; - 22 23 24 25 26 27 28] - exp = [6 19 35 53 71 89 107 77 33 14; - 60 148 217 285 339 393 447 315 134 56; - 204 478 658 822 930 1038 1146 798 338 140; - 468 1062 1400 1684 1828 1972 2116 1456 614 252; - 636 1426 1848 2188 2332 2476 2620 1792 754 308; - 624 1388 1778 2082 2190 2298 2406 1638 688 280; - 354 785 1001 1167 1221 1275 1329 903 379 154; - 132 292 371 431 449 467 485 329 138 56] - @test_broken conv(u, v, A) == exp - - fu = convert(Array{Float64}, u) - fv = convert(Array{Float64}, v) - fA = convert(Array{Float64}, A) - fexp = convert(Array{Float64}, exp) - @test conv(fu, fv, fA) ≈ fexp - - end - @testset "conv-ND" begin # is it safe to assume that if conv works for # int/float/complex in 1 and 2 D, it does in ND? @@ -191,10 +173,10 @@ end sv, v = os_test_data(eltype, nv, N) sout = su .+ sv .- 1 out = _conv_similar(u, sout, axes(u), axes(v)) - unsafe_conv_kern_os!(out, u, v, su, sv, sout, nffts) + unsafe_conv_kern_os!(out, u, v, (), su, sv, sout, nffts) os_out = copy(out) fft_nfft = nextfastfft(sout) - _conv_kern_fft!(out, u, v, su, sv, sout, fft_nfft) + _conv_kern_fft!(out, (u, v), sout, fft_nfft) @test out ≈ os_out end Ns = [1, 2, 3] @@ -202,6 +184,7 @@ end nlarge = 128 regular_nsmall = [12, 128] + for numdim in Ns for elt in eltypes for nsmall in regular_nsmall @@ -227,6 +210,70 @@ end # three blocks need to be padded in the following case: test_os(Float64, 25, 4, Val{1}(), 16) end + + + @testset "N-arg-conv" begin + u = [1, 2, 3, 2, 1] + v = transpose([6, 7, 3, 2]) + A = [1 2 3 4 5 6 7; + 8 9 10 11 12 13 14; + 15 16 17 18 19 20 21; + 22 23 24 25 26 27 28] + exp = [6 19 35 53 71 89 107 77 33 14; + 60 148 217 285 339 393 447 315 134 56; + 204 478 658 822 930 1038 1146 798 338 140; + 468 1062 1400 1684 1828 1972 2116 1456 614 252; + 636 1426 1848 2188 2332 2476 2620 1792 754 308; + 624 1388 1778 2082 2190 2298 2406 1638 688 280; + 354 785 1001 1167 1221 1275 1329 903 379 154; + 132 292 371 431 449 467 485 329 138 56] + @test conv(u, v, A) == exp + + fu = convert(Array{Float64}, u) + fv = convert(Array{Float64}, v) + fA = convert(Array{Float64}, A) + fexp = convert(Array{Float64}, exp) + @test conv(fu, fv, fA) ≈ fexp + + function compare_to_naive(arrs...) + function ncv(arrs...) + if length(arrs) == 1 + arrs[1] + else + ncv(conv(arrs[1], arrs[2]), arrs[3:end]...) + end + end + tconv = @elapsed rconv = conv(arrs...) + tncv = @elapsed rncv = ncv(arrs...) + @test rconv ≈ rncv + end + + compare_to_naive([1 2 1; + 2 3 1; + 1 2 1], + [3 2; + 0 1], + convert(Array, reshape(1:27, (3, 3, 3))), + [1 2 3 4 5 6 7; + 8 9 10 11 12 13 14; + 15 16 17 18 19 20 21; + 22 23 24 25 26 27 28]) + compare_to_naive(u, transpose(v), A, fu, fv, fA, fexp) + compare_to_naive(u, transpose(v), A, fu, fv, fA, fexp) + compare_to_naive(u, v, A) + compare_to_naive(fu, fv, fA) + compare_to_naive(u, v, A, v, u, A, u, v, v, v, transpose(v), transpose(u)) + compare_to_naive(repeat(exp, 300), + A, exp, v, u) + compare_to_naive(repeat(fexp, 300), + A, exp, fv, u) + compare_to_naive(repeat(complex.(fexp, 140), 30), + A, exp, fv, u) + compare_to_naive(repeat([1, 2, 3, 4, 5, 5, 6, 7 ,9, 10], 100000), + [1, 2, 3, 4], [5, 6, 7, 8]) + end + + end @testset "xcorr" begin diff --git a/test/runtests.jl b/test/runtests.jl index d418e2750..a1d31b976 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,5 +10,8 @@ testfiles = [ "dsp.jl", "util.jl", "windows.jl", "filter_conversion.jl", seed!(1776) for testfile in testfiles - eval(:(@testset $testfile begin include($testfile) end)) + @testset "$testfile" begin + time = @elapsed eval(:(include($testfile))) + @show time + end end