Skip to content

Commit ea3b734

Browse files
Updating TF to support 2 prng seeds.
Summary: Updating PoplarCompiler to generate a replica identical prng seed during compilation. This seed gets used by the BaseVisitor to set the appropriate StochasticRoundingMethod for each instruction as it gets lowered. Seeds are then switched as we encounter instructions that require different methods. The seed state is stored in CompilerResources to keep it consistent during lowering, since multiple visitors will be used to lower the whole program. Seed switching is behind the experimental_prng_stability flag and so disabled by default. Test Plan: CI Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, georgep, jakeh, samuelh Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, georgep, samuelh Subscribers: alfiee, samuelh Maniphest Tasks: T41157 Differential Revision: https://phabricator.sourcevertex.net/D51984
1 parent 10b2aa7 commit ea3b734

File tree

10 files changed

+188
-8
lines changed

10 files changed

+188
-8
lines changed

tensorflow/compiler/plugin/poplar/driver/compiler_resources.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ limitations under the License.
3636
#include "tensorflow/compiler/plugin/poplar/driver/compiler_information.h"
3737
#include "tensorflow/compiler/plugin/poplar/driver/config.pb.h"
3838
#include "tensorflow/compiler/plugin/poplar/driver/passes/convolution_classifier.h"
39+
#include "tensorflow/compiler/plugin/poplar/driver/prng_seed_state.h"
3940
#include "tensorflow/compiler/plugin/poplar/driver/tools/execution_counter_util.h"
4041
#include "tensorflow/compiler/plugin/poplar/driver/tools/generic_graph_caching.h"
4142
#include "tensorflow/compiler/plugin/poplar/driver/tools/mapping_helper.h"
@@ -186,6 +187,8 @@ struct CompilerResources {
186187
// The implementation of the progress bar.
187188
std::unique_ptr<ProgressBarBase> progress_bar;
188189

190+
PrngSeedState prng_seed_state;
191+
189192
CompilerResources(
190193
HloModule* module, const CompilerInformation& information,
191194
const poplar::OptionFlags& conv_options,

tensorflow/compiler/plugin/poplar/driver/poplar_compiler.cc

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,20 @@ limitations under the License.
2323

2424
#include <algorithm>
2525
#include <fstream>
26+
#include <gcl/Collectives.hpp>
2627
#include <gcl/TileAllocation.hpp>
2728
#include <limits>
2829
#include <mutex>
2930
#include <popfloat/experimental/codelets.hpp>
3031
#include <poplar/CSRFunctions.hpp>
3132
#include <poplar/CodeletFileType.hpp>
3233
#include <poplar/CycleCount.hpp>
34+
#include <poplar/RandomSeed.hpp>
3335
#include <poplar/exceptions.hpp>
3436
#include <poplar/replication_factor.hpp>
3537
#include <poplin/codelets.hpp>
3638
#include <popnn/codelets.hpp>
39+
#include <popops/Cast.hpp>
3740
#include <popops/codelets.hpp>
3841
#include <poprand/RandomGen.hpp>
3942
#include <poprand/codelets.hpp>
@@ -537,7 +540,7 @@ HloPrintOptions GetPrintOptions() {
537540
}
538541

539542
StatusOr<poplar::program::Program> InitializeSeed(
540-
poplar::Graph& graph, int replication_factor,
543+
poplar::Graph& graph, int replication_factor, CompilerResources& resources,
541544
const poplar::DebugContext& debug_context = {"__seed"}) {
542545
PoplarOpDefDebugInfo debug_info(debug_context, "InitializeSeed");
543546

@@ -561,7 +564,29 @@ StatusOr<poplar::program::Program> InitializeSeed(
561564
initializer.GetData(ShapeUtil::MakeShape(U32, {2})));
562565
TF_RETURN_IF_ERROR(SetInitialTensorValue(graph, seed, literal));
563566
}
564-
poprand::setSeed(graph, seed, 0, seq, {debug_info, "set"});
567+
568+
if (replication_factor > 1 && resources.enable_experimental_prng_stability) {
569+
// When running a replicated graph we need an additional seed, which is
570+
// identical for all IPUs. This gets used to ensure determinism when
571+
// performing stochastic rounding.
572+
poplar::Tensor identical_seed =
573+
// Using MEAN to preserve the seed value incase it's been explicitly set
574+
// to be the same for each replica.
575+
popops::cast(graph,
576+
gcl::allReduceCrossReplica(
577+
graph, popops::cast(graph, seed, poplar::FLOAT, seq),
578+
popops::CollectiveOperator::MEAN, seq,
579+
{debug_info, "allReduceSeed"}),
580+
poplar::UNSIGNED_INT, seq);
581+
582+
resources.prng_seed_state =
583+
PrngSeedState::SetupSeeds(graph, identical_seed, seed, seq);
584+
} else {
585+
resources.prng_seed_state = PrngSeedState::SetupSeed(graph, seed, seq);
586+
}
587+
588+
resources.prng_seed_state.ChangeStochasticRoundingMethod(
589+
StochasticRoundingMethod_DifferingSeeds, seq);
565590

566591
return seq;
567592
}
@@ -1679,6 +1704,11 @@ StatusOr<std::unique_ptr<PoplarExecutableCore>> CompileEngine(
16791704
TF_RETURN_IF_ERROR(CreatePoplarGraphs(resources, module, poplar_executor));
16801705
auto& main_graph = GetMasterGraph(resources);
16811706

1707+
// Set up the random seed.
1708+
TF_ASSIGN_OR_RETURN(
1709+
auto seed_setup,
1710+
InitializeSeed(main_graph, replication_factor, resources));
1711+
16821712
EntryVisitor visitor(resources, entry);
16831713
try {
16841714
Tracepoint tracepoint("PoplarGraphConstruction");
@@ -1720,9 +1750,6 @@ StatusOr<std::unique_ptr<PoplarExecutableCore>> CompileEngine(
17201750
main_program.add(poplar::program::Sync(poplar::SyncType::GLOBAL));
17211751
}
17221752

1723-
// Set up the random seed.
1724-
TF_ASSIGN_OR_RETURN(auto seed_setup,
1725-
InitializeSeed(main_graph, replication_factor));
17261753
main_program.add(seed_setup);
17271754

17281755
// Set up the floating point control register if required

tensorflow/compiler/plugin/poplar/driver/prng_seed_state.cc

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,23 @@ std::unique_ptr<poputil::graphfn::TensorFunction> CreateChangeHwSeedsFn(
4545
}
4646
} // namespace
4747

48+
/*static*/ PrngSeedState PrngSeedState::SetupSeed(
49+
poplar::Graph& graph, poplar::Tensor& seed,
50+
poplar::program::Sequence& seq) {
51+
const poplar::DebugContext& debug_context = {"__seed"};
52+
PoplarOpDefDebugInfo debug_info(debug_context, "InitializeSeed");
53+
54+
poprand::setSeed(graph, seed, 0, seq, {debug_info, "set"});
55+
auto differing_hw_seed = poplar::getHwSeeds(graph, seq, {debug_info, "get"});
56+
57+
// We want the behaviour to be consistent whether we're running
58+
// with single or multiple seeds, so even when there's no replication
59+
// we pretend that there's a separate identical seed and do everything
60+
// else as normal.
61+
return PrngSeedState(graph, StochasticRoundingMethod_DifferingSeeds,
62+
differing_hw_seed, differing_hw_seed);
63+
}
64+
4865
/*static*/ PrngSeedState PrngSeedState::SetupSeeds(
4966
poplar::Graph& graph, poplar::Tensor& identical_seed,
5067
poplar::Tensor& differing_seed, poplar::program::Sequence& seq) {
@@ -57,7 +74,7 @@ std::unique_ptr<poputil::graphfn::TensorFunction> CreateChangeHwSeedsFn(
5774

5875
poprand::setSeed(graph, differing_seed, 0, seq, {debug_info, "setDistinct"});
5976
auto differing_hw_seed =
60-
poplar::getHwSeeds(graph, seq, {debug_info, "getIdenticalHw"});
77+
poplar::getHwSeeds(graph, seq, {debug_info, "getDistinctHw"});
6178

6279
// Speciifying DifferingSeeds since the last seed set was the differing one.
6380
return PrngSeedState(graph, StochasticRoundingMethod_DifferingSeeds,

tensorflow/compiler/plugin/poplar/driver/prng_seed_state.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ namespace poplarplugin {
3232
// flag.
3333
class PrngSeedState {
3434
public:
35+
// Create a PrngSeedState from single or multiple seeds. This corresponds
36+
// to whether we're running with replication (2 seeds) or not (1 seed).
37+
static PrngSeedState SetupSeed(poplar::Graph& graph, poplar::Tensor& seed,
38+
poplar::program::Sequence& seq);
3539
static PrngSeedState SetupSeeds(poplar::Graph& graph,
3640
poplar::Tensor& identical_seed,
3741
poplar::Tensor& differing_seed,

tensorflow/compiler/plugin/poplar/driver/visitors/visitor_arithmetic_expr.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,13 @@ ArithmeticExprVisitor::ArithmeticExprVisitor(
4646
const poplar::DebugNameAndId& debug_name_and_id)
4747
: BaseVisitor(res, debug_name_and_id),
4848
inputs_(std::move(inputs)),
49-
caller_(caller) {}
49+
caller_(caller) {
50+
// We don need to change the seed as we visit instructions, since we will not
51+
// be executing them individually but as fused unit via a single poplar call.
52+
// Hence the seed we need to use will be set by the calling
53+
// instruction/visitor.
54+
allow_seed_changes_ = false;
55+
}
5056

5157
StatusOr<std::unique_ptr<popops::expr::Expr>>
5258
ArithmeticExprVisitor::FindExpressionInput(const HloInstruction* inst) {

tensorflow/compiler/plugin/poplar/driver/visitors/visitor_base.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,32 @@ Status BaseVisitor::Preprocess(HloInstruction* inst) {
8888
stochastic_rounding_enabled_ = new_stochastic_rounding_enabled;
8989
}
9090

91+
if (allow_seed_changes_ && stochastic_rounding_enabled_) {
92+
poplar::DebugNameAndId debug_name_and_id{"changeSRMethod"};
93+
poplar::program::Sequence seq({}, debug_name_and_id);
94+
95+
const auto new_sr_method =
96+
poplar_backend_config.stochastic_rounding_method();
97+
if (resources_.prng_seed_state.ChangeStochasticRoundingMethod(
98+
new_sr_method, seq, debug_name_and_id)) {
99+
AddSequenceForInstruction(inst, seq);
100+
}
101+
102+
VLOG(3) << "Using SR method "
103+
<< StochasticRoundingMethod_Name(
104+
resources_.prng_seed_state.GetStochasticRoundingMethod())
105+
<< " for instruction '" << inst->name() << "'";
106+
}
107+
91108
return Status::OK();
92109
}
93110

94111
BaseVisitor::BaseVisitor(CompilerResources& resources,
95112
const poplar::DebugNameAndId& debug_name_and_id)
96113
: resources_(resources),
97114
dnai_(debug_name_and_id),
98-
execution_counters_(resources, debug_name_and_id) {
115+
execution_counters_(resources, debug_name_and_id),
116+
allow_seed_changes_(resources.enable_experimental_prng_stability) {
99117
stochastic_rounding_enabled_ =
100118
resources_.global_floating_point_behaviour.esr();
101119

tensorflow/compiler/plugin/poplar/driver/visitors/visitor_base.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,10 @@ class BaseVisitor : public DfsHloVisitor {
207207
// Scope execution counters.
208208
ExecutionCounters execution_counters_;
209209

210+
// Control whether changing seeds is allowed during instruction lowering,
211+
// used to improve prng stability.
212+
bool allow_seed_changes_ = false;
213+
210214
private:
211215
Status CreateSequenceGroupedByInstruction(
212216
const HloInstruction* inst, const poplar::program::Sequence& seq);

tensorflow/python/ipu/tests/functional_ops_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def body(x, labels):
258258
# pylint: disable=line-too-long
259259
ok = [
260260
'__seed/set/setMasterSeed',
261+
'__seed/get/getSeeds',
261262
'matmul/dot*/Conv_1',
262263
'add_0/fusion/Op/Add',
263264
'Sigmoid/sigmoid/Nonlinearity',

tensorflow/python/ipu/tests/multi_conv_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def inputs_fn():
311311
'ipu/Mean*/reduce*/Reduce',
312312
'ipu/add',
313313
'__seed/set/setMasterSeed',
314+
'__seed/get/getSeeds',
314315
'[cC]opy',
315316
]
316317
_compare_ipu_to_cpu(self,

tensorflow/python/ipu/tests/replicated_seed_control_test.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,14 @@
1818

1919
from tensorflow.python import ipu
2020
from tensorflow.python.client import session as session_lib
21+
from tensorflow.python.framework import constant_op
2122
from tensorflow.python.framework import test_util
23+
from tensorflow.python.framework import random_seed
24+
from tensorflow.python.data.ops import dataset_ops
2225
from tensorflow.python.ops import array_ops
2326
from tensorflow.python.ops import math_ops
27+
from tensorflow.python.ops import random_ops
28+
from tensorflow.compiler.plugin.poplar.ops import gen_popops_ops
2429
from tensorflow.compiler.plugin.poplar.tests import test_utils as tu
2530
from tensorflow.python.platform import googletest
2631

@@ -170,6 +175,100 @@ def test_identical_replica_seeds(self):
170175
self.assertNotAllEqual(res1, res1_second)
171176
self.assertNotAllEqual(res2, res2_second)
172177

178+
@tu.test_uses_ipus(num_ipus=2)
179+
@test_util.deprecated_graph_mode_only
180+
def test_experimental_identical_seeds(self):
181+
inp = array_ops.placeholder(np.float32, [10])
182+
183+
# This should produce the same result on each IPU
184+
# as the stochastic rounding should be performed
185+
# using the same seed.
186+
with ipu.scopes.ipu_scope('/device:IPU:0'):
187+
cast1 = math_ops.cast(inp, dtype=np.float16)
188+
out0 = gen_popops_ops.ipu_all_gather(cast1, replication_factor=2)
189+
190+
# Configure the hardware
191+
config = IPUConfig()
192+
config.auto_select_ipus = [2]
193+
config.experimental.enable_prng_stability = True
194+
tu.add_hw_ci_connection_options(config)
195+
# Enable stochastic rounding
196+
config.floating_point_behaviour.esr = True
197+
config.configure_ipu_system()
198+
199+
in_data = np.array([0.1] * 10)
200+
201+
with session_lib.Session() as sess:
202+
res = sess.run(out0, {inp: in_data}).astype(np.float32)
203+
# Compare the result of each IPU
204+
self.assertAllEqual(res[0], res[1])
205+
206+
@tu.test_uses_ipus(num_ipus=4)
207+
@test_util.deprecated_graph_mode_only
208+
def test_experimental_identical_compute(self):
209+
config = IPUConfig()
210+
config.auto_select_ipus = [4]
211+
config.experimental.enable_prng_stability = True
212+
tu.add_hw_ci_connection_options(config)
213+
# Enable stochastic rounding
214+
config.floating_point_behaviour.esr = True
215+
config.configure_ipu_system()
216+
217+
random_seed.set_seed(1234)
218+
dataset = dataset_ops.Dataset.from_tensor_slices([[[0.1, 0.2, 0.3],
219+
[0.4, 0.5, 0.6]]])
220+
dataset = dataset.repeat()
221+
222+
infeed = ipu.ipu_infeed_queue.IPUInfeedQueue(dataset)
223+
outfeed = ipu.ipu_outfeed_queue.IPUOutfeedQueue()
224+
225+
noise = random_ops.random_uniform([2, 3], dtype=np.float32, seed=1)
226+
227+
def my_net():
228+
# Perform some arbitrary simple compute that invokes SR.
229+
def body(noise, infeed_value):
230+
infeed_value = ipu.cross_replica_ops.assume_equal_across_replicas(
231+
infeed_value)
232+
233+
noise_f16 = math_ops.cast(noise, dtype=np.float16)
234+
infeed_val_f16 = math_ops.cast(infeed_value, dtype=np.float16)
235+
236+
result = infeed_val_f16 + noise_f16
237+
238+
const = constant_op.constant(0.01,
239+
shape=result.shape,
240+
dtype=np.float32)
241+
const_f16 = math_ops.cast(const, dtype=np.float16)
242+
243+
result = result * const_f16
244+
result_f32 = math_ops.cast(result, dtype=np.float32)
245+
246+
out = outfeed.enqueue(result_f32)
247+
return (result_f32, out)
248+
249+
r = ipu.loops.repeat(10, body, [noise], infeed)
250+
return r
251+
252+
with ipu.scopes.ipu_scope("/device:IPU:0"):
253+
res = ipu.ipu_compiler.compile(my_net, inputs=[])
254+
255+
outqueue = outfeed.dequeue()
256+
with session_lib.Session() as sess:
257+
sess.run(infeed.initializer)
258+
sess.run(res)
259+
results = sess.run(outqueue)
260+
261+
# This contains a value per replica for each iteration of the loop. We want to test that
262+
# for each iteration those values are all equal, since each replica is using the same data and
263+
# we've enabled experimental prng stability.
264+
self.assertEqual(len(results), 10)
265+
for replica_results in results:
266+
replicas_equal = [
267+
replica_results[0].tolist() == result.tolist()
268+
for result in replica_results
269+
]
270+
self.assertTrue(all(replicas_equal))
271+
173272

174273
if __name__ == "__main__":
175274
googletest.main()

0 commit comments

Comments
 (0)