Skip to content

Commit d21d1bf

Browse files
committed
Change inputs of f8_convert ops from a tuple to simple inputs.
Summary: Using tuples was breaking pipelining as the ops were being assigned the wrong sharding. Removing tuple inputs removes this issue and makes the code more consistent with matmul and conv f8 ops. Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, alfiee Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, alfiee Maniphest Tasks: T71684 Differential Revision: https://phabricator.sourcevertex.net/D78466
1 parent b5f4972 commit d21d1bf

File tree

7 files changed

+54
-97
lines changed

7 files changed

+54
-97
lines changed

tensorflow/compiler/plugin/poplar/driver/ops/custom_ops/popops/f8_convert.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,13 @@ class ConvertToF8Op : public PoplarOpDef {
3535
const poplar::DebugContext& debug_context) override {
3636
PoplarOpDefDebugInfo debug_info(debug_context, "Fp8Convert");
3737
DriverProgramSequence seq(debug_info);
38-
auto inputs = FindInstructionInputs(tensor_map, res, inst, 0, seq,
39-
debug_info, /*expand_aliasing=*/true);
40-
CHECK_EQ(inputs.size(), 2);
41-
DriverTensor input = inputs[0].AsTensor();
42-
DriverTensor input_metadata = inputs[1].AsTensor();
38+
TF_ASSIGN_OR_RETURN(
39+
auto input, FindInstructionInput(tensor_map, res, inst, 0, seq,
40+
debug_info, /*expand_aliasing=*/true));
41+
TF_ASSIGN_OR_RETURN(
42+
auto input_metadata,
43+
FindInstructionInput(tensor_map, res, inst, 1, seq, debug_info,
44+
/*expand_aliasing=*/true));
4345
// We can't reinterpret to neither QUARTER_METADATA nor QUARTER type.
4446
// Instead, clone them and copy raw unsigned char data over.
4547
// This copy will be elided by poplar.

tensorflow/compiler/plugin/poplar/driver/tensor.cc

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,38 +1419,33 @@ StatusOr<DriverTensor> FindF8InstructionInput(
14191419
const poplar::DebugNameAndId& debug_name_and_id, bool expand_aliasing) {
14201420
const HloInstruction* operand = inst->operand(input);
14211421

1422-
TensorOrRemoteBufferVector inputs = GetTensorsMaybeExpand(
1423-
map, res, operand, seq, expand_aliasing, debug_name_and_id, 0, 2);
1424-
1425-
if (inputs.size() == 0) {
1426-
return tensorflow::errors::Unknown(
1427-
StrCat("[Poplar] Couldn't find input ", input, " for ", inst->name()));
1428-
}
1422+
TF_ASSIGN_OR_RETURN(
1423+
auto u8_data,
1424+
FindInstructionInput(map, res, inst, 0, seq, debug_name_and_id,
1425+
/*expand_aliasing=*/true));
1426+
// return u8_data;
1427+
TF_ASSIGN_OR_RETURN(
1428+
auto u8_metadata,
1429+
FindInstructionInput(map, res, inst, 1, seq, debug_name_and_id,
1430+
/*expand_aliasing=*/true));
14291431

1430-
CHECK_EQ(inputs.size(), 2);
14311432
auto& graph =
14321433
GetGraphWithOutputIndex(res, operand, /*flattened_output_tuple_index=*/0);
1433-
CHECK(&graph == &GetGraphWithOutputIndex(res, operand,
1434-
/*flattened_output_tuple_index=*/1));
1435-
poplar::Graph& poplar_graph = graph;
1436-
14371434
// We can't reinterpret to neither QUARTER_METADATA nor QUARTER type.
14381435
// Instead, clone them and copy raw unsigned char data over.
14391436
// Those copies will be elided by poplar.
14401437

1441-
DriverTensor u8_data = inputs[0].AsTensor();
1442-
DriverTensor u8_metadata = inputs[1].AsTensor();
1443-
auto f8_metadata = poplar_graph.clone(
1438+
auto f8_metadata = graph.clone(
14441439
poplar::QUARTER_METADATA, u8_metadata.reshape({1}), debug_name_and_id,
14451440
poplar::TensorCloneMethod::PRESERVE_ORDER_AND_ALIASES);
1446-
auto f8_data = poplar_graph.clone(
1447-
poplar::QUARTER, f8_metadata, u8_data, debug_name_and_id,
1448-
poplar::TensorCloneMethod::PRESERVE_ORDER_AND_ALIASES);
1441+
auto f8_data =
1442+
graph.clone(poplar::QUARTER, f8_metadata, u8_data, debug_name_and_id,
1443+
poplar::TensorCloneMethod::PRESERVE_ORDER_AND_ALIASES);
14491444
seq.add(poplar::program::Copy(
14501445
u8_metadata, f8_metadata.reinterpret(poplar::UNSIGNED_CHAR)));
14511446
seq.add(poplar::program::Copy(u8_data,
14521447
f8_data.reinterpret(poplar::UNSIGNED_CHAR)));
1453-
return DriverTensor(f8_data);
1448+
return f8_data;
14541449
}
14551450

14561451
TensorOrRemoteBufferVector FindInstructionInputs(

tensorflow/compiler/plugin/poplar/driver/tools/custom_ops/f8_convert.cc

Lines changed: 13 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -23,83 +23,46 @@ namespace xla {
2323
namespace poplarplugin {
2424

2525
HloConvertFromF8Instruction::HloConvertFromF8Instruction(
26-
const Shape& shape, HloInstruction* operand)
27-
: HloF8ConvertInstruction(shape, operand) {
28-
CHECK_EQ(shape, GetShape(operand));
29-
}
30-
31-
Shape HloConvertFromF8Instruction::GetShape(const HloInstruction* operand) {
32-
// Result shape is f16[<input-dimensions>].
33-
const Shape& op_shape = operand->shape();
34-
CHECK(op_shape.IsTuple());
35-
36-
// Expect data to be in U8.
37-
const Shape& input_shape = op_shape.tuple_shapes(0);
38-
CHECK(input_shape.element_type() == U8);
39-
40-
const Shape& metadata_shape = op_shape.tuple_shapes(1);
41-
CHECK(metadata_shape.element_type() == U8);
42-
CHECK(ShapeUtil::IsScalar(metadata_shape));
43-
44-
// The only supported type now is F16.
45-
return ShapeUtil::MakeShape(F16, input_shape.dimensions());
46-
}
26+
const Shape& shape, HloInstruction* data, HloInstruction* metadata)
27+
: HloF8ConvertInstruction(shape, {data, metadata}) {}
4728

4829
std::unique_ptr<HloInstruction>
4930
HloConvertFromF8Instruction::CloneWithNewOperandsImpl(
5031
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
5132
HloCloneContext*) const {
52-
CHECK_EQ(new_operands.size(), 1);
33+
CHECK_EQ(new_operands.size(), 2);
5334
return std::unique_ptr<HloInstruction>(
54-
new HloConvertFromF8Instruction(shape, new_operands[0]));
35+
new HloConvertFromF8Instruction(shape, new_operands[0], new_operands[1]));
5536
}
5637

5738
HloConvertToF8Instruction::HloConvertToF8Instruction(const Shape& shape,
58-
HloInstruction* operand)
59-
: HloF8ConvertInstruction(shape, operand) {
60-
CHECK(ShapeUtil::Compatible(shape, GetShape(operand)));
61-
}
62-
63-
Shape HloConvertToF8Instruction::GetShape(const HloInstruction* operand) {
64-
// Result shape is (u8[<input-dimensions>], u8 metadata).
65-
const Shape& op_shape = operand->shape();
66-
CHECK(op_shape.IsTuple());
67-
68-
// The only supported type now is F16.
69-
const Shape& input_shape = op_shape.tuple_shapes(0);
70-
CHECK(input_shape.element_type() == F16);
71-
72-
const Shape& metadata_shape = op_shape.tuple_shapes(1);
73-
CHECK(metadata_shape.element_type() == U8);
74-
CHECK(ShapeUtil::IsScalar(metadata_shape));
75-
76-
return ShapeUtil::MakeTupleShape(
77-
{ShapeUtil::MakeShape(U8, input_shape.dimensions()), metadata_shape});
78-
}
39+
HloInstruction* data,
40+
HloInstruction* metadata)
41+
: HloF8ConvertInstruction(shape, {data, metadata}) {}
7942

8043
std::unique_ptr<HloInstruction>
8144
HloConvertToF8Instruction::CloneWithNewOperandsImpl(
8245
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
8346
HloCloneContext*) const {
84-
CHECK_EQ(new_operands.size(), 1);
47+
CHECK_EQ(new_operands.size(), 2);
8548
return std::unique_ptr<HloInstruction>(
86-
new HloConvertToF8Instruction(shape, new_operands[0]));
49+
new HloConvertToF8Instruction(shape, new_operands[0], new_operands[1]));
8750
}
8851

8952
namespace {
9053
StatusOr<std::unique_ptr<HloInstruction>>
9154
HloConvertFromF8InstructionFactoryFunc(HloCustomCallInstruction* call) {
92-
return std::unique_ptr<HloInstruction>(
93-
new HloConvertFromF8Instruction(call->shape(), call->mutable_operand(0)));
55+
return std::unique_ptr<HloInstruction>(new HloConvertFromF8Instruction(
56+
call->shape(), call->mutable_operand(0), call->mutable_operand(1)));
9457
}
9558

9659
static HloPoplarInstructionFactory fp8_convert_from_factory(
9760
PoplarOp::ConvertFromF8, HloConvertFromF8InstructionFactoryFunc);
9861

9962
StatusOr<std::unique_ptr<HloInstruction>> HloConvertToF8InstructionFactoryFunc(
10063
HloCustomCallInstruction* call) {
101-
return std::unique_ptr<HloInstruction>(
102-
new HloConvertToF8Instruction(call->shape(), call->mutable_operand(0)));
64+
return std::unique_ptr<HloInstruction>(new HloConvertToF8Instruction(
65+
call->shape(), call->mutable_operand(0), call->mutable_operand(1)));
10366
}
10467

10568
static HloPoplarInstructionFactory fp8_convert_to_factory(

tensorflow/compiler/plugin/poplar/driver/tools/custom_ops/f8_convert.h

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ namespace poplarplugin {
3131
template <PoplarOp Op>
3232
class HloF8ConvertInstruction : public HloPoplarInstruction {
3333
public:
34-
HloF8ConvertInstruction(const Shape& shape, HloInstruction* operand)
35-
: HloPoplarInstruction(shape, {operand}, Op) {}
34+
HloF8ConvertInstruction(const Shape& shape,
35+
absl::Span<HloInstruction* const> operands)
36+
: HloPoplarInstruction(shape, operands, Op) {}
3637

3738
absl::flat_hash_set<int64_t> AllocatingIndices() const override { return {}; }
3839
bool AllocatingOutput() const override { return false; }
@@ -67,13 +68,10 @@ class HloF8ConvertInstruction : public HloPoplarInstruction {
6768
class HloConvertFromF8Instruction
6869
: public HloF8ConvertInstruction<PoplarOp::ConvertFromF8> {
6970
public:
70-
HloConvertFromF8Instruction(const Shape& shape, HloInstruction* operand);
71-
explicit HloConvertFromF8Instruction(HloInstruction* operand)
72-
: HloConvertFromF8Instruction(GetShape(operand), operand) {}
71+
HloConvertFromF8Instruction(const Shape& shape, HloInstruction* data,
72+
HloInstruction* metadata);
7373

7474
private:
75-
Shape GetShape(const HloInstruction* operand);
76-
7775
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
7876
const Shape& shape, absl::Span<HloInstruction* const>,
7977
HloCloneContext*) const override;
@@ -84,13 +82,10 @@ std::unique_ptr<HloInstruction> CreateConvertToF8Instruction(
8482
class HloConvertToF8Instruction
8583
: public HloF8ConvertInstruction<PoplarOp::ConvertToF8> {
8684
public:
87-
HloConvertToF8Instruction(const Shape& shape, HloInstruction* operand);
88-
explicit HloConvertToF8Instruction(HloInstruction* operand)
89-
: HloConvertToF8Instruction(GetShape(operand), operand) {}
85+
HloConvertToF8Instruction(const Shape& shape, HloInstruction* data,
86+
HloInstruction* metadata);
9087

9188
private:
92-
Shape GetShape(const HloInstruction* operand);
93-
9489
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
9590
const Shape& shape, absl::Span<HloInstruction* const>,
9691
HloCloneContext*) const override;

tensorflow/compiler/plugin/poplar/kernels/popops/f8_convert.cc

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,9 @@ class IpuConvertToF8Op : public XlaOpKernel, public IpuOpKernel {
120120
xla::Shape output_shape = xla::ShapeUtil::MakeTupleShape(
121121
{output_data_shape, output_metadata_shape});
122122

123-
auto packed_input = xla::Tuple(b, {input_data, input_metadata});
124-
auto output_tuple =
125-
xla::CustomCall(b, PoplarOp_Name(PoplarOp::ConvertToF8), {packed_input},
126-
{output_shape}, attribute_map_.Serialise());
123+
auto output_tuple = xla::CustomCall(
124+
b, PoplarOp_Name(PoplarOp::ConvertToF8), {input_data, input_metadata},
125+
output_shape, attribute_map_.Serialise());
127126

128127
ctx->SetOutput(0, xla::GetTupleElement(output_tuple, 0));
129128
ctx->SetOutput(1, xla::GetTupleElement(output_tuple, 1));
@@ -155,10 +154,9 @@ class IpuConvertFromF8Op : public XlaOpKernel, public IpuOpKernel {
155154
xla::Shape output_shape;
156155
OP_REQUIRES_OK(
157156
ctx, TensorShapeToXLAShape(DT_HALF, ctx->InputShape(0), &output_shape));
158-
xla::XlaOp input = xla::Tuple(b, {input_data, input_metadata});
159-
auto output =
160-
xla::CustomCall(b, PoplarOp_Name(PoplarOp::ConvertFromF8), {input},
161-
{output_shape}, attribute_map_.Serialise());
157+
auto output = xla::CustomCall(b, PoplarOp_Name(PoplarOp::ConvertFromF8),
158+
{input_data, input_metadata}, output_shape,
159+
attribute_map_.Serialise());
162160

163161
ctx->SetOutput(0, output);
164162
}

tensorflow/compiler/plugin/poplar/kernels/popops/fp8_ops.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ class PoplinF8ConvOp : public XlaOpKernel, IpuOpKernel {
143143

144144
OP_REQUIRES(
145145
ctx, op_type != PoplarOp::Unknown,
146-
xla::InvalidArgument("Unsupported F8 Convolution Dimension ", D));
146+
xla::InvalidArgument("Unsupported F8 Convolution Dimension %d", D));
147147

148148
auto call_output =
149149
xla::CustomCall(ctx->builder(), PoplarOp_Name(op_type), args, out_shape,

tensorflow/compiler/plugin/poplar/tests/f8_test.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,12 @@ TEST_F(Fp8Test, TestConvert) {
3535
3636
ENTRY main {
3737
input = (f16[2,2], u8[]) parameter(0)
38-
input.fp8 = (u8[2,2], u8[]) custom-call(input), custom_call_target="ConvertToF8"
39-
input.fp = f16[2,2] custom-call(input.fp8), custom_call_target="ConvertFromF8"
38+
input.1 = f16[2,2] get-tuple-element(input), index=0
39+
input.2 = u8[] get-tuple-element(input), index=1
40+
input.fp8 = (u8[2,2], u8[]) custom-call(input.1, input.2), custom_call_target="ConvertToF8"
41+
input.fp8.1 = u8[2,2] get-tuple-element(input.fp8), index=0
42+
input.fp8.2 = u8[] get-tuple-element(input.fp8), index=1
43+
input.fp = f16[2,2] custom-call(input.fp8.1, input.fp8.2), custom_call_target="ConvertFromF8"
4044
ROOT root = ((f16[2,2], u8[]), (u8[2,2], u8[]), f16[2,2]) tuple(input, input.fp8, input.fp)
4145
}
4246
)";

0 commit comments

Comments
 (0)