Skip to content

Commit

Permalink
Fix adjtransvec multiplication with zerosmatrix
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Apr 25, 2024
1 parent b0ee65f commit a4e8efc
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 25 deletions.
5 changes: 2 additions & 3 deletions src/fillalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@ for MT in (:(AbstractMatrix{T}), :(Transpose{<:Any, <:AbstractMatrix{T}}), :(Adj
:(AbstractTriangular{T}))
@eval *(a::$MT, b::AbstractZerosVector) where {T} = mult_zeros(a, b)
end
for MT in (:(Transpose{<:Any, <:AbstractVector}), :(Adjoint{<:Any, <:AbstractVector}))
@eval *(a::$MT, b::AbstractZerosMatrix) = mult_zeros(a, b)
end
*(a::Transpose{<:Any, <:AbstractVector}, b::AbstractZerosMatrix) = transpose(transpose(b) * parent(a))
*(a::Adjoint{<:Any, <:AbstractVector}, b::AbstractZerosMatrix) = adjoint(adjoint(b) * parent(a))
*(a::AbstractZerosMatrix, b::AbstractVector) = mult_zeros(a, b)

function lmul_diag(a::Diagonal, b)
Expand Down
54 changes: 32 additions & 22 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1633,8 +1633,8 @@ end
@test transpose(A) * Zeros(mA) Zeros(nA)
@test A' * Zeros(mA) Zeros(nA)

@test transpose(a) * Zeros(la, 3) Zeros(1,3)
@test a' * Zeros(la,3) Zeros(1,3)
@test transpose(a) * Zeros(la, 3) transpose(Zeros(3))
@test a' * Zeros(la,3) adjoint(Zeros(3))

@test Zeros(la)' * Transpose(Adjoint(a)) == 0.0

Expand Down Expand Up @@ -1701,30 +1701,40 @@ end
@test (1:5)'E == (1.0:5)'
@test E*E E

n = 10
k = 12
m = 15
for T in (Float64, ComplexF64)
fv = T == Float64 ? Float64(1.6) : ComplexF64(1.6, 1.3)
n = 10
k = 12
m = 15
fillvec = Fill(fv, k)
fillmat = Fill(fv, k, m)
A = rand(ComplexF64, n, k)
@test A*fillvec A*Array(fillvec)
@test A*fillmat A*Array(fillmat)
A = rand(ComplexF64, k, n)
@test transpose(A)*fillvec transpose(A)*Array(fillvec)
@test transpose(A)*fillmat transpose(A)*Array(fillmat)
@test adjoint(A)*fillvec adjoint(A)*Array(fillvec)
@test adjoint(A)*fillmat adjoint(A)*Array(fillmat)

# inplace C = F * B' * alpha + C * beta
Ank = rand(T, n, k)
Akn = rand(T, k, n)
Ak = rand(T, k)

fv = T == Float64 ? T(1.6) : T(1.6, 1.3)

for (fillvec, fillmat) in ((Fill(fv, k), Fill(fv, k, m)),
(Ones(T, k), Ones(T, k, m)),
(Zeros(T, k), Zeros(T, k, m)))

Afillvec = Array(fillvec)
Afillmat = Array(fillmat)
@test Ank * fillvec Ank * Afillvec
@test Ank * fillmat Ank * Afillmat

for A in (Akn, Ak)
@test transpose(A)*fillvec transpose(A)*Afillvec
@test transpose(A)*fillmat transpose(A)*Afillmat
@test adjoint(A)*fillvec adjoint(A)*Afillvec
@test adjoint(A)*fillmat adjoint(A)*Afillmat
end
end

# inplace C = F * A' * alpha + C * beta
F = Fill(fv, m, k)
A = Array(F)
B = rand(T, n, k)
M = Array(F)
C = rand(T, m, n)
@testset for f in (adjoint, transpose)
@test mul!(copy(C), F, f(B)) mul!(copy(C), A, f(B))
@test mul!(copy(C), F, f(B), 1.0, 2.0) mul!(copy(C), A, f(B), 1.0, 2.0)
@test mul!(copy(C), F, f(Ank)) mul!(copy(C), M, f(Ank))
@test mul!(copy(C), F, f(Ank), 1.0, 2.0) mul!(copy(C), M, f(Ank), 1.0, 2.0)
end
end

Expand Down

0 comments on commit a4e8efc

Please sign in to comment.