@@ -365,6 +365,19 @@ function gemm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSpars
365
365
return C
366
366
end
367
367
368
+ function gemm! (transa:: SparseChar , transb:: SparseChar , alpha:: Number , A:: CuSparseMatrixCSC{T} , B:: CuSparseMatrixCSC{T} ,
369
+ beta:: Number , C:: CuSparseMatrixCSC{T} , index:: SparseChar , algo:: cusparseSpGEMMAlg_t = CUSPARSE_SPGEMM_DEFAULT) where {T}
370
+ # C = AB <---> Cᵀ = BᵀAᵀ
371
+ Aᵀ = CuSparseMatrixCSR (A. colPtr, A. rowVal, A. nzVal, reverse (size (A)))
372
+ Bᵀ = CuSparseMatrixCSR (B. colPtr, B. rowVal, B. nzVal, reverse (size (B)))
373
+ Cᵀ = CuSparseMatrixCSR (C. colPtr, C. rowVal, C. nzVal, reverse (size (C)))
374
+ gemm! (transb, transa, alpha, Bᵀ, Aᵀ, beta, Cᵀ, index, algo)
375
+ # If BᵀAᵀ and Cᵀ have the same sparsity pattern, C is already updated after the gemm! call.
376
+ # If BᵀAᵀ and Cᵀ don't have the same sparsity pattern, Cᵀ is reallocated and C must be updated.
377
+ C = CuSparseMatrixCSC (Cᵀ. rowPtr, Cᵀ. colVal, Cᵀ. nzVal, size (C))
378
+ return C
379
+ end
380
+
368
381
function gemm (transa:: SparseChar , transb:: SparseChar , alpha:: Number , A:: CuSparseMatrixCSR{T} ,
369
382
B:: CuSparseMatrixCSR{T} , index:: SparseChar , algo:: cusparseSpGEMMAlg_t = CUSPARSE_SPGEMM_DEFAULT) where {T}
370
383
m,k = size (A)
@@ -424,17 +437,30 @@ function gemm(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparse
424
437
return C
425
438
end
426
439
427
- function gemm (transa:: SparseChar , transb:: SparseChar , alpha:: Number , A:: CuSparseMatrixCSR{T} , B:: CuSparseMatrixCSR{T} ,
428
- beta:: Number , C:: CuSparseMatrixCSR{T} , index:: SparseChar , algo:: cusparseSpGEMMAlg_t = CUSPARSE_SPGEMM_DEFAULT; same_pattern:: Bool = false ) where {T}
440
+ function gemm (transa:: SparseChar , transb:: SparseChar , alpha:: Number , A:: CuSparseMatrixCSC{T} ,
441
+ B:: CuSparseMatrixCSC{T} , index:: SparseChar , algo:: cusparseSpGEMMAlg_t = CUSPARSE_SPGEMM_DEFAULT) where {T}
442
+ # C = AB <---> Cᵀ = BᵀAᵀ
443
+ Aᵀ = CuSparseMatrixCSR (A. colPtr, A. rowVal, A. nzVal, reverse (size (A)))
444
+ Bᵀ = CuSparseMatrixCSR (B. colPtr, B. rowVal, B. nzVal, reverse (size (B)))
445
+ Cᵀ = gemm (transb, transa, alpha, Bᵀ, Aᵀ, index, algo)
446
+ C = CuSparseMatrixCSC (Cᵀ. rowPtr, Cᵀ. colVal, Cᵀ. nzVal, reverse (size (Cᵀ)))
447
+ return C
448
+ end
429
449
430
- if same_pattern
431
- D = copy (C)
432
- gemm! (transa, transb, alpha, A, B, beta, D, index, algo)
433
- else
434
- AB = gemm (transa, transb, one (T), A, B, index, algo)
435
- D = geam (alpha, AB, beta, C, index)
450
+ for SparseMatrixType in (:CuSparseMatrixCSC , :CuSparseMatrixCSR )
451
+ @eval begin
452
+ function gemm (transa:: SparseChar , transb:: SparseChar , alpha:: Number , A:: $SparseMatrixType{T} , B:: $SparseMatrixType{T} ,
453
+ beta:: Number , C:: $SparseMatrixType{T} , index:: SparseChar , algo:: cusparseSpGEMMAlg_t = CUSPARSE_SPGEMM_DEFAULT; same_pattern:: Bool = false ) where {T}
454
+ if same_pattern
455
+ D = copy (C)
456
+ gemm! (transa, transb, alpha, A, B, beta, D, index, algo)
457
+ else
458
+ AB = gemm (transa, transb, one (T), A, B, index, algo)
459
+ D = geam (alpha, AB, beta, C, index)
460
+ end
461
+ return D
462
+ end
436
463
end
437
- return D
438
464
end
439
465
440
466
function sv! (transa:: SparseChar , uplo:: SparseChar , diag:: SparseChar ,
0 commit comments