From a4e8efc19c6455ccbaf4275e759b94298ad7b737 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 17 Jan 2024 12:08:03 +0530 Subject: [PATCH] Fix adjtransvec multiplication with zerosmatrix --- src/fillalgebra.jl | 5 ++--- test/runtests.jl | 54 +++++++++++++++++++++++++++------------------- 2 files changed, 34 insertions(+), 25 deletions(-) diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index a887eca2..928c04a7 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -102,9 +102,8 @@ for MT in (:(AbstractMatrix{T}), :(Transpose{<:Any, <:AbstractMatrix{T}}), :(Adj :(AbstractTriangular{T})) @eval *(a::$MT, b::AbstractZerosVector) where {T} = mult_zeros(a, b) end -for MT in (:(Transpose{<:Any, <:AbstractVector}), :(Adjoint{<:Any, <:AbstractVector})) - @eval *(a::$MT, b::AbstractZerosMatrix) = mult_zeros(a, b) -end +*(a::Transpose{<:Any, <:AbstractVector}, b::AbstractZerosMatrix) = transpose(transpose(b) * parent(a)) +*(a::Adjoint{<:Any, <:AbstractVector}, b::AbstractZerosMatrix) = adjoint(adjoint(b) * parent(a)) *(a::AbstractZerosMatrix, b::AbstractVector) = mult_zeros(a, b) function lmul_diag(a::Diagonal, b) diff --git a/test/runtests.jl b/test/runtests.jl index 34c77108..70ed9df7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1633,8 +1633,8 @@ end @test transpose(A) * Zeros(mA) ≡ Zeros(nA) @test A' * Zeros(mA) ≡ Zeros(nA) - @test transpose(a) * Zeros(la, 3) ≡ Zeros(1,3) - @test a' * Zeros(la,3) ≡ Zeros(1,3) + @test transpose(a) * Zeros(la, 3) ≡ transpose(Zeros(3)) + @test a' * Zeros(la,3) ≡ adjoint(Zeros(3)) @test Zeros(la)' * Transpose(Adjoint(a)) == 0.0 @@ -1701,30 +1701,40 @@ end @test (1:5)'E == (1.0:5)' @test E*E ≡ E + n = 10 + k = 12 + m = 15 for T in (Float64, ComplexF64) - fv = T == Float64 ? Float64(1.6) : ComplexF64(1.6, 1.3) - n = 10 - k = 12 - m = 15 - fillvec = Fill(fv, k) - fillmat = Fill(fv, k, m) - A = rand(ComplexF64, n, k) - @test A*fillvec ≈ A*Array(fillvec) - @test A*fillmat ≈ A*Array(fillmat) - A = rand(ComplexF64, k, n) - @test transpose(A)*fillvec ≈ transpose(A)*Array(fillvec) - @test transpose(A)*fillmat ≈ transpose(A)*Array(fillmat) - @test adjoint(A)*fillvec ≈ adjoint(A)*Array(fillvec) - @test adjoint(A)*fillmat ≈ adjoint(A)*Array(fillmat) - - # inplace C = F * B' * alpha + C * beta + Ank = rand(T, n, k) + Akn = rand(T, k, n) + Ak = rand(T, k) + + fv = T == Float64 ? T(1.6) : T(1.6, 1.3) + + for (fillvec, fillmat) in ((Fill(fv, k), Fill(fv, k, m)), + (Ones(T, k), Ones(T, k, m)), + (Zeros(T, k), Zeros(T, k, m))) + + Afillvec = Array(fillvec) + Afillmat = Array(fillmat) + @test Ank * fillvec ≈ Ank * Afillvec + @test Ank * fillmat ≈ Ank * Afillmat + + for A in (Akn, Ak) + @test transpose(A)*fillvec ≈ transpose(A)*Afillvec + @test transpose(A)*fillmat ≈ transpose(A)*Afillmat + @test adjoint(A)*fillvec ≈ adjoint(A)*Afillvec + @test adjoint(A)*fillmat ≈ adjoint(A)*Afillmat + end + end + + # inplace C = F * A' * alpha + C * beta F = Fill(fv, m, k) - A = Array(F) - B = rand(T, n, k) + M = Array(F) C = rand(T, m, n) @testset for f in (adjoint, transpose) - @test mul!(copy(C), F, f(B)) ≈ mul!(copy(C), A, f(B)) - @test mul!(copy(C), F, f(B), 1.0, 2.0) ≈ mul!(copy(C), A, f(B), 1.0, 2.0) + @test mul!(copy(C), F, f(Ank)) ≈ mul!(copy(C), M, f(Ank)) + @test mul!(copy(C), F, f(Ank), 1.0, 2.0) ≈ mul!(copy(C), M, f(Ank), 1.0, 2.0) end end