11function sparse_release_matrix_handle (A:: oneAbstractSparseMatrix )
2- queue = global_queue (context (A. nzVal), device (A. nzVal))
3- handle_ptr = Ref {matrix_handle_t} (A. handle)
4- onemklXsparse_release_matrix_handle (sycl_queue (queue), handle_ptr)
2+ if A. handle != = nothing
3+ try
4+ queue = global_queue (context (A. nzVal), device (A. nzVal))
5+ handle_ptr = Ref {matrix_handle_t} (A. handle)
6+ onemklXsparse_release_matrix_handle (sycl_queue (queue), handle_ptr)
7+ # Only synchronize after successful release to ensure completion
8+ synchronize (queue)
9+ catch err
10+ # Don't let finalizer errors crash the program
11+ @warn " Error releasing sparse matrix handle" exception= err
12+ end
13+ end
514end
615
716for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int32 ),
@@ -13,20 +22,55 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
1322 (:onemklZsparse_set_csr_data , :ComplexF64 , :Int32 ),
1423 (:onemklZsparse_set_csr_data_64 , :ComplexF64 , :Int64 ))
1524 @eval begin
16- function oneSparseMatrixCSR (A:: SparseMatrixCSC{$elty, $intty} )
25+
26+ function oneSparseMatrixCSR (
27+ rowPtr:: oneVector{$intty} , colVal:: oneVector{$intty} ,
28+ nzVal:: oneVector{$elty} , dims:: NTuple{2, Int}
29+ )
30+ handle_ptr = Ref {matrix_handle_t} ()
31+ onemklXsparse_init_matrix_handle (handle_ptr)
32+ m, n = dims
33+ nnzA = length (nzVal)
34+ queue = global_queue (context (nzVal), device (nzVal))
35+ # Don't update handle if matrix is empty
36+ if m != 0 && n != 0
37+ $ fname (sycl_queue (queue), handle_ptr[], m, n, ' O' , rowPtr, colVal, nzVal)
38+ dA = oneSparseMatrixCSR {$elty, $intty} (handle_ptr[], rowPtr, colVal, nzVal, (m, n), nnzA)
39+ finalizer (sparse_release_matrix_handle, dA)
40+ else
41+ dA = oneSparseMatrixCSR {$elty, $intty} (nothing , rowPtr, colVal, nzVal, (m, n), nnzA)
42+ end
43+ return dA
44+ end
45+
46+ function oneSparseMatrixCSC (
47+ colPtr:: oneVector{$intty} , rowVal:: oneVector{$intty} ,
48+ nzVal:: oneVector{$elty} , dims:: NTuple{2, Int}
49+ )
50+ queue = global_queue (context (nzVal), device (nzVal))
1751 handle_ptr = Ref {matrix_handle_t} ()
1852 onemklXsparse_init_matrix_handle (handle_ptr)
53+ m, n = dims
54+ nnzA = length (nzVal)
55+ # Don't update handle if matrix is empty
56+ if m != 0 && n != 0
57+ $ fname (sycl_queue (queue), handle_ptr[], n, m, ' O' , colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ
58+ dA = oneSparseMatrixCSC {$elty, $intty} (handle_ptr[], colPtr, rowVal, nzVal, (m, n), nnzA)
59+ finalizer (sparse_release_matrix_handle, dA)
60+ else
61+ dA = oneSparseMatrixCSC {$elty, $intty} (nothing , colPtr, rowVal, nzVal, (m, n), nnzA)
62+ end
63+ return dA
64+ end
65+
66+
67+ function oneSparseMatrixCSR (A:: SparseMatrixCSC{$elty, $intty} )
1968 m, n = size (A)
2069 At = SparseMatrixCSC (A |> transpose)
2170 rowPtr = oneVector {$intty} (At. colptr)
2271 colVal = oneVector {$intty} (At. rowval)
2372 nzVal = oneVector {$elty} (At. nzval)
24- nnzA = length (At. nzval)
25- queue = global_queue (context (nzVal), device ())
26- $ fname (sycl_queue (queue), handle_ptr[], m, n, ' O' , rowPtr, colVal, nzVal)
27- dA = oneSparseMatrixCSR {$elty, $intty} (handle_ptr[], rowPtr, colVal, nzVal, (m,n), nnzA)
28- finalizer (sparse_release_matrix_handle, dA)
29- return dA
73+ return oneSparseMatrixCSR (rowPtr, colVal, nzVal, (m, n))
3074 end
3175
3276 function SparseArrays. SparseMatrixCSC (A:: oneSparseMatrixCSR{$elty, $intty} )
@@ -37,18 +81,11 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
3781 end
3882
3983 function oneSparseMatrixCSC (A:: SparseMatrixCSC{$elty, $intty} )
40- handle_ptr = Ref {matrix_handle_t} ()
41- onemklXsparse_init_matrix_handle (handle_ptr)
4284 m, n = size (A)
4385 colPtr = oneVector {$intty} (A. colptr)
4486 rowVal = oneVector {$intty} (A. rowval)
4587 nzVal = oneVector {$elty} (A. nzval)
46- nnzA = length (A. nzval)
47- queue = global_queue (context (nzVal), device ())
48- $ fname (sycl_queue (queue), handle_ptr[], n, m, ' O' , colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ
49- dA = oneSparseMatrixCSC {$elty, $intty} (handle_ptr[], colPtr, rowVal, nzVal, (m,n), nnzA)
50- finalizer (sparse_release_matrix_handle, dA)
51- return dA
88+ return oneSparseMatrixCSC (colPtr, rowVal, nzVal, (m, n))
5289 end
5390
5491 function SparseArrays. SparseMatrixCSC (A:: oneSparseMatrixCSC{$elty, $intty} )
@@ -77,10 +114,14 @@ for (fname, elty, intty) in ((:onemklSsparse_set_coo_data , :Float32 , :Int3
77114 colInd = oneVector {$intty} (col)
78115 nzVal = oneVector {$elty} (val)
79116 nnzA = length (val)
80- queue = global_queue (context (nzVal), device ())
81- $ fname (sycl_queue (queue), handle_ptr[], m, n, nnzA, ' O' , rowInd, colInd, nzVal)
82- dA = oneSparseMatrixCOO {$elty, $intty} (handle_ptr[], rowInd, colInd, nzVal, (m,n), nnzA)
83- finalizer (sparse_release_matrix_handle, dA)
117+ queue = global_queue (context (nzVal), device (nzVal))
118+ if m != 0 && n != 0
119+ $ fname (sycl_queue (queue), handle_ptr[], m, n, nnzA, ' O' , rowInd, colInd, nzVal)
120+ dA = oneSparseMatrixCOO {$elty, $intty} (handle_ptr[], rowInd, colInd, nzVal, (m,n), nnzA)
121+ finalizer (sparse_release_matrix_handle, dA)
122+ else
123+ dA = oneSparseMatrixCOO {$elty, $intty} (nothing , rowInd, colInd, nzVal, (m,n), nnzA)
124+ end
84125 return dA
85126 end
86127
@@ -105,7 +146,7 @@ for SparseMatrix in (:oneSparseMatrixCSR, :oneSparseMatrixCOO)
105146 beta:: Number ,
106147 y:: oneStridedVector{$elty} )
107148
108- queue = global_queue (context (x), device ())
149+ queue = global_queue (context (x), device (x ))
109150 $ fname (sycl_queue (queue), trans, alpha, A. handle, x, beta, y)
110151 y
111152 end
@@ -140,8 +181,11 @@ for SparseMatrix in (:oneSparseMatrixCSC,)
140181 beta:: Number ,
141182 y:: oneStridedVector{$elty} )
142183
143- queue = global_queue (context (x), device ())
144- $ fname (sycl_queue (queue), flip_trans (trans), alpha, A. handle, x, beta, y)
184+ queue = global_queue (context (x), device (x))
185+ m, n = size (A)
186+ if m != 0 && n != 0
187+ $ fname (sycl_queue (queue), flip_trans (trans), alpha, A. handle, x, beta, y)
188+ end
145189 y
146190 end
147191 end
@@ -173,7 +217,7 @@ for SparseMatrix in (:oneSparseMatrixCSC,)
173217 beta = conj (beta)
174218 end
175219
176- queue = global_queue (context (x), device ())
220+ queue = global_queue (context (x), device (x ))
177221 $ fname (sycl_queue (queue), flip_trans (trans), alpha, A. handle, x, beta, y)
178222
179223 if trans == ' C'
@@ -217,7 +261,7 @@ for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
217261 nrhs = size (B, 2 )
218262 ldb = max (1 ,stride (B,2 ))
219263 ldc = max (1 ,stride (C,2 ))
220- queue = global_queue (context (C), device ())
264+ queue = global_queue (context (C), device (C ))
221265 $ fname (sycl_queue (queue), ' C' , transa, transb, alpha, A. handle, B, nrhs, ldb, beta, C, ldc)
222266 C
223267 end
@@ -254,7 +298,7 @@ for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
254298 nrhs = size (B, 2 )
255299 ldb = max (1 ,stride (B,2 ))
256300 ldc = max (1 ,stride (C,2 ))
257- queue = global_queue (context (C), device ())
301+ queue = global_queue (context (C), device (C ))
258302 $ fname (sycl_queue (queue), ' C' , flip_trans (transa), transb, alpha, A. handle, B, nrhs, ldb, beta, C, ldc)
259303 C
260304 end
@@ -289,7 +333,7 @@ for (fname, elty) in (
289333 nrhs = size (B, 2 )
290334 ldb = max (1 , stride (B, 2 ))
291335 ldc = max (1 , stride (C, 2 ))
292- queue = global_queue (context (C), device ())
336+ queue = global_queue (context (C), device (C ))
293337
294338 # Use identity: conj(C_new) = conj(alpha) * S * conj(opB(B)) + conj(beta) * conj(C)
295339 # Prepare conj(C) in-place and conj(B) into a temporary if needed
@@ -359,7 +403,7 @@ for (fname, elty) in ((:onemklSsparse_symv, :Float32),
359403 beta:: Number ,
360404 y:: oneStridedVector{$elty} )
361405
362- queue = global_queue (context (y), device ())
406+ queue = global_queue (context (y), device (y ))
363407 $ fname (sycl_queue (queue), uplo, alpha, A. handle, x, beta, y)
364408 y
365409 end
@@ -379,7 +423,7 @@ for (fname, elty) in ((:onemklSsparse_symv, :Float32),
379423 beta:: Number ,
380424 y:: oneStridedVector{$elty} )
381425
382- queue = global_queue (context (y), device ())
426+ queue = global_queue (context (y), device (y ))
383427 $ fname (sycl_queue (queue), flip_uplo (uplo), alpha, A. handle, x, beta, y)
384428 y
385429 end
@@ -400,7 +444,7 @@ for (fname, elty) in ((:onemklSsparse_trmv, :Float32),
400444 beta:: Number ,
401445 y:: oneStridedVector{$elty} )
402446
403- queue = global_queue (context (y), device ())
447+ queue = global_queue (context (y), device (y ))
404448 $ fname (sycl_queue (queue), uplo, trans, diag, alpha, A. handle, x, beta, y)
405449 y
406450 end
@@ -442,7 +486,7 @@ for (fname, elty) in (
442486 " Convert to oneSparseMatrixCSR format instead."
443487 )
444488 )
445- queue = global_queue (context (y), device ())
489+ queue = global_queue (context (y), device (y ))
446490 $ fname (sycl_queue (queue), uplo, flip_trans (trans), diag, alpha, A. handle, x, beta, y)
447491 return y
448492 end
@@ -475,7 +519,7 @@ for (fname, elty) in ((:onemklSsparse_trsv, :Float32),
475519 x:: oneStridedVector{$elty} ,
476520 y:: oneStridedVector{$elty} )
477521
478- queue = global_queue (context (y), device ())
522+ queue = global_queue (context (y), device (y ))
479523 $ fname (sycl_queue (queue), uplo, trans, diag, alpha, A. handle, x, y)
480524 y
481525 end
@@ -512,7 +556,7 @@ for (fname, elty) in (
512556 " Convert to oneSparseMatrixCSR format instead."
513557 )
514558 )
515- queue = global_queue (context (y), device ())
559+ queue = global_queue (context (y), device (y ))
516560 onemklXsparse_optimize_trsv (sycl_queue (queue), uplo, flip_trans (trans), diag, A. handle)
517561 return A
518562 end
@@ -555,7 +599,7 @@ for (fname, elty) in ((:onemklSsparse_trsm, :Float32),
555599 nrhs = size (X, 2 )
556600 ldx = max (1 ,stride (X,2 ))
557601 ldy = max (1 ,stride (Y,2 ))
558- queue = global_queue (context (Y), device ())
602+ queue = global_queue (context (Y), device (Y ))
559603 $ fname (sycl_queue (queue), ' C' , transA, transX, uplo, diag, alpha, A. handle, X, nrhs, ldx, Y, ldy)
560604 Y
561605 end
@@ -614,7 +658,7 @@ for (fname, elty) in (
614658 nrhs = size (X, 2 )
615659 ldx = max (1 , stride (X, 2 ))
616660 ldy = max (1 , stride (Y, 2 ))
617- queue = global_queue (context (Y), device ())
661+ queue = global_queue (context (Y), device (Y ))
618662 $ fname (sycl_queue (queue), ' C' , flip_trans (transA), transX, uplo, diag, alpha, A. handle, X, nrhs, ldx, Y, ldy)
619663 return Y
620664 end
0 commit comments