@@ -132,6 +132,7 @@ protected override Tensor call(Tensor inputs, Tensor training = null)
132132 if ( fused )
133133 {
134134 outputs = _fused_batch_norm ( inputs , training : training ) ;
135+ return outputs ;
135136 }
136137
137138 throw new NotImplementedException ( "BatchNormalization call" ) ;
@@ -142,7 +143,7 @@ private Tensor _fused_batch_norm(Tensor inputs, Tensor training)
142143 var beta = this . beta ;
143144 var gamma = this . gamma ;
144145
145- Func < ( Tensor , Tensor , Tensor ) > _fused_batch_norm_training = ( ) =>
146+ Func < Tensor [ ] > _fused_batch_norm_training = ( ) =>
146147 {
147148 return tf . nn . fused_batch_norm (
148149 inputs ,
@@ -152,7 +153,7 @@ private Tensor _fused_batch_norm(Tensor inputs, Tensor training)
152153 data_format : _data_format ) ;
153154 } ;
154155
155- Func < ( Tensor , Tensor , Tensor ) > _fused_batch_norm_inference = ( ) =>
156+ Func < Tensor [ ] > _fused_batch_norm_inference = ( ) =>
156157 {
157158 return tf . nn . fused_batch_norm (
158159 inputs ,
@@ -165,9 +166,41 @@ private Tensor _fused_batch_norm(Tensor inputs, Tensor training)
165166 data_format : _data_format ) ;
166167 } ;
167168
168- tf_utils . smart_cond ( training , _fused_batch_norm_training , _fused_batch_norm_inference ) ;
169+ var results = tf_utils . smart_cond ( training , _fused_batch_norm_training , _fused_batch_norm_inference ) ;
170+ var ( output , mean , variance ) = ( results [ 0 ] , results [ 1 ] , results [ 2 ] ) ;
171+ var training_value = tf_utils . constant_value ( training ) ;
169172
170- throw new NotImplementedException ( "_fused_batch_norm" ) ;
173+ Tensor momentum_tensor ;
174+ if ( training_value == null )
175+ {
176+ momentum_tensor = tf_utils . smart_cond ( training ,
177+ ( ) => new float [ ] { momentum } , ( ) => new float [ ] { 1.0f } ) [ 0 ] ;
178+ }
179+ else
180+ {
181+ momentum_tensor = ops . convert_to_tensor ( momentum ) ;
182+ }
183+
184+ if ( training_value == null )
185+ {
186+ var mean_update = _assign_moving_average ( moving_mean , mean , momentum_tensor ) ;
187+ var variance_update = _assign_moving_average ( moving_variance , variance , momentum_tensor ) ;
188+ add_update ( new Tensor [ ] { mean_update } , inputs : true ) ;
189+ add_update ( new Tensor [ ] { variance_update } , inputs : true ) ;
190+ }
191+
192+ return output ;
193+ }
194+
195+ public Tensor _assign_moving_average ( RefVariable variable , Tensor value , Tensor momentum )
196+ {
197+ return Python . with ( ops . name_scope ( null , "AssignMovingAvg" , new { variable , value , momentum } ) , scope =>
198+ {
199+ // var cm = ops.colocate_with(variable);
200+ var decay = ops . convert_to_tensor ( 1.0f - momentum , name : "decay" ) ;
201+ var update_delta = ( variable - math_ops . cast ( value , variable . dtype ) ) * decay ;
202+ return state_ops . assign_sub ( variable , update_delta , name : scope ) ;
203+ } ) ;
171204 }
172205 }
173206}
0 commit comments