Skip to content

Commit 5ccc897

Browse files
metaflowGoogle-ML-Automation
authored andcommitted
[XLA:GPU] doc updates to experimental tiling
PiperOrigin-RevId: 817064107
1 parent 57c4d4a commit 5ccc897

File tree

4 files changed

+124
-65
lines changed

4 files changed

+124
-65
lines changed

xla/service/gpu/model/experimental/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ cc_library(
3131
"//xla:xla_data_proto_cc",
3232
"//xla/codegen/tiling:constraint_expression",
3333
"//xla/hlo/analysis:indexing_analysis",
34+
"//xla/hlo/analysis:interval",
3435
"//xla/hlo/ir:hlo",
3536
"//xla/hlo/utils:hlo_traversal",
3637
"@com_google_absl//absl/container:flat_hash_map",

xla/service/gpu/model/experimental/symbolic_tile.h

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -30,51 +30,69 @@ namespace xla::gpu::experimental {
3030

3131
class TilingSpace;
3232

33-
// A map from tile IDs, sizes and runtime variables to tile's offsets, sizes,
34-
// strides and upper bounds. Offsets-sizes-strides define what slice to extract,
35-
// upper bounds define masking, i.e. if the tile attempts to extract elements
36-
// with the indices outside of the bounds, the tile will be masked.
33+
// Tiling of a single dimension.
3734
//
38-
// (tile IDs) [tile sizes] {runtime variables} ->
39-
// offsets [offsets_] sizes [sizes_] strides [strides_]
40-
// upper bounds [upper_bounds_]
35+
// Offsets, sizes and strides define the slice of the tensor dimension. Upper
36+
// bounds set the range [0, upper_bound), values outside of this range are
37+
// masked.
4138
//
42-
// tile IDs correspond to the dimension variables of the affine expressions;
43-
// tile sizes and RT vars correspond to the symbol variables.
39+
// Expressions for offset, size, stride and upper bound are AffineExpr
40+
// functions. The TilingSpace keeps track of all dimensions and symbols we use
41+
// in the expressions and allows to create a consistent mapping from dimensions
42+
// and runtime variables to affine expression dimensions and symbols.
4443
//
45-
// The masking condition of the upper bound can be written as:
46-
// dimension_index < upper_bounds[i](tile IDs)
44+
// N.B.! not all of the symbols that the TilingSpace defines are used in
45+
// every expression. That depends on the position of the instruction in
46+
// the graph and the traversal path that we took.
47+
// Number of dimensions equals to the number of dimensions of the instruction
48+
// output, parallel dimensions the corresponding root instruction are followed
49+
// by sequential dimensions.
4750
//
48-
// In most of the cases, the upper bounds will coincide with the shape of the
49-
// tensor from which the tile is extracted.
50-
//
51-
// One example when upper bound does not match the shape is a reshape:
52-
// output = s32[2, 17] reshape (s32[34] input)
53-
//
54-
// If we propagate the `output` tile with the ts0 == 1,
55-
//
56-
// (tid0, tid1)[ts1] -> offsets [tid0, tid1 * ts1] sizes [1, ts1] strides [1, 1]
57-
// upper bounds [2, 17]
58-
//
59-
// to the `input` we will get a stricter upper bound
60-
//
61-
// (tid0, tid1)[ts1] -> offsets [17 * tid0 + tid1 * ts1] sizes [ts1] strides [1]
62-
// upper bounds [17 * tid0]
51+
// Symbols are:
52+
// - tile sizes of all dimensions, followed by
53+
// - runtime variables.
6354
struct DimTile {
6455
bool operator==(const DimTile& other) const;
6556

6657
mlir::AffineExpr offset;
6758
mlir::AffineExpr size;
6859
mlir::AffineExpr stride;
60+
// The masking condition of the upper bound can be written as:
61+
// dimension_index < upper_bounds(tile IDs)[tile sizes]{runtime variables}
62+
//
63+
// In most of the cases, the upper bounds will coincide with the shape of the
64+
// tensor from which the tile is extracted. One example when upper bound does
65+
// not match the shape is a reshape:
66+
//
67+
// output = s32[2, 17] reshape (s32[34] input)
68+
//
69+
// If we propagate the `output` SymbolicTile with the tile size of first
70+
// dimension equal to 1
71+
//
72+
// (tid0, tid1)[ts1] -> offsets [tid0, tid1 * ts1]
73+
// sizes [1, ts1]
74+
// strides [1, 1]
75+
// upper bounds [2, 17]
76+
//
77+
// then for to the `input` instruction we will get a stricter upper bound
78+
//
79+
// (tid0, tid1)[ts1] -> offsets [17 * tid0 + tid1 * ts1]
80+
// sizes [ts1]
81+
// strides [1]
82+
// upper bounds [17 * tid0]
6983
mlir::AffineExpr upper_bound;
7084
};
85+
7186
template <typename H>
7287
H AbslHashValue(H h, const DimTile& dim_tile) {
7388
llvm::hash_code dim_tile_hash = llvm::hash_combine(
7489
dim_tile.offset, dim_tile.size, dim_tile.stride, dim_tile.upper_bound);
7590
return H::combine(std::move(h), static_cast<size_t>(dim_tile_hash));
7691
}
7792

93+
// SymbolicTile is a collection of tilings for every dimension of output tensor
94+
// of an HLO instruction. SymbolicTiledHloInstruction associates a SymbolicTile
95+
// with an HLO instruction.
7896
class SymbolicTile {
7997
public:
8098
SymbolicTile(const TilingSpace& tiling_space,
@@ -120,8 +138,9 @@ H AbslHashValue(H h, const SymbolicTile& symbolic_tile) {
120138
return h;
121139
}
122140

123-
// Returns a DimTile that covers the entire dimension, i.e.
124-
// offset 0, size = next_power_of_2(dim_size), stride 1, upper_bound = dim_size.
141+
// Returns a DimTile that covers the entire dimension with a single power of 2
142+
// sized tile, i.e. offset 0, size = next_power_of_2(dim_size), stride 1,
143+
// upper_bound = dim_size.
125144
DimTile GetFullDimTile(int64_t dim_size, mlir::MLIRContext* ctx);
126145

127146
// Returns a DimTile that covers the entire dimension, i.e.

xla/service/gpu/model/experimental/tiling_space.cc

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ limitations under the License.
2727
#include "llvm/ADT/STLExtras.h"
2828
#include "llvm/ADT/SmallVector.h"
2929
#include "mlir/IR/MLIRContext.h"
30-
#include "xla/hlo/analysis/indexing_map.h"
30+
#include "xla/hlo/analysis/interval.h"
3131
#include "xla/hlo/ir/hlo_casting_utils.h"
3232
#include "xla/hlo/ir/hlo_instruction.h"
3333
#include "xla/hlo/ir/hlo_instructions.h"
@@ -65,6 +65,23 @@ void TilingSpace::AppendRTVar(const HloInstruction* hlo, int64_t operand_id,
6565
hlo_to_rt_var_[std::make_pair(hlo, operand_id)] = &rt_vars_.back();
6666
}
6767

68+
void TilingSpace::ProcessInstruction(const HloInstruction& hlo) {
69+
switch (hlo.opcode()) {
70+
case HloOpcode::kDot:
71+
ProcessDot(hlo);
72+
break;
73+
case HloOpcode::kReduce:
74+
ProcessReduce(hlo);
75+
break;
76+
case HloOpcode::kDynamicSlice:
77+
ProcessDynamicSlice(hlo);
78+
break;
79+
default:
80+
// TODO(goncharov): should have a explicit list of supported instructions?
81+
break;
82+
}
83+
}
84+
6885
// Add dot contraction dimensions in the order of contracting dimensions.
6986
void TilingSpace::ProcessDot(const HloInstruction& hlo) {
7087
auto dot = Cast<HloDotInstruction>(&hlo);
@@ -143,15 +160,17 @@ const TilingSpace::DimensionInfo& TilingSpace::GetDimensionInfo(
143160
const HloInstruction& hlo, int64_t dim_position) const {
144161
auto it = hlo_to_dimension_.find(std::make_pair(&hlo, dim_position));
145162
CHECK(it != hlo_to_dimension_.end())
146-
<< "Dimension not found: " << hlo.ToString() << " " << dim_position;
163+
<< "Dimension not found for " << hlo.ToString() << " dimension "
164+
<< dim_position;
147165
return *it->second;
148166
}
149167

150168
const TilingSpace::RTVarInfo& TilingSpace::GetRTVarInfo(
151169
const HloInstruction& hlo, int64_t operand_id) const {
152170
auto it = hlo_to_rt_var_.find(std::make_pair(&hlo, operand_id));
153171
CHECK(it != hlo_to_rt_var_.end())
154-
<< "Runtime variable not found: " << hlo.ToString();
172+
<< "Runtime variable not found for " << hlo.ToString() << " operand "
173+
<< operand_id;
155174
return *it->second;
156175
}
157176

@@ -163,8 +182,10 @@ std::unique_ptr<TilingSpace> TilingSpace::Create(const HloFusionAdaptor& fusion,
163182
for (const HloInstructionAdaptor& root : roots) {
164183
const Shape& root_shape = root.shape();
165184
if (!root.shape().IsArray() && root.opcode() != HloOpcode::kReduce) {
166-
LOG(FATAL) << "Unsupported root shape: " << root_shape.ToString();
185+
LOG(FATAL) << "Unsupported root shape " << root_shape.ToString()
186+
<< " for root " << root.instruction().ToString();
167187
}
188+
// TODO(goncharov): why do we only care about the first shape of a tuple?
168189
absl::Span<const int64_t> dims =
169190
GetFirstShape(&root.instruction()).dimensions();
170191
llvm::SmallVector<DimTile> dim_tiles;
@@ -186,19 +207,7 @@ std::unique_ptr<TilingSpace> TilingSpace::Create(const HloFusionAdaptor& fusion,
186207
// Iterator in reversed post-order (use-before-def).
187208
auto post_order = fusion.MakeInstructionPostOrder();
188209
for (auto it = post_order.rbegin(); it != post_order.rend(); ++it) {
189-
switch (it->instruction().opcode()) {
190-
case HloOpcode::kDot:
191-
tiling_space->ProcessDot(it->instruction());
192-
break;
193-
case HloOpcode::kReduce:
194-
tiling_space->ProcessReduce(it->instruction());
195-
break;
196-
case HloOpcode::kDynamicSlice:
197-
tiling_space->ProcessDynamicSlice(it->instruction());
198-
break;
199-
default:
200-
break;
201-
}
210+
tiling_space->ProcessInstruction(it->instruction());
202211
}
203212
return tiling_space;
204213
}

xla/service/gpu/model/experimental/tiling_space.h

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,26 @@ limitations under the License.
2626
#include "llvm/ADT/SmallVector.h"
2727
#include "mlir/IR/MLIRContext.h"
2828
#include "xla/codegen/tiling/constraint_expression.h"
29-
#include "xla/hlo/analysis/indexing_map.h"
29+
#include "xla/hlo/analysis/interval.h"
3030
#include "xla/hlo/ir/hlo_instruction.h"
3131
#include "xla/hlo/utils/hlo_traversal.h"
3232
#include "xla/service/gpu/model/experimental/symbolic_tile.h"
3333
#include "xla/shape.h"
3434

3535
namespace xla::gpu::experimental {
3636

37-
// TilingSpace contains information about all parallel and sequential dimensions
38-
// and runtime variables in a fusion.
39-
// The parallel dimensions correspond to the dimensions of the outputs of the
40-
// fusion.
41-
// The sequential dimensions correspond to the contraction/reduction dimensions
42-
// of the dots/reduces in the fusion.
43-
// The runtime variables correspond to the offsets of the dynamic slices in the
44-
// fusion.
37+
// TilingSpace holds information about all tiling parameters of a fusion.
38+
//
39+
// It defines symbolic tiles for the fusion roots as symbolic expressions and
40+
// constraints of possible tile "variables":
41+
// * parallel dimensions - output dimensions of the fusion;
42+
// * sequential dimensions - contraction/reduction dimensions of operations in
43+
// the fusion;
44+
// * runtime variables - for example, offsets of the dynamic slices.
45+
//
46+
// This information allows us later to explore the space of all possible tilings
47+
// and assign concrete tilings for every instruction of the fusion with
48+
// SymbolicTilePropagation.
4549
class TilingSpace {
4650
public:
4751
TilingSpace() : constraints_(ConstraintExpression::GetAlwaysSatisfied()) {}
@@ -51,29 +55,46 @@ class TilingSpace {
5155

5256
enum class DimensionSemantics { kParallel, kSequential };
5357
struct DimensionInfo {
54-
// Unique ID for the dimension.
58+
// Unique ID for the dimension within the tiling space.
5559
ID id;
5660
// Size of the dimension.
5761
int64_t dimension_size;
5862
// Type of the dimension.
5963
DimensionSemantics type;
60-
// HLO instruction that defines the dimension.
64+
// HLO instruction that defines (introduces) the dimension. For example
65+
// fusion root instruction defines the parallel dimensions. Dot/reduce
66+
// defines the sequential (contraction) dimensions.
6167
const HloInstruction* hlo;
62-
// Index into the ordered list of dimensions of the HLO instruction.
63-
// All dimensions in the HLO instruction are described as
68+
// Index into the ordered list of dimensions of the HLO instruction `hlo`
69+
// that defines the dimension.
70+
// All dimensions in the HLO instruction are ordered as
6471
// [all parallel dims of the output, all reduction/contraction dims].
6572
//
66-
// Example:
67-
// [output_dims] dot(lhs, rhs, lhs_contracting_dims, rhs_contracting_dims)
68-
// The dimensions are ordered as [output_dims, LHS[lhs_contracting_dims]].
73+
// Example, for `[a,b,c] = dot(lhs, rhs, lhs_contracting_dims={d,e}, ...)`.
74+
// The ordered list of dimensions is [a,b,c,d,e].
6975
int64_t dim_position;
7076
};
7177

78+
// Information about a runtime variable.
79+
// For example:
80+
//
81+
// off = s32[] parameter(0)
82+
// ds = dynamic-slice(tensor, off), ...
83+
//
84+
// `off = s32[] parameter(0)` instruction (`hlo`) defines the runtime
85+
// variable.
86+
// User's (dynamic-slice) semantics sets the `bounds` of possible values.
87+
//
88+
// If the same hlo is used as runtime variable multiple times, there will be
89+
// multiple entries in the `rt_vars_` with different IDs.
90+
//
91+
// RTVarInfo are accessed by (user_hlo, operand_id), in this case it is
92+
// (dynamic-slice, 1).
7293
struct RTVarInfo {
73-
// Unique ID for the runtime variable.
94+
// Unique ID for the runtime variable within the tiling space.
7495
ID id;
75-
// Feasible bounds of the runtime variable. The values outside of the bounds
76-
// will be clamped.
96+
// Feasible bounds of the runtime variable.
97+
// The values outside of the bounds will be clamped.
7798
Interval bounds;
7899
// HLO instruction that defines the runtime variable.
79100
const HloInstruction* hlo;
@@ -90,9 +111,15 @@ class TilingSpace {
90111
sink.Append(space.ToString());
91112
}
92113

114+
// Returns the dimension info for the given `hlo` and `dim_position`.
115+
// `dim_position` is the index into the ordered list of dimensions of the HLO
116+
// instruction `hlo` that defines the dimension. The dimension info must
117+
// exist.
93118
const DimensionInfo& GetDimensionInfo(const HloInstruction& hlo,
94119
int64_t dim_position) const;
95120

121+
// Returns the runtime variable info for `hlo` that uses it and its
122+
// `operand_id`. This runtime variable info must exist.
96123
const RTVarInfo& GetRTVarInfo(const HloInstruction& hlo,
97124
int64_t operand_id) const;
98125

@@ -115,6 +142,7 @@ class TilingSpace {
115142
void ProcessDot(const HloInstruction& hlo);
116143
void ProcessReduce(const HloInstruction& hlo);
117144
void ProcessDynamicSlice(const HloInstruction& hlo);
145+
void ProcessInstruction(const HloInstruction& hlo);
118146

119147
// Maps from (hlo, dim_position) to the dimension info.
120148
absl::flat_hash_map<std::pair<const HloInstruction*, int64_t>,
@@ -130,7 +158,9 @@ class TilingSpace {
130158
// The deque is used to guarantee the pointer stability.
131159
std::deque<RTVarInfo> rt_vars_;
132160

133-
// Root instruction of the fusion.
161+
// Symbolic tiles for the fusion roots.
162+
// For tuple roots, there will be one tile per tuple element. Otherwise,
163+
// there will be only one symbolic tile.
134164
llvm::SmallVector<SymbolicTile, 2> tiled_roots_;
135165

136166
// Constraint expression for the tiling space.

0 commit comments

Comments
 (0)