Skip to content

Commit 108e75f

Browse files
authored
[CUSPARSE] Add CuSparseMatrixCSC * CuSparseMatrixCSC (#1663)
1 parent 583b948 commit 108e75f

File tree

4 files changed

+68
-10
lines changed

4 files changed

+68
-10
lines changed

lib/cusparse/generic.jl

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,19 @@ function gemm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSpars
365365
return C
366366
end
367367

368+
function gemm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrixCSC{T}, B::CuSparseMatrixCSC{T},
369+
beta::Number, C::CuSparseMatrixCSC{T}, index::SparseChar, algo::cusparseSpGEMMAlg_t=CUSPARSE_SPGEMM_DEFAULT) where {T}
370+
# C = AB <---> Cᵀ = BᵀAᵀ
371+
Aᵀ = CuSparseMatrixCSR(A.colPtr, A.rowVal, A.nzVal, reverse(size(A)))
372+
Bᵀ = CuSparseMatrixCSR(B.colPtr, B.rowVal, B.nzVal, reverse(size(B)))
373+
Cᵀ = CuSparseMatrixCSR(C.colPtr, C.rowVal, C.nzVal, reverse(size(C)))
374+
gemm!(transb, transa, alpha, Bᵀ, Aᵀ, beta, Cᵀ, index, algo)
375+
# If BᵀAᵀ and Cᵀ have the same sparsity pattern, C is already updated after the gemm! call.
376+
# If BᵀAᵀ and Cᵀ don't have the same sparsity pattern, Cᵀ is reallocated and C must be updated.
377+
C = CuSparseMatrixCSC(Cᵀ.rowPtr, Cᵀ.colVal, Cᵀ.nzVal, size(C))
378+
return C
379+
end
380+
368381
function gemm(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrixCSR{T},
369382
B::CuSparseMatrixCSR{T}, index::SparseChar, algo::cusparseSpGEMMAlg_t=CUSPARSE_SPGEMM_DEFAULT) where {T}
370383
m,k = size(A)
@@ -424,17 +437,30 @@ function gemm(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparse
424437
return C
425438
end
426439

427-
function gemm(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrixCSR{T}, B::CuSparseMatrixCSR{T},
428-
beta::Number, C::CuSparseMatrixCSR{T}, index::SparseChar, algo::cusparseSpGEMMAlg_t=CUSPARSE_SPGEMM_DEFAULT; same_pattern::Bool=false) where {T}
440+
function gemm(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrixCSC{T},
441+
B::CuSparseMatrixCSC{T}, index::SparseChar, algo::cusparseSpGEMMAlg_t=CUSPARSE_SPGEMM_DEFAULT) where {T}
442+
# C = AB <---> Cᵀ = BᵀAᵀ
443+
Aᵀ = CuSparseMatrixCSR(A.colPtr, A.rowVal, A.nzVal, reverse(size(A)))
444+
Bᵀ = CuSparseMatrixCSR(B.colPtr, B.rowVal, B.nzVal, reverse(size(B)))
445+
Cᵀ = gemm(transb, transa, alpha, Bᵀ, Aᵀ, index, algo)
446+
C = CuSparseMatrixCSC(Cᵀ.rowPtr, Cᵀ.colVal, Cᵀ.nzVal, reverse(size(Cᵀ)))
447+
return C
448+
end
429449

430-
if same_pattern
431-
D = copy(C)
432-
gemm!(transa, transb, alpha, A, B, beta, D, index, algo)
433-
else
434-
AB = gemm(transa, transb, one(T), A, B, index, algo)
435-
D = geam(alpha, AB, beta, C, index)
450+
for SparseMatrixType in (:CuSparseMatrixCSC, :CuSparseMatrixCSR)
451+
@eval begin
452+
function gemm(transa::SparseChar, transb::SparseChar, alpha::Number, A::$SparseMatrixType{T}, B::$SparseMatrixType{T},
453+
beta::Number, C::$SparseMatrixType{T}, index::SparseChar, algo::cusparseSpGEMMAlg_t=CUSPARSE_SPGEMM_DEFAULT; same_pattern::Bool=false) where {T}
454+
if same_pattern
455+
D = copy(C)
456+
gemm!(transa, transb, alpha, A, B, beta, D, index, algo)
457+
else
458+
AB = gemm(transa, transb, one(T), A, B, index, algo)
459+
D = geam(alpha, AB, beta, C, index)
460+
end
461+
return D
462+
end
436463
end
437-
return D
438464
end
439465

440466
function sv!(transa::SparseChar, uplo::SparseChar, diag::SparseChar,

lib/cusparse/interfaces.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,15 @@ for (taga, untaga) in tag_wrappers, (wrapa, transa, unwrapa) in op_wrappers
110110
end
111111
end
112112

113+
for SparseMatrixType in (:CuSparseMatrixCSC, :CuSparseMatrixCSR)
114+
@eval begin
115+
function Base.:(*)(A::$SparseMatrixType{T}, B::$SparseMatrixType{T}) where {T <: BlasFloat}
116+
CUSPARSE.version() < v"11.1.1" && throw(ErrorException("This operation is not supported by the current CUDA version."))
117+
gemm('N', 'N', one(T), A, B, 'O')
118+
end
119+
end
120+
end
121+
113122
for op in (:(+), :(-))
114123
@eval begin
115124
Base.$op(A::CuSparseVector{T}, B::CuSparseVector{T}) where {T <: BlasFloat} = axpby(one(T), A, $(op)(one(T)), B, 'O')

test/cusparse/generic.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,13 @@ end # CUSPARSE.version >= 11.3.0
308308

309309
if CUSPARSE.version() >= v"11.1.1"
310310

311-
SPGEMM_ALGOS = Dict(CuSparseMatrixCSR => [CUSPARSE.CUSPARSE_SPGEMM_DEFAULT])
311+
SPGEMM_ALGOS = Dict(CuSparseMatrixCSR => [CUSPARSE.CUSPARSE_SPGEMM_DEFAULT],
312+
CuSparseMatrixCSC => [CUSPARSE.CUSPARSE_SPGEMM_DEFAULT])
312313
if CUSPARSE.version() >= v"11.6.0"
313314
push!(SPGEMM_ALGOS[CuSparseMatrixCSR], CUSPARSE.CUSPARSE_SPGEMM_CSR_ALG_DETERMINITIC,
314315
CUSPARSE.CUSPARSE_SPGEMM_CSR_ALG_NONDETERMINITIC)
316+
push!(SPGEMM_ALGOS[CuSparseMatrixCSC], CUSPARSE.CUSPARSE_SPGEMM_CSR_ALG_DETERMINITIC,
317+
CUSPARSE.CUSPARSE_SPGEMM_CSR_ALG_NONDETERMINITIC)
315318
end
316319

317320
for SparseMatrixType in keys(SPGEMM_ALGOS)

test/cusparse/interfaces.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,26 @@ using LinearAlgebra, SparseArrays
5252
end
5353
end
5454

55+
# SpGEMM was added in CUSPARSE v"11.1.1"
56+
if CUSPARSE.version() >= v"11.1.1"
57+
for SparseMatrixType in (CuSparseMatrixCSC, CuSparseMatrixCSR)
58+
@testset "$SparseMatrixType -- A * B $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
59+
n = 10
60+
k = 15
61+
m = 20
62+
A = sprand(elty, m, k, 0.2)
63+
B = sprand(elty, k, n, 0.5)
64+
65+
dA = SparseMatrixType(A)
66+
dB = SparseMatrixType(B)
67+
68+
C = A * B
69+
dC = dA * dB
70+
@test C collect(dC)
71+
end
72+
end
73+
end
74+
5575
@testset "$f(A)±$h(B) $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64],
5676
f in (identity, transpose), #adjoint),
5777
h in (identity, transpose)#, adjoint)

0 commit comments

Comments
 (0)