@@ -89,26 +89,15 @@ extern "C" {
8989}
9090
9191#[ cfg( target_os = "cuda" ) ]
92- macro_rules! inbounds {
93- // the bounds were taken mostly from the cuda C++ programming guide, i also
94- // double-checked with what cuda clang does by checking its emitted llvm ir's scalar metadata
95- ( $func_name: ident, $bound : expr) => { {
92+ macro_rules! in_range {
93+ // The bounds were taken mostly from the cuda C++ programming guide. I also
94+ // double-checked with what cuda clang does by checking its emitted llvm ir's scalar metadata.
95+ ( $func_name: ident, $range : expr) => { {
9696 let val = unsafe { $func_name( ) } ;
97- if val > $bound {
98- // SAFETY: this condition is declared unreachable by compute capability max bound
97+ if !$range . contains ( & val ) {
98+ // SAFETY: this condition is declared unreachable by compute capability max bound.
9999 // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities
100- // we do this to potentially allow for better optimizations by LLVM
101- unsafe { core:: hint:: unreachable_unchecked( ) }
102- } else {
103- val
104- }
105- } } ;
106- ( $func_name: ident, $lower_bound: expr, $upper_bound: expr) => { {
107- let val = unsafe { $func_name( ) } ;
108- if !( $lower_bound..=$upper_bound) . contains( & val) {
109- // SAFETY: this condition is declared unreachable by compute capability max bound
110- // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities
111- // we do this to potentially allow for better optimizations by LLVM
100+ // We do this to potentially allow for better optimizations by LLVM.
112101 unsafe { core:: hint:: unreachable_unchecked( ) }
113102 } else {
114103 val
@@ -119,73 +108,73 @@ macro_rules! inbounds {
119108#[ gpu_only]
120109#[ inline( always) ]
121110pub fn thread_idx_x ( ) -> u32 {
122- inbounds ! ( __nvvm_thread_idx_x, 1024 )
111+ in_range ! ( __nvvm_thread_idx_x, 0 ..= 1024 )
123112}
124113
125114#[ gpu_only]
126115#[ inline( always) ]
127116pub fn thread_idx_y ( ) -> u32 {
128- inbounds ! ( __nvvm_thread_idx_y, 1024 )
117+ in_range ! ( __nvvm_thread_idx_y, 0 ..= 1024 )
129118}
130119
131120#[ gpu_only]
132121#[ inline( always) ]
133122pub fn thread_idx_z ( ) -> u32 {
134- inbounds ! ( __nvvm_thread_idx_z, 64 )
123+ in_range ! ( __nvvm_thread_idx_z, 0 ..= 64 )
135124}
136125
137126#[ gpu_only]
138127#[ inline( always) ]
139128pub fn block_idx_x ( ) -> u32 {
140- inbounds ! ( __nvvm_block_idx_x, 2147483647 )
129+ in_range ! ( __nvvm_block_idx_x, 0 ..= 2147483647 )
141130}
142131
143132#[ gpu_only]
144133#[ inline( always) ]
145134pub fn block_idx_y ( ) -> u32 {
146- inbounds ! ( __nvvm_block_idx_y, 65535 )
135+ in_range ! ( __nvvm_block_idx_y, 0 ..= 65535 )
147136}
148137
149138#[ gpu_only]
150139#[ inline( always) ]
151140pub fn block_idx_z ( ) -> u32 {
152- inbounds ! ( __nvvm_block_idx_z, 65535 )
141+ in_range ! ( __nvvm_block_idx_z, 0 ..= 65535 )
153142}
154143
155144#[ gpu_only]
156145#[ inline( always) ]
157146pub fn block_dim_x ( ) -> u32 {
158- inbounds ! ( __nvvm_block_dim_x, 1 , 1025 )
147+ in_range ! ( __nvvm_block_dim_x, 1 ..= 1025 )
159148}
160149
161150#[ gpu_only]
162151#[ inline( always) ]
163152pub fn block_dim_y ( ) -> u32 {
164- inbounds ! ( __nvvm_block_dim_y, 1 , 1025 )
153+ in_range ! ( __nvvm_block_dim_y, 1 ..= 1025 )
165154}
166155
167156#[ gpu_only]
168157#[ inline( always) ]
169158pub fn block_dim_z ( ) -> u32 {
170- inbounds ! ( __nvvm_block_dim_z, 1 , 65 )
159+ in_range ! ( __nvvm_block_dim_z, 1 ..= 65 )
171160}
172161
173162#[ gpu_only]
174163#[ inline( always) ]
175164pub fn grid_dim_x ( ) -> u32 {
176- inbounds ! ( __nvvm_grid_dim_x, 1 , 2147483648 )
165+ in_range ! ( __nvvm_grid_dim_x, 1 ..= 2147483648 )
177166}
178167
179168#[ gpu_only]
180169#[ inline( always) ]
181170pub fn grid_dim_y ( ) -> u32 {
182- inbounds ! ( __nvvm_grid_dim_y, 1 , 65536 )
171+ in_range ! ( __nvvm_grid_dim_y, 1 ..= 65536 )
183172}
184173
185174#[ gpu_only]
186175#[ inline( always) ]
187176pub fn grid_dim_z ( ) -> u32 {
188- inbounds ! ( __nvvm_grid_dim_z, 1 , 65536 )
177+ in_range ! ( __nvvm_grid_dim_z, 1 ..= 65536 )
189178}
190179
191180/// Gets the 3d index of the thread currently executing the kernel.
0 commit comments