Skip to content

Commit b2d8980

Browse files
author
Milly Fowden
committed
Fix --use_on_demand flag and add test
Summary: --use_on_demand's previous impl was not equivalent to ON_DEMAND connection type: it would try to connect to IPUs before compilation rather than after. This is because it wasn't overriding the config's device connection type, so the program still thought we were using ALWAYS in some places. This diff changes it so that the flag makes us override the connection type in our config to ON_DEMAND at configuration time. The rest is free. Done at C++ level so users of the kernel directly can also use the flag. This means the flag will now make TF connect to IPUs after compilation. Ref T69999 Test Plan: Added test to check our internal config is correct when not using and using the flag. Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, krzysztofk, gauthamg, mateuszk Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, krzysztofk, gauthamg, mateuszk Subscribers: harrym Maniphest Tasks: T69999 Differential Revision: https://phabricator.sourcevertex.net/D80874
1 parent 8694236 commit b2d8980

File tree

3 files changed

+104
-6
lines changed

3 files changed

+104
-6
lines changed

tensorflow/compiler/plugin/poplar/driver/poplar_executor.cc

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,8 +1562,7 @@ Status PoplarExecutor::AttachToPoplarDevice() {
15621562
}
15631563

15641564
const bool wait_for_device =
1565-
ConnectionType() == IpuDeviceConnectionType::ON_DEMAND ||
1566-
PoplarXlaFlags::Get().use_on_demand;
1565+
ConnectionType() == IpuDeviceConnectionType::ON_DEMAND;
15671566
const bool use_ipu_model = PoplarXlaFlags::Get().use_ipu_model;
15681567

15691568
try {
@@ -1802,10 +1801,21 @@ Status PoplarExecutor::CreatePoplarTarget() {
18021801

18031802
Status PoplarExecutor::ConfigurePoplarDevice(const IpuOptions& cfg) {
18041803
TENSORFLOW_TRACEPOINT();
1804+
1805+
// Override config options with global flags.
1806+
IpuOptions overridden_cfg = cfg;
1807+
if (PoplarXlaFlags::Get().use_on_demand) {
1808+
VLOG(1) << "Overriding device connection type to ON_DEMAND due to "
1809+
"--use_on_demand flag";
1810+
overridden_cfg.set_device_connection_type(
1811+
IpuDeviceConnectionType::ON_DEMAND);
1812+
}
1813+
18051814
bool has_user_config = (current_config_.device_config_size() > 0);
1806-
if (!DeviceConfigurationsEqual(cfg, current_config_) && has_user_config) {
1815+
if (!DeviceConfigurationsEqual(overridden_cfg, current_config_) &&
1816+
has_user_config) {
18071817
XLA_VLOG_LINES(1, "Current config: " + current_config_.DebugString() +
1808-
"\nNew config: " + cfg.DebugString());
1818+
"\nNew config: " + overridden_cfg.DebugString());
18091819
return FailedPrecondition(
18101820
"IPU system configuration has already been set in this process, "
18111821
"but it should have been reset automatically by the call to "
@@ -1822,12 +1832,13 @@ Status PoplarExecutor::ConfigurePoplarDevice(const IpuOptions& cfg) {
18221832
<< ordinal_ << " is already configured: staying attached to it.";
18231833
}
18241834
}
1825-
current_config_ = cfg;
1835+
current_config_ = overridden_cfg;
18261836
configured_ = true;
18271837

18281838
if (!ipu_.DeviceAttached()) {
18291839
TF_RETURN_IF_ERROR(CreatePoplarTarget());
1830-
if (cfg.device_connection_type() == IpuDeviceConnectionType::ALWAYS) {
1840+
if (overridden_cfg.device_connection_type() ==
1841+
IpuDeviceConnectionType::ALWAYS) {
18311842
TF_RETURN_IF_ERROR(AttachToPoplarDevice());
18321843
}
18331844
}

tensorflow/python/ipu/BUILD

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,24 @@ tf_py_test(
10381038
],
10391039
)
10401040

1041+
tf_py_test(
1042+
name = "utils_use_on_demand_test",
1043+
size = "medium",
1044+
srcs = ["tests/utils_use_on_demand_test.py"],
1045+
# Shard count set equal to the number of tests as one process per test is
1046+
# required
1047+
shard_count = 2,
1048+
deps = [
1049+
"//tensorflow:tensorflow_py",
1050+
"//tensorflow/compiler/tests:xla_test",
1051+
"//tensorflow/python:framework_test_lib",
1052+
"//tensorflow/python/compiler/xla:compiler_py",
1053+
"//tensorflow/python/ipu:ipu_lib",
1054+
"//tensorflow/python/ipu/test_utils",
1055+
"//tensorflow/python/keras",
1056+
],
1057+
)
1058+
10411059
tf_py_test(
10421060
name = "utils_use_synthetic_data_for_test",
10431061
size = "medium",
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# =============================================================================
15+
from __future__ import absolute_import
16+
from __future__ import division
17+
from __future__ import print_function
18+
19+
import os
20+
21+
from tensorflow.compiler.plugin.poplar.driver.config_pb2 import IpuOptions
22+
from tensorflow.compiler.plugin.poplar.ops import gen_ipu_ops
23+
from tensorflow.python.client import session as session_lib
24+
from tensorflow.python.framework import ops
25+
from tensorflow.python.framework import test_util
26+
from tensorflow.python.ipu import config
27+
from tensorflow.python.ipu.utils import DeviceConnectionType
28+
from tensorflow.python.platform import googletest
29+
from tensorflow.python.platform import test
30+
import tensorflow.compat.v1 as tf
31+
tf.disable_v2_behavior()
32+
33+
34+
class UseOnDemandTest(test_util.TensorFlowTestCase):
35+
def _configure(self, configured_type, expected_type):
36+
cfg = config.IPUConfig()
37+
cfg.device_connection.type = configured_type
38+
cfg.configure_ipu_system()
39+
40+
# Get the current config.
41+
g = ops.Graph()
42+
with g.as_default():
43+
with ops.device("CPU"):
44+
with session_lib.Session(graph=g) as s:
45+
configurations = s.run(gen_ipu_ops.ipu_get_configuration())
46+
47+
self.assertEqual(len(configurations), 1)
48+
actual_cfg = IpuOptions()
49+
actual_cfg.ParseFromString(configurations[0])
50+
51+
self.assertEqual(actual_cfg.device_connection_type, expected_type)
52+
53+
@test_util.deprecated_graph_mode_only
54+
def testNoFlag(self):
55+
self._configure(DeviceConnectionType.ALWAYS,
56+
DeviceConnectionType.ALWAYS.value)
57+
58+
@test_util.deprecated_graph_mode_only
59+
def testFlag(self):
60+
flags = os.environ.get("TF_POPLAR_FLAGS", "") + ' --use_on_demand'
61+
with test.mock.patch.dict("os.environ", {"TF_POPLAR_FLAGS": flags}):
62+
self._configure(DeviceConnectionType.ALWAYS,
63+
DeviceConnectionType.ON_DEMAND.value)
64+
65+
66+
if __name__ == "__main__":
67+
os.environ['TF_XLA_FLAGS'] = ('--tf_xla_min_cluster_size=1' +
68+
os.environ.get('TF_XLA_FLAGS', ''))
69+
googletest.main()

0 commit comments

Comments
 (0)