Skip to content

Commit a96c5bc

Browse files
nnethercoteLegNeato
authored andcommitted
Don't call intrinsics in 3d dim/idx functions.
Instead call the Rust functions that have the range constraints. That way the 3d version get the same range constraints as the 1d versions. It also avoids the need for some `unsafe` blocks.
1 parent ab4b28b commit a96c5bc

File tree

1 file changed

+4
-28
lines changed

1 file changed

+4
-28
lines changed

crates/cuda_std/src/thread.rs

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -193,54 +193,30 @@ pub fn grid_dim_z() -> u32 {
193193
#[gpu_only]
194194
#[inline(always)]
195195
pub fn thread_idx() -> UVec3 {
196-
unsafe {
197-
UVec3::new(
198-
__nvvm_thread_idx_x(),
199-
__nvvm_thread_idx_y(),
200-
__nvvm_thread_idx_z(),
201-
)
202-
}
196+
UVec3::new(thread_idx_x(), thread_idx_y(), thread_idx_z())
203197
}
204198

205199
/// Gets the 3d index of the block that the thread currently executing the kernel is located in.
206200
#[gpu_only]
207201
#[inline(always)]
208202
pub fn block_idx() -> UVec3 {
209-
unsafe {
210-
UVec3::new(
211-
__nvvm_block_idx_x(),
212-
__nvvm_block_idx_y(),
213-
__nvvm_block_idx_z(),
214-
)
215-
}
203+
UVec3::new(block_idx_x(), block_idx_y(), block_idx_z())
216204
}
217205

218206
/// Gets the 3d layout of the thread blocks executing this kernel. In other words,
219207
/// how many threads exist in each thread block in every direction.
220208
#[gpu_only]
221209
#[inline(always)]
222210
pub fn block_dim() -> UVec3 {
223-
unsafe {
224-
UVec3::new(
225-
__nvvm_block_dim_x(),
226-
__nvvm_block_dim_y(),
227-
__nvvm_block_dim_z(),
228-
)
229-
}
211+
UVec3::new(block_dim_x(), block_dim_y(), block_dim_z())
230212
}
231213

232214
/// Gets the 3d layout of the block grids executing this kernel. In other words,
233215
/// how many thread blocks exist in each grid in every direction.
234216
#[gpu_only]
235217
#[inline(always)]
236218
pub fn grid_dim() -> UVec3 {
237-
unsafe {
238-
UVec3::new(
239-
__nvvm_grid_dim_x(),
240-
__nvvm_grid_dim_y(),
241-
__nvvm_grid_dim_z(),
242-
)
243-
}
219+
UVec3::new(grid_dim_x(), grid_dim_y(), grid_dim_z())
244220
}
245221

246222
/// Gets the overall thread index, accounting for 1d/2d/3d block/grid dimensions. This

0 commit comments

Comments
 (0)