@@ -16,16 +16,23 @@ public class FuncGraph : Graph
1616 Graph outer_graph ;
1717 public Graph OuterGraph => outer_graph ;
1818
19- string func_name ;
20-
2119 // _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle));
2220 IntPtr func_handle ;
23- public string FuncName => func_name ;
21+ public string FuncName => _graph_key ;
2422
2523 public Tensors Inputs { get ; set ; }
2624 public Tensors Outputs { get ; set ; }
2725 public Dictionary < string , string > Attrs { get ; set ; }
2826
27+ public Dictionary < long , ( Tensor , Tensor ) > _captures
28+ = new Dictionary < long , ( Tensor , Tensor ) > ( ) ;
29+
30+ public Tensor [ ] external_captures ( )
31+ => _captures . Select ( x => x . Value . Item1 ) . ToArray ( ) ;
32+
33+ public Tensor [ ] internal_captures ( )
34+ => _captures . Select ( x => x . Value . Item2 ) . ToArray ( ) ;
35+
2936 // new Dictionary<long, (Tensor, Tensor)> _captures = new Dictionary<long, (Tensor, Tensor)>();
3037 // public new Tensor[] external_captures => _captures.Values.Select(x => x.Item1).ToArray();
3138
@@ -35,7 +42,7 @@ public class FuncGraph : Graph
3542 public FuncGraph ( string name ) : base ( )
3643 {
3744 outer_graph = ops . get_default_graph ( ) ;
38- func_name = name ;
45+ _graph_key = name ;
3946
4047 tf . Context . graph_mode ( ) ;
4148 as_default ( ) ;
@@ -44,7 +51,7 @@ public FuncGraph(string name) : base()
4451 public FuncGraph ( IntPtr handle , string name , Dictionary < string , string > attrs ) : base ( )
4552 {
4653 outer_graph = ops . get_default_graph ( ) ;
47- func_name = name ;
54+ _graph_key = name ;
4855 Attrs = attrs ;
4956 // Will to test if FuncGraph has memory leak
5057 // c_api.TF_DeleteGraph(_handle);
@@ -60,7 +67,7 @@ public IntPtr ToGraph(Operation[] opers,
6067 {
6168 using var status = new Status ( ) ;
6269 func_handle = c_api . TF_GraphToFunction ( _handle ,
63- func_name ,
70+ _graph_key ,
6471 false ,
6572 opers . Length ,
6673 opers . Select ( x => ( IntPtr ) x ) . ToArray ( ) ,
@@ -82,7 +89,7 @@ public IntPtr ToGraph(Operation[] opers,
8289 c_api . TFE_ContextAddFunction ( tf . Context . Handle , func_handle , status . Handle ) ;
8390 status . Check ( true ) ;
8491
85- func_name = c_api . StringPiece ( c_api . TF_FunctionName ( func_handle ) ) ;
92+ _graph_key = c_api . StringPiece ( c_api . TF_FunctionName ( func_handle ) ) ;
8693
8794 Inputs = inputs ;
8895 // mark_as_return
@@ -131,7 +138,7 @@ Tensor capture(Tensor tensor, string name = null, TF_DataType shape = TF_DataTyp
131138 Tensor _capture_helper ( Tensor tensor , string name , TensorShape shape = null )
132139 {
133140 Tensor placeholder = null ;
134- if ( ! _captures . Contains ( tensor . Id ) )
141+ if ( ! _captures . ContainsKey ( tensor . Id ) )
135142 {
136143 placeholder = _create_substitute_placeholder ( tensor ,
137144 name : name ,
@@ -141,7 +148,7 @@ Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null)
141148 }
142149 else
143150 {
144- placeholder = ( ( ( Tensor , Tensor ) ) _captures [ tensor . Id ] ) . Item2 ;
151+ placeholder = _captures [ tensor . Id ] . Item2 ;
145152 }
146153
147154 BackwardFunction _backward_function_wrapper = ( output_grads , unneeded_gradients ) =>
@@ -160,7 +167,7 @@ Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null)
160167
161168 void add_capture ( Tensor tensor , Tensor placeholder )
162169 {
163- _captures [ tensor . Id ] = ( tensor , placeholder ) ;
170+ _captures . Add ( tensor . Id , ( tensor , placeholder ) ) ;
164171 if ( Inputs == null )
165172 Inputs = new Tensors ( placeholder ) ;
166173 else
0 commit comments