Skip to content

Commit

Permalink
[CUSPARSE] Add CuSparseMatrixCSC * CuSparseMatrixCSC (#1663)
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison authored Nov 9, 2022
1 parent 583b948 commit 108e75f
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 10 deletions.
44 changes: 35 additions & 9 deletions lib/cusparse/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,19 @@ function gemm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSpars
return C
end

function gemm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrixCSC{T}, B::CuSparseMatrixCSC{T},
beta::Number, C::CuSparseMatrixCSC{T}, index::SparseChar, algo::cusparseSpGEMMAlg_t=CUSPARSE_SPGEMM_DEFAULT) where {T}
# C = AB <---> Cᵀ = BᵀAᵀ
Aᵀ = CuSparseMatrixCSR(A.colPtr, A.rowVal, A.nzVal, reverse(size(A)))
Bᵀ = CuSparseMatrixCSR(B.colPtr, B.rowVal, B.nzVal, reverse(size(B)))
Cᵀ = CuSparseMatrixCSR(C.colPtr, C.rowVal, C.nzVal, reverse(size(C)))
gemm!(transb, transa, alpha, Bᵀ, Aᵀ, beta, Cᵀ, index, algo)
# If BᵀAᵀ and Cᵀ have the same sparsity pattern, C is already updated after the gemm! call.
# If BᵀAᵀ and Cᵀ don't have the same sparsity pattern, Cᵀ is reallocated and C must be updated.
C = CuSparseMatrixCSC(Cᵀ.rowPtr, Cᵀ.colVal, Cᵀ.nzVal, size(C))
return C
end

function gemm(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrixCSR{T},
B::CuSparseMatrixCSR{T}, index::SparseChar, algo::cusparseSpGEMMAlg_t=CUSPARSE_SPGEMM_DEFAULT) where {T}
m,k = size(A)
Expand Down Expand Up @@ -424,17 +437,30 @@ function gemm(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparse
return C
end

function gemm(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrixCSR{T}, B::CuSparseMatrixCSR{T},
beta::Number, C::CuSparseMatrixCSR{T}, index::SparseChar, algo::cusparseSpGEMMAlg_t=CUSPARSE_SPGEMM_DEFAULT; same_pattern::Bool=false) where {T}
function gemm(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrixCSC{T},
B::CuSparseMatrixCSC{T}, index::SparseChar, algo::cusparseSpGEMMAlg_t=CUSPARSE_SPGEMM_DEFAULT) where {T}
# C = AB <---> Cᵀ = BᵀAᵀ
Aᵀ = CuSparseMatrixCSR(A.colPtr, A.rowVal, A.nzVal, reverse(size(A)))
Bᵀ = CuSparseMatrixCSR(B.colPtr, B.rowVal, B.nzVal, reverse(size(B)))
Cᵀ = gemm(transb, transa, alpha, Bᵀ, Aᵀ, index, algo)
C = CuSparseMatrixCSC(Cᵀ.rowPtr, Cᵀ.colVal, Cᵀ.nzVal, reverse(size(Cᵀ)))
return C
end

if same_pattern
D = copy(C)
gemm!(transa, transb, alpha, A, B, beta, D, index, algo)
else
AB = gemm(transa, transb, one(T), A, B, index, algo)
D = geam(alpha, AB, beta, C, index)
for SparseMatrixType in (:CuSparseMatrixCSC, :CuSparseMatrixCSR)
@eval begin
function gemm(transa::SparseChar, transb::SparseChar, alpha::Number, A::$SparseMatrixType{T}, B::$SparseMatrixType{T},
beta::Number, C::$SparseMatrixType{T}, index::SparseChar, algo::cusparseSpGEMMAlg_t=CUSPARSE_SPGEMM_DEFAULT; same_pattern::Bool=false) where {T}
if same_pattern
D = copy(C)
gemm!(transa, transb, alpha, A, B, beta, D, index, algo)
else
AB = gemm(transa, transb, one(T), A, B, index, algo)
D = geam(alpha, AB, beta, C, index)
end
return D
end
end
return D
end

function sv!(transa::SparseChar, uplo::SparseChar, diag::SparseChar,
Expand Down
9 changes: 9 additions & 0 deletions lib/cusparse/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ for (taga, untaga) in tag_wrappers, (wrapa, transa, unwrapa) in op_wrappers
end
end

for SparseMatrixType in (:CuSparseMatrixCSC, :CuSparseMatrixCSR)
@eval begin
function Base.:(*)(A::$SparseMatrixType{T}, B::$SparseMatrixType{T}) where {T <: BlasFloat}
CUSPARSE.version() < v"11.1.1" && throw(ErrorException("This operation is not supported by the current CUDA version."))
gemm('N', 'N', one(T), A, B, 'O')
end
end
end

for op in (:(+), :(-))
@eval begin
Base.$op(A::CuSparseVector{T}, B::CuSparseVector{T}) where {T <: BlasFloat} = axpby(one(T), A, $(op)(one(T)), B, 'O')
Expand Down
5 changes: 4 additions & 1 deletion test/cusparse/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,13 @@ end # CUSPARSE.version >= 11.3.0

if CUSPARSE.version() >= v"11.1.1"

SPGEMM_ALGOS = Dict(CuSparseMatrixCSR => [CUSPARSE.CUSPARSE_SPGEMM_DEFAULT])
SPGEMM_ALGOS = Dict(CuSparseMatrixCSR => [CUSPARSE.CUSPARSE_SPGEMM_DEFAULT],
CuSparseMatrixCSC => [CUSPARSE.CUSPARSE_SPGEMM_DEFAULT])
if CUSPARSE.version() >= v"11.6.0"
push!(SPGEMM_ALGOS[CuSparseMatrixCSR], CUSPARSE.CUSPARSE_SPGEMM_CSR_ALG_DETERMINITIC,
CUSPARSE.CUSPARSE_SPGEMM_CSR_ALG_NONDETERMINITIC)
push!(SPGEMM_ALGOS[CuSparseMatrixCSC], CUSPARSE.CUSPARSE_SPGEMM_CSR_ALG_DETERMINITIC,
CUSPARSE.CUSPARSE_SPGEMM_CSR_ALG_NONDETERMINITIC)
end

for SparseMatrixType in keys(SPGEMM_ALGOS)
Expand Down
20 changes: 20 additions & 0 deletions test/cusparse/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,26 @@ using LinearAlgebra, SparseArrays
end
end

# SpGEMM was added in CUSPARSE v"11.1.1"
if CUSPARSE.version() >= v"11.1.1"
for SparseMatrixType in (CuSparseMatrixCSC, CuSparseMatrixCSR)
@testset "$SparseMatrixType -- A * B $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
n = 10
k = 15
m = 20
A = sprand(elty, m, k, 0.2)
B = sprand(elty, k, n, 0.5)

dA = SparseMatrixType(A)
dB = SparseMatrixType(B)

C = A * B
dC = dA * dB
@test C collect(dC)
end
end
end

@testset "$f(A)±$h(B) $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64],
f in (identity, transpose), #adjoint),
h in (identity, transpose)#, adjoint)
Expand Down

0 comments on commit 108e75f

Please sign in to comment.