@@ -129,22 +129,27 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint,
129129 Scalar. RawSignificand: FixedWidthInteger {
130130 /// Performs Glorot uniform initialization for the specified shape, creating a tensor by
131131 /// randomly sampling scalar values from a uniform distribution between `-limit` and `limit`,
132- /// where limit is `sqrt(6 / (fanIn + fanOut))`.
132+ /// where limit is `sqrt(6 / (fanIn + fanOut))` and `fanIn`/`fanOut` represent the number of
133+ /// input and output features multiplied by the receptive field if present.
133134 ///
134135 /// - Parameters:
135136 /// - shape: The dimensions of the tensor.
136137 /// - generator: Random number generator to use.
137138 ///
138139 init < G: RandomNumberGenerator > ( glorotUniform shape: TensorShape , generator: inout G ) {
139- let fanIn = shape [ shape. count - 2 ]
140- let fanOut = shape [ shape. count - 1 ]
140+ let spatialDimCount = shape. count - 2
141+ let receptiveField = shape [ 0 ..< spatialDimCount] . contiguousSize
142+ let fanIn = shape [ shape. count - 2 ] * receptiveField
143+ let fanOut = shape [ shape. count - 1 ] * receptiveField
141144 let minusOneToOne = 2 * Tensor( randomUniform: shape, generator: & generator) - 1
142145 self = sqrt ( Scalar ( 6 ) / Scalar( fanIn + fanOut) ) * minusOneToOne
143146 }
144147
145148 /// Creates a tensor by performing Glorot uniform initialization for the specified shape,
146149 /// randomly sampling scalar values from a uniform distribution between `-limit` and `limit`,
147- /// where limit is `sqrt(6 / (fanIn + fanOut))`, using the default random number generator.
150+ /// generated by the default random number generator, where limit is
151+ /// `sqrt(6 / (fanIn + fanOut))` and `fanIn`/`fanOut` represent the number of input and output
152+ /// features multiplied by the receptive field if present.
148153 ///
149154 /// - Parameters:
150155 /// - shape: The dimensions of the tensor.
0 commit comments