Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistently check matrix sizes in matmul #1152

Merged
merged 9 commits into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 14 additions & 30 deletions src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ end

# B .= A * B
function lmul!(A::Bidiagonal, B::AbstractVecOrMat)
_muldiag_size_check(size(A), size(B))
matmul_size_check(size(A), size(B))
(; dv, ev) = A
if A.uplo == 'U'
for k in axes(B,2)
Expand All @@ -518,7 +518,7 @@ function lmul!(A::Bidiagonal, B::AbstractVecOrMat)
end
# B .= D * B
function lmul!(D::Diagonal, B::Bidiagonal)
_muldiag_size_check(size(D), size(B))
matmul_size_check(size(D), size(B))
(; dv, ev) = B
isL = B.uplo == 'L'
dv[1] = D.diag[1] * dv[1]
Expand All @@ -530,7 +530,7 @@ function lmul!(D::Diagonal, B::Bidiagonal)
end
# B .= B * A
function rmul!(B::AbstractMatrix, A::Bidiagonal)
_muldiag_size_check(size(A), size(B))
matmul_size_check(size(A), size(B))
(; dv, ev) = A
if A.uplo == 'U'
for k in reverse(axes(dv,1)[2:end])
Expand All @@ -555,7 +555,7 @@ function rmul!(B::AbstractMatrix, A::Bidiagonal)
end
# B .= B * D
function rmul!(B::Bidiagonal, D::Diagonal)
_muldiag_size_check(size(B), size(D))
matmul_size_check(size(B), size(D))
(; dv, ev) = B
isU = B.uplo == 'U'
dv[1] *= D.diag[1]
Expand All @@ -566,22 +566,6 @@ function rmul!(B::Bidiagonal, D::Diagonal)
return B
end

@noinline function check_A_mul_B!_sizes((mC, nC)::NTuple{2,Integer}, (mA, nA)::NTuple{2,Integer}, (mB, nB)::NTuple{2,Integer})
# check for matching sizes in one column of B and C
check_A_mul_B!_sizes((mC,), (mA, nA), (mB,))
# ensure that the number of columns in B and C match
if nB != nC
throw(DimensionMismatch(lazy"second dimension of output C, $nC, and second dimension of B, $nB, must match"))
end
end
@noinline function check_A_mul_B!_sizes((mC,)::Tuple{Integer}, (mA, nA)::NTuple{2,Integer}, (mB,)::Tuple{Integer})
if mA != mC
throw(DimensionMismatch(lazy"first dimension of A, $mA, and first dimension of output C, $mC, must match"))
elseif nA != mB
throw(DimensionMismatch(lazy"second dimension of A, $nA, and first dimension of B, $mB, must match"))
end
end

# function to get the internally stored vectors for Bidiagonal and [Sym]Tridiagonal
# to avoid allocations in _mul! below (#24324, #24578)
_diag(A::Tridiagonal, k) = k == -1 ? A.dl : k == 0 ? A.d : A.du
Expand All @@ -603,7 +587,7 @@ _mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul) =
_bibimul!(C, A, B, _add)
function _bibimul!(C, A, B, _add)
require_one_based_indexing(C)
check_A_mul_B!_sizes(size(C), size(A), size(B))
matmul_size_check(size(C), size(A), size(B))
n = size(A,1)
iszero(n) && return C
# We use `_rmul_or_fill!` instead of `_modify!` here since using
Expand Down Expand Up @@ -851,7 +835,7 @@ _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number)
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
require_one_based_indexing(C)
check_A_mul_B!_sizes(size(C), size(A), size(B))
matmul_size_check(size(C), size(A), size(B))
n = size(A,1)
iszero(n) && return C
_rmul_or_fill!(C, _add.beta) # see the same use above
Expand Down Expand Up @@ -894,7 +878,7 @@ end

function _mul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
require_one_based_indexing(C)
check_A_mul_B!_sizes(size(C), size(A), size(B))
matmul_size_check(size(C), size(A), size(B))
n = size(A,1)
iszero(n) && return C
_rmul_or_fill!(C, _add.beta) # see the same use above
Expand Down Expand Up @@ -924,7 +908,7 @@ function _mul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
end

function _mul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
check_A_mul_B!_sizes(size(C), size(A), size(B))
matmul_size_check(size(C), size(A), size(B))
n = size(A,1)
iszero(n) && return C
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
Expand Down Expand Up @@ -957,7 +941,7 @@ end

function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulAddMul)
require_one_based_indexing(C, B)
check_A_mul_B!_sizes(size(C), size(A), size(B))
matmul_size_check(size(C), size(A), size(B))
nA = size(A,1)
nB = size(B,2)
(iszero(nA) || iszero(nB)) && return C
Expand Down Expand Up @@ -1027,7 +1011,7 @@ end

function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul)
require_one_based_indexing(C, A)
check_A_mul_B!_sizes(size(C), size(A), size(B))
matmul_size_check(size(C), size(A), size(B))
n = size(A,1)
m = size(B,2)
(iszero(_add.alpha) || iszero(m)) && return _rmul_or_fill!(C, _add.beta)
Expand Down Expand Up @@ -1063,7 +1047,7 @@ end

