Skip to content

Commit 57c4d4a

Browse files
akuegelGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Move helper function to indexing_analysis (NFC).
PiperOrigin-RevId: 817057604
1 parent 24346ac commit 57c4d4a

File tree

3 files changed

+35
-32
lines changed

3 files changed

+35
-32
lines changed

xla/hlo/analysis/indexing_analysis.cc

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,7 +1605,7 @@ void GetThreadIdToInputMemoryLayoutsMaps(
16051605
mlir_context);
16061606
// For every operand compute thread ID -> physical layout of operand
16071607
// indexing map.
1608-
for (auto&& [operand, operand_linarized_physical_map] :
1608+
for (auto&& [operand, operand_linearized_physical_map] :
16091609
llvm::zip(operands, operand_logical_to_linearized_physical_maps)) {
16101610
auto operand_indexing_maps_it =
16111611
instr_indexing_keyed_by_operands.find(operand);
@@ -1623,7 +1623,7 @@ void GetThreadIdToInputMemoryLayoutsMaps(
16231623
break;
16241624
}
16251625
IndexingMap logical_output_to_linearized_physical_input_map =
1626-
operand_indexing_map * operand_linarized_physical_map;
1626+
operand_indexing_map * operand_linearized_physical_map;
16271627
IndexingMap thread_id_to_linearized_physical_input_map =
16281628
thread_id_to_hero_operand_map *
16291629
logical_output_to_linearized_physical_input_map;
@@ -1634,6 +1634,36 @@ void GetThreadIdToInputMemoryLayoutsMaps(
16341634
}
16351635
}
16361636

1637+
// Replaces RTVars with the midpoints of the feasible intervals.
1638+
void AssignValuesToRTVars(IndexingMap* indexing_map) {
1639+
// If RTVars are present, replace them with constants.
1640+
if (indexing_map->GetRTVarsCount() == 0) {
1641+
return;
1642+
}
1643+
MLIRContext* mlir_context = indexing_map->GetMLIRContext();
1644+
llvm::SmallVector<AffineExpr, 2> symbol_replacements;
1645+
for (int64_t symbol_id = 0; symbol_id < indexing_map->GetRangeVarsCount();
1646+
++symbol_id) {
1647+
symbol_replacements.push_back(
1648+
mlir::getAffineSymbolExpr(symbol_id, mlir_context));
1649+
}
1650+
for (const IndexingMap::Variable& rt_var : indexing_map->GetRTVars()) {
1651+
// Take midpoint of the feasible interval for the RT variable.
1652+
symbol_replacements.push_back(getAffineConstantExpr(
1653+
(rt_var.bounds.lower + rt_var.bounds.upper) / 2, mlir_context));
1654+
}
1655+
AffineMap thread_x_to_input_no_dim_symbols =
1656+
indexing_map->GetAffineMap().replaceDimsAndSymbols(
1657+
{}, symbol_replacements, indexing_map->GetDimVarsCount(),
1658+
indexing_map->GetRangeVarsCount());
1659+
*indexing_map = IndexingMap{thread_x_to_input_no_dim_symbols,
1660+
indexing_map->GetDimVars(),
1661+
indexing_map->GetRangeVars(),
1662+
{}};
1663+
indexing_map->Simplify();
1664+
indexing_map->RemoveUnusedSymbols();
1665+
}
1666+
16371667
HloInstructionIndexing ComputeOutputToInputAllGatherOpIndexing(
16381668
const HloAllGatherInstruction* instr, MLIRContext* ctx) {
16391669
// CHECK_EQ(instr->all_gather_dimension(), 0);

xla/hlo/analysis/indexing_analysis.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,9 @@ void GetThreadIdToInputMemoryLayoutsMaps(
226226
absl::Span<const IndexingMap> operand_logical_to_linearized_physical_maps,
227227
mlir::MLIRContext* mlir_context, GroupedByOpIndexingMap& result);
228228

229+
// Replaces RTVars with the midpoints of the feasible intervals.
230+
void AssignValuesToRTVars(IndexingMap* indexing_map);
231+
229232
// Groups indexing maps by instructions.
230233
GroupedByOpIndexing GroupIndexingMapsByProducers(
231234
const HloInstructionIndexing& indexing, const HloInstruction* instr);

xla/service/gpu/model/coalescing_analysis.cc

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -218,36 +218,6 @@ bool EstimateCoalescingViaMemoryTransactionsCount(
218218
memory_transactions * kIsCoalescedThreshold;
219219
}
220220

221-
// Replaces RTVars with the midpoints of the feasible intervals.
222-
void AssignValuesToRTVars(IndexingMap* indexing_map) {
223-
// If RTVars are present, replace them with constants.
224-
if (indexing_map->GetRTVarsCount() == 0) {
225-
return;
226-
}
227-
MLIRContext* mlir_context = indexing_map->GetMLIRContext();
228-
llvm::SmallVector<AffineExpr, 2> symbol_replacements;
229-
for (int64_t symbol_id = 0; symbol_id < indexing_map->GetRangeVarsCount();
230-
++symbol_id) {
231-
symbol_replacements.push_back(
232-
mlir::getAffineSymbolExpr(symbol_id, mlir_context));
233-
}
234-
for (const IndexingMap::Variable& rt_var : indexing_map->GetRTVars()) {
235-
// Take midpoint of the feasible interval for the RT variable.
236-
symbol_replacements.push_back(getAffineConstantExpr(
237-
(rt_var.bounds.lower + rt_var.bounds.upper) / 2, mlir_context));
238-
}
239-
AffineMap thread_x_to_input_no_dim_symbols =
240-
indexing_map->GetAffineMap().replaceDimsAndSymbols(
241-
{}, symbol_replacements, indexing_map->GetDimVarsCount(),
242-
indexing_map->GetRangeVarsCount());
243-
*indexing_map = IndexingMap{thread_x_to_input_no_dim_symbols,
244-
indexing_map->GetDimVars(),
245-
indexing_map->GetRangeVars(),
246-
{}};
247-
indexing_map->Simplify();
248-
indexing_map->RemoveUnusedSymbols();
249-
}
250-
251221
// Replaces all but one RangeVars with the first elements in the range.
252222
// At the moment, we assume that the last RangeVar symbol corresponds to the
253223
// innermost loop induction variable.

0 commit comments

Comments
 (0)