Skip to content

Commit e4cc26b

Browse files
committed
Revert "Revert "Add AllocatingOutput extensions for the instructions which use HasTensorAllocation target.""
This reverts commit 71c2eab.
1 parent 71c2eab commit e4cc26b

File tree

4 files changed

+11
-9
lines changed

4 files changed

+11
-9
lines changed

tensorflow/compiler/plugin/poplar/driver/tools/custom_ops/hlo_extensions.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,18 @@ namespace xla {
2828
namespace poplarplugin {
2929
namespace {
3030

31-
void RegisterSharedExtensions(HloOpcode opcode) {
31+
void RegisterAllocatingOutputExtensions(HloOpcode opcode) {
3232
auto allocating_output = [](const HloInstruction*) { return true; };
3333
RegisterHloInstructionExtension<AllocatingOutputExtension>(opcode,
3434
allocating_output);
3535
}
36-
REGISTER_HLO_INST_EXTENSIONS(kConstant, RegisterSharedExtensions);
37-
REGISTER_HLO_INST_EXTENSIONS(kInfeed, RegisterSharedExtensions);
38-
REGISTER_HLO_INST_EXTENSIONS(kParameter, RegisterSharedExtensions);
39-
REGISTER_HLO_INST_EXTENSIONS(kReduceWindow, RegisterSharedExtensions);
40-
REGISTER_HLO_INST_EXTENSIONS(kRng, RegisterSharedExtensions);
36+
REGISTER_HLO_INST_EXTENSIONS(kConstant, RegisterAllocatingOutputExtensions);
37+
REGISTER_HLO_INST_EXTENSIONS(kInfeed, RegisterAllocatingOutputExtensions);
38+
REGISTER_HLO_INST_EXTENSIONS(kParameter, RegisterAllocatingOutputExtensions);
39+
REGISTER_HLO_INST_EXTENSIONS(kReduceWindow, RegisterAllocatingOutputExtensions);
40+
REGISTER_HLO_INST_EXTENSIONS(kRng, RegisterAllocatingOutputExtensions);
41+
REGISTER_HLO_INST_EXTENSIONS(kSelectAndScatter,
42+
RegisterAllocatingOutputExtensions);
4143

4244
void RegisterReduceExtensions(HloOpcode opcode) {
4345
auto allocating_output = [](const HloInstruction*) { return true; };

tensorflow/compiler/plugin/poplar/driver/tools/custom_ops/inter_ipu_copy.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ absl::flat_hash_set<int64_t> HloInterIpuCopy::AllocatingIndices() const {
2929
return {};
3030
}
3131

32-
bool HloInterIpuCopy::AllocatingOutput() const { return false; }
32+
bool HloInterIpuCopy::AllocatingOutput() const { return true; }
3333

3434
absl::flat_hash_map<int64_t, int64_t> HloInterIpuCopy::LayoutDependencies()
3535
const {

tensorflow/compiler/plugin/poplar/driver/tools/custom_ops/stateless_random.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ absl::flat_hash_set<int64_t> HloStatelessRandom::AllocatingIndices() const {
3131
return {};
3232
}
3333

34-
bool HloStatelessRandom::AllocatingOutput() const { return false; }
34+
bool HloStatelessRandom::AllocatingOutput() const { return true; }
3535

3636
absl::flat_hash_map<int64_t, int64_t> HloStatelessRandom::LayoutDependencies()
3737
const {

tensorflow/compiler/plugin/poplar/tests/conv_graph_caching_sharded_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def testConvolutionsDontMatchDifferentDevices(self):
6868
# Note how there are two convolutions
6969
ok = [
7070
'*OnTileCopy*', 'vs/conv2d/Conv2D/convolution.*',
71-
'Copy_*vs/conv2d/Conv2D/convolution', 'vs/conv2d_1/Conv2D/convolution'
71+
'vs/conv2d_1/Conv2D/convolution'
7272
]
7373
self.assert_all_compute_sets_and_list(report, ok)
7474

0 commit comments

Comments
 (0)