@@ -2430,22 +2430,17 @@ static Value *EmitHLSLElementwiseCast(CodeGenFunction &CGF, LValue SrcVal,
24302430 llvm::Value *V =
24312431 CGF.Builder .CreateLoad (CGF.CreateIRTemp (DestTy, " flatcast.tmp" ));
24322432 // write to V.
2433- unsigned NumCols = MatTy->getNumColumns ();
2434- unsigned NumRows = MatTy->getNumRows ();
2435- unsigned ColOffset = NumCols;
2436- if (auto *SrcMatTy = SrcVal.getType ()->getAs <ConstantMatrixType>())
2437- ColOffset = SrcMatTy->getNumColumns ();
2438- for (unsigned R = 0 ; R < NumRows; R++) {
2439- for (unsigned C = 0 ; C < NumCols; C++) {
2440- unsigned I = R * ColOffset + C;
2441- RValue RVal = CGF.EmitLoadOfLValue (LoadList[I], Loc);
2442- assert (RVal.isScalar () &&
2443- " All flattened source values should be scalars." );
2444- llvm::Value *Cast =
2445- CGF.EmitScalarConversion (RVal.getScalarVal (), LoadList[I].getType (),
2446- MatTy->getElementType (), Loc);
2447- V = CGF.Builder .CreateInsertElement (V, Cast, I);
2448- }
2433+ for (unsigned I = 0 , E = MatTy->getNumElementsFlattened (); I < E; I++) {
2434+ unsigned ColMajorIndex =
2435+ (I % MatTy->getNumRows ()) * MatTy->getNumColumns () +
2436+ (I / MatTy->getNumRows ());
2437+ RValue RVal = CGF.EmitLoadOfLValue (LoadList[ColMajorIndex], Loc);
2438+ assert (RVal.isScalar () &&
2439+ " All flattened source values should be scalars." );
2440+ llvm::Value *Cast = CGF.EmitScalarConversion (
2441+ RVal.getScalarVal (), LoadList[ColMajorIndex].getType (),
2442+ MatTy->getElementType (), Loc);
2443+ V = CGF.Builder .CreateInsertElement (V, Cast, I);
24492444 }
24502445 return V;
24512446 }
0 commit comments