Skip to content

Commit 9890ce4

Browse files
committed
StaticMultiUpdate - make sure scale is FP32
Summary: Fix T51326 Test Plan: CI, added new tests Reviewers: jackh, #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved Reviewed By: jackh, #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved Maniphest Tasks: T51326 Differential Revision: https://phabricator.sourcevertex.net/D56866
1 parent 29087bc commit 9890ce4

File tree

3 files changed

+129
-118
lines changed

3 files changed

+129
-118
lines changed

tensorflow/compiler/plugin/poplar/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2433,7 +2433,7 @@ xla_test(
24332433

24342434
xla_test(
24352435
name = "multi_slice_update_constant_indices_test",
2436-
size = "medium",
2436+
size = "large",
24372437
srcs = ["tests/multi_slice_update_constant_indices_test.cc"],
24382438
backends = ["poplar"],
24392439
copts = ["-fexceptions"],

tensorflow/compiler/plugin/poplar/driver/ops/custom_ops/popops/multi_slice.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ limitations under the License.
2626
#include "tensorflow/core/lib/core/errors.h"
2727

2828
#include <poplar/DebugContext.hpp>
29+
#include <popops/Cast.hpp>
2930
#include <popops/DynamicSlice.hpp>
3031
#include <popops/ElementWise.hpp>
3132
#include <poputil/GraphFunction.hpp>
@@ -247,8 +248,15 @@ Status MultiUpdateInternal(
247248
const auto constant_indices = GetConstantIndices(inst->operands().at(1));
248249

249250
if (constant_indices.has_value()) {
251+
poplar::Tensor scale_casted = *scale;
252+
if (operand.elementType() == poplar::HALF &&
253+
scale_casted.elementType() == poplar::HALF) {
254+
VLOG(2) << "Casting static multi update scale to F32";
255+
scale_casted = popops::cast(graph, scale_casted, poplar::FLOAT, prog,
256+
{debug_name_and_id, "ScaleCast"});
257+
}
250258
popops::multiUpdateAdd(graph, operand, expanded_updates,
251-
constant_indices.value(), *scale, 0, prog,
259+
constant_indices.value(), scale_casted, 0, prog,
252260
{debug_name_and_id});
253261
} else {
254262
popops::multiUpdateAdd(graph, operand, expanded_updates, unsigned_indices,

tensorflow/compiler/plugin/poplar/tests/multi_slice_update_constant_indices_test.cc

Lines changed: 119 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -26,94 +26,78 @@ namespace xla {
2626
namespace poplarplugin {
2727
namespace {
2828

29-
se::Platform* GetReferencePlatform() {
30-
auto result = PlatformUtil::GetPlatform("interpreter");
31-
return result.ValueOrDie();
29+
StatusOr<Literal> GetTestInputs(PrimitiveType type, bool updates = false) {
30+
unsigned rows = updates ? 4 : 10;
31+
auto input_shape = ShapeUtil::MakeShape(F32, {rows, 4});
32+
Literal input(input_shape);
33+
input.Populate<float>([](const xla::DimensionVector& index) {
34+
return 1.0f * index[1] + index[0];
35+
});
36+
37+
if (type != F32) {
38+
TF_ASSIGN_OR_RETURN(input, input.Convert(type));
39+
}
40+
return input;
3241
}
33-
34-
se::Platform* GetTestPlatform() {
35-
auto platform = se::MultiPlatformManager::PlatformWithName("Poplar");
36-
EXPECT_TRUE(platform.ok());
37-
38-
auto* p = dynamic_cast<xp::PoplarPlatform*>(platform.ValueOrDie());
39-
40-
xla::poplarplugin::IpuOptions options;
41-
options.set_creator_id(IpuOptionsCreator::IPU_UTILS);
42-
43-
EXPECT_EQ(p->ConfigurePoplarDevices(options), Status::OK());
44-
return p;
42+
Literal GetTestIndices() {
43+
auto indices_shape = ShapeUtil::MakeShape(S32, {4});
44+
Literal indices(indices_shape);
45+
indices.PopulateR1<int32>({0, 2, 4, 8});
46+
return indices;
4547
}
4648

47-
class MultiSliceUpdateConstantIndicesTest : public HloTestBase {
48-
public:
49-
MultiSliceUpdateConstantIndicesTest()
50-
: HloTestBase(GetTestPlatform(), GetReferencePlatform()) {}
51-
52-
static Literal GetTestInputs(bool updates = false) {
53-
unsigned rows = updates ? 4 : 10;
54-
auto input_shape = ShapeUtil::MakeShape(F32, {rows, 4});
55-
Literal input(input_shape);
56-
input.Populate<float>([](const xla::DimensionVector& index) {
57-
return 1.0f * index[1] + index[0];
58-
});
59-
return input;
60-
}
49+
StatusOr<Literal> GetScale(PrimitiveType type) {
50+
auto scale_shape = ShapeUtil::MakeShape(F32, {1});
51+
Literal scale(scale_shape);
52+
scale.PopulateR1<float>({2.0});
6153

62-
static Literal GetTestIndices() {
63-
auto indices_shape = ShapeUtil::MakeShape(S32, {4});
64-
Literal indices(indices_shape);
65-
indices.PopulateR1<int32>({0, 2, 4, 8});
66-
return indices;
54+
if (type != F32) {
55+
TF_ASSIGN_OR_RETURN(scale, scale.Convert(type));
6756
}
57+
return scale;
58+
}
6859

69-
static Literal GetScale() {
70-
auto scale_shape = ShapeUtil::MakeShape(F32, {1});
71-
Literal scale(scale_shape);
72-
scale.PopulateR1<float>({2.0});
73-
return scale;
74-
}
60+
Status VerifySlices(const Literal& result) {
61+
TF_ASSIGN_OR_RETURN(auto inputs, GetTestInputs(F32));
62+
auto indices = GetTestIndices();
7563

76-
static void VerifySlices(const Literal& result) {
77-
auto inputs = GetTestInputs();
78-
auto indices = GetTestIndices();
79-
80-
const auto slice_shape = ShapeUtil::GetSubshape(result.shape(), {0});
81-
ShapeUtil::ForEachIndex(
82-
slice_shape, [&](absl::Span<const int64> output_index) {
83-
EXPECT_EQ(output_index.size(), 2);
84-
auto value = result.Get<float>(output_index, {0});
85-
auto idx = indices.Get<int32>({output_index[0], 0});
86-
auto input_value = inputs.Get<float>({idx, output_index[1]});
87-
EXPECT_EQ(value, input_value);
88-
return true;
89-
});
90-
}
64+
ShapeUtil::ForEachIndex(
65+
result.shape(), [&](absl::Span<const int64> output_index) {
66+
auto value = result.Get<float>(output_index);
67+
auto idx = indices.Get<int32>({output_index[0], 0});
68+
auto input_value = inputs.Get<float>({idx, output_index[1]});
69+
EXPECT_EQ(value, input_value);
70+
return true;
71+
});
72+
return Status::OK();
73+
}
9174

92-
static void VerifyUpdates(const Literal& result) {
93-
auto inputs = GetTestInputs();
94-
auto updates = GetTestInputs(true);
95-
auto scale = GetScale().Get<float>({0});
96-
auto indices = GetTestIndices();
97-
auto indices_data = indices.data<int>();
98-
99-
const auto slice_shape = ShapeUtil::GetSubshape(result.shape(), {0});
100-
ShapeUtil::ForEachIndex(
101-
slice_shape, [&](absl::Span<const int64> output_index) {
102-
EXPECT_EQ(output_index.size(), 2);
103-
auto value = result.Get<float>(output_index, {0});
104-
for (size_t i = 0; i < indices_data.size(); i++) {
105-
auto idx = indices_data.at(i);
106-
if (output_index[0] == idx) {
107-
auto input_value = inputs.Get<float>({idx, output_index[1]});
108-
auto update_value = updates.Get<float>({i, output_index[1]});
109-
EXPECT_EQ(value, scale * update_value + input_value);
110-
break;
111-
}
75+
Status VerifyUpdates(const Literal& result) {
76+
TF_ASSIGN_OR_RETURN(auto inputs, GetTestInputs(F32));
77+
TF_ASSIGN_OR_RETURN(auto updates, GetTestInputs(F32, true));
78+
TF_ASSIGN_OR_RETURN(auto scale, GetScale(F32));
79+
auto indices = GetTestIndices();
80+
auto scale_data = scale.Get<float>({0});
81+
auto indices_data = indices.data<int>();
82+
83+
ShapeUtil::ForEachIndex(
84+
result.shape(), [&](absl::Span<const int64> output_index) {
85+
auto value = result.Get<float>(output_index);
86+
for (size_t i = 0; i < indices_data.size(); i++) {
87+
auto idx = indices_data.at(i);
88+
if (output_index[0] == idx) {
89+
auto input_value = inputs.Get<float>({idx, output_index[1]});
90+
auto update_value = updates.Get<float>({i, output_index[1]});
91+
EXPECT_EQ(value, scale_data * update_value + input_value);
92+
break;
11293
}
113-
return true;
114-
});
115-
}
116-
};
94+
}
95+
return true;
96+
});
97+
return Status::OK();
98+
}
99+
100+
using MultiSliceUpdateConstantIndicesTest = HloTestBase;
117101

118102
TEST_F(MultiSliceUpdateConstantIndicesTest, SliceNonConstantIndices) {
119103
std::string hlo_string = R"(
@@ -122,8 +106,7 @@ HloModule main
122106
ENTRY main {
123107
input = f32[10,4] parameter(0)
124108
indices = s32[4] parameter(1)
125-
slices = f32[4,4] custom-call(input, indices), custom_call_target="MultiSlice", backend_config="{\"indices_are_sorted\":false}"
126-
ROOT t = (f32[4,4]) tuple(slices)
109+
ROOT slices = f32[4,4] custom-call(input, indices), custom_call_target="MultiSlice", backend_config="{\"indices_are_sorted\":false}"
127110
}
128111
)";
129112

@@ -138,7 +121,7 @@ ENTRY main {
138121
EXPECT_TRUE(custom_ops_replaced);
139122

140123
// Input to be sliced and indices to slice at.
141-
auto inputs = GetTestInputs();
124+
TF_ASSERT_OK_AND_ASSIGN(auto inputs, GetTestInputs(F32));
142125
auto indices = GetTestIndices();
143126

144127
// Execute.
@@ -148,10 +131,10 @@ ENTRY main {
148131
std::move(
149132
ParseAndReturnVerifiedModule(hlo_string, config).ValueOrDie()),
150133
{&inputs, &indices}));
151-
ASSERT_TRUE(result.shape().IsTuple());
134+
ASSERT_TRUE(result.shape().IsArray());
152135

153136
// Verify output.
154-
VerifySlices(result);
137+
TF_ASSERT_OK(VerifySlices(result));
155138
}
156139

157140
TEST_F(MultiSliceUpdateConstantIndicesTest, SliceConstantIndices) {
@@ -161,8 +144,7 @@ HloModule main
161144
ENTRY main {
162145
input = f32[10,4] parameter(0)
163146
indices = s32[4] constant({0, 2, 4, 8})
164-
slices = f32[4,4] custom-call(input, indices), custom_call_target="MultiSlice", backend_config="{\"indices_are_sorted\":false}"
165-
ROOT t = (f32[4,4]) tuple(slices)
147+
ROOT slices = f32[4,4] custom-call(input, indices), custom_call_target="MultiSlice", backend_config="{\"indices_are_sorted\":false}"
166148
}
167149
)";
168150

@@ -177,7 +159,7 @@ ENTRY main {
177159
EXPECT_TRUE(custom_ops_replaced);
178160

179161
// Input to be sliced.
180-
auto inputs = GetTestInputs();
162+
TF_ASSERT_OK_AND_ASSIGN(auto inputs, GetTestInputs(F32));
181163

182164
// Execute.
183165
TF_ASSERT_OK_AND_ASSIGN(
@@ -186,10 +168,10 @@ ENTRY main {
186168
std::move(
187169
ParseAndReturnVerifiedModule(hlo_string, config).ValueOrDie()),
188170
{&inputs}));
189-
ASSERT_TRUE(result.shape().IsTuple());
171+
ASSERT_TRUE(result.shape().IsArray());
190172

191173
// Verify output.
192-
VerifySlices(result);
174+
TF_ASSERT_OK(VerifySlices(result));
193175
}
194176

195177
TEST_F(MultiSliceUpdateConstantIndicesTest, UpdateNonConstantIndices) {
@@ -206,8 +188,7 @@ ENTRY main {
206188
big_zero = f32[10, 4] broadcast(zero), dimensions={}
207189
208190
update = f32[10, 4] custom-call(big_zero, indices, updates, scale), custom_call_target="MultiUpdateAdd", backend_config="{\"indices_are_sorted\":false}\n"
209-
sum = f32[10, 4] add(update, input)
210-
ROOT t = (f32[10, 4]) tuple(sum)
191+
ROOT sum = f32[10, 4] add(update, input)
211192
}
212193
)";
213194

@@ -222,10 +203,10 @@ ENTRY main {
222203
EXPECT_TRUE(custom_ops_replaced);
223204

224205
// Input to be sliced and indices to slice at.
225-
auto inputs = GetTestInputs();
226-
auto updates = GetTestInputs(true);
206+
TF_ASSERT_OK_AND_ASSIGN(auto inputs, GetTestInputs(F32));
207+
TF_ASSERT_OK_AND_ASSIGN(auto updates, GetTestInputs(F32, true));
208+
TF_ASSERT_OK_AND_ASSIGN(auto scale, GetScale(F32));
227209
auto indices = GetTestIndices();
228-
auto scale = GetScale();
229210

230211
// Execute.
231212
TF_ASSERT_OK_AND_ASSIGN(
@@ -234,58 +215,80 @@ ENTRY main {
234215
std::move(
235216
ParseAndReturnVerifiedModule(hlo_string, config).ValueOrDie()),
236217
{&inputs, &indices, &updates, &scale}));
237-
ASSERT_TRUE(result.shape().IsTuple());
218+
ASSERT_TRUE(result.shape().IsArray());
238219

239220
// Verify output.
240-
VerifyUpdates(result);
221+
TF_ASSERT_OK(VerifyUpdates(result));
241222
}
242223

243-
TEST_F(MultiSliceUpdateConstantIndicesTest, UpdateConstantIndices) {
244-
std::string hlo_string = R"(
224+
struct MultiUpdateAddTestSpec {
225+
PrimitiveType element_type;
226+
227+
std::string GetHlo() const {
228+
const std::string hlo_string = R"(
245229
HloModule main
246230
247231
ENTRY main {
248-
input = f32[10, 4] parameter(0)
249-
updates = f32[4, 4] parameter(1)
250-
scale = f32[] parameter(2)
232+
input = $element_type[10, 4] parameter(0)
233+
updates = $element_type[4, 4] parameter(1)
234+
scale = $element_type[] parameter(2)
251235
252236
indices = s32[4, 1] constant({{0}, {2}, {4}, {8}})
253237
254-
zero = f32[] constant(0)
255-
big_zero = f32[10, 4] broadcast(zero), dimensions={}
238+
zero = $element_type[] constant(0)
239+
big_zero = $element_type[10, 4] broadcast(zero), dimensions={}
256240
257-
update = f32[10, 4] custom-call(big_zero, indices, updates, scale), custom_call_target="MultiUpdateAdd", backend_config="{\"indices_are_sorted\":false}\n"
258-
sum = f32[10, 4] add(update, input)
259-
ROOT t = (f32[10, 4]) tuple(sum)
241+
update = $element_type[10, 4] custom-call(big_zero, indices, updates, scale), custom_call_target="MultiUpdateAdd", backend_config="{\"indices_are_sorted\":false}\n"
242+
ROOT sum = $element_type[10, 4] add(update, input)
260243
}
261244
)";
245+
return tensorflow::str_util::StringReplace(
246+
hlo_string, "$element_type",
247+
primitive_util::LowercasePrimitiveTypeName(element_type), true);
248+
}
249+
};
262250

263-
HloModuleConfig config;
264-
config.set_debug_options(GetDebugOptionsForTest());
251+
std::ostream& operator<<(std::ostream& os, const MultiUpdateAddTestSpec& spec) {
252+
return os << "{element_type: " << spec.element_type << "}";
253+
}
254+
255+
class MultiUpdateAddTest
256+
: public HloTestBase,
257+
public ::testing::WithParamInterface<MultiUpdateAddTestSpec> {};
265258

259+
INSTANTIATE_TEST_SUITE_P(
260+
MultiUpdateAddTestCases, MultiUpdateAddTest,
261+
::testing::ValuesIn(std::vector<MultiUpdateAddTestSpec>{{F32}, {F16}}));
262+
263+
TEST_P(MultiUpdateAddTest, DoTest) {
264+
auto param = GetParam();
266265
TF_ASSERT_OK_AND_ASSIGN(auto module,
267-
ParseAndReturnVerifiedModule(hlo_string, config));
266+
ParseAndReturnVerifiedModule(param.GetHlo()));
268267

269268
TF_ASSERT_OK_AND_ASSIGN(bool custom_ops_replaced,
270269
CustomOpReplacer().Run(module.get()));
271270
EXPECT_TRUE(custom_ops_replaced);
272271

273272
// Input to be sliced and indices to slice at.
274-
auto inputs = GetTestInputs();
275-
auto updates = GetTestInputs(true);
276-
auto scale = GetScale();
273+
TF_ASSERT_OK_AND_ASSIGN(auto inputs, GetTestInputs(param.element_type));
274+
TF_ASSERT_OK_AND_ASSIGN(auto updates,
275+
GetTestInputs(param.element_type, /*update=*/true));
276+
TF_ASSERT_OK_AND_ASSIGN(auto scale, GetScale(param.element_type));
277277

278278
// Execute.
279279
TF_ASSERT_OK_AND_ASSIGN(
280280
Literal result,
281281
Execute(
282-
std::move(
283-
ParseAndReturnVerifiedModule(hlo_string, config).ValueOrDie()),
282+
std::move(ParseAndReturnVerifiedModule(param.GetHlo()).ValueOrDie()),
284283
{&inputs, &updates, &scale}));
285-
ASSERT_TRUE(result.shape().IsTuple());
284+
ASSERT_TRUE(result.shape().IsArray());
285+
286+
if (param.element_type != F32) {
287+
TF_ASSERT_OK_AND_ASSIGN(result, result.Convert(F32));
288+
}
286289

287290
// Verify output.
288-
VerifyUpdates(result);
291+
TF_ASSERT_OK(VerifyUpdates(result));
289292
}
290293

291294
} // namespace

0 commit comments

Comments
 (0)