Skip to content

Commit 89b2318

Browse files
author
Michael Abbott
committed
add safe keyword, opposite of unsafe(left/right)
1 parent ebf5c70 commit 89b2318

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

src/macro.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ function parse_options(exs...)
137137
)
138138
expr = nothing
139139
nograd = Symbol[]
140+
safe = Symbol[]
140141
ranges = Tuple[]
141142
for ex in exs
142143
# Actual options:
@@ -160,6 +161,16 @@ function parse_options(exs...)
160161
throw("this accepts nograd=A or nograd=(A,B,C)")
161162
end
162163

164+
# Safe keyword
165+
elseif isexpr(ex, :(=)) && ex.args[1] == :safe
166+
if ex.args[2] isa Symbol
167+
push!(safe, ex.args[2])
168+
elseif isexpr(ex.args[2], :tuple)
169+
append!(safe, ex.args[2].args)
170+
else
171+
throw("this accepts safe=i or safe=(i,j,k)")
172+
end
173+
163174
# Ranges specified outside:
164175
elseif isexpr(ex, :call) && ex.args[1] in [:in, :]
165176
push!(ranges, (ex.args[2], ex.args[3]))
@@ -201,6 +212,7 @@ function parse_options(exs...)
201212
cuda=opts[:cuda],
202213
tensor=opts[:tensor],
203214
nograd=nograd,
215+
safe=safe,
204216
), ranges, expr
205217
end
206218

@@ -586,7 +598,7 @@ detectunsafe(expr, list, store) = MacroTools_postwalk(expr) do ex
586598
MacroTools_postwalk(i) do x
587599
@capture_(x, B_[inner__]) || return x
588600
# Now we have found an array which indexes another one, mark its indices unsafe
589-
append!(list, filter(j -> j isa Symbol, inner))
601+
append!(list, setdiff(filter(j -> j isa Symbol, inner), store.safe))
590602
unique!(list)
591603
# and don't compute a gradient for the inner array
592604
B isa Symbol && push!(store.nograd, B)

0 commit comments

Comments
 (0)