1- /*****************************************************************************
2- Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3-
4- Licensed under the Apache License, Version 2.0 (the "License");
5- you may not use this file except in compliance with the License.
6- You may obtain a copy of the License at
7-
8- http://www.apache.org/licenses/LICENSE-2.0
9-
10- Unless required by applicable law or agreed to in writing, software
11- distributed under the License is distributed on an "AS IS" BASIS,
12- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13- See the License for the specific language governing permissions and
14- limitations under the License.
1+ /*****************************************************************************
2+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+ Licensed under the Apache License, Version 2.0 (the "License");
5+ you may not use this file except in compliance with the License.
6+ You may obtain a copy of the License at
7+
8+ http://www.apache.org/licenses/LICENSE-2.0
9+
10+ Unless required by applicable law or agreed to in writing, software
11+ distributed under the License is distributed on an "AS IS" BASIS,
12+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ See the License for the specific language governing permissions and
14+ limitations under the License.
1515******************************************************************************/
1616
1717using NumSharp ;
1818using System ;
1919using System . Collections ;
2020using System . Collections . Generic ;
2121using System . Linq ;
22+ using System . Numerics ;
2223using System . Runtime . InteropServices ;
2324using System . Text ;
2425
@@ -31,26 +32,26 @@ public class BaseSession
3132 protected bool _closed ;
3233 protected int _current_version ;
3334 protected byte [ ] _target ;
34- protected IntPtr _session ;
35- public Status Status ;
35+ protected IntPtr _session ;
36+ public Status Status ;
3637 public Graph graph => _graph ;
3738
3839 public BaseSession ( string target = "" , Graph g = null , SessionOptions opts = null )
39- {
40+ {
4041 _graph = g is null ? ops . get_default_graph ( ) : g ;
4142
4243 _target = UTF8Encoding . UTF8 . GetBytes ( target ) ;
4344
4445 SessionOptions newOpts = null ;
45- if ( opts == null )
46+ if ( opts == null )
4647 newOpts = c_api . TF_NewSessionOptions ( ) ;
4748
4849 Status = new Status ( ) ;
4950
5051 _session = c_api . TF_NewSession ( _graph , opts ?? newOpts , Status ) ;
5152
5253 // dispose newOpts
53- if ( opts == null )
54+ if ( opts == null )
5455 c_api . TF_DeleteSessionOptions ( newOpts ) ;
5556
5657 Status . Check ( true ) ;
@@ -63,7 +64,7 @@ public virtual NDArray run(object fetches, params FeedItem[] feed_dict)
6364
6465 public virtual NDArray run ( object fetches , Hashtable feed_dict = null )
6566 {
66- var feed_items = feed_dict == null ? new FeedItem [ 0 ] :
67+ var feed_items = feed_dict == null ? new FeedItem [ 0 ] :
6768 feed_dict . Keys . OfType < object > ( ) . Select ( key => new FeedItem ( key , feed_dict [ key ] ) ) . ToArray ( ) ;
6869 return _run ( fetches , feed_items ) ;
6970 }
@@ -86,57 +87,8 @@ private NDArray _run(object fetches, FeedItem[] feed_dict = null)
8687 foreach ( var ( subfeed , subfeed_val ) in feed_fn ( feed ) )
8788 {
8889 var subfeed_t = _graph . as_graph_element ( subfeed , allow_tensor : true , allow_operation : false ) ;
89- var subfeed_dtype = subfeed_t . dtype . as_numpy_datatype ( ) ;
90-
91- switch ( subfeed_val )
92- {
93- case IntPtr val :
94- feed_dict_tensor [ subfeed_t ] = val ;
95- break ;
96- case NDArray val :
97- feed_dict_tensor [ subfeed_t ] = val ;
98- break ;
99- case float val :
100- feed_dict_tensor [ subfeed_t ] = ( NDArray ) val ;
101- break ;
102- case double val :
103- feed_dict_tensor [ subfeed_t ] = ( NDArray ) val ;
104- break ;
105- case short val :
106- feed_dict_tensor [ subfeed_t ] = ( NDArray ) val ;
107- break ;
108- case int val :
109- feed_dict_tensor [ subfeed_t ] = ( NDArray ) val ;
110- break ;
111- case long val :
112- feed_dict_tensor [ subfeed_t ] = ( NDArray ) val ;
113- break ;
114- case long [ ] val :
115- feed_dict_tensor [ subfeed_t ] = ( NDArray ) val ;
116- break ;
117- case int [ ] val :
118- feed_dict_tensor [ subfeed_t ] = ( NDArray ) val ;
119- break ;
120- case string val :
121- feed_dict_tensor [ subfeed_t ] = ( NDArray ) val ;
122- break ;
123- case byte [ ] val :
124- feed_dict_tensor [ subfeed_t ] = np . array ( val ) ;
125- break ;
126- case char [ ] val :
127- feed_dict_tensor [ subfeed_t ] = ( NDArray ) val ;
128- break ;
129- case bool val :
130- feed_dict_tensor [ subfeed_t ] = ( NDArray ) val ;
131- break ;
132- case bool [ ] val :
133- feed_dict_tensor [ subfeed_t ] = ( NDArray ) val ;
134- break ;
135- default :
136- Console . WriteLine ( $ "can't handle data type of subfeed_val") ;
137- throw new NotImplementedException ( "_run subfeed" ) ;
138- }
139-
90+ //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used
91+ feed_dict_tensor [ subfeed_t ] = subfeed_val ;
14092 feed_map [ subfeed_t . name ] = ( subfeed_t , subfeed_val ) ;
14193 }
14294 }
@@ -175,26 +127,78 @@ private NDArray _run(object fetches, FeedItem[] feed_dict = null)
175127 /// </returns>
176128 private NDArray [ ] _do_run ( List < Operation > target_list , List < Tensor > fetch_list , Dictionary < object , object > feed_dict )
177129 {
178- var feeds = feed_dict . Select ( x =>
130+ var feeds = feed_dict . Select ( x =>
179131 {
180132 if ( x . Key is Tensor tensor )
181133 {
182134 switch ( x . Value )
183135 {
184- case IntPtr pointer :
185- return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , pointer ) ;
186- case Tensor t1 :
187- return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , t1 ) ;
188- case NDArray nd :
189- return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( nd , tensor . dtype ) ) ;
190- case int intVal :
191- return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( intVal ) ) ;
192- case float floatVal :
193- return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( floatVal ) ) ;
194- case double doubleVal :
195- return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( doubleVal ) ) ;
136+ #if _REGEN
137+ % types = [ "sbyte" , "byte" , "short" , "ushort" , "int" , "uint" , "long" , "ulong" , "float" , "double" , "Complex" ]
138+ % foreach types%
139+ case #1 v:
140+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
141+ case #1 [ ] v :
142+ return new KeyValuePair< TF_Output, Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
143+ %
144+ #else
145+ case sbyte v:
146+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
147+ case sbyte [ ] v:
148+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
149+ case byte v:
150+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
151+ case byte [ ] v:
152+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
153+ case short v:
154+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
155+ case short [ ] v:
156+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
157+ case ushort v:
158+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
159+ case ushort [ ] v:
160+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
161+ case int v:
162+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
163+ case int [ ] v:
164+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
165+ case uint v:
166+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
167+ case uint [ ] v:
168+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
169+ case long v:
170+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
171+ case long [ ] v:
172+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
173+ case ulong v:
174+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
175+ case ulong [ ] v:
176+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
177+ case float v:
178+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
179+ case float [ ] v:
180+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
181+ case double v:
182+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
183+ case double [ ] v:
184+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
185+ case Complex v:
186+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
187+ case Complex [ ] v:
188+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
189+ #endif
190+ case bool v:
191+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( ( byte ) ( v ? 1 : 0 ) , TF_DataType . TF_BOOL ) ) ;
192+ case string v:
193+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
194+ case IntPtr v:
195+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
196+ case Tensor v:
197+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , v ) ;
198+ case NDArray v:
199+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v , tensor . dtype ) ) ;
196200 default :
197- throw new NotImplementedException ( "feed_dict data type" ) ;
201+ throw new NotImplementedException ( $ "feed_dict data type { ( x . Value ? . GetType ( ) . Name ?? "<null>" ) } ") ;
198202 }
199203 }
200204 throw new NotImplementedException ( "_do_run.feed_dict" ) ;
0 commit comments