Skip to content

Commit 36661e3

Browse files
authored
Avoid cartesian iteration where possible. (#454)
1 parent 53a7b5c commit 36661e3

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

src/host/broadcast.jl

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,22 +51,53 @@ end
5151
bc′ = Broadcast.preprocess(dest, bc)
5252

5353
# grid-stride kernel
54-
function broadcast_kernel(ctx, dest, bc′, nelem)
55-
i = 0
56-
while i < nelem
57-
i += 1
58-
I = @cartesianidx(dest, i)
59-
@inbounds dest[I] = bc′[I]
54+
function broadcast_kernel(ctx, dest, ::Val{Is}, bc′, nelem) where Is
55+
j = 0
56+
while j < nelem
57+
j += 1
58+
59+
i = @linearidx(dest, j)
60+
61+
# cartesian indexing is slow, so avoid it if possible
62+
if isa(IndexStyle(dest), IndexCartesian) || isa(IndexStyle(bc′), IndexCartesian)
63+
# this performs an integer division, which is expensive. to make it possible
64+
# for the compiler to optimize it away, we put the iterator in the type
65+
# domain so that the indices are available at compile time. note that LLVM
66+
# only seems to replace pow2 divisions (with bitshifts), but other back-ends
67+
# may be smarter and replace arbitrary divisions by bit operations.
68+
#
69+
# also see maleadt/StaticCartesian.jl, which implements this in Julia,
70+
# but does not result in an additional speed-up on tested back-ends.
71+
#
72+
# in addition, we use @inbounds to avoid bounds checks, but we also need to
73+
# inform the compiler about the bounds that we are assuming. this is done
74+
# using the assume intrinsic, and in case of Metal yields a 8x speed-up.
75+
assume(1 <= i <= length(Is))
76+
I = @inbounds Is[i]
77+
end
78+
79+
val = if isa(IndexStyle(bc′), IndexCartesian)
80+
@inbounds bc′[I]
81+
else
82+
@inbounds bc′[i]
83+
end
84+
85+
if isa(IndexStyle(dest), IndexCartesian)
86+
@inbounds dest[I] = val
87+
else
88+
@inbounds dest[i] = val
89+
end
6090
end
6191
return
6292
end
6393
elements = length(dest)
6494
elements_per_thread = typemax(Int)
65-
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc′, 1;
95+
Is = CartesianIndices(dest)
96+
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, Val(Is), bc′, 1;
6697
elements, elements_per_thread)
6798
config = launch_configuration(backend(dest), heuristic;
6899
elements, elements_per_thread)
69-
gpu_call(broadcast_kernel, dest, bc′, config.elements_per_thread;
100+
gpu_call(broadcast_kernel, dest, Val(Is), bc′, config.elements_per_thread;
70101
threads=config.threads, blocks=config.blocks)
71102

72103
return dest

src/host/math.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
function Base.clamp!(A::AnyGPUArray, low, high)
44
gpu_call(A, low, high) do ctx, A, low, high
5-
I = @cartesianidx A
5+
I = @linearidx A
66
A[I] = clamp(A[I], low, high)
77
return
88
end

0 commit comments

Comments
 (0)