@@ -5,43 +5,58 @@ export allowscalar, @allowscalar, @disallowscalar, assertscalar
5
5
6
6
# mechanism to disallow scalar operations
7
7
8
- @enum ScalarIndexing ScalarAllowed ScalarWarned ScalarDisallowed
9
-
10
- const scalar_allowed = Ref (ScalarWarned)
11
- const scalar_warned = Ref (false )
8
+ @enum ScalarIndexing ScalarAllowed ScalarWarn ScalarWarned ScalarDisallowed
12
9
13
10
"""
14
11
allowscalar(allow=true, warn=true)
12
+ allowscalar(allow=true, warn=true) do end
15
13
16
14
Configure whether scalar indexing is allowed depending on the value of `allow`.
17
15
18
16
If allowed, `warn` can be set to throw a single warning instead. Calling this function will
19
17
reset the state of the warning, and throw a new warning on subsequent scalar iteration.
18
+
19
+ For temporary changes, use the do-block version, or [`@allowscalar`](@ref).
20
20
"""
21
21
function allowscalar (allow:: Bool = true , warn:: Bool = true )
22
- scalar_warned[] = false
23
- scalar_allowed[] = if allow && ! warn
22
+ val = if allow && ! warn
24
23
ScalarAllowed
25
24
elseif allow
26
- ScalarWarned
25
+ ScalarWarn
27
26
else
28
27
ScalarDisallowed
29
28
end
29
+
30
+ task_local_storage (:ScalarIndexing , val)
30
31
return
31
32
end
32
33
34
+ @doc (@doc allowscalar) ->
35
+ function allowscalar (f:: Base.Callable , allow:: Bool = true , warn:: Bool = false )
36
+ val = if allow && ! warn
37
+ ScalarAllowed
38
+ elseif allow
39
+ ScalarWarn
40
+ else
41
+ ScalarDisallowed
42
+ end
43
+
44
+ task_local_storage (f, :ScalarIndexing , val)
45
+ end
46
+
33
47
"""
34
48
assertscalar(op::String)
35
49
36
50
Assert that a certain operation `op` performs scalar indexing. If this is not allowed, an
37
51
error will be thrown ([`allowscalar`](@ref)).
38
52
"""
39
53
function assertscalar (op = " operation" )
40
- if scalar_allowed[] == ScalarDisallowed
54
+ val = get (task_local_storage (), :ScalarIndexing , ScalarWarn)
55
+ if val == ScalarDisallowed
41
56
error (" $op is disallowed" )
42
- elseif scalar_allowed[] == ScalarWarned && ! scalar_warned[]
57
+ elseif val == ScalarWarn
43
58
@warn " Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`"
44
- scalar_warned[] = true
59
+ task_local_storage ( :ScalarIndexing , ScalarWarned)
45
60
end
46
61
return
47
62
end
@@ -59,34 +74,21 @@ fine-grained expressions.
59
74
"""
60
75
macro allowscalar (ex)
61
76
quote
62
- local prev = scalar_allowed[]
63
- scalar_allowed[] = ScalarAllowed
64
- local ret = $ (esc (ex))
65
- scalar_allowed[] = prev
66
- ret
77
+ task_local_storage (:ScalarIndexing , ScalarAllowed) do
78
+ $ (esc (ex))
79
+ end
67
80
end
68
81
end
69
82
70
83
@doc (@doc @allowscalar ) ->
71
84
macro disallowscalar (ex)
72
85
quote
73
- local prev = scalar_allowed[]
74
- scalar_allowed[] = ScalarDisallowed
75
- local ret = $ (esc (ex))
76
- scalar_allowed[] = prev
77
- ret
86
+ task_local_storage (:ScalarIndexing , ScalarDisallowed) do
87
+ $ (esc (ex))
88
+ end
78
89
end
79
90
end
80
91
81
- @doc (@doc @allowscalar ) ->
82
- function allowscalar (f:: Base.Callable , allow:: Bool = true , warn:: Bool = false )
83
- prev = scalar_allowed[]
84
- allowscalar (allow, warn)
85
- ret = f ()
86
- scalar_allowed[] = prev
87
- ret
88
- end
89
-
90
92
91
93
# basic indexing with integers
92
94
0 commit comments