@@ -21,6 +21,7 @@ limitations under the License.
2121
2222#include " tensorflow/compiler/plugin/poplar/driver/tensor.h"
2323#include " tensorflow/compiler/xla/service/hlo_casting_utils.h"
24+ #include " tensorflow/compiler/xla/service/hlo_evaluator.h"
2425#include " tensorflow/compiler/xla/service/hlo_instruction.h"
2526#include " tensorflow/core/lib/core/errors.h"
2627
@@ -34,6 +35,63 @@ namespace pgf = poputil::graphfn;
3435namespace xla {
3536namespace poplarplugin {
3637namespace {
38+ absl::optional<std::vector<unsigned >> GetConstantIndices (
39+ HloInstruction* const indices) {
40+ // For a given input Literal, this lambda reshapes the literal to
41+ // be of rank-1, if possible.
42+ auto get_val = [](const Literal& val) -> absl::optional<Literal> {
43+ if (val.shape ().rank () == 1 ) {
44+ return absl::optional<Literal>(val.Clone ());
45+ }
46+
47+ if (val.shape ().rank () == 0 ) {
48+ return absl::optional<Literal>(val.Reshape ({1 }).ValueOrDie ());
49+ }
50+
51+ // If the product of the dims is equal to the number of elements,
52+ // then the Literal can be reshaped to be rank-1.
53+ const auto dims = val.shape ().dimensions ();
54+ const auto max_dim = absl::c_max_element (dims);
55+ if (max_dim == dims.end ()) {
56+ return absl::nullopt ;
57+ }
58+
59+ const auto dim_product =
60+ absl::c_accumulate (dims, 1 , std::multiplies<int64>());
61+
62+ if (*max_dim != dim_product) {
63+ return absl::nullopt ;
64+ }
65+
66+ return absl::optional<Literal>(val.Reshape ({*max_dim}).ValueOrDie ());
67+ };
68+
69+ if (indices->IsConstant ()) {
70+ auto val_opt = get_val (indices->literal ());
71+ if (!val_opt.has_value ()) {
72+ return absl::nullopt ;
73+ }
74+
75+ auto value = LiteralVectorToNativeType<unsigned >(val_opt.value ());
76+ return absl::optional<std::vector<unsigned >>(value.ValueOrDie ());
77+ }
78+
79+ auto cloned_indices = indices->Clone ();
80+ Literal result;
81+ HloEvaluator evaluator (/* max_loop_iterations=*/ 0 );
82+
83+ if (!evaluator.TryEvaluate (cloned_indices.get (), &result)) {
84+ return absl::nullopt ;
85+ }
86+
87+ auto val_opt = get_val (result);
88+ if (!val_opt.has_value ()) {
89+ return absl::nullopt ;
90+ }
91+
92+ auto value = LiteralVectorToNativeType<unsigned >(val_opt.value ());
93+ return absl::optional<std::vector<unsigned >>(value.ValueOrDie ());
94+ }
3795
3896StatusOr<poplar::Tensor> CreateInputTensor (
3997 poplar::Graph& graph, const popops::SlicePlan& plan,
@@ -105,10 +163,19 @@ class MultiSliceOp : public PoplarOpDef {
105163 TF_ASSIGN_OR_RETURN (poplar::OptionFlags opts,
106164 GetSliceOptionsForInst (inst, res));
107165
108- poplar::Tensor output = popops::multiSlice (
109- graph, input,
110- indices.flatten ().expand ({1 }).reinterpret (poplar::UNSIGNED_INT), {0 },
111- {1 }, seq, *plan, opts, {debug_info, " output" });
166+ const auto constant_indices = GetConstantIndices (inst->operands ().at (1 ));
167+
168+ poplar::Tensor output;
169+ if (!constant_indices.has_value ()) {
170+ output = popops::multiSlice (
171+ graph, input,
172+ indices.flatten ().expand ({1 }).reinterpret (poplar::UNSIGNED_INT), {0 },
173+ {1 }, seq, *plan, opts, {debug_info, " output" });
174+ } else {
175+ output = popops::multiSlice (graph, input, constant_indices.value (), {0 },
176+ seq, {debug_info, " output" });
177+ }
178+
112179 auto poplar_output_shape = PoplarShapeFromXlaShape (output_shape);
113180
114181 // Unflatten the output:
@@ -177,9 +244,17 @@ Status MultiUpdateInternal(
177244 popops::multiUpdate (graph, operand, expanded_updates, unsigned_indices, {0 },
178245 {1 }, prog, plan, opts, {debug_name_and_id});
179246 } else {
180- popops::multiUpdateAdd (graph, operand, expanded_updates, unsigned_indices,
181- *scale, {0 }, {1 }, prog, plan, opts,
182- {debug_name_and_id});
247+ const auto constant_indices = GetConstantIndices (inst->operands ().at (1 ));
248+
249+ if (constant_indices.has_value ()) {
250+ popops::multiUpdateAdd (graph, operand, expanded_updates,
251+ constant_indices.value (), *scale, 0 , prog,
252+ {debug_name_and_id});
253+ } else {
254+ popops::multiUpdateAdd (graph, operand, expanded_updates, unsigned_indices,
255+ *scale, {0 }, {1 }, prog, plan, opts,
256+ {debug_name_and_id});
257+ }
183258 }
184259
185260 return Status::OK ();
0 commit comments