@@ -57,6 +57,9 @@ struct MIRef {
5757 ++I, ++Pos)
5858 MI = &*I;
5959 }
60+ MIRef (MachineInstr *MI)
61+ : MI(MI), MBB(MI->getParent ()),
62+ Pos(std::distance(MBB->instr_begin (), ++MI->getIterator())) {}
6063 MIRef (MachineInstr *MI, MachineBasicBlock *MBB)
6164 : MI(MI), MBB(MBB),
6265 Pos(std::distance(MBB->instr_begin (), ++MI->getIterator())) {}
@@ -66,6 +69,7 @@ struct MIRef {
6669 bool operator ==(const MIRef &RHS) const {
6770 return MI == RHS.MI && MBB == RHS.MBB ;
6871 }
72+ bool operator !=(const MIRef &RHS) const { return !(*this == RHS); }
6973 bool operator <(const MIRef &RHS) const {
7074 return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos );
7175 }
@@ -77,7 +81,7 @@ struct MIRef {
7781struct BBInfo {
7882 MIRef FirstAMX;
7983 MIRef LastCall;
80- MIRef LastShape ;
84+ bool HasAMXRegLiveIn = false ;
8185 bool TileCfgForbidden = false ;
8286 bool NeedTileCfgLiveIn = false ;
8387};
@@ -86,8 +90,8 @@ class X86PreTileConfig : public MachineFunctionPass {
8690 MachineRegisterInfo *MRI;
8791 const MachineLoopInfo *MLI;
8892 SmallSet<MachineInstr *, 8 > DefVisited;
89- SmallSet<MachineBasicBlock *, 8 > ShapeBBs;
9093 DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
94+ DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8 >> ShapeBBs;
9195
9296 // / Check if the callee will clobber AMX registers.
9397 bool isDestructiveCall (MachineInstr &MI, BitVector UsableRegs) {
@@ -124,6 +128,32 @@ class X86PreTileConfig : public MachineFunctionPass {
124128 // / Collect the shape def information for later use.
125129 void collectShapeInfo (MachineInstr &MI);
126130
131+ // / Try to hoist shapes definded below AMX instructions.
132+ bool hoistShapesInBB (MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) {
133+ MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX ;
134+ auto FirstShapeBelowAMX = llvm::lower_bound (Shapes, FirstAMX);
135+ auto InsertPoint = FirstAMX.MI ->getIterator ();
136+ for (auto I = FirstShapeBelowAMX, E = Shapes.end (); I != E; ++I) {
137+ // Do not hoist instructions that access memory.
138+ if (I->MI ->mayLoadOrStore ())
139+ return false ;
140+ for (auto &MO : I->MI ->operands ()) {
141+ if (MO.isDef ())
142+ continue ;
143+ // Do not hoist instructions if the sources' def under AMX instruction.
144+ // TODO: We can handle isMoveImmediate MI here.
145+ if (MO.isReg () && MIRef (MRI->getVRegDef (MO.getReg ())) > FirstAMX)
146+ return false ;
147+ // TODO: Maybe need more checks here.
148+ }
149+ MBB->insert (InsertPoint, I->MI ->removeFromParent ());
150+ }
151+ // We only need to mark the last shape in the BB now.
152+ Shapes.clear ();
153+ Shapes.push_back (MIRef (&*--InsertPoint, MBB));
154+ return true ;
155+ }
156+
127157public:
128158 X86PreTileConfig () : MachineFunctionPass(ID) {}
129159
@@ -165,9 +195,9 @@ INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
165195void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {
166196 auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
167197 MIRef MIR (MI, MBB);
168- if (BBVisitedInfo [MBB]. LastShape < MIR)
169- BBVisitedInfo [MBB].LastShape = MIR;
170- ShapeBBs.insert (MBB );
198+ auto I = llvm::lower_bound (ShapeBBs [MBB], MIR);
199+ if (I == ShapeBBs [MBB].end () || *I ! = MIR)
200+ ShapeBBs[MBB] .insert (I, MIR );
171201 };
172202
173203 SmallVector<Register, 8 > WorkList (
@@ -229,6 +259,10 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
229259 else
230260 CfgLiveInBBs.push_back (&MBB);
231261 }
262+ if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn )
263+ for (auto *Succ : MBB.successors ())
264+ if (!isLoopBackEdge (Succ, &MBB))
265+ BBVisitedInfo[Succ].HasAMXRegLiveIn = true ;
232266 }
233267
234268 // Update NeedTileCfgLiveIn for predecessors.
@@ -252,8 +286,17 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
252286 return false ;
253287
254288 // Avoid to insert ldtilecfg before any shape defs.
255- SmallVector<MachineBasicBlock *, 8 > WorkList (
256- make_range (ShapeBBs.begin (), ShapeBBs.end ()));
289+ SmallVector<MachineBasicBlock *, 8 > WorkList;
290+ for (auto &I : ShapeBBs) {
291+ // TODO: We can hoist shapes across BBs here.
292+ if (BBVisitedInfo[I.first ].HasAMXRegLiveIn )
293+ REPORT_CONFIG_FAIL
294+ if (BBVisitedInfo[I.first ].FirstAMX &&
295+ BBVisitedInfo[I.first ].FirstAMX < I.second .back () &&
296+ !hoistShapesInBB (I.first , I.second ))
297+ REPORT_CONFIG_FAIL
298+ WorkList.push_back (I.first );
299+ }
257300 while (!WorkList.empty ()) {
258301 MachineBasicBlock *MBB = WorkList.pop_back_val ();
259302 for (auto *Pred : MBB->predecessors ()) {
@@ -282,9 +325,6 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
282325 } else {
283326 // Avoid the BB to be multi visited.
284327 VisitedOrInserted.insert (I);
285- // We cannot sink it across any AMX instruction.
286- if (BBVisitedInfo[I.MBB ].FirstAMX )
287- REPORT_CONFIG_FAIL;
288328 // Sink the inserting point along the chain with NeedTileCfgLiveIn =
289329 // true when MBB isn't all shapes reachable.
290330 for (auto *Succ : I.MBB ->successors ())
@@ -296,14 +336,9 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
296336
297337 // A given point might be forked due to shape conditions are not met.
298338 for (MIRef I : InsertPoints) {
299- // Even MBB is all shapes reachable, we still need to check if there's
300- // AMX that intersects with shapes in the same MBB.
301- if (BBVisitedInfo[I.MBB ].FirstAMX &&
302- BBVisitedInfo[I.MBB ].FirstAMX < BBVisitedInfo[I.MBB ].LastShape )
303- REPORT_CONFIG_FAIL;
304339 // Make sure we insert ldtilecfg after the last shape def in MBB.
305- if (I < BBVisitedInfo [I.MBB ].LastShape )
306- I = BBVisitedInfo [I.MBB ].LastShape ;
340+ if (ShapeBBs. count (I. MBB ) && I < ShapeBBs [I.MBB ].back () )
341+ I = ShapeBBs [I.MBB ].back () ;
307342 // There're chances the MBB is sunk more than once. Record it to avoid
308343 // multi insert.
309344 if (VisitedOrInserted.insert (I).second ) {
0 commit comments