@@ -1605,7 +1605,7 @@ void GetThreadIdToInputMemoryLayoutsMaps(
16051605 mlir_context);
16061606 // For every operand compute thread ID -> physical layout of operand
16071607 // indexing map.
1608- for (auto && [operand, operand_linarized_physical_map ] :
1608+ for (auto && [operand, operand_linearized_physical_map ] :
16091609 llvm::zip (operands, operand_logical_to_linearized_physical_maps)) {
16101610 auto operand_indexing_maps_it =
16111611 instr_indexing_keyed_by_operands.find (operand);
@@ -1623,7 +1623,7 @@ void GetThreadIdToInputMemoryLayoutsMaps(
16231623 break ;
16241624 }
16251625 IndexingMap logical_output_to_linearized_physical_input_map =
1626- operand_indexing_map * operand_linarized_physical_map ;
1626+ operand_indexing_map * operand_linearized_physical_map ;
16271627 IndexingMap thread_id_to_linearized_physical_input_map =
16281628 thread_id_to_hero_operand_map *
16291629 logical_output_to_linearized_physical_input_map;
@@ -1634,6 +1634,36 @@ void GetThreadIdToInputMemoryLayoutsMaps(
16341634 }
16351635}
16361636
1637+ // Replaces RTVars with the midpoints of the feasible intervals.
1638+ void AssignValuesToRTVars (IndexingMap* indexing_map) {
1639+ // If RTVars are present, replace them with constants.
1640+ if (indexing_map->GetRTVarsCount () == 0 ) {
1641+ return ;
1642+ }
1643+ MLIRContext* mlir_context = indexing_map->GetMLIRContext ();
1644+ llvm::SmallVector<AffineExpr, 2 > symbol_replacements;
1645+ for (int64_t symbol_id = 0 ; symbol_id < indexing_map->GetRangeVarsCount ();
1646+ ++symbol_id) {
1647+ symbol_replacements.push_back (
1648+ mlir::getAffineSymbolExpr (symbol_id, mlir_context));
1649+ }
1650+ for (const IndexingMap::Variable& rt_var : indexing_map->GetRTVars ()) {
1651+ // Take midpoint of the feasible interval for the RT variable.
1652+ symbol_replacements.push_back (getAffineConstantExpr (
1653+ (rt_var.bounds .lower + rt_var.bounds .upper ) / 2 , mlir_context));
1654+ }
1655+ AffineMap thread_x_to_input_no_dim_symbols =
1656+ indexing_map->GetAffineMap ().replaceDimsAndSymbols (
1657+ {}, symbol_replacements, indexing_map->GetDimVarsCount (),
1658+ indexing_map->GetRangeVarsCount ());
1659+ *indexing_map = IndexingMap{thread_x_to_input_no_dim_symbols,
1660+ indexing_map->GetDimVars (),
1661+ indexing_map->GetRangeVars (),
1662+ {}};
1663+ indexing_map->Simplify ();
1664+ indexing_map->RemoveUnusedSymbols ();
1665+ }
1666+
16371667HloInstructionIndexing ComputeOutputToInputAllGatherOpIndexing (
16381668 const HloAllGatherInstruction* instr, MLIRContext* ctx) {
16391669 // CHECK_EQ(instr->all_gather_dimension(), 0);
0 commit comments