Skip to content

Commit 9a58ebb

Browse files
committed
BaseSession: feeddict values are not auto-converted to NDArray any more (was a waste of time and memory)
1 parent c3ae8b1 commit 9a58ebb

File tree

1 file changed

+89
-85
lines changed

1 file changed

+89
-85
lines changed

src/TensorFlowNET.Core/Sessions/BaseSession.cs

Lines changed: 89 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,25 @@
1-
/*****************************************************************************
2-
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3-
4-
Licensed under the Apache License, Version 2.0 (the "License");
5-
you may not use this file except in compliance with the License.
6-
You may obtain a copy of the License at
7-
8-
http://www.apache.org/licenses/LICENSE-2.0
9-
10-
Unless required by applicable law or agreed to in writing, software
11-
distributed under the License is distributed on an "AS IS" BASIS,
12-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
See the License for the specific language governing permissions and
14-
limitations under the License.
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
1515
******************************************************************************/
1616

1717
using NumSharp;
1818
using System;
1919
using System.Collections;
2020
using System.Collections.Generic;
2121
using System.Linq;
22+
using System.Numerics;
2223
using System.Runtime.InteropServices;
2324
using System.Text;
2425

@@ -31,26 +32,26 @@ public class BaseSession
3132
protected bool _closed;
3233
protected int _current_version;
3334
protected byte[] _target;
34-
protected IntPtr _session;
35-
public Status Status;
35+
protected IntPtr _session;
36+
public Status Status;
3637
public Graph graph => _graph;
3738

3839
public BaseSession(string target = "", Graph g = null, SessionOptions opts = null)
39-
{
40+
{
4041
_graph = g is null ? ops.get_default_graph() : g;
4142

4243
_target = UTF8Encoding.UTF8.GetBytes(target);
4344

4445
SessionOptions newOpts = null;
45-
if (opts == null)
46+
if (opts == null)
4647
newOpts = c_api.TF_NewSessionOptions();
4748

4849
Status = new Status();
4950

5051
_session = c_api.TF_NewSession(_graph, opts ?? newOpts, Status);
5152

5253
// dispose newOpts
53-
if (opts == null)
54+
if (opts == null)
5455
c_api.TF_DeleteSessionOptions(newOpts);
5556

5657
Status.Check(true);
@@ -63,7 +64,7 @@ public virtual NDArray run(object fetches, params FeedItem[] feed_dict)
6364

6465
public virtual NDArray run(object fetches, Hashtable feed_dict = null)
6566
{
66-
var feed_items = feed_dict == null ? new FeedItem[0] :
67+
var feed_items = feed_dict == null ? new FeedItem[0] :
6768
feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray();
6869
return _run(fetches, feed_items);
6970
}
@@ -86,57 +87,8 @@ private NDArray _run(object fetches, FeedItem[] feed_dict = null)
8687
foreach (var (subfeed, subfeed_val) in feed_fn(feed))
8788
{
8889
var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false);
89-
var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype();
90-
91-
switch (subfeed_val)
92-
{
93-
case IntPtr val:
94-
feed_dict_tensor[subfeed_t] = val;
95-
break;
96-
case NDArray val:
97-
feed_dict_tensor[subfeed_t] = val;
98-
break;
99-
case float val:
100-
feed_dict_tensor[subfeed_t] = (NDArray)val;
101-
break;
102-
case double val:
103-
feed_dict_tensor[subfeed_t] = (NDArray)val;
104-
break;
105-
case short val:
106-
feed_dict_tensor[subfeed_t] = (NDArray)val;
107-
break;
108-
case int val:
109-
feed_dict_tensor[subfeed_t] = (NDArray)val;
110-
break;
111-
case long val:
112-
feed_dict_tensor[subfeed_t] = (NDArray)val;
113-
break;
114-
case long[] val:
115-
feed_dict_tensor[subfeed_t] = (NDArray)val;
116-
break;
117-
case int[] val:
118-
feed_dict_tensor[subfeed_t] = (NDArray)val;
119-
break;
120-
case string val:
121-
feed_dict_tensor[subfeed_t] = (NDArray)val;
122-
break;
123-
case byte[] val:
124-
feed_dict_tensor[subfeed_t] = np.array(val);
125-
break;
126-
case char[] val:
127-
feed_dict_tensor[subfeed_t] = (NDArray)val;
128-
break;
129-
case bool val:
130-
feed_dict_tensor[subfeed_t] = (NDArray)val;
131-
break;
132-
case bool[] val:
133-
feed_dict_tensor[subfeed_t] = (NDArray)val;
134-
break;
135-
default:
136-
Console.WriteLine($"can't handle data type of subfeed_val");
137-
throw new NotImplementedException("_run subfeed");
138-
}
139-
90+
//var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used
91+
feed_dict_tensor[subfeed_t] = subfeed_val;
14092
feed_map[subfeed_t.name] = (subfeed_t, subfeed_val);
14193
}
14294
}
@@ -175,26 +127,78 @@ private NDArray _run(object fetches, FeedItem[] feed_dict = null)
175127
/// </returns>
176128
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict)
177129
{
178-
var feeds = feed_dict.Select(x =>
130+
var feeds = feed_dict.Select(x =>
179131
{
180132
if (x.Key is Tensor tensor)
181133
{
182134
switch (x.Value)
183135
{
184-
case IntPtr pointer:
185-
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), pointer);
186-
case Tensor t1:
187-
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), t1);
188-
case NDArray nd:
189-
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(nd, tensor.dtype));
190-
case int intVal:
191-
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(intVal));
192-
case float floatVal:
193-
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(floatVal));
194-
case double doubleVal:
195-
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(doubleVal));
136+
#if _REGEN
137+
%types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"]
138+
%foreach types%
139+
case #1 v:
140+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
141+
case #1[] v:
142+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
143+
%
144+
#else
145+
case sbyte v:
146+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
147+
case sbyte[] v:
148+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
149+
case byte v:
150+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
151+
case byte[] v:
152+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
153+
case short v:
154+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
155+
case short[] v:
156+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
157+
case ushort v:
158+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
159+
case ushort[] v:
160+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
161+
case int v:
162+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
163+
case int[] v:
164+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
165+
case uint v:
166+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
167+
case uint[] v:
168+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
169+
case long v:
170+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
171+
case long[] v:
172+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
173+
case ulong v:
174+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
175+
case ulong[] v:
176+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
177+
case float v:
178+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
179+
case float[] v:
180+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
181+
case double v:
182+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
183+
case double[] v:
184+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
185+
case Complex v:
186+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
187+
case Complex[] v:
188+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
189+
#endif
190+
case bool v:
191+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL));
192+
case string v:
193+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
194+
case IntPtr v:
195+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
196+
case Tensor v:
197+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v);
198+
case NDArray v:
199+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype));
196200
default:
197-
throw new NotImplementedException("feed_dict data type");
201+
throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "<null>")}");
198202
}
199203
}
200204
throw new NotImplementedException("_do_run.feed_dict");

0 commit comments

Comments
 (0)