11//! Extension trait for [`f32`] and [`f64`], providing high level wrappers on top of
22//! raw libdevice intrinsics from [`intrinsics`](crate::intrinsics).
33
4+ use cuda_std_macros:: gpu_only;
5+
6+ #[ cfg( target_arch = "nvptx64" ) ]
47use crate :: intrinsics as raw;
58
69// allows us to add new functions to the trait at any time without needing a new major version.
@@ -71,26 +74,32 @@ pub trait FloatExt: Sized + private::Sealed {
7174}
7275
7376impl FloatExt for f64 {
77+ #[ gpu_only]
7478 fn cospi ( self ) -> Self {
7579 unsafe { raw:: cospi ( self ) }
7680 }
7781
82+ #[ gpu_only]
7883 fn error_function ( self ) -> Self {
7984 unsafe { raw:: erf ( self ) }
8085 }
8186
87+ #[ gpu_only]
8288 fn complementary_error_function ( self ) -> Self {
8389 unsafe { raw:: erfc ( self ) }
8490 }
8591
92+ #[ gpu_only]
8693 fn inv_complementary_error_function ( self ) -> Self {
8794 unsafe { raw:: erfcinv ( self ) }
8895 }
8996
97+ #[ gpu_only]
9098 fn scaled_complementary_error_function ( self ) -> Self {
9199 unsafe { raw:: erfcx ( self ) }
92100 }
93101
102+ #[ gpu_only]
94103 fn frexp ( self ) -> ( Self , i32 ) {
95104 let mut exp = 0 ;
96105 unsafe {
@@ -99,55 +108,68 @@ impl FloatExt for f64 {
99108 }
100109 }
101110
111+ #[ gpu_only]
102112 fn unbiased_exp ( self ) -> i32 {
103113 unsafe { raw:: ilogb ( self ) }
104114 }
105115
116+ #[ gpu_only]
106117 fn j0 ( self ) -> Self {
107118 unsafe { raw:: j0 ( self ) }
108119 }
109120
121+ #[ gpu_only]
110122 fn j1 ( self ) -> Self {
111123 unsafe { raw:: j1 ( self ) }
112124 }
113125
126+ #[ gpu_only]
114127 fn jn ( self , order : i32 ) -> Self {
115128 unsafe { raw:: jn ( order, self ) }
116129 }
117130
131+ #[ gpu_only]
118132 fn ldexp ( self , exp : i32 ) -> Self {
119133 unsafe { raw:: ldexp ( self , exp) }
120134 }
121135
136+ #[ gpu_only]
122137 fn log_gamma ( self ) -> Self {
123138 unsafe { raw:: lgamma ( self ) }
124139 }
125140
141+ #[ gpu_only]
126142 fn log1p ( self ) -> Self {
127143 unsafe { raw:: log1p ( self ) }
128144 }
129145
146+ #[ gpu_only]
130147 fn norm_cdf ( self ) -> Self {
131148 unsafe { raw:: normcdf ( self ) }
132149 }
133150
151+ #[ gpu_only]
134152 fn inv_norm_cdf ( self ) -> Self {
135153 unsafe { raw:: normcdfinv ( self ) }
136154 }
137155
156+ #[ gpu_only]
138157 fn rcbrt ( self ) -> Self {
139158 unsafe { raw:: rcbrt ( self ) }
140159 }
141160
161+ #[ gpu_only]
142162 fn saturate ( self ) -> Self {
143163 // this intrinsic doesnt actually exit on f64, so implement it as clamp on f64
144164 self . clamp ( 0.0 , 1.0 )
145165 }
146166
167+ #[ gpu_only]
147168 fn scale_by_n ( self , exp : i32 ) -> Self {
148169 unsafe { raw:: scalbn ( self , exp) }
149170 }
150171
172+ #[ gpu_only]
151173 fn sincospi ( self ) -> ( Self , Self ) {
152174 let mut sin = 0.0 ;
153175 let mut cos = 0.0 ;
@@ -157,48 +179,59 @@ impl FloatExt for f64 {
157179 ( sin, cos)
158180 }
159181
182+ #[ gpu_only]
160183 fn sinpi ( self ) -> Self {
161184 unsafe { raw:: sinpi ( self ) }
162185 }
163186
187+ #[ gpu_only]
164188 fn gamma ( self ) -> Self {
165189 unsafe { raw:: tgamma ( self ) }
166190 }
167191
192+ #[ gpu_only]
168193 fn y0 ( self ) -> Self {
169194 unsafe { raw:: y0 ( self ) }
170195 }
171196
197+ #[ gpu_only]
172198 fn y1 ( self ) -> Self {
173199 unsafe { raw:: y1 ( self ) }
174200 }
175201
202+ #[ gpu_only]
176203 fn yn ( self , order : i32 ) -> Self {
177204 unsafe { raw:: yn ( order, self ) }
178205 }
179206}
180207
181208impl FloatExt for f32 {
209+ #[ gpu_only]
182210 fn cospi ( self ) -> Self {
183211 unsafe { raw:: cospif ( self ) }
184212 }
185213
214+ #[ gpu_only]
186215 fn error_function ( self ) -> Self {
187216 unsafe { raw:: erff ( self ) }
188217 }
189218
219+ #[ gpu_only]
190220 fn complementary_error_function ( self ) -> Self {
191221 unsafe { raw:: erfcf ( self ) }
192222 }
193223
224+ #[ gpu_only]
194225 fn inv_complementary_error_function ( self ) -> Self {
195226 unsafe { raw:: erfcinvf ( self ) }
196227 }
197228
229+ #[ gpu_only]
198230 fn scaled_complementary_error_function ( self ) -> Self {
199231 unsafe { raw:: erfcxf ( self ) }
200232 }
201233
234+ #[ gpu_only]
202235 fn frexp ( self ) -> ( Self , i32 ) {
203236 let mut exp = 0 ;
204237 unsafe {
@@ -207,54 +240,67 @@ impl FloatExt for f32 {
207240 }
208241 }
209242
243+ #[ gpu_only]
210244 fn unbiased_exp ( self ) -> i32 {
211245 unsafe { raw:: ilogbf ( self ) }
212246 }
213247
248+ #[ gpu_only]
214249 fn j0 ( self ) -> Self {
215250 unsafe { raw:: j0f ( self ) }
216251 }
217252
253+ #[ gpu_only]
218254 fn j1 ( self ) -> Self {
219255 unsafe { raw:: j1f ( self ) }
220256 }
221257
258+ #[ gpu_only]
222259 fn jn ( self , order : i32 ) -> Self {
223260 unsafe { raw:: jnf ( order, self ) }
224261 }
225262
263+ #[ gpu_only]
226264 fn ldexp ( self , exp : i32 ) -> Self {
227265 unsafe { raw:: ldexpf ( self , exp) }
228266 }
229267
268+ #[ gpu_only]
230269 fn log_gamma ( self ) -> Self {
231270 unsafe { raw:: lgammaf ( self ) }
232271 }
233272
273+ #[ gpu_only]
234274 fn log1p ( self ) -> Self {
235275 unsafe { raw:: log1pf ( self ) }
236276 }
237277
278+ #[ gpu_only]
238279 fn norm_cdf ( self ) -> Self {
239280 unsafe { raw:: normcdff ( self ) }
240281 }
241282
283+ #[ gpu_only]
242284 fn inv_norm_cdf ( self ) -> Self {
243285 unsafe { raw:: normcdfinvf ( self ) }
244286 }
245287
288+ #[ gpu_only]
246289 fn rcbrt ( self ) -> Self {
247290 unsafe { raw:: rcbrtf ( self ) }
248291 }
249292
293+ #[ gpu_only]
250294 fn saturate ( self ) -> Self {
251295 unsafe { raw:: saturatef ( self ) }
252296 }
253297
298+ #[ gpu_only]
254299 fn scale_by_n ( self , exp : i32 ) -> Self {
255300 unsafe { raw:: scalbnf ( self , exp) }
256301 }
257302
303+ #[ gpu_only]
258304 fn sincospi ( self ) -> ( Self , Self ) {
259305 let mut sin = 0.0 ;
260306 let mut cos = 0.0 ;
@@ -264,22 +310,27 @@ impl FloatExt for f32 {
264310 ( sin, cos)
265311 }
266312
313+ #[ gpu_only]
267314 fn sinpi ( self ) -> Self {
268315 unsafe { raw:: sinpif ( self ) }
269316 }
270317
318+ #[ gpu_only]
271319 fn gamma ( self ) -> Self {
272320 unsafe { raw:: tgammaf ( self ) }
273321 }
274322
323+ #[ gpu_only]
275324 fn y0 ( self ) -> Self {
276325 unsafe { raw:: y0f ( self ) }
277326 }
278327
328+ #[ gpu_only]
279329 fn y1 ( self ) -> Self {
280330 unsafe { raw:: y1f ( self ) }
281331 }
282332
333+ #[ gpu_only]
283334 fn yn ( self , order : i32 ) -> Self {
284335 unsafe { raw:: ynf ( order, self ) }
285336 }
0 commit comments