Skip to content

Commit 10a343a

Browse files
Gautham Ganapathygeorgepaw
authored andcommitted
Store information about the runtime the executable was compiled for
Summary: FIX T41132 Reviewers: #tensorflow, simonl, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, jakeh, georgep Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, georgep Maniphest Tasks: T41132 Differential Revision: https://phabricator.sourcevertex.net/D48456
1 parent 7157de0 commit 10a343a

File tree

5 files changed

+140
-48
lines changed

5 files changed

+140
-48
lines changed

tensorflow/compiler/plugin/poplar/driver/poplar_compiler.cc

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,6 +1220,15 @@ StatusOr<std::unique_ptr<PoplarExecutableCore>> CompileEngine(
12201220
VLOG(1) << "Created " << replication_factor << " replica IPU graph.";
12211221
}
12221222

1223+
const int64 num_IPUs = target.getNumIPUs();
1224+
const std::string target_type = poplar::toString(target.getTargetType());
1225+
const std::string target_arch =
1226+
target.getTargetType() == poplar::TargetType::IPU
1227+
? target.getTargetArchString()
1228+
: "";
1229+
const bool gateway_mode = target.getGatewayMode();
1230+
const bool supports_remote_buffers = poplar_executor->SupportsRemoteBuffers();
1231+
12231232
resources.progress_bar->Start();
12241233

12251234
{
@@ -1739,11 +1748,33 @@ StatusOr<std::unique_ptr<PoplarExecutableCore>> CompileEngine(
17391748
poplar::OptionFlags options_to_serialize =
17401749
poplar_executor->GetReportExecutionFlags();
17411750

1751+
auto& annotations = resources.annotations;
1752+
17421753
TF_RETURN_IF_ERROR(PoplarExecutableCore::Serialize(
1743-
filenames, exec, resources.annotations, replication_factor,
1744-
options_to_serialize, logging_cycle_count,
1745-
resources.streams_indices.GetAssignedIds(),
1746-
resources.streams_indices.CheckpointFeedsOrder()));
1754+
filenames, exec, options_to_serialize,
1755+
PoplarExecutableInfo{
1756+
num_IPUs,
1757+
target_type,
1758+
target_arch,
1759+
gateway_mode,
1760+
supports_remote_buffers,
1761+
replication_factor,
1762+
annotations.infeed_infos,
1763+
annotations.outfeed_infos,
1764+
annotations.send_infos,
1765+
annotations.recv_infos,
1766+
annotations.host_embedding_lookup_infos,
1767+
annotations.host_embedding_update_infos,
1768+
annotations.host_embedding_notify_infos,
1769+
annotations.remote_parameter_infos,
1770+
annotations.entry_input_infos,
1771+
annotations.feed_input_infos,
1772+
annotations.entry_output_infos,
1773+
annotations.feed_output_infos,
1774+
logging_cycle_count,
1775+
resources.streams_indices.GetAssignedIds(),
1776+
resources.streams_indices.CheckpointFeedsOrder(),
1777+
}));
17471778

