@@ -68,7 +68,11 @@ public Tensor __call__(Tensor inputs,
6868 throw new NotImplementedException ( "" ) ;
6969 }
7070
71- protected virtual void add_weight ( )
71+ protected virtual void add_weight ( string name ,
72+ int [ ] shape ,
73+ TF_DataType dtype = TF_DataType . DtInvalid ,
74+ IInitializer initializer = null ,
75+ bool ? trainable = null )
7276 {
7377 var default_graph = ops . get_default_graph ( ) ;
7478 Graph init_graph = null ;
@@ -84,7 +88,9 @@ protected virtual void add_weight()
8488 existing_variables = variables . global_variables ( ) . ToArray ( ) ;
8589 }
8690
87- var dtype = TF_DataType . TF_FLOAT ;
91+ if ( dtype == TF_DataType . DtInvalid )
92+ dtype = TF_DataType . TF_FLOAT ;
93+
8894 _set_scope ( ) ;
8995 var reuse = built || ( _reuse != null && _reuse . Value ) ;
9096 Python . with ( tf . variable_scope ( _scope ,
@@ -94,8 +100,19 @@ protected virtual void add_weight()
94100 _current_scope = scope ;
95101 Python . with ( ops . name_scope ( _name_scope ( ) ) , delegate
96102 {
97-
98-
103+ base . add_weight ( name ,
104+ shape ,
105+ dtype : dtype ,
106+ initializer : initializer ,
107+ trainable : trainable ,
108+ getter : ( name1 , shape1 , dtype1 , initializer1 , trainable1 ) =>
109+ {
110+ return tf . get_variable ( name1 ,
111+ shape : new TensorShape ( shape1 ) ,
112+ dtype : dtype1 ,
113+ initializer : initializer1 ,
114+ trainable : trainable1 ) ;
115+ } ) ;
99116 } ) ;
100117 } ) ;
101118 }
0 commit comments