Skip to content

Commit 47c3ff7

Browse files
committed
Towards MadNLP
1 parent 24110a8 commit 47c3ff7

File tree

10 files changed

+174
-86
lines changed

10 files changed

+174
-86
lines changed

lib/mkl/array.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ const oneAbstractSparseVector{Tv, Ti} = oneAbstractSparseArray{Tv, Ti, 1}
55
const oneAbstractSparseMatrix{Tv, Ti} = oneAbstractSparseArray{Tv, Ti, 2}
66

77
mutable struct oneSparseMatrixCSR{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
8-
handle::matrix_handle_t
8+
handle::Union{Nothing, matrix_handle_t}
99
rowPtr::oneVector{Ti}
1010
colVal::oneVector{Ti}
1111
nzVal::oneVector{Tv}
@@ -14,7 +14,7 @@ mutable struct oneSparseMatrixCSR{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
1414
end
1515

1616
mutable struct oneSparseMatrixCSC{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
17-
handle::matrix_handle_t
17+
handle::Union{Nothing, matrix_handle_t}
1818
colPtr::oneVector{Ti}
1919
rowVal::oneVector{Ti}
2020
nzVal::oneVector{Tv}
@@ -23,7 +23,7 @@ mutable struct oneSparseMatrixCSC{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
2323
end
2424

2525
mutable struct oneSparseMatrixCOO{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
26-
handle::matrix_handle_t
26+
handle::Union{Nothing, matrix_handle_t}
2727
rowInd::oneVector{Ti}
2828
colInd::oneVector{Ti}
2929
nzVal::oneVector{Tv}

lib/mkl/wrappers_blas.jl

Lines changed: 42 additions & 42 deletions
Large diffs are not rendered by default.

lib/mkl/wrappers_sparse.jl

Lines changed: 81 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
function sparse_release_matrix_handle(A::oneAbstractSparseMatrix)
2-
queue = global_queue(context(A.nzVal), device(A.nzVal))
3-
handle_ptr = Ref{matrix_handle_t}(A.handle)
4-
onemklXsparse_release_matrix_handle(sycl_queue(queue), handle_ptr)
2+
if A.handle !== nothing
3+
try
4+
queue = global_queue(context(A.nzVal), device(A.nzVal))
5+
handle_ptr = Ref{matrix_handle_t}(A.handle)
6+
onemklXsparse_release_matrix_handle(sycl_queue(queue), handle_ptr)
7+
# Only synchronize after successful release to ensure completion
8+
synchronize(queue)
9+
catch err
10+
# Don't let finalizer errors crash the program
11+
@warn "Error releasing sparse matrix handle" exception=err
12+
end
13+
end
514
end
615

716
for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int32),
@@ -13,20 +22,55 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
1322
(:onemklZsparse_set_csr_data , :ComplexF64, :Int32),
1423
(:onemklZsparse_set_csr_data_64, :ComplexF64, :Int64))
1524
@eval begin
16-
function oneSparseMatrixCSR(A::SparseMatrixCSC{$elty, $intty})
25+
26+
function oneSparseMatrixCSR(
27+
rowPtr::oneVector{$intty}, colVal::oneVector{$intty},
28+
nzVal::oneVector{$elty}, dims::NTuple{2, Int}
29+
)
30+
handle_ptr = Ref{matrix_handle_t}()
31+
onemklXsparse_init_matrix_handle(handle_ptr)
32+
m, n = dims
33+
nnzA = length(nzVal)
34+
queue = global_queue(context(nzVal), device(nzVal))
35+
# Don't update handle if matrix is empty
36+
if m != 0 && n != 0
37+
$fname(sycl_queue(queue), handle_ptr[], m, n, 'O', rowPtr, colVal, nzVal)
38+
dA = oneSparseMatrixCSR{$elty, $intty}(handle_ptr[], rowPtr, colVal, nzVal, (m, n), nnzA)
39+
finalizer(sparse_release_matrix_handle, dA)
40+
else
41+
dA = oneSparseMatrixCSR{$elty, $intty}(nothing, rowPtr, colVal, nzVal, (m, n), nnzA)
42+
end
43+
return dA
44+
end
45+
46+
function oneSparseMatrixCSC(
47+
colPtr::oneVector{$intty}, rowVal::oneVector{$intty},
48+
nzVal::oneVector{$elty}, dims::NTuple{2, Int}
49+
)
50+
queue = global_queue(context(nzVal), device(nzVal))
1751
handle_ptr = Ref{matrix_handle_t}()
1852
onemklXsparse_init_matrix_handle(handle_ptr)
53+
m, n = dims
54+
nnzA = length(nzVal)
55+
# Don't update handle if matrix is empty
56+
if m != 0 && n != 0
57+
$fname(sycl_queue(queue), handle_ptr[], n, m, 'O', colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ
58+
dA = oneSparseMatrixCSC{$elty, $intty}(handle_ptr[], colPtr, rowVal, nzVal, (m, n), nnzA)
59+
finalizer(sparse_release_matrix_handle, dA)
60+
else
61+
dA = oneSparseMatrixCSC{$elty, $intty}(nothing, colPtr, rowVal, nzVal, (m, n), nnzA)
62+
end
63+
return dA
64+
end
65+
66+
67+
function oneSparseMatrixCSR(A::SparseMatrixCSC{$elty, $intty})
1968
m, n = size(A)
2069
At = SparseMatrixCSC(A |> transpose)
2170
rowPtr = oneVector{$intty}(At.colptr)
2271
colVal = oneVector{$intty}(At.rowval)
2372
nzVal = oneVector{$elty}(At.nzval)
24-
nnzA = length(At.nzval)
25-
queue = global_queue(context(nzVal), device())
26-
$fname(sycl_queue(queue), handle_ptr[], m, n, 'O', rowPtr, colVal, nzVal)
27-
dA = oneSparseMatrixCSR{$elty, $intty}(handle_ptr[], rowPtr, colVal, nzVal, (m,n), nnzA)
28-
finalizer(sparse_release_matrix_handle, dA)
29-
return dA
73+
return oneSparseMatrixCSR(rowPtr, colVal, nzVal, (m, n))
3074
end
3175

3276
function SparseArrays.SparseMatrixCSC(A::oneSparseMatrixCSR{$elty, $intty})
@@ -37,18 +81,11 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
3781
end
3882

3983
function oneSparseMatrixCSC(A::SparseMatrixCSC{$elty, $intty})
40-
handle_ptr = Ref{matrix_handle_t}()
41-
onemklXsparse_init_matrix_handle(handle_ptr)
4284
m, n = size(A)
4385
colPtr = oneVector{$intty}(A.colptr)
4486
rowVal = oneVector{$intty}(A.rowval)
4587
nzVal = oneVector{$elty}(A.nzval)
46-
nnzA = length(A.nzval)
47-
queue = global_queue(context(nzVal), device())
48-
$fname(sycl_queue(queue), handle_ptr[], n, m, 'O', colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ
49-
dA = oneSparseMatrixCSC{$elty, $intty}(handle_ptr[], colPtr, rowVal, nzVal, (m,n), nnzA)
50-
finalizer(sparse_release_matrix_handle, dA)
51-
return dA
88+
return oneSparseMatrixCSC(colPtr, rowVal, nzVal, (m, n))
5289
end
5390

5491
function SparseArrays.SparseMatrixCSC(A::oneSparseMatrixCSC{$elty, $intty})
@@ -77,10 +114,14 @@ for (fname, elty, intty) in ((:onemklSsparse_set_coo_data , :Float32 , :Int3
77114
colInd = oneVector{$intty}(col)
78115
nzVal = oneVector{$elty}(val)
79116
nnzA = length(val)
80-
queue = global_queue(context(nzVal), device())
81-
$fname(sycl_queue(queue), handle_ptr[], m, n, nnzA, 'O', rowInd, colInd, nzVal)
82-
dA = oneSparseMatrixCOO{$elty, $intty}(handle_ptr[], rowInd, colInd, nzVal, (m,n), nnzA)
83-
finalizer(sparse_release_matrix_handle, dA)
117+
queue = global_queue(context(nzVal), device(nzVal))
118+
if m != 0 && n != 0
119+
$fname(sycl_queue(queue), handle_ptr[], m, n, nnzA, 'O', rowInd, colInd, nzVal)
120+
dA = oneSparseMatrixCOO{$elty, $intty}(handle_ptr[], rowInd, colInd, nzVal, (m,n), nnzA)
121+
finalizer(sparse_release_matrix_handle, dA)
122+
else
123+
dA = oneSparseMatrixCOO{$elty, $intty}(nothing, rowInd, colInd, nzVal, (m,n), nnzA)
124+
end
84125
return dA
85126
end
86127

@@ -105,7 +146,7 @@ for SparseMatrix in (:oneSparseMatrixCSR, :oneSparseMatrixCOO)
105146
beta::Number,
106147
y::oneStridedVector{$elty})
107148

108-
queue = global_queue(context(x), device())
149+
queue = global_queue(context(x), device(x))
109150
$fname(sycl_queue(queue), trans, alpha, A.handle, x, beta, y)
110151
y
111152
end
@@ -140,8 +181,11 @@ for SparseMatrix in (:oneSparseMatrixCSC,)
140181
beta::Number,
141182
y::oneStridedVector{$elty})
142183

143-
queue = global_queue(context(x), device())
144-
$fname(sycl_queue(queue), flip_trans(trans), alpha, A.handle, x, beta, y)
184+
queue = global_queue(context(x), device(x))
185+
m, n = size(A)
186+
if m != 0 && n != 0
187+
$fname(sycl_queue(queue), flip_trans(trans), alpha, A.handle, x, beta, y)
188+
end
145189
y
146190
end
147191
end
@@ -173,7 +217,7 @@ for SparseMatrix in (:oneSparseMatrixCSC,)
173217
beta = conj(beta)
174218
end
175219

176-
queue = global_queue(context(x), device())
220+
queue = global_queue(context(x), device(x))
177221
$fname(sycl_queue(queue), flip_trans(trans), alpha, A.handle, x, beta, y)
178222

179223
if trans == 'C'
@@ -217,7 +261,7 @@ for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
217261
nrhs = size(B, 2)
218262
ldb = max(1,stride(B,2))
219263
ldc = max(1,stride(C,2))
220-
queue = global_queue(context(C), device())
264+
queue = global_queue(context(C), device(C))
221265
$fname(sycl_queue(queue), 'C', transa, transb, alpha, A.handle, B, nrhs, ldb, beta, C, ldc)
222266
C
223267
end
@@ -254,7 +298,7 @@ for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
254298
nrhs = size(B, 2)
255299
ldb = max(1,stride(B,2))
256300
ldc = max(1,stride(C,2))
257-
queue = global_queue(context(C), device())
301+
queue = global_queue(context(C), device(C))
258302
$fname(sycl_queue(queue), 'C', flip_trans(transa), transb, alpha, A.handle, B, nrhs, ldb, beta, C, ldc)
259303
C
260304
end
@@ -289,7 +333,7 @@ for (fname, elty) in (
289333
nrhs = size(B, 2)
290334
ldb = max(1, stride(B, 2))
291335
ldc = max(1, stride(C, 2))
292-
queue = global_queue(context(C), device())
336+
queue = global_queue(context(C), device(C))
293337

294338
# Use identity: conj(C_new) = conj(alpha) * S * conj(opB(B)) + conj(beta) * conj(C)
295339
# Prepare conj(C) in-place and conj(B) into a temporary if needed
@@ -359,7 +403,7 @@ for (fname, elty) in ((:onemklSsparse_symv, :Float32),
359403
beta::Number,
360404
y::oneStridedVector{$elty})
361405

362-
queue = global_queue(context(y), device())
406+
queue = global_queue(context(y), device(y))
363407
$fname(sycl_queue(queue), uplo, alpha, A.handle, x, beta, y)
364408
y
365409
end
@@ -379,7 +423,7 @@ for (fname, elty) in ((:onemklSsparse_symv, :Float32),
379423
beta::Number,
380424
y::oneStridedVector{$elty})
381425

382-
queue = global_queue(context(y), device())
426+
queue = global_queue(context(y), device(y))
383427
$fname(sycl_queue(queue), flip_uplo(uplo), alpha, A.handle, x, beta, y)
384428
y
385429
end
@@ -400,7 +444,7 @@ for (fname, elty) in ((:onemklSsparse_trmv, :Float32),
400444
beta::Number,
401445
y::oneStridedVector{$elty})
402446

403-
queue = global_queue(context(y), device())
447+
queue = global_queue(context(y), device(y))
404448
$fname(sycl_queue(queue), uplo, trans, diag, alpha, A.handle, x, beta, y)
405449
y
406450
end
@@ -442,7 +486,7 @@ for (fname, elty) in (
442486
"Convert to oneSparseMatrixCSR format instead."
443487
)
444488
)
445-
queue = global_queue(context(y), device())
489+
queue = global_queue(context(y), device(y))
446490
$fname(sycl_queue(queue), uplo, flip_trans(trans), diag, alpha, A.handle, x, beta, y)
447491
return y
448492
end
@@ -475,7 +519,7 @@ for (fname, elty) in ((:onemklSsparse_trsv, :Float32),
475519
x::oneStridedVector{$elty},
476520
y::oneStridedVector{$elty})
477521

478-
queue = global_queue(context(y), device())
522+
queue = global_queue(context(y), device(y))
479523
$fname(sycl_queue(queue), uplo, trans, diag, alpha, A.handle, x, y)
480524
y
481525
end
@@ -512,7 +556,7 @@ for (fname, elty) in (
512556
"Convert to oneSparseMatrixCSR format instead."
513557
)
514558
)
515-
queue = global_queue(context(y), device())
559+
queue = global_queue(context(y), device(y))
516560
onemklXsparse_optimize_trsv(sycl_queue(queue), uplo, flip_trans(trans), diag, A.handle)
517561
return A
518562
end
@@ -555,7 +599,7 @@ for (fname, elty) in ((:onemklSsparse_trsm, :Float32),
555599
nrhs = size(X, 2)
556600
ldx = max(1,stride(X,2))
557601
ldy = max(1,stride(Y,2))
558-
queue = global_queue(context(Y), device())
602+
queue = global_queue(context(Y), device(Y))
559603
$fname(sycl_queue(queue), 'C', transA, transX, uplo, diag, alpha, A.handle, X, nrhs, ldx, Y, ldy)
560604
Y
561605
end
@@ -614,7 +658,7 @@ for (fname, elty) in (
614658
nrhs = size(X, 2)
615659
ldx = max(1, stride(X, 2))
616660
ldy = max(1, stride(Y, 2))
617-
queue = global_queue(context(Y), device())
661+
queue = global_queue(context(Y), device(Y))
618662
$fname(sycl_queue(queue), 'C', flip_trans(transA), transX, uplo, diag, alpha, A.handle, X, nrhs, ldx, Y, ldy)
619663
return Y
620664
end

src/indexing.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,17 @@ function Base.findall(bools::oneArray{Bool})
2020
I = keytype(bools)
2121

2222
indices = cumsum(reshape(bools, prod(size(bools))))
23-
oneL0.synchronize()
2423

2524
n = isempty(indices) ? 0 : @allowscalar indices[end]
2625

2726
ys = oneArray{I}(undef, n)
2827

2928
if n > 0
30-
@oneapi items = length(bools) _ker!(ys, bools, indices)
29+
kernel = @oneapi launch=false _ker!(ys, bools, indices)
30+
group_size = launch_configuration(kernel)
31+
kernel(ys, bools, indices; items=group_size, groups=cld(length(bools), group_size))
3132
end
32-
oneL0.synchronize()
33-
unsafe_free!(indices)
33+
# unsafe_free!(indices)
3434

3535
return ys
3636
end

src/oneAPI.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ include("utils.jl")
7070
include("oneAPIKernels.jl")
7171
import .oneAPIKernels: oneAPIBackend
7272
include("accumulate.jl")
73+
include("sorting.jl")
7374
include("indexing.jl")
7475
export oneAPIBackend
7576

src/sorting.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Base.sort!(x::oneArray; kwargs...) = (AK.sort!(x; kwargs...); return x)
2+
Base.sortperm!(ix::oneArray, x::oneArray; kwargs...) = (AK.sortperm!(ix, x; kwargs...); return ix)
3+
Base.sortperm(x::oneArray; kwargs...) = sortperm!(oneArray(1:length(x)), x; kwargs...)

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1919
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2020
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2121
libigc_jll = "94295238-5935-5bd7-bb0f-b00942e9bdd5"
22+
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
2223
oneAPI_Support_jll = "b049733a-a71d-5ed3-8eba-7d323ac00b36"

test/indexing.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,18 @@ using oneAPI
1717
data = oneArray(collect(1:6))
1818
mask = oneArray(Bool[true, false, true, false, false, true])
1919
@test Array(data[mask]) == collect(1:6)[findall(Bool[true, false, true, false, false, true])]
20+
21+
# Test with array larger than 1024 to trigger multiple groups
22+
large_size = 2048
23+
large_mask = oneArray(rand(Bool, large_size))
24+
large_result_gpu = Array(findall(large_mask))
25+
large_result_cpu = findall(Array(large_mask))
26+
@test large_result_gpu == large_result_cpu
27+
28+
# Test with even larger array to ensure robustness
29+
very_large_size = 5000
30+
very_large_mask = oneArray(fill(true, very_large_size)) # all true for predictable result
31+
very_large_result_gpu = Array(findall(very_large_mask))
32+
very_large_result_cpu = findall(fill(true, very_large_size))
33+
@test very_large_result_gpu == very_large_result_cpu
2034
end

test/onemkl.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,10 @@ end
10901090
B = oneSparseMatrixCSR(A)
10911091
A2 = SparseMatrixCSC(B)
10921092
@test A == A2
1093+
C = oneSparseMatrixCSR(B.rowPtr, B.colVal, B.nzVal, size(B))
1094+
A3 = SparseMatrixCSC(C)
1095+
@test A == A3
1096+
D = oneSparseMatrixCSR(oneVector(S[]), oneVector(S[]), oneVector(T[]), (0, 0)) # empty matrix
10931097
end
10941098
end
10951099

@@ -1101,6 +1105,10 @@ end
11011105
B = oneSparseMatrixCSC(A)
11021106
A2 = SparseMatrixCSC(B)
11031107
@test A == A2
1108+
C = oneSparseMatrixCSC(A.colptr |> oneVector, A.rowval |> oneVector, A.nzval |> oneVector, size(A))
1109+
A3 = SparseMatrixCSC(C)
1110+
@test A == A3
1111+
D = oneSparseMatrixCSC(oneVector(S[]), oneVector(S[]), oneVector(T[]), (0, 0)) # empty matrix
11041112
end
11051113
end
11061114

test/sorting.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using Test
2+
using oneAPI
3+
4+
@testset "sorting" begin
5+
data = oneArray([3, 1, 4, 1, 5])
6+
sort!(data)
7+
@test Array(data) == [1, 1, 3, 4, 5]
8+
9+
data_rev = oneArray([3, 1, 4, 1, 5])
10+
sort!(data_rev, rev = true)
11+
@test Array(data_rev) == [5, 4, 3, 1, 1]
12+
data = oneArray([3, 1, 4, 1, 5])
13+
@test Array(sortperm(data)) == sortperm([3, 1, 4, 1, 5])
14+
15+
data_rev = oneArray([3, 1, 4, 1, 5])
16+
@test Array(sortperm(data_rev, rev = true)) == sortperm([3, 1, 4, 1, 5], rev = true)
17+
end

0 commit comments

Comments
 (0)