@@ -52,6 +52,7 @@ public class Layer : AutoTrackable
5252 protected InputSpec input_spec ;
5353 protected bool supports_masking ;
5454 protected List < VariableV1 > _trainable_weights ;
55+ protected List < VariableV1 > _non_trainable_weights ;
5556 private string _name ;
5657 public string name => _name ;
5758 protected string _base_name ;
@@ -84,6 +85,7 @@ public Layer(bool trainable = true,
8485
8586 _init_set_name ( name ) ;
8687 _trainable_weights = new List < VariableV1 > ( ) ;
88+ _non_trainable_weights = new List < VariableV1 > ( ) ;
8789 _compute_previous_mask = false ;
8890 _updates = new List < Operation > ( ) ;
8991
@@ -103,6 +105,7 @@ public Layer(bool trainable = true,
103105
104106 public ( Tensor , Tensor ) __call__ ( Tensor [ ] inputs ,
105107 Tensor training = null ,
108+ Tensor state = null ,
106109 VariableScope scope = null )
107110 {
108111 var input_list = inputs ;
@@ -139,7 +142,9 @@ public Layer(bool trainable = true,
139142 // overridden).
140143 _maybe_build ( inputs [ 0 ] ) ;
141144
142- ( input , outputs ) = call ( inputs [ 0 ] , training : training ) ;
145+ ( input , outputs ) = call ( inputs [ 0 ] ,
146+ training : training ,
147+ state : state ) ;
143148 ( input , outputs ) = _set_connectivity_metadata_ ( input , outputs ) ;
144149 _handle_activity_regularization ( inputs [ 0 ] , outputs ) ;
145150 _set_mask_metadata ( inputs [ 0 ] , outputs , null ) ;
@@ -173,7 +178,7 @@ private Tensor compute_mask(Tensor inputs, Tensor mask = null)
173178 return null ;
174179 }
175180
176- protected virtual ( Tensor , Tensor ) call ( Tensor inputs , Tensor training = null )
181+ protected virtual ( Tensor , Tensor ) call ( Tensor inputs , Tensor training = null , Tensor state = null )
177182 {
178183 return ( inputs , inputs ) ;
179184 }
@@ -233,7 +238,10 @@ protected virtual VariableV1 add_weight(string name,
233238 initializer : initializer ,
234239 trainable : trainable . Value ) ;
235240 //backend.track_variable(variable);
236- _trainable_weights . Add ( variable ) ;
241+ if ( trainable == true )
242+ _trainable_weights . Add ( variable ) ;
243+ else
244+ _non_trainable_weights . Add ( variable ) ;
237245
238246 return variable ;
239247 }
0 commit comments