|
51 | 51 | bc′ = Broadcast.preprocess(dest, bc)
|
52 | 52 |
|
53 | 53 | # grid-stride kernel
|
54 |
| - function broadcast_kernel(ctx, dest, bc′, nelem) |
55 |
| - i = 0 |
56 |
| - while i < nelem |
57 |
| - i += 1 |
58 |
| - I = @cartesianidx(dest, i) |
59 |
| - @inbounds dest[I] = bc′[I] |
| 54 | + function broadcast_kernel(ctx, dest, ::Val{Is}, bc′, nelem) where Is |
| 55 | + j = 0 |
| 56 | + while j < nelem |
| 57 | + j += 1 |
| 58 | + |
| 59 | + i = @linearidx(dest, j) |
| 60 | + |
| 61 | + # cartesian indexing is slow, so avoid it if possible |
| 62 | + if isa(IndexStyle(dest), IndexCartesian) || isa(IndexStyle(bc′), IndexCartesian) |
| 63 | + # this performs an integer division, which is expensive. to make it possible |
| 64 | + # for the compiler to optimize it away, we put the iterator in the type |
| 65 | + # domain so that the indices are available at compile time. note that LLVM |
| 66 | + # only seems to replace pow2 divisions (with bitshifts), but other back-ends |
| 67 | + # may be smarter and replace arbitrary divisions by bit operations. |
| 68 | + # |
| 69 | + # also see maleadt/StaticCartesian.jl, which implements this in Julia, |
| 70 | + # but does not result in an additional speed-up on tested back-ends. |
| 71 | + # |
| 72 | + # in addition, we use @inbounds to avoid bounds checks, but we also need to |
| 73 | + # inform the compiler about the bounds that we are assuming. this is done |
| 74 | + # using the assume intrinsic, and in case of Metal yields a 8x speed-up. |
| 75 | + assume(1 <= i <= length(Is)) |
| 76 | + I = @inbounds Is[i] |
| 77 | + end |
| 78 | + |
| 79 | + val = if isa(IndexStyle(bc′), IndexCartesian) |
| 80 | + @inbounds bc′[I] |
| 81 | + else |
| 82 | + @inbounds bc′[i] |
| 83 | + end |
| 84 | + |
| 85 | + if isa(IndexStyle(dest), IndexCartesian) |
| 86 | + @inbounds dest[I] = val |
| 87 | + else |
| 88 | + @inbounds dest[i] = val |
| 89 | + end |
60 | 90 | end
|
61 | 91 | return
|
62 | 92 | end
|
63 | 93 | elements = length(dest)
|
64 | 94 | elements_per_thread = typemax(Int)
|
65 |
| - heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc′, 1; |
| 95 | + Is = CartesianIndices(dest) |
| 96 | + heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, Val(Is), bc′, 1; |
66 | 97 | elements, elements_per_thread)
|
67 | 98 | config = launch_configuration(backend(dest), heuristic;
|
68 | 99 | elements, elements_per_thread)
|
69 |
| - gpu_call(broadcast_kernel, dest, bc′, config.elements_per_thread; |
| 100 | + gpu_call(broadcast_kernel, dest, Val(Is), bc′, config.elements_per_thread; |
70 | 101 | threads=config.threads, blocks=config.blocks)
|
71 | 102 |
|
72 | 103 | return dest
|
|
0 commit comments