|
21 | 21 | #include "clang/AST/Expr.h" |
22 | 22 | #include "clang/AST/HLSLResource.h" |
23 | 23 | #include "clang/AST/Type.h" |
| 24 | +#include "clang/AST/TypeBase.h" |
24 | 25 | #include "clang/AST/TypeLoc.h" |
25 | 26 | #include "clang/Basic/Builtins.h" |
26 | 27 | #include "clang/Basic/DiagnosticSema.h" |
|
31 | 32 | #include "clang/Basic/TargetInfo.h" |
32 | 33 | #include "clang/Sema/Initialization.h" |
33 | 34 | #include "clang/Sema/Lookup.h" |
| 35 | +#include "clang/Sema/Ownership.h" |
34 | 36 | #include "clang/Sema/ParsedAttr.h" |
35 | 37 | #include "clang/Sema/Sema.h" |
36 | 38 | #include "clang/Sema/Template.h" |
@@ -4313,6 +4315,101 @@ bool SemaHLSL::transformInitList(const InitializedEntity &Entity, |
4313 | 4315 | return true; |
4314 | 4316 | } |
4315 | 4317 |
|
| 4318 | +ExprResult SemaHLSL::tryBuildHLSLMatrixElementAccessor( |
| 4319 | + Expr *Base, SourceLocation MemberLoc, const IdentifierInfo *MemberId) { |
| 4320 | + if (!Base) |
| 4321 | + return ExprError(); |
| 4322 | + |
| 4323 | + QualType T = Base->getType(); |
| 4324 | + const ConstantMatrixType *MT = T->getAs<ConstantMatrixType>(); |
| 4325 | + if (!MT) |
| 4326 | + return ExprError(); |
| 4327 | + |
| 4328 | + auto parseHLSLMatrixAccessor = |
| 4329 | + [this, MemberLoc, MemberId](unsigned &Row, unsigned &Col) -> ExprResult { |
| 4330 | + StringRef Name = MemberId->getNameStart(); |
| 4331 | + bool IsMaccessor = Name.consume_front("_m"); |
| 4332 | + // consume numeric accessor second so if _m exist we don't consume _ to |
| 4333 | + // early. |
| 4334 | + bool IsNumericAccessor = Name.consume_front("_"); |
| 4335 | + if (!IsMaccessor && !IsNumericAccessor) |
| 4336 | + return ExprError( |
| 4337 | + Diag(MemberLoc, diag::err_builtin_matrix_invalid_member)); |
| 4338 | + |
| 4339 | + auto isDigit = [](char c) { return c >= '0' && c <= '9'; }; |
| 4340 | + auto isZeroBasedIndex = [](char c) { return c >= '0' && c <= '3'; }; |
| 4341 | + auto isOneBasedIndex = [](char c) { return c >= '1' && c <= '4'; }; |
| 4342 | + if (Name.empty() || !isDigit(Name[0]) || !isDigit(Name[1])) { |
| 4343 | + return ExprError( |
| 4344 | + Diag(MemberLoc, diag::err_typecheck_subscript_not_integer) |
| 4345 | + << /*matrix*/ 1); |
| 4346 | + } |
| 4347 | + |
| 4348 | + Row = Name[0] - '0'; |
| 4349 | + Col = Name[1] - '0'; |
| 4350 | + bool HasIndexingError = false; |
| 4351 | + if (IsNumericAccessor) { |
| 4352 | + Row--; |
| 4353 | + Col--; |
| 4354 | + // Note: we add diagnostic errors here because otherwise we will return -1 |
| 4355 | + if (!isOneBasedIndex(Name[0])) { |
| 4356 | + Diag(MemberLoc, diag::err_typecheck_subscript_not_in_bounds) |
| 4357 | + << /*row*/ 0 << /*one*/ 1; |
| 4358 | + HasIndexingError = true; |
| 4359 | + } |
| 4360 | + if (!isOneBasedIndex(Name[1])) { |
| 4361 | + Diag(MemberLoc, diag::err_typecheck_subscript_not_in_bounds) |
| 4362 | + << /*col*/ 1 << /*one*/ 1; |
| 4363 | + HasIndexingError = true; |
| 4364 | + } |
| 4365 | + } else { |
| 4366 | + if (!isZeroBasedIndex(Name[0])) { |
| 4367 | + Diag(MemberLoc, diag::err_typecheck_subscript_not_in_bounds) |
| 4368 | + << /*row*/ 0 << /*zero*/ 0; |
| 4369 | + HasIndexingError = true; |
| 4370 | + } |
| 4371 | + if (!isZeroBasedIndex(Name[1])) { |
| 4372 | + Diag(MemberLoc, diag::err_typecheck_subscript_not_in_bounds) |
| 4373 | + << /*col*/ 1 << /*zero*/ 0; |
| 4374 | + HasIndexingError = true; |
| 4375 | + } |
| 4376 | + } |
| 4377 | + if (HasIndexingError) |
| 4378 | + return ExprError(); |
| 4379 | + return ExprEmpty(); |
| 4380 | + }; |
| 4381 | + unsigned Row = 0, Col = 0; |
| 4382 | + ExprResult ParseResult = parseHLSLMatrixAccessor(Row, Col); |
| 4383 | + if (ParseResult.isInvalid()) |
| 4384 | + return ParseResult; |
| 4385 | + |
| 4386 | + unsigned Rows = MT->getNumRows(); |
| 4387 | + unsigned Cols = MT->getNumColumns(); |
| 4388 | + bool HasBoundsError = false; |
| 4389 | + if (Row >= Rows) { |
| 4390 | + Diag(MemberLoc, diag::err_typecheck_subscript_out_of_bounds) << /*Row*/ 0; |
| 4391 | + HasBoundsError = true; |
| 4392 | + } |
| 4393 | + if (Col >= Cols) { |
| 4394 | + Diag(MemberLoc, diag::err_typecheck_subscript_out_of_bounds) << /*Col*/ 1; |
| 4395 | + HasBoundsError = true; |
| 4396 | + } |
| 4397 | + if (HasBoundsError) |
| 4398 | + return ExprError(); |
| 4399 | + |
| 4400 | + auto mkIdx = [&](unsigned v) -> Expr * { |
| 4401 | + ASTContext &Context = SemaRef.getASTContext(); |
| 4402 | + return IntegerLiteral::Create(Context, llvm::APInt(32, v), |
| 4403 | + Context.UnsignedIntTy, MemberLoc); |
| 4404 | + }; |
| 4405 | + Expr *RowIdx = mkIdx(Row); |
| 4406 | + Expr *ColIdx = mkIdx(Col); |
| 4407 | + |
| 4408 | + // Build A[Row][Col], reusing the existing matrix subscript machinery. |
| 4409 | + return SemaRef.CreateBuiltinMatrixSubscriptExpr(Base, RowIdx, ColIdx, |
| 4410 | + MemberLoc); |
| 4411 | +} |
| 4412 | + |
4316 | 4413 | bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) { |
4317 | 4414 | const HLSLVkConstantIdAttr *ConstIdAttr = |
4318 | 4415 | VDecl->getAttr<HLSLVkConstantIdAttr>(); |
|
0 commit comments