@@ -34,6 +34,7 @@ public string Name
3434 public ConcreteFunction ( string name )
3535 {
3636 func_graph = new FuncGraph ( name ) ;
37+ func_graph . as_default ( ) ;
3738 }
3839
3940 public ConcreteFunction ( FuncGraph graph , Dictionary < string , string > attrs )
@@ -48,37 +49,36 @@ public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype)
4849 string func_name = $ "autograph_{ Guid . NewGuid ( ) } _{ func . Method . Name } ";
4950
5051 // IntPtr func_handle;
51- using ( var graph = new FuncGraph ( func_name ) )
52- {
53- var input = tf . placeholder ( dtype ) ;
54- var output = func ( input ) ;
55-
56- var opers = graph . _nodes_by_name . Values . Select ( x => x as Operation ) . ToArray ( ) ;
57- _handle = graph . ToGraph ( opers ,
58- new [ ] { input } ,
59- new [ ] { output } ,
60- null ) ;
61- }
52+ using var graph = new FuncGraph ( func_name ) ;
53+ graph . as_default ( ) ;
54+ var input = tf . placeholder ( dtype ) ;
55+ var output = func ( input ) ;
56+
57+ var opers = graph . _nodes_by_name . Values . Select ( x => x as Operation ) . ToArray ( ) ;
58+ _handle = graph . ToGraph ( opers ,
59+ new [ ] { input } ,
60+ new [ ] { output } ,
61+ null ) ;
6262 }
6363
6464 public ConcreteFunction ( Func < Tensor , IDatasetV2 > func , TF_DataType dtype )
6565 {
6666 string func_name = $ "autograph_{ Guid . NewGuid ( ) } _{ func . Method . Name } ";
6767
6868 // IntPtr func_handle;
69- using ( var graph = new FuncGraph ( func_name ) )
70- {
71- var input = tf . placeholder ( dtype ) ;
72- var output = func ( input ) ;
69+ using var graph = new FuncGraph ( func_name ) ;
70+ graph . as_default ( ) ;
7371
74- OutputStructure = output . structure ;
72+ var input = tf . placeholder ( dtype ) ;
73+ var output = func ( input ) ;
7574
76- var opers = graph . _nodes_by_name . Values . Select ( x => x as Operation ) . ToArray ( ) ;
77- _handle = graph . ToGraph ( opers ,
78- new [ ] { input } ,
79- new [ ] { output . variant_tensor } ,
80- null ) ;
81- }
75+ OutputStructure = output . structure ;
76+
77+ var opers = graph . _nodes_by_name . Values . Select ( x => x as Operation ) . ToArray ( ) ;
78+ _handle = graph . ToGraph ( opers ,
79+ new [ ] { input } ,
80+ new [ ] { output . variant_tensor } ,
81+ null ) ;
8282 }
8383
8484 public ConcreteFunction ( Func < Tensor , ( Tensor , Tensor ) , ( Tensor , Tensor ) > func ,
@@ -87,22 +87,22 @@ public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func,
8787 string func_name = $ "autograph_{ Guid . NewGuid ( ) } _{ func . Method . Name } ";
8888
8989 // IntPtr func_handle;
90- using ( var graph = new FuncGraph ( func_name ) )
91- {
92- var input1 = tf . placeholder ( dtypes [ 0 ] , shape : shapes [ 0 ] , name : "args" ) ;
93- var input2 = tf . placeholder ( dtypes [ 1 ] , shape : shapes [ 1 ] , name : "args" ) ;
94- var input3 = tf . placeholder ( dtypes [ 2 ] , shape : shapes [ 2 ] , name : "args" ) ;
95- var outputs = func ( input1 , ( input2 , input3 ) ) ;
96-
97- Outputs = new [ ] { outputs . Item1 , outputs . Item2 } ;
98- OutputStructure = new [ ] { outputs . Item1 . ToTensorSpec ( ) , outputs . Item2 . ToTensorSpec ( ) } ;
99-
100- var opers = graph . _nodes_by_name . Values . Select ( x => x as Operation ) . ToArray ( ) ;
101- _handle = graph . ToGraph ( opers ,
102- new [ ] { input1 , input2 , input3 } ,
103- new [ ] { outputs . Item1 , outputs . Item2 } ,
104- null ) ;
105- }
90+ using var graph = new FuncGraph ( func_name ) ;
91+ graph . as_default ( ) ;
92+
93+ var input1 = tf . placeholder ( dtypes [ 0 ] , shape : shapes [ 0 ] , name : "args" ) ;
94+ var input2 = tf . placeholder ( dtypes [ 1 ] , shape : shapes [ 1 ] , name : "args" ) ;
95+ var input3 = tf . placeholder ( dtypes [ 2 ] , shape : shapes [ 2 ] , name : "args" ) ;
96+ var outputs = func ( input1 , ( input2 , input3 ) ) ;
97+
98+ Outputs = new [ ] { outputs . Item1 , outputs . Item2 } ;
99+ OutputStructure = new [ ] { outputs . Item1 . ToTensorSpec ( ) , outputs . Item2 . ToTensorSpec ( ) } ;
100+
101+ var opers = graph . _nodes_by_name . Values . Select ( x => x as Operation ) . ToArray ( ) ;
102+ _handle = graph . ToGraph ( opers ,
103+ new [ ] { input1 , input2 , input3 } ,
104+ new [ ] { outputs . Item1 , outputs . Item2 } ,
105+ null ) ;
106106 }
107107
108108 public void ToGraph ( Tensors inputs , Tensors outputs )
0 commit comments