|
14 | 14 | #include "comms/ctran/tests/CtranStandaloneUTUtils.h" |
15 | 15 | #include "comms/utils/cvars/nccl_cvars.h" |
16 | 16 |
|
| 17 | +#include "comms/ctran/algos/AllReduce/AllReduceImpl.h" |
| 18 | + |
17 | 19 | namespace ctran::testing { |
18 | 20 |
|
19 | 21 | using AllReduceTestParam = std::tuple<std::string, enum NCCL_ALLREDUCE_ALGO>; |
| 22 | +using AllReduceMinMsgSizeTestParam = std::tuple<size_t, commDataType_t>; |
| 23 | + |
| 24 | +enum class CtranAllReduceRingMinSizeTestOpt { |
| 25 | + expect_sufficient, |
| 26 | + expect_insufficient, |
| 27 | +}; |
20 | 28 |
|
21 | 29 | class CtranAllReduceTest |
22 | 30 | : public CtranStandaloneMultiRankBaseTest, |
@@ -250,4 +258,156 @@ INSTANTIATE_TEST_SUITE_P( |
250 | 258 | return std::get<0>(info.param); |
251 | 259 | }); |
252 | 260 |
|
| 261 | +// Test fixture for ctring minimum message size validation |
| 262 | +class CtranAllReduceRingMinSizeTest |
| 263 | + : public CtranStandaloneMultiRankBaseTest, |
| 264 | + public ::testing::WithParamInterface<AllReduceMinMsgSizeTestParam> { |
| 265 | + protected: |
| 266 | + static constexpr int kDefaultNumRanks = 4; |
| 267 | + static_assert(kDefaultNumRanks % 2 == 0); |
| 268 | + static constexpr commRedOp_t kReduceOpType = commSum; |
| 269 | + |
| 270 | + void SetUp() override { |
| 271 | + setenv("NCCL_COMM_STATE_DEBUG_TOPO", "nolocal", 1); |
| 272 | + setenv("NCCL_IGNORE_TOPO_LOAD_FAILURE", "1", 1); |
| 273 | + CtranStandaloneMultiRankBaseTest::SetUp(); |
| 274 | + } |
| 275 | + |
| 276 | + void startWorkers(int numRanks = kDefaultNumRanks) { |
| 277 | + std::vector<std::shared_ptr<::ctran::utils::Abort>> aborts; |
| 278 | + aborts.reserve(numRanks); |
| 279 | + for (int i = 0; i < numRanks; ++i) { |
| 280 | + aborts.push_back(ctran::utils::createAbort(/*enabled=*/true)); |
| 281 | + } |
| 282 | + CtranStandaloneMultiRankBaseTest::startWorkers(numRanks, /*aborts=*/aborts); |
| 283 | + } |
| 284 | + |
| 285 | + void runTest( |
| 286 | + size_t count, |
| 287 | + commDataType_t dt, |
| 288 | + enum CtranAllReduceRingMinSizeTestOpt testOpt, |
| 289 | + int numRanks = kDefaultNumRanks) { |
| 290 | + startWorkers(numRanks); |
| 291 | + for (int rank = 0; rank < numRanks; ++rank) { |
| 292 | + run(rank, [this, count, dt, testOpt](PerRankState& state) { |
| 293 | + ASSERT_TRUE(ctranAllReduceSupport(state.ctranComm.get())); |
| 294 | + |
| 295 | + size_t bufferSize = count * commTypeSize(dt); |
| 296 | + if (bufferSize < CTRAN_MIN_REGISTRATION_SIZE) { |
| 297 | + bufferSize = CTRAN_MIN_REGISTRATION_SIZE; |
| 298 | + } |
| 299 | + |
| 300 | + void* srcHandle; |
| 301 | + void* dstHandle; |
| 302 | + ASSERT_EQ( |
| 303 | + commSuccess, |
| 304 | + state.ctranComm->ctran_->commRegister( |
| 305 | + state.srcBuffer, bufferSize, &srcHandle)); |
| 306 | + ASSERT_EQ( |
| 307 | + commSuccess, |
| 308 | + state.ctranComm->ctran_->commRegister( |
| 309 | + state.dstBuffer, bufferSize, &dstHandle)); |
| 310 | + |
| 311 | + if (testOpt == CtranAllReduceRingMinSizeTestOpt::expect_sufficient) { |
| 312 | + // Should not throw when count >= nRanks |
| 313 | + EXPECT_NO_THROW({ |
| 314 | + auto res = ctranAllReduceRing( |
| 315 | + state.srcBuffer, |
| 316 | + state.dstBuffer, |
| 317 | + count, |
| 318 | + dt, |
| 319 | + kReduceOpType, |
| 320 | + state.ctranComm.get(), |
| 321 | + state.stream); |
| 322 | + EXPECT_EQ(res, commSuccess); |
| 323 | + }); |
| 324 | + } else { |
| 325 | + // Expect ctran::utils::Exception when count < nRanks |
| 326 | + EXPECT_THROW( |
| 327 | + { |
| 328 | + ctranAllReduceRing( |
| 329 | + state.srcBuffer, |
| 330 | + state.dstBuffer, |
| 331 | + count, |
| 332 | + dt, |
| 333 | + kReduceOpType, |
| 334 | + state.ctranComm.get(), |
| 335 | + state.stream); |
| 336 | + }, |
| 337 | + ctran::utils::Exception); |
| 338 | + } |
| 339 | + |
| 340 | + // ensure async execution completion and no error |
| 341 | + EXPECT_EQ(cudaSuccess, cudaStreamSynchronize(state.stream)); |
| 342 | + |
| 343 | + // deregistering will happen after streamSync below |
| 344 | + ASSERT_EQ( |
| 345 | + commSuccess, state.ctranComm->ctran_->commDeregister(dstHandle)); |
| 346 | + ASSERT_EQ( |
| 347 | + commSuccess, state.ctranComm->ctran_->commDeregister(srcHandle)); |
| 348 | + }); |
| 349 | + } |
| 350 | + } |
| 351 | +}; |
| 352 | + |
| 353 | +TEST_P(CtranAllReduceRingMinSizeTest, InsufficientElements_1Element) { |
| 354 | + auto [numRanks, dt] = GetParam(); |
| 355 | + ASSERT_FALSE(numRanks <= 1) << "Need at least 2 ranks for this test"; |
| 356 | + runTest( |
| 357 | + 1, dt, CtranAllReduceRingMinSizeTestOpt::expect_insufficient, numRanks); |
| 358 | +} |
| 359 | + |
| 360 | +TEST_P(CtranAllReduceRingMinSizeTest, InsufficientElements_NRanksMinus1) { |
| 361 | + auto [numRanks, dt] = GetParam(); |
| 362 | + ASSERT_FALSE(numRanks <= 1) << "Need at least 2 ranks for this test"; |
| 363 | + runTest( |
| 364 | + numRanks - 1, |
| 365 | + dt, |
| 366 | + CtranAllReduceRingMinSizeTestOpt::expect_insufficient, |
| 367 | + numRanks); |
| 368 | +} |
| 369 | + |
| 370 | +TEST_P(CtranAllReduceRingMinSizeTest, SufficientElements_ExactlyNRanks) { |
| 371 | + auto [numRanks, dt] = GetParam(); |
| 372 | + XLOG(INFO) << "SufficientElements_ExactlyNRanks: numRanks: " << numRanks |
| 373 | + << ", dt: " << dt; |
| 374 | + runTest( |
| 375 | + numRanks, |
| 376 | + dt, |
| 377 | + CtranAllReduceRingMinSizeTestOpt::expect_sufficient, |
| 378 | + numRanks); |
| 379 | +} |
| 380 | + |
| 381 | +TEST_P(CtranAllReduceRingMinSizeTest, SufficientElements_NRanksPlus1) { |
| 382 | + auto [numRanks, dt] = GetParam(); |
| 383 | + runTest( |
| 384 | + numRanks + 1, |
| 385 | + dt, |
| 386 | + CtranAllReduceRingMinSizeTestOpt::expect_sufficient, |
| 387 | + numRanks); |
| 388 | +} |
| 389 | + |
| 390 | +TEST_P(CtranAllReduceRingMinSizeTest, SufficientElements_LargeMessage) { |
| 391 | + auto [numRanks, dt] = GetParam(); |
| 392 | + runTest( |
| 393 | + 1024, dt, CtranAllReduceRingMinSizeTestOpt::expect_sufficient, numRanks); |
| 394 | +} |
| 395 | + |
| 396 | +INSTANTIATE_TEST_SUITE_P( |
| 397 | + AllDataTypes, |
| 398 | + CtranAllReduceRingMinSizeTest, |
| 399 | + ::testing::Values( |
| 400 | + std::make_tuple<>(2, commFloat), |
| 401 | + std::make_tuple<>(2, commInt32), |
| 402 | + std::make_tuple<>(2, commInt8), |
| 403 | + std::make_tuple<>(4, commFloat), |
| 404 | + std::make_tuple<>(4, commInt32), |
| 405 | + std::make_tuple<>(4, commInt8), |
| 406 | + std::make_tuple<>(6, commFloat), |
| 407 | + std::make_tuple<>(6, commInt32), |
| 408 | + std::make_tuple<>(6, commInt8), |
| 409 | + std::make_tuple<>(8, commFloat), |
| 410 | + std::make_tuple<>(8, commInt32), |
| 411 | + std::make_tuple<>(8, commInt8))); |
| 412 | + |
253 | 413 | } // namespace ctran::testing |
0 commit comments