@@ -479,13 +479,20 @@ impl Builder<'_, '_> {
479479 let u32 = SpirvType :: Integer ( 32 , false ) . def ( self . span ( ) , self ) ;
480480
481481 let glsl = self . ext_inst . borrow_mut ( ) . import_glsl ( self ) ;
482- let find_xsb = |arg| {
482+ let find_xsb = |arg, offset : i32 | {
483483 if trailing {
484- self . emit ( )
484+ let lsb = self
485+ . emit ( )
485486 . ext_inst ( u32, None , glsl, GLOp :: FindILsb as u32 , [ Operand :: IdRef (
486487 arg,
487488 ) ] )
488- . unwrap ( )
489+ . unwrap ( ) ;
490+ if offset == 0 {
491+ lsb
492+ } else {
493+ let const_offset = self . constant_i32 ( self . span ( ) , offset) . def ( self ) ;
494+ self . emit ( ) . i_add ( u32, None , const_offset, lsb) . unwrap ( )
495+ }
489496 } else {
490497 // rust is always unsigned, so FindUMsb
491498 let msb_bit = self
@@ -496,25 +503,21 @@ impl Builder<'_, '_> {
496503 . unwrap ( ) ;
497504 // the glsl op returns the Msb bit, not the amount of leading zeros of this u32
498505 // leading zeros = 31 - Msb bit
499- let u32_31 = self . constant_u32 ( self . span ( ) , 31 ) . def ( self ) ;
500- self . emit ( ) . i_sub ( u32, None , u32_31 , msb_bit) . unwrap ( )
506+ let const_offset = self . constant_i32 ( self . span ( ) , 31 - offset ) . def ( self ) ;
507+ self . emit ( ) . i_sub ( u32, None , const_offset , msb_bit) . unwrap ( )
501508 }
502509 } ;
503510
504511 let converted = match bits {
505512 8 | 16 => {
513+ let arg = self . emit ( ) . u_convert ( u32, None , arg. def ( self ) ) . unwrap ( ) ;
506514 if trailing {
507- let arg = self . emit ( ) . u_convert ( u32, None , arg. def ( self ) ) . unwrap ( ) ;
508- find_xsb ( arg)
515+ find_xsb ( arg, 0 )
509516 } else {
510- let arg = arg. def ( self ) ;
511- let arg = self . emit ( ) . u_convert ( u32, None , arg) . unwrap ( ) ;
512- let xsb = find_xsb ( arg) ;
513- let subtrahend = self . constant_u32 ( self . span ( ) , 32 - bits) . def ( self ) ;
514- self . emit ( ) . i_sub ( u32, None , xsb, subtrahend) . unwrap ( )
517+ find_xsb ( arg, bits as i32 - 32 )
515518 }
516519 }
517- 32 => find_xsb ( arg. def ( self ) ) ,
520+ 32 => find_xsb ( arg. def ( self ) , 0 ) ,
518521 64 => {
519522 let u32_0 = self . constant_int ( u32, 0 ) . def ( self ) ;
520523 let u32_32 = self . constant_u32 ( self . span ( ) , 32 ) . def ( self ) ;
@@ -527,20 +530,17 @@ impl Builder<'_, '_> {
527530 . unwrap ( ) ;
528531 let higher = self . emit ( ) . u_convert ( u32, None , higher) . unwrap ( ) ;
529532
530- let lower_bits = find_xsb ( lower) ;
531- let higher_bits = find_xsb ( higher) ;
532-
533533 if trailing {
534534 let use_lower = self . emit ( ) . i_equal ( bool, None , higher, u32_0) . unwrap ( ) ;
535- let lower_bits =
536- self . emit ( ) . i_add ( u32 , None , lower_bits , u32_32 ) . unwrap ( ) ;
535+ let lower_bits = find_xsb ( lower , 32 ) ;
536+ let higher_bits = find_xsb ( higher , 0 ) ;
537537 self . emit ( )
538538 . select ( u32, None , use_lower, lower_bits, higher_bits)
539539 . unwrap ( )
540540 } else {
541541 let use_higher = self . emit ( ) . i_equal ( bool, None , lower, u32_0) . unwrap ( ) ;
542- let higher_bits =
543- self . emit ( ) . i_add ( u32 , None , higher_bits , u32_32 ) . unwrap ( ) ;
542+ let lower_bits = find_xsb ( lower , 0 ) ;
543+ let higher_bits = find_xsb ( higher , 32 ) ;
544544 self . emit ( )
545545 . select ( u32, None , use_higher, higher_bits, lower_bits)
546546 . unwrap ( )
0 commit comments