@@ -216,6 +216,7 @@ _catsize(x::AbstractArray) = size(x)
216216
217217function rrule (:: typeof (hcat), Xs... )
218218 Y = hcat (Xs... ) # note that Y always has 1-based indexing, even if X isa OffsetArray
219+ Base. require_one_based_indexing (Y)
219220 ndimsY = Val (ndims (Y)) # this avoids closing over Y, Val() is essential for type-stability
220221 sizes = map (_catsize, Xs) # this avoids closing over Xs
221222 project_Xs = map (ProjectTo, Xs)
@@ -233,15 +234,10 @@ function rrule(::typeof(hcat), Xs...)
233234 d > ndimsX ? 1 : (:)
234235 end
235236 end
236- dX = if ndimsX > 0
237- # Here InplaceableThunk breaks @inferred, removed for now
238- # InplaceableThunk(dX -> dX .+= view(dY, ind...), @thunk(dY[ind...]))
239- dY[ind... ]
240- else
241- # This is a hack to perhaps avoid GPU scalar indexing
242- sum (view (dY, ind... ))
243- end
244- return project (dX)
237+ InplaceableThunk (
238+ dX -> dX .+ = view (dY, ind... ),
239+ @thunk project (@allowscalar dY[ind... ])
240+ )
245241 end
246242 return (NoTangent (), dXs... )
247243 end
@@ -253,6 +249,8 @@ function frule((_, _, Ȧs), ::typeof(reduce), ::typeof(hcat), As::AbstractVecto
253249end
254250
255251function rrule (:: typeof (reduce), :: typeof (hcat), As:: AbstractVector{<:AbstractVecOrMat} )
252+ Y = reduce (hcat, As)
253+ Base. require_one_based_indexing (Y)
256254 widths = map (A -> size (A,2 ), As)
257255 function reduce_hcat_pullback_2 (dY)
258256 hi = Ref (0 )
@@ -263,7 +261,7 @@ function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVe
263261 end
264262 return (NoTangent (), NoTangent (), dAs)
265263 end
266- return reduce (hcat, As) , reduce_hcat_pullback_2
264+ return Y , reduce_hcat_pullback_2
267265end
268266
269267function rrule (:: typeof (reduce), :: typeof (hcat), As:: AbstractVector{<:AbstractVector} )
286284
287285function rrule (:: typeof (vcat), Xs... )
288286 Y = vcat (Xs... )
287+ Base. require_one_based_indexing (Y)
289288 ndimsY = Val (ndims (Y))
290289 sizes = map (_catsize, Xs)
291290 project_Xs = map (ProjectTo, Xs)
@@ -303,13 +302,10 @@ function rrule(::typeof(vcat), Xs...)
303302 d > ndimsX ? 1 : (:)
304303 end
305304 end
306- dX = if ndimsX > 0
307- # InplaceableThunk(@thunk(dY[ind...]), dX -> dX .+= view(dY, ind...))
308- dY[ind... ]
309- else
310- sum (view (dY, ind... ))
311- end
312- return project (dX)
305+ InplaceableThunk (
306+ dX -> dX .+ = view (dY, ind... ),
307+ @thunk project (@allowscalar dY[ind... ])
308+ )
313309 end
314310 return (NoTangent (), dXs... )
315311 end
322318
323319function rrule (:: typeof (reduce), :: typeof (vcat), As:: AbstractVector{<:AbstractVecOrMat} )
324320 Y = reduce (vcat, As)
321+ Base. require_one_based_indexing (Y)
325322 ndimsY = Val (ndims (Y))
326323 heights = map (A -> size (A,1 ), As)
327324 function reduce_vcat_pullback (dY)
349346
350347function rrule (:: typeof (cat), Xs... ; dims)
351348 Y = cat (Xs... ; dims= dims)
349+ Base. require_one_based_indexing (Y)
352350 cdims = dims isa Val ? Int (_val (dims)) : dims isa Integer ? Int (dims) : Tuple (dims)
353351 ndimsY = Val (ndims (Y))
354352 sizes = map (_catsize, Xs)
@@ -368,13 +366,10 @@ function rrule(::typeof(cat), Xs...; dims)
368366 for d in cdims
369367 prev[d] += get (sizeX, d, 1 )
370368 end
371- dX = if ndimsX > 0
372- # InplaceableThunk(@thunk(dY[index...]), dX -> dX .+= view(dY, index...))
373- dY[index... ]
374- else
375- sum (view (dY, index... ))
376- end
377- return project (dX)
369+ InplaceableThunk (
370+ dX -> dX .+ = view (dY, index... ),
371+ @thunk project (@allowscalar dY[index... ])
372+ )
378373 end
379374 return (NoTangent (), dXs... )
380375 end
0 commit comments