Skip to content

Commit 353c4f3

Browse files
jakeh-gcgeorgepaw
authored andcommitted
Use Poplar host function.
Summary: What's changed: - Use Poplar Host Functions to tell poplar when we have a stream copy with a host data-dependency. - Allows `opt_flags.set("streamCallbacks.maxLookahead", "unlimited");` to be safely set by default, improving throughput. - There's a slight change to the UserOp API. - In testing, it remains ABI compatible with any old ones, if any even exist. Resolves T50668 Test Plan: CI tests - particularly user ops and host embedding tests. Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, babakk Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, babakk Subscribers: jorgec, babakk Maniphest Tasks: T50668 Differential Revision: https://phabricator.sourcevertex.net/D56667
1 parent 9890ce4 commit 353c4f3

File tree

8 files changed

+114
-276
lines changed

8 files changed

+114
-276
lines changed

tensorflow/compiler/plugin/poplar/docs/custom_codelet.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -617,9 +617,9 @@ The signature of the callback function is:
617617

618618
extern "C"
619619
void Callback(
620-
const std::vector<void*>& data,
620+
const std::vector<const void*>& data,
621621
const std::vector<std::uint32_t>& number_of_elements,
622-
std::vector<void*>& outputs,
622+
const std::vector<void*>& outputs,
623623
const std::string& attributes,
624624
const std::string& name);
625625

tensorflow/compiler/plugin/poplar/driver/compiler_annotations.h

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ limitations under the License.
1818

1919
#include <map>
2020
#include <set>
21+
#include <string>
22+
#include <unordered_map>
2123
#include <utility>
2224
#include <vector>
2325

