Skip to content

Commit fbc77ce

Browse files
committed
abi layout: give Vector a dynamic size and alignment
1 parent 236ee2c commit fbc77ce

File tree

8 files changed

+159
-59
lines changed

8 files changed

+159
-59
lines changed

crates/rustc_codegen_spirv/src/abi.rs

Lines changed: 78 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,8 @@ impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
648648
SpirvType::Vector {
649649
element: elem_spirv,
650650
count: count as u32,
651+
size: self.size,
652+
align: self.align.abi,
651653
}
652654
.def(span, cx)
653655
}
@@ -1229,43 +1231,92 @@ fn trans_intrinsic_type<'tcx>(
12291231
}
12301232
}
12311233
IntrinsicType::Matrix => {
1232-
let span = def_id_for_spirv_type_adt(ty)
1233-
.map(|did| cx.tcx.def_span(did))
1234-
.expect("#[spirv(matrix)] must be added to a type which has DefId");
1235-
1236-
let field_types = (0..ty.fields.count())
1237-
.map(|i| ty.field(cx, i).spirv_type(span, cx))
1238-
.collect::<Vec<_>>();
1239-
if field_types.len() < 2 {
1240-
return Err(cx
1241-
.tcx
1242-
.dcx()
1243-
.span_err(span, "#[spirv(matrix)] type must have at least two fields"));
1244-
}
1245-
let elem_type = field_types[0];
1246-
if !field_types.iter().all(|&ty| ty == elem_type) {
1247-
return Err(cx.tcx.dcx().span_err(
1248-
span,
1249-
"#[spirv(matrix)] type fields must all be the same type",
1250-
));
1251-
}
1252-
match cx.lookup_type(elem_type) {
1234+
let (element, count) =
1235+
trans_glam_like_struct(cx, span, ty, args, "`#[spirv(matrix)]`")?;
1236+
match cx.lookup_type(element) {
12531237
SpirvType::Vector { .. } => (),
12541238
ty => {
12551239
return Err(cx
12561240
.tcx
12571241
.dcx()
1258-
.struct_span_err(span, "#[spirv(matrix)] type fields must all be vectors")
1259-
.with_note(format!("field type is {}", ty.debug(elem_type, cx)))
1242+
.struct_span_err(span, "`#[spirv(matrix)]` type fields must all be vectors")
1243+
.with_note(format!("field type is {}", ty.debug(element, cx)))
12601244
.emit());
12611245
}
12621246
}
1263-
1264-
Ok(SpirvType::Matrix {
1265-
element: elem_type,
1266-
count: field_types.len() as u32,
1247+
Ok(SpirvType::Matrix { element, count }.def(span, cx))
1248+
}
1249+
IntrinsicType::Vector => {
1250+
let (element, count) =
1251+
trans_glam_like_struct(cx, span, ty, args, "`#[spirv(vector)]`")?;
1252+
match cx.lookup_type(element) {
1253+
SpirvType::Float { .. } | SpirvType::Integer { .. } => (),
1254+
ty => {
1255+
return Err(cx
1256+
.tcx
1257+
.dcx()
1258+
.struct_span_err(
1259+
span,
1260+
"`#[spirv(vector)]` type fields must all be floats or integers",
1261+
)
1262+
.with_note(format!("field type is {}", ty.debug(element, cx)))
1263+
.emit());
1264+
}
1265+
}
1266+
Ok(SpirvType::Vector {
1267+
element,
1268+
count,
1269+
size: ty.size,
1270+
align: ty.align.abi,
12671271
}
12681272
.def(span, cx))
12691273
}
12701274
}
12711275
}
1276+
1277+
/// A struct with multiple fields of the same kind
1278+
/// Used for `#[spirv(vector)]` and `#[spirv(matrix)]`
1279+
fn trans_glam_like_struct<'tcx>(
1280+
cx: &CodegenCx<'tcx>,
1281+
span: Span,
1282+
ty: TyAndLayout<'tcx>,
1283+
args: GenericArgsRef<'tcx>,
1284+
err_attr_name: &str,
1285+
) -> Result<(Word, u32), ErrorGuaranteed> {
1286+
let tcx = cx.tcx;
1287+
if let Some(adt) = ty.ty.ty_adt_def()
1288+
&& adt.is_struct()
1289+
{
1290+
let (count, element) = adt
1291+
.non_enum_variant()
1292+
.fields
1293+
.iter()
1294+
.map(|f| f.ty(tcx, args))
1295+
.dedup_with_count()
1296+
.exactly_one()
1297+
.map_err(|_e| {
1298+
tcx.dcx().span_err(
1299+
span,
1300+
format!("{err_attr_name} member types must all be the same"),
1301+
)
1302+
})?;
1303+
1304+
let element = cx.layout_of(element);
1305+
let element_word = element.spirv_type(span, cx);
1306+
let count = u32::try_from(count)
1307+
.ok()
1308+
.filter(|count| *count >= 2)
1309+
.ok_or_else(|| {
1310+
tcx.dcx().span_err(
1311+
span,
1312+
format!("{err_attr_name} must have at least 2 members"),
1313+
)
1314+
})?;
1315+
1316+
Ok((element_word, count))
1317+
} else {
1318+
Err(tcx
1319+
.dcx()
1320+
.span_err(span, "#[spirv(vector)] type must be a struct"))
1321+
}
1322+
}