function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAddMul)
require_one_based_indexing(C, A)
check_A_mul_B!_sizes(size(C), size(A), size(B))
matmul_size_check(size(C), size(A), size(B))
m, n = size(A)
(iszero(m) || iszero(n)) && return C
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
Expand Down Expand Up @@ -1093,7 +1077,7 @@ _mul!(C::AbstractMatrix, A::Diagonal, B::TriSym, _add::MulAddMul) =
_dibimul!(C, A, B, _add)
function _dibimul!(C, A, B, _add)
require_one_based_indexing(C)
check_A_mul_B!_sizes(size(C), size(A), size(B))
matmul_size_check(size(C), size(A), size(B))
n = size(A,1)
iszero(n) && return C
# ensure that we fill off-band elements in the destination
Expand Down Expand Up @@ -1137,7 +1121,7 @@ function _dibimul!(C, A, B, _add)
end
function _dibimul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add)
require_one_based_indexing(C)
check_A_mul_B!_sizes(size(C), size(A), size(B))
matmul_size_check(size(C), size(A), size(B))
n = size(A,1)
iszero(n) && return C
# ensure that we fill off-band elements in the destination
Expand Down Expand Up @@ -1168,7 +1152,7 @@ function _dibimul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add)
C
end
function _dibimul!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add)
check_A_mul_B!_sizes(size(C), size(A), size(B))
matmul_size_check(size(C), size(A), size(B))
n = size(A,1)
n == 0 && return C
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
Expand Down
43 changes: 11 additions & 32 deletions src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -322,39 +322,18 @@ Base.literal_pow(::typeof(^), D::Diagonal, valp::Val) =
Diagonal(Base.literal_pow.(^, D.diag, valp)) # for speed
Base.literal_pow(::typeof(^), D::Diagonal, ::Val{-1}) = inv(D) # for disambiguation

function _muldiag_size_check(szA::NTuple{2,Integer}, szB::Tuple{Integer,Vararg{Integer}})
nA = szA[2]
mB = szB[1]
@noinline throw_dimerr(szB::NTuple{2}, nA, mB) = throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match first dimension of B, $mB"))
@noinline throw_dimerr(szB::NTuple{1}, nA, mB) = throw(DimensionMismatch(lazy"second dimension of D, $nA, does not match length of V, $mB"))
nA == mB || throw_dimerr(szB, nA, mB)
return nothing
end
# the output matrix should have the same size as the non-diagonal input matrix or vector
@noinline throw_dimerr(szC, szA) = throw(DimensionMismatch(lazy"output matrix has size: $szC, but should have size $szA"))
function _size_check_out(szC::NTuple{2}, szA::NTuple{2}, szB::NTuple{2})
(szC[1] == szA[1] && szC[2] == szB[2]) || throw_dimerr(szC, (szA[1], szB[2]))
end
function _size_check_out(szC::NTuple{1}, szA::NTuple{2}, szB::NTuple{1})
szC[1] == szA[1] || throw_dimerr(szC, (szA[1],))
end
function _muldiag_size_check(szC::Tuple{Vararg{Integer}}, szA::Tuple{Vararg{Integer}}, szB::Tuple{Vararg{Integer}})
_muldiag_size_check(szA, szB)
_size_check_out(szC, szA, szB)
end

function (*)(Da::Diagonal, Db::Diagonal)
_muldiag_size_check(size(Da), size(Db))
matmul_size_check(size(Da), size(Db))
return Diagonal(Da.diag .* Db.diag)
end

function (*)(D::Diagonal, V::AbstractVector)
_muldiag_size_check(size(D), size(V))
matmul_size_check(size(D), size(V))
return D.diag .* V
end

function rmul!(A::AbstractMatrix, D::Diagonal)
_muldiag_size_check(size(A), size(D))
matmul_size_check(size(A), size(D))
for I in CartesianIndices(A)
row, col = Tuple(I)
@inbounds A[row, col] *= D.diag[col]
Expand All @@ -363,7 +342,7 @@ function rmul!(A::AbstractMatrix, D::Diagonal)
end
# T .= T * D
function rmul!(T::Tridiagonal, D::Diagonal)
_muldiag_size_check(size(T), size(D))
matmul_size_check(size(T), size(D))
(; dl, d, du) = T
d[1] *= D.diag[1]
for i in axes(dl,1)
Expand All @@ -375,7 +354,7 @@ function rmul!(T::Tridiagonal, D::Diagonal)
end

function lmul!(D::Diagonal, B::AbstractVecOrMat)
_muldiag_size_check(size(D), size(B))
matmul_size_check(size(D), size(B))
for I in CartesianIndices(B)
row = I[1]
@inbounds B[I] = D.diag[row] * B[I]
Expand All @@ -386,7 +365,7 @@ end
# in-place multiplication with a diagonal
# T .= D * T
function lmul!(D::Diagonal, T::Tridiagonal)
_muldiag_size_check(size(D), size(T))
matmul_size_check(size(D), size(T))
(; dl, d, du) = T
d[1] = D.diag[1] * d[1]
for i in axes(dl,1)
Expand Down Expand Up @@ -507,7 +486,7 @@ end
# specialize the non-trivial case
function _mul_diag!(out, A, B, alpha, beta)
require_one_based_indexing(out, A, B)
_muldiag_size_check(size(out), size(A), size(B))
matmul_size_check(size(out), size(A), size(B))
if iszero(alpha)
_rmul_or_fill!(out, beta)
else
Expand All @@ -532,14 +511,14 @@ _mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number
_mul_diag!(C, Da, Db, alpha, beta)

function (*)(Da::Diagonal, A::AbstractMatrix, Db::Diagonal)
_muldiag_size_check(size(Da), size(A))
_muldiag_size_check(size(A), size(Db))
matmul_size_check(size(Da), size(A))
matmul_size_check(size(A), size(Db))
return broadcast(*, Da.diag, A, permutedims(Db.diag))
end

function (*)(Da::Diagonal, Db::Diagonal, Dc::Diagonal)
_muldiag_size_check(size(Da), size(Db))
_muldiag_size_check(size(Db), size(Dc))
matmul_size_check(size(Da), size(Db))
matmul_size_check(size(Db), size(Dc))
return Diagonal(Da.diag .* Db.diag .* Dc.diag)
end

Expand Down
Loading
Loading