@@ -137,6 +137,7 @@ function parse_options(exs...)
137
137
)
138
138
expr = nothing
139
139
nograd = Symbol[]
140
+ safe = Symbol[]
140
141
ranges = Tuple[]
141
142
for ex in exs
142
143
# Actual options:
@@ -160,6 +161,16 @@ function parse_options(exs...)
160
161
throw (" this accepts nograd=A or nograd=(A,B,C)" )
161
162
end
162
163
164
+ # Safe keyword
165
+ elseif isexpr (ex, :(= )) && ex. args[1 ] == :safe
166
+ if ex. args[2 ] isa Symbol
167
+ push! (safe, ex. args[2 ])
168
+ elseif isexpr (ex. args[2 ], :tuple )
169
+ append! (safe, ex. args[2 ]. args)
170
+ else
171
+ throw (" this accepts safe=i or safe=(i,j,k)" )
172
+ end
173
+
163
174
# Ranges specified outside:
164
175
elseif isexpr (ex, :call ) && ex. args[1 ] in [:in , :∈ ]
165
176
push! (ranges, (ex. args[2 ], ex. args[3 ]))
@@ -201,6 +212,7 @@ function parse_options(exs...)
201
212
cuda= opts[:cuda ],
202
213
tensor= opts[:tensor ],
203
214
nograd= nograd,
215
+ safe= safe,
204
216
), ranges, expr
205
217
end
206
218
@@ -586,7 +598,7 @@ detectunsafe(expr, list, store) = MacroTools_postwalk(expr) do ex
586
598
MacroTools_postwalk (i) do x
587
599
@capture_ (x, B_[inner__]) || return x
588
600
# Now we have found an array which indexes another one, mark its indices unsafe
589
- append! (list, filter (j -> j isa Symbol, inner))
601
+ append! (list, setdiff ( filter (j -> j isa Symbol, inner), store . safe ))
590
602
unique! (list)
591
603
# and don't compute a gradient for the inner array
592
604
B isa Symbol && push! (store. nograd, B)
0 commit comments