@@ -466,7 +466,7 @@ padmodclamp_replace(s, store, inside=false) = s
466
466
padmodclamp_replace (ex:: Expr , store, inside= false ) =
467
467
if ex. head == :(= ) && @capture_ (ex. args[1 ], A_[inds__])
468
468
# This tricky case is 𝛥A[pad(i,2)] = 𝛥A[pad(i,2)] + ...
469
- Aex, fun = padmodclamp_pair (A, inds, store)
469
+ Aex, fun = padmodclamp_pair (A, inds, store, true )
470
470
right = if fun != identity
471
471
padmodclamp_replace (ex. args[2 ], store, true )
472
472
else
@@ -481,7 +481,7 @@ padmodclamp_replace(ex::Expr, store, inside=false) =
481
481
Expr (ex. head, args... )
482
482
end
483
483
484
- padmodclamp_pair (A, inds, store) = begin
484
+ padmodclamp_pair (A, inds, store, assign = false ) = begin
485
485
nopadif = []
486
486
inds4 = map (enumerate (inds)) do (d,ex)
487
487
isexpr (ex, :call ) || return ex
@@ -509,7 +509,9 @@ padmodclamp_pair(A, inds, store) = begin
509
509
for c2 in nopadif[2 : end ]
510
510
cond = :($ cond & $ c2)
511
511
end
512
- if store. padkeyword == TYP # default
512
+ if assign # for gradients, this wraps 𝛥A[pad(i,2)] = 𝛥A[pad(i,2)] + ...
513
+ ex -> :($ cond && $ ex)
514
+ elseif store. padkeyword == TYP # default, pad with zero
513
515
ex -> :($ cond ? $ ex : zero (eltype ($ A)))
514
516
else
515
517
ex -> :($ cond ? $ ex : $ convert ($ eltype ($ A), $ (store. padkeyword)))
0 commit comments