Skip to content

Commit d1c744b

Browse files
committed
Support static indices in multiSlice and multiUpdateAdd.
Summary: Automatically call into the static indices overloads when constant indices are found. Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, georgep, vladimirm, jakeh Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, georgep, jakeh Subscribers: jakeh Maniphest Tasks: T46631 Differential Revision: https://phabricator.sourcevertex.net/D55351
1 parent 11e7d64 commit d1c744b

File tree

3 files changed

+390
-7
lines changed

3 files changed

+390
-7
lines changed

tensorflow/compiler/plugin/poplar/BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2431,6 +2431,20 @@ xla_test(
24312431
],
24322432
)
24332433

2434+
xla_test(
2435+
name = "multi_slice_update_constant_indices_test",
2436+
size = "medium",
2437+
srcs = ["tests/multi_slice_update_constant_indices_test.cc"],
2438+
backends = ["poplar"],
2439+
copts = ["-fexceptions"],
2440+
deps = [
2441+
":test_utils",
2442+
"//tensorflow/compiler/xla:test",
2443+
"//tensorflow/compiler/xla/tests:hlo_test_base",
2444+
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
2445+
],
2446+
)
2447+
24342448
xla_test(
24352449
name = "dynamic_slice_test",
24362450
size = "medium",
@@ -5020,6 +5034,7 @@ test_suite(
50205034
"multi_ipu_test",
50215035
"multi_run_test",
50225036
"multi_slice_combiner_test",
5037+
"multi_slice_update_constant_indices_test",
50235038
"multi_update_apply_test",
50245039
"multi_update_combiner_test",
50255040
"multi_update_scale_apply_test",

tensorflow/compiler/plugin/poplar/driver/ops/custom_ops/popops/multi_slice.cc

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121

2222
#include "tensorflow/compiler/plugin/poplar/driver/tensor.h"
2323
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
24+
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
2425
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
2526
#include "tensorflow/core/lib/core/errors.h"
2627

@@ -34,6 +35,63 @@ namespace pgf = poputil::graphfn;
3435
namespace xla {
3536
namespace poplarplugin {
3637
namespace {
38+
absl::optional<std::vector<unsigned>> GetConstantIndices(
39+
HloInstruction* const indices) {
40+
// For a given input Literal, this lambda reshapes the literal to
41+
// be of rank-1, if possible.
42+
auto get_val = [](const Literal& val) -> absl::optional<Literal> {
43+
if (val.shape().rank() == 1) {
44+
return absl::optional<Literal>(val.Clone());
45+
}
46+
47+
if (val.shape().rank() == 0) {
48+
return absl::optional<Literal>(val.Reshape({1}).ValueOrDie());
49+
}
50+
51+
// If the product of the dims is equal to the number of elements,
52+
// then the Literal can be reshaped to be rank-1.
53+
const auto dims = val.shape().dimensions();
54+
const auto max_dim = absl::c_max_element(dims);
55+
if (max_dim == dims.end()) {
56+
return absl::nullopt;
57+
}
58+
59+
const auto dim_product =
60+
absl::c_accumulate(dims, 1, std::multiplies<int64>());
61+
62+
if (*max_dim != dim_product) {
63+
return absl::nullopt;
64+
}
65+
66+
return absl::optional<Literal>(val.Reshape({*max_dim}).ValueOrDie());
67+
};
68+
69+
if (indices->IsConstant()) {
70+
auto val_opt = get_val(indices->literal());
71+
if (!val_opt.has_value()) {
72+
return absl::nullopt;
73+
}
74+
75+
auto value = LiteralVectorToNativeType<unsigned>(val_opt.value());
76+
return absl::optional<std::vector<unsigned>>(value.ValueOrDie());
77+
}
78+
79+
auto cloned_indices = indices->Clone();
80+
Literal result;
81+
HloEvaluator evaluator(/*max_loop_iterations=*/0);
82+
83+
if (!evaluator.TryEvaluate(cloned_indices.get(), &result)) {
84+
return absl::nullopt;
85+
}
86+
87+
auto val_opt = get_val(result);
88+
if (!val_opt.has_value()) {
89+
return absl::nullopt;
90+
}
91+
92+
auto value = LiteralVectorToNativeType<unsigned>(val_opt.value());
93+
return absl::optional<std::vector<unsigned>>(value.ValueOrDie());
94+
}
3795

3896
StatusOr<poplar::Tensor> CreateInputTensor(
3997
poplar::Graph& graph, const popops::SlicePlan& plan,
@@ -105,10 +163,19 @@ class MultiSliceOp : public PoplarOpDef {
105163
TF_ASSIGN_OR_RETURN(poplar::OptionFlags opts,
106164
GetSliceOptionsForInst(inst, res));
107165

108-
poplar::Tensor output = popops::multiSlice(
109-
graph, input,
110-
indices.flatten().expand({1}).reinterpret(poplar::UNSIGNED_INT), {0},
111-
{1}, seq, *plan, opts, {debug_info, "output"});
166+
const auto constant_indices = GetConstantIndices(inst->operands().at(1));
167+
168+
poplar::Tensor output;
169+
if (!constant_indices.has_value()) {
170+
output = popops::multiSlice(
171+
graph, input,
172+
indices.flatten().expand({1}).reinterpret(poplar::UNSIGNED_INT), {0},
173+
{1}, seq, *plan, opts, {debug_info, "output"});
174+
} else {
175+
output = popops::multiSlice(graph, input, constant_indices.value(), {0},
176+
seq, {debug_info, "output"});
177+
}
178+
112179
auto poplar_output_shape = PoplarShapeFromXlaShape(output_shape);
113180

114181
// Unflatten the output:
@@ -177,9 +244,17 @@ Status MultiUpdateInternal(
177244
popops::multiUpdate(graph, operand, expanded_updates, unsigned_indices, {0},
178245
{1}, prog, plan, opts, {debug_name_and_id});
179246
} else {
180-
popops::multiUpdateAdd(graph, operand, expanded_updates, unsigned_indices,
181-
*scale, {0}, {1}, prog, plan, opts,
182-
{debug_name_and_id});
247+
const auto constant_indices = GetConstantIndices(inst->operands().at(1));
248+
249+
if (constant_indices.has_value()) {
250+
popops::multiUpdateAdd(graph, operand, expanded_updates,
251+
constant_indices.value(), *scale, 0, prog,
252+
{debug_name_and_id});
253+
} else {
254+
popops::multiUpdateAdd(graph, operand, expanded_updates, unsigned_indices,
255+
*scale, {0}, {1}, prog, plan, opts,
256+
{debug_name_and_id});
257+
}
183258
}
184259

185260
return Status::OK();

0 commit comments

Comments
 (0)