Skip to content

Commit 515ffb3

Browse files
Regina8023meta-codesync[bot]
authored andcommitted
Fix KernelElem test hang caused by reading garbage Fault Tolerance state
Summary: Similar to D85974514, the KernelElem test kernels rely on proper initialized value of `shmDevState.enableCancellableWaits` because they call `KernelTestHostAbort` in the kernel. Manually set the value before running anything in the test kernel. Reviewed By: arttianezhu Differential Revision: D86158576 fbshipit-source-id: f7dfbb065bda209faa2bb47ed5850d554c97cf5d
1 parent f80ce55 commit 515ffb3

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

comms/ctran/gpe/tests/KernelElemPoolUT.cc

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "comms/ctran/Ctran.h"
77
#include "comms/ctran/gpe/CtranGpeDev.h"
88
#include "comms/ctran/gpe/CtranGpeImpl.h"
9+
#include "comms/ctran/utils/CudaWrap.h"
910
// FIXME [REBASE]: update the path once moved to fbcode/comms
1011
#include "comms/ctran/gpe/tests/KernelElemPoolUTKernels.h"
1112
#include "comms/ctran/tests/CtranXPlatUtUtils.h"
@@ -149,12 +150,17 @@ TEST_F(KernelElemPoolTest, PostRevokeComplete) {
149150
dim3 grid = {ngroups, 1, 1};
150151
dim3 blocks = {640, 1, 1};
151152
void* args[] = {&elemList, &unuseIdx};
153+
154+
CUDACHECK_TEST(cudaFuncSetAttribute(
155+
(void*)KElemPostRevokeKernel,
156+
cudaFuncAttributeMaxDynamicSharedMemorySize,
157+
sizeof(CtranAlgoDeviceState)));
152158
CUDACHECK_TEST(cudaLaunchKernel(
153159
reinterpret_cast<void*>(KElemPostRevokeKernel),
154160
grid,
155161
blocks,
156162
args,
157-
0,
163+
sizeof(CtranAlgoDeviceState),
158164
0));
159165

160166
// Now host side posts the elems, revoke only 1 elem in the middle
@@ -224,8 +230,18 @@ TEST_F(KernelElemPoolTest, PostWait) {
224230
dim3 grid = {ngroups, 1, 1};
225231
dim3 blocks = {640, 1, 1};
226232
void* args[] = {&elem, &count, &vec1, &vec2};
233+
234+
CUDACHECK_TEST(cudaFuncSetAttribute(
235+
(void*)KElemPostWaitKernel,
236+
cudaFuncAttributeMaxDynamicSharedMemorySize,
237+
sizeof(CtranAlgoDeviceState)));
227238
CUDACHECK_TEST(cudaLaunchKernel(
228-
reinterpret_cast<void*>(KElemPostWaitKernel), grid, blocks, args, 0, 0));
239+
reinterpret_cast<void*>(KElemPostWaitKernel),
240+
grid,
241+
blocks,
242+
args,
243+
sizeof(CtranAlgoDeviceState),
244+
0));
229245

230246
// Host side posts the elems
231247
elem->post();
@@ -282,12 +298,17 @@ TEST_F(KernelElemPoolTest, PostMultiGroupSets) {
282298
dim3 grid = {nGroups * nGroupsSets, 1, 1};
283299
dim3 blocks = {256, 1, 1};
284300
void* args[] = {&elemList, &countPerGroupSet, &nGroupsSets, &vec1, &vec2};
301+
302+
CUDACHECK_TEST(cudaFuncSetAttribute(
303+
(void*)KElemPostMultiGroupsKernel,
304+
cudaFuncAttributeMaxDynamicSharedMemorySize,
305+
sizeof(CtranAlgoDeviceState)));
285306
CUDACHECK_TEST(cudaLaunchKernel(
286307
reinterpret_cast<void*>(KElemPostMultiGroupsKernel),
287308
grid,
288309
blocks,
289310
args,
290-
0,
311+
sizeof(CtranAlgoDeviceState),
291312
0));
292313

293314
auto elem = elemList;

comms/ctran/gpe/tests/KernelElemPoolUTKernels.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Meta Platforms, Inc. and affiliates.
22

33
#include "comms/ctran/algos/DevCommon.cuh"
4+
#include "comms/ctran/algos/DevShmState.cuh"
45
// FIXME [REBASE]: update the path once moved to fbcode/comms
56
#include "comms/ctran/gpe/tests/KernelElemPoolUTKernels.h"
67

@@ -13,6 +14,9 @@ __global__ void KElemConsumerKernel(KernelElem* elemList) {
1314
}
1415

1516
__global__ void KElemPostRevokeKernel(KernelElem* elemList, int unuseIdx) {
17+
// TODO(T243528798): remove this preload of devstate by splitting h2d/d2h
18+
// channels.
19+
shmDevState.enableCancellableWaits = false;
1620
KernelElem* elem = elemList;
1721
int i = 0;
1822
while (elem) {
@@ -42,6 +46,9 @@ __global__ void KElemPostRevokeKernel(KernelElem* elemList, int unuseIdx) {
4246

4347
__global__ void
4448
KElemPostWaitKernel(KernelElem* elem, size_t count, int* vec1, int* vec2) {
49+
// TODO(T243528798): remove this preload of devstate by splitting h2d/d2h
50+
// channels.
51+
shmDevState.enableCancellableWaits = false;
4552
bool revoked = false;
4653
elemWaitPostOrRevokeByGroup(elem, blockIdx.x, &revoked);
4754

@@ -63,6 +70,9 @@ __global__ void KElemPostMultiGroupsKernel(
6370
int nGroupSets,
6471
int* vec1,
6572
int* vec2) {
73+
// TODO(T243528798): remove this preload of devstate by splitting h2d/d2h
74+
// channels.
75+
shmDevState.enableCancellableWaits = false;
6676
bool revoked = false;
6777
auto nGroupsPerSet = gridDim.x / nGroupSets;
6878
auto groupSetId = blockIdx.x / nGroupsPerSet;

0 commit comments

Comments
 (0)