@@ -45,6 +45,8 @@ pub enum SpirvType<'tcx> {
4545 element : Word ,
4646 /// Note: vector count is literal.
4747 count : u32 ,
48+ size : Size ,
49+ align : Align ,
4850 } ,
4951 Matrix {
5052 element : Word ,
@@ -131,7 +133,9 @@ impl SpirvType<'_> {
131133 }
132134 result
133135 }
134- Self :: Vector { element, count } => cx. emit_global ( ) . type_vector_id ( id, element, count) ,
136+ Self :: Vector { element, count, .. } => {
137+ cx. emit_global ( ) . type_vector_id ( id, element, count)
138+ }
135139 Self :: Matrix { element, count } => cx. emit_global ( ) . type_matrix_id ( id, element, count) ,
136140 Self :: Array { element, count } => {
137141 let result = cx
@@ -280,9 +284,7 @@ impl SpirvType<'_> {
280284 Self :: Bool => Size :: from_bytes ( 1 ) ,
281285 Self :: Integer ( width, _) | Self :: Float ( width) => Size :: from_bits ( width) ,
282286 Self :: Adt { size, .. } => size?,
283- Self :: Vector { element, count } => {
284- cx. lookup_type ( element) . sizeof ( cx) ? * count. next_power_of_two ( ) as u64
285- }
287+ Self :: Vector { size, .. } => size,
286288 Self :: Matrix { element, count } => cx. lookup_type ( element) . sizeof ( cx) ? * count as u64 ,
287289 Self :: Array { element, count } => {
288290 cx. lookup_type ( element) . sizeof ( cx) ?
@@ -311,13 +313,7 @@ impl SpirvType<'_> {
311313 Self :: Bool => Align :: from_bytes ( 1 ) . unwrap ( ) ,
312314 Self :: Integer ( width, _) | Self :: Float ( width) => Align :: from_bits ( width as u64 ) . unwrap ( ) ,
313315 Self :: Adt { align, .. } => align,
314- // Vectors have size==align
315- Self :: Vector { .. } => Align :: from_bytes (
316- self . sizeof ( cx)
317- . expect ( "alignof: Vectors must be sized" )
318- . bytes ( ) ,
319- )
320- . expect ( "alignof: Vectors must have power-of-2 size" ) ,
316+ Self :: Vector { align, .. } => align,
321317 Self :: Array { element, .. }
322318 | Self :: RuntimeArray { element }
323319 | Self :: Matrix { element, .. } => cx. lookup_type ( element) . alignof ( cx) ,
@@ -382,7 +378,17 @@ impl SpirvType<'_> {
382378 SpirvType :: Bool => SpirvType :: Bool ,
383379 SpirvType :: Integer ( width, signedness) => SpirvType :: Integer ( width, signedness) ,
384380 SpirvType :: Float ( width) => SpirvType :: Float ( width) ,
385- SpirvType :: Vector { element, count } => SpirvType :: Vector { element, count } ,
381+ SpirvType :: Vector {
382+ element,
383+ count,
384+ size,
385+ align,
386+ } => SpirvType :: Vector {
387+ element,
388+ count,
389+ size,
390+ align,
391+ } ,
386392 SpirvType :: Matrix { element, count } => SpirvType :: Matrix { element, count } ,
387393 SpirvType :: Array { element, count } => SpirvType :: Array { element, count } ,
388394 SpirvType :: RuntimeArray { element } => SpirvType :: RuntimeArray { element } ,
@@ -435,6 +441,32 @@ impl SpirvType<'_> {
435441 } ,
436442 }
437443 }
444+
445+ pub fn simd_vector ( cx : & CodegenCx < ' _ > , span : Span , element : SpirvType < ' _ > , count : u32 ) -> Self {
446+ Self :: Vector {
447+ element : element. def ( span, cx) ,
448+ count,
449+ size : element. sizeof ( cx) . unwrap ( ) * count as u64 ,
450+ align : element. alignof ( cx) ,
451+ }
452+ }
453+
454+ /// Now that we can have different `OpTypeVector` types with various sizes or alignments, having a statement in an
455+ /// `asm!` block like `OpTypeVector %f32 3` doesn't correlate to a specific type anymore. It could be Vec3 or Vec3A,
456+ /// nobody knows.
457+ ///
458+ /// FIXME(@firestar99) This is a giant hack that only works with base glam types, if any other type with
459+ /// `#[spirv(vector)]` is used, this will generate a mismatched type and fail validation. Also, Vec3A doesn't work.
460+ pub fn glam_vector_asm_hack ( cx : & CodegenCx < ' _ > , span : Span , element : Word , count : u32 ) -> Word {
461+ let spirv_type = cx. lookup_type ( element) ;
462+ SpirvType :: Vector {
463+ element,
464+ count,
465+ size : spirv_type. sizeof ( cx) . unwrap ( ) * count as u64 ,
466+ align : spirv_type. alignof ( cx) ,
467+ }
468+ . def ( span, cx)
469+ }
438470}
439471
440472impl < ' a > SpirvType < ' a > {
@@ -501,11 +533,18 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
501533 . field ( "field_names" , & field_names)
502534 . finish ( )
503535 }
504- SpirvType :: Vector { element, count } => f
536+ SpirvType :: Vector {
537+ element,
538+ count,
539+ size,
540+ align,
541+ } => f
505542 . debug_struct ( "Vector" )
506543 . field ( "id" , & self . id )
507544 . field ( "element" , & self . cx . debug_type ( element) )
508545 . field ( "count" , & count)
546+ . field ( "size" , & size)
547+ . field ( "align" , & align)
509548 . finish ( ) ,
510549 SpirvType :: Matrix { element, count } => f
511550 . debug_struct ( "Matrix" )
@@ -668,7 +707,7 @@ impl SpirvTypePrinter<'_, '_> {
668707 }
669708 f. write_str ( " }" )
670709 }
671- SpirvType :: Vector { element, count } | SpirvType :: Matrix { element, count } => {
710+ SpirvType :: Vector { element, count, .. } | SpirvType :: Matrix { element, count } => {
672711 ty ( self . cx , stack, f, element) ?;
673712 write ! ( f, "x{count}" )
674713 }
0 commit comments