@@ -209,10 +209,14 @@ function frule((_, ẋs...), ::typeof(hcat), xs...)
209209 return hcat (xs... ), hcat (_instantiate_zeros (ẋs, xs)... )
210210end
211211
212- function rrule (:: typeof (hcat), Xs:: Union{AbstractArray, Number} ...)
212+ # All the [hv]cat functions treat anything that's not an array as a scalar.
213+ _catsize (x) = ()
214+ _catsize (x:: AbstractArray ) = size (x)
215+
216+ function rrule (:: typeof (hcat), Xs... )
213217 Y = hcat (Xs... ) # note that Y always has 1-based indexing, even if X isa OffsetArray
214218 ndimsY = Val (ndims (Y)) # this avoids closing over Y, Val() is essential for type-stability
215- sizes = map (size , Xs) # this avoids closing over Xs
219+ sizes = map (_catsize , Xs) # this avoids closing over Xs
216220 project_Xs = map (ProjectTo, Xs)
217221 function hcat_pullback (ȳ)
218222 dY = unthunk (ȳ)
@@ -279,10 +283,10 @@ function frule((_, ẋs...), ::typeof(vcat), xs...)
279283 return vcat (xs... ), vcat (_instantiate_zeros (ẋs, xs)... )
280284end
281285
282- function rrule (:: typeof (vcat), Xs:: Union{AbstractArray, Number} ...)
286+ function rrule (:: typeof (vcat), Xs... )
283287 Y = vcat (Xs... )
284288 ndimsY = Val (ndims (Y))
285- sizes = map (size , Xs)
289+ sizes = map (_catsize , Xs)
286290 project_Xs = map (ProjectTo, Xs)
287291 function vcat_pullback (ȳ)
288292 dY = unthunk (ȳ)
@@ -342,11 +346,11 @@ function frule((_, ẋs...), ::typeof(cat), xs...; dims)
342346 return cat (xs... ; dims), cat (_instantiate_zeros (ẋs, xs)... ; dims)
343347end
344348
345- function rrule (:: typeof (cat), Xs:: Union{AbstractArray, Number} ...; dims)
349+ function rrule (:: typeof (cat), Xs... ; dims)
346350 Y = cat (Xs... ; dims= dims)
347351 cdims = dims isa Val ? Int (_val (dims)) : dims isa Integer ? Int (dims) : Tuple (dims)
348352 ndimsY = Val (ndims (Y))
349- sizes = map (size , Xs)
353+ sizes = map (_catsize , Xs)
350354 project_Xs = map (ProjectTo, Xs)
351355 function cat_pullback (ȳ)
352356 dY = unthunk (ȳ)
@@ -384,11 +388,11 @@ function frule((_, _, ẋs...), ::typeof(hvcat), rows, xs...)
384388 return hvcat (rows, xs... ), hvcat (rows, _instantiate_zeros (ẋs, xs)... )
385389end
386390
387- function rrule (:: typeof (hvcat), rows, values:: Union{AbstractArray, Number} ...)
391+ function rrule (:: typeof (hvcat), rows, values... )
388392 Y = hvcat (rows, values... )
389393 cols = size (Y,2 )
390394 ndimsY = Val (ndims (Y))
391- sizes = map (size , values)
395+ sizes = map (_catsize , values)
392396 project_Vs = map (ProjectTo, values)
393397 function hvcat_pullback (dY)
394398 prev = fill (0 , 2 )
0 commit comments