Skip to content

Commit f35e9c8

Browse files
committed
Fixing RTS for LAMB
Summary: These changes are an update to the elementwise clustering algorithm to allow RTS to work with the LAMB optimiser. RTS was not working with LAMB because the clusters in LAMB had intermediate scalar values which were not included. In the old algorithm, all scalars were excluded by default to prevent otherwise unrelated clusters from being joined via hyperparameters. There was a mechanism to allow certain scalar values into clusters: - Recursively search operands for all non-scalar values which the scalar depends on. - If all of the non-scalar values have the same shape as the top of the cluster, add the scalar to the cluster. The problem with this was that in LAMB with BERT, there are scalars which depend on a large number of non-scalar values. The algorithm doesn't consider if the ops it's looking at are actually clusterable (replica-identical and supported op type) so it ends up checking lots of irrelevant ops and failing. With the new algorithm, all ops in a given cluster are the same shape when the clusters are first created. Then the merging step is used to allow for clusters with differently shaped intermediate values. This guarantees that only clusterable ops get considered. Since the clusters are already constructed, we can more robustely determine if one cluster contains intermediates for another. Changes to the algorithm: - Renaming some functions to be less confusing. - Using `ReplicaIdenticalDataflowAnalysis` to identify more accurately which ops are replica identical. This allowed the the removal of code which precalculated which fusion computations are clusterable. - Moving all logic which determines if an op is clusterable into one function (IsClusterable). - Tightening the requirements for which ops can be added to a cluster during cluster creation. The shape of the op must match the shape of the top of the cluster. This means initially, all ops in a cluster have the same shape. - Changing the merging logic so that normally clusters will only be merged if their top elements have the same shape. The exception to this is when a cluster `b` is surrounded by another cluster `a` (`a` directly uses the outputs of `b` but `b`'s inputs are reachable from `a`), in which case they are merged even if the shapes don't match. This essentially allows clusters to have differently shaped intermediate values. Test Plan: There are existing tests (`resource_update_elementwise_clustering_test.cc` and `replicated_resource_update_elementwise_clustering_test.cc`). I have added a test based on the issue which was preventing RTS working with LAMB. The test fails before the changes and passes after. Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, gauthamg, georgep Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, gauthamg Subscribers: gauthamg Maniphest Tasks: T65700 Differential Revision: https://phabricator.sourcevertex.net/D72891
1 parent 35f4537 commit f35e9c8

10 files changed

+274
-258
lines changed

tensorflow/compiler/plugin/poplar/driver/passes/replicated_resource_update_elementwise_clustering.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ ReplicatedResourceUpdateElementwiseClustering::AddClusterInput(
219219
// Lower the all reduce into the cluster if all its users will be in the
220220
// cluster too.
221221
const bool lower_all_reduce = IsGlobalAllReduce(cluster_input) &&
222-
cluster.AllUsersIn(cluster_input) &&
222+
cluster.ContainsAllUsersOf(cluster_input) &&
223223
status_or_collective_op.ok();
224224

225225
if (lower_all_reduce) {

tensorflow/compiler/plugin/poplar/driver/passes/resource_update_elementwise_clustering.cc

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,7 @@ ResourceUpdateElementwiseClustering::CreateValidator(
110110

111111
StatusOr<std::vector<ElementwiseCluster>>
112112
ResourceUpdateElementwiseClustering::GetClustersIn(
113-
HloInstruction* const call,
114-
const absl::flat_hash_set<const HloComputation*>& elementwise_comps) const {
113+
HloInstruction* const call) const {
115114
CHECK(IsRepeatLoop(call) || IsPipelineOp(call));
116115
HloComputation* call_comp = call->to_apply();
117116
// Make sure that the root of the call op is a tuple instruction.
@@ -149,9 +148,9 @@ ResourceUpdateElementwiseClustering::GetClustersIn(
149148
}
150149

151150
auto validator = CreateValidator(resource_update_comp);
152-
TF_ASSIGN_OR_RETURN(std::vector<ElementwiseCluster> clusters,
153-
ElementwiseCluster::GetClustersIn(
154-
resource_update, elementwise_comps, *validator));
151+
TF_ASSIGN_OR_RETURN(
152+
std::vector<ElementwiseCluster> clusters,
153+
ElementwiseCluster::GetClustersIn(resource_update, *validator));
155154
// Try to print some helpful warnings if things don't look right.
156155
TF_RETURN_IF_ERROR(
157156
ValidateResourceUpdateAndClusters(resource_update, clusters));
@@ -239,7 +238,7 @@ StatusOr<HloInstruction*> ResourceUpdateElementwiseClustering::AddClusterInput(
239238
// Lower the all reduce into the cluster if all its users will be in the
240239
// cluster too.
241240
const bool lower_all_reduce =
242-
IsAllReduce(cluster_input) && cluster.AllUsersIn(cluster_input);
241+
IsAllReduce(cluster_input) && cluster.ContainsAllUsersOf(cluster_input);
243242

244243
if (lower_all_reduce) {
245244
HloInstruction* input = cluster_input->mutable_operand(0);
@@ -398,10 +397,9 @@ Status ResourceUpdateElementwiseClustering::UpdateClusterBackendConfig(
398397
}
399398

400399
StatusOr<bool> ResourceUpdateElementwiseClustering::RewriteCall(
401-
HloModule* module, HloInstruction* call,
402-
const absl::flat_hash_set<const HloComputation*>& elementwise_comps) const {
400+
HloModule* module, HloInstruction* call) const {
403401
TF_ASSIGN_OR_RETURN(std::vector<ElementwiseCluster> clusters,
404-
GetClustersIn(call, elementwise_comps));
402+
GetClustersIn(call));
405403

406404
if (clusters.empty()) {
407405
VLOG(3) << "No clusters found.";
@@ -466,13 +464,9 @@ StatusOr<bool> ResourceUpdateElementwiseClustering::Run(HloModule* module) {
466464

467465
TF_RETURN_IF_ERROR(RunDataflowAnalysis(module));
468466

469-
const absl::flat_hash_set<const HloComputation*> elementwise_comps =
470-
ElementwiseCluster::GetElementwiseClusterableComputations(module);
471-
472467
bool module_changed = false;
473468
for (auto call : to_optimize) {
474-
TF_ASSIGN_OR_RETURN(auto changed,
475-
RewriteCall(module, call, elementwise_comps));
469+
TF_ASSIGN_OR_RETURN(auto changed, RewriteCall(module, call));
476470
if (changed) {
477471
module_changed = true;
478472
}

tensorflow/compiler/plugin/poplar/driver/passes/resource_update_elementwise_clustering.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ class ResourceUpdateElementwiseClustering : public HloModulePass {
5555
// Get clusters inside of the call, where the call has to be a repeat loop or
5656
// a pipeline.
5757
StatusOr<std::vector<ElementwiseCluster>> GetClustersIn(
58-
HloInstruction* const call,
59-
const absl::flat_hash_set<const HloComputation*>& elementwise_comps)
60-
const;
58+
HloInstruction* const call) const;
6159

6260
// Outline the provided cluster - returns the call instruction to the cluster.
6361
StatusOr<HloInstruction*> OutlineCluster(ElementwiseCluster& cluster) const;
@@ -103,9 +101,7 @@ class ResourceUpdateElementwiseClustering : public HloModulePass {
103101
const HloInstruction* ru, std::vector<ElementwiseCluster> clusters) const;
104102

105103
private:
106-
StatusOr<bool> RewriteCall(HloModule* module, HloInstruction* call,
107-
const absl::flat_hash_set<const HloComputation*>&
108-
elementwise_comps) const;
104+
StatusOr<bool> RewriteCall(HloModule* module, HloInstruction* call) const;
109105

110106
StatusOr<HloInstruction*> AddClusterInputToOutlinedComputation(
111107
int64_t param_idx, const ElementwiseCluster& cluster,

tensorflow/compiler/plugin/poplar/driver/tools/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,7 @@ poplar_cc_library(
10151015
":matcher_predicates",
10161016
":offloading_util",
10171017
":pipeline_util",
1018+
":replica_identical_dataflow_analysis",
10181019
":util",
10191020
"//tensorflow/compiler/plugin/poplar/driver/tools/custom_ops:all_gather",
10201021
"//tensorflow/compiler/plugin/poplar/driver/tools/custom_ops:reduce_scatter",

0 commit comments

Comments
 (0)