|
6 | 6 | #include "comms/ctran/Ctran.h" |
7 | 7 | #include "comms/ctran/gpe/CtranGpeDev.h" |
8 | 8 | #include "comms/ctran/gpe/CtranGpeImpl.h" |
| 9 | +#include "comms/ctran/utils/CudaWrap.h" |
9 | 10 | // FIXME [REBASE]: update the path once moved to fbcode/comms |
10 | 11 | #include "comms/ctran/gpe/tests/KernelElemPoolUTKernels.h" |
11 | 12 | #include "comms/ctran/tests/CtranXPlatUtUtils.h" |
@@ -149,12 +150,17 @@ TEST_F(KernelElemPoolTest, PostRevokeComplete) { |
149 | 150 | dim3 grid = {ngroups, 1, 1}; |
150 | 151 | dim3 blocks = {640, 1, 1}; |
151 | 152 | void* args[] = {&elemList, &unuseIdx}; |
| 153 | + |
| 154 | + CUDACHECK_TEST(cudaFuncSetAttribute( |
| 155 | + (void*)KElemPostRevokeKernel, |
| 156 | + cudaFuncAttributeMaxDynamicSharedMemorySize, |
| 157 | + sizeof(CtranAlgoDeviceState))); |
152 | 158 | CUDACHECK_TEST(cudaLaunchKernel( |
153 | 159 | reinterpret_cast<void*>(KElemPostRevokeKernel), |
154 | 160 | grid, |
155 | 161 | blocks, |
156 | 162 | args, |
157 | | - 0, |
| 163 | + sizeof(CtranAlgoDeviceState), |
158 | 164 | 0)); |
159 | 165 |
|
160 | 166 | // Now host side posts the elems, revoke only 1 elem in the middle |
@@ -224,8 +230,18 @@ TEST_F(KernelElemPoolTest, PostWait) { |
224 | 230 | dim3 grid = {ngroups, 1, 1}; |
225 | 231 | dim3 blocks = {640, 1, 1}; |
226 | 232 | void* args[] = {&elem, &count, &vec1, &vec2}; |
| 233 | + |
| 234 | + CUDACHECK_TEST(cudaFuncSetAttribute( |
| 235 | + (void*)KElemPostWaitKernel, |
| 236 | + cudaFuncAttributeMaxDynamicSharedMemorySize, |
| 237 | + sizeof(CtranAlgoDeviceState))); |
227 | 238 | CUDACHECK_TEST(cudaLaunchKernel( |
228 | | - reinterpret_cast<void*>(KElemPostWaitKernel), grid, blocks, args, 0, 0)); |
| 239 | + reinterpret_cast<void*>(KElemPostWaitKernel), |
| 240 | + grid, |
| 241 | + blocks, |
| 242 | + args, |
| 243 | + sizeof(CtranAlgoDeviceState), |
| 244 | + 0)); |
229 | 245 |
|
230 | 246 | // Host side posts the elems |
231 | 247 | elem->post(); |
@@ -282,12 +298,17 @@ TEST_F(KernelElemPoolTest, PostMultiGroupSets) { |
282 | 298 | dim3 grid = {nGroups * nGroupsSets, 1, 1}; |
283 | 299 | dim3 blocks = {256, 1, 1}; |
284 | 300 | void* args[] = {&elemList, &countPerGroupSet, &nGroupsSets, &vec1, &vec2}; |
| 301 | + |
| 302 | + CUDACHECK_TEST(cudaFuncSetAttribute( |
| 303 | + (void*)KElemPostMultiGroupsKernel, |
| 304 | + cudaFuncAttributeMaxDynamicSharedMemorySize, |
| 305 | + sizeof(CtranAlgoDeviceState))); |
285 | 306 | CUDACHECK_TEST(cudaLaunchKernel( |
286 | 307 | reinterpret_cast<void*>(KElemPostMultiGroupsKernel), |
287 | 308 | grid, |
288 | 309 | blocks, |
289 | 310 | args, |
290 | | - 0, |
| 311 | + sizeof(CtranAlgoDeviceState), |
291 | 312 | 0)); |
292 | 313 |
|
293 | 314 | auto elem = elemList; |
|
0 commit comments