|
1 | 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. |
2 | 2 |
|
| 3 | +#include <csignal> |
| 4 | +#include <thread> |
| 5 | + |
3 | 6 | #include <gmock/gmock.h> |
4 | 7 | #include <gtest/gtest.h> |
5 | | -#include <csignal> |
| 8 | + |
6 | 9 | #include "comms/ctran/Ctran.h" |
7 | 10 | #include "comms/ctran/gpe/CtranGpeDev.h" |
8 | 11 | #include "comms/ctran/gpe/CtranGpeImpl.h" |
9 | 12 | #include "comms/ctran/utils/CudaWrap.h" |
10 | 13 | // FIXME [REBASE]: update the path once moved to fbcode/comms |
11 | 14 | #include "comms/ctran/gpe/tests/KernelElemPoolUTKernels.h" |
12 | 15 | #include "comms/ctran/tests/CtranXPlatUtUtils.h" |
| 16 | +#include "comms/ctran/utils/Abort.h" |
13 | 17 |
|
14 | 18 | class KernelElemPoolTest : public ::testing::Test { |
15 | 19 | public: |
@@ -395,3 +399,106 @@ TEST_F(KernelElemPoolAbortTest, CheckFreeWithInvalidNGroups) { |
395 | 399 | ASSERT_EXIT(elem->isFree(), testing::KilledBySignal(SIGABRT), "") |
396 | 400 | << "Expect abort when revoking with invalid ngroups"; |
397 | 401 | } |
| 402 | + |
| 403 | +TEST_F(KernelElemPoolTest, WaitWithoutAbortCtrl) { |
| 404 | + constexpr int poolSize = 10; |
| 405 | + auto elemPool = std::make_unique<KernelElemPool>(poolSize); |
| 406 | + ASSERT_NE(elemPool, nullptr); |
| 407 | + |
| 408 | + constexpr int ngroups = 5; |
| 409 | + KernelElem* elem = elemPool->pop(ngroups); |
| 410 | + ASSERT_NE(elem, nullptr); |
| 411 | + |
| 412 | + // Set elem status to simulate completion |
| 413 | + for (int i = 0; i < ngroups; i++) { |
| 414 | + elem->status[i] = KernelElem::ElemStatus::INUSE; |
| 415 | + } |
| 416 | + |
| 417 | + std::thread completer([elem, ngroups]() { |
| 418 | + std::this_thread::sleep_for(std::chrono::milliseconds(50)); |
| 419 | + for (int i = 0; i < ngroups; i++) { |
| 420 | + elem->status[i] = KernelElem::ElemStatus::DONE; |
| 421 | + } |
| 422 | + }); |
| 423 | + elem->wait(); |
| 424 | + EXPECT_TRUE(elem->isComplete()); |
| 425 | + |
| 426 | + completer.join(); |
| 427 | + |
| 428 | + // Clean up |
| 429 | + elem->free(); |
| 430 | + elemPool->reclaim(); |
| 431 | +} |
| 432 | + |
| 433 | +TEST_F(KernelElemPoolTest, WaitWithAbortCtrlWithoutSet) { |
| 434 | + constexpr int poolSize = 10; |
| 435 | + auto elemPool = std::make_unique<KernelElemPool>(poolSize); |
| 436 | + ASSERT_NE(elemPool, nullptr); |
| 437 | + |
| 438 | + constexpr int ngroups = 5; |
| 439 | + KernelElem* elem = elemPool->pop(ngroups); |
| 440 | + ASSERT_NE(elem, nullptr); |
| 441 | + |
| 442 | + // Enable abortCtrl |
| 443 | + auto abortCtrl = ctran::utils::createAbort(/*enabled=*/true); |
| 444 | + EXPECT_TRUE(abortCtrl->Enabled()); |
| 445 | + EXPECT_FALSE(abortCtrl->Test()); |
| 446 | + |
| 447 | + // Set elem status to simulate ongoing work |
| 448 | + for (int i = 0; i < ngroups; i++) { |
| 449 | + elem->status[i] = KernelElem::ElemStatus::INUSE; |
| 450 | + } |
| 451 | + |
| 452 | + std::thread completer([elem, ngroups]() { |
| 453 | + std::this_thread::sleep_for(std::chrono::milliseconds(50)); |
| 454 | + for (int i = 0; i < ngroups; i++) { |
| 455 | + elem->status[i] = KernelElem::ElemStatus::DONE; |
| 456 | + } |
| 457 | + }); |
| 458 | + elem->wait(abortCtrl); |
| 459 | + EXPECT_TRUE(elem->isComplete()); |
| 460 | + EXPECT_FALSE(abortCtrl->Test()); |
| 461 | + |
| 462 | + completer.join(); |
| 463 | + |
| 464 | + // Clean up |
| 465 | + elem->free(); |
| 466 | + elemPool->reclaim(); |
| 467 | +} |
| 468 | + |
| 469 | +TEST_F(KernelElemPoolTest, WaitWithAbortCtrlUnblockOnSet) { |
| 470 | + constexpr int poolSize = 10; |
| 471 | + auto elemPool = std::make_unique<KernelElemPool>(poolSize); |
| 472 | + ASSERT_NE(elemPool, nullptr); |
| 473 | + |
| 474 | + constexpr int ngroups = 5; |
| 475 | + KernelElem* elem = elemPool->pop(ngroups); |
| 476 | + ASSERT_NE(elem, nullptr); |
| 477 | + |
| 478 | + // Enable abortCtrl |
| 479 | + auto abortCtrl = ctran::utils::createAbort(/*enabled=*/true); |
| 480 | + EXPECT_TRUE(abortCtrl->Enabled()); |
| 481 | + EXPECT_FALSE(abortCtrl->Test()); |
| 482 | + |
| 483 | + // Set elem status to simulate ongoing work that won't complete |
| 484 | + for (int i = 0; i < ngroups; i++) { |
| 485 | + elem->status[i] = KernelElem::ElemStatus::INUSE; |
| 486 | + } |
| 487 | + |
| 488 | + std::thread aborter([abortCtrl]() { |
| 489 | + std::this_thread::sleep_for(std::chrono::milliseconds(50)); |
| 490 | + abortCtrl->Set(); |
| 491 | + }); |
| 492 | + elem->wait(abortCtrl); |
| 493 | + EXPECT_FALSE(elem->isComplete()); |
| 494 | + EXPECT_TRUE(abortCtrl->Test()); |
| 495 | + |
| 496 | + aborter.join(); |
| 497 | + |
| 498 | + // Clean up - since abort is set and element is not complete, |
| 499 | + // we need to manually set status to allow free |
| 500 | + for (int i = 0; i < ngroups; i++) { |
| 501 | + elem->status[i] = KernelElem::ElemStatus::RESET; |
| 502 | + } |
| 503 | + elemPool->reclaim(); |
| 504 | +} |
0 commit comments