@@ -216,6 +216,7 @@ _catsize(x::AbstractArray) = size(x)
216
216
217
217
function rrule (:: typeof (hcat), Xs... )
218
218
Y = hcat (Xs... ) # note that Y always has 1-based indexing, even if X isa OffsetArray
219
+ Base. require_one_based_indexing (Y)
219
220
ndimsY = Val (ndims (Y)) # this avoids closing over Y, Val() is essential for type-stability
220
221
sizes = map (_catsize, Xs) # this avoids closing over Xs
221
222
project_Xs = map (ProjectTo, Xs)
@@ -233,15 +234,10 @@ function rrule(::typeof(hcat), Xs...)
233
234
d > ndimsX ? 1 : (:)
234
235
end
235
236
end
236
- dX = if ndimsX > 0
237
- # Here InplaceableThunk breaks @inferred, removed for now
238
- # InplaceableThunk(dX -> dX .+= view(dY, ind...), @thunk(dY[ind...]))
239
- dY[ind... ]
240
- else
241
- # This is a hack to perhaps avoid GPU scalar indexing
242
- sum (view (dY, ind... ))
243
- end
244
- return project (dX)
237
+ InplaceableThunk (
238
+ dX -> dX .+ = view (dY, ind... ),
239
+ @thunk project (@allowscalar dY[ind... ])
240
+ )
245
241
end
246
242
return (NoTangent (), dXs... )
247
243
end
@@ -253,6 +249,8 @@ function frule((_, _, Ȧs), ::typeof(reduce), ::typeof(hcat), As::AbstractVecto
253
249
end
254
250
255
251
function rrule (:: typeof (reduce), :: typeof (hcat), As:: AbstractVector{<:AbstractVecOrMat} )
252
+ Y = reduce (hcat, As)
253
+ Base. require_one_based_indexing (Y)
256
254
widths = map (A -> size (A,2 ), As)
257
255
function reduce_hcat_pullback_2 (dY)
258
256
hi = Ref (0 )
@@ -263,7 +261,7 @@ function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVe
263
261
end
264
262
return (NoTangent (), NoTangent (), dAs)
265
263
end
266
- return reduce (hcat, As) , reduce_hcat_pullback_2
264
+ return Y , reduce_hcat_pullback_2
267
265
end
268
266
269
267
function rrule (:: typeof (reduce), :: typeof (hcat), As:: AbstractVector{<:AbstractVector} )
286
284
287
285
function rrule (:: typeof (vcat), Xs... )
288
286
Y = vcat (Xs... )
287
+ Base. require_one_based_indexing (Y)
289
288
ndimsY = Val (ndims (Y))
290
289
sizes = map (_catsize, Xs)
291
290
project_Xs = map (ProjectTo, Xs)
@@ -303,13 +302,10 @@ function rrule(::typeof(vcat), Xs...)
303
302
d > ndimsX ? 1 : (:)
304
303
end
305
304
end
306
- dX = if ndimsX > 0
307
- # InplaceableThunk(@thunk(dY[ind...]), dX -> dX .+= view(dY, ind...))
308
- dY[ind... ]
309
- else
310
- sum (view (dY, ind... ))
311
- end
312
- return project (dX)
305
+ InplaceableThunk (
306
+ dX -> dX .+ = view (dY, ind... ),
307
+ @thunk project (@allowscalar dY[ind... ])
308
+ )
313
309
end
314
310
return (NoTangent (), dXs... )
315
311
end
322
318
323
319
function rrule (:: typeof (reduce), :: typeof (vcat), As:: AbstractVector{<:AbstractVecOrMat} )
324
320
Y = reduce (vcat, As)
321
+ Base. require_one_based_indexing (Y)
325
322
ndimsY = Val (ndims (Y))
326
323
heights = map (A -> size (A,1 ), As)
327
324
function reduce_vcat_pullback (dY)
349
346
350
347
function rrule (:: typeof (cat), Xs... ; dims)
351
348
Y = cat (Xs... ; dims= dims)
349
+ Base. require_one_based_indexing (Y)
352
350
cdims = dims isa Val ? Int (_val (dims)) : dims isa Integer ? Int (dims) : Tuple (dims)
353
351
ndimsY = Val (ndims (Y))
354
352
sizes = map (_catsize, Xs)
@@ -368,13 +366,10 @@ function rrule(::typeof(cat), Xs...; dims)
368
366
for d in cdims
369
367
prev[d] += get (sizeX, d, 1 )
370
368
end
371
- dX = if ndimsX > 0
372
- # InplaceableThunk(@thunk(dY[index...]), dX -> dX .+= view(dY, index...))
373
- dY[index... ]
374
- else
375
- sum (view (dY, index... ))
376
- end
377
- return project (dX)
369
+ InplaceableThunk (
370
+ dX -> dX .+ = view (dY, index... ),
371
+ @thunk project (@allowscalar dY[index... ])
372
+ )
378
373
end
379
374
return (NoTangent (), dXs... )
380
375
end
0 commit comments