Skip to content

Commit dedb993

Browse files
nnethercoteLegNeato
authored andcommitted
Improve and rename inbounds! macro.
It has two versions, one with an upper bound, and one with a lower and upper bound. This commit removes the first one and changes the second one to take a range, because that is more concise and flexible and clearer. Also, rename it as `in_range`, which makes sense given that the bounds are specified via a Rust `Range`. Note: some of the ranges are incorrect, and will be fixed in the next commit.
1 parent 0e864be commit dedb993

File tree

1 file changed

+19
-30
lines changed

1 file changed

+19
-30
lines changed

crates/cuda_std/src/thread.rs

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -89,26 +89,15 @@ extern "C" {
8989
}
9090

9191
#[cfg(target_os = "cuda")]
92-
macro_rules! inbounds {
93-
// the bounds were taken mostly from the cuda C++ programming guide, i also
94-
// double-checked with what cuda clang does by checking its emitted llvm ir's scalar metadata
95-
($func_name:ident, $bound:expr) => {{
92+
macro_rules! in_range {
93+
// The bounds were taken mostly from the cuda C++ programming guide. I also
94+
// double-checked with what cuda clang does by checking its emitted llvm ir's scalar metadata.
95+
($func_name:ident, $range:expr) => {{
9696
let val = unsafe { $func_name() };
97-
if val > $bound {
98-
// SAFETY: this condition is declared unreachable by compute capability max bound
97+
if !$range.contains(&val) {
98+
// SAFETY: this condition is declared unreachable by compute capability max bound.
9999
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities
100-
// we do this to potentially allow for better optimizations by LLVM
101-
unsafe { core::hint::unreachable_unchecked() }
102-
} else {
103-
val
104-
}
105-
}};
106-
($func_name:ident, $lower_bound:expr, $upper_bound:expr) => {{
107-
let val = unsafe { $func_name() };
108-
if !($lower_bound..=$upper_bound).contains(&val) {
109-
// SAFETY: this condition is declared unreachable by compute capability max bound
110-
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities
111-
// we do this to potentially allow for better optimizations by LLVM
100+
// We do this to potentially allow for better optimizations by LLVM.
112101
unsafe { core::hint::unreachable_unchecked() }
113102
} else {
114103
val
@@ -119,73 +108,73 @@ macro_rules! inbounds {
119108
#[gpu_only]
120109
#[inline(always)]
121110
pub fn thread_idx_x() -> u32 {
122-
inbounds!(__nvvm_thread_idx_x, 1024)
111+
in_range!(__nvvm_thread_idx_x, 0..=1024)
123112
}
124113

125114
#[gpu_only]
126115
#[inline(always)]
127116
pub fn thread_idx_y() -> u32 {
128-
inbounds!(__nvvm_thread_idx_y, 1024)
117+
in_range!(__nvvm_thread_idx_y, 0..=1024)
129118
}
130119

131120
#[gpu_only]
132121
#[inline(always)]
133122
pub fn thread_idx_z() -> u32 {
134-
inbounds!(__nvvm_thread_idx_z, 64)
123+
in_range!(__nvvm_thread_idx_z, 0..=64)
135124
}
136125

137126
#[gpu_only]
138127
#[inline(always)]
139128
pub fn block_idx_x() -> u32 {
140-
inbounds!(__nvvm_block_idx_x, 2147483647)
129+
in_range!(__nvvm_block_idx_x, 0..=2147483647)
141130
}
142131

143132
#[gpu_only]
144133
#[inline(always)]
145134
pub fn block_idx_y() -> u32 {
146-
inbounds!(__nvvm_block_idx_y, 65535)
135+
in_range!(__nvvm_block_idx_y, 0..=65535)
147136
}
148137

149138
#[gpu_only]
150139
#[inline(always)]
151140
pub fn block_idx_z() -> u32 {
152-
inbounds!(__nvvm_block_idx_z, 65535)
141+
in_range!(__nvvm_block_idx_z, 0..=65535)
153142
}
154143

155144
#[gpu_only]
156145
#[inline(always)]
157146
pub fn block_dim_x() -> u32 {
158-
inbounds!(__nvvm_block_dim_x, 1, 1025)
147+
in_range!(__nvvm_block_dim_x, 1..=1025)
159148
}
160149

161150
#[gpu_only]
162151
#[inline(always)]
163152
pub fn block_dim_y() -> u32 {
164-
inbounds!(__nvvm_block_dim_y, 1, 1025)
153+
in_range!(__nvvm_block_dim_y, 1..=1025)
165154
}
166155

167156
#[gpu_only]
168157
#[inline(always)]
169158
pub fn block_dim_z() -> u32 {
170-
inbounds!(__nvvm_block_dim_z, 1, 65)
159+
in_range!(__nvvm_block_dim_z, 1..=65)
171160
}
172161

173162
#[gpu_only]
174163
#[inline(always)]
175164
pub fn grid_dim_x() -> u32 {
176-
inbounds!(__nvvm_grid_dim_x, 1, 2147483648)
165+
in_range!(__nvvm_grid_dim_x, 1..=2147483648)
177166
}
178167

179168
#[gpu_only]
180169
#[inline(always)]
181170
pub fn grid_dim_y() -> u32 {
182-
inbounds!(__nvvm_grid_dim_y, 1, 65536)
171+
in_range!(__nvvm_grid_dim_y, 1..=65536)
183172
}
184173

185174
#[gpu_only]
186175
#[inline(always)]
187176
pub fn grid_dim_z() -> u32 {
188-
inbounds!(__nvvm_grid_dim_z, 1, 65536)
177+
in_range!(__nvvm_grid_dim_z, 1..=65536)
189178
}
190179

191180
/// Gets the 3d index of the thread currently executing the kernel.

0 commit comments

Comments
 (0)