Skip to content

Commit 598a9e5

Browse files
arttianezhumeta-codesync[bot]
authored andcommitted
make Gpe host side wait cancellable
Summary: As title. abortCtrl is not held by KernelElem since it is a struct shared by both host & device side. Arguably we could create a wrapper class for KernelElem on the host side, but I feel like that might be overkill, so only a new function is added. Reviewed By: saifhhasan Differential Revision: D85371956 fbshipit-source-id: cdad607103d81735b87328e36e660167789f4caf
1 parent 898a95b commit 598a9e5

File tree

3 files changed

+118
-1
lines changed

3 files changed

+118
-1
lines changed

comms/ctran/gpe/CtranGpeDev.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "comms/ctran/algos/CtranAlgoArgDev.h"
1010
#include "comms/ctran/algos/CtranAlgoDev.h"
11+
#include "comms/ctran/utils/Abort.h"
1112
#include "comms/utils/commSpecs.h"
1213

1314
#ifdef CTRAN_DISABLE_TCPDM
@@ -139,6 +140,7 @@ struct alignas(16) KernelElem {
139140
// need make progress. It can be safely called only when algorithm ensures no
140141
// network progress is needed.
141142
void wait(int groupId = -1);
143+
void wait(std::shared_ptr<ctran::utils::Abort> abort, int groupId = -1);
142144
};
143145

144146
template <>

comms/ctran/gpe/CtranGpeImpl.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,14 @@ void KernelElem::wait(int groupId) {
779779
}
780780
}
781781
782+
void KernelElem::wait(std::shared_ptr<ctran::utils::Abort> abort, int groupId) {
783+
// wait for all thread blocks to complete
784+
while (!this->isComplete(groupId) && !abort->Test()) {
785+
// friendly spin so we don't hog CPU
786+
std::this_thread::yield();
787+
}
788+
}
789+
782790
KernelElemPool::KernelElemPool(size_t capacity) : capacity_(capacity) {
783791
FB_CUDACHECKTHROW(cudaHostAlloc(
784792
&this->memPtr_,

comms/ctran/gpe/tests/KernelElemPoolUT.cc

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
// Copyright (c) Meta Platforms, Inc. and affiliates.
22

3+
#include <csignal>
4+
#include <thread>
5+
36
#include <gmock/gmock.h>
47
#include <gtest/gtest.h>
5-
#include <csignal>
8+
69
#include "comms/ctran/Ctran.h"
710
#include "comms/ctran/gpe/CtranGpeDev.h"
811
#include "comms/ctran/gpe/CtranGpeImpl.h"
912
#include "comms/ctran/utils/CudaWrap.h"
1013
// FIXME [REBASE]: update the path once moved to fbcode/comms
1114
#include "comms/ctran/gpe/tests/KernelElemPoolUTKernels.h"
1215
#include "comms/ctran/tests/CtranXPlatUtUtils.h"
16+
#include "comms/ctran/utils/Abort.h"
1317

1418
class KernelElemPoolTest : public ::testing::Test {
1519
public:
@@ -395,3 +399,106 @@ TEST_F(KernelElemPoolAbortTest, CheckFreeWithInvalidNGroups) {
395399
ASSERT_EXIT(elem->isFree(), testing::KilledBySignal(SIGABRT), "")
396400
<< "Expect abort when revoking with invalid ngroups";
397401
}
402+
403+
TEST_F(KernelElemPoolTest, WaitWithoutAbortCtrl) {
404+
constexpr int poolSize = 10;
405+
auto elemPool = std::make_unique<KernelElemPool>(poolSize);
406+
ASSERT_NE(elemPool, nullptr);
407+
408+
constexpr int ngroups = 5;
409+
KernelElem* elem = elemPool->pop(ngroups);
410+
ASSERT_NE(elem, nullptr);
411+
412+
// Set elem status to simulate completion
413+
for (int i = 0; i < ngroups; i++) {
414+
elem->status[i] = KernelElem::ElemStatus::INUSE;
415+
}
416+
417+
std::thread completer([elem, ngroups]() {
418+
std::this_thread::sleep_for(std::chrono::milliseconds(50));
419+
for (int i = 0; i < ngroups; i++) {
420+
elem->status[i] = KernelElem::ElemStatus::DONE;
421+
}
422+
});
423+
elem->wait();
424+
EXPECT_TRUE(elem->isComplete());
425+
426+
completer.join();
427+
428+
// Clean up
429+
elem->free();
430+
elemPool->reclaim();
431+
}
432+
433+
TEST_F(KernelElemPoolTest, WaitWithAbortCtrlWithoutSet) {
434+
constexpr int poolSize = 10;
435+
auto elemPool = std::make_unique<KernelElemPool>(poolSize);
436+
ASSERT_NE(elemPool, nullptr);
437+
438+
constexpr int ngroups = 5;
439+
KernelElem* elem = elemPool->pop(ngroups);
440+
ASSERT_NE(elem, nullptr);
441+
442+
// Enable abortCtrl
443+
auto abortCtrl = ctran::utils::createAbort(/*enabled=*/true);
444+
EXPECT_TRUE(abortCtrl->Enabled());
445+
EXPECT_FALSE(abortCtrl->Test());
446+
447+
// Set elem status to simulate ongoing work
448+
for (int i = 0; i < ngroups; i++) {
449+
elem->status[i] = KernelElem::ElemStatus::INUSE;
450+
}
451+
452+
std::thread completer([elem, ngroups]() {
453+
std::this_thread::sleep_for(std::chrono::milliseconds(50));
454+
for (int i = 0; i < ngroups; i++) {
455+
elem->status[i] = KernelElem::ElemStatus::DONE;
456+
}
457+
});
458+
elem->wait(abortCtrl);
459+
EXPECT_TRUE(elem->isComplete());
460+
EXPECT_FALSE(abortCtrl->Test());
461+
462+
completer.join();
463+
464+
// Clean up
465+
elem->free();
466+
elemPool->reclaim();
467+
}
468+
469+
TEST_F(KernelElemPoolTest, WaitWithAbortCtrlUnblockOnSet) {
470+
constexpr int poolSize = 10;
471+
auto elemPool = std::make_unique<KernelElemPool>(poolSize);
472+
ASSERT_NE(elemPool, nullptr);
473+
474+
constexpr int ngroups = 5;
475+
KernelElem* elem = elemPool->pop(ngroups);
476+
ASSERT_NE(elem, nullptr);
477+
478+
// Enable abortCtrl
479+
auto abortCtrl = ctran::utils::createAbort(/*enabled=*/true);
480+
EXPECT_TRUE(abortCtrl->Enabled());
481+
EXPECT_FALSE(abortCtrl->Test());
482+
483+
// Set elem status to simulate ongoing work that won't complete
484+
for (int i = 0; i < ngroups; i++) {
485+
elem->status[i] = KernelElem::ElemStatus::INUSE;
486+
}
487+
488+
std::thread aborter([abortCtrl]() {
489+
std::this_thread::sleep_for(std::chrono::milliseconds(50));
490+
abortCtrl->Set();
491+
});
492+
elem->wait(abortCtrl);
493+
EXPECT_FALSE(elem->isComplete());
494+
EXPECT_TRUE(abortCtrl->Test());
495+
496+
aborter.join();
497+
498+
// Clean up - since abort is set and element is not complete,
499+
// we need to manually set status to allow free
500+
for (int i = 0; i < ngroups; i++) {
501+
elem->status[i] = KernelElem::ElemStatus::RESET;
502+
}
503+
elemPool->reclaim();
504+
}

0 commit comments

Comments
 (0)