@@ -26,94 +26,78 @@ namespace xla {
2626namespace poplarplugin {
2727namespace {
2828
29- se::Platform* GetReferencePlatform () {
30- auto result = PlatformUtil::GetPlatform (" interpreter" );
31- return result.ValueOrDie ();
29+ StatusOr<Literal> GetTestInputs (PrimitiveType type, bool updates = false ) {
30+ unsigned rows = updates ? 4 : 10 ;
31+ auto input_shape = ShapeUtil::MakeShape (F32, {rows, 4 });
32+ Literal input (input_shape);
33+ input.Populate <float >([](const xla::DimensionVector& index) {
34+ return 1 .0f * index[1 ] + index[0 ];
35+ });
36+
37+ if (type != F32) {
38+ TF_ASSIGN_OR_RETURN (input, input.Convert (type));
39+ }
40+ return input;
3241}
33-
34- se::Platform* GetTestPlatform () {
35- auto platform = se::MultiPlatformManager::PlatformWithName (" Poplar" );
36- EXPECT_TRUE (platform.ok ());
37-
38- auto * p = dynamic_cast <xp::PoplarPlatform*>(platform.ValueOrDie ());
39-
40- xla::poplarplugin::IpuOptions options;
41- options.set_creator_id (IpuOptionsCreator::IPU_UTILS);
42-
43- EXPECT_EQ (p->ConfigurePoplarDevices (options), Status::OK ());
44- return p;
42+ Literal GetTestIndices () {
43+ auto indices_shape = ShapeUtil::MakeShape (S32, {4 });
44+ Literal indices (indices_shape);
45+ indices.PopulateR1 <int32>({0 , 2 , 4 , 8 });
46+ return indices;
4547}
4648
47- class MultiSliceUpdateConstantIndicesTest : public HloTestBase {
48- public:
49- MultiSliceUpdateConstantIndicesTest ()
50- : HloTestBase(GetTestPlatform(), GetReferencePlatform()) {}
51-
52- static Literal GetTestInputs (bool updates = false ) {
53- unsigned rows = updates ? 4 : 10 ;
54- auto input_shape = ShapeUtil::MakeShape (F32, {rows, 4 });
55- Literal input (input_shape);
56- input.Populate <float >([](const xla::DimensionVector& index) {
57- return 1 .0f * index[1 ] + index[0 ];
58- });
59- return input;
60- }
49+ StatusOr<Literal> GetScale (PrimitiveType type) {
50+ auto scale_shape = ShapeUtil::MakeShape (F32, {1 });
51+ Literal scale (scale_shape);
52+ scale.PopulateR1 <float >({2.0 });
6153
62- static Literal GetTestIndices () {
63- auto indices_shape = ShapeUtil::MakeShape (S32, {4 });
64- Literal indices (indices_shape);
65- indices.PopulateR1 <int32>({0 , 2 , 4 , 8 });
66- return indices;
54+ if (type != F32) {
55+ TF_ASSIGN_OR_RETURN (scale, scale.Convert (type));
6756 }
57+ return scale;
58+ }
6859
69- static Literal GetScale () {
70- auto scale_shape = ShapeUtil::MakeShape (F32, {1 });
71- Literal scale (scale_shape);
72- scale.PopulateR1 <float >({2.0 });
73- return scale;
74- }
60+ Status VerifySlices (const Literal& result) {
61+ TF_ASSIGN_OR_RETURN (auto inputs, GetTestInputs (F32));
62+ auto indices = GetTestIndices ();
7563
76- static void VerifySlices (const Literal& result) {
77- auto inputs = GetTestInputs ();
78- auto indices = GetTestIndices ();
79-
80- const auto slice_shape = ShapeUtil::GetSubshape (result.shape (), {0 });
81- ShapeUtil::ForEachIndex (
82- slice_shape, [&](absl::Span<const int64> output_index) {
83- EXPECT_EQ (output_index.size (), 2 );
84- auto value = result.Get <float >(output_index, {0 });
85- auto idx = indices.Get <int32>({output_index[0 ], 0 });
86- auto input_value = inputs.Get <float >({idx, output_index[1 ]});
87- EXPECT_EQ (value, input_value);
88- return true ;
89- });
90- }
64+ ShapeUtil::ForEachIndex (
65+ result.shape (), [&](absl::Span<const int64> output_index) {
66+ auto value = result.Get <float >(output_index);
67+ auto idx = indices.Get <int32>({output_index[0 ], 0 });
68+ auto input_value = inputs.Get <float >({idx, output_index[1 ]});
69+ EXPECT_EQ (value, input_value);
70+ return true ;
71+ });
72+ return Status::OK ();
73+ }
9174
92- static void VerifyUpdates (const Literal& result) {
93- auto inputs = GetTestInputs ();
94- auto updates = GetTestInputs (true );
95- auto scale = GetScale ().Get <float >({0 });
96- auto indices = GetTestIndices ();
97- auto indices_data = indices.data <int >();
98-
99- const auto slice_shape = ShapeUtil::GetSubshape (result.shape (), {0 });
100- ShapeUtil::ForEachIndex (
101- slice_shape, [&](absl::Span<const int64> output_index) {
102- EXPECT_EQ (output_index.size (), 2 );
103- auto value = result.Get <float >(output_index, {0 });
104- for (size_t i = 0 ; i < indices_data.size (); i++) {
105- auto idx = indices_data.at (i);
106- if (output_index[0 ] == idx) {
107- auto input_value = inputs.Get <float >({idx, output_index[1 ]});
108- auto update_value = updates.Get <float >({i, output_index[1 ]});
109- EXPECT_EQ (value, scale * update_value + input_value);
110- break ;
111- }
75+ Status VerifyUpdates (const Literal& result) {
76+ TF_ASSIGN_OR_RETURN (auto inputs, GetTestInputs (F32));
77+ TF_ASSIGN_OR_RETURN (auto updates, GetTestInputs (F32, true ));
78+ TF_ASSIGN_OR_RETURN (auto scale, GetScale (F32));
79+ auto indices = GetTestIndices ();
80+ auto scale_data = scale.Get <float >({0 });
81+ auto indices_data = indices.data <int >();
82+
83+ ShapeUtil::ForEachIndex (
84+ result.shape (), [&](absl::Span<const int64> output_index) {
85+ auto value = result.Get <float >(output_index);
86+ for (size_t i = 0 ; i < indices_data.size (); i++) {
87+ auto idx = indices_data.at (i);
88+ if (output_index[0 ] == idx) {
89+ auto input_value = inputs.Get <float >({idx, output_index[1 ]});
90+ auto update_value = updates.Get <float >({i, output_index[1 ]});
91+ EXPECT_EQ (value, scale_data * update_value + input_value);
92+ break ;
11293 }
113- return true ;
114- });
115- }
116- };
94+ }
95+ return true ;
96+ });
97+ return Status::OK ();
98+ }
99+
100+ using MultiSliceUpdateConstantIndicesTest = HloTestBase;
117101
118102TEST_F (MultiSliceUpdateConstantIndicesTest, SliceNonConstantIndices) {
119103 std::string hlo_string = R"(
@@ -122,8 +106,7 @@ HloModule main
122106ENTRY main {
123107 input = f32[10,4] parameter(0)
124108 indices = s32[4] parameter(1)
125- slices = f32[4,4] custom-call(input, indices), custom_call_target="MultiSlice", backend_config="{\"indices_are_sorted\":false}"
126- ROOT t = (f32[4,4]) tuple(slices)
109+ ROOT slices = f32[4,4] custom-call(input, indices), custom_call_target="MultiSlice", backend_config="{\"indices_are_sorted\":false}"
127110}
128111)" ;
129112
@@ -138,7 +121,7 @@ ENTRY main {
138121 EXPECT_TRUE (custom_ops_replaced);
139122
140123 // Input to be sliced and indices to slice at.
141- auto inputs = GetTestInputs ();
124+ TF_ASSERT_OK_AND_ASSIGN ( auto inputs, GetTestInputs (F32) );
142125 auto indices = GetTestIndices ();
143126
144127 // Execute.
@@ -148,10 +131,10 @@ ENTRY main {
148131 std::move (
149132 ParseAndReturnVerifiedModule (hlo_string, config).ValueOrDie ()),
150133 {&inputs, &indices}));
151- ASSERT_TRUE (result.shape ().IsTuple ());
134+ ASSERT_TRUE (result.shape ().IsArray ());
152135
153136 // Verify output.
154- VerifySlices (result);
137+ TF_ASSERT_OK ( VerifySlices (result) );
155138}
156139
157140TEST_F (MultiSliceUpdateConstantIndicesTest, SliceConstantIndices) {
@@ -161,8 +144,7 @@ HloModule main
161144ENTRY main {
162145 input = f32[10,4] parameter(0)
163146 indices = s32[4] constant({0, 2, 4, 8})
164- slices = f32[4,4] custom-call(input, indices), custom_call_target="MultiSlice", backend_config="{\"indices_are_sorted\":false}"
165- ROOT t = (f32[4,4]) tuple(slices)
147+ ROOT slices = f32[4,4] custom-call(input, indices), custom_call_target="MultiSlice", backend_config="{\"indices_are_sorted\":false}"
166148}
167149)" ;
168150
@@ -177,7 +159,7 @@ ENTRY main {
177159 EXPECT_TRUE (custom_ops_replaced);
178160
179161 // Input to be sliced.
180- auto inputs = GetTestInputs ();
162+ TF_ASSERT_OK_AND_ASSIGN ( auto inputs, GetTestInputs (F32) );
181163
182164 // Execute.
183165 TF_ASSERT_OK_AND_ASSIGN (
@@ -186,10 +168,10 @@ ENTRY main {
186168 std::move (
187169 ParseAndReturnVerifiedModule (hlo_string, config).ValueOrDie ()),
188170 {&inputs}));
189- ASSERT_TRUE (result.shape ().IsTuple ());
171+ ASSERT_TRUE (result.shape ().IsArray ());
190172
191173 // Verify output.
192- VerifySlices (result);
174+ TF_ASSERT_OK ( VerifySlices (result) );
193175}
194176
195177TEST_F (MultiSliceUpdateConstantIndicesTest, UpdateNonConstantIndices) {
@@ -206,8 +188,7 @@ ENTRY main {
206188 big_zero = f32[10, 4] broadcast(zero), dimensions={}
207189
208190 update = f32[10, 4] custom-call(big_zero, indices, updates, scale), custom_call_target="MultiUpdateAdd", backend_config="{\"indices_are_sorted\":false}\n"
209- sum = f32[10, 4] add(update, input)
210- ROOT t = (f32[10, 4]) tuple(sum)
191+ ROOT sum = f32[10, 4] add(update, input)
211192}
212193)" ;
213194
@@ -222,10 +203,10 @@ ENTRY main {
222203 EXPECT_TRUE (custom_ops_replaced);
223204
224205 // Input to be sliced and indices to slice at.
225- auto inputs = GetTestInputs ();
226- auto updates = GetTestInputs (true );
206+ TF_ASSERT_OK_AND_ASSIGN (auto inputs, GetTestInputs (F32));
207+ TF_ASSERT_OK_AND_ASSIGN (auto updates, GetTestInputs (F32, true ));
208+ TF_ASSERT_OK_AND_ASSIGN (auto scale, GetScale (F32));
227209 auto indices = GetTestIndices ();
228- auto scale = GetScale ();
229210
230211 // Execute.
231212 TF_ASSERT_OK_AND_ASSIGN (
@@ -234,58 +215,80 @@ ENTRY main {
234215 std::move (
235216 ParseAndReturnVerifiedModule (hlo_string, config).ValueOrDie ()),
236217 {&inputs, &indices, &updates, &scale}));
237- ASSERT_TRUE (result.shape ().IsTuple ());
218+ ASSERT_TRUE (result.shape ().IsArray ());
238219
239220 // Verify output.
240- VerifyUpdates (result);
221+ TF_ASSERT_OK ( VerifyUpdates (result) );
241222}
242223
243- TEST_F (MultiSliceUpdateConstantIndicesTest, UpdateConstantIndices) {
244- std::string hlo_string = R"(
224+ struct MultiUpdateAddTestSpec {
225+ PrimitiveType element_type;
226+
227+ std::string GetHlo () const {
228+ const std::string hlo_string = R"(
245229HloModule main
246230
247231ENTRY main {
248- input = f32 [10, 4] parameter(0)
249- updates = f32 [4, 4] parameter(1)
250- scale = f32 [] parameter(2)
232+ input = $element_type [10, 4] parameter(0)
233+ updates = $element_type [4, 4] parameter(1)
234+ scale = $element_type [] parameter(2)
251235
252236 indices = s32[4, 1] constant({{0}, {2}, {4}, {8}})
253237
254- zero = f32 [] constant(0)
255- big_zero = f32 [10, 4] broadcast(zero), dimensions={}
238+ zero = $element_type [] constant(0)
239+ big_zero = $element_type [10, 4] broadcast(zero), dimensions={}
256240
257- update = f32[10, 4] custom-call(big_zero, indices, updates, scale), custom_call_target="MultiUpdateAdd", backend_config="{\"indices_are_sorted\":false}\n"
258- sum = f32[10, 4] add(update, input)
259- ROOT t = (f32[10, 4]) tuple(sum)
241+ update = $element_type[10, 4] custom-call(big_zero, indices, updates, scale), custom_call_target="MultiUpdateAdd", backend_config="{\"indices_are_sorted\":false}\n"
242+ ROOT sum = $element_type[10, 4] add(update, input)
260243}
261244)" ;
245+ return tensorflow::str_util::StringReplace (
246+ hlo_string, " $element_type" ,
247+ primitive_util::LowercasePrimitiveTypeName (element_type), true );
248+ }
249+ };
262250
263- HloModuleConfig config;
264- config.set_debug_options (GetDebugOptionsForTest ());
251+ std::ostream& operator <<(std::ostream& os, const MultiUpdateAddTestSpec& spec) {
252+ return os << " {element_type: " << spec.element_type << " }" ;
253+ }
254+
255+ class MultiUpdateAddTest
256+ : public HloTestBase,
257+ public ::testing::WithParamInterface<MultiUpdateAddTestSpec> {};
265258
259+ INSTANTIATE_TEST_SUITE_P (
260+ MultiUpdateAddTestCases, MultiUpdateAddTest,
261+ ::testing::ValuesIn (std::vector<MultiUpdateAddTestSpec>{{F32}, {F16}}));
262+
263+ TEST_P (MultiUpdateAddTest, DoTest) {
264+ auto param = GetParam ();
266265 TF_ASSERT_OK_AND_ASSIGN (auto module ,
267- ParseAndReturnVerifiedModule (hlo_string, config ));
266+ ParseAndReturnVerifiedModule (param. GetHlo () ));
268267
269268 TF_ASSERT_OK_AND_ASSIGN (bool custom_ops_replaced,
270269 CustomOpReplacer ().Run (module .get ()));
271270 EXPECT_TRUE (custom_ops_replaced);
272271
273272 // Input to be sliced and indices to slice at.
274- auto inputs = GetTestInputs ();
275- auto updates = GetTestInputs (true );
276- auto scale = GetScale ();
273+ TF_ASSERT_OK_AND_ASSIGN (auto inputs, GetTestInputs (param.element_type ));
274+ TF_ASSERT_OK_AND_ASSIGN (auto updates,
275+ GetTestInputs (param.element_type , /* update=*/ true ));
276+ TF_ASSERT_OK_AND_ASSIGN (auto scale, GetScale (param.element_type ));
277277
278278 // Execute.
279279 TF_ASSERT_OK_AND_ASSIGN (
280280 Literal result,
281281 Execute (
282- std::move (
283- ParseAndReturnVerifiedModule (hlo_string, config).ValueOrDie ()),
282+ std::move (ParseAndReturnVerifiedModule (param.GetHlo ()).ValueOrDie ()),
284283 {&inputs, &updates, &scale}));
285- ASSERT_TRUE (result.shape ().IsTuple ());
284+ ASSERT_TRUE (result.shape ().IsArray ());
285+
286+ if (param.element_type != F32) {
287+ TF_ASSERT_OK_AND_ASSIGN (result, result.Convert (F32));
288+ }
286289
287290 // Verify output.
288- VerifyUpdates (result);
291+ TF_ASSERT_OK ( VerifyUpdates (result) );
289292}
290293
291294} // namespace
0 commit comments