@@ -198,28 +200,18 @@ struct StreamCopyInfo {
198200
// assosiated with that instruction.
199201
using StreamInfos = std::unordered_map<std::string, std::list<StreamCopyInfo>>;
200202

201-
// Stream meta info contains the information relating to the setup of the output
202-
// streams. We need to know how many outputs there are and how much data to
203-
// allocate in each buffer.
204-
struct StreamCopyMetaInfo {
205-
StreamCopyMetaInfo() {}
206-
StreamCopyMetaInfo(const HloInstruction* inst, std::uint32_t input_count)
207-
: parent_instruction(inst), num_inputs(input_count) {}
203+
struct HostFunctionInfo {
204+
using FunctionType = std::function<void(const std::vector<const void*>& input,
205+
const std::vector<void*>& outputs)>;
208206

209-
// The instruction the user op came from. We use this as a unique identifier
210-
// for the inputs/outputs so we can sort the input/outputs by operation.
211207
const HloInstruction* parent_instruction;
212-
213-
// Track all of the output streams, we do this so we can allocate them in
214-
// advance.
215-
std::list<StreamCopyInfo*> output_stream_info;
216-
217-
// The number of inputs this operation has.
218-
std::uint32_t num_inputs;
208+
std::string handle;
209+
std::vector<Shape> input_shapes;
210+
std::vector<Shape> output_shapes;
211+
FunctionType function;
219212
};
220213

221-
// We track one metainfo struct for each stream copy which the user has added.
222-
using StreamMetaInfos = std::unordered_map<std::string, StreamCopyMetaInfo>;
214+
using HostFunctionInfos = std::unordered_map<std::string, HostFunctionInfo>;
223215

224216
// This structure contains all information which we generate that pertains
225217
// to the XLA graph, as opposed to the poplar lowering of that graph.
@@ -236,8 +228,6 @@ struct CompilerAnnotations {
236228

237229
StreamInfos stream_infos;
238230

239-
StreamMetaInfos stream_meta_infos;
240-
241231
SendRecvInfos send_infos;
242232
SendRecvInfos recv_infos;
243233

@@ -265,6 +255,9 @@ struct CompilerAnnotations {
265255
OutputInfos entry_output_infos;
266256
// Feed output descriptions.
267257
OutputInfos feed_output_infos;
258+
259+
// Host function information.
260+
HostFunctionInfos host_function_infos;
268261
};
269262

270263
inline Status AddInfeedInfo(CompilerAnnotations& compiler_annotations,

tensorflow/compiler/plugin/poplar/driver/ops/custom_ops/host_embedding.cc

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -130,25 +130,13 @@ class HostEmbeddingLookupOp : public PoplarOpDef {
130130
{inst->name(), inst->EmbeddingId(), inst->operand(0)->shape(),
131131
output_shape, inst->SplittingStrategy()});
132132

133-
auto index_buffer = graph.addDeviceToHostFIFO(
134-
inst->name() + inst->EmbeddingId() + "_indices", indices.elementType(),
135-
indices.numElements());
133+
auto lookup_fn = graph.addHostFunction(
134+
/* handle = */ inst->name() + inst->EmbeddingId(),
135+
/* inputs = */ {{indices.elementType(), indices.numElements()}},
136+
/* outputs = */ {{output.elementType(), output.numElements()}});
136137

137-
auto activation_fifo = graph.addHostToDeviceFIFO(
138-
inst->name() + inst->EmbeddingId() + "_activations",
139-
output.elementType(), output.numElements());
140-
141-
// Send the indices to the host.
142-
seq.add(poplar::program::Copy(indices, index_buffer, false,
143-
{debug_name_and_id}));
144-
145-
// Sync to avoid any stream merging due to host-side data dependecy.
146-
seq.add(
147-
poplar::program::Sync(poplar::SyncType::INTERNAL, {debug_name_and_id}));
148-
149-
// Read the values from the host.
150-
seq.add(poplar::program::Copy(activation_fifo, output, false,
151-
{debug_name_and_id}));
138+
seq.add(poplar::program::Call(lookup_fn, {indices}, {output},
139+
{debug_name_and_id, "call"}));
152140

153141
TF_CHECK_OK(AddOutputTensor(tensor_map, inst, 0, output));
154142

@@ -417,18 +405,15 @@ class HostEmbeddingUpdateOp : public PoplarOpDef {
417405
{inst->name(), inst->EmbeddingId(), inst->operand(2)->shape(),
418406
inst->operand(1)->shape()});
419407

420-
auto index_buffer = graph.addDeviceToHostFIFO(
421-
inst->name() + inst->EmbeddingId() + "_indices", indices.elementType(),
422-
indices.numElements());
423-
424-
auto grad_fifo =
425-
graph.addDeviceToHostFIFO(inst->name() + inst->EmbeddingId() + "_grads",
426-
grads.elementType(), grads.numElements());
408+
auto update_fn = graph.addHostFunction(
409+
/* handle = */ inst->name() + inst->EmbeddingId(),
410+
/* inputs = */
411+
{{indices.elementType(), indices.numElements()},
412+
{grads.elementType(), grads.numElements()}},
413+
/* outputs = */ {});
427414

428-
seq.add(poplar::program::Copy(indices, index_buffer, false,
429-
{debug_name_and_id}));
430-
seq.add(
431-
poplar::program::Copy(grads, grad_fifo, false, {debug_name_and_id}));
415+
seq.add(poplar::program::Call(update_fn, {indices, grads}, {},
416+
{debug_name_and_id, "call"}));
432417

433418
return seq;
434419
}

tensorflow/compiler/plugin/poplar/driver/ops/custom_ops/poputil/user_op.cc

Lines changed: 39 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ class UserOpImpl : public PoplarOpDef {
7171
// the operation on the host by reading the raw bytes, executing the on the
7272
// host, then copying back.
7373
void (*as_function_host_rw_ptr)(
74-
const std::vector<void*>& data,
74+
const std::vector<const void*>& data,
7575
const std::vector<std::uint32_t>& number_of_elements,
76-
std::vector<void*>& outputs, const std::string& attributes,
76+
const std::vector<void*>& outputs, const std::string& attributes,
7777
const std::string& debugPrefix);
7878

7979
// Convert the function pointer to each of the types of function we could
@@ -113,61 +113,33 @@ class UserOpImpl : public PoplarOpDef {
113113
// stream we create the StreamCopyInfo structures to communicate with
114114
// poplar_executor which does the actual linking of the streams.
115115
if (is_user_read_write) {
116-
// A wrapper around the user functor which we finally call down to.
117-
auto functor = [=](std::vector<void*>& data,
118-
std::vector<std::uint32_t>& number_of_elements,
119-
std::vector<void*>& outputs) {
120-
as_function_host_rw_ptr(data, number_of_elements, outputs, attributes,
121-
instruction_name);
122-
};
123-
124-
// We add a map of user ops to their owned streams.
125-
res.annotations.stream_infos.insert({instruction_name, {}});
126-
std::list<StreamCopyInfo>& info_list =
127-
res.annotations.stream_infos[instruction_name];
128-
129-
// Allocate a stream info
130-
res.annotations.stream_meta_infos[instruction_name] = {inst,
131-
number_of_inputs};
132-
StreamCopyMetaInfo& meta_info =
133-
res.annotations.stream_meta_infos[instruction_name];
116+
auto& host_function_info =
117+
res.annotations.host_function_infos[instruction_name];
118+
host_function_info.parent_instruction = user_op_inst;
119+
host_function_info.handle = instruction_name;
120+
121+
std::vector<poplar::Tensor> inputs;
122+
std::vector<poplar::Graph::HostFunctionArgument> in_args;
123+
std::vector<poplar::Graph::HostFunctionArgument> out_args;
124+
std::vector<std::uint32_t> in_args_elems;
125+
inputs.resize(user_op_inst->NumInputs());
126+
outputs.resize(number_of_outputs);
134127

128+
// Collect the input tensors and input arg descriptions.
135129
for (std::uint32_t i = 0; i < user_op_inst->NumInputs(); ++i) {
136130
// Get the poplar tensor.
137131
TF_ASSIGN_OR_RETURN(poplar::Tensor in,
138132
FindInstructionInput(tensor_map, res, inst, i, seq,
139133
{debug_info}, false));
140134

141-
// Give each input a stream identifier based on the instruction name.
142-
const std::string stream_name =
143-
instruction_name + "_read_" + std::to_string(i);
144-
145-
// Create a datastream for the input tensor.
146-
poplar::DataStream stream = graph.addDeviceToHostFIFO(
147-
stream_name, in.elementType(), in.numElements());
148-
149-
// Allocate this structure to communicate to the executor, which
150-
// callbacks to register to which input tensors.
151-
const uint32_t num_elements = static_cast<uint32_t>(in.numElements());
152-
const uint32_t type_size = static_cast<uint32_t>(
153-
graph.getTarget().getTypeSize(in.elementType()));
154-
StreamCopyInfo info{inst, stream_name, num_elements,
155-
type_size, i, functor};
156-
info_list.push_back(info);
157-
158-
// Copy from the tensor into the host stream. We will later attach a
159-
// callback to this.
160-
seq.add(poplar::program::Copy(in, stream, false, {debug_info}));
161-
}
162-
163-
// Add an ontile sync to stop the copies from host being merged with the
164-
// above as there is an invisble dependency in the callback.
165-
seq.add(poplar::program::Sync(poplar::SyncType::INTERNAL, {debug_info}));
135+
in_args.emplace_back(in.elementType(), in.numElements());
136+
in_args_elems.push_back(in.numElements());
137+
host_function_info.input_shapes.push_back(inst->operand(i)->shape());
166138

167-
outputs.resize(number_of_outputs);
139+
inputs[i] = in;
140+
}
168141

169-
// Now go over and add a copy from the device back to the host for each
170-
// output.
142+
// Collect the output tensors and output arg descriptions.
171143
for (std::uint32_t output_index = 0; output_index != number_of_outputs;
172144
output_index++) {
173145
xla::Shape shape = output_shape.tuple_shapes()[output_index];
@@ -177,36 +149,27 @@ class UserOpImpl : public PoplarOpDef {
177149
AddTensor(graph, TensorLocation{inst, output_index}, shape, res,
178150
tensor_map, {debug_info, "output"}));
179151

180-
// Add stream ID for each output tensor.
181-
const std::string stream_name =
182-
instruction_name + "_write_" + std::to_string(output_index);
183-
184-
// Copy from the host into these new tensors.
185-
poplar::DataStream stream =
186-
graph.addHostToDeviceFIFO(stream_name, output_tensor.elementType(),
187-
output_tensor.numElements());
188-
189-
// Allocate this structure to communicate to the executor so the
190-
// executor knows how much memory to allocate for the callback to write
191-
// into.
192-
const uint32_t num_elements =
193-
static_cast<uint32_t>(output_tensor.numElements());
194-
const uint32_t type_size = static_cast<uint32_t>(
195-
graph.getTarget().getTypeSize(output_tensor.elementType()));
196-
StreamCopyInfo info{inst, stream_name, num_elements, type_size,
197-
output_index};
198-
info_list.push_back(std::move(info));
199-
200-
// Store a reference to this stream copy info.
201-
StreamCopyInfo* ref = &info_list.back();
202-
meta_info.output_stream_info.push_back(ref);
203-
204-
// Add the copy to the graph.
205-
seq.add(
206-
poplar::program::Copy(stream, output_tensor, false, {debug_info}));
207-
152+
out_args.emplace_back(output_tensor.elementType(),
153+
output_tensor.numElements());
154+
host_function_info.output_shapes.push_back(shape);
208155
outputs[output_index] = output_tensor;
209156
}
157+
158+
host_function_info.function = [as_function_host_rw_ptr, in_args_elems,
159+
attributes, instruction_name](
160+
const std::vector<const void*>& input,
161+
const std::vector<void*>& outputs) {
162+
as_function_host_rw_ptr(input, in_args_elems, outputs, attributes,
163+
instruction_name);
164+
};
165+
166+
// Create the host function
167+
auto user_fn_device =
168+
graph.addHostFunction(instruction_name, in_args, out_args);
169+
170+
// Add the device call to the host function.
171+
seq.add(
172+
poplar::program::Call(user_fn_device, inputs, outputs, debug_info));
210173
} else {
211174
if (!is_gradient) {
212175
std::vector<poplar::Tensor> inputs(user_op_inst->NumInputs());

tensorflow/compiler/plugin/poplar/driver/poplar_compiler.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1965,7 +1965,7 @@ StatusOr<std::unique_ptr<PoplarExecutableCore>> CompileEngine(
19651965
is_scalar_elementwise_graph,
19661966
/*loaded_from_cache=*/false, std::move(remaped_output),
19671967
std::move(resources.annotations.stream_infos),
1968-
std::move(resources.annotations.stream_meta_infos),
1968+
std::move(resources.annotations.host_function_infos),
19691969
PoplarExecutableInfo{
19701970
num_IPUs,
19711971
target_type,

tensorflow/compiler/plugin/poplar/driver/poplar_executable.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ PoplarExecutableCore::PoplarExecutableCore(
4242
std::vector<std::vector<Literal>> constant_literal_output,
4343
bool is_remap_graph, bool is_scalar_elementwise_graph,
4444
bool loaded_from_cache, std::vector<uint64> remaped_output,
45-
StreamInfos&& stream_infos, StreamMetaInfos&& stream_meta_info,
45+
StreamInfos&& stream_infos, HostFunctionInfos&& host_function_infos,
4646
PoplarExecutableInfo&& info)
4747
: poplar_engine_(std::move(engine)),
4848
input_output_aliasing_map_(std::move(input_output_aliasing_map)),
@@ -53,7 +53,7 @@ PoplarExecutableCore::PoplarExecutableCore(
5353
loaded_from_cache_(loaded_from_cache),
5454
remaped_output_(std::move(remaped_output)),
5555
stream_infos_(std::move(stream_infos)),
56-
stream_meta_infos_(std::move(stream_meta_info)),
56+
host_function_infos_(std::move(host_function_infos)),
5757
info_(std::move(info)) {
5858
TENSORFLOW_TRACEPOINT();
5959
PopulateCollectiveBalanceReorderHostRerrangements();
@@ -421,7 +421,7 @@ PoplarExecutableCore::Deserialize(
421421
/*is_remap_graph=*/false,
422422
/*is_scalar_elementwise_graph=*/false,
423423
/*loaded_from_cache=*/true, std::vector<uint64>{}, StreamInfos{},
424-
StreamMetaInfos{}, std::move(info));
424+
HostFunctionInfos{}, std::move(info));
425425

426426
return executable_core;
427427
}

tensorflow/compiler/plugin/poplar/driver/poplar_executable.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class PoplarExecutableCore {
8181
std::vector<std::vector<Literal>> constant_literal_output,
8282
bool is_remap_graph, bool is_scalar_elementwise_graph,
8383
bool loaded_from_cache, std::vector<uint64> remaped_output,
84-
StreamInfos&& stream_infos, StreamMetaInfos&& stream_meta_info,
84+
StreamInfos&& stream_infos, HostFunctionInfos&& host_function_infos,
8585
PoplarExecutableInfo&& info);
8686

8787
~PoplarExecutableCore();
@@ -120,6 +120,10 @@ class PoplarExecutableCore {
120120
return info_.remote_parameter_infos;
121121
}
122122

123+
const HostFunctionInfos& GetHostFunctionInfos() const {
124+
return host_function_infos_;
125+
}
126+
123127
const RemoteParameterHostRearrangements&
124128
GetRemoteParameterHostRearrangements() const {
125129
return info_.remote_parameter_host_rearrangements;
@@ -133,10 +137,6 @@ class PoplarExecutableCore {
133137

134138
const StreamInfos& GetStreamInfos() const { return stream_infos_; }
135139

136-
const StreamMetaInfos& GetStreamMetaInfos() const {
137-
return stream_meta_infos_;
138-
}
139-
140140
const SendRecvInfos& GetSendInfos() const { return info_.send_infos; }
141141

142142
const SendRecvInfos& GetRecvInfos() const { return info_.recv_infos; }
@@ -202,7 +202,7 @@ class PoplarExecutableCore {
202202

203203
// User op info that is not serialized.
204204
StreamInfos stream_infos_;
205-
StreamMetaInfos stream_meta_infos_;
205+
HostFunctionInfos host_function_infos_;
206206

207207
// All the other info that is serialized.
208208
PoplarExecutableInfo info_;
@@ -273,12 +273,12 @@ class PoplarExecutable : public Executable {
273273
return executable_core_->GetRemoteParameterInfos();
274274
}
275275

276-
const StreamInfos& GetStreamInfos() const {
277-
return executable_core_->GetStreamInfos();
276+
const HostFunctionInfos& GetHostFunctionInfos() const {
277+
return executable_core_->GetHostFunctionInfos();
278278
}
279279

280-
const StreamMetaInfos& GetStreamMetaInfos() const {
281-
return executable_core_->GetStreamMetaInfos();
280+
const StreamInfos& GetStreamInfos() const {
281+
return executable_core_->GetStreamInfos();
282282
}
283283

284284
const SendRecvInfos& GetSendInfos() const {

0 commit comments

Comments
 (0)