Skip to content

Commit 69e30e0

Browse files
committed
Test AK accumulate
1 parent 4ef96cc commit 69e30e0

File tree

5 files changed

+14
-8
lines changed

5 files changed

+14
-8
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ uuid = "dde4c033-4e86-420c-a63e-0dd931031962"
33
version = "1.5.1"
44

55
[deps]
6+
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
67
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
78
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
89
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
@@ -32,6 +33,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3233
SpecialFunctionsExt = "SpecialFunctions"
3334

3435
[compat]
36+
AcceleratedKernels = "0.3.3"
3537
Adapt = "4"
3638
BFloat16s = "0.5"
3739
CEnum = "0.4, 0.5"

src/Metal.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using ExprTools: splitdef, combinedef
1212
using ObjectiveC, .CoreFoundation, .Foundation, .Dispatch, .OS
1313
import ObjectiveC: is_macos, darwin_version, macos_version
1414
import KernelAbstractions
15+
import AcceleratedKernels as AK
1516
using ScopedValues
1617

1718
include("version.jl")

src/accumulate.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -170,27 +170,26 @@ end
170170
## Base interface
171171

172172
Base._accumulate!(op, output::WrappedMtlArray, input::WrappedMtlVector, dims::Nothing, init::Nothing) =
173-
scan!(op, output, input; dims=1)
173+
@inline AK.accumulate!(op, output, input; dims, init=AK.neutral_element(op, eltype(output)), alg=AK.ScanPrefixes())
174174

175175
Base._accumulate!(op, output::WrappedMtlArray, input::WrappedMtlArray, dims::Integer, init::Nothing) =
176-
scan!(op, output, input; dims=dims)
177-
176+
@inline AK.accumulate!(op, output, input; dims, init=AK.neutral_element(op, eltype(output)), alg=AK.ScanPrefixes())
178177
Base._accumulate!(op, output::WrappedMtlArray, input::MtlVector, dims::Nothing, init::Some) =
179-
scan!(op, output, input; dims=1, init=init)
178+
@inline AK.accumulate!(op, output, input; dims, init=something(init), alg=AK.ScanPrefixes())
180179

181180
Base._accumulate!(op, output::WrappedMtlArray, input::WrappedMtlArray, dims::Integer, init::Some) =
182-
scan!(op, output, input; dims=dims, init=init)
181+
@inline AK.accumulate!(op, output, input; dims, init=something(init), alg=AK.ScanPrefixes())
183182

184-
Base.accumulate_pairwise!(op, result::WrappedMtlVector, v::WrappedMtlVector) = accumulate!(op, result, v)
183+
Base.accumulate_pairwise!(op, result::WrappedMtlVector, v::WrappedMtlVector) = @inline AK.accumulate!(op, result, v; init=AK.neutral_element(op, eltype(result)), alg=AK.ScanPrefixes())
185184

186185
# default behavior unless dims are specified by the user
187186
function Base.accumulate(op, A::WrappedMtlArray;
188187
dims::Union{Nothing,Integer}=nothing, kw...)
188+
nt = values(kw)
189189
if dims === nothing && !(A isa AbstractVector)
190190
# This branch takes care of the cases not handled by `_accumulate!`.
191-
return reshape(accumulate(op, A[:]; kw...), size(A))
191+
return reshape(AK.accumulate(op, A[:]; init = (:init in keys(kw) ? nt.init : AK.neutral_element(op, eltype(A))), alg=AK.ScanPrefixes()), size(A))
192192
end
193-
nt = values(kw)
194193
if isempty(kw)
195194
out = similar(A, Base.promote_op(op, eltype(A), eltype(A)))
196195
elseif keys(nt) === (:init,)

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
23
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
34
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
45
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
@@ -10,6 +11,7 @@ KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1011
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1213
ObjectiveC = "e86c9b32-1129-44ac-8ea0-90d5bb39ded9"
14+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1315
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1416
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
1517
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using Pkg
2+
Pkg.develop("AcceleratedKernels")
13
using Distributed
24
using Dates
35
using Metal

0 commit comments

Comments
 (0)