Skip to content

Commit

Permalink
Use MulAdd's CZeros to remove duplicate zero-fill (#424)
Browse files Browse the repository at this point in the history
* Use MulAdd's CZeros in non-blas mat-vec

* Use MulAdd's CZero in blas mat-vec and mat-mat
  • Loading branch information
jishnub authored Feb 2, 2024
1 parent 03a54ea commit 13b11fc
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 51 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "BandedMatrices"
uuid = "aae01518-5342-5314-be14-df237901396f"
version = "1.4.1"
version = "1.5.0"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand All @@ -17,7 +17,7 @@ BandedMatricesSparseArraysExt = "SparseArrays"

[compat]
Aqua = "0.8"
ArrayLayouts = "1.5.3"
ArrayLayouts = "1.6.0"
Documenter = "1"
FillArrays = "1.3"
GenericLinearAlgebra = "0.3"
Expand Down
40 changes: 21 additions & 19 deletions src/banded/gbmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ function _num_zeroband_l(A)
return Al+Au+1
end

function gbmm!(tA::Char, tB::Char, α::T, A::AbstractMatrix{T}, B::AbstractMatrix{T}, β::T, C::AbstractMatrix{T}) where {T<:BlasFloat}
function gbmm!(tA::Char, tB::Char, α::T, A::AbstractMatrix{T},
B::AbstractMatrix{T}, β::T, C::AbstractMatrix{T},
Czero=false) where {T<:BlasFloat}
if tA 'N' || tB 'N'
error("Only 'N' flag is supported.")
end
Expand All @@ -227,23 +229,23 @@ function gbmm!(tA::Char, tB::Char, α::T, A::AbstractMatrix{T}, B::AbstractMatri

# prune zero bands
if (-Al > Au) || (-Bl > Bu) # A or B has empty bands
fill!(C, zero(T))
Czero || fill!(C, zero(T))
return C
elseif Al < 0
_fill_lmul!(β, @views(C[max(1,Bn+Al-1):Am, :]))
gbmm!('N', 'N', α, view(A, :, 1-Al:An), view(B, 1-Al:An, :), β, C)
_fill_lmul!(β, @views(C[max(1,Bn+Al-1):Am, :]), Czero)
gbmm!('N', 'N', α, view(A, :, 1-Al:An), view(B, 1-Al:An, :), β, C, Czero)
return C
elseif Au < 0
_fill_lmul!(β, @views(C[1:-Au,:]))
gbmm!('N', 'N', α, view(A, 1-Au:Am,:), B, β, view(C, 1-Au:Am,:))
_fill_lmul!(β, @views(C[1:-Au,:]), Czero)
gbmm!('N', 'N', α, view(A, 1-Au:Am,:), B, β, view(C, 1-Au:Am,:), Czero)
return C
elseif Bl < 0
_fill_lmul!(β, @views(C[:, 1:-Bl]))
gbmm!('N', 'N', α, A, view(B, :, 1-Bl:Bn), β, view(C, :, 1-Bl:Bn))
_fill_lmul!(β, @views(C[:, 1:-Bl]), Czero)
gbmm!('N', 'N', α, A, view(B, :, 1-Bl:Bn), β, view(C, :, 1-Bl:Bn), Czero)
return C
elseif Bu < 0
_fill_lmul!(β, @views(C[:, max(1,Am+Bu-1):Bn]))
gbmm!('N', 'N', α, view(A, :, 1-Bu:Bm), view(B, 1-Bu:Bm, :), β, C)
_fill_lmul!(β, @views(C[:, max(1,Am+Bu-1):Bn]), Czero)
gbmm!('N', 'N', α, view(A, :, 1-Bu:Bm), view(B, 1-Bu:Bm, :), β, C, Czero)
return C
elseif C̃u < Cu
Au_r, Bu_r = _num_zeroband_u(A), _num_zeroband_u(B)
Expand All @@ -252,13 +254,13 @@ function gbmm!(tA::Char, tB::Char, α::T, A::AbstractMatrix{T}, B::AbstractMatri
B_data = bandeddata(B)

if Au-Au_r < -Al || Bu - Bu_r < -Bl
_fill_lmul!(β, C)
_fill_lmul!(β, C, Czero)
return C
end

= _BandedMatrix(@views(A_data[Au_r+1:end,:]), n, Al, Au-Au_r)
= _BandedMatrix(@views(B_data[Bu_r+1:end,:]), ν, Bl, Bu-Bu_r)
gbmm!('N', 'N', α, Ã, B̃, β, C)
gbmm!('N', 'N', α, Ã, B̃, β, C, Czero)
return C
elseif C̃l < Cl
Al_r, Bl_r = _num_zeroband_l(A), _num_zeroband_l(B)
Expand All @@ -267,13 +269,13 @@ function gbmm!(tA::Char, tB::Char, α::T, A::AbstractMatrix{T}, B::AbstractMatri
B_data = bandeddata(B)

if Al-Al_r < -Au || Bl - Bl_r < -Bu
_fill_lmul!(β, C)
_fill_lmul!(β, C, Czero)
return C
end

= _BandedMatrix(@views(A_data[1:end-Al_r,:]), n, Al-Al_r, Au)
= _BandedMatrix(@views(B_data[1:end-Bl_r,:]), ν, Bl-Bl_r, Bu)
gbmm!('N', 'N', α, Ã, B̃, β, C)
gbmm!('N', 'N', α, Ã, B̃, β, C, Czero)
return C
end

Expand All @@ -282,16 +284,16 @@ function gbmm!(tA::Char, tB::Char, α::T, A::AbstractMatrix{T}, B::AbstractMatri
C̃_data = bandeddata(C)

# scale extra bands
_fill_lmul!(β, view(C̃_data, 1:min(C̃u-Cu,size(C̃_data,1)),:))
_fill_lmul!(β, view(C̃_data, (C̃u+Cl+1)+1:size(C̃_data,1),:))
_fill_lmul!(β, view(C̃_data, 1:min(C̃u-Cu,size(C̃_data,1)),:), Czero)
_fill_lmul!(β, view(C̃_data, (C̃u+Cl+1)+1:size(C̃_data,1),:), Czero)
C_data = view(C̃_data, (C̃u-Cu+1):(C̃u+Cl+1), :) # shift to bands we will write to
_gbmm!(α, A_data, B_data, β, C_data, (n,ν,m), (Al, Au), (Bl, Bu), (Cl, Cu))
_gbmm!(α, A_data, B_data, β, C_data, (n,ν,m), (Al, Au), (Bl, Bu), (Cl, Cu), Czero)

C
end


function _gbmm!::T, A_data, B_data, β, C_data, (n,ν,m), (Al, Au), (Bl, Bu), (Cl, Cu)) where T
function _gbmm!::T, A_data, B_data, β, C_data, (n,ν,m), (Al, Au), (Bl, Bu), (Cl, Cu), Czero) where T
a = pointer(A_data)
b = pointer(B_data)
c = pointer(C_data)
Expand Down Expand Up @@ -334,5 +336,5 @@ function _gbmm!(α::T, A_data, B_data, β, C_data, (n,ν,m), (Al, Au), (Bl, Bu),
end

# scale columns of C by β that aren't impacted by α*A*B
_fill_lmul!(β, view(C_data, :, ν+Bu+1:min(m,n+Cu)))
_fill_lmul!(β, view(C_data, :, ν+Bu+1:min(m,n+Cu)), Czero)
end
57 changes: 29 additions & 28 deletions src/generic/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,72 +23,72 @@ banded_gbmv!(tA, α, A, x, β, y) =
α, bandeddata(A), x, β, y)


@inline function _banded_gbmv!(tA, α, A, x, β, y)
@inline function _banded_gbmv!(tA, α, A, x, β, y, yzero=false)
#= Some BLAS implementations throw warnings
with zero-sized arrays, so we handle
these cases separately.
=#
length(y) == 0 && return y
if length(x) == 0
_fill_rmul!(y, β)
_fill_rmul!(y, β, yzero)
else
xc = Base.unalias(y, x)
banded_gbmv!(tA, α, A, xc, β, y)
end
return y
end

function _banded_muladd!(α, A, x::AbstractVector, β, y)
function _banded_muladd!(α, A, x::AbstractVector, β, y, yzero)
m, n = size(A)
l, u = bandwidths(A)
if -l > u # no bands
_fill_rmul!(y, β)
_fill_rmul!(y, β, yzero)
elseif l < 0 # with u >= -l > 0, that is, all bands lie above the diagonal
# E.g. (l,u) = (-1,2)
# set lview = 0 and uview = u + l >= 0
_banded_gbmv!('N', α, view(A, :, 1-l:n), view(x, 1-l:n), β, y)
_banded_gbmv!('N', α, view(A, :, 1-l:n), view(x, 1-l:n), β, y, yzero)
elseif u < 0 # with -l <= u < 0, that is, all bands lie below the diagnoal.
# E.g. (l,u) = (2,-1)
# set lview = l + u >= 0 and uview = 0
_fill_rmul!(@view(y[1:-u]), β)
_banded_gbmv!('N', α, view(A, 1-u:m, :), x, β, view(y, 1-u:m))
_fill_rmul!(@view(y[1:-u]), β, yzero)
_banded_gbmv!('N', α, view(A, 1-u:m, :), x, β, view(y, 1-u:m), yzero)
y
else
_banded_gbmv!('N', α, A, x, β, y)
_banded_gbmv!('N', α, A, x, β, y, yzero)
end
end

function materialize!(M::BlasMatMulVecAdd{<:BandedColumnMajor,<:AbstractStridedLayout,<:AbstractStridedLayout,<:BlasFloat})
checkdimensions(M)
_banded_muladd!(M.α, M.A, M.B, M.β, M.C)
_banded_muladd!(M.α, M.A, M.B, M.β, M.C, M.Czero)
end

function _banded_muladd_row!(tA, α, At, x, β, y)
function _banded_muladd_row!(tA, α, At, x, β, y, yzero=false)
n, m = size(At)
u, l = bandwidths(At)
if -l > u # no bands
_fill_rmul!(y, β)
_fill_rmul!(y, β, yzero)
elseif l < 0
_banded_gbmv!(tA, α, view(At, 1-l:n, :,), view(x, 1-l:n), β, y)
_banded_gbmv!(tA, α, view(At, 1-l:n, :,), view(x, 1-l:n), β, y, yzero)
elseif u < 0
_fill_rmul!(@view(y[1:-u]), β)
_banded_gbmv!(tA, α, view(At, :, 1-u:m), x, β, view(y, 1-u:m))
_fill_rmul!(@view(y[1:-u]), β, yzero)
_banded_gbmv!(tA, α, view(At, :, 1-u:m), x, β, view(y, 1-u:m), yzero)
y
else
_banded_gbmv!(tA, α, At, x, β, y)
_banded_gbmv!(tA, α, At, x, β, y, yzero)
end
end

function materialize!(M::BlasMatMulVecAdd{<:BandedRowMajor,<:AbstractStridedLayout,<:AbstractStridedLayout,<:BlasFloat})
checkdimensions(M)
α, A, x, β, y = M.α, M.A, M.B, M.β, M.C
_banded_muladd_row!('T', α, transpose(A), x, β, y)
α, A, x, β, y, yzero = M.α, M.A, M.B, M.β, M.C, M.Czero
_banded_muladd_row!('T', α, transpose(A), x, β, y, yzero)
end

function materialize!(M::BlasMatMulVecAdd{<:ConjLayout{<:BandedRowMajor},<:AbstractStridedLayout,<:AbstractStridedLayout,<:BlasComplex})
checkdimensions(M)
α, A, x, β, y = M.α, M.A, M.B, M.β, M.C
_banded_muladd_row!('C', α, A', x, β, y)
α, A, x, β, y, yzero = M.α, M.A, M.B, M.β, M.C, M.Czero
_banded_muladd_row!('C', α, A', x, β, y, yzero)
end


Expand All @@ -100,7 +100,7 @@ end
@inline function materialize!(M::MatMulVecAdd{<:AbstractBandedLayout})
checkdimensions(M)
α,A,B,β,C = M.α,M.A,M.B,M.β,M.C
_fill_rmul!(C, β)
_fill_rmul!(M, β)
@inbounds for j = intersect(rowsupport(A), colsupport(B))
for k = colrange(A,j)
C[k] += inbands_getindex(A,k,j) * B[j] * α
Expand All @@ -113,7 +113,7 @@ end
checkdimensions(M)
α,At,B,β,C = M.α,M.A,M.B,M.β,M.C
A = transpose(At)
_fill_rmul!(C, β)
_fill_rmul!(M, β)

@inbounds for j = rowsupport(A)
for k = intersect(colrange(A,j), colsupport(B))
Expand All @@ -127,7 +127,7 @@ end
checkdimensions(M)
α,Ac,B,β,C = M.α,M.A,M.B,M.β,M.C
A = Ac'
_fill_rmul!(C, β)
_fill_rmul!(M, β)
@inbounds for j = rowsupport(A)
for k = intersect(colrange(A,j), colsupport(B))
C[j] += inbands_getindex(A,k,j)' * B[k] * α
Expand Down Expand Up @@ -174,17 +174,18 @@ end
const ConjOrBandedLayout = Union{AbstractBandedLayout,ConjLayout{<:AbstractBandedLayout}}
const ConjOrBandedColumnMajor = Union{<:BandedColumnMajor,ConjLayout{<:BandedColumnMajor}}

function _banded_muladd!::T, A, B::AbstractMatrix, β, C) where T
gbmm!('N', 'N', α, A, B, β, C)
function _banded_muladd!::T, A, B::AbstractMatrix, β, C, Czero=false) where T
gbmm!('N', 'N', α, A, B, β, C, Czero)
C
end

materialize!(M::BlasMatMulMatAdd{<:AbstractBandedLayout,<:AbstractBandedLayout,<:BandedColumnMajor}) =
materialize!(MulAdd(M.α, convert(DefaultBandedMatrix,M.A), convert(DefaultBandedMatrix,M.B), M.β, M.C))
materialize!(MulAdd(M.α, convert(DefaultBandedMatrix,M.A), convert(DefaultBandedMatrix,M.B),
M.β, M.C; Czero = M.Czero))

function materialize!(M::BlasMatMulMatAdd{<:BandedColumnMajor,<:BandedColumnMajor,<:BandedColumnMajor})
checkdimensions(M)
_banded_muladd!(M.α, M.A, M.B, M.β, M.C)
_banded_muladd!(M.α, M.A, M.B, M.β, M.C, M.Czero)
end


Expand Down Expand Up @@ -244,7 +245,7 @@ function materialize!(M::MatMulMatAdd{<:BandedColumns, <:AbstractStridedLayout,
α, β, A, B, C = M.α, M.β, M.A, M.B, M.C

if iszero(α)
rmul!(C, β)
_fill_rmul!(M, β)
else
for (colC, colB) in zip(eachcol(C), eachcol(B))
mul!(colC, A, colB, α, β)
Expand All @@ -259,7 +260,7 @@ function materialize!(M::MatMulMatAdd{<:AbstractStridedLayout, <:BandedColumns,
α, β, A, B, C = M.α, M.β, M.A, M.B, M.C

if iszero(α)
rmul!(C, β)
_fill_rmul!(M, β)
else
for (rowC, rowA) in zip(eachrow(C), eachrow(A))
mul!(rowC, transpose(B), rowA, α, β)
Expand Down
5 changes: 3 additions & 2 deletions src/generic/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@ function sumbandwidths(A::AbstractMatrix, B::AbstractMatrix)
end


_fill_lmul!(β, A::AbstractArray{T}) where T = iszero(β) ? zero!(A) : lmul!(β, A)
_fill_rmul!(A::AbstractArray{T}, β) where T = iszero(β) ? zero!(A) : rmul!(A, β)
_fill_lmul!(β, A::AbstractArray, Azero=false) = iszero(β) ? (Azero ? A : zero!(A)) : lmul!(β, A)
_fill_rmul!(A::AbstractArray, β, Azero=false) = iszero(β) ? (Azero ? A : zero!(A)) : rmul!(A, β)
_fill_rmul!(M::MulAdd, β) = _fill_rmul!(M.C, β, M.Czero)

0 comments on commit 13b11fc

Please sign in to comment.