Skip to content

Commit 734fe29

Browse files
committed
Be able to set Keras session options #712
1 parent 536293d commit 734fe29

File tree

10 files changed

+141
-27
lines changed

10 files changed

+141
-27
lines changed

docs/RELEASE.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Release Notes
2+
3+
**Thanks to our Contributors!**
4+
5+
This release contains contributions from many people at SciSharp as well as the external contributors.
6+
7+
**Release Date 01/09/2021**
8+
9+
### TensorFlow.Binding v0.32.0
10+
11+
* Fix input `dtype` for `MapDataset`.
12+
* Fix `image_dataset_from_directory` function.
13+
* Fix `tf.transpose`.
14+
* Add `array_ops.where_v2`, `array_ops.select_v2`, `array_ops.softplus`.
15+
* Add `dataset.dataset_cardinality`.
16+
17+
### TensorFlow.Keras v0.3.0
18+
19+
* Fix `weight` init value for `double` type in `compute_weighted_loss`.
20+
* Add `MeanSquaredError `, `MeanAbsolutePercentageError `, `MeanAbsoluteError` and `MeanSquaredLogarithmicError` loss functions.
21+
* `Sequential` model API works.
22+
* Add `ShellProgressBar` to show training progress better.
23+
24+
25+

src/TensorFlowNET.Core/Contexts/Context.Config.cs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616

1717
using System;
1818
using System.Diagnostics;
19+
using System.Linq;
1920

2021
namespace Tensorflow.Contexts
2122
{
@@ -24,24 +25,28 @@ namespace Tensorflow.Contexts
2425
/// </summary>
2526
public sealed partial class Context
2627
{
27-
ConfigProto _config;
28-
29-
ConfigProto config()
28+
public ConfigProto Config { get; set; } = new ConfigProto
3029
{
31-
var config = new ConfigProto()
30+
GpuOptions = new GPUOptions
3231
{
33-
LogDevicePlacement = _log_device_placement,
34-
GpuOptions = _compute_gpu_options()
35-
};
32+
}
33+
};
3634

37-
return config;
35+
ConfigProto MergeConfig()
36+
{
37+
Config.LogDevicePlacement = _log_device_placement;
38+
// var gpu_options = _compute_gpu_options();
39+
// Config.GpuOptions.AllowGrowth = gpu_options.AllowGrowth;
40+
return Config;
3841
}
3942

4043
GPUOptions _compute_gpu_options()
4144
{
45+
// By default, TensorFlow maps nearly all of the GPU memory of all GPUs
46+
// https://www.tensorflow.org/guide/gpu
4247
return new GPUOptions()
4348
{
44-
49+
AllowGrowth = get_memory_growth("GPU")
4550
};
4651
}
4752
}

src/TensorFlowNET.Core/Contexts/Context.Device.cs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ limitations under the License.
2020
using Tensorflow.Eager;
2121
using static Tensorflow.Binding;
2222
using Google.Protobuf;
23+
using Tensorflow.Device;
24+
using System.Collections.Generic;
2325

2426
namespace Tensorflow.Contexts
2527
{
@@ -30,6 +32,7 @@ public sealed partial class Context
3032
{
3133
ContextDevicePlacementPolicy _device_policy;
3234
bool _log_device_placement;
35+
Dictionary<PhysicalDevice, bool> _memory_growth_map = new Dictionary<PhysicalDevice, bool>();
3336

3437
public void log_device_placement(bool enable)
3538
{
@@ -38,5 +41,53 @@ public void log_device_placement(bool enable)
3841
_log_device_placement = enable;
3942
// _thread_local_data.function_call_options = null;
4043
}
44+
45+
public bool get_memory_growth(string device_type)
46+
{
47+
foreach(var map in _memory_growth_map)
48+
{
49+
if (map.Key.DeviceType == device_type)
50+
return map.Value;
51+
}
52+
return false;
53+
}
54+
55+
public void set_memory_growth(PhysicalDevice device, bool enable)
56+
{
57+
_memory_growth_map[device] = enable;
58+
}
59+
60+
public PhysicalDevice[] list_physical_devices(string device_type = null)
61+
{
62+
using var opts = c_api.TFE_NewContextOptions();
63+
using var ctx = c_api.TFE_NewContext(opts, tf.Status.Handle);
64+
using var devices = c_api.TFE_ContextListDevices(ctx, tf.Status.Handle);
65+
tf.Status.Check(true);
66+
67+
int num_devices = c_api.TF_DeviceListCount(devices);
68+
var results = new List<PhysicalDevice>();
69+
for (int i = 0; i < num_devices; ++i)
70+
{
71+
var dev_type = c_api.StringPiece(c_api.TF_DeviceListType(devices, i, tf.Status.Handle));
72+
tf.Status.Check(true);
73+
74+
if (dev_type.StartsWith("XLA"))
75+
continue;
76+
77+
if (device_type == null || dev_type == device_type)
78+
{
79+
var dev_name = c_api.TF_DeviceListName(devices, i, tf.Status.Handle);
80+
tf.Status.Check(true);
81+
82+
results.Add(new PhysicalDevice
83+
{
84+
DeviceName = dev_name,
85+
DeviceType = dev_type
86+
});
87+
}
88+
}
89+
90+
return results.ToArray();
91+
}
4192
}
4293
}

