11using System ;
22using System . Collections . Generic ;
3+ using System . Linq ;
34using System . Text ;
45
56namespace Tensorflow
@@ -69,14 +70,15 @@ private void _init_from_args(object initial_value,
6970 {
7071
7172 }
73+ // Or get the initial value from a Tensor or Python object.
7274 else
7375 {
7476 _initial_value = ops . convert_to_tensor ( initial_value , name : "initial_value" ) ;
75- }
7677
77- var shape = _initial_value . shape ;
78- dtype = _initial_value . dtype ;
79- _variable = gen_state_ops . variable_v2 ( shape , dtype , name ) ;
78+ var shape = _initial_value . shape ;
79+ dtype = _initial_value . dtype ;
80+ _variable = gen_state_ops . variable_v2 ( shape , dtype , name ) ;
81+ }
8082
8183 // Manually overrides the variable's shape with the initial value's.
8284 if ( validate_shape )
@@ -87,8 +89,9 @@ private void _init_from_args(object initial_value,
8789 // If 'initial_value' makes use of other variables, make sure we don't
8890 // have an issue if these other variables aren't initialized first by
8991 // using their initialized_value() method.
92+ var _initial_value2 = _try_guard_against_uninitialized_dependencies ( _initial_value ) ;
9093
91- _initializer_op = gen_state_ops . assign ( _variable , _initial_value , validate_shape ) . op ;
94+ _initializer_op = gen_state_ops . assign ( _variable , _initial_value2 , validate_shape ) . op ;
9295
9396 if ( ! String . IsNullOrEmpty ( caching_device ) )
9497 {
@@ -112,5 +115,51 @@ public Tensor _AsTensor()
112115 {
113116 return _snapshot ;
114117 }
118+
119+ /// <summary>
120+ /// Attempt to guard against dependencies on uninitialized variables.
121+ /// </summary>
122+ /// <param name="initial_value"></param>
123+ private Tensor _try_guard_against_uninitialized_dependencies ( Tensor initial_value )
124+ {
125+ return _safe_initial_value_from_tensor ( initial_value , new Dictionary < string , Operation > ( ) ) ;
126+ }
127+
128+ /// <summary>
129+ /// Replace dependencies on variables with their initialized values.
130+ /// </summary>
131+ /// <param name="tensor">A `Tensor`. The tensor to replace.</param>
132+ /// <param name="op_cache">A dict mapping operation names to `Operation`s.</param>
133+ /// <returns>A `Tensor` compatible with `tensor`.</returns>
134+ private Tensor _safe_initial_value_from_tensor ( Tensor tensor , Dictionary < string , Operation > op_cache )
135+ {
136+ var op = tensor . op ;
137+ var new_op = op_cache . ContainsKey ( op . Name ) ? op_cache [ op . Name ] : null ;
138+ if ( new_op == null )
139+ {
140+ new_op = _safe_initial_value_from_op ( op , op_cache ) ;
141+ op_cache [ op . Name ] = new_op ;
142+ }
143+ return new_op . outputs [ tensor . value_index ] ;
144+ }
145+
146+ private Operation _safe_initial_value_from_op ( Operation op , Dictionary < string , Operation > op_cache )
147+ {
148+ var op_type = op . node_def . Op ;
149+ switch ( op_type )
150+ {
151+ case "IsVariableInitialized" :
152+ case "VarIsInitializedOp" :
153+ case "ReadVariableOp" :
154+ return op ;
155+ case "Variable" :
156+ case "VariableV2" :
157+ case "VarHandleOp" :
158+ break ;
159+ }
160+
161+ // Recursively build initializer expressions for inputs.
162+ return op ;
163+ }
115164 }
116165}
0 commit comments