Skip to content

Commit 32be6f5

Browse files
committed
Support nD scatters with multi update(add)
Summary: Fix T45278 Test Plan: CI, added new test Reviewers: jakeh, vladimirm, georgew, #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved Reviewed By: jakeh, #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved Maniphest Tasks: T45278 Differential Revision: https://phabricator.sourcevertex.net/D53132
1 parent 436cc37 commit 32be6f5

File tree

3 files changed

+152
-60
lines changed

3 files changed

+152
-60
lines changed

tensorflow/compiler/plugin/poplar/driver/passes/scatter_simplifier.cc

Lines changed: 110 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -38,36 +38,88 @@ namespace xla {
3838
namespace m = match;
3939
namespace poplarplugin {
4040
namespace {
41-
// TODO(T45278) popops::multiUpdate and popops::multiUpdateAdd only supports the
42-
// 2D case.
43-
bool CheckValidMultiUpdateAttributes(const HloScatterInstruction* inst) {
44-
const Shape operand_shape = inst->operand(0)->shape();
45-
const Shape indices_shape = inst->operand(1)->shape();
46-
const Shape updates_shape = inst->operand(2)->shape();
47-
const auto dim_numbers = inst->scatter_dimension_numbers();
48-
const auto update_window_dims = dim_numbers.update_window_dims();
49-
const auto inserted_window_dims = dim_numbers.inserted_window_dims();
50-
const auto scatter_dims_to_operand_dims =
41+
absl::optional<int64> GetScatterDimension(
42+
int64 rank, absl::Span<const int64> update_window_dims) {
43+
std::vector<int64> all_dims(rank);
44+
absl::c_iota(all_dims, 0);
45+
46+
std::vector<int64> scatter_dims;
47+
absl::c_set_difference(all_dims, update_window_dims,
48+
std::inserter(scatter_dims, scatter_dims.begin()));
49+
if (scatter_dims.size() == 1) {
50+
return scatter_dims[0];
51+
}
52+
return absl::nullopt;
53+
}
54+
55+
bool CheckValidMultiUpdateAttributes(const HloInstruction* inst) {
56+
const Shape& operand_shape = inst->operand(0)->shape();
57+
const Shape& indices_shape = inst->operand(1)->shape();
58+
const Shape& updates_shape = inst->operand(2)->shape();
59+
const auto& dim_numbers = inst->scatter_dimension_numbers();
60+
const auto& update_window_dims = dim_numbers.update_window_dims();
61+
const auto& inserted_window_dims = dim_numbers.inserted_window_dims();
62+
const auto& scatter_dims_to_operand_dims =
5163
dim_numbers.scatter_dims_to_operand_dims();
5264
const auto index_vector_dim = dim_numbers.index_vector_dim();
5365
const uint64 index_dim_size =
5466
indices_shape.rank() == index_vector_dim
5567
? 1
5668
: indices_shape.dimensions(index_vector_dim);
57-
return operand_shape.rank() == 2 && index_dim_size == 1 &&
58-
scatter_dims_to_operand_dims.size() == 1 &&
59-
scatter_dims_to_operand_dims[0] == 0 &&
60-
inserted_window_dims.size() == 1 && inserted_window_dims[0] == 0 &&
61-
update_window_dims.size() == 1 &&
62-
update_window_dims[0] == (updates_shape.rank() - 1);
69+
70+
if (updates_shape.rank() == 0) {
71+
return false;
72+
}
73+
74+
if (updates_shape.rank() != operand_shape.rank()) {
75+
return false;
76+
}
77+
78+
if (index_dim_size != 1) {
79+
return false;
80+
}
81+
82+
if (update_window_dims.size() != (updates_shape.rank() - 1)) {
83+
return false;
84+
}
85+
86+
auto scatter_dimension_opt = GetScatterDimension(
87+
updates_shape.rank(), AsInt64Slice(update_window_dims));
88+
if (!scatter_dimension_opt) {
89+
return false;
90+
}
91+
92+
const int64 scatter_dimension = *scatter_dimension_opt;
93+
94+
if (ShapeUtil::DeleteDimension(scatter_dimension, updates_shape) !=
95+
ShapeUtil::DeleteDimension(scatter_dimension, operand_shape)) {
96+
return false;
97+
}
98+
99+
if (scatter_dims_to_operand_dims.size() != 1) {
100+
return false;
101+
}
102+
103+
if (scatter_dims_to_operand_dims[0] != 0) {
104+
return false;
105+
}
106+
107+
if (inserted_window_dims.size() != 1) {
108+
return false;
109+
}
110+
111+
if (inserted_window_dims[0] != 0) {
112+
return false;
113+
}
114+
115+
return true;
63116
}
64117

65118
bool IsMultiUpdateScatter(const HloInstruction* inst) {
66119
if (inst->opcode() == HloOpcode::kScatter) {
67120
const HloScatterInstruction* scatter = Cast<HloScatterInstruction>(inst);
68121
const HloInstruction* root = inst->to_apply()->root_instruction();
69-
return Match(root, m::Parameter(1)) &&
70-
CheckValidMultiUpdateAttributes(scatter);
122+
return Match(root, m::Parameter(1));
71123
}
72124
return false;
73125
}
@@ -76,8 +128,7 @@ bool IsMultiUpdateAddScatter(const HloInstruction* inst) {
76128
if (inst->opcode() == HloOpcode::kScatter) {
77129
const HloScatterInstruction* scatter = Cast<HloScatterInstruction>(inst);
78130
const HloInstruction* root = inst->to_apply()->root_instruction();
79-
return Match(root, m::Add(m::Parameter(0), m::Parameter(1))) &&
80-
CheckValidMultiUpdateAttributes(scatter);
131+
return Match(root, m::Add(m::Parameter(0), m::Parameter(1)));
81132
}
82133
return false;
83134
}
@@ -86,32 +137,26 @@ bool IsConvertableScatter(const HloInstruction* inst) {
86137
return IsMultiUpdateScatter(inst) || IsMultiUpdateAddScatter(inst);
87138
}
88139

89-
StatusOr<HloInstruction*> MoveInstructionDimensionToBack(
90-
HloInstruction* inst, std::size_t dim_to_move) {
140+
std::vector<int64> GetPermutation(const HloInstruction* inst,
141+
std::size_t dim_to_move) {
91142
std::vector<int64> permutation(inst->shape().rank());
92-
for (size_t i = 0, next_idx = 0; i != permutation.size(); ++i) {
143+
permutation[0] = dim_to_move;
144+
for (size_t i = 0, next_idx = 1; i != permutation.size(); ++i) {
93145
if (i != dim_to_move) {
94146
permutation[next_idx++] = i;
95147
}
96148
}
97-
permutation.back() = dim_to_move;
98-
TF_ASSIGN_OR_RETURN(HloInstruction * new_inst,
99-
MakeTransposeHlo(inst, permutation));
100-
inst->SetupDerivedInstruction(new_inst);
101-
return new_inst;
149+
return permutation;
102150
}
103151

104-
StatusOr<HloInstruction*> CollapseAllButLastDimension(HloInstruction* inst) {
152+
StatusOr<HloInstruction*> CollapseAllButZerothDimension(HloInstruction* inst) {
105153
HloComputation* computation = inst->parent();
106154
const Shape inst_shape = inst->shape();
107-
const int64 last_dim = inst_shape.dimensions(inst_shape.rank() - 1);
108-
const Shape new_inst_shape = ShapeUtil::MakeShape(
155+
const int64 zero_dim = inst_shape.dimensions(0);
156+
const Shape new_shape = ShapeUtil::MakeShape(
109157
inst_shape.element_type(),
110-
{ShapeUtil::ElementsIn(inst_shape) / last_dim, last_dim});
111-
HloInstruction* new_inst = computation->AddInstruction(
112-
HloInstruction::CreateReshape(new_inst_shape, inst));
113-
inst->SetupDerivedInstruction(new_inst);
114-
return new_inst;
158+
{zero_dim, ShapeUtil::ElementsIn(inst_shape) / zero_dim});
159+
return MakeReshapeHlo(new_shape, inst);
115160
}
116161

117162
StatusOr<bool> ReplaceScatter(HloInstruction* scatter) {
@@ -120,38 +165,36 @@ StatusOr<bool> ReplaceScatter(HloInstruction* scatter) {
120165
HloComputation* computation = scatter->parent();
121166
auto dim_numbers = scatter->scatter_dimension_numbers();
122167
const int64 index_vector_dim = dim_numbers.index_vector_dim();
168+
const auto& update_window_dims = dim_numbers.update_window_dims();
123169
const bool is_update_add = IsMultiUpdateAddScatter(scatter);
124-
int64 update_dim = dim_numbers.update_window_dims()[0];
125170

126171
HloInstruction* operand = scatter->mutable_operand(0);
127172
HloInstruction* indices = scatter->mutable_operand(1);
128173
HloInstruction* updates = scatter->mutable_operand(2);
129174

130-
// If the indices are scalar then this is a dynamic-update-slice kind of
131-
// scatter.
132-
if (ShapeUtil::IsScalar(indices->shape())) {
133-
TF_ASSIGN_OR_RETURN(updates, PrependDegenerateDims(updates, 1));
134-
update_dim += 1;
135-
}
136-
137175
// Reshape the indices into a 2D shape [num_lookups, 1].
138176
const Shape& indices_shape = indices->shape();
139177
const Shape new_indices_shape = ShapeUtil::MakeShape(
140178
indices_shape.element_type(), {ShapeUtil::ElementsIn(indices_shape), 1});
141179
TF_ASSIGN_OR_RETURN(indices, MakeReshapeHlo(new_indices_shape, indices));
142180

143-
// Move the update_dim to the back.
144-
if ((updates->shape().rank() - 1) != update_dim) {
145-
TF_ASSIGN_OR_RETURN(updates,
146-
MoveInstructionDimensionToBack(updates, update_dim));
147-
update_dim = updates->shape().rank() - 1;
148-
}
181+
const int64 scatter_dimension = *GetScatterDimension(
182+
updates->shape().rank(), AsInt64Slice(update_window_dims));
149183

150-
// Collapse all but the last dimension of updates.
151-
if (updates->shape().rank() != 2) {
152-
TF_ASSIGN_OR_RETURN(updates, CollapseAllButLastDimension(updates));
153-
update_dim = 1;
154-
}
184+
const std::vector<int64> permutation =
185+
GetPermutation(operand, scatter_dimension);
186+
const std::vector<int64> invert_permutation =
187+
InvertPermutations<int64>(permutation);
188+
189+
// Move the scatter dimension to the front.
190+
TF_ASSIGN_OR_RETURN(operand, MakeTransposeHlo(operand, permutation));
191+
TF_ASSIGN_OR_RETURN(updates, MakeTransposeHlo(updates, permutation));
192+
193+
const Shape pre_flatten_operand_shape = operand->shape();
194+
195+
// Collapse all but the zeroth dimension.
196+
TF_ASSIGN_OR_RETURN(operand, CollapseAllButZerothDimension(operand));
197+
TF_ASSIGN_OR_RETURN(updates, CollapseAllButZerothDimension(updates));
155198

156199
HloInstruction* multi_update;
157200
if (is_update_add) {
@@ -160,12 +203,20 @@ StatusOr<bool> ReplaceScatter(HloInstruction* scatter) {
160203
computation->AddInstruction(HloInstruction::CreateConstant(
161204
LiteralUtil::One(scatter->shape().element_type())));
162205
multi_update = computation->AddInstruction(CreateMultiUpdateAdd(
163-
scatter->shape(), {operand, indices, updates, one}, update_dim));
206+
operand->shape(), {operand, indices, updates, one}));
164207
} else {
165-
multi_update = computation->AddInstruction(CreateMultiUpdate(
166-
scatter->shape(), {operand, indices, updates}, update_dim));
208+
multi_update = computation->AddInstruction(
209+
CreateMultiUpdate(operand->shape(), {operand, indices, updates}));
167210
}
168211
scatter->SetupDerivedInstruction(multi_update);
212+
213+
// Uncollapse the dimensions.
214+
TF_ASSIGN_OR_RETURN(multi_update,
215+
MakeReshapeHlo(pre_flatten_operand_shape, multi_update));
216+
// Undo the transpose.
217+
TF_ASSIGN_OR_RETURN(multi_update,
218+
MakeTransposeHlo(multi_update, invert_permutation));
219+
169220
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(scatter, multi_update));
170221
return true;
171222
}
@@ -185,7 +236,7 @@ StatusOr<bool> ScatterSimplifier::Run(HloModule* module) {
185236
// operands.
186237
auto insts = comp->MakeInstructionPostOrder();
187238
for (HloInstruction* inst : insts) {
188-
if (IsConvertableScatter(inst)) {
239+
if (IsConvertableScatter(inst) && CheckValidMultiUpdateAttributes(inst)) {
189240
TF_ASSIGN_OR_RETURN(bool replaced, ReplaceScatter(inst));
190241
changed |= replaced;
191242
}

tensorflow/compiler/plugin/poplar/tests/multi_update_combiner_test.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License.
1414
==============================================================================*/
1515
#include "tensorflow/compiler/plugin/poplar/driver/passes/multi_update_combiner.h"
1616
#include "tensorflow/compiler/plugin/poplar/driver/compiler_annotations.h"
17+
#include "tensorflow/compiler/plugin/poplar/driver/passes/poplar_algebraic_simplifier.h"
1718
#include "tensorflow/compiler/plugin/poplar/driver/passes/scatter_simplifier.h"
1819
#include "tensorflow/compiler/plugin/poplar/driver/tools/custom_ops/multi_slice.h"
1920
#include "tensorflow/compiler/plugin/poplar/driver/tools/data_initializer.h"
@@ -71,6 +72,7 @@ main {
7172

7273
ScatterSimplifier sc;
7374
EXPECT_TRUE(sc.Run(module).ValueOrDie());
75+
EXPECT_TRUE(PoplarAlgebraicSimplifier().Run(module).ValueOrDie());
7476
EXPECT_EQ(GetNumMultiUpdateAdds(module->entry_computation()), 2);
7577
HloCSE cse(false);
7678
EXPECT_TRUE(cse.Run(module).ValueOrDie());
@@ -166,6 +168,7 @@ main {
166168

167169
ScatterSimplifier sc;
168170
EXPECT_TRUE(sc.Run(module).ValueOrDie());
171+
EXPECT_TRUE(PoplarAlgebraicSimplifier().Run(module).ValueOrDie());
169172
EXPECT_EQ(GetNumMultiUpdateAdds(module->entry_computation()), 3);
170173
MultiUpdateCombiner mu_combiner(annotations);
171174
int64 execution_count = -1;
@@ -233,6 +236,7 @@ main {
233236

234237
ScatterSimplifier sc;
235238
EXPECT_TRUE(sc.Run(module).ValueOrDie());
239+
EXPECT_TRUE(PoplarAlgebraicSimplifier().Run(module).ValueOrDie());
236240
EXPECT_EQ(GetNumMultiUpdateAdds(module->entry_computation()), 3);
237241
MultiUpdateCombiner mu_combiner(annotations);
238242
int64 execution_count = -1;
@@ -304,6 +308,7 @@ main {
304308

305309
ScatterSimplifier sc;
306310
EXPECT_TRUE(sc.Run(module).ValueOrDie());
311+
EXPECT_TRUE(PoplarAlgebraicSimplifier().Run(module).ValueOrDie());
307312
EXPECT_EQ(GetNumMultiUpdateAdds(module->entry_computation()), 2);
308313
MultiUpdateCombiner mu_combiner(annotations);
309314
int64 execution_count = -1;
@@ -366,6 +371,7 @@ main {
366371

367372
ScatterSimplifier sc;
368373
EXPECT_TRUE(sc.Run(module).ValueOrDie());
374+
EXPECT_TRUE(PoplarAlgebraicSimplifier().Run(module).ValueOrDie());
369375
EXPECT_EQ(GetNumMultiUpdateAdds(module->entry_computation()), 2);
370376
HloCSE cse(false);
371377
cse.Run(module).ValueOrDie();
@@ -419,6 +425,7 @@ main {
419425

420426
ScatterSimplifier sc;
421427
EXPECT_TRUE(sc.Run(module).ValueOrDie());
428+
EXPECT_TRUE(PoplarAlgebraicSimplifier().Run(module).ValueOrDie());
422429
EXPECT_EQ(GetNumMultiUpdateAdds(module->entry_computation()), 2);
423430
HloCSE cse(false);
424431
cse.Run(module).ValueOrDie();

tensorflow/compiler/plugin/poplar/tests/scatter_simplifier_test.cc

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ main {
109109
EXPECT_TRUE(sc.Run(module).ValueOrDie());
110110
auto root = module->entry_computation()->root_instruction();
111111
EXPECT_EQ(GetNumMultiUpdates(module->entry_computation()), 2);
112-
auto mu0 = Cast<HloMultiUpdateInstruction>(root->operand(0));
113112
}
114113

115114
TEST_F(ScatterSimplifierTest, TestNotValid) {
@@ -141,6 +140,41 @@ main {
141140
EXPECT_EQ(GetNumMultiUpdates(module->entry_computation()), 0);
142141
}
143142

143+
TEST_F(ScatterSimplifierTest, TestMultiUpdateAddsMultiDim) {
144+
std::string hlo_string = R"(
145+
HloModule top
146+
scatter-combiner {
147+
p0 = f32[] parameter(0)
148+
p1 = f32[] parameter(1)
149+
ROOT add = f32[] add(p0, p1)
150+
}
151+
152+
main {
153+
arg0 = s32[15] parameter(0)
154+
arg1 = f32[15,10,10,1] parameter(1)
155+
zero = f32[] constant(0)
156+
big_zero = f32[2000,10,10,1] broadcast(zero), dimensions={}
157+
ROOT s1 = f32[2000,10,10,1] scatter(big_zero, arg0, arg1), update_window_dims={1,2,3}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=scatter-combiner
158+
}
159+
)";
160+
161+
HloModuleConfig config;
162+
config.set_debug_options(GetDebugOptionsForTest());
163+
164+
TF_ASSERT_OK_AND_ASSIGN(auto module,
165+
ParseAndReturnVerifiedModule(hlo_string, config));
166+
167+
EXPECT_TRUE(ScatterSimplifier().Run(module.get()).ValueOrDie());
168+
auto root = module->entry_computation()->root_instruction();
169+
EXPECT_EQ(GetNumMultiUpdateAdds(module->entry_computation()), 1);
170+
EXPECT_TRUE(Match(
171+
root,
172+
m::Transpose(m::Reshape(m::CustomCall(
173+
m::Reshape(m::Transpose(m::Broadcast(m::ConstantScalar(0)))),
174+
m::Reshape(m::Parameter(0)),
175+
m::Reshape(m::Transpose(m::Parameter(1))), m::ConstantScalar(1))))));
176+
}
177+
144178
} // namespace
145179
} // namespace poplarplugin
146180
} // namespace xla

0 commit comments

Comments
 (0)