src/TensorFlowNET.Core/Contexts/Context.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ public void ensure_initialized()
5757
if (initialized)
5858
return;
5959

60-
_config = config();
61-
var config_str = _config.ToByteArray();
62-
60+
Config = MergeConfig();
61+
FunctionCallOptions.Config = Config;
62+
var config_str = Config.ToByteArray();
6363
using var opts = new ContextOptions();
6464
using var status = new Status();
6565
c_api.TFE_ContextOptionsSetConfig(opts.Handle, config_str, (ulong)config_str.Length, status.Handle);
@@ -82,7 +82,9 @@ public void end_step()
8282
/// <returns></returns>
8383
[DebuggerStepThrough]
8484
public bool executing_eagerly()
85-
=> context_switches.Current().EagerMode;
85+
{
86+
return context_switches.Current().EagerMode;
87+
}
8688

8789
public bool is_build_function()
8890
=> context_switches.Current().IsBuildingFunction;

src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,17 @@
22
using System.Collections.Generic;
33
using System.Text;
44
using Google.Protobuf;
5-
using Google.Protobuf.Collections;
5+
using static Tensorflow.Binding;
66

77
namespace Tensorflow.Contexts
88
{
99
public class FunctionCallOptions
1010
{
11+
public ConfigProto Config { get; set; }
12+
1113
public string config_proto_serialized()
1214
{
13-
var config = new ConfigProto
14-
{
15-
AllowSoftPlacement = true,
16-
};
17-
return config.ToByteString().ToStringUtf8();
15+
return Config.ToByteString().ToStringUtf8();
1816
}
1917
}
2018
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Device
6+
{
7+
public class PhysicalDevice
8+
{
9+
public string DeviceName { get; set; }
10+
public string DeviceType { get; set; }
11+
12+
public override string ToString()
13+
=> $"{DeviceType}: {DeviceName}";
14+
}
15+
}

src/TensorFlowNET.Core/Eager/c_api.eager.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,13 @@ public static void TFE_Execute(SafeOpHandle op, SafeTensorHandleHandle[] retvals
380380
[DllImport(TensorFlowLibName)]
381381
public static extern SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status);
382382

383+
/// <summary>
384+
/// Clears the internal caches in the TFE context. Useful when reseeding random ops.
385+
/// </summary>
386+
/// <param name="ctx">TFE_Context*</param>
387+
[DllImport(TensorFlowLibName)]
388+
public static extern void TFE_ContextClearCaches(SafeContextHandle ctx);
389+
383390
/// <summary>
384391
///
385392
/// </summary>
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,27 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using static Tensorflow.Binding;
5+
using Tensorflow.Device;
46

57
namespace Tensorflow.Framework
68
{
79
public class ConfigImpl
810
{
11+
/// <summary>
12+
/// Return a list of physical devices visible to the host runtime.
13+
/// </summary>
14+
/// <param name="device_type">CPU, GPU, TPU</param>
15+
/// <returns></returns>
16+
public PhysicalDevice[] list_physical_devices(string device_type = null)
17+
=> tf.Context.list_physical_devices(device_type: device_type);
918

19+
public Experimental experimental => new Experimental();
20+
21+
public class Experimental
22+
{
23+
public void set_memory_growth(PhysicalDevice device, bool enable)
24+
=> tf.Context.set_memory_growth(device, enable);
25+
}
1026
}
1127
}

src/TensorFlowNET.Keras/Datasets/Cifar10.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,8 @@ public DatasetPass load_data()
124124
string Download()
125125
{
126126
var dst = Path.Combine(Path.GetTempPath(), dest_folder);
127-
Directory.CreateDirectory(dst);
128-
129127
Web.Download(origin_folder + file_name, dst, file_name);
130-
Compress.ExtractTGZ(Path.Combine(Path.GetTempPath(), file_name), dst);
128+
Compress.ExtractTGZ(Path.Combine(dst, file_name), dst);
131129

132130
return Path.Combine(dst, "cifar-10-batches-py");
133131
}

src/TensorFlowNet.Benchmarks/Leak/GpuLeakByCNN.cs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,8 @@ public class GpuLeakByCNN
1616
[Benchmark]
1717
public void Run()
1818
{
19-
tf.debugging.set_log_device_placement(true);
20-
21-
var a = tf.constant(3.0);
22-
var b = tf.constant(2.0);
23-
var c = tf.multiply(a, b);
19+
// tf.debugging.set_log_device_placement(true);
20+
tf.Context.Config.GpuOptions.AllowGrowth = true;
2421

2522
int num = 50, width = 64, height = 64;
2623
// if width = 128, height = 128, the exception occurs faster

0 commit comments

Comments
 (0)