@@ -105,21 +105,43 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
105105 precondition (
106106 input. shape [ positiveAxis] == offset. shape [ 0 ] ,
107107 " The number of features of the input and the offset doesn't match. " )
108- var ( offset, scale) = { x in ( x. offset, x. scale) } ( self )
109- if positiveAxis != input. rank - 1 {
110- var broadcastShape = TensorShape ( [ Int] ( repeating: 1 , count: input. rank) )
111- broadcastShape [ positiveAxis] = input. shape [ positiveAxis]
112- offset = offset. reshaped ( to: broadcastShape)
113-
114- scale = scale. reshaped ( to: broadcastShape)
115- }
108+ // var (offset, scale) = {x in (x.offset, x.scale) }(self)
109+ // if positiveAxis != input.rank - 1 {
110+ // var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank))
111+ // broadcastShape[positiveAxis] = input.shape[positiveAxis]
112+ // offset = offset.reshaped(to: broadcastShape)
113+ // scale = scale.reshaped(to: broadcastShape)
114+ // }
115+ let offsetOriginal = self . offset
116+ let scaleOriginal = self . scale
117+ let ( offset, scale) = Self . _sr13263workaround ( offset: offsetOriginal,
118+ scale: scaleOriginal,
119+ input: input,
120+ positiveAxis: positiveAxis)
116121 switch Context . local. learningPhase {
117122 case . training:
118123 return doTraining ( input, offset: offset, scale: scale, axis: positiveAxis)
119124 case . inference:
120125 return doInference ( input, offset: offset, scale: scale)
121126 }
122127 }
128+
129+ @inline ( never)
130+ @differentiable ( reverse) // if the function is `public` or `internal`, the compiler crashes
131+ private static func _sr13263workaround(
132+ offset: Tensor < Scalar > ,
133+ scale: Tensor < Scalar > ,
134+ input: Tensor < Scalar > ,
135+ positiveAxis: Int
136+ ) -> ( Tensor < Scalar > , Tensor < Scalar > ) {
137+ if positiveAxis != input. rank - 1 {
138+ var broadcastShape = TensorShape ( [ Int] ( repeating: 1 , count: input. rank) )
139+ broadcastShape [ positiveAxis] = input. shape [ positiveAxis]
140+ return ( offset. reshaped ( to: broadcastShape) , scale. reshaped ( to: broadcastShape) )
141+ } else {
142+ return ( offset, scale)
143+ }
144+ }
123145
124146 private func doTraining(
125147 _ input: Tensor < Scalar > , offset: Tensor < Scalar > , scale: Tensor < Scalar > , axis: Int
0 commit comments