17481779
if (in_precompile_mode) {
17491780
LOG(INFO) << "A pre-compiled Poplar program has been saved to "
@@ -1806,7 +1837,13 @@ StatusOr<std::unique_ptr<PoplarExecutableCore>> CompileEngine(
18061837
std::move(resources.annotations.stream_infos),
18071838
std::move(resources.annotations.stream_meta_infos),
18081839
PoplarExecutableInfo{
1809-
replication_factor, std::move(resources.annotations.infeed_infos),
1840+
num_IPUs,
1841+
target_type,
1842+
target_arch,
1843+
gateway_mode,
1844+
supports_remote_buffers,
1845+
replication_factor,
1846+
std::move(resources.annotations.infeed_infos),
18101847
std::move(resources.annotations.outfeed_infos),
18111848
std::move(resources.annotations.send_infos),
18121849
std::move(resources.annotations.recv_infos),
@@ -1818,7 +1855,8 @@ StatusOr<std::unique_ptr<PoplarExecutableCore>> CompileEngine(
18181855
std::move(resources.annotations.feed_input_infos),
18191856
std::move(resources.annotations.entry_output_infos),
18201857
std::move(resources.annotations.feed_output_infos),
1821-
logging_cycle_count, resources.streams_indices.GetAssignedIds(),
1858+
logging_cycle_count,
1859+
resources.streams_indices.GetAssignedIds(),
18221860
resources.streams_indices.CheckpointFeedsOrder()});
18231861

18241862
return executable_core;

tensorflow/compiler/plugin/poplar/driver/poplar_executable.cc

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ limitations under the License.
3030
#include "tensorflow/compiler/plugin/poplar/driver/tools/tracepoint.h"
3131
#include "tensorflow/compiler/plugin/poplar/driver/tools/util.h"
3232
#include "tensorflow/compiler/plugin/poplar/driver/xla_ipu_common.h"
33+
#include "tensorflow/core/platform/stacktrace.h"
3334

3435
namespace xla {
3536
namespace poplarplugin {
@@ -74,6 +75,13 @@ PoplarExecutableInfo FromProto(const PoplarExecutableProto& proto,
7475
poplar::OptionFlags* engine_options) {
7576
PoplarExecutableInfo info;
7677

78+
auto& ertc = proto.embedded_runtime_config();
79+
info.num_IPUs = ertc.num_ipus();
80+
info.target_type = ertc.target_type();
81+
info.target_arch = ertc.target_arch();
82+
info.gateway_mode = ertc.gateway_mode();
83+
info.supports_remote_buffers = ertc.supports_remote_buffers();
84+
7785
info.replication_factor = proto.replication_factor();
7886

7987
for (const auto& infeed : proto.infeeds()) {
@@ -166,6 +174,13 @@ PoplarExecutableProto ToProto(const PoplarExecutableInfo& info,
166174
const poplar::OptionFlags& poplar_options = {}) {
167175
PoplarExecutableProto proto;
168176

177+
auto ertc = proto.mutable_embedded_runtime_config();
178+
ertc->set_num_ipus(info.num_IPUs);
179+
ertc->set_target_type(info.target_type);
180+
ertc->set_target_arch(info.target_arch);
181+
ertc->set_gateway_mode(info.gateway_mode);
182+
ertc->set_supports_remote_buffers(info.supports_remote_buffers);
183+
169184
proto.set_replication_factor(info.replication_factor);
170185

171186
for (const auto& infeed : info.infeed_infos) {
@@ -261,7 +276,7 @@ PoplarExecutableProto ToProto(const PoplarExecutableInfo& info,
261276

262277
// Items that don't need deserialising.
263278
for (const auto& input_info : info.entry_input_infos) {
264-
auto input = proto.mutable_signature()->add_inputs();
279+
auto input = ertc->mutable_signature()->add_inputs();
265280
input->set_name(input_info.name);
266281
input->set_handle(input_info.handle);
267282
input->set_argument(input_info.argument);
@@ -270,7 +285,7 @@ PoplarExecutableProto ToProto(const PoplarExecutableInfo& info,
270285
}
271286

272287
for (const auto& streamed_input_info : info.feed_input_infos) {
273-
auto input = proto.mutable_signature()->add_streamed_inputs();
288+
auto input = ertc->mutable_signature()->add_streamed_inputs();
274289
input->set_name(streamed_input_info.name);
275290
input->set_handle(streamed_input_info.handle);
276291
input->set_argument(streamed_input_info.argument);
@@ -279,15 +294,15 @@ PoplarExecutableProto ToProto(const PoplarExecutableInfo& info,
279294
}
280295

281296
for (const auto& output_info : info.entry_output_infos) {
282-
auto output = proto.mutable_signature()->add_outputs();
297+
auto output = ertc->mutable_signature()->add_outputs();
283298
output->set_name(output_info.name);
284299
output->set_handle(output_info.handle);
285300
output->set_tuple_index(output_info.tuple_index);
286301
(*output->mutable_shape()) = output_info.shape.ToProto();
287302
}
288303

289304
for (const auto& streamed_output_info : info.feed_output_infos) {
290-
auto output = proto.mutable_signature()->add_streamed_outputs();
305+
auto output = ertc->mutable_signature()->add_streamed_outputs();
291306
output->set_name(streamed_output_info.name);
292307
output->set_handle(streamed_output_info.handle);
293308
output->set_tuple_index(streamed_output_info.tuple_index);
@@ -354,32 +369,10 @@ PoplarExecutableCore::Deserialize(
354369

355370
/*static*/ Status PoplarExecutableCore::Serialize(
356371
const ModuleFilenames& filenames, const poplar::Executable& executable,
357-
const CompilerAnnotations& annotations, uint32 replication_count,
358-
const poplar::OptionFlags& opts, bool logging_cycle_count,
359-
const VerifiedStreamsIndices::KeyIdMappings& mappings,
360-
const std::vector<string>& checkpoint_feeds_order) {
372+
const poplar::OptionFlags& opts, const PoplarExecutableInfo& info) {
361373
TENSORFLOW_TRACEPOINT();
362374

363-
const PoplarExecutableProto proto = ToProto(
364-
PoplarExecutableInfo{
365-
replication_count,
366-
annotations.infeed_infos,
367-
annotations.outfeed_infos,
368-
annotations.send_infos,
369-
annotations.recv_infos,
370-
annotations.host_embedding_lookup_infos,
371-
annotations.host_embedding_update_infos,
372-
annotations.host_embedding_notify_infos,
373-
annotations.remote_parameter_infos,
374-
annotations.entry_input_infos,
375-
annotations.feed_input_infos,
376-
annotations.entry_output_infos,
377-
annotations.feed_output_infos,
378-
logging_cycle_count,
379-
mappings,
380-
checkpoint_feeds_order,
381-
},
382-
opts);
375+
const PoplarExecutableProto proto = ToProto(info, opts);
383376

384377
return PoplarExecutableBinaryFile::Write(
385378
filenames.CachedExecutableFilename(), proto,

tensorflow/compiler/plugin/poplar/driver/poplar_executable.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ struct CompilerAnnotations;
4040
struct CompilerResources;
4141

4242
struct PoplarExecutableInfo {
43+
int64 num_IPUs;
44+
std::string target_type;
45+
std::string target_arch;
46+
bool gateway_mode;
47+
bool supports_remote_buffers;
48+
4349
uint32 replication_factor;
4450
CanonicalInfeedInfos infeed_infos;
4551
CanonicalOutfeedInfos outfeed_infos;
@@ -152,12 +158,8 @@ class PoplarExecutableCore {
152158

153159
static Status Serialize(const ModuleFilenames& filenames,
154160
const poplar::Executable& executable,
155-
const CompilerAnnotations& annotations,
156-
uint32 replication_count,
157161
const poplar::OptionFlags& opts,
158-
bool logging_cycle_count,
159-
const VerifiedStreamsIndices::KeyIdMappings& mappings,
160-
const std::vector<string>& checkpoint_feeds_order);
162+
const PoplarExecutableInfo& info);
161163

162164
static Status Export(const ModuleFilenames& filenames,
163165
const poplar::Executable& executable,

tensorflow/compiler/plugin/poplar/driver/poplar_executable.proto

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,18 @@ message Signature {
7575
repeated Output streamed_outputs = 4;
7676
}
7777

78+
message EmbeddedRuntimeConfig {
79+
// The functional signature of the poplar executable.
80+
Signature signature = 1;
81+
82+
// Information about runtime used to compile executable
83+
int64 num_IPUs = 2;
84+
string target_type = 3;
85+
string target_arch = 4;
86+
bool gateway_mode = 5;
87+
bool supports_remote_buffers = 6;
88+
}
89+
7890
message PoplarExecutableProto {
7991

8092
// The number of replicas
@@ -104,6 +116,5 @@ message PoplarExecutableProto {
104116

105117
bool logging_cycle_count = 13;
106118

107-
// The functional signature of the poplar executable.
108-
Signature signature = 14;
119+
EmbeddedRuntimeConfig embedded_runtime_config = 14;
109120
};

tensorflow/compiler/plugin/poplar/kernels/datastream/application_runtime.cc

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,49 @@ namespace {
101101
const char APPLICATION_RUNTIME_RESOURCE_CONTAINER[] =
102102
"ApplicationRuntimeResourceContainer";
103103

104-
poplar::Device GetIpuDevice(int64 replication_factor = 1) {
104+
bool ParsePoplarTargetType(const std::string target_type_string,
105+
poplar::TargetType& target_type) {
106+
if (target_type_string == "IPU") {
107+
target_type = poplar::TargetType::IPU;
108+
return true;
109+
} else if (target_type_string == "IPU_MODEL") {
110+
target_type = poplar::TargetType::IPU_MODEL;
111+
return true;
112+
} else if (target_type_string == "CPU") {
113+
target_type = poplar::TargetType::CPU;
114+
return true;
115+
}
116+
return false;
117+
}
118+
119+
StatusOr<poplar::Device> GetIpuDevice(const poplar::TargetType target_type,
120+
const std::string target_arch_string,
121+
const int64 num_IPUs,
122+
const bool gateway_mode,
123+
const bool supports_remote_buffers) {
105124
poplar::DeviceManager manager = poplar::DeviceManager::createDeviceManager();
106125
auto devices =
107-
manager.getDevices(poplar::TargetType::IPU, replication_factor);
108-
std::size_t device_idx =
109-
PoplarExecutor::AttachToPoplarDevice(devices, 0, true).ValueOrDie();
126+
manager.getDevices(target_type, num_IPUs,
127+
{{"gatewayMode", gateway_mode ? "true" : "false"}});
128+
TF_ASSIGN_OR_RETURN(std::size_t device_idx,
129+
PoplarExecutor::AttachToPoplarDevice(devices, 0, true));
130+
131+
if (supports_remote_buffers && !devices[device_idx].supportsRemoteBuffers()) {
132+
return errors::InvalidArgument(
133+
"The compiled TensorFlow executable requires remote buffer support, "
134+
"but it is not available on "
135+
"this device");
136+
}
137+
138+
const auto& device_target_arch_string =
139+
devices[device_idx].getTarget().getTargetArchString();
140+
if (device_target_arch_string != target_arch_string) {
141+
return errors::InvalidArgument(absl::StrFormat(
142+
"The target architecture for the compiled executable (%s) does not "
143+
"match device's target architure (%s)",
144+
target_arch_string, device_target_arch_string));
145+
}
146+
110147
return std::move(devices[device_idx]);
111148
}
112149

@@ -148,7 +185,7 @@ class IOConfig {
148185
IOConfig() = default;
149186

150187
void ParsePoplarExecutableProto(PoplarExecutableProto& executable_proto) {
151-
ParseSignature(executable_proto.signature());
188+
ParseSignature(executable_proto.embedded_runtime_config().signature());
152189
}
153190

154191
const IOGroup& GetInputs() const { return inputs_; }
@@ -347,12 +384,23 @@ class ApplicationRuntime : public OpKernel {
347384
OP_REQUIRES_OK(ctx, ctx->GetAttr("engine_name", &engine_name_));
348385

349386
if (!GetEngines().contains(engine_name_)) {
350-
auto device = GetIpuDevice();
351-
352387
PoplarExecutableProto proto;
353388
poplar::Executable executable =
354389
PoplarExecutableBinaryFile::Read(filename_, &proto).ValueOrDie();
355390

391+
auto& ertc = proto.embedded_runtime_config();
392+
const std::string target_type_string = ertc.target_type();
393+
poplar::TargetType target_type;
394+
OP_REQUIRES(ctx, ParsePoplarTargetType(target_type_string, target_type),
395+
errors::InvalidArgument(absl::StrFormat(
396+
"Invalid target type %s", target_type_string)));
397+
398+
auto status_or_device =
399+
GetIpuDevice(target_type, ertc.target_arch(), ertc.num_ipus(),
400+
ertc.gateway_mode(), ertc.supports_remote_buffers());
401+
OP_REQUIRES_OK(ctx, status_or_device.status());
402+
auto& device = status_or_device.ValueOrDie();
403+
356404
auto& io_config = resources_.IOCfg();
357405
io_config.ParsePoplarExecutableProto(proto);
358406
VerifyExecutable(proto);

0 commit comments

Comments
 (0)