crates/rustc_codegen_spirv/src/attr.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ pub enum IntrinsicType {
6868
RuntimeArray,
6969
TypedBuffer,
7070
Matrix,
71+
Vector,
7172
}
7273

7374
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
@@ -548,6 +549,22 @@ fn parse_attrs_for_checking<'a>(
548549
))
549550
}
550551
}
552+
Some(command) if command.name == sym.vector => {
553+
// #[rust_gpu::vector ...]
554+
match s.get(2) {
555+
// #[rust_gpu::vector::v1]
556+
Some(version) if version.name == sym.v1 => {
557+
Ok(SmallVec::from_iter([
558+
Ok((attr.span(), SpirvAttribute::IntrinsicType(IntrinsicType::Vector)))
559+
]))
560+
},
561+
_ => Err((
562+
attr.span(),
563+
"unknown `rust_gpu::vector` version, expected `rust_gpu::vector::v1`"
564+
.to_string(),
565+
)),
566+
}
567+
}
551568
_ => {
552569
// #[rust_gpu::...] but not a know version
553570
let spirv = sym.spirv_attr_with_version.as_str();

crates/rustc_codegen_spirv/src/builder/builder_methods.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -317,10 +317,12 @@ fn memset_dynamic_scalar(
317317
byte_width: usize,
318318
is_float: bool,
319319
) -> Word {
320-
let composite_type = SpirvType::Vector {
321-
element: SpirvType::Integer(8, false).def(builder.span(), builder),
322-
count: byte_width as u32,
323-
}
320+
let composite_type = SpirvType::simd_vector(
321+
builder,
322+
builder.span(),
323+
SpirvType::Integer(8, false),
324+
byte_width as u32,
325+
)
324326
.def(builder.span(), builder);
325327
let composite = builder
326328
.emit()
@@ -417,7 +419,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
417419
_ => self.fatal(format!("memset on float width {width} not implemented yet")),
418420
},
419421
SpirvType::Adt { .. } => self.fatal("memset on structs not implemented yet"),
420-
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => {
422+
SpirvType::Vector { element, count, .. } | SpirvType::Matrix { element, count } => {
421423
let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte);
422424
self.constant_composite(
423425
ty.def(self.span(), self),
@@ -478,7 +480,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
478480
)
479481
.unwrap()
480482
}
481-
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => {
483+
SpirvType::Vector { element, count, .. } | SpirvType::Matrix { element, count } => {
482484
let elem_pat = self.memset_dynamic_pattern(&self.lookup_type(element), fill_var);
483485
self.emit()
484486
.composite_construct(
@@ -2976,11 +2978,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
29762978
}
29772979

29782980
fn vector_splat(&mut self, num_elts: usize, elt: Self::Value) -> Self::Value {
2979-
let result_type = SpirvType::Vector {
2980-
element: elt.ty,
2981-
count: num_elts as u32,
2982-
}
2983-
.def(self.span(), self);
2981+
let result_type =
2982+
SpirvType::simd_vector(self, self.span(), self.lookup_type(elt.ty), num_elts as u32)
2983+
.def(self.span(), self);
29842984
if self.builder.lookup_const(elt).is_some() {
29852985
self.constant_composite(result_type, iter::repeat_n(elt.def(self), num_elts))
29862986
} else {

crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
114114
let val = self.load_u32(array, dynamic_word_index, constant_word_offset);
115115
self.bitcast(val, result_type)
116116
}
117-
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => self
117+
SpirvType::Vector { element, count, .. } | SpirvType::Matrix { element, count } => self
118118
.load_vec_mat_arr(
119119
original_type,
120120
result_type,
@@ -314,7 +314,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
314314
let value_u32 = self.bitcast(value, u32_ty);
315315
self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32)
316316
}
317-
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => self
317+
SpirvType::Vector { element, count, .. } | SpirvType::Matrix { element, count } => self
318318
.store_vec_mat_arr(
319319
original_type,
320320
value,

crates/rustc_codegen_spirv/src/builder/spirv_asm.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
728728
SpirvType::Vector {
729729
element: ty,
730730
count: 4,
731+
..
731732
},
732733
)
733734
| (

crates/rustc_codegen_spirv/src/codegen_cx/constant.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,12 @@ impl ConstCodegenMethods for CodegenCx<'_> {
200200
self.constant_composite(struct_ty, elts.iter().map(|f| f.def_cx(self)))
201201
}
202202
fn const_vector(&self, elts: &[Self::Value]) -> Self::Value {
203-
let vector_ty = SpirvType::Vector {
204-
element: elts[0].ty,
205-
count: elts.len() as u32,
206-
}
203+
let vector_ty = SpirvType::simd_vector(
204+
self,
205+
DUMMY_SP,
206+
self.lookup_type(elts[0].ty),
207+
elts.len() as u32,
208+
)
207209
.def(DUMMY_SP, self);
208210
self.constant_composite(vector_ty, elts.iter().map(|elt| elt.def_cx(self)))
209211
}

crates/rustc_codegen_spirv/src/spirv_type.rs

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ pub enum SpirvType<'tcx> {
4545
element: Word,
4646
/// Note: vector count is literal.
4747
count: u32,
48+
size: Size,
49+
align: Align,
4850
},
4951
Matrix {
5052
element: Word,
@@ -131,7 +133,9 @@ impl SpirvType<'_> {
131133
}
132134
result
133135
}
134-
Self::Vector { element, count } => cx.emit_global().type_vector_id(id, element, count),
136+
Self::Vector { element, count, .. } => {
137+
cx.emit_global().type_vector_id(id, element, count)
138+
}
135139
Self::Matrix { element, count } => cx.emit_global().type_matrix_id(id, element, count),
136140
Self::Array { element, count } => {
137141
let result = cx
@@ -280,9 +284,7 @@ impl SpirvType<'_> {
280284
Self::Bool => Size::from_bytes(1),
281285
Self::Integer(width, _) | Self::Float(width) => Size::from_bits(width),
282286
Self::Adt { size, .. } => size?,
283-
Self::Vector { element, count } => {
284-
cx.lookup_type(element).sizeof(cx)? * count.next_power_of_two() as u64
285-
}
287+
Self::Vector { size, .. } => size,
286288
Self::Matrix { element, count } => cx.lookup_type(element).sizeof(cx)? * count as u64,
287289
Self::Array { element, count } => {
288290
cx.lookup_type(element).sizeof(cx)?
@@ -310,14 +312,7 @@ impl SpirvType<'_> {
310312

311313
Self::Bool => Align::from_bytes(1).unwrap(),
312314
Self::Integer(width, _) | Self::Float(width) => Align::from_bits(width as u64).unwrap(),
313-
Self::Adt { align, .. } => align,
314-
// Vectors have size==align
315-
Self::Vector { .. } => Align::from_bytes(
316-
self.sizeof(cx)
317-
.expect("alignof: Vectors must be sized")
318-
.bytes(),
319-
)
320-
.expect("alignof: Vectors must have power-of-2 size"),
315+
Self::Adt { align, .. } | Self::Vector { align, .. } => align,
321316
Self::Array { element, .. }
322317
| Self::RuntimeArray { element }
323318
| Self::Matrix { element, .. } => cx.lookup_type(element).alignof(cx),
@@ -382,7 +377,17 @@ impl SpirvType<'_> {
382377
SpirvType::Bool => SpirvType::Bool,
383378
SpirvType::Integer(width, signedness) => SpirvType::Integer(width, signedness),
384379
SpirvType::Float(width) => SpirvType::Float(width),
385-
SpirvType::Vector { element, count } => SpirvType::Vector { element, count },
380+
SpirvType::Vector {
381+
element,
382+
count,
383+
size,
384+
align,
385+
} => SpirvType::Vector {
386+
element,
387+
count,
388+
size,
389+
align,
390+
},
386391
SpirvType::Matrix { element, count } => SpirvType::Matrix { element, count },
387392
SpirvType::Array { element, count } => SpirvType::Array { element, count },
388393
SpirvType::RuntimeArray { element } => SpirvType::RuntimeArray { element },
@@ -435,6 +440,15 @@ impl SpirvType<'_> {
435440
},
436441
}
437442
}
443+
444+
pub fn simd_vector(cx: &CodegenCx<'_>, span: Span, element: SpirvType<'_>, count: u32) -> Self {
445+
Self::Vector {
446+
element: element.def(span, cx),
447+
count,
448+
size: element.sizeof(cx).unwrap() * count as u64,
449+
align: element.alignof(cx),
450+
}
451+
}
438452
}
439453

