Skip to content

Commit 4a15512

Browse files
committed
[HLSL] Add _m and _<numeric> based accessors to hlsl::matrix
fixes #159438 HLSL supports 0 based index and one based index accessors that are equivalent to the column row bracket subscripting operators. This change adds support for these accessors by hooking into LookupMemberExpr and adding them similar to how `CheckExtVectorComponent` and the hlsl specific scalar to vectorsplat accessors work. Since these accessors are HLSL specific The implementation details are kept to SemaHLSL via `tryBuildHLSLMatrixElementAccessor`.
1 parent c7d776b commit 4a15512

File tree

7 files changed

+204
-7
lines changed

7 files changed

+204
-7
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7500,7 +7500,11 @@ def ext_subscript_non_lvalue : Extension<
75007500
def err_typecheck_subscript_value : Error<
75017501
"subscripted value is not an array, pointer, or vector">;
75027502
def err_typecheck_subscript_not_integer : Error<
7503-
"array subscript is not an integer">;
7503+
"%select{array|matrix}0 subscript is not an integer">;
7504+
def err_typecheck_subscript_not_in_bounds : Error<
7505+
"%select{row|column}0 subscript is out of bounds of %select{zero|one}1 based indexing">;
7506+
def err_typecheck_subscript_out_of_bounds : Error<
7507+
"%select{row|column}0 subscripted is out of bounds">;
75047508
def err_subscript_function_type : Error<
75057509
"subscript of pointer to function type %0">;
75067510
def err_subscript_incomplete_or_sizeless_type : Error<
@@ -12972,6 +12976,7 @@ def err_builtin_matrix_stride_too_small: Error<
1297212976
"stride must be greater or equal to the number of rows">;
1297312977
def err_builtin_matrix_invalid_dimension: Error<
1297412978
"%0 dimension is outside the allowed range [1, %1]">;
12979+
def err_builtin_matrix_invalid_member: Error<"invalid matrix member">;
1297512980

1297612981
def warn_mismatched_import : Warning<
1297712982
"import %select{module|name}0 (%1) does not match the import %select{module|name}0 (%2) of the "

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,9 @@ class SemaHLSL : public SemaBase {
225225
bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init);
226226
bool handleInitialization(VarDecl *VDecl, Expr *&Init);
227227
void deduceAddressSpace(VarDecl *Decl);
228+
ExprResult tryBuildHLSLMatrixElementAccessor(Expr *Base,
229+
SourceLocation MemberLoc,
230+
const IdentifierInfo *MemberId);
228231

229232
private:
230233
// HLSL resource type attributes need to be processed all at once.

clang/lib/Sema/SemaExprMember.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "clang/Sema/Overload.h"
2020
#include "clang/Sema/Scope.h"
2121
#include "clang/Sema/ScopeInfo.h"
22+
#include "clang/Sema/SemaHLSL.h"
2223
#include "clang/Sema/SemaObjC.h"
2324
#include "clang/Sema/SemaOpenMP.h"
2425

