Skip to content

Commit 2628014

Browse files
author
Gautham Ganapathy
committed
Fix for OOM in EfficientDet
Summary: The changes for a unified instruction cache was causing EfficientDet to go OOM (T66916). This was caused by the input layout flag not being provided correctly to the SubcomputationGraphCache. FIX T68951 Reviewers: alfiee, jackh, #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, zigmasb Reviewed By: alfiee, jackh, #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, zigmasb Maniphest Tasks: T68951 Differential Revision: https://phabricator.sourcevertex.net/D79039
1 parent cf0511c commit 2628014

File tree

3 files changed

+6
-1
lines changed

3 files changed

+6
-1
lines changed

tensorflow/compiler/plugin/poplar/driver/ops/map_ops.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ StatusOr<DriverProgramSequence> CreateFunctionOp(
666666
HloComputation* comp = inst->to_apply();
667667

668668
bool keep_input_layouts = false;
669-
if (IsFunction(inst)) {
669+
if (IsFunction(inst) || IsCall(inst)) {
670670
keep_input_layouts = GetFunctionKeepInputLayouts(inst);
671671
}
672672

tensorflow/compiler/plugin/poplar/driver/tools/util.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,10 @@ bool IsFunction(const HloInstruction* inst) {
457457
return CallConfigHasType(inst, PoplarBackendConfig::CallConfig::Function);
458458
}
459459

460+
bool IsCall(const HloInstruction* inst) {
461+
return CallConfigHasType(inst, PoplarBackendConfig::CallConfig::Call);
462+
}
463+
460464
bool IsMultiConv(const HloInstruction* inst) {
461465
return CallConfigHasType(inst, PoplarBackendConfig::CallConfig::MultiConv);
462466
}

tensorflow/compiler/plugin/poplar/driver/tools/util.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ bool IsPipelineStageBackward(const HloInstruction*);
173173
bool IsPipelineStageRecomputation(const HloInstruction*);
174174
bool IsResourceUpdate(const HloInstruction*);
175175
bool IsFunction(const HloInstruction*);
176+
bool IsCall(const HloInstruction*);
176177
bool IsMultiConv(const HloInstruction*);
177178
bool IsPipelineOp(const HloInstruction*);
178179
bool IsBatchSerializedPipelineOp(const HloInstruction*);

0 commit comments

Comments
 (0)