Skip to content

Commit f13daef

Browse files
committed
Cherry-pick from upstream to fix handling of constant resources
Summary: This commit cherry-pick: ``` [TF2XLA] Support must-be-constant resource variables for compilation Performs an explicit copy at runtime from device to host if needed. PiperOrigin-RevId: 341491694 Change-Id: If4a6c0c76a1110637a06e96595c6013c8fac17e5 ``` ``` Remove `platform` field from shaped buffer. This further simplifies ShapedBuffer object as itself doesn't have any logic to use the platform field. Notice that this cl also removed some sanity check in allocation tracker. we can add that sanity check back if need --- just keep track of `platform` inside of allocation tracker as a side map. PiperOrigin-RevId: 339938197 Change-Id: I090e603927ed3fccdb51254f972b3af2e1ec1470 ``` ``` [TF2XLA] [NFC] Provide a more informative error message when encountering a must-constant in a resource variable PiperOrigin-RevId: 339779174 Change-Id: Ic364413f6794393cc15d1ccab5a5cb127560f2d1 ``` With additional merge conflict fixes/API Changes fixes. This change is required for shape inference. Fix T43533 TF2.4 Only Test Plan: CI Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, hakons, jakeh Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, hakons Maniphest Tasks: T43533 Differential Revision: https://phabricator.sourcevertex.net/D49203
1 parent 559ff28 commit f13daef

34 files changed

+307
-118
lines changed

tensorflow/compiler/jit/get_compiler_ir.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ xla::StatusOr<std::string> GetCompilerIr(
115115

116116
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
117117
XlaComputationLaunchContext::BuildXlaCompilerArguments(
118-
constant_arg_indices, inputs, variable_infos);
118+
constant_arg_indices, inputs, variable_infos, dev);
119119
TF_RETURN_IF_ERROR(args.status());
120120

