Skip to content

Commit 748c31f

Browse files
author
Frederik Mellbye
committed
Add popdist::run to prevent HST in TF2
Reviewers: christiana, #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, georgep Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, georgep Maniphest Tasks: T63781 Differential Revision: https://phabricator.sourcevertex.net/D72533
1 parent 7ccafa5 commit 748c31f

File tree

3 files changed

+10
-1
lines changed

3 files changed

+10
-1
lines changed

tensorflow/compiler/plugin/poplar/driver/poplar_executor.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525
#include <utility>
2626
#include <vector>
2727

28+
#include <popdist/backend.hpp>
2829
#include <poplar/DeviceManager.hpp>
2930
#include <poplar/IPUModel.hpp>
3031
#include <poplar/StreamCallback.hpp>
@@ -3788,7 +3789,7 @@ Status PoplarExecutor::ExecuteEngineImpl(se::DeviceMemoryBase* result_buffer,
37883789

37893790
// Run the main engine
37903791
current_engine_->enableExecutionProfiling();
3791-
current_engine_->run(PoplarProgramType::MAIN_SEQUENCE);
3792+
popdist::run(*current_engine_, PoplarProgramType::MAIN_SEQUENCE);
37923793

37933794
StopIOThreads();
37943795

tensorflow/compiler/plugin/poplar/tests/poprun_basic_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828

2929

3030
class PoprunBasicTest(test_util.TensorFlowTestCase): # pylint: disable=abstract-method
31+
@classmethod
32+
def setUpClass(cls):
33+
popdist.init()
34+
3135
@tu.test_uses_ipus(num_ipus=4)
3236
@test_util.deprecated_graph_mode_only
3337
def test_cross_replica_sum(self):

tensorflow/python/ipu/tests/poprun_replica_partitioning_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ def cross_replica_reduce(self, grad):
7171

7272

7373
class PoprunReplicaPartitioningTest(test.TestCase):
74+
@classmethod
75+
def setUpClass(cls):
76+
popdist.init()
77+
7478
def _compare_partitioned_to_non_partitioned(self, stages, repeat_count,
7579
gradient_accumulation_count,
7680
dataset_fn, optimizer_fn):

0 commit comments

Comments
 (0)