Skip to content

Commit 395baba

Browse files
cpuhrschmeta-codesync[bot]
authored andcommitted
Version guard for AccumulateGrad::add (#1768)
Summary: Pull Request resolved: #1768 When running `fbpkg build //monarch:monarch_dev` on 3279935a10 we encounter a compile time error that seems to stem from compiling with different versions of PyTorch. We might want to update our build chain. It seems monarch is once compiled against a third part 2.8.0 pytorch and once against the fbcode version. This version guard helps with this version drift if that's intended. Reviewed By: pzhan9, zdevito, soulitzer Differential Revision: D86363543 fbshipit-source-id: cb344e3363595cfe6aed22a27ac8eb8075fae351
1 parent e08b3c5 commit 395baba

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

python/monarch/gradient/_gradient_generator.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323
#include <torch/csrc/autograd/python_function.h> // @manual=//caffe2:torch_extension
2424
#include <torch/extension.h> // @manual=//caffe2:torch_extension
2525

26+
#define TORCH_VERSION_NEWER_THAN(major, minor, patch) \
27+
((TORCH_VERSION_MAJOR > (major)) || \
28+
(TORCH_VERSION_MAJOR == (major) && TORCH_VERSION_MINOR > (minor)) || \
29+
(TORCH_VERSION_MAJOR == (major) && TORCH_VERSION_MINOR == (minor) && \
30+
TORCH_VERSION_PATCH > (patch)))
31+
2632
using torch::autograd::Edge;
2733
using torch::autograd::InputBuffer;
2834
using torch::autograd::Node;
@@ -420,12 +426,20 @@ struct GradientGenerator {
420426
DEBUG_PRINT(
421427
"// add: " << node->node->name()
422428
<< ", input_nr=" << static_cast<int>(input_nr) << "\n");
429+
#if TORCH_VERSION_NEWER_THAN(2, 8, 0)
423430
realInputBuffer(node).add(
424431
input_nr,
425432
check_and_reduce(node->node, input_nr, std::move(t)),
426433
std::nullopt,
427434
std::nullopt,
428435
node->node);
436+
#else
437+
realInputBuffer(node).add(
438+
input_nr,
439+
check_and_reduce(node->node, input_nr, std::move(t)),
440+
std::nullopt,
441+
std::nullopt);
442+
#endif
429443
}
430444

431445
InputBuffer& realInputBuffer(NodeState* state) {

0 commit comments

Comments
 (0)