121121
switch (stage) {

tensorflow/compiler/jit/kernels/xla_ops.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ static Status CompileToLocalExecutable(
213213

214214
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
215215
XlaComputationLaunchContext::BuildXlaCompilerArguments(
216-
constants, inputs, variable_infos, mangled_input_names);
216+
constants, inputs, variable_infos,
217+
static_cast<Device*>(ctx->device()), mangled_input_names);
217218
TF_RETURN_IF_ERROR(args.status());
218219
return cache->Compile(options, function, *args, compile_options,
219220
lazy ? XlaCompilationCache::CompileMode::kLazy
@@ -252,8 +253,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
252253
se::Stream* stream =
253254
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
254255

255-
VLOG(1) << "Executing XLA Computation...";
256-
257256
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
258257
se::DeviceMemoryAllocator* allocator = GetAllocator(
259258
&tf_allocator_adapter, ctx->device(),

tensorflow/compiler/jit/xla_compilation_cache.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ XlaCompilationCache::BuildSignature(
139139
for (const XlaCompiler::Argument& arg : args) {
140140
switch (arg.kind) {
141141
case XlaCompiler::Argument::kConstant:
142+
case XlaCompiler::Argument::kConstantResource:
142143
signature.arg_values.push_back(arg.constant_value);
143144
break;
144145
case XlaCompiler::Argument::kParameter:
@@ -488,6 +489,7 @@ Status XlaCompilationCache::CompileImpl(
488489
argument_input_indices.push_back(i);
489490
break;
490491
}
492+
case XlaCompiler::Argument::kConstantResource:
491493
case XlaCompiler::Argument::kResource: {
492494
resource_input_indices.push_back(i);
493495
resource_input_initialized.push_back(arg.initialized);

tensorflow/compiler/jit/xla_compile_on_demand_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ Status XlaCompileOnDemandOp::Compile(
152152
ctx, variables_indices, variable_infos, variable_args));
153153

154154
args = XlaComputationLaunchContext::BuildXlaCompilerArguments(
155-
constant_input_indices, inputs, variable_infos);
155+
constant_input_indices, inputs, variable_infos,
156+
static_cast<Device*>(ctx->device()));
156157
TF_RETURN_IF_ERROR(args.status());
157158
}
158159

tensorflow/compiler/jit/xla_launch_util.cc

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
426426
ShapedBuffer buffer(
427427
xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}),
428428
xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_device_shape()}),
429-
output.platform(), output.device_ordinal());
429+
output.device_ordinal());
430430
buffer.buffers().CopySubtreeFrom(nontuple_buffer.buffers(),
431431
/*source_base_index=*/{},
432432
/*target_base_index=*/{0});
@@ -564,12 +564,27 @@ xla::StatusOr<std::vector<XlaCompiler::Argument>>
564564
XlaComputationLaunchContext::BuildXlaCompilerArguments(
565565
absl::Span<int const> must_be_constant_idxs,
566566
absl::Span<const Tensor* const> inputs,
567-
absl::Span<VariableInfo const> variable_args,
567+
absl::Span<VariableInfo const> variable_args, Device* device,
568568
const std::vector<std::string>& mangled_input_names) {
569569
CHECK(absl::c_is_sorted(must_be_constant_idxs));
570570
std::vector<XlaCompiler::Argument> out;
571571
out.resize(inputs.size());
572572

573+
// TODO(cheshire): Avoid duplication with framework/op_kernel.h
574+
DeviceContext* device_context = nullptr;
575+
TF_RETURN_IF_ERROR(device->TryGetDeviceContext(&device_context));
576+
bool using_default_context = false;
577+
auto cleanup = xla::MakeCleanup([&] {
578+
if (device_context != nullptr && !using_default_context) {
579+
device_context->Unref();
580+
}
581+
});
582+
if (device_context == nullptr) {
583+
using_default_context = true;
584+
auto* dev_info = device->tensorflow_gpu_device_info();
585+
if (dev_info) device_context = dev_info->default_context;
586+
}
587+
573588
absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
574589
for (const VariableInfo& info : variable_args) {
575590
CHECK(!info.var() || info.lock_held())
@@ -589,14 +604,7 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
589604
arg.name = mangled_input_names.at(input_num);
590605
}
591606

592-
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
593-
// Handles compile-time constants.
594-
TF_RET_CHECK(input->dtype() != DT_RESOURCE);
595-
arg.kind = XlaCompiler::Argument::kConstant;
596-
arg.type = input->dtype();
597-
arg.shape = input->shape();
598-
arg.constant_value = *input;
599-
} else if (variable_info_lookup.count(input_num)) {
607+
if (variable_info_lookup.count(input_num)) {
600608
// Handles resource variables.
601609
TF_RET_CHECK(input->dtype() == DT_RESOURCE);
602610
const VariableInfo& variable = *variable_info_lookup[input_num];
@@ -617,6 +625,25 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
617625
arg.type = DT_INVALID;
618626
arg.shape = TensorShape();
619627
}
628+
629+
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
630+
TF_RET_CHECK(variable.var() && variable.var()->is_initialized);
631+
const Tensor* value = variable.var()->tensor();
632+
Tensor value_on_host(value->dtype(), value->shape());
633+
if (!device_context) {
634+
value_on_host = *value;
635+
} else {
636+
TF_RETURN_IF_ERROR(device_context->CopyDeviceTensorToCPUSync(
637+
value, "", device, &value_on_host));
638+
}
639+
arg.kind = XlaCompiler::Argument::kConstantResource;
640+
arg.constant_value = value_on_host;
641+
}
642+
} else if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
643+
arg.kind = XlaCompiler::Argument::kConstant;
644+
arg.type = input->dtype();
645+
arg.shape = input->shape();
646+
arg.constant_value = *input;
620647
} else {
621648
// Normal inputs.
622649
TF_RET_CHECK(input->dtype() != DT_RESOURCE);

tensorflow/compiler/jit/xla_launch_util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ class XlaComputationLaunchContext {
144144
BuildXlaCompilerArguments(
145145
absl::Span<int const> must_be_constant_idxs,
146146
absl::Span<const Tensor* const> inputs,
147-
absl::Span<VariableInfo const> variable_args,
147+
absl::Span<VariableInfo const> variable_args, Device* device,
148148
const std::vector<std::string>& mangled_input_names = {});
149149

150150
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).

tensorflow/compiler/plugin/poplar/kernels/application_runtime/application_compile.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ xla::StatusOr<xla::LocalExecutable*> CompileExecutable(
9999
TF_ASSIGN_OR_RETURN(
100100
std::vector<XlaCompiler::Argument> arguments,
101101
XlaComputationLaunchContext::BuildXlaCompilerArguments(
102-
constants, inputs, variable_infos, mangled_input_names));
102+
constants, inputs, variable_infos,
103+
static_cast<Device*>(ctx->device()), mangled_input_names));
103104

104105
const XlaCompiler::CompilationResult* compilation_result;
105106
xla::LocalExecutable* executable;

