@@ -1194,6 +1194,144 @@ void CustomSafeOptPass::matchDp4a(BinaryOperator &I) {
11941194 I.replaceAllUsesWith (Res);
11951195}
11961196
1197+ // Optimize mix operation if detected.
1198+ // Mix is computed as x*(1 - a) + y*a
1199+ // Replace it with a*(y - x) + x to save one instruction ('add' ISA, 'sub' in IR).
1200+ // This pattern also optimizes a similar operation:
1201+ // x*(a - 1) + y*a which can be replaced with a(x + y) - x
1202+ void CustomSafeOptPass::matchMixOperation (BinaryOperator &I)
1203+ {
1204+ // Pattern Mix check step 1: find a FSub instruction with a constant value of 1.
1205+ if (I.getOpcode () == BinaryOperator::FSub)
1206+ {
1207+ unsigned int fSubOpIdx = 0 ;
1208+ while (fSubOpIdx < 2 && !llvm::isa<llvm::ConstantFP>(I.getOperand (fSubOpIdx )))
1209+ {
1210+ fSubOpIdx ++;
1211+ }
1212+ if ((fSubOpIdx == 1 ) ||
1213+ ((fSubOpIdx == 0 ) && !llvm::isa<llvm::ConstantFP>(I.getOperand (1 ))))
1214+ {
1215+ llvm::ConstantFP* fSubOpConst = llvm::dyn_cast<llvm::ConstantFP>(I.getOperand (fSubOpIdx ));
1216+ const APFloat& APF = fSubOpConst ->getValueAPF ();
1217+ bool isInf = APF.isInfinity ();
1218+ bool isNaN = APF.isNaN ();
1219+ double val = 0.0 ;
1220+ if (!isInf && !isNaN)
1221+ {
1222+ if (&APF.getSemantics () == &APFloat::IEEEdouble ())
1223+ {
1224+ val = APF.convertToDouble ();
1225+ }
1226+ else if (&APF.getSemantics () == &APFloat::IEEEsingle ())
1227+ {
1228+ val = (double )APF.convertToFloat ();
1229+ }
1230+ }
1231+ if (val == 1.0 )
1232+ {
1233+ bool doNotOptimize = false ;
1234+ bool matchFound = false ;
1235+ SmallVector<std::pair<Instruction*, Instruction*>, 3 > fMulInsts ;
1236+
1237+ // Pattern Mix check step 2: there should be only FMul users of this FSub instruction
1238+ for (User* U : I.users ())
1239+ {
1240+ Instruction* fMul = cast<Instruction>(U);
1241+ matchFound = false ;
1242+ if (fMul ->getOpcode () == BinaryOperator::FMul)
1243+ {
1244+ // Pattern Mix check step 3: there should be only one fAdd user for such an FMul instruction
1245+ if ((int )std::distance (fMul ->users ().begin (), fMul ->users ().end ()) == 1 )
1246+ {
1247+ Instruction* fAdd = dyn_cast<Instruction>(*fMul ->users ().begin ());
1248+ if (fAdd ->getOpcode () == BinaryOperator::FAdd)
1249+ {
1250+ // Pattern Mix check step 4: fAdd should be a user of two FMul instructions
1251+ unsigned int opIdx = 0 ;
1252+ while (opIdx < 2 && cast<Value>(fMul ) != fAdd ->getOperand (opIdx))
1253+ {
1254+ opIdx++;
1255+ }
1256+
1257+ if (opIdx < 2 )
1258+ {
1259+ opIdx = (opIdx + 1 ) % 2 ; // 0 -> 1 or 1 -> 0
1260+ Instruction* fMul2nd = cast<Instruction>(fAdd ->getOperand (opIdx));
1261+ if (fMul2nd ->getOpcode () == BinaryOperator::FMul)
1262+ {
1263+ fMulInsts .push_back (std::make_pair (fMul , fMul2nd ));
1264+ matchFound = true ; // Pattern Mix (partially) detected.
1265+ }
1266+ }
1267+ }
1268+ }
1269+
1270+ }
1271+
1272+ if (!matchFound)
1273+ {
1274+ doNotOptimize = true ; // To optimize both FMul instructions and FAdd must be found
1275+ }
1276+ }
1277+
1278+ if (!doNotOptimize && !fMulInsts .empty () && I.users ().begin () != I.users ().end ())
1279+ {
1280+ // Pattern Mix fully detected. Replace sequence of detected instructions with new ones.
1281+ IGC_ASSERT_MESSAGE (
1282+ fMulInsts .size () == (int )std::distance (I.users ().begin (), I.users ().end ()),
1283+ " Incorrect pattern match data" );
1284+ // If Pattern Mix with 1-a in the first instruction was detected then create
1285+ // this sequence of new instructions: FSub, FMul, FAdd.
1286+ // But if Pattern Mix with a-1 in the first instruction was detected then create
1287+ // this sequence of new instructions: FAdd, FMul, FSub.
1288+ Instruction::BinaryOps newFirstInstOp = (fSubOpIdx == 0 ) ? Instruction::FSub : Instruction::FAdd;
1289+ Instruction::BinaryOps newLastInstOp = (fSubOpIdx == 0 ) ? Instruction::FAdd : Instruction::FSub;
1290+
1291+ fSubOpIdx = (fSubOpIdx + 1 ) % 2 ; // 0 -> 1 or 1 -> 0, i.e. get another FSub operand
1292+ Value* r = I.getOperand (fSubOpIdx );
1293+
1294+ for (std::pair<Instruction*, Instruction*> fMulPair : fMulInsts )
1295+ {
1296+ Instruction* fAdd = dyn_cast<Instruction>(*fMulPair .first ->users ().begin ());
1297+
1298+ unsigned int fMul2OpToFirstInstIdx = (r == fMulPair .second ->getOperand (0 )) ? 1 : 0 ;
1299+ Value* newFirstInstOp1 = fMulPair .second ->getOperand (fMul2OpToFirstInstIdx );
1300+ Value* fSubVal = cast<Value>(&I);
1301+
1302+ unsigned int fMul1OpToTakeIdx = (fSubVal == fMulPair .first ->getOperand (0 )) ? 1 : 0 ;
1303+ Instruction* newFirstInst = BinaryOperator::Create (
1304+ newFirstInstOp, fMulPair .first ->getOperand (fMul1OpToTakeIdx ), newFirstInstOp1, " " , &I);
1305+ DILocation* DL1st = I.getDebugLoc ();
1306+ if (DL1st)
1307+ {
1308+ newFirstInst->setDebugLoc (DL1st);
1309+ }
1310+
1311+ Instruction* newFMul = BinaryOperator::CreateFMul (
1312+ fMulPair .second ->getOperand ((fMul2OpToFirstInstIdx + 1 ) % 2 ), newFirstInst, " " , fMulPair .second );
1313+ DILocation* DL2nd = fMulPair .second ->getDebugLoc ();
1314+ if (DL2nd)
1315+ {
1316+ newFMul->setDebugLoc (DL2nd);
1317+ }
1318+
1319+ Instruction* newLastInst = BinaryOperator::Create (
1320+ newLastInstOp, newFMul, fMulPair .first ->getOperand (fMul1OpToTakeIdx ), " " , fAdd );
1321+ DILocation* DL3rd = fMulPair .second ->getDebugLoc ();
1322+ if (DL3rd)
1323+ {
1324+ newLastInst->setDebugLoc (DL3rd);
1325+ }
1326+
1327+ fAdd ->replaceAllUsesWith (newLastInst);
1328+ }
1329+ }
1330+ }
1331+ }
1332+ }
1333+ }
1334+
11971335void CustomSafeOptPass::hoistDp3 (BinaryOperator& I)
11981336{
11991337 if (I.getOpcode () != Instruction::BinaryOps::FAdd)
@@ -1569,6 +1707,10 @@ void CustomSafeOptPass::visitBinaryOperator(BinaryOperator& I)
15691707{
15701708 matchDp4a (I);
15711709
1710+ // Optimize mix operation if detected.
1711+ // Mix is computed as x*(1 - a) + y*a
1712+ matchMixOperation (I);
1713+
15721714 // move immediate value in consecutive integer adds to the last added value.
15731715 // this can allow more chance of doing CSE and memopt.
15741716 // a = b + 8
0 commit comments