@@ -38,36 +38,88 @@ namespace xla {
3838namespace m = match;
3939namespace poplarplugin {
4040namespace {
41- // TODO(T45278) popops::multiUpdate and popops::multiUpdateAdd only supports the
42- // 2D case.
43- bool CheckValidMultiUpdateAttributes (const HloScatterInstruction* inst) {
44- const Shape operand_shape = inst->operand (0 )->shape ();
45- const Shape indices_shape = inst->operand (1 )->shape ();
46- const Shape updates_shape = inst->operand (2 )->shape ();
47- const auto dim_numbers = inst->scatter_dimension_numbers ();
48- const auto update_window_dims = dim_numbers.update_window_dims ();
49- const auto inserted_window_dims = dim_numbers.inserted_window_dims ();
50- const auto scatter_dims_to_operand_dims =
41+ absl::optional<int64> GetScatterDimension (
42+ int64 rank, absl::Span<const int64> update_window_dims) {
43+ std::vector<int64> all_dims (rank);
44+ absl::c_iota (all_dims, 0 );
45+
46+ std::vector<int64> scatter_dims;
47+ absl::c_set_difference (all_dims, update_window_dims,
48+ std::inserter (scatter_dims, scatter_dims.begin ()));
49+ if (scatter_dims.size () == 1 ) {
50+ return scatter_dims[0 ];
51+ }
52+ return absl::nullopt ;
53+ }
54+
55+ bool CheckValidMultiUpdateAttributes (const HloInstruction* inst) {
56+ const Shape& operand_shape = inst->operand (0 )->shape ();
57+ const Shape& indices_shape = inst->operand (1 )->shape ();
58+ const Shape& updates_shape = inst->operand (2 )->shape ();
59+ const auto & dim_numbers = inst->scatter_dimension_numbers ();
60+ const auto & update_window_dims = dim_numbers.update_window_dims ();
61+ const auto & inserted_window_dims = dim_numbers.inserted_window_dims ();
62+ const auto & scatter_dims_to_operand_dims =
5163 dim_numbers.scatter_dims_to_operand_dims ();
5264 const auto index_vector_dim = dim_numbers.index_vector_dim ();
5365 const uint64 index_dim_size =
5466 indices_shape.rank () == index_vector_dim
5567 ? 1
5668 : indices_shape.dimensions (index_vector_dim);
57- return operand_shape.rank () == 2 && index_dim_size == 1 &&
58- scatter_dims_to_operand_dims.size () == 1 &&
59- scatter_dims_to_operand_dims[0 ] == 0 &&
60- inserted_window_dims.size () == 1 && inserted_window_dims[0 ] == 0 &&
61- update_window_dims.size () == 1 &&
62- update_window_dims[0 ] == (updates_shape.rank () - 1 );
69+
70+ if (updates_shape.rank () == 0 ) {
71+ return false ;
72+ }
73+
74+ if (updates_shape.rank () != operand_shape.rank ()) {
75+ return false ;
76+ }
77+
78+ if (index_dim_size != 1 ) {
79+ return false ;
80+ }
81+
82+ if (update_window_dims.size () != (updates_shape.rank () - 1 )) {
83+ return false ;
84+ }
85+
86+ auto scatter_dimension_opt = GetScatterDimension (
87+ updates_shape.rank (), AsInt64Slice (update_window_dims));
88+ if (!scatter_dimension_opt) {
89+ return false ;
90+ }
91+
92+ const int64 scatter_dimension = *scatter_dimension_opt;
93+
94+ if (ShapeUtil::DeleteDimension (scatter_dimension, updates_shape) !=
95+ ShapeUtil::DeleteDimension (scatter_dimension, operand_shape)) {
96+ return false ;
97+ }
98+
99+ if (scatter_dims_to_operand_dims.size () != 1 ) {
100+ return false ;
101+ }
102+
103+ if (scatter_dims_to_operand_dims[0 ] != 0 ) {
104+ return false ;
105+ }
106+
107+ if (inserted_window_dims.size () != 1 ) {
108+ return false ;
109+ }
110+
111+ if (inserted_window_dims[0 ] != 0 ) {
112+ return false ;
113+ }
114+
115+ return true ;
63116}
64117
65118bool IsMultiUpdateScatter (const HloInstruction* inst) {
66119 if (inst->opcode () == HloOpcode::kScatter ) {
67120 const HloScatterInstruction* scatter = Cast<HloScatterInstruction>(inst);
68121 const HloInstruction* root = inst->to_apply ()->root_instruction ();
69- return Match (root, m::Parameter (1 )) &&
70- CheckValidMultiUpdateAttributes (scatter);
122+ return Match (root, m::Parameter (1 ));
71123 }
72124 return false ;
73125}
@@ -76,8 +128,7 @@ bool IsMultiUpdateAddScatter(const HloInstruction* inst) {
76128 if (inst->opcode () == HloOpcode::kScatter ) {
77129 const HloScatterInstruction* scatter = Cast<HloScatterInstruction>(inst);
78130 const HloInstruction* root = inst->to_apply ()->root_instruction ();
79- return Match (root, m::Add (m::Parameter (0 ), m::Parameter (1 ))) &&
80- CheckValidMultiUpdateAttributes (scatter);
131+ return Match (root, m::Add (m::Parameter (0 ), m::Parameter (1 )));
81132 }
82133 return false ;
83134}
@@ -86,32 +137,26 @@ bool IsConvertableScatter(const HloInstruction* inst) {
86137 return IsMultiUpdateScatter (inst) || IsMultiUpdateAddScatter (inst);
87138}
88139
89- StatusOr<HloInstruction*> MoveInstructionDimensionToBack (
90- HloInstruction* inst, std::size_t dim_to_move) {
140+ std::vector<int64> GetPermutation ( const HloInstruction* inst,
141+ std::size_t dim_to_move) {
91142 std::vector<int64> permutation (inst->shape ().rank ());
92- for (size_t i = 0 , next_idx = 0 ; i != permutation.size (); ++i) {
143+ permutation[0 ] = dim_to_move;
144+ for (size_t i = 0 , next_idx = 1 ; i != permutation.size (); ++i) {
93145 if (i != dim_to_move) {
94146 permutation[next_idx++] = i;
95147 }
96148 }
97- permutation.back () = dim_to_move;
98- TF_ASSIGN_OR_RETURN (HloInstruction * new_inst,
99- MakeTransposeHlo (inst, permutation));
100- inst->SetupDerivedInstruction (new_inst);
101- return new_inst;
149+ return permutation;
102150}
103151
104- StatusOr<HloInstruction*> CollapseAllButLastDimension (HloInstruction* inst) {
152+ StatusOr<HloInstruction*> CollapseAllButZerothDimension (HloInstruction* inst) {
105153 HloComputation* computation = inst->parent ();
106154 const Shape inst_shape = inst->shape ();
107- const int64 last_dim = inst_shape.dimensions (inst_shape. rank () - 1 );
108- const Shape new_inst_shape = ShapeUtil::MakeShape (
155+ const int64 zero_dim = inst_shape.dimensions (0 );
156+ const Shape new_shape = ShapeUtil::MakeShape (
109157 inst_shape.element_type (),
110- {ShapeUtil::ElementsIn (inst_shape) / last_dim, last_dim});
111- HloInstruction* new_inst = computation->AddInstruction (
112- HloInstruction::CreateReshape (new_inst_shape, inst));
113- inst->SetupDerivedInstruction (new_inst);
114- return new_inst;
158+ {zero_dim, ShapeUtil::ElementsIn (inst_shape) / zero_dim});
159+ return MakeReshapeHlo (new_shape, inst);
115160}
116161
117162StatusOr<bool > ReplaceScatter (HloInstruction* scatter) {
@@ -120,38 +165,36 @@ StatusOr<bool> ReplaceScatter(HloInstruction* scatter) {
120165 HloComputation* computation = scatter->parent ();
121166 auto dim_numbers = scatter->scatter_dimension_numbers ();
122167 const int64 index_vector_dim = dim_numbers.index_vector_dim ();
168+ const auto & update_window_dims = dim_numbers.update_window_dims ();
123169 const bool is_update_add = IsMultiUpdateAddScatter (scatter);
124- int64 update_dim = dim_numbers.update_window_dims ()[0 ];
125170
126171 HloInstruction* operand = scatter->mutable_operand (0 );
127172 HloInstruction* indices = scatter->mutable_operand (1 );
128173 HloInstruction* updates = scatter->mutable_operand (2 );
129174
130- // If the indices are scalar then this is a dynamic-update-slice kind of
131- // scatter.
132- if (ShapeUtil::IsScalar (indices->shape ())) {
133- TF_ASSIGN_OR_RETURN (updates, PrependDegenerateDims (updates, 1 ));
134- update_dim += 1 ;
135- }
136-
137175 // Reshape the indices into a 2D shape [num_lookups, 1].
138176 const Shape& indices_shape = indices->shape ();
139177 const Shape new_indices_shape = ShapeUtil::MakeShape (
140178 indices_shape.element_type (), {ShapeUtil::ElementsIn (indices_shape), 1 });
141179 TF_ASSIGN_OR_RETURN (indices, MakeReshapeHlo (new_indices_shape, indices));
142180
143- // Move the update_dim to the back.
144- if ((updates->shape ().rank () - 1 ) != update_dim) {
145- TF_ASSIGN_OR_RETURN (updates,
146- MoveInstructionDimensionToBack (updates, update_dim));
147- update_dim = updates->shape ().rank () - 1 ;
148- }
181+ const int64 scatter_dimension = *GetScatterDimension (
182+ updates->shape ().rank (), AsInt64Slice (update_window_dims));
149183
150- // Collapse all but the last dimension of updates.
151- if (updates->shape ().rank () != 2 ) {
152- TF_ASSIGN_OR_RETURN (updates, CollapseAllButLastDimension (updates));
153- update_dim = 1 ;
154- }
184+ const std::vector<int64> permutation =
185+ GetPermutation (operand, scatter_dimension);
186+ const std::vector<int64> invert_permutation =
187+ InvertPermutations<int64>(permutation);
188+
189+ // Move the scatter dimension to the front.
190+ TF_ASSIGN_OR_RETURN (operand, MakeTransposeHlo (operand, permutation));
191+ TF_ASSIGN_OR_RETURN (updates, MakeTransposeHlo (updates, permutation));
192+
193+ const Shape pre_flatten_operand_shape = operand->shape ();
194+
195+ // Collapse all but the zeroth dimension.
196+ TF_ASSIGN_OR_RETURN (operand, CollapseAllButZerothDimension (operand));
197+ TF_ASSIGN_OR_RETURN (updates, CollapseAllButZerothDimension (updates));
155198
156199 HloInstruction* multi_update;
157200 if (is_update_add) {
@@ -160,12 +203,20 @@ StatusOr<bool> ReplaceScatter(HloInstruction* scatter) {
160203 computation->AddInstruction (HloInstruction::CreateConstant (
161204 LiteralUtil::One (scatter->shape ().element_type ())));
162205 multi_update = computation->AddInstruction (CreateMultiUpdateAdd (
163- scatter ->shape (), {operand, indices, updates, one}, update_dim ));
206+ operand ->shape (), {operand, indices, updates, one}));
164207 } else {
165- multi_update = computation->AddInstruction (CreateMultiUpdate (
166- scatter ->shape (), {operand, indices, updates}, update_dim ));
208+ multi_update = computation->AddInstruction (
209+ CreateMultiUpdate (operand ->shape (), {operand, indices, updates}));
167210 }
168211 scatter->SetupDerivedInstruction (multi_update);
212+
213+ // Uncollapse the dimensions.
214+ TF_ASSIGN_OR_RETURN (multi_update,
215+ MakeReshapeHlo (pre_flatten_operand_shape, multi_update));
216+ // Undo the transpose.
217+ TF_ASSIGN_OR_RETURN (multi_update,
218+ MakeTransposeHlo (multi_update, invert_permutation));
219+
169220 TF_RETURN_IF_ERROR (computation->ReplaceInstruction (scatter, multi_update));
170221 return true ;
171222}
@@ -185,7 +236,7 @@ StatusOr<bool> ScatterSimplifier::Run(HloModule* module) {
185236 // operands.
186237 auto insts = comp->MakeInstructionPostOrder ();
187238 for (HloInstruction* inst : insts) {
188- if (IsConvertableScatter (inst)) {
239+ if (IsConvertableScatter (inst) && CheckValidMultiUpdateAttributes (inst) ) {
189240 TF_ASSIGN_OR_RETURN (bool replaced, ReplaceScatter (inst));
190241 changed |= replaced;
191242 }
0 commit comments