tensorflow/compiler/plugin/poplar/tests/graph_compile_io_map_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -436,11 +436,11 @@ main {
436436

437437
Shape shape = ShapeUtil::MakeShape(F32, {2});
438438

439-
ShapedBuffer arg0(shape, shape, platform, 0);
439+
ShapedBuffer arg0(shape, shape, 0);
440440
arg0.set_buffer(buf0, {});
441-
ShapedBuffer arg1(shape, shape, platform, 0);
441+
ShapedBuffer arg1(shape, shape, 0);
442442
arg1.set_buffer(buf1, {});
443-
ShapedBuffer arg2(shape, shape, platform, 0);
443+
ShapedBuffer arg2(shape, shape, 0);
444444
arg2.set_buffer(buf2, {});
445445

446446
std::vector<const ShapedBuffer*> args = {&arg0, &arg1, &arg2};
@@ -549,9 +549,9 @@ main {
549549

550550
Shape shape = ShapeUtil::MakeShape(F32, {2});
551551

552-
ShapedBuffer arg0(shape, shape, platform, 0);
552+
ShapedBuffer arg0(shape, shape, 0);
553553
arg0.set_buffer(buf, {});
554-
ShapedBuffer arg1(shape, shape, platform, 0);
554+
ShapedBuffer arg1(shape, shape, 0);
555555
arg1.set_buffer(buf, {});
556556

557557
std::vector<const ShapedBuffer*> args = {&arg0, &arg1};

tensorflow/compiler/plugin/poplar/tests/variable_test.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
import test_utils as tu
2020

2121
from tensorflow.compiler.tests import xla_test
22+
from tensorflow.python.eager import def_function
2223
from tensorflow.python.platform import googletest
24+
from tensorflow.python.framework import constant_op
25+
from tensorflow.python.framework import dtypes
26+
from tensorflow.python.framework import errors
2327
from tensorflow.python.framework import ops
2428
from tensorflow.python.ipu.config import IPUConfig
2529
from tensorflow.python.ops import array_ops
@@ -796,6 +800,74 @@ def testNonModifiedResourceIsNotOverwrittenInPlaceOp(self):
796800
[], "w should be copied to device once and "
797801
"that should be the only io event")
798802

803+
def testGetConstantOutOfResourceVariable(self):
804+
with ops.device("/device:IPU:0"):
805+
806+
# Use floats to force device placement.
807+
a = variables.Variable(50.0)
808+
b = variables.Variable(2.0)
809+
810+
@def_function.function(experimental_compile=True)
811+
def f(x):
812+
return array_ops.reshape(
813+
x,
814+
[math_ops.cast(a, dtypes.int32),
815+
math_ops.cast(b, dtypes.int32)])
816+
817+
# OK since the value is known at compile time.
818+
out = f(random_ops.random_normal([10, 10]))
819+
self.assertEqual(out.shape[0], 50)
820+
self.assertEqual(out.shape[1], 2)
821+
822+
def testGetConstantOutOfResourceVariableAfterWrite(self):
823+
with ops.device("/device:IPU:0"):
824+
825+
# Use floats to force device placement.
826+
a = variables.Variable(50.0)
827+
b = variables.Variable(2.0)
828+
829+
@def_function.function(experimental_compile=True)
830+
def f(x, val1, val2):
831+
a.assign(math_ops.cast(val1, dtypes.float32))
832+
b.assign(math_ops.cast(val2, dtypes.float32))
833+
return array_ops.reshape(
834+
x,
835+
[math_ops.cast(a, dtypes.int32),
836+
math_ops.cast(b, dtypes.int32)])
837+
838+
val1 = constant_op.constant(2)
839+
val2 = constant_op.constant(50)
840+
841+
# Returns an error, since the value known at compile time was overriden.
842+
with self.assertRaisesRegex(errors.InvalidArgumentError,
843+
'concrete values at compile time'):
844+
f(random_ops.random_normal([10, 10]), val1, val2)
845+
846+
def testGetConstantOutOfResourceVariableBeforeWrite(self):
847+
with ops.device("/device:IPU:0"):
848+
849+
# Use floats to force device placement.
850+
a = variables.Variable(50.0)
851+
b = variables.Variable(2.0)
852+
853+
@def_function.function(experimental_compile=True)
854+
def f(x, val1, val2):
855+
out = array_ops.reshape(
856+
x,
857+
[math_ops.cast(a, dtypes.int32),
858+
math_ops.cast(b, dtypes.int32)])
859+
a.assign(math_ops.cast(val1, dtypes.float32))
860+
b.assign(math_ops.cast(val2, dtypes.float32))
861+
return out
862+
863+
val1 = constant_op.constant(2)
864+
val2 = constant_op.constant(50)
865+
866+
# OK since the write happens after the reshape.
867+
out = f(random_ops.random_normal([10, 10]), val1, val2)
868+
self.assertEqual(out.shape[0], 50)
869+
self.assertEqual(out.shape[1], 2)
870+
799871

800872
if __name__ == "__main__":
801873
os.environ['TF_XLA_FLAGS'] = ('--tf_xla_min_cluster_size=1 ' +

tensorflow/compiler/tf2xla/graph_compiler.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
7373
switch (expressions[i]->kind()) {
7474
case XlaExpression::Kind::kConstant:
7575
arg.kind = XlaCompiler::Argument::kConstant;
76-
arg.constant_value = expressions[i]->constant_value();
76+
arg.constant_value = *expressions[i]->constant_value();
7777
break;
7878
case XlaExpression::Kind::kXlaOp:
7979
if (arg_must_be_compile_time_constant[i]) {

0 commit comments

Comments
 (0)