Skip to content

Commit ee82dcd

Browse files
hakosgeorgepaw
authored andcommitted
Add application compile op
Summary: This adds the `IPUApplicationCompile` op that compiles a function and returns a string to the compiled executable. It also supports freezing the variables such that they become constants embedded into the executable. This is not really tied to the application runtime per se, but it seems good to limit the initial scope. Ref. T41635. TF2.4 only. Test Plan: Added new tests and integrated with existing application runtime test. Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, gauthamg, jakeh, georgep Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, jakeh, georgep Subscribers: georgep Maniphest Tasks: T41635 Differential Revision: https://phabricator.sourcevertex.net/D48034
1 parent 818520b commit ee82dcd

File tree

12 files changed

+582
-60
lines changed

12 files changed

+582
-60
lines changed

tensorflow/compiler/plugin/poplar/BUILD

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,8 @@ cc_library(
956956
cc_library(
957957
name = "kernels",
958958
srcs = [
959-
"kernels/datastream/application_runtime.cc",
959+
"kernels/application_runtime/application_compile.cc",
960+
"kernels/application_runtime/application_runtime.cc",
960961
"kernels/datastream/dataset_benchmark.cc",
961962
"kernels/datastream/feeds.cc",
962963
"kernels/datastream/host_embedding.cc",
@@ -1016,6 +1017,7 @@ cc_library(
10161017
":driver",
10171018
":xla_util",
10181019
"//tensorflow/compiler/jit:xla_device",
1020+
"//tensorflow/compiler/jit/kernels:xla_ops",
10191021
"//tensorflow/compiler/plugin/poplar/kernels/dataset:kernels",
10201022
"//tensorflow/compiler/tf2xla:common",
10211023
"//tensorflow/compiler/tf2xla:xla_compiler",
@@ -1372,7 +1374,10 @@ tf_custom_op_py_library(
13721374

13731375
cc_library(
13741376
name = "application_runtime",
1375-
srcs = ["ops/datastream/application_runtime.cc"],
1377+
srcs = [
1378+
"ops/application_runtime/application_compile.cc",
1379+
"ops/application_runtime/application_runtime.cc",
1380+
],
13761381
deps = [
13771382
"//tensorflow/core:framework",
13781383
],
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
16+
#include "tensorflow/compiler/plugin/poplar/driver/poplar_executable.h"
17+
#include "tensorflow/compiler/plugin/poplar/driver/poplar_platform.h"
18+
#include "tensorflow/compiler/plugin/poplar/kernels/ipu_kernels_common.h"
19+
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
20+
#include "tensorflow/compiler/xla/client/client_library.h"
21+
#include "tensorflow/compiler/xla/statusor.h"
22+
#include "tensorflow/core/framework/function.h"
23+
#include "tensorflow/core/framework/op_kernel.h"
24+
#include "tensorflow/core/lib/core/status.h"
25+
26+
namespace tensorflow {
27+
28+
namespace {
29+
30+
Status BuildCompilationCache(OpKernelContext* ctx, se::Platform* platform,
31+
XlaCompilationCache** out_cache) {
32+
xla::LocalClientOptions client_options;
33+
client_options.set_platform(platform);
34+
client_options.set_intra_op_parallelism_threads(
35+
ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
36+
TF_ASSIGN_OR_RETURN(
37+
auto* client, xla::ClientLibrary::GetOrCreateLocalClient(client_options));
38+
const XlaOpRegistry::DeviceRegistration* registration;
39+
if (!XlaOpRegistry::GetCompilationDevice("IPU", &registration)) {
40+
return errors::InvalidArgument("No JIT device registered for IPU");
41+
}
42+
43+
*out_cache = new XlaCompilationCache(
44+
client, DeviceType(registration->compilation_device_name));
45+
return Status::OK();
46+
}
47+
48+
xla::StatusOr<xla::LocalExecutable*> CompileExecutable(
49+
OpKernelContext* ctx, const NameAttrList& function, se::Platform* platform,
50+
absl::Span<const Tensor* const> inputs,
51+
absl::Span<const VariableInfo> variable_infos,
52+
absl::Span<const int> constants) {
53+
auto* resource_manager = ctx->resource_manager();
54+
if (!resource_manager) {
55+
return errors::Internal("Resource manager not found");
56+
}
57+
58+
XlaCompilationCache* cache;
59+
TF_RETURN_IF_ERROR(resource_manager->LookupOrCreate<XlaCompilationCache>(
60+
resource_manager->default_container(), "ipu_application_compile_cache",
61+
&cache, [&](XlaCompilationCache** cache) {
62+
return BuildCompilationCache(ctx, platform, cache);
63+
}));
64+
core::ScopedUnref cache_ref(cache);
65+
66+
const auto* function_library = ctx->function_library();
67+
if (!function_library) {
68+
return errors::Internal("Function library not found");
69+
}
70+
71+
const auto* flib_def = function_library->GetFunctionLibraryDefinition();
72+
const auto* func_def = CHECK_NOTNULL(flib_def)->Find(function.name());
73+
if (!func_def) {
74+
return errors::Internal("Function not found: " + function.name());
75+
}
76+
77+
VLOG(1) << "Compiling function: " << DebugString(*func_def);
78+
79+
XlaCompiler::Options options;
80+
options.client = cache->client();
81+
options.device_type = cache->device_type();
82+
options.flib_def = flib_def;
83+
options.graph_def_version = function_library->graph_def_version();
84+
85+
se::TfAllocatorAdapter tf_allocator_adapter(ctx->device()->GetAllocator({}),
86+
platform);
87+
options.device_allocator = &tf_allocator_adapter;
88+
89+
XlaCompiler::CompileOptions compile_options;
90+
compile_options.is_entry_computation = true;
91+
compile_options.always_return_tuple = false;
92+
93+
// IPU Specific - store the names of all inputs.
94+
std::vector<std::string> mangled_input_names(inputs.size());
95+
for (int64 i = 0; i != inputs.size(); ++i) {
96+
mangled_input_names[i] = ctx->op_kernel().requested_input(i);
97+
}
98+
99+
TF_ASSIGN_OR_RETURN(
100+
std::vector<XlaCompiler::Argument> arguments,
101+
XlaComputationLaunchContext::BuildXlaCompilerArguments(
102+
constants, inputs, variable_infos, mangled_input_names));
103+
104+
const XlaCompiler::CompilationResult* compilation_result;
105+
xla::LocalExecutable* executable;
106+
TF_RETURN_IF_ERROR(cache->Compile(options, function, arguments,
107+
compile_options,
108+
XlaCompilationCache::CompileMode::kStrict,
109+
&compilation_result, &executable));
110+
return executable;
111+
}
112+
113+
} // namespace
114+
115+
class IPUApplicationCompile : public OpKernel {
116+
public:
117+
explicit IPUApplicationCompile(OpKernelConstruction* ctx) : OpKernel(ctx) {
118+
OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &function_));
119+
OP_REQUIRES_OK(ctx, ctx->GetAttr("resource_indices", &resource_indices_));
120+
OP_REQUIRES_OK(ctx, ctx->GetAttr("constant_indices", &constant_indices_));
121+
OP_REQUIRES_OK(
122+
ctx, ctx->GetAttr("executable_output_path", &executable_output_path_));
123+
}
124+
125+
void Compute(OpKernelContext* ctx) {
126+
auto platform_or_status =
127+
se::MultiPlatformManager::PlatformWithName("Poplar");
128+
OP_REQUIRES_OK(ctx, platform_or_status.status());
129+
auto* platform = platform_or_status.ValueOrDie();
130+
131+
std::vector<const Tensor*> inputs = InputsFromContext(ctx);
132+
std::vector<VariableInfo> variable_infos;
133+
OP_REQUIRES_OK(ctx, GetVariableInfosFromInputs(
134+
ctx->resource_manager(), ctx->device(), inputs,
135+
resource_indices_, &variable_infos));
136+
137+
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
138+
139+
auto executable_or_status = CompileExecutable(
140+
ctx, function_, platform, inputs, variable_infos, constant_indices_);
141+
OP_REQUIRES_OK(ctx, executable_or_status.status());
142+
143+
auto* poplar_executable =
144+
dynamic_cast<xla::poplarplugin::PoplarExecutable*>(
145+
executable_or_status.ValueOrDie()->executable());
146+
OP_REQUIRES(ctx, poplar_executable != nullptr,
147+
errors::Internal("Missing Poplar executable"));
148+
149+
OP_REQUIRES_OK(ctx, poplar_executable->Serialize(executable_output_path_));
150+
ctx->set_output(0, Tensor(executable_output_path_));
151+
}
152+
153+
private:
154+
NameAttrList function_;
155+
std::string executable_output_path_;
156+
std::vector<int> constant_indices_;
157+
std::vector<int> resource_indices_;
158+
159+
TF_DISALLOW_COPY_AND_ASSIGN(IPUApplicationCompile);
160+
};
161+
162+
// We register the op both for CPU and IPU to make it easier to use, as we then
163+
// can handle any colocation requirements from variables etc. The function will
164+
// be compiled for IPU regardless of the device placement of the op itself.
165+
REGISTER_KERNEL_BUILDER(Name("IPUApplicationCompile").Device(DEVICE_CPU),
166+
IPUApplicationCompile);
167+
REGISTER_KERNEL_BUILDER(Name("IPUApplicationCompile").Device(DEVICE_XLA_IPU),
168+
IPUApplicationCompile);
169+
170+
} // namespace tensorflow
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/core/framework/op.h"
17+
18+
namespace tensorflow {
19+
20+
REGISTER_OP("IPUApplicationCompile")
21+
.Input("args: Targs")
22+
.Attr("Targs: list(type) >= 0")
23+
.Attr("resource_indices: list(int) >= 0")
24+
.Attr("constant_indices: list(int) >= 0")
25+
.Attr("executable_output_path: string")
26+
.Output("output: string")
27+
.Attr("function: func")
28+
// Compilation cache is stateful.
29+
.SetIsStateful();
30+
31+
} // namespace tensorflow

tensorflow/python/framework/func_graph.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,10 @@ def _capture_by_value(
509509
compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list,
510510
context.context())
511511
else:
512-
op = ops.get_default_graph()._create_op_internal( # pylint: disable=protected-access
512+
# Make sure the name is unique in the outer graph.
513+
outer_graph = ops.get_default_graph()
514+
name = outer_graph.unique_name(name)
515+
op = outer_graph._create_op_internal( # pylint: disable=protected-access
513516
op_type,
514517
uncaptured_inputs,
515518
dtypes,

tensorflow/python/ipu/BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ load("@local_config_ipu_horovod//:build_defs_horovod.bzl", "if_horovod", "poprun
1313
py_library(
1414
name = "ipu_ops_lib",
1515
srcs = [
16+
"ops/application_compile_op.py",
1617
"ops/cross_replica_ops.py",
1718
"ops/embedding_ops.py",
1819
"ops/functional_ops.py",
@@ -291,11 +292,23 @@ tf_py_test(
291292
],
292293
)
293294

295+
tf_py_test(
296+
name = "application_compile_test",
297+
size = "large",
298+
srcs = ["tests/application_compile_test.py"],
299+
deps = [
300+
"//tensorflow/compiler/plugin/poplar:test_utils_py",
301+
"//tensorflow/compiler/tests:xla_test",
302+
"//tensorflow/python/ipu:ipu_lib",
303+
],
304+
)
305+
294306
tf_py_test(
295307
name = "application_runtime_test",
296308
size = "large",
297309
srcs = ["tests/application_runtime_test.py"],
298310
shard_count = 4,
311+
tags = ["hw_poplar_test"],
299312
deps = [
300313
"//tensorflow/compiler/plugin/poplar:ipu_ops_py",
301314
"//tensorflow/compiler/plugin/poplar:test_utils_py",
@@ -1913,6 +1926,8 @@ tf_py_test(
19131926
test_suite(
19141927
name = "all_tests",
19151928
tests = [
1929+
"application_compile_test",
1930+
"application_runtime_test",
19161931
"assume_equal_test",
19171932
"candidate_sampler_test",
19181933
"config_test",

tensorflow/python/ipu/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# pylint: disable=wildcard-import,unused-import
2121
from tensorflow.python.ipu.ops import all_to_all_op
2222
from tensorflow.python.ipu.ops import all_to_all_op_grad
23+
from tensorflow.python.ipu.ops import application_compile_op
2324
from tensorflow.python.ipu.ops import custom_ops
2425
from tensorflow.python.ipu.ops import cross_replica_ops
2526
from tensorflow.python.ipu.ops import cross_replica_ops_grad

0 commit comments

Comments
 (0)