440454
impl<'a> SpirvType<'a> {
@@ -501,11 +515,18 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
501515
.field("field_names", &field_names)
502516
.finish()
503517
}
504-
SpirvType::Vector { element, count } => f
518+
SpirvType::Vector {
519+
element,
520+
count,
521+
size,
522+
align,
523+
} => f
505524
.debug_struct("Vector")
506525
.field("id", &self.id)
507526
.field("element", &self.cx.debug_type(element))
508527
.field("count", &count)
528+
.field("size", &size)
529+
.field("align", &align)
509530
.finish(),
510531
SpirvType::Matrix { element, count } => f
511532
.debug_struct("Matrix")
@@ -668,7 +689,7 @@ impl SpirvTypePrinter<'_, '_> {
668689
}
669690
f.write_str(" }")
670691
}
671-
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => {
692+
SpirvType::Vector { element, count, .. } | SpirvType::Matrix { element, count } => {
672693
ty(self.cx, stack, f, element)?;
673694
write!(f, "x{count}")
674695
}

crates/rustc_codegen_spirv/src/symbols.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ pub struct Symbols {
1515
pub discriminant: Symbol,
1616
pub rust_gpu: Symbol,
1717
pub spirv_attr_with_version: Symbol,
18+
pub vector: Symbol,
19+
pub v1: Symbol,
1820
pub libm: Symbol,
1921
pub entry_point_name: Symbol,
2022
pub spv_khr_vulkan_memory_model: Symbol,
@@ -371,6 +373,10 @@ impl Symbols {
371373
"matrix",
372374
SpirvAttribute::IntrinsicType(IntrinsicType::Matrix),
373375
),
376+
(
377+
"vector",
378+
SpirvAttribute::IntrinsicType(IntrinsicType::Vector),
379+
),
374380
("buffer_load_intrinsic", SpirvAttribute::BufferLoadIntrinsic),
375381
(
376382
"buffer_store_intrinsic",
@@ -406,6 +412,8 @@ impl Symbols {
406412
discriminant: Symbol::intern("discriminant"),
407413
rust_gpu: Symbol::intern("rust_gpu"),
408414
spirv_attr_with_version: Symbol::intern(&spirv_attr_with_version()),
415+
vector: Symbol::intern("vector"),
416+
v1: Symbol::intern("v1"),
409417
libm: Symbol::intern("libm"),
410418
entry_point_name: Symbol::intern("entry_point_name"),
411419
spv_khr_vulkan_memory_model: Symbol::intern("SPV_KHR_vulkan_memory_model"),

0 commit comments

Comments
 (0)