Skip to content

Commit 7f68400

Browse files
authored
Merge pull request #346 from JuliaGPU/tb/scalar_indexing_tls
Make scalar indexing toggle task-local.
2 parents 17236cc + f3cb9f0 commit 7f68400

File tree

1 file changed

+31
-29
lines changed

1 file changed

+31
-29
lines changed

src/host/indexing.jl

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,43 +5,58 @@ export allowscalar, @allowscalar, @disallowscalar, assertscalar
55

66
# mechanism to disallow scalar operations
77

8-
@enum ScalarIndexing ScalarAllowed ScalarWarned ScalarDisallowed
9-
10-
const scalar_allowed = Ref(ScalarWarned)
11-
const scalar_warned = Ref(false)
8+
@enum ScalarIndexing ScalarAllowed ScalarWarn ScalarWarned ScalarDisallowed
129

1310
"""
1411
allowscalar(allow=true, warn=true)
12+
allowscalar(allow=true, warn=true) do end
1513
1614
Configure whether scalar indexing is allowed depending on the value of `allow`.
1715
1816
If allowed, `warn` can be set to throw a single warning instead. Calling this function will
1917
reset the state of the warning, and throw a new warning on subsequent scalar iteration.
18+
19+
For temporary changes, use the do-block version, or [`@allowscalar`](@ref).
2020
"""
2121
function allowscalar(allow::Bool=true, warn::Bool=true)
22-
scalar_warned[] = false
23-
scalar_allowed[] = if allow && !warn
22+
val = if allow && !warn
2423
ScalarAllowed
2524
elseif allow
26-
ScalarWarned
25+
ScalarWarn
2726
else
2827
ScalarDisallowed
2928
end
29+
30+
task_local_storage(:ScalarIndexing, val)
3031
return
3132
end
3233

34+
@doc (@doc allowscalar) ->
35+
function allowscalar(f::Base.Callable, allow::Bool=true, warn::Bool=false)
36+
val = if allow && !warn
37+
ScalarAllowed
38+
elseif allow
39+
ScalarWarn
40+
else
41+
ScalarDisallowed
42+
end
43+
44+
task_local_storage(f, :ScalarIndexing, val)
45+
end
46+
3347
"""
3448
assertscalar(op::String)
3549
3650
Assert that a certain operation `op` performs scalar indexing. If this is not allowed, an
3751
error will be thrown ([`allowscalar`](@ref)).
3852
"""
3953
function assertscalar(op = "operation")
40-
if scalar_allowed[] == ScalarDisallowed
54+
val = get(task_local_storage(), :ScalarIndexing, ScalarWarn)
55+
if val == ScalarDisallowed
4156
error("$op is disallowed")
42-
elseif scalar_allowed[] == ScalarWarned && !scalar_warned[]
57+
elseif val == ScalarWarn
4358
@warn "Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`"
44-
scalar_warned[] = true
59+
task_local_storage(:ScalarIndexing, ScalarWarned)
4560
end
4661
return
4762
end
@@ -59,34 +74,21 @@ fine-grained expressions.
5974
"""
6075
macro allowscalar(ex)
6176
quote
62-
local prev = scalar_allowed[]
63-
scalar_allowed[] = ScalarAllowed
64-
local ret = $(esc(ex))
65-
scalar_allowed[] = prev
66-
ret
77+
task_local_storage(:ScalarIndexing, ScalarAllowed) do
78+
$(esc(ex))
79+
end
6780
end
6881
end
6982

7083
@doc (@doc @allowscalar) ->
7184
macro disallowscalar(ex)
7285
quote
73-
local prev = scalar_allowed[]
74-
scalar_allowed[] = ScalarDisallowed
75-
local ret = $(esc(ex))
76-
scalar_allowed[] = prev
77-
ret
86+
task_local_storage(:ScalarIndexing, ScalarDisallowed) do
87+
$(esc(ex))
88+
end
7889
end
7990
end
8091

81-
@doc (@doc @allowscalar) ->
82-
function allowscalar(f::Base.Callable, allow::Bool=true, warn::Bool=false)
83-
prev = scalar_allowed[]
84-
allowscalar(allow, warn)
85-
ret = f()
86-
scalar_allowed[] = prev
87-
ret
88-
end
89-
9092

9193
# basic indexing with integers
9294

0 commit comments

Comments
 (0)