@@ -1669,12 +1670,19 @@ static ExprResult LookupMemberExpr(Sema &S, LookupResult &R,
16691670

16701671
// HLSL supports implicit conversion of scalar types to single element vector
16711672
// rvalues in member expressions.
1672-
if (S.getLangOpts().HLSL && BaseType->isScalarType()) {
1673-
QualType VectorTy = S.Context.getExtVectorType(BaseType, 1);
1674-
BaseExpr = S.ImpCastExprToType(BaseExpr.get(), VectorTy, CK_VectorSplat,
1675-
BaseExpr.get()->getValueKind());
1676-
return LookupMemberExpr(S, R, BaseExpr, IsArrow, OpLoc, SS, ObjCImpDecl,
1677-
HasTemplateArgs, TemplateKWLoc);
1673+
if (S.getLangOpts().HLSL) {
1674+
if (BaseType->isScalarType()) {
1675+
QualType VectorTy = S.Context.getExtVectorType(BaseType, 1);
1676+
BaseExpr = S.ImpCastExprToType(BaseExpr.get(), VectorTy, CK_VectorSplat,
1677+
BaseExpr.get()->getValueKind());
1678+
return LookupMemberExpr(S, R, BaseExpr, IsArrow, OpLoc, SS, ObjCImpDecl,
1679+
HasTemplateArgs, TemplateKWLoc);
1680+
}
1681+
if (!IsArrow && BaseType->isConstantMatrixType()) {
1682+
if (const auto *II = MemberName.getAsIdentifierInfo())
1683+
return S.HLSL().tryBuildHLSLMatrixElementAccessor(BaseExpr.get(),
1684+
MemberLoc, II);
1685+
}
16781686
}
16791687

16801688
S.Diag(OpLoc, diag::err_typecheck_member_reference_struct_union)

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "clang/AST/Expr.h"
2222
#include "clang/AST/HLSLResource.h"
2323
#include "clang/AST/Type.h"
24+
#include "clang/AST/TypeBase.h"
2425
#include "clang/AST/TypeLoc.h"
2526
#include "clang/Basic/Builtins.h"
2627
#include "clang/Basic/DiagnosticSema.h"
@@ -31,6 +32,7 @@
3132
#include "clang/Basic/TargetInfo.h"
3233
#include "clang/Sema/Initialization.h"
3334
#include "clang/Sema/Lookup.h"
35+
#include "clang/Sema/Ownership.h"
3436
#include "clang/Sema/ParsedAttr.h"
3537
#include "clang/Sema/Sema.h"
3638
#include "clang/Sema/Template.h"
@@ -4313,6 +4315,101 @@ bool SemaHLSL::transformInitList(const InitializedEntity &Entity,
43134315
return true;
43144316
}
43154317

4318+
ExprResult SemaHLSL::tryBuildHLSLMatrixElementAccessor(
4319+
Expr *Base, SourceLocation MemberLoc, const IdentifierInfo *MemberId) {
4320+
if (!Base)
4321+
return ExprError();
4322+
4323+
QualType T = Base->getType();
4324+
const ConstantMatrixType *MT = T->getAs<ConstantMatrixType>();
4325+
if (!MT)
4326+
return ExprError();
4327+
4328+
auto parseHLSLMatrixAccessor =
4329+
[this, MemberLoc, MemberId](unsigned &Row, unsigned &Col) -> ExprResult {
4330+
StringRef Name = MemberId->getNameStart();
4331+
bool IsMaccessor = Name.consume_front("_m");
4332+
// consume numeric accessor second so if _m exist we don't consume _ to
4333+
// early.
4334+
bool IsNumericAccessor = Name.consume_front("_");
4335+
if (!IsMaccessor && !IsNumericAccessor)
4336+
return ExprError(
4337+
Diag(MemberLoc, diag::err_builtin_matrix_invalid_member));
4338+
4339+
auto isDigit = [](char c) { return c >= '0' && c <= '9'; };
4340+
auto isZeroBasedIndex = [](char c) { return c >= '0' && c <= '3'; };
4341+
auto isOneBasedIndex = [](char c) { return c >= '1' && c <= '4'; };
4342+
if (Name.empty() || !isDigit(Name[0]) || !isDigit(Name[1])) {
4343+
return ExprError(
4344+
Diag(MemberLoc, diag::err_typecheck_subscript_not_integer)
4345+
<< /*matrix*/ 1);
4346+
}
4347+
4348+
Row = Name[0] - '0';
4349+
Col = Name[1] - '0';
4350+
bool HasIndexingError = false;
4351+
if (IsNumericAccessor) {
4352+
Row--;
4353+
Col--;
4354+
// Note: we add diagnostic errors here because otherwise we will return -1
4355+
if (!isOneBasedIndex(Name[0])) {
4356+
Diag(MemberLoc, diag::err_typecheck_subscript_not_in_bounds)
4357+
<< /*row*/ 0 << /*one*/ 1;
4358+
HasIndexingError = true;
4359+
}
4360+
if (!isOneBasedIndex(Name[1])) {
4361+
Diag(MemberLoc, diag::err_typecheck_subscript_not_in_bounds)
4362+
<< /*col*/ 1 << /*one*/ 1;
4363+
HasIndexingError = true;
4364+
}
4365+
} else {
4366+
if (!isZeroBasedIndex(Name[0])) {
4367+
Diag(MemberLoc, diag::err_typecheck_subscript_not_in_bounds)
4368+
<< /*row*/ 0 << /*zero*/ 0;
4369+
HasIndexingError = true;
4370+
}
4371+
if (!isZeroBasedIndex(Name[1])) {
4372+
Diag(MemberLoc, diag::err_typecheck_subscript_not_in_bounds)
4373+
<< /*col*/ 1 << /*zero*/ 0;
4374+
HasIndexingError = true;
4375+
}
4376+
}
4377+
if (HasIndexingError)
4378+
return ExprError();
4379+
return ExprEmpty();
4380+
};
4381+
unsigned Row = 0, Col = 0;
4382+
ExprResult ParseResult = parseHLSLMatrixAccessor(Row, Col);
4383+
if (ParseResult.isInvalid())
4384+
return ParseResult;
4385+
4386+
unsigned Rows = MT->getNumRows();
4387+
unsigned Cols = MT->getNumColumns();
4388+
bool HasBoundsError = false;
4389+
if (Row >= Rows) {
4390+
Diag(MemberLoc, diag::err_typecheck_subscript_out_of_bounds) << /*Row*/ 0;
4391+
HasBoundsError = true;
4392+
}
4393+
if (Col >= Cols) {
4394+
Diag(MemberLoc, diag::err_typecheck_subscript_out_of_bounds) << /*Col*/ 1;
4395+
HasBoundsError = true;
4396+
}
4397+
if (HasBoundsError)
4398+
return ExprError();
4399+
4400+
auto mkIdx = [&](unsigned v) -> Expr * {
4401+
ASTContext &Context = SemaRef.getASTContext();
4402+
return IntegerLiteral::Create(Context, llvm::APInt(32, v),
4403+
Context.UnsignedIntTy, MemberLoc);
4404+
};
4405+
Expr *RowIdx = mkIdx(Row);
4406+
Expr *ColIdx = mkIdx(Col);
4407+
4408+
// Build A[Row][Col], reusing the existing matrix subscript machinery.
4409+
return SemaRef.CreateBuiltinMatrixSubscriptExpr(Base, RowIdx, ColIdx,
4410+
MemberLoc);
4411+
}
4412+
43164413
bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
43174414
const HLSLVkConstantIdAttr *ConstIdAttr =
43184415
VDecl->getAttr<HLSLVkConstantIdAttr>();
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s
2+
3+
typedef float float3x3 __attribute__((matrix_type(3,3)));
4+
5+
[numthreads(1,1,1)]
6+
void ok() {
7+
float3x3 A;
8+
// CHECK: BinaryOperator 0x{{[0-9a-fA-F]+}} <line:{{[0-9]+}}:{{[0-9]+}}, col:{{[0-9]+}}> 'float' lvalue matrixcomponent '='
9+
// CHECK-NEXT: MatrixSubscriptExpr 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}, col:{{[0-9]+}}> 'float' lvalue matrixcomponent
10+
// CHECK-NEXT: DeclRefExpr 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}> 'float3x3':'matrix<float, 3, 3>' lvalue Var 0x{{[0-9a-fA-F]+}} 'A' 'float3x3':'matrix<float, 3, 3>'
11+
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}> 'unsigned int' 1
12+
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}> 'unsigned int' 2
13+
// CHECK-NEXT: FloatingLiteral 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}> 'float' 3.140000e+00
14+
A._m12 = 3.14;
15+
16+
// CHECK: VarDecl 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}, col:{{[0-9]+}}> col:{{[0-9]+}} r 'float' cinit
17+
// CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}, col:{{[0-9]+}}> 'float' <LValueToRValue>
18+
// CHECK-NEXT: MatrixSubscriptExpr 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}, col:{{[0-9]+}}> 'float' lvalue matrixcomponent
19+
// CHECK-NEXT: DeclRefExpr 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}> 'float3x3':'matrix<float, 3, 3>' lvalue Var 0x{{[0-9a-fA-F]+}} 'A' 'float3x3':'matrix<float, 3, 3>'
20+
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}> 'unsigned int' 0
21+
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}> 'unsigned int' 0
22+
float r = A._m00;
23+
24+
// CHECK: VarDecl 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}, col:{{[0-9]+}}> col:{{[0-9]+}} good1 'float' cinit
25+
// CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}, col:{{[0-9]+}}> 'float' <LValueToRValue>
26+
// CHECK-NEXT: MatrixSubscriptExpr 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}, col:{{[0-9]+}}> 'float' lvalue matrixcomponent
27+
// CHECK-NEXT: DeclRefExpr 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}> 'float3x3':'matrix<float, 3, 3>' lvalue Var 0x{{[0-9a-fA-F]+}} 'A' 'float3x3':'matrix<float, 3, 3>'
28+
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}> 'unsigned int' 0
29+
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}> 'unsigned int' 0
30+
float good1 = A._11;
31+
32+
// CHECK: VarDecl 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}, col:{{[0-9]+}}> col:{{[0-9]+}} good2 'float' cinit
33+
// CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}, col:{{[0-9]+}}> 'float' <LValueToRValue>
34+
// CHECK-NEXT: MatrixSubscriptExpr 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}, col:{{[0-9]+}}> 'float' lvalue matrixcomponent
35+
// CHECK-NEXT: DeclRefExpr 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}> 'float3x3':'matrix<float, 3, 3>' lvalue Var 0x{{[0-9a-fA-F]+}} 'A' 'float3x3':'matrix<float, 3, 3>'
36+
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}> 'unsigned int' 2
37+
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:{{[0-9]+}}> 'unsigned int' 2
38+
float good2 = A._33;
39+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -x hlsl -triple \
2+
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
3+
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s
4+
5+
6+
half3x3 write_pi(half3x3 A) {
7+
//CHECK: [[MAT_INS:%.*]] = insertelement <9 x half> %{{[0-9]+}}, half 0xH4248, i32 7
8+
//CHECK-NEXT: store <9 x half> [[MAT_INS]], ptr %{{.*}}, align 2
9+
A._m12 = 3.14;
10+
return A;
11+
}
12+
13+
half read_1x1(half3x3 A) {
14+
//CHECK: [[MAT_EXT:%.*]] = extractelement <9 x half> %{{[0-9]+}}, i32 0
15+
return A._11;
16+
}
17+
half read_m0x0(half3x3 A) {
18+
//CHECK: [[MAT_EXT:%.*]] = extractelement <9 x half> %{{[0-9]+}}, i32 0
19+
return A._m00;
20+
}
21+
22+
half read_3x3(half3x3 A) {
23+
//CHECK: [[MAT_EXT:%.*]] = extractelement <9 x half> %{{[0-9]+}}, i32 8
24+
return A._33;
25+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -finclude-default-header -verify %s
2+
3+
void foo() {
4+
float3x3 A;
5+
float r = A._m00; // read is ok
6+
float good1 = A._11;
7+
float good2 = A._33;
8+
9+
float bad0 = A._m44; // expected-error {{row subscript is out of bounds of zero based indexing}} expected-error {{column subscript is out of bounds of zero based indexing}}
10+
float bad1 = A._m33; // expected-error {{row subscripted is out of bounds}} expected-error {{column subscripted is out of bounds}}
11+
float bad2 = A._mA2; // expected-error {{matrix subscript is not an integer}}
12+
float bad3 = A._m2F; // expected-error {{matrix subscript is not an integer}}
13+
14+
float bad4 = A._00; // expected-error {{row subscript is out of bounds of one based indexing}} expected-error {{column subscript is out of bounds of one based indexing}}
15+
float bad5 = A._44; // expected-error {{row subscripted is out of bounds}} expected-error {{column subscripted is out of bounds}}
16+
float bad6 = A._55; // expected-error {{row subscript is out of bounds of one based indexing}} expected-error {{column subscript is out of bounds of one based indexing}}
17+
18+
19+
A._m12 = 3.14; // write is OK
20+
}

0 commit comments

Comments
 (0)