@@ -370,18 +370,12 @@ impl Builder<'_, '_> {
370370 pub fn count_ones ( & self , arg : SpirvValue ) -> SpirvValue {
371371 let ty = arg. ty ;
372372 match self . cx . lookup_type ( ty) {
373- SpirvType :: Integer ( bits, signed ) => {
373+ SpirvType :: Integer ( bits, false ) => {
374374 let u32 = SpirvType :: Integer ( 32 , false ) . def ( self . span ( ) , self ) ;
375375
376376 match bits {
377377 8 | 16 => {
378378 let arg = arg. def ( self ) ;
379- let arg = if signed {
380- let unsigned = SpirvType :: Integer ( bits, false ) . def ( self . span ( ) , self ) ;
381- self . emit ( ) . bitcast ( unsigned, None , arg) . unwrap ( )
382- } else {
383- arg
384- } ;
385379 let arg = self . emit ( ) . u_convert ( u32, None , arg) . unwrap ( ) ;
386380 self . emit ( ) . bit_count ( u32, None , arg) . unwrap ( )
387381 }
@@ -413,25 +407,23 @@ impl Builder<'_, '_> {
413407 }
414408 . with_type ( u32)
415409 }
416- _ => self . fatal ( "count_ones() on a non-integer type" ) ,
410+ _ => self . fatal ( format ! (
411+ "count_ones() expected an unsigned integer type, got {:?}" ,
412+ self . cx. lookup_type( ty)
413+ ) ) ,
417414 }
418415 }
419416
420417 pub fn bit_reverse ( & self , arg : SpirvValue ) -> SpirvValue {
421418 let ty = arg. ty ;
422419 match self . cx . lookup_type ( ty) {
423- SpirvType :: Integer ( bits, signed ) => {
420+ SpirvType :: Integer ( bits, false ) => {
424421 let u32 = SpirvType :: Integer ( 32 , false ) . def ( self . span ( ) , self ) ;
425422 let uint = SpirvType :: Integer ( bits, false ) . def ( self . span ( ) , self ) ;
426423
427- match ( bits, signed ) {
428- ( 8 | 16 , signed ) => {
424+ match bits {
425+ 8 | 16 => {
429426 let arg = arg. def ( self ) ;
430- let arg = if signed {
431- self . emit ( ) . bitcast ( uint, None , arg) . unwrap ( )
432- } else {
433- arg
434- } ;
435427 let arg = self . emit ( ) . u_convert ( u32, None , arg) . unwrap ( ) ;
436428
437429 let reverse = self . emit ( ) . bit_reverse ( u32, None , arg) . unwrap ( ) ;
@@ -440,20 +432,10 @@ impl Builder<'_, '_> {
440432 . emit ( )
441433 . shift_right_logical ( u32, None , reverse, shift)
442434 . unwrap ( ) ;
443- let reverse = self . emit ( ) . u_convert ( uint, None , reverse) . unwrap ( ) ;
444- if signed {
445- self . emit ( ) . bitcast ( ty, None , reverse) . unwrap ( )
446- } else {
447- reverse
448- }
449- }
450- ( 32 , false ) => self . emit ( ) . bit_reverse ( u32, None , arg. def ( self ) ) . unwrap ( ) ,
451- ( 32 , true ) => {
452- let arg = self . emit ( ) . bitcast ( u32, None , arg. def ( self ) ) . unwrap ( ) ;
453- let reverse = self . emit ( ) . bit_reverse ( u32, None , arg) . unwrap ( ) ;
454- self . emit ( ) . bitcast ( ty, None , reverse) . unwrap ( )
435+ self . emit ( ) . u_convert ( uint, None , reverse) . unwrap ( )
455436 }
456- ( 64 , signed) => {
437+ 32 => self . emit ( ) . bit_reverse ( u32, None , arg. def ( self ) ) . unwrap ( ) ,
438+ 64 => {
457439 let u32_32 = self . constant_u32 ( self . span ( ) , 32 ) . def ( self ) ;
458440 let arg = arg. def ( self ) ;
459441 let lower = self . emit ( ) . s_convert ( u32, None , arg) . unwrap ( ) ;
@@ -475,15 +457,9 @@ impl Builder<'_, '_> {
475457 . unwrap ( ) ;
476458 let lower_bits = self . emit ( ) . u_convert ( uint, None , lower_bits) . unwrap ( ) ;
477459
478- let result = self
479- . emit ( )
460+ self . emit ( )
480461 . bitwise_or ( ty, None , lower_bits, higher_bits)
481- . unwrap ( ) ;
482- if signed {
483- self . emit ( ) . bitcast ( ty, None , result) . unwrap ( )
484- } else {
485- result
486- }
462+ . unwrap ( )
487463 }
488464 _ => {
489465 let undef = self . undef ( ty) . def ( self ) ;
@@ -496,7 +472,10 @@ impl Builder<'_, '_> {
496472 }
497473 . with_type ( ty)
498474 }
499- _ => self . fatal ( "bit_reverse() on a non-integer type" ) ,
475+ _ => self . fatal ( format ! (
476+ "bit_reverse() expected an unsigned integer type, got {:?}" ,
477+ self . cx. lookup_type( ty)
478+ ) ) ,
500479 }
501480 }
502481
@@ -508,7 +487,7 @@ impl Builder<'_, '_> {
508487 ) -> SpirvValue {
509488 let ty = arg. ty ;
510489 match self . cx . lookup_type ( ty) {
511- SpirvType :: Integer ( bits, signed ) => {
490+ SpirvType :: Integer ( bits, false ) => {
512491 let bool = SpirvType :: Bool . def ( self . span ( ) , self ) ;
513492 let u32 = SpirvType :: Integer ( 32 , false ) . def ( self . span ( ) , self ) ;
514493
@@ -542,13 +521,6 @@ impl Builder<'_, '_> {
542521 find_xsb ( arg)
543522 } else {
544523 let arg = arg. def ( self ) ;
545- let arg = if signed {
546- let unsigned =
547- SpirvType :: Integer ( bits, false ) . def ( self . span ( ) , self ) ;
548- self . emit ( ) . bitcast ( unsigned, None , arg) . unwrap ( )
549- } else {
550- arg
551- } ;
552524 let arg = self . emit ( ) . u_convert ( u32, None , arg) . unwrap ( ) ;
553525 let xsb = find_xsb ( arg) ;
554526 let subtrahend = self . constant_u32 ( self . span ( ) , 32 - bits) . def ( self ) ;
@@ -611,7 +583,26 @@ impl Builder<'_, '_> {
611583 }
612584 . with_type ( u32)
613585 }
614- _ => self . fatal ( "count_leading_trailing_zeros() on a non-integer type" ) ,
586+ SpirvType :: Integer ( bits, true ) => {
587+ // rustc wants `[i8,i16,i32,i64]::leading_zeros()` with `non_zero: true` for some reason. I do not know
588+ // how these are reachable, marking them as zombies makes none of our compiletests fail.
589+ let unsigned = SpirvType :: Integer ( bits, false ) . def ( self . span ( ) , self ) ;
590+ let arg = self
591+ . emit ( )
592+ . bitcast ( unsigned, None , arg. def ( self ) )
593+ . unwrap ( )
594+ . with_type ( unsigned) ;
595+ let result = self . count_leading_trailing_zeros ( arg, trailing, non_zero) ;
596+ self . emit ( )
597+ . bitcast ( ty, None , result. def ( self ) )
598+ . unwrap ( )
599+ . with_type ( ty)
600+ }
601+ e => {
602+ self . fatal ( format ! (
603+ "count_leading_trailing_zeros(trailing: {trailing}, non_zero: {non_zero}) expected an integer type, got {e:?}" ,
604+ ) ) ;
605+ }
615606 }
616607 }
617608
0 commit comments